瀏覽代碼

feat: add gateway rate limiting

Jax Docker 1 月之前
父節點
當前提交
705a2cb7cf

+ 3 - 0
deployments/docker/.env.example

@@ -28,6 +28,9 @@ AGENT_PLATFORM_AUTH_REQUIRED=false
 AGENT_PLATFORM_AUTHZ_REQUIRED=false
 AGENT_PLATFORM_INTERNAL_SERVICE_AUTH_REQUIRED=false
 AGENT_PLATFORM_INTERNAL_SERVICE_TOKEN=replace-with-shared-internal-token
+AGENT_PLATFORM_RATE_LIMIT_ENABLED=false
+AGENT_PLATFORM_TENANT_RATE_LIMIT_PER_MINUTE=600
+AGENT_PLATFORM_API_KEY_RATE_LIMIT_PER_MINUTE=1200
 AGENT_PLATFORM_WORKER_POLL_INTERVAL_SECONDS=1
 AGENT_PLATFORM_WORKER_LEASE_SECONDS=300
 AGENT_PLATFORM_SCHEDULER_WORKER_CLAIM_LIMIT=20

+ 3 - 0
deployments/docker/docker-compose.yml

@@ -617,6 +617,9 @@ services:
       AGENT_PLATFORM_SCHEDULER_SERVICE_URL: http://scheduler-service:8015
       AGENT_PLATFORM_AUTH_REQUIRED: ${AGENT_PLATFORM_AUTH_REQUIRED:-false}
       AGENT_PLATFORM_AUTHZ_REQUIRED: ${AGENT_PLATFORM_AUTHZ_REQUIRED:-false}
+      AGENT_PLATFORM_RATE_LIMIT_ENABLED: ${AGENT_PLATFORM_RATE_LIMIT_ENABLED:-false}
+      AGENT_PLATFORM_TENANT_RATE_LIMIT_PER_MINUTE: ${AGENT_PLATFORM_TENANT_RATE_LIMIT_PER_MINUTE:-600}
+      AGENT_PLATFORM_API_KEY_RATE_LIMIT_PER_MINUTE: ${AGENT_PLATFORM_API_KEY_RATE_LIMIT_PER_MINUTE:-1200}
     ports:
       - "8000:8000"
     volumes:

+ 122 - 0
libs/core-shared/src/core_shared/rate_limit.py

@@ -0,0 +1,122 @@
+from __future__ import annotations
+
+import time
+from dataclasses import dataclass
+from threading import Lock
+from typing import TYPE_CHECKING, Protocol
+
+if TYPE_CHECKING:
+    from redis import Redis
+
+
+@dataclass(frozen=True, slots=True)
+class RateLimitDecision:
+    allowed: bool
+    limit: int
+    remaining: int
+    reset_epoch_seconds: int
+    current_count: int
+
+
+class RateLimiter(Protocol):
+    def check(
+        self,
+        *,
+        key: str,
+        limit: int,
+        window_seconds: int,
+        now_epoch_seconds: int | None = None,
+    ) -> RateLimitDecision: ...
+
+
+class InMemoryFixedWindowRateLimiter:
+    def __init__(self) -> None:
+        self._lock = Lock()
+        self._counts: dict[tuple[str, int], int] = {}
+
+    def check(
+        self,
+        *,
+        key: str,
+        limit: int,
+        window_seconds: int,
+        now_epoch_seconds: int | None = None,
+    ) -> RateLimitDecision:
+        now = now_epoch_seconds or int(time.time())
+        window = _window_start(now, window_seconds)
+        bucket_key = (key, window)
+        with self._lock:
+            self._cleanup_expired(now=now, window_seconds=window_seconds)
+            current_count = self._counts.get(bucket_key, 0) + 1
+            self._counts[bucket_key] = current_count
+        return _decision(
+            current_count=current_count,
+            limit=limit,
+            reset_epoch_seconds=window + window_seconds,
+        )
+
+    def _cleanup_expired(self, *, now: int, window_seconds: int) -> None:
+        cutoff = _window_start(now, window_seconds) - window_seconds
+        expired_keys = [
+            bucket_key for bucket_key in self._counts if bucket_key[1] < cutoff
+        ]
+        for bucket_key in expired_keys:
+            self._counts.pop(bucket_key, None)
+
+
+class RedisFixedWindowRateLimiter:
+    def __init__(self, *, client: "Redis", prefix: str = "rate-limit") -> None:
+        self._client = client
+        self._prefix = prefix
+
+    def check(
+        self,
+        *,
+        key: str,
+        limit: int,
+        window_seconds: int,
+        now_epoch_seconds: int | None = None,
+    ) -> RateLimitDecision:
+        now = now_epoch_seconds or int(time.time())
+        window = _window_start(now, window_seconds)
+        redis_key = f"{self._prefix}:{key}:{window}"
+        pipeline = self._client.pipeline()
+        pipeline.incr(redis_key)
+        pipeline.expire(redis_key, window_seconds * 2)
+        result = pipeline.execute()
+        current_count = int(result[0])
+        return _decision(
+            current_count=current_count,
+            limit=limit,
+            reset_epoch_seconds=window + window_seconds,
+        )
+
+
+def build_rate_limiter(redis_client: "Redis | None") -> RateLimiter:
+    if redis_client is None:
+        return InMemoryFixedWindowRateLimiter()
+    try:
+        return RedisFixedWindowRateLimiter(client=redis_client)
+    except Exception:
+        return InMemoryFixedWindowRateLimiter()
+
+
+def _window_start(now_epoch_seconds: int, window_seconds: int) -> int:
+    safe_window_seconds = max(window_seconds, 1)
+    return now_epoch_seconds - (now_epoch_seconds % safe_window_seconds)
+
+
+def _decision(
+    *,
+    current_count: int,
+    limit: int,
+    reset_epoch_seconds: int,
+) -> RateLimitDecision:
+    remaining = max(limit - current_count, 0)
+    return RateLimitDecision(
+        allowed=current_count <= limit,
+        limit=limit,
+        remaining=remaining,
+        reset_epoch_seconds=reset_epoch_seconds,
+        current_count=current_count,
+    )

