from datetime import datetime from sqlalchemy import func, select from sqlalchemy.orm import Session from core_domain import TeamRunStatus, TeamStatus, TeamVersionStatus from core_shared import JSONValue from app.db.models import TeamDefinition, TeamRun, TeamVersion class TeamDefinitionRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, code: str, name: str, description: str | None, team_type: str, owner_user_id: str | None, metadata_json: dict[str, JSONValue] | None) -> TeamDefinition: entity = TeamDefinition( code=code, name=name, description=description, team_type=team_type, owner_user_id=owner_user_id, metadata_json=metadata_json) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_all(self) -> list[TeamDefinition]: stmt = ( select(TeamDefinition) .order_by(TeamDefinition.created_time.desc()) ) return list(self.db.scalars(stmt)) def get_by_id(self, *, team_id: str) -> TeamDefinition | None: stmt = ( select(TeamDefinition) .where(TeamDefinition.id == team_id) ) return self.db.scalar(stmt) def update_status( self, *, team_id: str, status: TeamStatus) -> TeamDefinition | None: entity = self.get_by_id(team_id=team_id) if entity is None: return None entity.status = status self.db.commit() self.db.refresh(entity) return entity class TeamVersionRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, team_id: str, status: TeamVersionStatus, coordination_mode: str, objective: str | None, member_refs_json: list[dict[str, JSONValue]], policy_json: dict[str, JSONValue]) -> TeamVersion: version_no = self._next_version_no(team_id) entity = TeamVersion( team_id=team_id, version_no=version_no, status=status, coordination_mode=coordination_mode, objective=objective, member_refs_json=member_refs_json, policy_json=policy_json, published_time=datetime.utcnow() if status == "published" else None) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_team(self, *, team_id: str) -> list[TeamVersion]: stmt = ( select(TeamVersion) .where(TeamVersion.team_id == team_id) .order_by(TeamVersion.version_no.desc()) ) return list(self.db.scalars(stmt)) def get_by_id(self, *, team_version_id: str) -> TeamVersion | None: stmt = ( select(TeamVersion) .where(TeamVersion.id == team_version_id) ) return self.db.scalar(stmt) def get_latest_published(self, *, team_id: str) -> TeamVersion | None: stmt = ( select(TeamVersion) .where(TeamVersion.team_id == team_id) .where(TeamVersion.status == "published") .order_by(TeamVersion.version_no.desc()) .limit(1) ) return self.db.scalar(stmt) def _next_version_no(self, team_id: str) -> int: stmt = select(func.max(TeamVersion.version_no)).where(TeamVersion.team_id == team_id) current_max = self.db.scalar(stmt) return (current_max or 0) + 1 class TeamRunRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, team_id: str, team_version_id: str, session_id: str | None, input_text: str | None, input_json: dict[str, JSONValue] | None) -> TeamRun: now = datetime.utcnow() entity = TeamRun( team_id=team_id, team_version_id=team_version_id, session_id=session_id, input_text=input_text, input_json=input_json, status="queued", queued_time=now) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_scope( self, *, team_id: str | None = None, session_id: str | None = None) -> list[TeamRun]: stmt = select(TeamRun) if team_id is not None: stmt = stmt.where(TeamRun.team_id == team_id) if session_id is not None: stmt = stmt.where(TeamRun.session_id == session_id) stmt = stmt.order_by(TeamRun.created_time.desc()) return list(self.db.scalars(stmt)) def get_by_id(self, *, team_run_id: str) -> TeamRun | None: stmt = ( select(TeamRun) .where(TeamRun.id == team_run_id) ) return self.db.scalar(stmt) def claim_next_queued( self, *, worker_key: str, lease_expire_time: datetime) -> TeamRun | None: stmt = ( select(TeamRun) .where(TeamRun.status == "queued") .order_by(TeamRun.created_time.asc()) .with_for_update(skip_locked=True) .limit(1) ) entity = self.db.scalar(stmt) if entity is None: return None now = datetime.utcnow() entity.status = "running" entity.worker_key = worker_key entity.started_time = entity.started_time or now entity.lease_expire_time = lease_expire_time self.db.commit() self.db.refresh(entity) return entity def release_expired_leases(self, *, now_time: datetime, max_items: int = 100) -> int: stmt = ( select(TeamRun) .where(TeamRun.status == "running") .where(TeamRun.lease_expire_time.is_not(None)) .where(TeamRun.lease_expire_time <= now_time) .order_by(TeamRun.lease_expire_time.asc()) .limit(max_items) ) entities = list(self.db.scalars(stmt)) for entity in entities: entity.status = "queued" entity.worker_key = None entity.lease_expire_time = None entity.queued_time = now_time entity.started_time = None entity.finished_time = None if entities: self.db.commit() return len(entities) def update_status( self, *, team_run_id: str, status: TeamRunStatus, worker_key: str | None = None, output_text: str | None = None, output_json: dict[str, JSONValue] | None = None, error_code: str | None = None, error_message: str | None = None) -> TeamRun | None: entity = self.db.get(TeamRun, team_run_id) if entity is None: return None now = datetime.utcnow() entity.status = status entity.worker_key = worker_key entity.output_text = output_text entity.output_json = output_json entity.error_code = error_code entity.error_message = error_message if status == "running" and entity.started_time is None: entity.started_time = now if status != "running": entity.lease_expire_time = None if status in {"completed", "failed", "cancelled"}: entity.finished_time = now self.db.commit() self.db.refresh(entity) return entity