request_context.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. from dataclasses import dataclass
  2. from datetime import datetime
  3. from time import perf_counter
  4. from uuid import uuid4
  5. import httpx
  6. from core_shared.security import build_internal_service_headers
  7. from fastapi import Request, Response
  8. from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
  9. from starlette.responses import JSONResponse
  10. from app.bootstrap.settings import ApiGatewaySettings
  11. from app.domain.repositories import ApiKeyRepository
  12. from app.infrastructure.api_keys import hash_api_key
  13. from app.infrastructure.rate_limit import (
  14. apply_gateway_rate_limit_headers,
  15. enforce_gateway_rate_limit,
  16. )
  17. REQUEST_ID_HEADER = "x-request-id"
  18. @dataclass
  19. class GatewayRequestContext:
  20. request_id: str
  21. started_perf_counter: float
  22. api_key_id: str | None = None
  23. user_id: str | None = None
  24. target_service: str | None = None
  25. target_url: str | None = None
  26. class GatewayRequestContextMiddleware(BaseHTTPMiddleware):
  27. async def dispatch(
  28. self,
  29. request: Request,
  30. call_next: RequestResponseEndpoint) -> Response:
  31. request_id = request.headers.get(REQUEST_ID_HEADER) or str(uuid4())
  32. request.state.gateway_context = GatewayRequestContext(
  33. request_id=request_id,
  34. started_perf_counter=perf_counter())
  35. auth_response = authenticate_gateway_request(request)
  36. if auth_response is not None:
  37. from app.infrastructure.audit import persist_gateway_audit
  38. persist_gateway_audit(
  39. request=request,
  40. session_factory=request.app.state.session_factory,
  41. status_code=auth_response.status_code,
  42. error_message=None)
  43. context = get_gateway_request_context(request)
  44. auth_response.headers[REQUEST_ID_HEADER] = request_id
  45. return auth_response
  46. settings = request.app.state.settings
  47. rate_limit_response = enforce_gateway_rate_limit(
  48. request=request,
  49. settings=settings,
  50. limiter=request.app.state.rate_limiter)
  51. if rate_limit_response is not None:
  52. from app.infrastructure.audit import persist_gateway_audit
  53. persist_gateway_audit(
  54. request=request,
  55. session_factory=request.app.state.session_factory,
  56. status_code=rate_limit_response.status_code,
  57. error_message="rate limit exceeded")
  58. context = get_gateway_request_context(request)
  59. rate_limit_response.headers[REQUEST_ID_HEADER] = request_id
  60. return rate_limit_response
  61. try:
  62. response = await call_next(request)
  63. except Exception as exc:
  64. from app.infrastructure.audit import persist_gateway_audit
  65. persist_gateway_audit(
  66. request=request,
  67. session_factory=request.app.state.session_factory,
  68. status_code=500,
  69. error_message=str(exc))
  70. raise
  71. from app.infrastructure.audit import persist_gateway_audit
  72. persist_gateway_audit(
  73. request=request,
  74. session_factory=request.app.state.session_factory,
  75. status_code=response.status_code)
  76. context = get_gateway_request_context(request)
  77. response.headers[REQUEST_ID_HEADER] = request_id
  78. apply_gateway_rate_limit_headers(
  79. response=response,
  80. request=request,
  81. settings=settings,
  82. limiter=request.app.state.rate_limiter)
  83. return response
  84. def get_gateway_request_context(request: Request) -> GatewayRequestContext:
  85. context = getattr(request.state, "gateway_context", None)
  86. if isinstance(context, GatewayRequestContext):
  87. return context
  88. return GatewayRequestContext(
  89. request_id=str(uuid4()),
  90. started_perf_counter=perf_counter())
  91. def authenticate_gateway_request(request: Request) -> Response | None:
  92. settings = ApiGatewaySettings()
  93. if not settings.auth_required:
  94. return None
  95. if not request.url.path.startswith("/gateway/"):
  96. return None
  97. if request.url.path in {"/gateway/services/health"}:
  98. return None
  99. if is_auth_login_request(request):
  100. return None
  101. if is_initial_api_key_bootstrap_request(request):
  102. return None
  103. bearer_token = get_bearer_token(request)
  104. if bearer_token is not None:
  105. return authenticate_bearer_token(request=request, settings=settings, token=bearer_token)
  106. api_key = request.headers.get(settings.api_key_header_name)
  107. if not api_key:
  108. return JSONResponse(
  109. status_code=401,
  110. content={"detail": "missing bearer token or api key"})
  111. db = request.app.state.session_factory()
  112. try:
  113. entity = ApiKeyRepository(db).get_active_by_hash(key_hash=hash_api_key(api_key))
  114. if entity is None:
  115. return JSONResponse(
  116. status_code=401,
  117. content={"detail": "invalid api key"})
  118. if entity.expires_time is not None and entity.expires_time <= datetime.utcnow():
  119. return JSONResponse(
  120. status_code=401,
  121. content={"detail": "api key expired"})
  122. context = get_gateway_request_context(request)
  123. context.api_key_id = entity.id
  124. context.user_id = request.headers.get(settings.user_id_header_name)
  125. permission = derive_gateway_permission(request)
  126. if permission is not None and not api_key_scope_allows(
  127. scopes=entity.scopes,
  128. permission=permission):
  129. return JSONResponse(
  130. status_code=403,
  131. content={"detail": "api key scope denied", "permission": permission})
  132. if settings.authz_required:
  133. if context.user_id is None:
  134. return JSONResponse(
  135. status_code=401,
  136. content={"detail": "missing user id"})
  137. authz_response = check_auth_service_permission(
  138. settings=settings,
  139. user_id=context.user_id,
  140. permission=permission or "gateway:access")
  141. if authz_response is not None:
  142. return authz_response
  143. ApiKeyRepository(db).touch_last_used_time(api_key_id=entity.id)
  144. finally:
  145. db.close()
  146. return None
  147. def is_auth_login_request(request: Request) -> bool:
  148. return request.method.upper() == "POST" and request.url.path == "/gateway/auth/login"
  149. def is_initial_api_key_bootstrap_request(request: Request) -> bool:
  150. if request.method.upper() != "POST" or request.url.path != "/gateway/api-keys":
  151. return False
  152. db = request.app.state.session_factory()
  153. try:
  154. return not ApiKeyRepository(db).has_any()
  155. finally:
  156. db.close()
  157. def get_bearer_token(request: Request) -> str | None:
  158. authorization = request.headers.get("authorization")
  159. if authorization is None:
  160. return None
  161. scheme, _, token = authorization.partition(" ")
  162. if scheme.lower() != "bearer" or not token.strip():
  163. return None
  164. return token.strip()
  165. def authenticate_bearer_token(
  166. *,
  167. request: Request,
  168. settings: ApiGatewaySettings,
  169. token: str) -> Response | None:
  170. try:
  171. with httpx.Client(timeout=settings.authz_timeout_seconds) as client:
  172. response = client.post(
  173. f"{settings.auth_service_url.rstrip('/')}/auth/tokens/verify",
  174. headers=build_internal_service_headers(settings),
  175. json={"access_token": token})
  176. response.raise_for_status()
  177. payload = response.json()
  178. except (httpx.HTTPError, ValueError) as exc:
  179. return JSONResponse(
  180. status_code=503,
  181. content={"detail": "auth token verification failed", "error": str(exc)})
  182. if payload.get("active") is not True:
  183. reason = payload.get("reason")
  184. return JSONResponse(
  185. status_code=401,
  186. content={
  187. "detail": "invalid bearer token",
  188. "reason": reason if isinstance(reason, str) else "inactive",
  189. })
  190. user_id = payload.get("user_id")
  191. if not isinstance(user_id, str):
  192. return JSONResponse(
  193. status_code=401,
  194. content={"detail": "invalid token identity"})
  195. context = get_gateway_request_context(request)
  196. context.user_id = user_id
  197. if settings.authz_required:
  198. permission = derive_gateway_permission(request) or "gateway:access"
  199. authz_response = check_auth_service_permission(
  200. settings=settings,
  201. user_id=user_id,
  202. permission=permission)
  203. if authz_response is not None:
  204. return authz_response
  205. return None
  206. def derive_gateway_permission(request: Request) -> str | None:
  207. if not request.url.path.startswith("/gateway/"):
  208. return None
  209. path_parts = [part for part in request.url.path.split("/") if part]
  210. if len(path_parts) < 2:
  211. return "gateway:access"
  212. if path_parts[1] in {"services", "api-keys", "audits"}:
  213. resource = path_parts[1]
  214. else:
  215. resource = path_parts[1].replace("_", "-")
  216. action = "read" if request.method.upper() in {"GET", "HEAD", "OPTIONS"} else "write"
  217. return f"gateway:{resource}:{action}"
  218. def api_key_scope_allows(*, scopes: str | None, permission: str) -> bool:
  219. if scopes is None or not scopes.strip():
  220. return True
  221. scope_values = parse_scope_values(scopes)
  222. if "*" in scope_values or permission in scope_values:
  223. return True
  224. resource_prefix = permission.rsplit(":", 1)[0]
  225. return f"{resource_prefix}:*" in scope_values
  226. def parse_scope_values(scopes: str) -> set[str]:
  227. normalized = scopes.replace(",", " ").replace("\n", " ")
  228. return {item.strip() for item in normalized.split(" ") if item.strip()}
  229. def check_auth_service_permission(
  230. *,
  231. settings: ApiGatewaySettings,
  232. user_id: str,
  233. permission: str) -> Response | None:
  234. try:
  235. with httpx.Client(timeout=settings.authz_timeout_seconds) as client:
  236. response = client.post(
  237. f"{settings.auth_service_url.rstrip('/')}/auth/permissions/check",
  238. headers=build_internal_service_headers(settings),
  239. json={
  240. "user_id": user_id,
  241. "permission": permission,
  242. })
  243. response.raise_for_status()
  244. payload = response.json()
  245. except (httpx.HTTPError, ValueError) as exc:
  246. return JSONResponse(
  247. status_code=503,
  248. content={"detail": "auth service permission check failed", "error": str(exc)})
  249. allowed = payload.get("allowed")
  250. if allowed is True:
  251. return None
  252. reason = payload.get("reason")
  253. return JSONResponse(
  254. status_code=403,
  255. content={
  256. "detail": "permission denied",
  257. "permission": permission,
  258. "reason": reason if isinstance(reason, str) else "denied",
  259. })