+ 3 - 0
services/api-gateway/app/bootstrap/app.py

@@ -1,5 +1,7 @@
 from fastapi import FastAPI
 from core_shared.observability import add_observability
+from core_shared import try_build_redis_client
+from core_shared.rate_limit import build_rate_limiter
 
 from app.api.routes import router
 from app.bootstrap.settings import ApiGatewaySettings
@@ -15,6 +17,7 @@ def create_app() -> FastAPI:
     )
     app.state.settings = settings
     app.state.session_factory = build_session_factory(settings)
+    app.state.rate_limiter = build_rate_limiter(try_build_redis_client(settings.redis_url))
     add_observability(app, settings.service_name)
     app.add_middleware(GatewayRequestContextMiddleware)
     app.include_router(router)

+ 3 - 0
services/api-gateway/app/bootstrap/settings.py

@@ -27,3 +27,6 @@ class ApiGatewaySettings(ServiceSettings):
     api_key_header_name: str = "x-api-key"
     user_id_header_name: str = "x-user-id"
     authz_timeout_seconds: float = 2.0
+    rate_limit_enabled: bool = False
+    tenant_rate_limit_per_minute: int = 600
+    api_key_rate_limit_per_minute: int = 1200

+ 90 - 0
services/api-gateway/app/infrastructure/rate_limit.py

@@ -0,0 +1,90 @@
+from __future__ import annotations
+
+from fastapi import Request, Response
+from starlette.responses import JSONResponse
+
+from core_shared.rate_limit import RateLimitDecision, RateLimiter
+
+from app.bootstrap.settings import ApiGatewaySettings
+
+
+RATE_LIMIT_WINDOW_SECONDS = 60
+RATE_LIMIT_LIMIT_HEADER = "x-ratelimit-limit"
+RATE_LIMIT_REMAINING_HEADER = "x-ratelimit-remaining"
+RATE_LIMIT_RESET_HEADER = "x-ratelimit-reset"
+
+
+def enforce_gateway_rate_limit(
+    *,
+    request: Request,
+    settings: ApiGatewaySettings,
+    limiter: RateLimiter,
+) -> Response | None:
+    if not settings.rate_limit_enabled:
+        return None
+    if not request.url.path.startswith("/gateway/"):
+        return None
+
+    from app.infrastructure.request_context import get_gateway_request_context
+
+    context = get_gateway_request_context(request)
+    checks: list[RateLimitDecision] = []
+    tenant_limit = max(settings.tenant_rate_limit_per_minute, 1)
+    checks.append(
+        limiter.check(
+            key=f"tenant:{context.tenant_id}",
+            limit=tenant_limit,
+            window_seconds=RATE_LIMIT_WINDOW_SECONDS,
+        )
+    )
+    if context.api_key_id is not None:
+        api_key_limit = max(settings.api_key_rate_limit_per_minute, 1)
+        checks.append(
+            limiter.check(
+                key=f"api-key:{context.api_key_id}",
+                limit=api_key_limit,
+                window_seconds=RATE_LIMIT_WINDOW_SECONDS,
+            )
+        )
+
+    denied = next((item for item in checks if not item.allowed), None)
+    if denied is None:
+        if checks:
+            request.state.gateway_rate_limit_decision = min(
+                checks,
+                key=lambda item: item.remaining,
+            )
+        return None
+
+    response = JSONResponse(
+        status_code=429,
+        content={
+            "detail": "rate limit exceeded",
+            "limit": denied.limit,
+            "reset_epoch_seconds": denied.reset_epoch_seconds,
+        },
+    )
+    apply_rate_limit_headers(response, denied)
+    return response
+
+
+def apply_gateway_rate_limit_headers(
+    *,
+    response: Response,
+    request: Request,
+    settings: ApiGatewaySettings,
+    limiter: RateLimiter,
+) -> None:
+    if not settings.rate_limit_enabled:
+        return
+    if not request.url.path.startswith("/gateway/"):
+        return
+    decision = getattr(request.state, "gateway_rate_limit_decision", None)
+    if isinstance(decision, RateLimitDecision):
+        apply_rate_limit_headers(response, decision)
+
+
+def apply_rate_limit_headers(response: Response, decision: RateLimitDecision) -> None:
+    response.headers[RATE_LIMIT_LIMIT_HEADER] = str(decision.limit)
+    response.headers[RATE_LIMIT_REMAINING_HEADER] = str(decision.remaining)
+    response.headers[RATE_LIMIT_RESET_HEADER] = str(decision.reset_epoch_seconds)

