from dataclasses import dataclass from time import perf_counter from uuid import uuid4 from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint REQUEST_ID_HEADER = "x-request-id" TENANT_ID_HEADER = "x-tenant-id" DEFAULT_TENANT_ID = "public" @dataclass class GatewayRequestContext: request_id: str tenant_id: str started_perf_counter: float target_service: str | None = None target_url: str | None = None class GatewayRequestContextMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> Response: request_id = request.headers.get(REQUEST_ID_HEADER) or str(uuid4()) tenant_id = resolve_tenant_id(request) request.state.gateway_context = GatewayRequestContext( request_id=request_id, tenant_id=tenant_id, started_perf_counter=perf_counter(), ) try: response = await call_next(request) except Exception as exc: from app.infrastructure.audit import persist_gateway_audit persist_gateway_audit( request=request, session_factory=request.app.state.session_factory, status_code=500, error_message=str(exc), ) raise from app.infrastructure.audit import persist_gateway_audit persist_gateway_audit( request=request, session_factory=request.app.state.session_factory, status_code=response.status_code, ) response.headers[REQUEST_ID_HEADER] = request_id response.headers[TENANT_ID_HEADER] = tenant_id return response def resolve_tenant_id(request: Request) -> str: header_tenant_id = request.headers.get(TENANT_ID_HEADER) if header_tenant_id: return header_tenant_id query_tenant_id = request.query_params.get("tenant_id") if query_tenant_id: return query_tenant_id return DEFAULT_TENANT_ID def get_gateway_request_context(request: Request) -> GatewayRequestContext: context = getattr(request.state, "gateway_context", None) if isinstance(context, GatewayRequestContext): return context return GatewayRequestContext( request_id=str(uuid4()), tenant_id=DEFAULT_TENANT_ID, started_perf_counter=perf_counter(), )