repositories.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from datetime import datetime
  2. from sqlalchemy import select
  3. from sqlalchemy.orm import Session
  4. from app.db.models import ModelDefinition, ModelProviderDefinition
  5. class ModelDefinitionRepository:
  6. def __init__(self, db: Session) -> None:
  7. self.db = db
  8. def create(self, entity: ModelDefinition) -> ModelDefinition:
  9. self.db.add(entity)
  10. self.db.commit()
  11. self.db.refresh(entity)
  12. return entity
  13. def list_all(self) -> list[ModelDefinition]:
  14. stmt = select(ModelDefinition).order_by(
  15. ModelDefinition.status.asc(),
  16. ModelDefinition.name.asc(),
  17. )
  18. return list(self.db.scalars(stmt))
  19. def get_by_id(self, model_id: str) -> ModelDefinition | None:
  20. return self.db.get(ModelDefinition, model_id)
  21. def get_by_code(self, code: str) -> ModelDefinition | None:
  22. stmt = select(ModelDefinition).where(ModelDefinition.code == code).limit(1)
  23. return self.db.scalar(stmt)
  24. def get_active_for_request(self, model: str) -> ModelDefinition | None:
  25. stmt = (
  26. select(ModelDefinition)
  27. .where(ModelDefinition.status == "active")
  28. .where(
  29. (ModelDefinition.code == model)
  30. | (ModelDefinition.model_name == model)
  31. )
  32. .limit(1)
  33. )
  34. return self.db.scalar(stmt)
  35. def update(self, entity: ModelDefinition) -> ModelDefinition:
  36. entity.updated_time = datetime.utcnow()
  37. self.db.commit()
  38. self.db.refresh(entity)
  39. return entity
  40. def delete(self, entity: ModelDefinition) -> None:
  41. self.db.delete(entity)
  42. self.db.commit()
  43. class ModelProviderDefinitionRepository:
  44. def __init__(self, db: Session) -> None:
  45. self.db = db
  46. def create(self, entity: ModelProviderDefinition) -> ModelProviderDefinition:
  47. self.db.add(entity)
  48. self.db.commit()
  49. self.db.refresh(entity)
  50. return entity
  51. def list_all(self) -> list[ModelProviderDefinition]:
  52. stmt = select(ModelProviderDefinition).order_by(
  53. ModelProviderDefinition.updated_time.desc(),
  54. ModelProviderDefinition.name.asc(),
  55. )
  56. return list(self.db.scalars(stmt))
  57. def get_by_id(self, provider_id: str) -> ModelProviderDefinition | None:
  58. return self.db.get(ModelProviderDefinition, provider_id)
  59. def get_by_connection(
  60. self,
  61. *,
  62. provider_type: str,
  63. base_url: str) -> ModelProviderDefinition | None:
  64. stmt = (
  65. select(ModelProviderDefinition)
  66. .where(ModelProviderDefinition.provider_type == provider_type)
  67. .where(ModelProviderDefinition.base_url == base_url)
  68. .limit(1)
  69. )
  70. return self.db.scalar(stmt)
  71. def update(self, entity: ModelProviderDefinition) -> ModelProviderDefinition:
  72. entity.updated_time = datetime.utcnow()
  73. self.db.commit()
  74. self.db.refresh(entity)
  75. return entity
  76. def delete(self, entity: ModelProviderDefinition) -> None:
  77. self.db.delete(entity)
  78. self.db.commit()