from dataclasses import dataclass from datetime import datetime from time import perf_counter from uuid import uuid4 import httpx from core_shared.security import build_internal_service_headers from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.responses import JSONResponse from app.bootstrap.settings import ApiGatewaySettings from app.domain.repositories import ApiKeyRepository from app.infrastructure.api_keys import hash_api_key from app.infrastructure.rate_limit import ( apply_gateway_rate_limit_headers, enforce_gateway_rate_limit, ) REQUEST_ID_HEADER = "x-request-id" @dataclass class GatewayRequestContext: request_id: str started_perf_counter: float api_key_id: str | None = None user_id: str | None = None target_service: str | None = None target_url: str | None = None class GatewayRequestContextMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint) -> Response: request_id = request.headers.get(REQUEST_ID_HEADER) or str(uuid4()) request.state.gateway_context = GatewayRequestContext( request_id=request_id, started_perf_counter=perf_counter()) auth_response = authenticate_gateway_request(request) if auth_response is not None: from app.infrastructure.audit import persist_gateway_audit persist_gateway_audit( request=request, session_factory=request.app.state.session_factory, status_code=auth_response.status_code, error_message=None) auth_response.headers[REQUEST_ID_HEADER] = request_id return auth_response settings = request.app.state.settings rate_limit_response = enforce_gateway_rate_limit( request=request, settings=settings, limiter=request.app.state.rate_limiter) if rate_limit_response is not None: from app.infrastructure.audit import persist_gateway_audit persist_gateway_audit( request=request, session_factory=request.app.state.session_factory, status_code=rate_limit_response.status_code, error_message="rate limit exceeded") rate_limit_response.headers[REQUEST_ID_HEADER] = request_id return rate_limit_response try: response = await call_next(request) except Exception as exc: from app.infrastructure.audit import persist_gateway_audit persist_gateway_audit( request=request, session_factory=request.app.state.session_factory, status_code=500, error_message=str(exc)) raise from app.infrastructure.audit import persist_gateway_audit persist_gateway_audit( request=request, session_factory=request.app.state.session_factory, status_code=response.status_code) response.headers[REQUEST_ID_HEADER] = request_id apply_gateway_rate_limit_headers( response=response, request=request, settings=settings, limiter=request.app.state.rate_limiter) return response def get_gateway_request_context(request: Request) -> GatewayRequestContext: context = getattr(request.state, "gateway_context", None) if isinstance(context, GatewayRequestContext): return context return GatewayRequestContext( request_id=str(uuid4()), started_perf_counter=perf_counter()) def authenticate_gateway_request(request: Request) -> Response | None: settings = ApiGatewaySettings() if not settings.auth_required: return None if not request.url.path.startswith("/gateway/"): return None if request.url.path in {"/gateway/services/health"}: return None if is_auth_login_request(request): return None if is_initial_api_key_bootstrap_request(request): return None bearer_token = get_bearer_token(request) if bearer_token is not None: return authenticate_bearer_token(request=request, settings=settings, token=bearer_token) api_key = request.headers.get(settings.api_key_header_name) if not api_key: return JSONResponse( status_code=401, content={"detail": "missing bearer token or api key"}) db = request.app.state.session_factory() try: entity = ApiKeyRepository(db).get_active_by_hash(key_hash=hash_api_key(api_key)) if entity is None: return JSONResponse( status_code=401, content={"detail": "invalid api key"}) if entity.expires_time is not None and entity.expires_time <= datetime.utcnow(): return JSONResponse( status_code=401, content={"detail": "api key expired"}) context = get_gateway_request_context(request) context.api_key_id = entity.id context.user_id = request.headers.get(settings.user_id_header_name) permission = derive_gateway_permission(request) if permission is not None and not api_key_scope_allows( scopes=entity.scopes, permission=permission): return JSONResponse( status_code=403, content={"detail": "api key scope denied", "permission": permission}) if settings.authz_required: if context.user_id is None: return JSONResponse( status_code=401, content={"detail": "missing user id"}) authz_response = check_auth_service_permission( settings=settings, user_id=context.user_id, permission=permission or "gateway:access") if authz_response is not None: return authz_response ApiKeyRepository(db).touch_last_used_time(api_key_id=entity.id) finally: db.close() return None def is_auth_login_request(request: Request) -> bool: return request.method.upper() == "POST" and request.url.path == "/gateway/identity/auth/login" def is_initial_api_key_bootstrap_request(request: Request) -> bool: if request.method.upper() != "POST" or request.url.path != "/gateway/api-keys": return False db = request.app.state.session_factory() try: return not ApiKeyRepository(db).has_any() finally: db.close() def get_bearer_token(request: Request) -> str | None: authorization = request.headers.get("authorization") if authorization is None: return None scheme, _, token = authorization.partition(" ") if scheme.lower() != "bearer" or not token.strip(): return None return token.strip() def authenticate_bearer_token( *, request: Request, settings: ApiGatewaySettings, token: str) -> Response | None: try: with httpx.Client(timeout=settings.authz_timeout_seconds) as client: response = client.post( f"{settings.auth_service_url.rstrip('/')}/identity/auth/tokens/verify", headers=build_internal_service_headers(settings), json={"accessToken": token}) response.raise_for_status() payload = response.json() except (httpx.HTTPError, ValueError) as exc: return JSONResponse( status_code=503, content={"detail": "auth token verification failed", "error": str(exc)}) data = payload.get("data") if not isinstance(data, dict): return JSONResponse( status_code=401, content={"detail": "invalid token verification response"}) if data.get("active") is not True: reason = data.get("reason") return JSONResponse( status_code=401, content={ "detail": "invalid bearer token", "reason": reason if isinstance(reason, str) else "inactive", }) user_id = data.get("userId") if not isinstance(user_id, str): return JSONResponse( status_code=401, content={"detail": "invalid token identity"}) context = get_gateway_request_context(request) context.user_id = user_id if settings.authz_required: permission = derive_gateway_permission(request) or "gateway:access" authz_response = check_auth_service_permission( settings=settings, user_id=user_id, permission=permission) if authz_response is not None: return authz_response return None def derive_gateway_permission(request: Request) -> str | None: if not request.url.path.startswith("/gateway/"): return None path_parts = [part for part in request.url.path.split("/") if part] if len(path_parts) < 2: return "gateway:access" if path_parts[1] in {"services", "api-keys", "audits"}: resource = path_parts[1] else: resource = path_parts[1].replace("_", "-") action = "read" if request.method.upper() in {"GET", "HEAD", "OPTIONS"} else "write" return f"gateway:{resource}:{action}" def api_key_scope_allows(*, scopes: str | None, permission: str) -> bool: if scopes is None or not scopes.strip(): return True scope_values = parse_scope_values(scopes) if "*" in scope_values or permission in scope_values: return True resource_prefix = permission.rsplit(":", 1)[0] return f"{resource_prefix}:*" in scope_values def parse_scope_values(scopes: str) -> set[str]: normalized = scopes.replace(",", " ").replace("\n", " ") return {item.strip() for item in normalized.split(" ") if item.strip()} def check_auth_service_permission( *, settings: ApiGatewaySettings, user_id: str, permission: str) -> Response | None: try: with httpx.Client(timeout=settings.authz_timeout_seconds) as client: response = client.post( f"{settings.auth_service_url.rstrip('/')}/identity/permissions/check", headers=build_internal_service_headers(settings), json={ "userId": user_id, "permission": permission, }) response.raise_for_status() payload = response.json() except (httpx.HTTPError, ValueError) as exc: return JSONResponse( status_code=503, content={"detail": "auth service permission check failed", "error": str(exc)}) data = payload.get("data") if not isinstance(data, dict): return JSONResponse( status_code=403, content={"detail": "invalid permission check response"}) allowed = data.get("allowed") if allowed is True: return None reason = data.get("reason") return JSONResponse( status_code=403, content={ "detail": "permission denied", "permission": permission, "reason": reason if isinstance(reason, str) else "denied", })