from datetime import datetime from sqlalchemy import case, func, select from sqlalchemy.orm import Session from app.db.models import ApiKey, GatewayRequestAudit class GatewayRequestAuditRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, tenant_id: str, request_id: str, method: str, path: str, query_string: str | None, target_service: str | None, target_url: str | None, status_code: int | None, duration_ms: int, client_host: str | None, user_agent: str | None, error_message: str | None, ) -> GatewayRequestAudit: entity = GatewayRequestAudit( tenant_id=tenant_id, request_id=request_id, method=method, path=path, query_string=query_string, target_service=target_service, target_url=target_url, status_code=status_code, duration_ms=duration_ms, client_host=client_host, user_agent=user_agent, error_message=error_message, ) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_scope( self, *, tenant_id: str, request_id: str | None = None, target_service: str | None = None, limit: int = 100, ) -> list[GatewayRequestAudit]: stmt = select(GatewayRequestAudit).where(GatewayRequestAudit.tenant_id == tenant_id) if request_id is not None: stmt = stmt.where(GatewayRequestAudit.request_id == request_id) if target_service is not None: stmt = stmt.where(GatewayRequestAudit.target_service == target_service) stmt = stmt.order_by(GatewayRequestAudit.created_time.desc()).limit(limit) return list(self.db.scalars(stmt)) def stats_by_service(self, *, tenant_id: str) -> list[tuple[str, int, int, float]]: target_service = func.coalesce(GatewayRequestAudit.target_service, "api-gateway") error_count = func.sum( case( (GatewayRequestAudit.status_code >= 400, 1), else_=0, ) ) stmt = ( select( target_service.label("target_service"), func.count(GatewayRequestAudit.id), error_count, func.avg(GatewayRequestAudit.duration_ms), ) .where(GatewayRequestAudit.tenant_id == tenant_id) .group_by(target_service) .order_by(target_service.asc()) ) rows = self.db.execute(stmt).all() return [ ( str(row[0]), int(row[1] or 0), int(row[2] or 0), float(row[3] or 0.0), ) for row in rows ] class ApiKeyRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, tenant_id: str, name: str, key_prefix: str, key_hash: str, scopes: str | None, expires_time: datetime | None, ) -> ApiKey: entity = ApiKey( tenant_id=tenant_id, name=name, key_prefix=key_prefix, key_hash=key_hash, status="active", scopes=scopes, expires_time=expires_time, ) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_tenant(self, *, tenant_id: str) -> list[ApiKey]: stmt = ( select(ApiKey) .where(ApiKey.tenant_id == tenant_id) .order_by(ApiKey.created_time.desc()) ) return list(self.db.scalars(stmt)) def has_any(self) -> bool: stmt = select(ApiKey.id).limit(1) return self.db.scalar(stmt) is not None def get_by_id(self, *, tenant_id: str, api_key_id: str) -> ApiKey | None: stmt = ( select(ApiKey) .where(ApiKey.tenant_id == tenant_id) .where(ApiKey.id == api_key_id) .limit(1) ) return self.db.scalar(stmt) def get_active_by_hash(self, *, key_hash: str) -> ApiKey | None: stmt = ( select(ApiKey) .where(ApiKey.key_hash == key_hash) .where(ApiKey.status == "active") .limit(1) ) return self.db.scalar(stmt) def touch_last_used_time(self, *, api_key_id: str) -> None: entity = self.db.get(ApiKey, api_key_id) if entity is None: return entity.last_used_time = datetime.utcnow() self.db.commit() def update_status( self, *, tenant_id: str, api_key_id: str, status: str, ) -> ApiKey | None: entity = self.get_by_id(tenant_id=tenant_id, api_key_id=api_key_id) if entity is None: return None entity.status = status self.db.commit() self.db.refresh(entity) return entity