request_context.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. from dataclasses import dataclass
  2. from datetime import datetime
  3. from time import perf_counter
  4. from uuid import uuid4
  5. from fastapi import Request, Response
  6. from starlette.responses import JSONResponse
  7. from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
  8. from app.bootstrap.settings import ApiGatewaySettings
  9. from app.domain.repositories import ApiKeyRepository
  10. from app.infrastructure.api_keys import hash_api_key
  11. REQUEST_ID_HEADER = "x-request-id"
  12. TENANT_ID_HEADER = "x-tenant-id"
  13. DEFAULT_TENANT_ID = "public"
  14. @dataclass
  15. class GatewayRequestContext:
  16. request_id: str
  17. tenant_id: str
  18. started_perf_counter: float
  19. api_key_id: str | None = None
  20. target_service: str | None = None
  21. target_url: str | None = None
  22. class GatewayRequestContextMiddleware(BaseHTTPMiddleware):
  23. async def dispatch(
  24. self,
  25. request: Request,
  26. call_next: RequestResponseEndpoint,
  27. ) -> Response:
  28. request_id = request.headers.get(REQUEST_ID_HEADER) or str(uuid4())
  29. tenant_id = resolve_tenant_id(request)
  30. request.state.gateway_context = GatewayRequestContext(
  31. request_id=request_id,
  32. tenant_id=tenant_id,
  33. started_perf_counter=perf_counter(),
  34. )
  35. auth_response = authenticate_gateway_request(request)
  36. if auth_response is not None:
  37. from app.infrastructure.audit import persist_gateway_audit
  38. persist_gateway_audit(
  39. request=request,
  40. session_factory=request.app.state.session_factory,
  41. status_code=auth_response.status_code,
  42. error_message=None,
  43. )
  44. context = get_gateway_request_context(request)
  45. auth_response.headers[REQUEST_ID_HEADER] = request_id
  46. auth_response.headers[TENANT_ID_HEADER] = context.tenant_id
  47. return auth_response
  48. try:
  49. response = await call_next(request)
  50. except Exception as exc:
  51. from app.infrastructure.audit import persist_gateway_audit
  52. persist_gateway_audit(
  53. request=request,
  54. session_factory=request.app.state.session_factory,
  55. status_code=500,
  56. error_message=str(exc),
  57. )
  58. raise
  59. from app.infrastructure.audit import persist_gateway_audit
  60. persist_gateway_audit(
  61. request=request,
  62. session_factory=request.app.state.session_factory,
  63. status_code=response.status_code,
  64. )
  65. context = get_gateway_request_context(request)
  66. response.headers[REQUEST_ID_HEADER] = request_id
  67. response.headers[TENANT_ID_HEADER] = context.tenant_id
  68. return response
  69. def resolve_tenant_id(request: Request) -> str:
  70. header_tenant_id = request.headers.get(TENANT_ID_HEADER)
  71. if header_tenant_id:
  72. return header_tenant_id
  73. query_tenant_id = request.query_params.get("tenant_id")
  74. if query_tenant_id:
  75. return query_tenant_id
  76. return DEFAULT_TENANT_ID
  77. def get_gateway_request_context(request: Request) -> GatewayRequestContext:
  78. context = getattr(request.state, "gateway_context", None)
  79. if isinstance(context, GatewayRequestContext):
  80. return context
  81. return GatewayRequestContext(
  82. request_id=str(uuid4()),
  83. tenant_id=DEFAULT_TENANT_ID,
  84. started_perf_counter=perf_counter(),
  85. )
  86. def authenticate_gateway_request(request: Request) -> Response | None:
  87. settings = ApiGatewaySettings()
  88. if not settings.auth_required:
  89. return None
  90. if not request.url.path.startswith("/gateway/"):
  91. return None
  92. if request.url.path in {"/gateway/services/health"}:
  93. return None
  94. if is_initial_api_key_bootstrap_request(request):
  95. return None
  96. api_key = request.headers.get(settings.api_key_header_name)
  97. if not api_key:
  98. return JSONResponse(
  99. status_code=401,
  100. content={"detail": "missing api key"},
  101. )
  102. db = request.app.state.session_factory()
  103. try:
  104. entity = ApiKeyRepository(db).get_active_by_hash(key_hash=hash_api_key(api_key))
  105. if entity is None:
  106. return JSONResponse(
  107. status_code=401,
  108. content={"detail": "invalid api key"},
  109. )
  110. if entity.expires_time is not None and entity.expires_time <= datetime.utcnow():
  111. return JSONResponse(
  112. status_code=401,
  113. content={"detail": "api key expired"},
  114. )
  115. context = get_gateway_request_context(request)
  116. requested_tenant_id = resolve_tenant_id(request)
  117. if requested_tenant_id not in {DEFAULT_TENANT_ID, entity.tenant_id}:
  118. return JSONResponse(
  119. status_code=403,
  120. content={"detail": "api key tenant mismatch"},
  121. )
  122. context.tenant_id = entity.tenant_id
  123. context.api_key_id = entity.id
  124. ApiKeyRepository(db).touch_last_used_time(api_key_id=entity.id)
  125. finally:
  126. db.close()
  127. return None
  128. def is_initial_api_key_bootstrap_request(request: Request) -> bool:
  129. if request.method.upper() != "POST" or request.url.path != "/gateway/api-keys":
  130. return False
  131. db = request.app.state.session_factory()
  132. try:
  133. return not ApiKeyRepository(db).has_any()
  134. finally:
  135. db.close()