from dataclasses import dataclass from datetime import datetime from time import perf_counter from uuid import uuid4 from fastapi import Request, Response from starlette.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from app.bootstrap.settings import ApiGatewaySettings from app.domain.repositories import ApiKeyRepository from app.infrastructure.api_keys import hash_api_key REQUEST_ID_HEADER = "x-request-id" TENANT_ID_HEADER = "x-tenant-id" DEFAULT_TENANT_ID = "public" @dataclass class GatewayRequestContext: request_id: str tenant_id: str started_perf_counter: float api_key_id: str | None = None target_service: str | None = None target_url: str | None = None class GatewayRequestContextMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> Response: request_id = request.headers.get(REQUEST_ID_HEADER) or str(uuid4()) tenant_id = resolve_tenant_id(request) request.state.gateway_context = GatewayRequestContext( request_id=request_id, tenant_id=tenant_id, started_perf_counter=perf_counter(), ) auth_response = authenticate_gateway_request(request) if auth_response is not None: from app.infrastructure.audit import persist_gateway_audit persist_gateway_audit( request=request, session_factory=request.app.state.session_factory, status_code=auth_response.status_code, error_message=None, ) context = get_gateway_request_context(request) auth_response.headers[REQUEST_ID_HEADER] = request_id auth_response.headers[TENANT_ID_HEADER] = context.tenant_id return auth_response try: response = await call_next(request) except Exception as exc: from app.infrastructure.audit import persist_gateway_audit persist_gateway_audit( request=request, session_factory=request.app.state.session_factory, status_code=500, error_message=str(exc), ) raise from app.infrastructure.audit import persist_gateway_audit persist_gateway_audit( request=request, session_factory=request.app.state.session_factory, status_code=response.status_code, ) context = get_gateway_request_context(request) response.headers[REQUEST_ID_HEADER] = request_id response.headers[TENANT_ID_HEADER] = context.tenant_id return response def resolve_tenant_id(request: Request) -> str: header_tenant_id = request.headers.get(TENANT_ID_HEADER) if header_tenant_id: return header_tenant_id query_tenant_id = request.query_params.get("tenant_id") if query_tenant_id: return query_tenant_id return DEFAULT_TENANT_ID def get_gateway_request_context(request: Request) -> GatewayRequestContext: context = getattr(request.state, "gateway_context", None) if isinstance(context, GatewayRequestContext): return context return GatewayRequestContext( request_id=str(uuid4()), tenant_id=DEFAULT_TENANT_ID, started_perf_counter=perf_counter(), ) def authenticate_gateway_request(request: Request) -> Response | None: settings = ApiGatewaySettings() if not settings.auth_required: return None if not request.url.path.startswith("/gateway/"): return None if request.url.path in {"/gateway/services/health"}: return None if is_initial_api_key_bootstrap_request(request): return None api_key = request.headers.get(settings.api_key_header_name) if not api_key: return JSONResponse( status_code=401, content={"detail": "missing api key"}, ) db = request.app.state.session_factory() try: entity = ApiKeyRepository(db).get_active_by_hash(key_hash=hash_api_key(api_key)) if entity is None: return JSONResponse( status_code=401, content={"detail": "invalid api key"}, ) if entity.expires_time is not None and entity.expires_time <= datetime.utcnow(): return JSONResponse( status_code=401, content={"detail": "api key expired"}, ) context = get_gateway_request_context(request) requested_tenant_id = resolve_tenant_id(request) if requested_tenant_id not in {DEFAULT_TENANT_ID, entity.tenant_id}: return JSONResponse( status_code=403, content={"detail": "api key tenant mismatch"}, ) context.tenant_id = entity.tenant_id context.api_key_id = entity.id ApiKeyRepository(db).touch_last_used_time(api_key_id=entity.id) finally: db.close() return None def is_initial_api_key_bootstrap_request(request: Request) -> bool: if request.method.upper() != "POST" or request.url.path != "/gateway/api-keys": return False db = request.app.state.session_factory() try: return not ApiKeyRepository(db).has_any() finally: db.close()