proxy.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. from dataclasses import dataclass
  2. from typing import Literal
  3. import httpx
  4. from fastapi import Request, Response
  5. from app.infrastructure.audit import mark_gateway_target
  6. from app.infrastructure.request_context import (
  7. REQUEST_ID_HEADER,
  8. TENANT_ID_HEADER,
  9. get_gateway_request_context,
  10. )
  11. from app.schemas.gateway import DownstreamServiceHealth
  12. ProxyServiceName = Literal[
  13. "workflow-service",
  14. "session-service",
  15. "runtime-service",
  16. "tool-service",
  17. "model-gateway-service",
  18. "code-runner-service",
  19. "agent-service",
  20. "memory-service",
  21. "team-service",
  22. "skill-service",
  23. "human-service",
  24. "knowledge-service",
  25. "event-service",
  26. ]
  27. @dataclass(frozen=True)
  28. class ProxyTarget:
  29. service_name: ProxyServiceName
  30. base_url: str
  31. path_prefix: str
  32. health_path: str
  33. class ServiceProxy:
  34. def __init__(self, *, timeout_seconds: float) -> None:
  35. self.timeout_seconds = timeout_seconds
  36. async def forward(
  37. self,
  38. *,
  39. request: Request,
  40. target: ProxyTarget,
  41. path: str,
  42. ) -> Response:
  43. target_url = build_target_url(target=target, path=path)
  44. mark_gateway_target(
  45. request,
  46. target_service=target.service_name,
  47. target_url=target_url,
  48. )
  49. headers = build_forward_headers(request)
  50. request_context = get_gateway_request_context(request)
  51. headers[REQUEST_ID_HEADER] = request_context.request_id
  52. headers[TENANT_ID_HEADER] = request_context.tenant_id
  53. body = await request.body()
  54. async with httpx.AsyncClient(timeout=self.timeout_seconds) as client:
  55. upstream_response = await client.request(
  56. method=request.method,
  57. url=target_url,
  58. params=request.query_params,
  59. headers=headers,
  60. content=body,
  61. )
  62. return Response(
  63. content=upstream_response.content,
  64. status_code=upstream_response.status_code,
  65. headers=build_response_headers(upstream_response),
  66. media_type=upstream_response.headers.get("content-type"),
  67. )
  68. async def check_health(self, target: ProxyTarget) -> DownstreamServiceHealth:
  69. health_url = f"{target.base_url.rstrip('/')}{target.health_path}"
  70. try:
  71. async with httpx.AsyncClient(timeout=self.timeout_seconds) as client:
  72. response = await client.get(health_url)
  73. except httpx.HTTPError as exc:
  74. return DownstreamServiceHealth(
  75. service=target.service_name,
  76. status="error",
  77. url=health_url,
  78. error_message=str(exc),
  79. )
  80. return DownstreamServiceHealth(
  81. service=target.service_name,
  82. status="ok" if response.is_success else "error",
  83. url=health_url,
  84. status_code=response.status_code,
  85. error_message=None if response.is_success else response.text,
  86. )
  87. def build_target_url(*, target: ProxyTarget, path: str) -> str:
  88. normalized_path = path.strip("/")
  89. if normalized_path:
  90. return f"{target.base_url.rstrip('/')}{target.path_prefix}/{normalized_path}"
  91. return f"{target.base_url.rstrip('/')}{target.path_prefix}"
  92. def build_forward_headers(request: Request) -> dict[str, str]:
  93. skipped_headers = {"host", "content-length", "connection", REQUEST_ID_HEADER, TENANT_ID_HEADER}
  94. return {
  95. key: value
  96. for key, value in request.headers.items()
  97. if key.lower() not in skipped_headers
  98. }
  99. def build_response_headers(response: httpx.Response) -> dict[str, str]:
  100. skipped_headers = {"content-length", "transfer-encoding", "connection"}
  101. return {
  102. key: value
  103. for key, value in response.headers.items()
  104. if key.lower() not in skipped_headers
  105. }