from dataclasses import dataclass from datetime import datetime from time import perf_counter from uuid import uuid4 import httpx from fastapi import Request, Response from starlette.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from app.bootstrap.settings import ApiGatewaySettings from app.domain.repositories import ApiKeyRepository from app.infrastructure.api_keys import hash_api_key REQUEST_ID_HEADER = "x-request-id" TENANT_ID_HEADER = "x-tenant-id" DEFAULT_TENANT_ID = "public" @dataclass class GatewayRequestContext: request_id: str tenant_id: str started_perf_counter: float api_key_id: str | None = None user_id: str | None = None target_service: str | None = None target_url: str | None = None class GatewayRequestContextMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> Response: request_id = request.headers.get(REQUEST_ID_HEADER) or str(uuid4()) tenant_id = resolve_tenant_id(request) request.state.gateway_context = GatewayRequestContext( request_id=request_id, tenant_id=tenant_id, started_perf_counter=perf_counter(), ) auth_response = authenticate_gateway_request(request) if auth_response is not None: from app.infrastructure.audit import persist_gateway_audit persist_gateway_audit( request=request, session_factory=request.app.state.session_factory, status_code=auth_response.status_code, error_message=None, ) context = get_gateway_request_context(request) auth_response.headers[REQUEST_ID_HEADER] = request_id auth_response.headers[TENANT_ID_HEADER] = context.tenant_id return auth_response try: response = await call_next(request) except Exception as exc: from app.infrastructure.audit import persist_gateway_audit persist_gateway_audit( request=request, session_factory=request.app.state.session_factory, status_code=500, error_message=str(exc), ) raise from app.infrastructure.audit import persist_gateway_audit persist_gateway_audit( request=request, session_factory=request.app.state.session_factory, status_code=response.status_code, ) context = get_gateway_request_context(request) response.headers[REQUEST_ID_HEADER] = request_id response.headers[TENANT_ID_HEADER] = context.tenant_id return response def resolve_tenant_id(request: Request) -> str: header_tenant_id = request.headers.get(TENANT_ID_HEADER) if header_tenant_id: return header_tenant_id query_tenant_id = request.query_params.get("tenant_id") if query_tenant_id: return query_tenant_id return DEFAULT_TENANT_ID def get_gateway_request_context(request: Request) -> GatewayRequestContext: context = getattr(request.state, "gateway_context", None) if isinstance(context, GatewayRequestContext): return context return GatewayRequestContext( request_id=str(uuid4()), tenant_id=DEFAULT_TENANT_ID, started_perf_counter=perf_counter(), ) def authenticate_gateway_request(request: Request) -> Response | None: settings = ApiGatewaySettings() if not settings.auth_required: return None if not request.url.path.startswith("/gateway/"): return None if request.url.path in {"/gateway/services/health"}: return None if is_initial_api_key_bootstrap_request(request): return None api_key = request.headers.get(settings.api_key_header_name) if not api_key: return JSONResponse( status_code=401, content={"detail": "missing api key"}, ) db = request.app.state.session_factory() try: entity = ApiKeyRepository(db).get_active_by_hash(key_hash=hash_api_key(api_key)) if entity is None: return JSONResponse( status_code=401, content={"detail": "invalid api key"}, ) if entity.expires_time is not None and entity.expires_time <= datetime.utcnow(): return JSONResponse( status_code=401, content={"detail": "api key expired"}, ) context = get_gateway_request_context(request) requested_tenant_id = resolve_tenant_id(request) if requested_tenant_id not in {DEFAULT_TENANT_ID, entity.tenant_id}: return JSONResponse( status_code=403, content={"detail": "api key tenant mismatch"}, ) context.tenant_id = entity.tenant_id context.api_key_id = entity.id context.user_id = request.headers.get(settings.user_id_header_name) permission = derive_gateway_permission(request) if permission is not None and not api_key_scope_allows( scopes=entity.scopes, permission=permission, ): return JSONResponse( status_code=403, content={"detail": "api key scope denied", "permission": permission}, ) if settings.authz_required: if context.user_id is None: return JSONResponse( status_code=401, content={"detail": "missing user id"}, ) authz_response = check_auth_service_permission( settings=settings, tenant_id=entity.tenant_id, user_id=context.user_id, permission=permission or "gateway:access", ) if authz_response is not None: return authz_response ApiKeyRepository(db).touch_last_used_time(api_key_id=entity.id) finally: db.close() return None def is_initial_api_key_bootstrap_request(request: Request) -> bool: if request.method.upper() != "POST" or request.url.path != "/gateway/api-keys": return False db = request.app.state.session_factory() try: return not ApiKeyRepository(db).has_any() finally: db.close() def derive_gateway_permission(request: Request) -> str | None: if not request.url.path.startswith("/gateway/"): return None path_parts = [part for part in request.url.path.split("/") if part] if len(path_parts) < 2: return "gateway:access" if path_parts[1] in {"services", "api-keys", "audits"}: resource = path_parts[1] else: resource = path_parts[1].replace("_", "-") action = "read" if request.method.upper() in {"GET", "HEAD", "OPTIONS"} else "write" return f"gateway:{resource}:{action}" def api_key_scope_allows(*, scopes: str | None, permission: str) -> bool: if scopes is None or not scopes.strip(): return True scope_values = parse_scope_values(scopes) if "*" in scope_values or permission in scope_values: return True resource_prefix = permission.rsplit(":", 1)[0] return f"{resource_prefix}:*" in scope_values def parse_scope_values(scopes: str) -> set[str]: normalized = scopes.replace(",", " ").replace("\n", " ") return {item.strip() for item in normalized.split(" ") if item.strip()} def check_auth_service_permission( *, settings: ApiGatewaySettings, tenant_id: str, user_id: str, permission: str, ) -> Response | None: try: with httpx.Client(timeout=settings.authz_timeout_seconds) as client: response = client.post( f"{settings.auth_service_url.rstrip('/')}/auth/permissions/check", json={ "tenant_id": tenant_id, "user_id": user_id, "permission": permission, }, ) response.raise_for_status() payload = response.json() except (httpx.HTTPError, ValueError) as exc: return JSONResponse( status_code=503, content={"detail": "auth service permission check failed", "error": str(exc)}, ) allowed = payload.get("allowed") if allowed is True: return None reason = payload.get("reason") return JSONResponse( status_code=403, content={ "detail": "permission denied", "permission": permission, "reason": reason if isinstance(reason, str) else "denied", }, )