+ 30 - 0
services/api-gateway/app/infrastructure/request_context.py

@@ -12,6 +12,10 @@ from core_shared.security import build_internal_service_headers
 from app.bootstrap.settings import ApiGatewaySettings
 from app.domain.repositories import ApiKeyRepository
 from app.infrastructure.api_keys import hash_api_key
+from app.infrastructure.rate_limit import (
+    apply_gateway_rate_limit_headers,
+    enforce_gateway_rate_limit,
+)
 
 REQUEST_ID_HEADER = "x-request-id"
 TENANT_ID_HEADER = "x-tenant-id"
@@ -57,6 +61,26 @@ class GatewayRequestContextMiddleware(BaseHTTPMiddleware):
             auth_response.headers[TENANT_ID_HEADER] = context.tenant_id
             return auth_response
 
+        settings = request.app.state.settings
+        rate_limit_response = enforce_gateway_rate_limit(
+            request=request,
+            settings=settings,
+            limiter=request.app.state.rate_limiter,
+        )
+        if rate_limit_response is not None:
+            from app.infrastructure.audit import persist_gateway_audit
+
+            persist_gateway_audit(
+                request=request,
+                session_factory=request.app.state.session_factory,
+                status_code=rate_limit_response.status_code,
+                error_message="rate limit exceeded",
+            )
+            context = get_gateway_request_context(request)
+            rate_limit_response.headers[REQUEST_ID_HEADER] = request_id
+            rate_limit_response.headers[TENANT_ID_HEADER] = context.tenant_id
+            return rate_limit_response
+
         try:
             response = await call_next(request)
         except Exception as exc:
@@ -80,6 +104,12 @@ class GatewayRequestContextMiddleware(BaseHTTPMiddleware):
         context = get_gateway_request_context(request)
         response.headers[REQUEST_ID_HEADER] = request_id
         response.headers[TENANT_ID_HEADER] = context.tenant_id
+        apply_gateway_rate_limit_headers(
+            response=response,
+            request=request,
+            settings=settings,
+            limiter=request.app.state.rate_limiter,
+        )
         return response
 
 

+ 38 - 0
tests/test_rate_limit.py

@@ -0,0 +1,38 @@
+from core_shared.rate_limit import InMemoryFixedWindowRateLimiter
+
+
+def test_in_memory_fixed_window_rate_limiter_blocks_after_limit() -> None:
+    limiter = InMemoryFixedWindowRateLimiter()
+
+    first = limiter.check(
+        key="tenant:t1",
+        limit=2,
+        window_seconds=60,
+        now_epoch_seconds=120,
+    )
+    second = limiter.check(
+        key="tenant:t1",
+        limit=2,
+        window_seconds=60,
+        now_epoch_seconds=121,
+    )
+    third = limiter.check(
+        key="tenant:t1",
+        limit=2,
+        window_seconds=60,
+        now_epoch_seconds=122,
+    )
+    next_window = limiter.check(
+        key="tenant:t1",
+        limit=2,
+        window_seconds=60,
+        now_epoch_seconds=180,
+    )
+
+    assert first.allowed is True
+    assert first.remaining == 1
+    assert second.allowed is True
+    assert second.remaining == 0
+    assert third.allowed is False
+    assert third.reset_epoch_seconds == 180
+    assert next_window.allowed is True