request_context.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from dataclasses import dataclass
  2. from time import perf_counter
  3. from uuid import uuid4
  4. from fastapi import Request, Response
  5. from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
  6. REQUEST_ID_HEADER = "x-request-id"
  7. TENANT_ID_HEADER = "x-tenant-id"
  8. DEFAULT_TENANT_ID = "public"
  9. @dataclass
  10. class GatewayRequestContext:
  11. request_id: str
  12. tenant_id: str
  13. started_perf_counter: float
  14. target_service: str | None = None
  15. target_url: str | None = None
  16. class GatewayRequestContextMiddleware(BaseHTTPMiddleware):
  17. async def dispatch(
  18. self,
  19. request: Request,
  20. call_next: RequestResponseEndpoint,
  21. ) -> Response:
  22. request_id = request.headers.get(REQUEST_ID_HEADER) or str(uuid4())
  23. tenant_id = resolve_tenant_id(request)
  24. request.state.gateway_context = GatewayRequestContext(
  25. request_id=request_id,
  26. tenant_id=tenant_id,
  27. started_perf_counter=perf_counter(),
  28. )
  29. try:
  30. response = await call_next(request)
  31. except Exception as exc:
  32. from app.infrastructure.audit import persist_gateway_audit
  33. persist_gateway_audit(
  34. request=request,
  35. session_factory=request.app.state.session_factory,
  36. status_code=500,
  37. error_message=str(exc),
  38. )
  39. raise
  40. from app.infrastructure.audit import persist_gateway_audit
  41. persist_gateway_audit(
  42. request=request,
  43. session_factory=request.app.state.session_factory,
  44. status_code=response.status_code,
  45. )
  46. response.headers[REQUEST_ID_HEADER] = request_id
  47. response.headers[TENANT_ID_HEADER] = tenant_id
  48. return response
  49. def resolve_tenant_id(request: Request) -> str:
  50. header_tenant_id = request.headers.get(TENANT_ID_HEADER)
  51. if header_tenant_id:
  52. return header_tenant_id
  53. query_tenant_id = request.query_params.get("tenant_id")
  54. if query_tenant_id:
  55. return query_tenant_id
  56. return DEFAULT_TENANT_ID
  57. def get_gateway_request_context(request: Request) -> GatewayRequestContext:
  58. context = getattr(request.state, "gateway_context", None)
  59. if isinstance(context, GatewayRequestContext):
  60. return context
  61. return GatewayRequestContext(
  62. request_id=str(uuid4()),
  63. tenant_id=DEFAULT_TENANT_ID,
  64. started_perf_counter=perf_counter(),
  65. )