| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- from dataclasses import dataclass
- from time import perf_counter
- from uuid import uuid4
- from fastapi import Request, Response
- from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
- REQUEST_ID_HEADER = "x-request-id"
- TENANT_ID_HEADER = "x-tenant-id"
- DEFAULT_TENANT_ID = "public"
- @dataclass
- class GatewayRequestContext:
- request_id: str
- tenant_id: str
- started_perf_counter: float
- target_service: str | None = None
- target_url: str | None = None
- class GatewayRequestContextMiddleware(BaseHTTPMiddleware):
- async def dispatch(
- self,
- request: Request,
- call_next: RequestResponseEndpoint,
- ) -> Response:
- request_id = request.headers.get(REQUEST_ID_HEADER) or str(uuid4())
- tenant_id = resolve_tenant_id(request)
- request.state.gateway_context = GatewayRequestContext(
- request_id=request_id,
- tenant_id=tenant_id,
- started_perf_counter=perf_counter(),
- )
- try:
- response = await call_next(request)
- except Exception as exc:
- from app.infrastructure.audit import persist_gateway_audit
- persist_gateway_audit(
- request=request,
- session_factory=request.app.state.session_factory,
- status_code=500,
- error_message=str(exc),
- )
- raise
- from app.infrastructure.audit import persist_gateway_audit
- persist_gateway_audit(
- request=request,
- session_factory=request.app.state.session_factory,
- status_code=response.status_code,
- )
- response.headers[REQUEST_ID_HEADER] = request_id
- response.headers[TENANT_ID_HEADER] = tenant_id
- return response
- def resolve_tenant_id(request: Request) -> str:
- header_tenant_id = request.headers.get(TENANT_ID_HEADER)
- if header_tenant_id:
- return header_tenant_id
- query_tenant_id = request.query_params.get("tenant_id")
- if query_tenant_id:
- return query_tenant_id
- return DEFAULT_TENANT_ID
- def get_gateway_request_context(request: Request) -> GatewayRequestContext:
- context = getattr(request.state, "gateway_context", None)
- if isinstance(context, GatewayRequestContext):
- return context
- return GatewayRequestContext(
- request_id=str(uuid4()),
- tenant_id=DEFAULT_TENANT_ID,
- started_perf_counter=perf_counter(),
- )
|