Bladeren bron

feat: add gateway api key auth

Jax Docker 2 maanden geleden
bovenliggende
commit
be5d6ab70b

+ 59 - 0
README.md

@@ -460,6 +460,65 @@ Invoke-RestMethod `
   -Headers @{"x-tenant-id"="t1"}
 ```
 
+Gateway API Key auth:
+
+- `AGENT_PLATFORM_AUTH_REQUIRED=false` by default for local development.
+- Set `AGENT_PLATFORM_AUTH_REQUIRED=true` to protect `/gateway/**`, except `/gateway/services/health`.
+- API keys are stored as SHA-256 hashes. The raw key is only returned once at creation.
+- When auth is enabled and no API key exists yet, the first `POST /gateway/api-keys` is allowed as bootstrap.
+- API keys can be `active`, `disabled`, or `revoked`; only `active` keys are accepted.
+
+Create an API key:
+
+```powershell
+$body = @{
+  tenant_id = "t1"
+  name = "local-dev"
+} | ConvertTo-Json
+
+$created = Invoke-RestMethod `
+  -Method Post `
+  -Uri "http://127.0.0.1:8000/gateway/api-keys" `
+  -ContentType "application/json" `
+  -Body $body
+
+$created.api_key
+```
+
+Use an API key:
+
+```powershell
+Invoke-RestMethod `
+  -Uri "http://127.0.0.1:8000/gateway/audits?tenant_id=t1" `
+  -Headers @{"x-tenant-id"="t1"; "x-api-key"=$created.api_key}
+```
+
+Disable or revoke an API key:
+
+```powershell
+$body = @{
+  tenant_id = "t1"
+  status = "revoked"
+} | ConvertTo-Json
+
+Invoke-RestMethod `
+  -Method Patch `
+  -Uri "http://127.0.0.1:8000/gateway/api-keys/$($created.id)/status" `
+  -ContentType "application/json" `
+  -Headers @{"x-tenant-id"="t1"; "x-api-key"=$created.api_key} `
+  -Body $body
+```
+
+Run smoke test through an authenticated gateway:
+
+```powershell
+$env:AGENT_PLATFORM_SMOKE_WORKFLOW_URL="http://127.0.0.1:8000/gateway/workflows"
+$env:AGENT_PLATFORM_SMOKE_RUNTIME_URL="http://127.0.0.1:8000/gateway/runtime"
+$env:AGENT_PLATFORM_SMOKE_TENANT_ID="t1"
+$env:AGENT_PLATFORM_SMOKE_API_KEY=$created.api_key
+.\.venv\Scripts\python scripts\smoke_runtime_no_key.py
+```
+
 HTTP tool node config example:
 
 ```json

+ 1 - 0
deployments/docker/.env.example

@@ -2,3 +2,4 @@ AGENT_PLATFORM_PROVIDER_BASE_URL=https://api.openai.com/v1
 AGENT_PLATFORM_PROVIDER_API_KEY=replace-me
 AGENT_PLATFORM_DEFAULT_MODEL=gpt-4o-mini
 AGENT_PLATFORM_MAX_TIMEOUT_SECONDS=30
+AGENT_PLATFORM_AUTH_REQUIRED=false

+ 1 - 0
deployments/docker/docker-compose.yml

@@ -116,6 +116,7 @@ services:
       AGENT_PLATFORM_TOOL_SERVICE_URL: http://tool-service:8004
       AGENT_PLATFORM_MODEL_GATEWAY_SERVICE_URL: http://model-gateway-service:8005
       AGENT_PLATFORM_CODE_RUNNER_SERVICE_URL: http://code-runner-service:8006
+      AGENT_PLATFORM_AUTH_REQUIRED: ${AGENT_PLATFORM_AUTH_REQUIRED:-false}
     ports:
       - "8003:8003"
     volumes:

+ 6 - 1
scripts/smoke_runtime_no_key.py

@@ -18,6 +18,7 @@ RUNTIME_SERVICE_URL = os.getenv(
     "http://127.0.0.1:8003/runtime",
 )
 TENANT_ID = os.getenv("AGENT_PLATFORM_SMOKE_TENANT_ID", "t-smoke")
+SMOKE_API_KEY = os.getenv("AGENT_PLATFORM_SMOKE_API_KEY")
 
 
 @dataclass(frozen=True)
