Skip to content

Commit

Permalink
New rate limiting / api key check
Browse files Browse the repository at this point in the history
  • Loading branch information
caparker committed Oct 3, 2024
1 parent 127766f commit 5e23ca5
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 194 deletions.
57 changes: 15 additions & 42 deletions openaq_api/openaq_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Any

import orjson
from fastapi import FastAPI, Request
from fastapi import FastAPI, Request, Depends
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -21,10 +21,9 @@

from openaq_api.db import db_pool
from openaq_api.middleware import (
check_api_key,
CacheControlMiddleware,
LoggingMiddleware,
PrivatePathsMiddleware,
RateLimiterMiddleWare,
)
from openaq_api.models.logging import (
InfrastructureErrorLog,
Expand All @@ -33,6 +32,8 @@
WarnLog,
)



# from openaq_api.routers.auth import router as auth_router
from openaq_api.routers.averages import router as averages_router
from openaq_api.routers.cities import router as cities_router
Expand Down Expand Up @@ -105,8 +106,6 @@ def render(self, content: Any) -> bytes:
return orjson.dumps(content, default=default)


redis_client = None # initialize for generalize_schema.py


@asynccontextmanager
async def lifespan(app: FastAPI):
Expand All @@ -119,7 +118,7 @@ async def lifespan(app: FastAPI):
app.state.counter += 1
else:
app.state.counter = 0
app.state.redis_client = redis_client

yield
if hasattr(app.state, "pool") and not settings.USE_SHARED_POOL:
logger.debug("Closing connection")
Expand All @@ -128,16 +127,21 @@ async def lifespan(app: FastAPI):
logger.debug("Connection closed")





app = FastAPI(
title="OpenAQ",
description="OpenAQ API",
version="2.0.0",
default_response_class=ORJSONResponse,
dependencies=[Depends(check_api_key)],
docs_url="/docs",
lifespan=lifespan,
)


app.redis = None
if settings.RATE_LIMITING is True:
if settings.RATE_LIMITING:
logger.debug("Connecting to redis")
Expand All @@ -150,25 +154,15 @@ async def lifespan(app: FastAPI):
decode_responses=True,
socket_timeout=5,
)
# attach to the app so it can be retrieved via the request
app.redis = redis_client
logger.debug("Redis connected")

except Exception as e:
logging.error(
InfrastructureErrorLog(detail=f"failed to connect to redis: {e}")
)
print(redis_client)
logger.debug("Redis connected")
if redis_client:
app.add_middleware(
RateLimiterMiddleWare,
redis_client=redis_client,
rate_amount_key=settings.RATE_AMOUNT_KEY,
rate_time=datetime.timedelta(minutes=settings.RATE_TIME),
)
else:
logger.warning(
WarnLog(
detail="valid redis client not provided but RATE_LIMITING set to TRUE"
)
)


app.add_middleware(
CORSMiddleware,
Expand All @@ -180,7 +174,6 @@ async def lifespan(app: FastAPI):
app.add_middleware(CacheControlMiddleware, cachecontrol="public, max-age=900")
app.add_middleware(LoggingMiddleware)
app.add_middleware(GZipMiddleware, minimum_size=1000)
app.add_middleware(PrivatePathsMiddleware)


class OpenAQValidationResponseDetail(BaseModel):
Expand All @@ -198,31 +191,11 @@ async def openaq_request_validation_exception_handler(
request: Request, exc: RequestValidationError
):
return ORJSONResponse(status_code=422, content=jsonable_encoder(str(exc)))
# return PlainTextResponse(str(exc))
# print("\n\n\n\n\n")
# print(str(exc))
# print("\n\n\n\n\n")
# detail = orjson.loads(str(exc))
# logger.debug(traceback.format_exc())
# logger.info(
# UnprocessableEntityLog(request=request, detail=str(exc)).model_dump_json()
# )
# detail = OpenAQValidationResponse(detail=detail)
# return ORJSONResponse(status_code=422, content=jsonable_encoder(detail))


@app.exception_handler(ValidationError)
async def openaq_exception_handler(request: Request, exc: ValidationError):
return ORJSONResponse(status_code=422, content=jsonable_encoder(str(exc)))
# detail = orjson.loads(exc.model_dump_json())
# logger.debug(traceback.format_exc())
# logger.error(
# ModelValidationError(
# request=request, detail=exc.jsmodel_dump_jsonon()
# ).model_dump_json()
# )
# return ORJSONResponse(status_code=422, content=jsonable_encoder(detail))
# return ORJSONResponse(status_code=500, content={"message": "internal server error"})


@app.get("/ping", include_in_schema=False)
Expand Down
Loading

0 comments on commit 5e23ca5

Please sign in to comment.