from dataclasses import dataclass from typing import Literal import httpx from fastapi import Request, Response from core_shared.observability import PARENT_SPAN_ID_HEADER, SPAN_ID_HEADER, TRACE_ID_HEADER from core_shared.security import build_internal_service_headers from app.infrastructure.audit import mark_gateway_target from app.bootstrap.settings import ApiGatewaySettings from app.infrastructure.request_context import ( REQUEST_ID_HEADER, TENANT_ID_HEADER, get_gateway_request_context, ) from app.schemas.gateway import DownstreamServiceHealth ProxyServiceName = Literal[ "workflow-service", "session-service", "runtime-service", "tool-service", "model-gateway-service", "code-runner-service", "agent-service", "memory-service", "team-service", "skill-service", "human-service", "knowledge-service", "event-service", "auth-service", "scheduler-service", ] @dataclass(frozen=True) class ProxyTarget: service_name: ProxyServiceName base_url: str path_prefix: str health_path: str class ServiceProxy: def __init__(self, *, settings: ApiGatewaySettings, timeout_seconds: float) -> None: self.settings = settings self.timeout_seconds = timeout_seconds async def forward( self, *, request: Request, target: ProxyTarget, path: str, ) -> Response: target_url = build_target_url(target=target, path=path) mark_gateway_target( request, target_service=target.service_name, target_url=target_url, ) headers = build_forward_headers(request) request_context = get_gateway_request_context(request) headers[REQUEST_ID_HEADER] = request_context.request_id headers[TENANT_ID_HEADER] = request_context.tenant_id trace_id = getattr(request.state, "trace_id", None) span_id = getattr(request.state, "span_id", None) if isinstance(trace_id, str): headers[TRACE_ID_HEADER] = trace_id if isinstance(span_id, str): headers[PARENT_SPAN_ID_HEADER] = span_id headers.update(build_internal_service_headers(self.settings)) body = await request.body() async with httpx.AsyncClient(timeout=self.timeout_seconds) as client: upstream_response = await client.request( method=request.method, url=target_url, params=request.query_params, headers=headers, content=body, ) return Response( content=upstream_response.content, status_code=upstream_response.status_code, headers=build_response_headers(upstream_response), media_type=upstream_response.headers.get("content-type"), ) async def check_health(self, target: ProxyTarget) -> DownstreamServiceHealth: health_url = f"{target.base_url.rstrip('/')}{target.health_path}" try: async with httpx.AsyncClient(timeout=self.timeout_seconds) as client: response = await client.get( health_url, headers=build_internal_service_headers(self.settings), ) except httpx.HTTPError as exc: return DownstreamServiceHealth( service=target.service_name, status="error", url=health_url, error_message=str(exc), ) return DownstreamServiceHealth( service=target.service_name, status="ok" if response.is_success else "error", url=health_url, status_code=response.status_code, error_message=None if response.is_success else response.text, ) def build_target_url(*, target: ProxyTarget, path: str) -> str: normalized_path = path.strip("/") if normalized_path: return f"{target.base_url.rstrip('/')}{target.path_prefix}/{normalized_path}" return f"{target.base_url.rstrip('/')}{target.path_prefix}" def build_forward_headers(request: Request) -> dict[str, str]: skipped_headers = { "host", "content-length", "connection", REQUEST_ID_HEADER, TENANT_ID_HEADER, "x-internal-service-token", "x-internal-service-name", TRACE_ID_HEADER, SPAN_ID_HEADER, PARENT_SPAN_ID_HEADER, } return { key: value for key, value in request.headers.items() if key.lower() not in skipped_headers } def build_response_headers(response: httpx.Response) -> dict[str, str]: skipped_headers = {"content-length", "transfer-encoding", "connection"} return { key: value for key, value in response.headers.items() if key.lower() not in skipped_headers }