@@ -43,7 +44,11 @@ SCENARIOS = (
 
 def main() -> int:
     unique_suffix = uuid.uuid4().hex[:8]
-    with httpx.Client(timeout=20.0, headers={"x-tenant-id": TENANT_ID}) as client:
+    headers = {"x-tenant-id": TENANT_ID}
+    if SMOKE_API_KEY:
+        headers["x-api-key"] = SMOKE_API_KEY
+
+    with httpx.Client(timeout=20.0, headers=headers) as client:
         app_id = create_app(client, unique_suffix)
         workflow_id = create_workflow(client, app_id, unique_suffix)
 

+ 51 - 0
services/api-gateway/alembic/versions/20260423_0002_add_api_keys.py

@@ -0,0 +1,51 @@
+"""add api keys
+
+Revision ID: 20260423_0002
+Revises: 20260423_0001
+Create Date: 2026-04-23 20:00:00
+"""
+
+from collections.abc import Sequence
+
+from alembic import op
+import sqlalchemy as sa
+
+
+revision: str = "20260423_0002"
+down_revision: str | None = "20260423_0001"
+branch_labels: Sequence[str] | None = None
+depends_on: Sequence[str] | None = None
+
+
+def upgrade() -> None:
+    op.create_table(
+        "api_key",
+        sa.Column("name", sa.String(length=128), nullable=False),
+        sa.Column("key_prefix", sa.String(length=16), nullable=False),
+        sa.Column("key_hash", sa.String(length=128), nullable=False),
+        sa.Column("status", sa.String(length=32), nullable=False),
+        sa.Column("scopes", sa.Text(), nullable=True),
+        sa.Column("expires_time", sa.DateTime(), nullable=True),
+        sa.Column("last_used_time", sa.DateTime(), nullable=True),
+        sa.Column("id", sa.String(length=36), nullable=False),
+        sa.Column("tenant_id", sa.String(length=36), nullable=False),
+        sa.Column("created_by", sa.String(length=36), nullable=True),
+        sa.Column("updated_by", sa.String(length=36), nullable=True),
+        sa.Column("created_time", sa.DateTime(), nullable=False),
+        sa.Column("updated_time", sa.DateTime(), nullable=False),
+        sa.Column("deleted_time", sa.DateTime(), nullable=True),
+        sa.Column("version", sa.Integer(), nullable=False),
+        sa.PrimaryKeyConstraint("id"),
+    )
+    op.create_index("ix_api_key_key_prefix", "api_key", ["key_prefix"], unique=False)
+    op.create_index("ix_api_key_key_hash", "api_key", ["key_hash"], unique=True)
+    op.create_index("ix_api_key_status", "api_key", ["status"], unique=False)
+    op.create_index("ix_api_key_tenant_id", "api_key", ["tenant_id"], unique=False)
+
+
+def downgrade() -> None:
+    op.drop_index("ix_api_key_tenant_id", table_name="api_key")
+    op.drop_index("ix_api_key_status", table_name="api_key")
+    op.drop_index("ix_api_key_key_hash", table_name="api_key")
+    op.drop_index("ix_api_key_key_prefix", table_name="api_key")
+    op.drop_table("api_key")

+ 65 - 3
services/api-gateway/app/api/routes.py

@@ -1,15 +1,23 @@
 import asyncio
 
-from fastapi import APIRouter, Depends, Query, Request, Response
+from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response
 from sqlalchemy import text
 from sqlalchemy.orm import Session
 
 from core_domain import ServiceDescriptor, ServiceHealth
 from app.bootstrap.settings import ApiGatewaySettings
 from app.db.session import get_db
-from app.domain.repositories import GatewayRequestAuditRepository
+from app.domain.repositories import ApiKeyRepository, GatewayRequestAuditRepository
+from app.infrastructure.api_keys import generate_api_key, get_api_key_prefix, hash_api_key
 from app.infrastructure.proxy import ProxyServiceName, ProxyTarget, ServiceProxy
-from app.schemas.gateway import GatewayRequestAuditResponse, GatewayServicesHealthResponse
+from app.schemas.gateway import (
+    ApiKeyCreateRequest,
+    ApiKeyCreateResponse,
+    ApiKeyResponse,
+    ApiKeyStatusUpdateRequest,
+    GatewayRequestAuditResponse,
+    GatewayServicesHealthResponse,
+)
 
 router = APIRouter()
 
@@ -26,6 +34,60 @@ def readiness_check(db: Session = Depends(get_db)) -> ServiceHealth:
     return ServiceHealth(service="api-gateway", status="ok", database="ok")
 
 
+@router.post("/gateway/api-keys", response_model=ApiKeyCreateResponse)
+def create_api_key(
+    payload: ApiKeyCreateRequest,
+    db: Session = Depends(get_db),
+) -> ApiKeyCreateResponse:
+    api_key = generate_api_key()
+    entity = ApiKeyRepository(db).create(
+        tenant_id=payload.tenant_id,
+        name=payload.name,
+        key_prefix=get_api_key_prefix(api_key),
+        key_hash=hash_api_key(api_key),
+        scopes=payload.scopes,
+        expires_time=payload.expires_time,
+    )
+    return ApiKeyCreateResponse(
+        id=entity.id,
+        tenant_id=entity.tenant_id,
+        name=entity.name,
+        key_prefix=entity.key_prefix,
+        api_key=api_key,
+        status=entity.status,
+        scopes=entity.scopes,
+        expires_time=entity.expires_time,
+        created_time=entity.created_time,
+    )
+
+
+@router.get("/gateway/api-keys", response_model=list[ApiKeyResponse])
+def list_api_keys(
+    tenant_id: str = Query(...),
+    db: Session = Depends(get_db),
+) -> list[ApiKeyResponse]:
+    return [
+        ApiKeyResponse.from_entity(item)
+        for item in ApiKeyRepository(db).list_by_tenant(tenant_id=tenant_id)
+    ]
+
+
+@router.patch("/gateway/api-keys/{api_key_id}/status", response_model=ApiKeyResponse)
+def update_api_key_status(
+    api_key_id: str,
+    payload: ApiKeyStatusUpdateRequest,
+    db: Session = Depends(get_db),
+) -> ApiKeyResponse:
+    entity = ApiKeyRepository(db).update_status(
+        tenant_id=payload.tenant_id,
+        api_key_id=api_key_id,
+        status=payload.status,
+    )
+    if entity is None:
+        raise HTTPException(status_code=404, detail=f"api key not found: {api_key_id}")
+    return ApiKeyResponse.from_entity(entity)
+
+
 @router.get("/gateway/audits", response_model=list[GatewayRequestAuditResponse])
 def list_gateway_audits(
     tenant_id: str = Query(...),

+ 2 - 0
services/api-gateway/app/bootstrap/settings.py

@@ -13,3 +13,5 @@ class ApiGatewaySettings(ServiceSettings):
     code_runner_service_url: str = "http://127.0.0.1:8006"
     proxy_timeout_seconds: float = 30.0
     downstream_health_timeout_seconds: float = 2.0
+    auth_required: bool = False
+    api_key_header_name: str = "x-api-key"

+ 2 - 1
services/api-gateway/app/db/models/__init__.py

@@ -1,5 +1,6 @@
 from core_db import Base
 
+from .api_key import ApiKey
 from .gateway_request_audit import GatewayRequestAudit
 
-__all__ = ["Base", "GatewayRequestAudit"]
+__all__ = ["ApiKey", "Base", "GatewayRequestAudit"]

+ 18 - 0
services/api-gateway/app/db/models/api_key.py

@@ -0,0 +1,18 @@
+from datetime import datetime
+
+from sqlalchemy import DateTime, String, Text
+from sqlalchemy.orm import Mapped, mapped_column
+
+from core_db import AuditMixin, Base, TenantMixin, VersionMixin
+
+
+class ApiKey(TenantMixin, AuditMixin, VersionMixin, Base):
+    __tablename__ = "api_key"
+
+    name: Mapped[str] = mapped_column(String(128))
+    key_prefix: Mapped[str] = mapped_column(String(16), index=True)
+    key_hash: Mapped[str] = mapped_column(String(128), unique=True, index=True)
+    status: Mapped[str] = mapped_column(String(32), default="active", index=True)
+    scopes: Mapped[str | None] = mapped_column(Text, nullable=True)
+    expires_time: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
+    last_used_time: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)

+ 84 - 1
services/api-gateway/app/domain/repositories.py

@@ -1,7 +1,9 @@
+from datetime import datetime
+
 from sqlalchemy import select
 from sqlalchemy.orm import Session
 
-from app.db.models import GatewayRequestAudit
+from app.db.models import ApiKey, GatewayRequestAudit
 
 
 class GatewayRequestAuditRepository:
@@ -58,3 +60,84 @@ class GatewayRequestAuditRepository:
             stmt = stmt.where(GatewayRequestAudit.target_service == target_service)
         stmt = stmt.order_by(GatewayRequestAudit.created_time.desc()).limit(limit)
         return list(self.db.scalars(stmt))
+
+
+class ApiKeyRepository:
+    def __init__(self, db: Session) -> None:
+        self.db = db
+
+    def create(
+        self,
+        *,
+        tenant_id: str,
+        name: str,
+        key_prefix: str,
+        key_hash: str,
+        scopes: str | None,
+        expires_time: datetime | None,
+    ) -> ApiKey:
+        entity = ApiKey(
+            tenant_id=tenant_id,
+            name=name,
+            key_prefix=key_prefix,
+            key_hash=key_hash,
+            status="active",
+            scopes=scopes,
+            expires_time=expires_time,
+        )
+        self.db.add(entity)
+        self.db.commit()
+        self.db.refresh(entity)
+        return entity
+
+    def list_by_tenant(self, *, tenant_id: str) -> list[ApiKey]:
+        stmt = (
+            select(ApiKey)
+            .where(ApiKey.tenant_id == tenant_id)
+            .order_by(ApiKey.created_time.desc())
+        )
+        return list(self.db.scalars(stmt))
+
+    def has_any(self) -> bool:
+        stmt = select(ApiKey.id).limit(1)
+        return self.db.scalar(stmt) is not None
+
+    def get_by_id(self, *, tenant_id: str, api_key_id: str) -> ApiKey | None:
+        stmt = (
+            select(ApiKey)
+            .where(ApiKey.tenant_id == tenant_id)
+            .where(ApiKey.id == api_key_id)
+            .limit(1)
+        )
+        return self.db.scalar(stmt)
+
+    def get_active_by_hash(self, *, key_hash: str) -> ApiKey | None:
+        stmt = (
+            select(ApiKey)
+            .where(ApiKey.key_hash == key_hash)
+            .where(ApiKey.status == "active")
+            .limit(1)
+        )
+        return self.db.scalar(stmt)
+
+    def touch_last_used_time(self, *, api_key_id: str) -> None:
+        entity = self.db.get(ApiKey, api_key_id)
+        if entity is None:
+            return
+        entity.last_used_time = datetime.utcnow()
+        self.db.commit()
+
+    def update_status(
+        self,
+        *,
+        tenant_id: str,
+        api_key_id: str,
+        status: str,
+    ) -> ApiKey | None:
+        entity = self.get_by_id(tenant_id=tenant_id, api_key_id=api_key_id)
+        if entity is None:
+            return None
+        entity.status = status
+        self.db.commit()
+        self.db.refresh(entity)
+        return entity

+ 19 - 0
services/api-gateway/app/infrastructure/api_keys.py

@@ -0,0 +1,19 @@
+import hashlib
+import secrets
+
+
+API_KEY_PREFIX = "agp"
+
+
+def generate_api_key() -> str:
+    return f"{API_KEY_PREFIX}_{secrets.token_urlsafe(32)}"
+
+
+def hash_api_key(api_key: str) -> str:
+    return hashlib.sha256(api_key.encode("utf-8")).hexdigest()
+
+
+def get_api_key_prefix(api_key: str) -> str:
+    if len(api_key) <= 12:
+        return api_key
+    return api_key[:12]

+ 84 - 1
services/api-gateway/app/infrastructure/request_context.py

@@ -1,10 +1,16 @@
 from dataclasses import dataclass
+from datetime import datetime
 from time import perf_counter
 from uuid import uuid4
 
 from fastapi import Request, Response
+from starlette.responses import JSONResponse
 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
 
+from app.bootstrap.settings import ApiGatewaySettings
+from app.domain.repositories import ApiKeyRepository
+from app.infrastructure.api_keys import hash_api_key
+
 REQUEST_ID_HEADER = "x-request-id"
 TENANT_ID_HEADER = "x-tenant-id"
 DEFAULT_TENANT_ID = "public"
@@ -15,6 +21,7 @@ class GatewayRequestContext:
     request_id: str
     tenant_id: str
     started_perf_counter: float
+    api_key_id: str | None = None
     target_service: str | None = None
     target_url: str | None = None
 
@@ -32,6 +39,21 @@ class GatewayRequestContextMiddleware(BaseHTTPMiddleware):
             tenant_id=tenant_id,
             started_perf_counter=perf_counter(),
         )
