proxy.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from dataclasses import dataclass
  2. from typing import Literal
  3. import httpx
  4. from fastapi import Request, Response
  5. from app.infrastructure.audit import mark_gateway_target
  6. from app.infrastructure.request_context import REQUEST_ID_HEADER, TENANT_ID_HEADER, get_gateway_request_context
  7. from app.schemas.gateway import DownstreamServiceHealth
  8. ProxyServiceName = Literal[
  9. "workflow-service",
  10. "session-service",
  11. "runtime-service",
  12. "tool-service",
  13. "model-gateway-service",
  14. "code-runner-service",
  15. ]
  16. @dataclass(frozen=True)
  17. class ProxyTarget:
  18. service_name: ProxyServiceName
  19. base_url: str
  20. path_prefix: str
  21. health_path: str
  22. class ServiceProxy:
  23. def __init__(self, *, timeout_seconds: float) -> None:
  24. self.timeout_seconds = timeout_seconds
  25. async def forward(
  26. self,
  27. *,
  28. request: Request,
  29. target: ProxyTarget,
  30. path: str,
  31. ) -> Response:
  32. target_url = build_target_url(target=target, path=path)
  33. mark_gateway_target(
  34. request,
  35. target_service=target.service_name,
  36. target_url=target_url,
  37. )
  38. headers = build_forward_headers(request)
  39. request_context = get_gateway_request_context(request)
  40. headers[REQUEST_ID_HEADER] = request_context.request_id
  41. headers[TENANT_ID_HEADER] = request_context.tenant_id
  42. body = await request.body()
  43. async with httpx.AsyncClient(timeout=self.timeout_seconds) as client:
  44. upstream_response = await client.request(
  45. method=request.method,
  46. url=target_url,
  47. params=request.query_params,
  48. headers=headers,
  49. content=body,
  50. )
  51. return Response(
  52. content=upstream_response.content,
  53. status_code=upstream_response.status_code,
  54. headers=build_response_headers(upstream_response),
  55. media_type=upstream_response.headers.get("content-type"),
  56. )
  57. async def check_health(self, target: ProxyTarget) -> DownstreamServiceHealth:
  58. health_url = f"{target.base_url.rstrip('/')}{target.health_path}"
  59. try:
  60. async with httpx.AsyncClient(timeout=self.timeout_seconds) as client:
  61. response = await client.get(health_url)
  62. except httpx.HTTPError as exc:
  63. return DownstreamServiceHealth(
  64. service=target.service_name,
  65. status="error",
  66. url=health_url,
  67. error_message=str(exc),
  68. )
  69. return DownstreamServiceHealth(
  70. service=target.service_name,
  71. status="ok" if response.is_success else "error",
  72. url=health_url,
  73. status_code=response.status_code,
  74. error_message=None if response.is_success else response.text,
  75. )
  76. def build_target_url(*, target: ProxyTarget, path: str) -> str:
  77. normalized_path = path.strip("/")
  78. if normalized_path:
  79. return f"{target.base_url.rstrip('/')}{target.path_prefix}/{normalized_path}"
  80. return f"{target.base_url.rstrip('/')}{target.path_prefix}"
  81. def build_forward_headers(request: Request) -> dict[str, str]:
  82. skipped_headers = {"host", "content-length", "connection", REQUEST_ID_HEADER, TENANT_ID_HEADER}
  83. return {
  84. key: value
  85. for key, value in request.headers.items()
  86. if key.lower() not in skipped_headers
  87. }
  88. def build_response_headers(response: httpx.Response) -> dict[str, str]:
  89. skipped_headers = {"content-length", "transfer-encoding", "connection"}
  90. return {
  91. key: value
  92. for key, value in response.headers.items()
  93. if key.lower() not in skipped_headers
  94. }