proxy.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. from dataclasses import dataclass
  2. from typing import Literal
  3. import httpx
  4. from core_shared.observability import PARENT_SPAN_ID_HEADER, SPAN_ID_HEADER, TRACE_ID_HEADER
  5. from core_shared.security import build_internal_service_headers
  6. from fastapi import Request, Response
  7. from fastapi.responses import StreamingResponse
  8. from app.bootstrap.settings import ApiGatewaySettings
  9. from app.infrastructure.audit import mark_gateway_target
  10. from app.infrastructure.request_context import REQUEST_ID_HEADER, get_gateway_request_context
  11. from app.schemas.gateway import DownstreamServiceHealth
  12. ProxyServiceName = Literal[
  13. "session-service",
  14. "tool-service",
  15. "model-gateway-service",
  16. "model-provider-service",
  17. "code-runner-service",
  18. "agent-service",
  19. "memory-service",
  20. "team-service",
  21. "skill-service",
  22. "human-service",
  23. "knowledge-service",
  24. "event-service",
  25. "identity-service",
  26. "scheduler-service",
  27. ]
  28. @dataclass(frozen=True)
  29. class ProxyTarget:
  30. service_name: ProxyServiceName
  31. base_url: str
  32. path_prefix: str
  33. health_path: str
  34. class ServiceProxy:
  35. def __init__(self, *, settings: ApiGatewaySettings, timeout_seconds: float) -> None:
  36. self.settings = settings
  37. self.timeout_seconds = timeout_seconds
  38. async def forward(
  39. self,
  40. *,
  41. request: Request,
  42. target: ProxyTarget,
  43. path: str) -> Response:
  44. target_url = build_target_url(target=target, path=path)
  45. mark_gateway_target(
  46. request,
  47. target_service=target.service_name,
  48. target_url=target_url)
  49. headers = build_forward_headers(request)
  50. request_context = get_gateway_request_context(request)
  51. headers[REQUEST_ID_HEADER] = request_context.request_id
  52. trace_id = getattr(request.state, "trace_id", None)
  53. span_id = getattr(request.state, "span_id", None)
  54. if isinstance(trace_id, str):
  55. headers[TRACE_ID_HEADER] = trace_id
  56. if isinstance(span_id, str):
  57. headers[PARENT_SPAN_ID_HEADER] = span_id
  58. headers.update(build_internal_service_headers(self.settings))
  59. body = await request.body()
  60. if _expects_stream(request):
  61. client = httpx.AsyncClient(timeout=self.timeout_seconds)
  62. request_builder = client.build_request(
  63. method=request.method,
  64. url=target_url,
  65. params=request.query_params,
  66. headers=headers,
  67. content=body)
  68. upstream_response = await client.send(request_builder, stream=True)
  69. return StreamingResponse(
  70. _stream_upstream_response(client, upstream_response),
  71. status_code=upstream_response.status_code,
  72. headers=build_response_headers(upstream_response),
  73. media_type=upstream_response.headers.get("content-type"))
  74. async with httpx.AsyncClient(timeout=self.timeout_seconds) as client:
  75. upstream_response = await client.request(
  76. method=request.method,
  77. url=target_url,
  78. params=request.query_params,
  79. headers=headers,
  80. content=body)
  81. return Response(
  82. content=upstream_response.content,
  83. status_code=upstream_response.status_code,
  84. headers=build_response_headers(upstream_response),
  85. media_type=upstream_response.headers.get("content-type"))
  86. async def check_health(self, target: ProxyTarget) -> DownstreamServiceHealth:
  87. health_url = f"{target.base_url.rstrip('/')}{target.health_path}"
  88. try:
  89. async with httpx.AsyncClient(timeout=self.timeout_seconds) as client:
  90. response = await client.get(
  91. health_url,
  92. headers=build_internal_service_headers(self.settings))
  93. except httpx.HTTPError as exc:
  94. return DownstreamServiceHealth(
  95. service=target.service_name,
  96. status="error",
  97. url=health_url,
  98. error_message=str(exc))
  99. return DownstreamServiceHealth(
  100. service=target.service_name,
  101. status="ok" if response.is_success else "error",
  102. url=health_url,
  103. status_code=response.status_code,
  104. error_message=None if response.is_success else response.text)
  105. def build_target_url(*, target: ProxyTarget, path: str) -> str:
  106. normalized_path = path.strip("/")
  107. if normalized_path:
  108. return f"{target.base_url.rstrip('/')}{target.path_prefix}/{normalized_path}"
  109. return f"{target.base_url.rstrip('/')}{target.path_prefix}"
  110. def build_forward_headers(request: Request) -> dict[str, str]:
  111. skipped_headers = {
  112. "host",
  113. "content-length",
  114. "connection",
  115. REQUEST_ID_HEADER,
  116. "x-internal-service-token",
  117. "x-internal-service-name",
  118. TRACE_ID_HEADER,
  119. SPAN_ID_HEADER,
  120. PARENT_SPAN_ID_HEADER,
  121. }
  122. return {
  123. key: value
  124. for key, value in request.headers.items()
  125. if key.lower() not in skipped_headers
  126. }
  127. def build_response_headers(response: httpx.Response) -> dict[str, str]:
  128. skipped_headers = {"content-length", "transfer-encoding", "connection"}
  129. return {
  130. key: value
  131. for key, value in response.headers.items()
  132. if key.lower() not in skipped_headers
  133. }
  134. def _expects_stream(request: Request) -> bool:
  135. accept_header = request.headers.get("accept", "")
  136. return "text/event-stream" in accept_header or request.url.path.endswith("stream")
  137. async def _stream_upstream_response(
  138. client: httpx.AsyncClient,
  139. response: httpx.Response):
  140. try:
  141. async for chunk in response.aiter_raw():
  142. yield chunk
  143. finally:
  144. await response.aclose()
  145. await client.aclose()