+        auth_response = authenticate_gateway_request(request)
+        if auth_response is not None:
+            from app.infrastructure.audit import persist_gateway_audit
+
+            persist_gateway_audit(
+                request=request,
+                session_factory=request.app.state.session_factory,
+                status_code=auth_response.status_code,
+                error_message=None,
+            )
+            context = get_gateway_request_context(request)
+            auth_response.headers[REQUEST_ID_HEADER] = request_id
+            auth_response.headers[TENANT_ID_HEADER] = context.tenant_id
+            return auth_response
+
         try:
             response = await call_next(request)
         except Exception as exc:
@@ -52,8 +74,9 @@ class GatewayRequestContextMiddleware(BaseHTTPMiddleware):
             session_factory=request.app.state.session_factory,
             status_code=response.status_code,
         )
+        context = get_gateway_request_context(request)
         response.headers[REQUEST_ID_HEADER] = request_id
-        response.headers[TENANT_ID_HEADER] = tenant_id
+        response.headers[TENANT_ID_HEADER] = context.tenant_id
         return response
 
 
@@ -78,3 +101,63 @@ def get_gateway_request_context(request: Request) -> GatewayRequestContext:
         tenant_id=DEFAULT_TENANT_ID,
         started_perf_counter=perf_counter(),
     )
