repositories.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from datetime import datetime
  2. from sqlalchemy import select
  3. from sqlalchemy.orm import Session
  4. from app.db.models import ModelDefinition, ModelProviderDefinition
  5. class ModelDefinitionRepository:
  6. def __init__(self, db: Session) -> None:
  7. self.db = db
  8. def create(self, entity: ModelDefinition) -> ModelDefinition:
  9. self.db.add(entity)
  10. self.db.commit()
  11. self.db.refresh(entity)
  12. return entity
  13. def list_all(self) -> list[ModelDefinition]:
  14. stmt = select(ModelDefinition).order_by(
  15. ModelDefinition.status.asc(),
  16. ModelDefinition.name.asc(),
  17. )
  18. return list(self.db.scalars(stmt))
  19. def get_by_id(self, model_id: str) -> ModelDefinition | None:
  20. return self.db.get(ModelDefinition, model_id)
  21. def get_by_code(self, code: str) -> ModelDefinition | None:
  22. stmt = select(ModelDefinition).where(ModelDefinition.code == code).limit(1)
  23. return self.db.scalar(stmt)
  24. def get_by_provider_model(
  25. self,
  26. *,
  27. provider_id: str,
  28. model_name: str) -> ModelDefinition | None:
  29. stmt = (
  30. select(ModelDefinition)
  31. .where(ModelDefinition.provider_id == provider_id)
  32. .where(ModelDefinition.model_name == model_name)
  33. .limit(1)
  34. )
  35. return self.db.scalar(stmt)
  36. def get_active_for_request(self, model: str) -> ModelDefinition | None:
  37. stmt = (
  38. select(ModelDefinition)
  39. .where(ModelDefinition.status == "active")
  40. .where(
  41. (ModelDefinition.code == model)
  42. | (ModelDefinition.model_name == model)
  43. )
  44. .limit(1)
  45. )
  46. return self.db.scalar(stmt)
  47. def update(self, entity: ModelDefinition) -> ModelDefinition:
  48. entity.updated_time = datetime.utcnow()
  49. self.db.commit()
  50. self.db.refresh(entity)
  51. return entity
  52. def delete(self, entity: ModelDefinition) -> None:
  53. self.db.delete(entity)
  54. self.db.commit()
  55. class ModelProviderDefinitionRepository:
  56. def __init__(self, db: Session) -> None:
  57. self.db = db
  58. def create(self, entity: ModelProviderDefinition) -> ModelProviderDefinition:
  59. self.db.add(entity)
  60. self.db.commit()
  61. self.db.refresh(entity)
  62. return entity
  63. def list_all(self) -> list[ModelProviderDefinition]:
  64. stmt = select(ModelProviderDefinition).order_by(
  65. ModelProviderDefinition.updated_time.desc(),
  66. ModelProviderDefinition.name.asc(),
  67. )
  68. return list(self.db.scalars(stmt))
  69. def get_by_id(self, provider_id: str) -> ModelProviderDefinition | None:
  70. return self.db.get(ModelProviderDefinition, provider_id)
  71. def get_by_connection(
  72. self,
  73. *,
  74. provider_type: str,
  75. base_url: str) -> ModelProviderDefinition | None:
  76. stmt = (
  77. select(ModelProviderDefinition)
  78. .where(ModelProviderDefinition.provider_type == provider_type)
  79. .where(ModelProviderDefinition.base_url == base_url)
  80. .limit(1)
  81. )
  82. return self.db.scalar(stmt)
  83. def update(self, entity: ModelProviderDefinition) -> ModelProviderDefinition:
  84. entity.updated_time = datetime.utcnow()
  85. self.db.commit()
  86. self.db.refresh(entity)
  87. return entity
  88. def delete(self, entity: ModelProviderDefinition) -> None:
  89. self.db.delete(entity)
  90. self.db.commit()