|
@@ -0,0 +1,203 @@
|
|
|
|
|
+from __future__ import annotations
|
|
|
|
|
+
|
|
|
|
|
+import json
|
|
|
|
|
+import logging
|
|
|
|
|
+from collections import defaultdict
|
|
|
|
|
+from dataclasses import dataclass
|
|
|
|
|
+from time import perf_counter
|
|
|
|
|
+from threading import Lock
|
|
|
|
|
+from typing import Any, Awaitable, Callable, Iterable, Protocol
|
|
|
|
|
+
|
|
|
|
|
+from starlette.datastructures import Headers
|
|
|
|
|
+from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
|
+from starlette.requests import Request
|
|
|
|
|
+from starlette.responses import PlainTextResponse, Response
|
|
|
|
|
+from starlette.types import ASGIApp
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+_METRICS_CONTENT_TYPE = "text/plain; version=0.0.4; charset=utf-8"
|
|
|
|
|
+_DURATION_BUCKETS = (0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0)
|
|
|
|
|
+RouteDecorator = Callable[[Callable[..., Awaitable[Response]]], Callable[..., Awaitable[Response]]]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class ObservableApp(Protocol):
|
|
|
|
|
+ state: Any
|
|
|
|
|
+
|
|
|
|
|
+ def add_middleware(self, middleware_class: type[BaseHTTPMiddleware], **options: Any) -> None: ...
|
|
|
|
|
+
|
|
|
|
|
+ def get(
|
|
|
|
|
+ self,
|
|
|
|
|
+ path: str,
|
|
|
|
|
+ **options: Any,
|
|
|
|
|
+ ) -> RouteDecorator: ...
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@dataclass(frozen=True, slots=True)
|
|
|
|
|
+class HttpMetricLabels:
|
|
|
|
|
+ service: str
|
|
|
|
|
+ method: str
|
|
|
|
|
+ path: str
|
|
|
|
|
+ status_code: str
|
|
|
|
|
+
|
|
|
|
|
+ def as_pairs(self) -> tuple[tuple[str, str], ...]:
|
|
|
|
|
+ return (
|
|
|
|
|
+ ("service", self.service),
|
|
|
|
|
+ ("method", self.method),
|
|
|
|
|
+ ("path", self.path),
|
|
|
|
|
+ ("status_code", self.status_code),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class MetricsRegistry:
|
|
|
|
|
+ """Small in-process Prometheus registry for HTTP service telemetry."""
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, service_name: str) -> None:
|
|
|
|
|
+ self._service_name = service_name
|
|
|
|
|
+ self._lock = Lock()
|
|
|
|
|
+ self._request_counts: defaultdict[HttpMetricLabels, int] = defaultdict(int)
|
|
|
|
|
+ self._duration_sums: defaultdict[HttpMetricLabels, float] = defaultdict(float)
|
|
|
|
|
+ self._duration_buckets: defaultdict[tuple[HttpMetricLabels, str], int] = defaultdict(int)
|
|
|
|
|
+
|
|
|
|
|
+ def observe_http_request(self, labels: HttpMetricLabels, duration_seconds: float) -> None:
|
|
|
|
|
+ with self._lock:
|
|
|
|
|
+ self._request_counts[labels] += 1
|
|
|
|
|
+ self._duration_sums[labels] += duration_seconds
|
|
|
|
|
+ for bucket in _DURATION_BUCKETS:
|
|
|
|
|
+ if duration_seconds <= bucket:
|
|
|
|
|
+ self._duration_buckets[(labels, _format_bucket(bucket))] += 1
|
|
|
|
|
+ self._duration_buckets[(labels, "+Inf")] += 1
|
|
|
|
|
+
|
|
|
|
|
+ def render_prometheus(self) -> str:
|
|
|
|
|
+ with self._lock:
|
|
|
|
|
+ request_counts = dict(self._request_counts)
|
|
|
|
|
+ duration_sums = dict(self._duration_sums)
|
|
|
|
|
+ duration_buckets = dict(self._duration_buckets)
|
|
|
|
|
+
|
|
|
|
|
+ lines: list[str] = [
|
|
|
|
|
+ "# HELP agent_platform_service_info Static service metadata.",
|
|
|
|
|
+ "# TYPE agent_platform_service_info gauge",
|
|
|
|
|
+ f'agent_platform_service_info{{service="{_escape_label(self._service_name)}"}} 1',
|
|
|
|
|
+ "# HELP agent_platform_http_requests_total Total HTTP requests by service, route, method and status.",
|
|
|
|
|
+ "# TYPE agent_platform_http_requests_total counter",
|
|
|
|
|
+ ]
|
|
|
|
|
+ for labels, value in sorted(request_counts.items(), key=lambda item: item[0].as_pairs()):
|
|
|
|
|
+ lines.append(f"agent_platform_http_requests_total{{{_render_labels(labels.as_pairs())}}} {value}")
|
|
|
|
|
+
|
|
|
|
|
+ lines.extend(
|
|
|
|
|
+ [
|
|
|
|
|
+ "# HELP agent_platform_http_request_duration_seconds HTTP request duration histogram.",
|
|
|
|
|
+ "# TYPE agent_platform_http_request_duration_seconds histogram",
|
|
|
|
|
+ ]
|
|
|
|
|
+ )
|
|
|
|
|
+ for labels, count in sorted(request_counts.items(), key=lambda item: item[0].as_pairs()):
|
|
|
|
|
+ for bucket in (*(_format_bucket(bucket) for bucket in _DURATION_BUCKETS), "+Inf"):
|
|
|
|
|
+ bucket_count = duration_buckets.get((labels, bucket), 0)
|
|
|
|
|
+ label_pairs = (*labels.as_pairs(), ("le", bucket))
|
|
|
|
|
+ lines.append(
|
|
|
|
|
+ "agent_platform_http_request_duration_seconds_bucket"
|
|
|
|
|
+ f"{{{_render_labels(label_pairs)}}} {bucket_count}"
|
|
|
|
|
+ )
|
|
|
|
|
+ lines.append(
|
|
|
|
|
+ "agent_platform_http_request_duration_seconds_sum"
|
|
|
|
|
+ f"{{{_render_labels(labels.as_pairs())}}} {duration_sums.get(labels, 0.0):.9f}"
|
|
|
|
|
+ )
|
|
|
|
|
+ lines.append(
|
|
|
|
|
+ "agent_platform_http_request_duration_seconds_count"
|
|
|
|
|
+ f"{{{_render_labels(labels.as_pairs())}}} {count}"
|
|
|
|
|
+ )
|
|
|
|
|
+ return "\n".join(lines) + "\n"
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class ObservabilityMiddleware(BaseHTTPMiddleware):
|
|
|
|
|
+ def __init__(self, app: ASGIApp, service_name: str, registry: MetricsRegistry) -> None:
|
|
|
|
|
+ super().__init__(app)
|
|
|
|
|
+ self._service_name = service_name
|
|
|
|
|
+ self._registry = registry
|
|
|
|
|
+ self._logger = logging.getLogger("agent_platform.access")
|
|
|
|
|
+
|
|
|
|
|
+ async def dispatch(
|
|
|
|
|
+ self,
|
|
|
|
|
+ request: Request,
|
|
|
|
|
+ call_next: Callable[[Request], Awaitable[Response]],
|
|
|
|
|
+ ) -> Response:
|
|
|
|
|
+ if request.url.path == "/metrics":
|
|
|
|
|
+ return await call_next(request)
|
|
|
|
|
+
|
|
|
|
|
+ started_at_monotonic = perf_counter()
|
|
|
|
|
+ status_code = 500
|
|
|
|
|
+ try:
|
|
|
|
|
+ response = await call_next(request)
|
|
|
|
|
+ status_code = response.status_code
|
|
|
|
|
+ return response
|
|
|
|
|
+ finally:
|
|
|
|
|
+ duration_seconds = perf_counter() - started_at_monotonic
|
|
|
|
|
+ path = _route_path(request)
|
|
|
|
|
+ labels = HttpMetricLabels(
|
|
|
|
|
+ service=self._service_name,
|
|
|
|
|
+ method=request.method,
|
|
|
|
|
+ path=path,
|
|
|
|
|
+ status_code=str(status_code),
|
|
|
|
|
+ )
|
|
|
|
|
+ self._registry.observe_http_request(labels, duration_seconds)
|
|
|
|
|
+ self._log_request(request, path, status_code, duration_seconds)
|
|
|
|
|
+
|
|
|
|
|
+ def _log_request(
|
|
|
|
|
+ self,
|
|
|
|
|
+ request: Request,
|
|
|
|
|
+ path: str,
|
|
|
|
|
+ status_code: int,
|
|
|
|
|
+ duration_seconds: float,
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ headers = request.headers
|
|
|
|
|
+ payload = {
|
|
|
|
|
+ "event": "http_request",
|
|
|
|
|
+ "service": self._service_name,
|
|
|
|
|
+ "method": request.method,
|
|
|
|
|
+ "path": path,
|
|
|
|
|
+ "status_code": status_code,
|
|
|
|
|
+ "duration_ms": round(duration_seconds * 1000, 3),
|
|
|
|
|
+ "request_id": _header(headers, "x-request-id"),
|
|
|
|
|
+ "tenant_id": _header(headers, "x-tenant-id"),
|
|
|
|
|
+ }
|
|
|
|
|
+ self._logger.info(json.dumps(payload, ensure_ascii=False, separators=(",", ":")))
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def add_observability(app: ObservableApp, service_name: str) -> MetricsRegistry:
|
|
|
|
|
+ """Attach request metrics and a /metrics endpoint to a FastAPI app."""
|
|
|
|
|
+
|
|
|
|
|
+ registry = MetricsRegistry(service_name=service_name)
|
|
|
|
|
+ setattr(app.state, "metrics_registry", registry)
|
|
|
|
|
+ app.add_middleware(ObservabilityMiddleware, service_name=service_name, registry=registry)
|
|
|
|
|
+
|
|
|
|
|
+ @app.get("/metrics", include_in_schema=False)
|
|
|
|
|
+ async def metrics() -> PlainTextResponse:
|
|
|
|
|
+ return PlainTextResponse(registry.render_prometheus(), media_type=_METRICS_CONTENT_TYPE)
|
|
|
|
|
+
|
|
|
|
|
+ return registry
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _route_path(request: Request) -> str:
|
|
|
|
|
+ route = request.scope.get("route")
|
|
|
|
|
+ path = getattr(route, "path", None)
|
|
|
|
|
+ if isinstance(path, str):
|
|
|
|
|
+ return path
|
|
|
|
|
+ return request.url.path
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _header(headers: Headers, name: str) -> str | None:
|
|
|
|
|
+ value = headers.get(name)
|
|
|
|
|
+ if value is None or value == "":
|
|
|
|
|
+ return None
|
|
|
|
|
+ return value
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _format_bucket(value: float) -> str:
|
|
|
|
|
+ return f"{value:g}"
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _render_labels(pairs: Iterable[tuple[str, str]]) -> str:
|
|
|
|
|
+ return ",".join(f'{key}="{_escape_label(value)}"' for key, value in pairs)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _escape_label(value: str) -> str:
|
|
|
|
|
+ return value.replace("\\", "\\\\").replace("\n", "\\n").replace('"', '\\"')
|