|
@@ -1,10 +1,16 @@
|
|
|
from dataclasses import dataclass
|
|
from dataclasses import dataclass
|
|
|
|
|
+from datetime import datetime
|
|
|
from time import perf_counter
|
|
from time import perf_counter
|
|
|
from uuid import uuid4
|
|
from uuid import uuid4
|
|
|
|
|
|
|
|
from fastapi import Request, Response
|
|
from fastapi import Request, Response
|
|
|
|
|
+from starlette.responses import JSONResponse
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
|
|
|
|
|
|
|
|
+from app.bootstrap.settings import ApiGatewaySettings
|
|
|
|
|
+from app.domain.repositories import ApiKeyRepository
|
|
|
|
|
+from app.infrastructure.api_keys import hash_api_key
|
|
|
|
|
+
|
|
|
REQUEST_ID_HEADER = "x-request-id"
|
|
REQUEST_ID_HEADER = "x-request-id"
|
|
|
TENANT_ID_HEADER = "x-tenant-id"
|
|
TENANT_ID_HEADER = "x-tenant-id"
|
|
|
DEFAULT_TENANT_ID = "public"
|
|
DEFAULT_TENANT_ID = "public"
|
|
@@ -15,6 +21,7 @@ class GatewayRequestContext:
|
|
|
request_id: str
|
|
request_id: str
|
|
|
tenant_id: str
|
|
tenant_id: str
|
|
|
started_perf_counter: float
|
|
started_perf_counter: float
|
|
|
|
|
+ api_key_id: str | None = None
|
|
|
target_service: str | None = None
|
|
target_service: str | None = None
|
|
|
target_url: str | None = None
|
|
target_url: str | None = None
|
|
|
|
|
|
|
@@ -32,6 +39,21 @@ class GatewayRequestContextMiddleware(BaseHTTPMiddleware):
|
|
|
tenant_id=tenant_id,
|
|
tenant_id=tenant_id,
|
|
|
started_perf_counter=perf_counter(),
|
|
started_perf_counter=perf_counter(),
|
|
|
)
|
|
)
|
|
|
|
|
+ auth_response = authenticate_gateway_request(request)
|
|
|
|
|
+ if auth_response is not None:
|
|
|
|
|
+ from app.infrastructure.audit import persist_gateway_audit
|
|
|
|
|
+
|
|
|
|
|
+ persist_gateway_audit(
|
|
|
|
|
+ request=request,
|
|
|
|
|
+ session_factory=request.app.state.session_factory,
|
|
|
|
|
+ status_code=auth_response.status_code,
|
|
|
|
|
+ error_message=None,
|
|
|
|
|
+ )
|
|
|
|
|
+ context = get_gateway_request_context(request)
|
|
|
|
|
+ auth_response.headers[REQUEST_ID_HEADER] = request_id
|
|
|
|
|
+ auth_response.headers[TENANT_ID_HEADER] = context.tenant_id
|
|
|
|
|
+ return auth_response
|
|
|
|
|
+
|
|
|
try:
|
|
try:
|
|
|
response = await call_next(request)
|
|
response = await call_next(request)
|
|
|
except Exception as exc:
|
|
except Exception as exc:
|
|
@@ -52,8 +74,9 @@ class GatewayRequestContextMiddleware(BaseHTTPMiddleware):
|
|
|
session_factory=request.app.state.session_factory,
|
|
session_factory=request.app.state.session_factory,
|
|
|
status_code=response.status_code,
|
|
status_code=response.status_code,
|
|
|
)
|
|
)
|
|
|
|
|
+ context = get_gateway_request_context(request)
|
|
|
response.headers[REQUEST_ID_HEADER] = request_id
|
|
response.headers[REQUEST_ID_HEADER] = request_id
|
|
|
- response.headers[TENANT_ID_HEADER] = tenant_id
|
|
|
|
|
|
|
+ response.headers[TENANT_ID_HEADER] = context.tenant_id
|
|
|
return response
|
|
return response
|
|
|
|
|
|
|
|
|
|
|
|
@@ -78,3 +101,63 @@ def get_gateway_request_context(request: Request) -> GatewayRequestContext:
|
|
|
tenant_id=DEFAULT_TENANT_ID,
|
|
tenant_id=DEFAULT_TENANT_ID,
|
|
|
started_perf_counter=perf_counter(),
|
|
started_perf_counter=perf_counter(),
|
|
|
)
|
|
)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def authenticate_gateway_request(request: Request) -> Response | None:
|
|
|
|
|
+ settings = ApiGatewaySettings()
|
|
|
|
|
+ if not settings.auth_required:
|
|
|
|
|
+ return None
|
|
|
|
|
+ if not request.url.path.startswith("/gateway/"):
|
|
|
|
|
+ return None
|
|
|
|
|
+ if request.url.path in {"/gateway/services/health"}:
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+ if is_initial_api_key_bootstrap_request(request):
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+ api_key = request.headers.get(settings.api_key_header_name)
|
|
|
|
|
+ if not api_key:
|
|
|
|
|
+ return JSONResponse(
|
|
|
|
|
+ status_code=401,
|
|
|
|
|
+ content={"detail": "missing api key"},
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ db = request.app.state.session_factory()
|
|
|
|
|
+ try:
|
|
|
|
|
+ entity = ApiKeyRepository(db).get_active_by_hash(key_hash=hash_api_key(api_key))
|
|
|
|
|
+ if entity is None:
|
|
|
|
|
+ return JSONResponse(
|
|
|
|
|
+ status_code=401,
|
|
|
|
|
+ content={"detail": "invalid api key"},
|
|
|
|
|
+ )
|
|
|
|
|
+ if entity.expires_time is not None and entity.expires_time <= datetime.utcnow():
|
|
|
|
|
+ return JSONResponse(
|
|
|
|
|
+ status_code=401,
|
|
|
|
|
+ content={"detail": "api key expired"},
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ context = get_gateway_request_context(request)
|
|
|
|
|
+ requested_tenant_id = resolve_tenant_id(request)
|
|
|
|
|
+ if requested_tenant_id not in {DEFAULT_TENANT_ID, entity.tenant_id}:
|
|
|
|
|
+ return JSONResponse(
|
|
|
|
|
+ status_code=403,
|
|
|
|
|
+ content={"detail": "api key tenant mismatch"},
|
|
|
|
|
+ )
|
|
|
|
|
+ context.tenant_id = entity.tenant_id
|
|
|
|
|
+ context.api_key_id = entity.id
|
|
|
|
|
+ ApiKeyRepository(db).touch_last_used_time(api_key_id=entity.id)
|
|
|
|
|
+ finally:
|
|
|
|
|
+ db.close()
|
|
|
|
|
+
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def is_initial_api_key_bootstrap_request(request: Request) -> bool:
|
|
|
|
|
+ if request.method.upper() != "POST" or request.url.path != "/gateway/api-keys":
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ db = request.app.state.session_factory()
|
|
|
|
|
+ try:
|
|
|
|
|
+ return not ApiKeyRepository(db).has_any()
|
|
|
|
|
+ finally:
|
|
|
|
|
+ db.close()
|