+
+
+def authenticate_gateway_request(request: Request) -> Response | None:
+    settings = ApiGatewaySettings()
+    if not settings.auth_required:
+        return None
+    if not request.url.path.startswith("/gateway/"):
+        return None
+    if request.url.path in {"/gateway/services/health"}:
+        return None
+
+    if is_initial_api_key_bootstrap_request(request):
+        return None
+
+    api_key = request.headers.get(settings.api_key_header_name)
+    if not api_key:
+        return JSONResponse(
+            status_code=401,
+            content={"detail": "missing api key"},
+        )
+
+    db = request.app.state.session_factory()
+    try:
+        entity = ApiKeyRepository(db).get_active_by_hash(key_hash=hash_api_key(api_key))
+        if entity is None:
+            return JSONResponse(
+                status_code=401,
+                content={"detail": "invalid api key"},
+            )
+        if entity.expires_time is not None and entity.expires_time <= datetime.utcnow():
+            return JSONResponse(
+                status_code=401,
+                content={"detail": "api key expired"},
+            )
+
+        context = get_gateway_request_context(request)
+        requested_tenant_id = resolve_tenant_id(request)
+        if requested_tenant_id not in {DEFAULT_TENANT_ID, entity.tenant_id}:
+            return JSONResponse(
+                status_code=403,
+                content={"detail": "api key tenant mismatch"},
+            )
+        context.tenant_id = entity.tenant_id
+        context.api_key_id = entity.id
+        ApiKeyRepository(db).touch_last_used_time(api_key_id=entity.id)
+    finally:
+        db.close()
+
+    return None
+
+
+def is_initial_api_key_bootstrap_request(request: Request) -> bool:
+    if request.method.upper() != "POST" or request.url.path != "/gateway/api-keys":
+        return False
+
+    db = request.app.state.session_factory()
+    try:
+        return not ApiKeyRepository(db).has_any()
+    finally:
+        db.close()

