| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- from dataclasses import dataclass
- from typing import Literal
- import httpx
- from fastapi import Request, Response
- from app.infrastructure.audit import mark_gateway_target
- from app.infrastructure.request_context import REQUEST_ID_HEADER, TENANT_ID_HEADER, get_gateway_request_context
- from app.schemas.gateway import DownstreamServiceHealth
- ProxyServiceName = Literal[
- "workflow-service",
- "session-service",
- "runtime-service",
- "tool-service",
- "model-gateway-service",
- "code-runner-service",
- "agent-service",
- "memory-service",
- "team-service",
- "skill-service",
- "human-service",
- ]
- @dataclass(frozen=True)
- class ProxyTarget:
- service_name: ProxyServiceName
- base_url: str
- path_prefix: str
- health_path: str
- class ServiceProxy:
- def __init__(self, *, timeout_seconds: float) -> None:
- self.timeout_seconds = timeout_seconds
- async def forward(
- self,
- *,
- request: Request,
- target: ProxyTarget,
- path: str,
- ) -> Response:
- target_url = build_target_url(target=target, path=path)
- mark_gateway_target(
- request,
- target_service=target.service_name,
- target_url=target_url,
- )
- headers = build_forward_headers(request)
- request_context = get_gateway_request_context(request)
- headers[REQUEST_ID_HEADER] = request_context.request_id
- headers[TENANT_ID_HEADER] = request_context.tenant_id
- body = await request.body()
- async with httpx.AsyncClient(timeout=self.timeout_seconds) as client:
- upstream_response = await client.request(
- method=request.method,
- url=target_url,
- params=request.query_params,
- headers=headers,
- content=body,
- )
- return Response(
- content=upstream_response.content,
- status_code=upstream_response.status_code,
- headers=build_response_headers(upstream_response),
- media_type=upstream_response.headers.get("content-type"),
- )
- async def check_health(self, target: ProxyTarget) -> DownstreamServiceHealth:
- health_url = f"{target.base_url.rstrip('/')}{target.health_path}"
- try:
- async with httpx.AsyncClient(timeout=self.timeout_seconds) as client:
- response = await client.get(health_url)
- except httpx.HTTPError as exc:
- return DownstreamServiceHealth(
- service=target.service_name,
- status="error",
- url=health_url,
- error_message=str(exc),
- )
- return DownstreamServiceHealth(
- service=target.service_name,
- status="ok" if response.is_success else "error",
- url=health_url,
- status_code=response.status_code,
- error_message=None if response.is_success else response.text,
- )
- def build_target_url(*, target: ProxyTarget, path: str) -> str:
- normalized_path = path.strip("/")
- if normalized_path:
- return f"{target.base_url.rstrip('/')}{target.path_prefix}/{normalized_path}"
- return f"{target.base_url.rstrip('/')}{target.path_prefix}"
- def build_forward_headers(request: Request) -> dict[str, str]:
- skipped_headers = {"host", "content-length", "connection", REQUEST_ID_HEADER, TENANT_ID_HEADER}
- return {
- key: value
- for key, value in request.headers.items()
- if key.lower() not in skipped_headers
- }
- def build_response_headers(response: httpx.Response) -> dict[str, str]:
- skipped_headers = {"content-length", "transfer-encoding", "connection"}
- return {
- key: value
- for key, value in response.headers.items()
- if key.lower() not in skipped_headers
- }
|