request_context.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. from dataclasses import dataclass
  2. from datetime import datetime
  3. from time import perf_counter
  4. from uuid import uuid4
  5. import httpx
  6. from core_shared.security import build_internal_service_headers
  7. from fastapi import Request, Response
  8. from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
  9. from starlette.responses import JSONResponse
  10. from app.bootstrap.settings import ApiGatewaySettings
  11. from app.domain.repositories import ApiKeyRepository
  12. from app.infrastructure.api_keys import hash_api_key
  13. from app.infrastructure.rate_limit import (
  14. apply_gateway_rate_limit_headers,
  15. enforce_gateway_rate_limit,
  16. )
  17. REQUEST_ID_HEADER = "x-request-id"
  18. @dataclass
  19. class GatewayRequestContext:
  20. request_id: str
  21. started_perf_counter: float
  22. api_key_id: str | None = None
  23. user_id: str | None = None
  24. target_service: str | None = None
  25. target_url: str | None = None
  26. class GatewayRequestContextMiddleware(BaseHTTPMiddleware):
  27. async def dispatch(
  28. self,
  29. request: Request,
  30. call_next: RequestResponseEndpoint) -> Response:
  31. request_id = request.headers.get(REQUEST_ID_HEADER) or str(uuid4())
  32. request.state.gateway_context = GatewayRequestContext(
  33. request_id=request_id,
  34. started_perf_counter=perf_counter())
  35. auth_response = authenticate_gateway_request(request)
  36. if auth_response is not None:
  37. from app.infrastructure.audit import persist_gateway_audit
  38. persist_gateway_audit(
  39. request=request,
  40. session_factory=request.app.state.session_factory,
  41. status_code=auth_response.status_code,
  42. error_message=None)
  43. auth_response.headers[REQUEST_ID_HEADER] = request_id
  44. return auth_response
  45. settings = request.app.state.settings
  46. rate_limit_response = enforce_gateway_rate_limit(
  47. request=request,
  48. settings=settings,
  49. limiter=request.app.state.rate_limiter)
  50. if rate_limit_response is not None:
  51. from app.infrastructure.audit import persist_gateway_audit
  52. persist_gateway_audit(
  53. request=request,
  54. session_factory=request.app.state.session_factory,
  55. status_code=rate_limit_response.status_code,
  56. error_message="rate limit exceeded")
  57. rate_limit_response.headers[REQUEST_ID_HEADER] = request_id
  58. return rate_limit_response
  59. try:
  60. response = await call_next(request)
  61. except Exception as exc:
  62. from app.infrastructure.audit import persist_gateway_audit
  63. persist_gateway_audit(
  64. request=request,
  65. session_factory=request.app.state.session_factory,
  66. status_code=500,
  67. error_message=str(exc))
  68. raise
  69. from app.infrastructure.audit import persist_gateway_audit
  70. persist_gateway_audit(
  71. request=request,
  72. session_factory=request.app.state.session_factory,
  73. status_code=response.status_code)
  74. response.headers[REQUEST_ID_HEADER] = request_id
  75. apply_gateway_rate_limit_headers(
  76. response=response,
  77. request=request,
  78. settings=settings,
  79. limiter=request.app.state.rate_limiter)
  80. return response
  81. def get_gateway_request_context(request: Request) -> GatewayRequestContext:
  82. context = getattr(request.state, "gateway_context", None)
  83. if isinstance(context, GatewayRequestContext):
  84. return context
  85. return GatewayRequestContext(
  86. request_id=str(uuid4()),
  87. started_perf_counter=perf_counter())
  88. def authenticate_gateway_request(request: Request) -> Response | None:
  89. settings = ApiGatewaySettings()
  90. if not settings.auth_required:
  91. return None
  92. if not request.url.path.startswith("/gateway/"):
  93. return None
  94. if request.url.path in {"/gateway/services/health"}:
  95. return None
  96. if request.url.path.startswith("/gateway/openapi/"):
  97. return None
  98. if is_auth_login_request(request):
  99. return None
  100. if is_initial_api_key_bootstrap_request(request):
  101. return None
  102. bearer_token = get_bearer_token(request)
  103. if bearer_token is not None:
  104. return authenticate_bearer_token(request=request, settings=settings, token=bearer_token)
  105. api_key = request.headers.get(settings.api_key_header_name)
  106. if not api_key:
  107. return JSONResponse(
  108. status_code=401,
  109. content={"detail": "missing bearer token or api key"})
  110. db = request.app.state.session_factory()
  111. try:
  112. entity = ApiKeyRepository(db).get_active_by_hash(key_hash=hash_api_key(api_key))
  113. if entity is None:
  114. return JSONResponse(
  115. status_code=401,
  116. content={"detail": "invalid api key"})
  117. if entity.expires_time is not None and entity.expires_time <= datetime.utcnow():
  118. return JSONResponse(
  119. status_code=401,
  120. content={"detail": "api key expired"})
  121. context = get_gateway_request_context(request)
  122. context.api_key_id = entity.id
  123. context.user_id = request.headers.get(settings.user_id_header_name)
  124. permission = derive_gateway_permission(request)
  125. if permission is not None and not api_key_scope_allows(
  126. scopes=entity.scopes,
  127. permission=permission):
  128. return JSONResponse(
  129. status_code=403,
  130. content={"detail": "api key scope denied", "permission": permission})
  131. if settings.authz_required:
  132. if context.user_id is None:
  133. return JSONResponse(
  134. status_code=401,
  135. content={"detail": "missing user id"})
  136. authz_response = check_auth_service_permission(
  137. settings=settings,
  138. user_id=context.user_id,
  139. permission=permission or "gateway:access")
  140. if authz_response is not None:
  141. return authz_response
  142. ApiKeyRepository(db).touch_last_used_time(api_key_id=entity.id)
  143. finally:
  144. db.close()
  145. return None
  146. def is_auth_login_request(request: Request) -> bool:
  147. return request.method.upper() == "POST" and request.url.path == "/gateway/identity/auth/login"
  148. def is_initial_api_key_bootstrap_request(request: Request) -> bool:
  149. if request.method.upper() != "POST" or request.url.path != "/gateway/api-keys":
  150. return False
  151. db = request.app.state.session_factory()
  152. try:
  153. return not ApiKeyRepository(db).has_any()
  154. finally:
  155. db.close()
  156. def get_bearer_token(request: Request) -> str | None:
  157. authorization = request.headers.get("authorization")
  158. if authorization is None:
  159. return None
  160. scheme, _, token = authorization.partition(" ")
  161. if scheme.lower() != "bearer" or not token.strip():
  162. return None
  163. return token.strip()
  164. def authenticate_bearer_token(
  165. *,
  166. request: Request,
  167. settings: ApiGatewaySettings,
  168. token: str) -> Response | None:
  169. try:
  170. with httpx.Client(timeout=settings.authz_timeout_seconds) as client:
  171. response = client.post(
  172. f"{settings.auth_service_url.rstrip('/')}/identity/auth/tokens/verify",
  173. headers=build_internal_service_headers(settings),
  174. json={"accessToken": token})
  175. response.raise_for_status()
  176. payload = response.json()
  177. except (httpx.HTTPError, ValueError) as exc:
  178. return JSONResponse(
  179. status_code=503,
  180. content={"detail": "auth token verification failed", "error": str(exc)})
  181. data = payload.get("data")
  182. if not isinstance(data, dict):
  183. return JSONResponse(
  184. status_code=401,
  185. content={"detail": "invalid token verification response"})
  186. if data.get("active") is not True:
  187. reason = data.get("reason")
  188. return JSONResponse(
  189. status_code=401,
  190. content={
  191. "detail": "invalid bearer token",
  192. "reason": reason if isinstance(reason, str) else "inactive",
  193. })
  194. user_id = data.get("userId")
  195. if not isinstance(user_id, str):
  196. return JSONResponse(
  197. status_code=401,
  198. content={"detail": "invalid token identity"})
  199. context = get_gateway_request_context(request)
  200. context.user_id = user_id
  201. if settings.authz_required:
  202. permission = derive_gateway_permission(request) or "gateway:access"
  203. authz_response = check_auth_service_permission(
  204. settings=settings,
  205. user_id=user_id,
  206. permission=permission)
  207. if authz_response is not None:
  208. return authz_response
  209. return None
  210. def derive_gateway_permission(request: Request) -> str | None:
  211. if not request.url.path.startswith("/gateway/"):
  212. return None
  213. path_parts = [part for part in request.url.path.split("/") if part]
  214. if len(path_parts) < 2:
  215. return "gateway:access"
  216. if path_parts[1] in {"services", "api-keys", "audits"}:
  217. resource = path_parts[1]
  218. else:
  219. resource = path_parts[1].replace("_", "-")
  220. action = "read" if request.method.upper() in {"GET", "HEAD", "OPTIONS"} else "write"
  221. return f"gateway:{resource}:{action}"
  222. def api_key_scope_allows(*, scopes: str | None, permission: str) -> bool:
  223. if scopes is None or not scopes.strip():
  224. return True
  225. scope_values = parse_scope_values(scopes)
  226. if "*" in scope_values or permission in scope_values:
  227. return True
  228. resource_prefix = permission.rsplit(":", 1)[0]
  229. return f"{resource_prefix}:*" in scope_values
  230. def parse_scope_values(scopes: str) -> set[str]:
  231. normalized = scopes.replace(",", " ").replace("\n", " ")
  232. return {item.strip() for item in normalized.split(" ") if item.strip()}
  233. def check_auth_service_permission(
  234. *,
  235. settings: ApiGatewaySettings,
  236. user_id: str,
  237. permission: str) -> Response | None:
  238. try:
  239. with httpx.Client(timeout=settings.authz_timeout_seconds) as client:
  240. response = client.post(
  241. f"{settings.auth_service_url.rstrip('/')}/identity/permissions/check",
  242. headers=build_internal_service_headers(settings),
  243. json={
  244. "userId": user_id,
  245. "permission": permission,
  246. })
  247. response.raise_for_status()
  248. payload = response.json()
  249. except (httpx.HTTPError, ValueError) as exc:
  250. return JSONResponse(
  251. status_code=503,
  252. content={"detail": "auth service permission check failed", "error": str(exc)})
  253. data = payload.get("data")
  254. if not isinstance(data, dict):
  255. return JSONResponse(
  256. status_code=403,
  257. content={"detail": "invalid permission check response"})
  258. allowed = data.get("allowed")
  259. if allowed is True:
  260. return None
  261. reason = data.get("reason")
  262. return JSONResponse(
  263. status_code=403,
  264. content={
  265. "detail": "permission denied",
  266. "permission": permission,
  267. "reason": reason if isinstance(reason, str) else "denied",
  268. })