from sqlalchemy import func, select from sqlalchemy.orm import Session from app.db.models import ToolBinding, ToolDefinition, ToolVersion from core_shared import JSONValue class ToolDefinitionRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, tenant_id: str, plugin_id: str | None, code: str, name: str, tool_type: str, description: str | None, ) -> ToolDefinition: entity = ToolDefinition( tenant_id=tenant_id, plugin_id=plugin_id, code=code, name=name, tool_type=tool_type, description=description, ) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_tenant(self, tenant_id: str) -> list[ToolDefinition]: stmt = ( select(ToolDefinition) .where(ToolDefinition.tenant_id == tenant_id) .order_by(ToolDefinition.created_time.desc()) ) return list(self.db.scalars(stmt)) def get_by_id(self, *, tenant_id: str, tool_id: str) -> ToolDefinition | None: stmt = ( select(ToolDefinition) .where(ToolDefinition.tenant_id == tenant_id) .where(ToolDefinition.id == tool_id) ) return self.db.scalar(stmt) class ToolVersionRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, tenant_id: str, tool_id: str, input_schema_json: dict[str, JSONValue] | None, output_schema_json: dict[str, JSONValue] | None, invoke_config_json: dict[str, JSONValue] | None, timeout_ms: int | None, retry_policy_json: dict[str, JSONValue] | None, ) -> ToolVersion: entity = ToolVersion( tenant_id=tenant_id, tool_id=tool_id, version_no=self._next_version_no(tool_id), input_schema_json=input_schema_json, output_schema_json=output_schema_json, invoke_config_json=invoke_config_json, timeout_ms=timeout_ms, retry_policy_json=retry_policy_json, ) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_tool(self, *, tenant_id: str, tool_id: str) -> list[ToolVersion]: stmt = ( select(ToolVersion) .where(ToolVersion.tenant_id == tenant_id) .where(ToolVersion.tool_id == tool_id) .order_by(ToolVersion.version_no.desc()) ) return list(self.db.scalars(stmt)) def get_by_id(self, *, tenant_id: str, tool_version_id: str) -> ToolVersion | None: stmt = ( select(ToolVersion) .where(ToolVersion.tenant_id == tenant_id) .where(ToolVersion.id == tool_version_id) ) return self.db.scalar(stmt) def _next_version_no(self, tool_id: str) -> int: stmt = select(func.max(ToolVersion.version_no)).where(ToolVersion.tool_id == tool_id) current_max = self.db.scalar(stmt) return (current_max or 0) + 1 class ToolBindingRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, tenant_id: str, app_id: str, tool_version_id: str, credential_id: str | None, binding_scope: str, enabled: bool, config_json: dict[str, JSONValue] | None, ) -> ToolBinding: entity = ToolBinding( tenant_id=tenant_id, app_id=app_id, tool_version_id=tool_version_id, credential_id=credential_id, binding_scope=binding_scope, enabled=enabled, config_json=config_json, ) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_scope(self, *, tenant_id: str, app_id: str | None = None) -> list[ToolBinding]: stmt = select(ToolBinding).where(ToolBinding.tenant_id == tenant_id) if app_id is not None: stmt = stmt.where(ToolBinding.app_id == app_id) stmt = stmt.order_by(ToolBinding.created_time.desc()) return list(self.db.scalars(stmt)) def get_by_id(self, *, tenant_id: str, binding_id: str) -> ToolBinding | None: stmt = ( select(ToolBinding) .where(ToolBinding.tenant_id == tenant_id) .where(ToolBinding.id == binding_id) ) return self.db.scalar(stmt)