proxy.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. from dataclasses import dataclass
  2. from typing import Literal
  3. import httpx
  4. from fastapi import Request, Response
  5. from core_shared.security import build_internal_service_headers
  6. from app.infrastructure.audit import mark_gateway_target
  7. from app.bootstrap.settings import ApiGatewaySettings
  8. from app.infrastructure.request_context import (
  9. REQUEST_ID_HEADER,
  10. TENANT_ID_HEADER,
  11. get_gateway_request_context,
  12. )
  13. from app.schemas.gateway import DownstreamServiceHealth
  14. ProxyServiceName = Literal[
  15. "workflow-service",
  16. "session-service",
  17. "runtime-service",
  18. "tool-service",
  19. "model-gateway-service",
  20. "code-runner-service",
  21. "agent-service",
  22. "memory-service",
  23. "team-service",
  24. "skill-service",
  25. "human-service",
  26. "knowledge-service",
  27. "event-service",
  28. "auth-service",
  29. "scheduler-service",
  30. ]
  31. @dataclass(frozen=True)
  32. class ProxyTarget:
  33. service_name: ProxyServiceName
  34. base_url: str
  35. path_prefix: str
  36. health_path: str
  37. class ServiceProxy:
  38. def __init__(self, *, settings: ApiGatewaySettings, timeout_seconds: float) -> None:
  39. self.settings = settings
  40. self.timeout_seconds = timeout_seconds
  41. async def forward(
  42. self,
  43. *,
  44. request: Request,
  45. target: ProxyTarget,
  46. path: str,
  47. ) -> Response:
  48. target_url = build_target_url(target=target, path=path)
  49. mark_gateway_target(
  50. request,
  51. target_service=target.service_name,
  52. target_url=target_url,
  53. )
  54. headers = build_forward_headers(request)
  55. request_context = get_gateway_request_context(request)
  56. headers[REQUEST_ID_HEADER] = request_context.request_id
  57. headers[TENANT_ID_HEADER] = request_context.tenant_id
  58. headers.update(build_internal_service_headers(self.settings))
  59. body = await request.body()
  60. async with httpx.AsyncClient(timeout=self.timeout_seconds) as client:
  61. upstream_response = await client.request(
  62. method=request.method,
  63. url=target_url,
  64. params=request.query_params,
  65. headers=headers,
  66. content=body,
  67. )
  68. return Response(
  69. content=upstream_response.content,
  70. status_code=upstream_response.status_code,
  71. headers=build_response_headers(upstream_response),
  72. media_type=upstream_response.headers.get("content-type"),
  73. )
  74. async def check_health(self, target: ProxyTarget) -> DownstreamServiceHealth:
  75. health_url = f"{target.base_url.rstrip('/')}{target.health_path}"
  76. try:
  77. async with httpx.AsyncClient(timeout=self.timeout_seconds) as client:
  78. response = await client.get(
  79. health_url,
  80. headers=build_internal_service_headers(self.settings),
  81. )
  82. except httpx.HTTPError as exc:
  83. return DownstreamServiceHealth(
  84. service=target.service_name,
  85. status="error",
  86. url=health_url,
  87. error_message=str(exc),
  88. )
  89. return DownstreamServiceHealth(
  90. service=target.service_name,
  91. status="ok" if response.is_success else "error",
  92. url=health_url,
  93. status_code=response.status_code,
  94. error_message=None if response.is_success else response.text,
  95. )
  96. def build_target_url(*, target: ProxyTarget, path: str) -> str:
  97. normalized_path = path.strip("/")
  98. if normalized_path:
  99. return f"{target.base_url.rstrip('/')}{target.path_prefix}/{normalized_path}"
  100. return f"{target.base_url.rstrip('/')}{target.path_prefix}"
  101. def build_forward_headers(request: Request) -> dict[str, str]:
  102. skipped_headers = {
  103. "host",
  104. "content-length",
  105. "connection",
  106. REQUEST_ID_HEADER,
  107. TENANT_ID_HEADER,
  108. "x-internal-service-token",
  109. "x-internal-service-name",
  110. }
  111. return {
  112. key: value
  113. for key, value in request.headers.items()
  114. if key.lower() not in skipped_headers
  115. }
  116. def build_response_headers(response: httpx.Response) -> dict[str, str]:
  117. skipped_headers = {"content-length", "transfer-encoding", "connection"}
  118. return {
  119. key: value
  120. for key, value in response.headers.items()
  121. if key.lower() not in skipped_headers
  122. }