repositories.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from datetime import datetime
  2. from sqlalchemy import case, func, select
  3. from sqlalchemy.orm import Session
  4. from app.db.models import ApiKey, GatewayRequestAudit
  5. class GatewayRequestAuditRepository:
  6. def __init__(self, db: Session) -> None:
  7. self.db = db
  8. def create(
  9. self,
  10. *,
  11. tenant_id: str,
  12. request_id: str,
  13. method: str,
  14. path: str,
  15. query_string: str | None,
  16. target_service: str | None,
  17. target_url: str | None,
  18. status_code: int | None,
  19. duration_ms: int,
  20. client_host: str | None,
  21. user_agent: str | None,
  22. error_message: str | None,
  23. ) -> GatewayRequestAudit:
  24. entity = GatewayRequestAudit(
  25. tenant_id=tenant_id,
  26. request_id=request_id,
  27. method=method,
  28. path=path,
  29. query_string=query_string,
  30. target_service=target_service,
  31. target_url=target_url,
  32. status_code=status_code,
  33. duration_ms=duration_ms,
  34. client_host=client_host,
  35. user_agent=user_agent,
  36. error_message=error_message,
  37. )
  38. self.db.add(entity)
  39. self.db.commit()
  40. self.db.refresh(entity)
  41. return entity
  42. def list_by_scope(
  43. self,
  44. *,
  45. tenant_id: str,
  46. request_id: str | None = None,
  47. target_service: str | None = None,
  48. limit: int = 100,
  49. ) -> list[GatewayRequestAudit]:
  50. stmt = select(GatewayRequestAudit).where(GatewayRequestAudit.tenant_id == tenant_id)
  51. if request_id is not None:
  52. stmt = stmt.where(GatewayRequestAudit.request_id == request_id)
  53. if target_service is not None:
  54. stmt = stmt.where(GatewayRequestAudit.target_service == target_service)
  55. stmt = stmt.order_by(GatewayRequestAudit.created_time.desc()).limit(limit)
  56. return list(self.db.scalars(stmt))
  57. def stats_by_service(self, *, tenant_id: str) -> list[tuple[str, int, int, float]]:
  58. target_service = func.coalesce(GatewayRequestAudit.target_service, "api-gateway")
  59. error_count = func.sum(
  60. case(
  61. (GatewayRequestAudit.status_code >= 400, 1),
  62. else_=0,
  63. )
  64. )
  65. stmt = (
  66. select(
  67. target_service.label("target_service"),
  68. func.count(GatewayRequestAudit.id),
  69. error_count,
  70. func.avg(GatewayRequestAudit.duration_ms),
  71. )
  72. .where(GatewayRequestAudit.tenant_id == tenant_id)
  73. .group_by(target_service)
  74. .order_by(target_service.asc())
  75. )
  76. rows = self.db.execute(stmt).all()
  77. return [
  78. (
  79. str(row[0]),
  80. int(row[1] or 0),
  81. int(row[2] or 0),
  82. float(row[3] or 0.0),
  83. )
  84. for row in rows
  85. ]
  86. class ApiKeyRepository:
  87. def __init__(self, db: Session) -> None:
  88. self.db = db
  89. def create(
  90. self,
  91. *,
  92. tenant_id: str,
  93. name: str,
  94. key_prefix: str,
  95. key_hash: str,
  96. scopes: str | None,
  97. expires_time: datetime | None,
  98. ) -> ApiKey:
  99. entity = ApiKey(
  100. tenant_id=tenant_id,
  101. name=name,
  102. key_prefix=key_prefix,
  103. key_hash=key_hash,
  104. status="active",
  105. scopes=scopes,
  106. expires_time=expires_time,
  107. )
  108. self.db.add(entity)
  109. self.db.commit()
  110. self.db.refresh(entity)
  111. return entity
  112. def list_by_tenant(self, *, tenant_id: str) -> list[ApiKey]:
  113. stmt = (
  114. select(ApiKey)
  115. .where(ApiKey.tenant_id == tenant_id)
  116. .order_by(ApiKey.created_time.desc())
  117. )
  118. return list(self.db.scalars(stmt))
  119. def has_any(self) -> bool:
  120. stmt = select(ApiKey.id).limit(1)
  121. return self.db.scalar(stmt) is not None
  122. def get_by_id(self, *, tenant_id: str, api_key_id: str) -> ApiKey | None:
  123. stmt = (
  124. select(ApiKey)
  125. .where(ApiKey.tenant_id == tenant_id)
  126. .where(ApiKey.id == api_key_id)
  127. .limit(1)
  128. )
  129. return self.db.scalar(stmt)
  130. def get_active_by_hash(self, *, key_hash: str) -> ApiKey | None:
  131. stmt = (
  132. select(ApiKey)
  133. .where(ApiKey.key_hash == key_hash)
  134. .where(ApiKey.status == "active")
  135. .limit(1)
  136. )
  137. return self.db.scalar(stmt)
  138. def touch_last_used_time(self, *, api_key_id: str) -> None:
  139. entity = self.db.get(ApiKey, api_key_id)
  140. if entity is None:
  141. return
  142. entity.last_used_time = datetime.utcnow()
  143. self.db.commit()
  144. def update_status(
  145. self,
  146. *,
  147. tenant_id: str,
  148. api_key_id: str,
  149. status: str,
  150. ) -> ApiKey | None:
  151. entity = self.get_by_id(tenant_id=tenant_id, api_key_id=api_key_id)
  152. if entity is None:
  153. return None
  154. entity.status = status
  155. self.db.commit()
  156. self.db.refresh(entity)
  157. return entity