Skip to content

Commit

Permalink
Merge pull request #268 from openaq/staging
Browse files Browse the repository at this point in the history
Hot fix/api key logging (#267)
  • Loading branch information
russbiggs authored Sep 15, 2023
2 parents d5f7baf + 25e6ebf commit 0082237
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 20 deletions.
35 changes: 23 additions & 12 deletions openaq_api/openaq_api/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import time
from datetime import timedelta
from os import environ
import json
from fastapi import Response, status
from fastapi.responses import JSONResponse
from redis.asyncio.cluster import RedisCluster
Expand All @@ -15,6 +14,7 @@
LogType,
TooManyRequestsLog,
UnauthorizedLog,
RedisErrorLog,
)

from .settings import settings
Expand Down Expand Up @@ -111,29 +111,33 @@ def __init__(
"""Init Middleware."""
super().__init__(app)
self.redis_client = redis_client
self.rate_amount = rate_amount # 100
self.rate_amount_key = rate_amount_key # 400
self.rate_amount = rate_amount
self.rate_amount_key = rate_amount_key
self.rate_time = rate_time

async def request_is_limited(self, key: str, limit: int, request: Request):
async def request_is_limited(self, key: str, limit: int, request: Request) -> bool:
if await self.redis_client.set(key, limit, nx=True):
await self.redis_client.expire(key, int(self.rate_time.total_seconds()))
count = await self.redis_client.get(key)
if count in ("-1", "-2"):
logger.error(
RedisErrorLog(
detail=f"redis has an invalid value for limit: {count} for key: {key}"
)
)
if count and int(count) > 0:
request.state.counter = await self.redis_client.decrby(key, 1)
return False
if int(count) < 0:
logger.error(f"rate limiter hit a value below zero: {count} for key: {key}")
return True

async def check_valid_key(self, key: str):
async def check_valid_key(self, key: str) -> bool:
if await self.redis_client.sismember("keys", key):
return True
return False

@staticmethod
def limited_path(route: str) -> bool:
allow_list = ["/", "/openapi.json", "/docs", "/register", "/assets"]
allow_list = ["/", "/openapi.json", "/docs", "/register"]
if route in allow_list:
return False
if "/v2/locations/tiles" in route:
Expand All @@ -157,16 +161,23 @@ async def dispatch(
key = request.client.host

if auth:
if not self.check_valid_key(auth):
logging.info(UnauthorizedLog(request=request).model_dump_json())
valid_key = await self.check_valid_key(auth)
if not valid_key:
logging.info(
UnauthorizedLog(
request=request, detail=f"invalid key used: {auth}"
).model_dump_json()
)
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"message": "invalid credentials"},
)
key = auth
limit = self.rate_amount_key
request.state.counter = 0
limited = await self.request_is_limited(key, limit, request)
request.state.counter = limit
limited = False
if self.limited_path(route):
limited = await self.request_is_limited(key, limit, request)
if self.limited_path(route) and limited:
logging.info(
TooManyRequestsLog(
Expand Down
13 changes: 11 additions & 2 deletions openaq_api/openaq_api/models/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class LogType(StrEnum):
TOO_MANY_REQUESTS = "TOO_MANY_REQUESTS"
WARNING = "WARNING"
INFO = "INFO"
ERROR = "ERROR"


class BaseLog(BaseModel):
Expand Down Expand Up @@ -42,6 +43,10 @@ class WarnLog(BaseLog):
type: LogType = LogType.WARNING


class ErrorLog(BaseLog):
type: LogType = LogType.ERROR


class InfrastructureErrorLog(BaseLog):
type: LogType = LogType.INFRASTRUCTURE_ERROR

Expand Down Expand Up @@ -124,7 +129,7 @@ def params_keys(self) -> list:
return [] if self.params_obj is None else list(self.params_obj.keys())


class ErrorLog(HTTPLog):
class HTTPErrorLog(HTTPLog):
"""Log for HTTP 500.
Inherits from HTTPLog
Expand Down Expand Up @@ -163,10 +168,14 @@ class UnauthorizedLog(HTTPLog):
type: LogType = LogType.UNAUTHORIZED


class ModelValidationError(ErrorLog):
class ModelValidationError(HTTPErrorLog):
"""Log for model validations
Inherits from ErrorLog
"""

type: LogType = LogType.VALIDATION_ERROR


class RedisErrorLog(ErrorLog):
detail: str
15 changes: 9 additions & 6 deletions openaq_api/openaq_api/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..db import DB
from ..forms.register import RegisterForm, UserExistsException
from ..models.auth import User
from ..models.logging import AuthLog, InfoLog, SESEmailLog
from ..models.logging import AuthLog, ErrorLog, InfoLog, SESEmailLog
from ..settings import settings

logger = logging.getLogger("auth")
Expand Down Expand Up @@ -207,11 +207,14 @@ async def verify(request: Request, verification_code: str, db: DB = Depends()):
{"request": request, "error": True, "error_message": message},
)
else:
token = await db.get_user_token(row[0])
if request.app.state.redis_client:
redis_client = request.app.state.redis_client
await redis_client.sadd("keys", token)
send_api_key_email(token, row[3], row[4])
try:
token = await db.get_user_token(row[0])
if request.app.state.redis_client:
redis_client = request.app.state.redis_client
await redis_client.sadd("keys", token)
send_api_key_email(token, row[3], row[4])
except Exception as e:
logger.error(ErrorLog(detail=f"something went wrong: {e}"))
return templates.TemplateResponse(
"verify/index.html", {"request": request, "error": False, "verify": True}
)
Expand Down

0 comments on commit 0082237

Please sign in to comment.