from datetime import datetime from sqlalchemy import delete, select, text from sqlalchemy.orm import Session from core_domain import KnowledgeBaseStatus, KnowledgeDocumentStatus from core_shared import JSONValue from app.db.models import KnowledgeBase, KnowledgeChunk, KnowledgeDocument class KnowledgeBaseRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, tenant_id: str, code: str, name: str, description: str | None, metadata_json: dict[str, JSONValue] | None, ) -> KnowledgeBase: entity = KnowledgeBase( tenant_id=tenant_id, code=code, name=name, description=description, metadata_json=metadata_json, ) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_tenant(self, *, tenant_id: str) -> list[KnowledgeBase]: stmt = ( select(KnowledgeBase) .where(KnowledgeBase.tenant_id == tenant_id) .order_by(KnowledgeBase.created_time.desc()) ) return list(self.db.scalars(stmt)) def get_by_id(self, *, tenant_id: str, knowledge_base_id: str) -> KnowledgeBase | None: stmt = ( select(KnowledgeBase) .where(KnowledgeBase.tenant_id == tenant_id) .where(KnowledgeBase.id == knowledge_base_id) ) return self.db.scalar(stmt) def update_status( self, *, tenant_id: str, knowledge_base_id: str, status: KnowledgeBaseStatus, ) -> KnowledgeBase | None: entity = self.get_by_id(tenant_id=tenant_id, knowledge_base_id=knowledge_base_id) if entity is None: return None entity.status = status self.db.commit() self.db.refresh(entity) return entity class KnowledgeDocumentRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, tenant_id: str, knowledge_base_id: str, title: str, source_type: str, source_uri: str | None, content_text: str, content_hash: str | None, metadata_json: dict[str, JSONValue] | None, ) -> KnowledgeDocument: entity = KnowledgeDocument( tenant_id=tenant_id, knowledge_base_id=knowledge_base_id, title=title, source_type=source_type, source_uri=source_uri, content_text=content_text, content_hash=content_hash, metadata_json=metadata_json, status="draft", ) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_base( self, *, tenant_id: str, knowledge_base_id: str, ) -> list[KnowledgeDocument]: stmt = ( select(KnowledgeDocument) .where(KnowledgeDocument.tenant_id == tenant_id) .where(KnowledgeDocument.knowledge_base_id == knowledge_base_id) .order_by(KnowledgeDocument.created_time.desc()) ) return list(self.db.scalars(stmt)) def get_by_id(self, *, tenant_id: str, document_id: str) -> KnowledgeDocument | None: stmt = ( select(KnowledgeDocument) .where(KnowledgeDocument.tenant_id == tenant_id) .where(KnowledgeDocument.id == document_id) ) return self.db.scalar(stmt) def update_status( self, *, document_id: str, status: KnowledgeDocumentStatus, ) -> KnowledgeDocument | None: entity = self.db.get(KnowledgeDocument, document_id) if entity is None: return None entity.status = status entity.indexed_time = datetime.utcnow() if status == "indexed" else entity.indexed_time self.db.commit() self.db.refresh(entity) return entity class KnowledgeChunkRepository: def __init__(self, db: Session) -> None: self.db = db def replace_document_chunks( self, *, tenant_id: str, knowledge_base_id: str, document_id: str, chunks: list[dict[str, JSONValue]], ) -> list[KnowledgeChunk]: self.db.execute( delete(KnowledgeChunk) .where(KnowledgeChunk.tenant_id == tenant_id) .where(KnowledgeChunk.document_id == document_id) ) entities: list[KnowledgeChunk] = [] for chunk in chunks: entity = KnowledgeChunk( tenant_id=tenant_id, knowledge_base_id=knowledge_base_id, document_id=document_id, chunk_index=_read_int(chunk, "chunk_index"), content_text=_read_string(chunk, "content_text"), token_count=_read_int(chunk, "token_count"), embedding_model=_read_optional_string(chunk, "embedding_model"), embedding_json=_read_float_list(chunk, "embedding_json"), embedding_vector=_format_vector(_read_float_list(chunk, "embedding_json")), metadata_json=_read_optional_dict(chunk, "metadata_json"), ) self.db.add(entity) entities.append(entity) self.db.commit() for entity in entities: self.db.refresh(entity) return entities def list_by_base( self, *, tenant_id: str, knowledge_base_id: str, ) -> list[KnowledgeChunk]: stmt = ( select(KnowledgeChunk) .where(KnowledgeChunk.tenant_id == tenant_id) .where(KnowledgeChunk.knowledge_base_id == knowledge_base_id) .order_by(KnowledgeChunk.created_time.asc()) ) return list(self.db.scalars(stmt)) def search_by_vector( self, *, tenant_id: str, knowledge_base_id: str, embedding: list[float], limit: int, ) -> list[tuple[KnowledgeChunk, float]]: if not self._supports_pgvector_search(): return [] vector = _format_vector(embedding) if vector is None: return [] stmt = text( """ SELECT id, 1 - (embedding_vector <=> CAST(:embedding AS vector)) AS score FROM knowledge_chunk WHERE tenant_id = :tenant_id AND knowledge_base_id = :knowledge_base_id AND embedding_vector IS NOT NULL ORDER BY embedding_vector <=> CAST(:embedding AS vector) LIMIT :limit """ ) rows = self.db.execute( stmt, { "tenant_id": tenant_id, "knowledge_base_id": knowledge_base_id, "embedding": vector, "limit": limit, }, ).all() if not rows: return [] chunk_ids = [str(row[0]) for row in rows] chunks_by_id = { chunk.id: chunk for chunk in self.db.scalars( select(KnowledgeChunk).where(KnowledgeChunk.id.in_(chunk_ids)) ) } scored: list[tuple[KnowledgeChunk, float]] = [] for row in rows: chunk = chunks_by_id.get(str(row[0])) if chunk is not None: scored.append((chunk, float(row[1] or 0.0))) return scored def _supports_pgvector_search(self) -> bool: return self.db.bind is not None and self.db.bind.dialect.name == "postgresql" def _read_string(payload: dict[str, JSONValue], key: str) -> str: value = payload.get(key) return value if isinstance(value, str) else "" def _read_optional_string(payload: dict[str, JSONValue], key: str) -> str | None: value = payload.get(key) return value if isinstance(value, str) else None def _read_int(payload: dict[str, JSONValue], key: str) -> int: value = payload.get(key) return value if isinstance(value, int) and not isinstance(value, bool) else 0 def _read_float_list(payload: dict[str, JSONValue], key: str) -> list[float] | None: value = payload.get(key) if not isinstance(value, list): return None return [float(item) for item in value if isinstance(item, (int, float))] def _format_vector(value: list[float] | None) -> str | None: if not value: return None return "[" + ",".join(str(float(item)) for item in value) + "]" def _read_optional_dict( payload: dict[str, JSONValue], key: str, ) -> dict[str, JSONValue] | None: value = payload.get(key) if isinstance(value, dict): return {str(item_key): item_value for item_key, item_value in value.items()} return None