request_context.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. from dataclasses import dataclass
  2. from datetime import datetime
  3. from time import perf_counter
  4. from uuid import uuid4
  5. import httpx
  6. from fastapi import Request, Response
  7. from starlette.responses import JSONResponse
  8. from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
  9. from core_shared.security import build_internal_service_headers
  10. from app.bootstrap.settings import ApiGatewaySettings
  11. from app.domain.repositories import ApiKeyRepository
  12. from app.infrastructure.api_keys import hash_api_key
  13. from app.infrastructure.rate_limit import (
  14. apply_gateway_rate_limit_headers,
  15. enforce_gateway_rate_limit,
  16. )
  17. REQUEST_ID_HEADER = "x-request-id"
  18. TENANT_ID_HEADER = "x-tenant-id"
  19. DEFAULT_TENANT_ID = "public"
  20. @dataclass
  21. class GatewayRequestContext:
  22. request_id: str
  23. tenant_id: str
  24. started_perf_counter: float
  25. api_key_id: str | None = None
  26. user_id: str | None = None
  27. target_service: str | None = None
  28. target_url: str | None = None
  29. class GatewayRequestContextMiddleware(BaseHTTPMiddleware):
  30. async def dispatch(
  31. self,
  32. request: Request,
  33. call_next: RequestResponseEndpoint,
  34. ) -> Response:
  35. request_id = request.headers.get(REQUEST_ID_HEADER) or str(uuid4())
  36. tenant_id = resolve_tenant_id(request)
  37. request.state.gateway_context = GatewayRequestContext(
  38. request_id=request_id,
  39. tenant_id=tenant_id,
  40. started_perf_counter=perf_counter(),
  41. )
  42. auth_response = authenticate_gateway_request(request)
  43. if auth_response is not None:
  44. from app.infrastructure.audit import persist_gateway_audit
  45. persist_gateway_audit(
  46. request=request,
  47. session_factory=request.app.state.session_factory,
  48. status_code=auth_response.status_code,
  49. error_message=None,
  50. )
  51. context = get_gateway_request_context(request)
  52. auth_response.headers[REQUEST_ID_HEADER] = request_id
  53. auth_response.headers[TENANT_ID_HEADER] = context.tenant_id
  54. return auth_response
  55. settings = request.app.state.settings
  56. rate_limit_response = enforce_gateway_rate_limit(
  57. request=request,
  58. settings=settings,
  59. limiter=request.app.state.rate_limiter,
  60. )
  61. if rate_limit_response is not None:
  62. from app.infrastructure.audit import persist_gateway_audit
  63. persist_gateway_audit(
  64. request=request,
  65. session_factory=request.app.state.session_factory,
  66. status_code=rate_limit_response.status_code,
  67. error_message="rate limit exceeded",
  68. )
  69. context = get_gateway_request_context(request)
  70. rate_limit_response.headers[REQUEST_ID_HEADER] = request_id
  71. rate_limit_response.headers[TENANT_ID_HEADER] = context.tenant_id
  72. return rate_limit_response
  73. try:
  74. response = await call_next(request)
  75. except Exception as exc:
  76. from app.infrastructure.audit import persist_gateway_audit
  77. persist_gateway_audit(
  78. request=request,
  79. session_factory=request.app.state.session_factory,
  80. status_code=500,
  81. error_message=str(exc),
  82. )
  83. raise
  84. from app.infrastructure.audit import persist_gateway_audit
  85. persist_gateway_audit(
  86. request=request,
  87. session_factory=request.app.state.session_factory,
  88. status_code=response.status_code,
  89. )
  90. context = get_gateway_request_context(request)
  91. response.headers[REQUEST_ID_HEADER] = request_id
  92. response.headers[TENANT_ID_HEADER] = context.tenant_id
  93. apply_gateway_rate_limit_headers(
  94. response=response,
  95. request=request,
  96. settings=settings,
  97. limiter=request.app.state.rate_limiter,
  98. )
  99. return response
  100. def resolve_tenant_id(request: Request) -> str:
  101. header_tenant_id = request.headers.get(TENANT_ID_HEADER)
  102. if header_tenant_id:
  103. return header_tenant_id
  104. query_tenant_id = request.query_params.get("tenant_id")
  105. if query_tenant_id:
  106. return query_tenant_id
  107. return DEFAULT_TENANT_ID
  108. def get_gateway_request_context(request: Request) -> GatewayRequestContext:
  109. context = getattr(request.state, "gateway_context", None)
  110. if isinstance(context, GatewayRequestContext):
  111. return context
  112. return GatewayRequestContext(
  113. request_id=str(uuid4()),
  114. tenant_id=DEFAULT_TENANT_ID,
  115. started_perf_counter=perf_counter(),
  116. )
  117. def authenticate_gateway_request(request: Request) -> Response | None:
  118. settings = ApiGatewaySettings()
  119. if not settings.auth_required:
  120. return None
  121. if not request.url.path.startswith("/gateway/"):
  122. return None
  123. if request.url.path in {"/gateway/services/health"}:
  124. return None
  125. if is_initial_api_key_bootstrap_request(request):
  126. return None
  127. api_key = request.headers.get(settings.api_key_header_name)
  128. if not api_key:
  129. return JSONResponse(
  130. status_code=401,
  131. content={"detail": "missing api key"},
  132. )
  133. db = request.app.state.session_factory()
  134. try:
  135. entity = ApiKeyRepository(db).get_active_by_hash(key_hash=hash_api_key(api_key))
  136. if entity is None:
  137. return JSONResponse(
  138. status_code=401,
  139. content={"detail": "invalid api key"},
  140. )
  141. if entity.expires_time is not None and entity.expires_time <= datetime.utcnow():
  142. return JSONResponse(
  143. status_code=401,
  144. content={"detail": "api key expired"},
  145. )
  146. context = get_gateway_request_context(request)
  147. requested_tenant_id = resolve_tenant_id(request)
  148. if requested_tenant_id not in {DEFAULT_TENANT_ID, entity.tenant_id}:
  149. return JSONResponse(
  150. status_code=403,
  151. content={"detail": "api key tenant mismatch"},
  152. )
  153. context.tenant_id = entity.tenant_id
  154. context.api_key_id = entity.id
  155. context.user_id = request.headers.get(settings.user_id_header_name)
  156. permission = derive_gateway_permission(request)
  157. if permission is not None and not api_key_scope_allows(
  158. scopes=entity.scopes,
  159. permission=permission,
  160. ):
  161. return JSONResponse(
  162. status_code=403,
  163. content={"detail": "api key scope denied", "permission": permission},
  164. )
  165. if settings.authz_required:
  166. if context.user_id is None:
  167. return JSONResponse(
  168. status_code=401,
  169. content={"detail": "missing user id"},
  170. )
  171. authz_response = check_auth_service_permission(
  172. settings=settings,
  173. tenant_id=entity.tenant_id,
  174. user_id=context.user_id,
  175. permission=permission or "gateway:access",
  176. )
  177. if authz_response is not None:
  178. return authz_response
  179. ApiKeyRepository(db).touch_last_used_time(api_key_id=entity.id)
  180. finally:
  181. db.close()
  182. return None
  183. def is_initial_api_key_bootstrap_request(request: Request) -> bool:
  184. if request.method.upper() != "POST" or request.url.path != "/gateway/api-keys":
  185. return False
  186. db = request.app.state.session_factory()
  187. try:
  188. return not ApiKeyRepository(db).has_any()
  189. finally:
  190. db.close()
  191. def derive_gateway_permission(request: Request) -> str | None:
  192. if not request.url.path.startswith("/gateway/"):
  193. return None
  194. path_parts = [part for part in request.url.path.split("/") if part]
  195. if len(path_parts) < 2:
  196. return "gateway:access"
  197. if path_parts[1] in {"services", "api-keys", "audits"}:
  198. resource = path_parts[1]
  199. else:
  200. resource = path_parts[1].replace("_", "-")
  201. action = "read" if request.method.upper() in {"GET", "HEAD", "OPTIONS"} else "write"
  202. return f"gateway:{resource}:{action}"
  203. def api_key_scope_allows(*, scopes: str | None, permission: str) -> bool:
  204. if scopes is None or not scopes.strip():
  205. return True
  206. scope_values = parse_scope_values(scopes)
  207. if "*" in scope_values or permission in scope_values:
  208. return True
  209. resource_prefix = permission.rsplit(":", 1)[0]
  210. return f"{resource_prefix}:*" in scope_values
  211. def parse_scope_values(scopes: str) -> set[str]:
  212. normalized = scopes.replace(",", " ").replace("\n", " ")
  213. return {item.strip() for item in normalized.split(" ") if item.strip()}
  214. def check_auth_service_permission(
  215. *,
  216. settings: ApiGatewaySettings,
  217. tenant_id: str,
  218. user_id: str,
  219. permission: str,
  220. ) -> Response | None:
  221. try:
  222. with httpx.Client(timeout=settings.authz_timeout_seconds) as client:
  223. response = client.post(
  224. f"{settings.auth_service_url.rstrip('/')}/auth/permissions/check",
  225. headers=build_internal_service_headers(settings),
  226. json={
  227. "tenant_id": tenant_id,
  228. "user_id": user_id,
  229. "permission": permission,
  230. },
  231. )
  232. response.raise_for_status()
  233. payload = response.json()
  234. except (httpx.HTTPError, ValueError) as exc:
  235. return JSONResponse(
  236. status_code=503,
  237. content={"detail": "auth service permission check failed", "error": str(exc)},
  238. )
  239. allowed = payload.get("allowed")
  240. if allowed is True:
  241. return None
  242. reason = payload.get("reason")
  243. return JSONResponse(
  244. status_code=403,
  245. content={
  246. "detail": "permission denied",
  247. "permission": permission,
  248. "reason": reason if isinstance(reason, str) else "denied",
  249. },
  250. )