| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- from datetime import datetime
- from sqlalchemy import select
- from sqlalchemy.orm import Session
- from app.db.models import ModelDefinition, ModelProviderDefinition
- class ModelDefinitionRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(self, entity: ModelDefinition) -> ModelDefinition:
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_all(self) -> list[ModelDefinition]:
- stmt = select(ModelDefinition).order_by(
- ModelDefinition.status.asc(),
- ModelDefinition.name.asc(),
- )
- return list(self.db.scalars(stmt))
- def get_by_id(self, model_id: str) -> ModelDefinition | None:
- return self.db.get(ModelDefinition, model_id)
- def get_by_code(self, code: str) -> ModelDefinition | None:
- stmt = select(ModelDefinition).where(ModelDefinition.code == code).limit(1)
- return self.db.scalar(stmt)
- def get_by_provider_model(
- self,
- *,
- provider_id: str,
- model_name: str) -> ModelDefinition | None:
- stmt = (
- select(ModelDefinition)
- .where(ModelDefinition.provider_id == provider_id)
- .where(ModelDefinition.model_name == model_name)
- .limit(1)
- )
- return self.db.scalar(stmt)
- def get_active_for_request(self, model: str) -> ModelDefinition | None:
- stmt = (
- select(ModelDefinition)
- .where(ModelDefinition.status == "active")
- .where(
- (ModelDefinition.code == model)
- | (ModelDefinition.model_name == model)
- )
- .limit(1)
- )
- return self.db.scalar(stmt)
- def update(self, entity: ModelDefinition) -> ModelDefinition:
- entity.updated_time = datetime.utcnow()
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def delete(self, entity: ModelDefinition) -> None:
- self.db.delete(entity)
- self.db.commit()
- class ModelProviderDefinitionRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(self, entity: ModelProviderDefinition) -> ModelProviderDefinition:
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_all(self) -> list[ModelProviderDefinition]:
- stmt = select(ModelProviderDefinition).order_by(
- ModelProviderDefinition.updated_time.desc(),
- ModelProviderDefinition.name.asc(),
- )
- return list(self.db.scalars(stmt))
- def get_by_id(self, provider_id: str) -> ModelProviderDefinition | None:
- return self.db.get(ModelProviderDefinition, provider_id)
- def get_by_connection(
- self,
- *,
- provider_type: str,
- base_url: str) -> ModelProviderDefinition | None:
- stmt = (
- select(ModelProviderDefinition)
- .where(ModelProviderDefinition.provider_type == provider_type)
- .where(ModelProviderDefinition.base_url == base_url)
- .limit(1)
- )
- return self.db.scalar(stmt)
- def update(self, entity: ModelProviderDefinition) -> ModelProviderDefinition:
- entity.updated_time = datetime.utcnow()
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def delete(self, entity: ModelProviderDefinition) -> None:
- self.db.delete(entity)
- self.db.commit()
|