request_context.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. from dataclasses import dataclass
  2. from datetime import datetime
  3. from time import perf_counter
  4. from uuid import uuid4
  5. import httpx
  6. from fastapi import Request, Response
  7. from starlette.responses import JSONResponse
  8. from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
  9. from app.bootstrap.settings import ApiGatewaySettings
  10. from app.domain.repositories import ApiKeyRepository
  11. from app.infrastructure.api_keys import hash_api_key
  12. REQUEST_ID_HEADER = "x-request-id"
  13. TENANT_ID_HEADER = "x-tenant-id"
  14. DEFAULT_TENANT_ID = "public"
  15. @dataclass
  16. class GatewayRequestContext:
  17. request_id: str
  18. tenant_id: str
  19. started_perf_counter: float
  20. api_key_id: str | None = None
  21. user_id: str | None = None
  22. target_service: str | None = None
  23. target_url: str | None = None
  24. class GatewayRequestContextMiddleware(BaseHTTPMiddleware):
  25. async def dispatch(
  26. self,
  27. request: Request,
  28. call_next: RequestResponseEndpoint,
  29. ) -> Response:
  30. request_id = request.headers.get(REQUEST_ID_HEADER) or str(uuid4())
  31. tenant_id = resolve_tenant_id(request)
  32. request.state.gateway_context = GatewayRequestContext(
  33. request_id=request_id,
  34. tenant_id=tenant_id,
  35. started_perf_counter=perf_counter(),
  36. )
  37. auth_response = authenticate_gateway_request(request)
  38. if auth_response is not None:
  39. from app.infrastructure.audit import persist_gateway_audit
  40. persist_gateway_audit(
  41. request=request,
  42. session_factory=request.app.state.session_factory,
  43. status_code=auth_response.status_code,
  44. error_message=None,
  45. )
  46. context = get_gateway_request_context(request)
  47. auth_response.headers[REQUEST_ID_HEADER] = request_id
  48. auth_response.headers[TENANT_ID_HEADER] = context.tenant_id
  49. return auth_response
  50. try:
  51. response = await call_next(request)
  52. except Exception as exc:
  53. from app.infrastructure.audit import persist_gateway_audit
  54. persist_gateway_audit(
  55. request=request,
  56. session_factory=request.app.state.session_factory,
  57. status_code=500,
  58. error_message=str(exc),
  59. )
  60. raise
  61. from app.infrastructure.audit import persist_gateway_audit
  62. persist_gateway_audit(
  63. request=request,
  64. session_factory=request.app.state.session_factory,
  65. status_code=response.status_code,
  66. )
  67. context = get_gateway_request_context(request)
  68. response.headers[REQUEST_ID_HEADER] = request_id
  69. response.headers[TENANT_ID_HEADER] = context.tenant_id
  70. return response
  71. def resolve_tenant_id(request: Request) -> str:
  72. header_tenant_id = request.headers.get(TENANT_ID_HEADER)
  73. if header_tenant_id:
  74. return header_tenant_id
  75. query_tenant_id = request.query_params.get("tenant_id")
  76. if query_tenant_id:
  77. return query_tenant_id
  78. return DEFAULT_TENANT_ID
  79. def get_gateway_request_context(request: Request) -> GatewayRequestContext:
  80. context = getattr(request.state, "gateway_context", None)
  81. if isinstance(context, GatewayRequestContext):
  82. return context
  83. return GatewayRequestContext(
  84. request_id=str(uuid4()),
  85. tenant_id=DEFAULT_TENANT_ID,
  86. started_perf_counter=perf_counter(),
  87. )
  88. def authenticate_gateway_request(request: Request) -> Response | None:
  89. settings = ApiGatewaySettings()
  90. if not settings.auth_required:
  91. return None
  92. if not request.url.path.startswith("/gateway/"):
  93. return None
  94. if request.url.path in {"/gateway/services/health"}:
  95. return None
  96. if is_initial_api_key_bootstrap_request(request):
  97. return None
  98. api_key = request.headers.get(settings.api_key_header_name)
  99. if not api_key:
  100. return JSONResponse(
  101. status_code=401,
  102. content={"detail": "missing api key"},
  103. )
  104. db = request.app.state.session_factory()
  105. try:
  106. entity = ApiKeyRepository(db).get_active_by_hash(key_hash=hash_api_key(api_key))
  107. if entity is None:
  108. return JSONResponse(
  109. status_code=401,
  110. content={"detail": "invalid api key"},
  111. )
  112. if entity.expires_time is not None and entity.expires_time <= datetime.utcnow():
  113. return JSONResponse(
  114. status_code=401,
  115. content={"detail": "api key expired"},
  116. )
  117. context = get_gateway_request_context(request)
  118. requested_tenant_id = resolve_tenant_id(request)
  119. if requested_tenant_id not in {DEFAULT_TENANT_ID, entity.tenant_id}:
  120. return JSONResponse(
  121. status_code=403,
  122. content={"detail": "api key tenant mismatch"},
  123. )
  124. context.tenant_id = entity.tenant_id
  125. context.api_key_id = entity.id
  126. context.user_id = request.headers.get(settings.user_id_header_name)
  127. permission = derive_gateway_permission(request)
  128. if permission is not None and not api_key_scope_allows(
  129. scopes=entity.scopes,
  130. permission=permission,
  131. ):
  132. return JSONResponse(
  133. status_code=403,
  134. content={"detail": "api key scope denied", "permission": permission},
  135. )
  136. if settings.authz_required:
  137. if context.user_id is None:
  138. return JSONResponse(
  139. status_code=401,
  140. content={"detail": "missing user id"},
  141. )
  142. authz_response = check_auth_service_permission(
  143. settings=settings,
  144. tenant_id=entity.tenant_id,
  145. user_id=context.user_id,
  146. permission=permission or "gateway:access",
  147. )
  148. if authz_response is not None:
  149. return authz_response
  150. ApiKeyRepository(db).touch_last_used_time(api_key_id=entity.id)
  151. finally:
  152. db.close()
  153. return None
  154. def is_initial_api_key_bootstrap_request(request: Request) -> bool:
  155. if request.method.upper() != "POST" or request.url.path != "/gateway/api-keys":
  156. return False
  157. db = request.app.state.session_factory()
  158. try:
  159. return not ApiKeyRepository(db).has_any()
  160. finally:
  161. db.close()
  162. def derive_gateway_permission(request: Request) -> str | None:
  163. if not request.url.path.startswith("/gateway/"):
  164. return None
  165. path_parts = [part for part in request.url.path.split("/") if part]
  166. if len(path_parts) < 2:
  167. return "gateway:access"
  168. if path_parts[1] in {"services", "api-keys", "audits"}:
  169. resource = path_parts[1]
  170. else:
  171. resource = path_parts[1].replace("_", "-")
  172. action = "read" if request.method.upper() in {"GET", "HEAD", "OPTIONS"} else "write"
  173. return f"gateway:{resource}:{action}"
  174. def api_key_scope_allows(*, scopes: str | None, permission: str) -> bool:
  175. if scopes is None or not scopes.strip():
  176. return True
  177. scope_values = parse_scope_values(scopes)
  178. if "*" in scope_values or permission in scope_values:
  179. return True
  180. resource_prefix = permission.rsplit(":", 1)[0]
  181. return f"{resource_prefix}:*" in scope_values
  182. def parse_scope_values(scopes: str) -> set[str]:
  183. normalized = scopes.replace(",", " ").replace("\n", " ")
  184. return {item.strip() for item in normalized.split(" ") if item.strip()}
  185. def check_auth_service_permission(
  186. *,
  187. settings: ApiGatewaySettings,
  188. tenant_id: str,
  189. user_id: str,
  190. permission: str,
  191. ) -> Response | None:
  192. try:
  193. with httpx.Client(timeout=settings.authz_timeout_seconds) as client:
  194. response = client.post(
  195. f"{settings.auth_service_url.rstrip('/')}/auth/permissions/check",
  196. json={
  197. "tenant_id": tenant_id,
  198. "user_id": user_id,
  199. "permission": permission,
  200. },
  201. )
  202. response.raise_for_status()
  203. payload = response.json()
  204. except (httpx.HTTPError, ValueError) as exc:
  205. return JSONResponse(
  206. status_code=503,
  207. content={"detail": "auth service permission check failed", "error": str(exc)},
  208. )
  209. allowed = payload.get("allowed")
  210. if allowed is True:
  211. return None
  212. reason = payload.get("reason")
  213. return JSONResponse(
  214. status_code=403,
  215. content={
  216. "detail": "permission denied",
  217. "permission": permission,
  218. "reason": reason if isinstance(reason, str) else "denied",
  219. },
  220. )