+ 45 - 1
services/api-gateway/app/schemas/gateway.py

@@ -1,10 +1,11 @@
 from pydantic import BaseModel
 from datetime import datetime
+from typing import Literal
 
 from typing import TYPE_CHECKING
 
 if TYPE_CHECKING:
-    from app.db.models import GatewayRequestAudit
+    from app.db.models import ApiKey, GatewayRequestAudit
 
 
 class DownstreamServiceHealth(BaseModel):
@@ -40,3 +41,46 @@ class GatewayRequestAuditResponse(BaseModel):
     @classmethod
     def from_entity(cls, entity: "GatewayRequestAudit") -> "GatewayRequestAuditResponse":
         return cls.model_validate(entity, from_attributes=True)
+
+
+class ApiKeyCreateRequest(BaseModel):
+    tenant_id: str
+    name: str
+    scopes: str | None = None
+    expires_time: datetime | None = None
+
+
+class ApiKeyCreateResponse(BaseModel):
+    id: str
+    tenant_id: str
+    name: str
+    key_prefix: str
+    api_key: str
+    status: str
+    scopes: str | None = None
+    expires_time: datetime | None = None
+    created_time: datetime
+
+
+class ApiKeyResponse(BaseModel):
+    id: str
+    tenant_id: str
+    name: str
+    key_prefix: str
+    status: str
+    scopes: str | None = None
+    expires_time: datetime | None = None
+    last_used_time: datetime | None = None
+    created_time: datetime
+
+    @classmethod
+    def from_entity(cls, entity: "ApiKey") -> "ApiKeyResponse":
+        return cls.model_validate(entity, from_attributes=True)
+
+
+ApiKeyStatus = Literal["active", "disabled", "revoked"]
+
+
+class ApiKeyStatusUpdateRequest(BaseModel):
+    tenant_id: str
+    status: ApiKeyStatus