proxy.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. from dataclasses import dataclass
  2. from typing import Literal
  3. import httpx
  4. from fastapi import Request, Response
  5. from core_shared.observability import PARENT_SPAN_ID_HEADER, SPAN_ID_HEADER, TRACE_ID_HEADER
  6. from core_shared.security import build_internal_service_headers
  7. from app.infrastructure.audit import mark_gateway_target
  8. from app.bootstrap.settings import ApiGatewaySettings
  9. from app.infrastructure.request_context import (
  10. REQUEST_ID_HEADER,
  11. TENANT_ID_HEADER,
  12. get_gateway_request_context,
  13. )
  14. from app.schemas.gateway import DownstreamServiceHealth
  15. ProxyServiceName = Literal[
  16. "workflow-service",
  17. "session-service",
  18. "runtime-service",
  19. "tool-service",
  20. "model-gateway-service",
  21. "code-runner-service",
  22. "agent-service",
  23. "memory-service",
  24. "team-service",
  25. "skill-service",
  26. "human-service",
  27. "knowledge-service",
  28. "event-service",
  29. "auth-service",
  30. "scheduler-service",
  31. ]
  32. @dataclass(frozen=True)
  33. class ProxyTarget:
  34. service_name: ProxyServiceName
  35. base_url: str
  36. path_prefix: str
  37. health_path: str
  38. class ServiceProxy:
  39. def __init__(self, *, settings: ApiGatewaySettings, timeout_seconds: float) -> None:
  40. self.settings = settings
  41. self.timeout_seconds = timeout_seconds
  42. async def forward(
  43. self,
  44. *,
  45. request: Request,
  46. target: ProxyTarget,
  47. path: str,
  48. ) -> Response:
  49. target_url = build_target_url(target=target, path=path)
  50. mark_gateway_target(
  51. request,
  52. target_service=target.service_name,
  53. target_url=target_url,
  54. )
  55. headers = build_forward_headers(request)
  56. request_context = get_gateway_request_context(request)
  57. headers[REQUEST_ID_HEADER] = request_context.request_id
  58. headers[TENANT_ID_HEADER] = request_context.tenant_id
  59. trace_id = getattr(request.state, "trace_id", None)
  60. span_id = getattr(request.state, "span_id", None)
  61. if isinstance(trace_id, str):
  62. headers[TRACE_ID_HEADER] = trace_id
  63. if isinstance(span_id, str):
  64. headers[PARENT_SPAN_ID_HEADER] = span_id
  65. headers.update(build_internal_service_headers(self.settings))
  66. body = await request.body()
  67. async with httpx.AsyncClient(timeout=self.timeout_seconds) as client:
  68. upstream_response = await client.request(
  69. method=request.method,
  70. url=target_url,
  71. params=request.query_params,
  72. headers=headers,
  73. content=body,
  74. )
  75. return Response(
  76. content=upstream_response.content,
  77. status_code=upstream_response.status_code,
  78. headers=build_response_headers(upstream_response),
  79. media_type=upstream_response.headers.get("content-type"),
  80. )
  81. async def check_health(self, target: ProxyTarget) -> DownstreamServiceHealth:
  82. health_url = f"{target.base_url.rstrip('/')}{target.health_path}"
  83. try:
  84. async with httpx.AsyncClient(timeout=self.timeout_seconds) as client:
  85. response = await client.get(
  86. health_url,
  87. headers=build_internal_service_headers(self.settings),
  88. )
  89. except httpx.HTTPError as exc:
  90. return DownstreamServiceHealth(
  91. service=target.service_name,
  92. status="error",
  93. url=health_url,
  94. error_message=str(exc),
  95. )
  96. return DownstreamServiceHealth(
  97. service=target.service_name,
  98. status="ok" if response.is_success else "error",
  99. url=health_url,
  100. status_code=response.status_code,
  101. error_message=None if response.is_success else response.text,
  102. )
  103. def build_target_url(*, target: ProxyTarget, path: str) -> str:
  104. normalized_path = path.strip("/")
  105. if normalized_path:
  106. return f"{target.base_url.rstrip('/')}{target.path_prefix}/{normalized_path}"
  107. return f"{target.base_url.rstrip('/')}{target.path_prefix}"
  108. def build_forward_headers(request: Request) -> dict[str, str]:
  109. skipped_headers = {
  110. "host",
  111. "content-length",
  112. "connection",
  113. REQUEST_ID_HEADER,
  114. TENANT_ID_HEADER,
  115. "x-internal-service-token",
  116. "x-internal-service-name",
  117. TRACE_ID_HEADER,
  118. SPAN_ID_HEADER,
  119. PARENT_SPAN_ID_HEADER,
  120. }
  121. return {
  122. key: value
  123. for key, value in request.headers.items()
  124. if key.lower() not in skipped_headers
  125. }
  126. def build_response_headers(response: httpx.Response) -> dict[str, str]:
  127. skipped_headers = {"content-length", "transfer-encoding", "connection"}
  128. return {
  129. key: value
  130. for key, value in response.headers.items()
  131. if key.lower() not in skipped_headers
  132. }