from datetime import datetime from sqlalchemy import delete, or_, select, text from sqlalchemy.orm import Session from core_domain import KnowledgeBaseStatus, KnowledgeDocumentStatus from core_shared import JSONValue from app.db.models import KnowledgeBase, KnowledgeChunk, KnowledgeDocument class KnowledgeBaseRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, code: str, name: str, description: str | None, metadata_json: dict[str, JSONValue] | None) -> KnowledgeBase: entity = KnowledgeBase( code=code, name=name, description=description, metadata_json=metadata_json) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_all(self) -> list[KnowledgeBase]: stmt = ( select(KnowledgeBase) .order_by(KnowledgeBase.created_time.desc()) ) return list(self.db.scalars(stmt)) def list_filtered( self, *, keyword: str | None = None, status: KnowledgeBaseStatus | None = None) -> list[KnowledgeBase]: stmt = select(KnowledgeBase) if status is not None: stmt = stmt.where(KnowledgeBase.status == status) if keyword: pattern = f"%{keyword.strip()}%" stmt = stmt.where( or_( KnowledgeBase.name.ilike(pattern), KnowledgeBase.description.ilike(pattern))) stmt = stmt.order_by(KnowledgeBase.created_time.desc()) return list(self.db.scalars(stmt)) def get_by_id(self, *, knowledge_base_id: str) -> KnowledgeBase | None: stmt = ( select(KnowledgeBase) .where(KnowledgeBase.id == knowledge_base_id) ) return self.db.scalar(stmt) def update_status( self, *, knowledge_base_id: str, status: KnowledgeBaseStatus) -> KnowledgeBase | None: entity = self.get_by_id(knowledge_base_id=knowledge_base_id) if entity is None: return None entity.status = status self.db.commit() self.db.refresh(entity) return entity def update( self, *, knowledge_base_id: str, name: str | None = None, description: str | None = None, status: KnowledgeBaseStatus | None = None, metadata_json: dict[str, JSONValue] | None = None) -> KnowledgeBase | None: entity = self.get_by_id(knowledge_base_id=knowledge_base_id) if entity is None: return None if name is not None: entity.name = name if description is not None: entity.description = description if status is not None: entity.status = status if metadata_json is not None: entity.metadata_json = metadata_json self.db.commit() self.db.refresh(entity) return entity def delete(self, *, knowledge_base_id: str) -> bool: entity = self.get_by_id(knowledge_base_id=knowledge_base_id) if entity is None: return False self.db.delete(entity) self.db.commit() return True class KnowledgeDocumentRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, knowledge_base_id: str, title: str, source_type: str, source_uri: str | None, content_text: str, content_hash: str | None, metadata_json: dict[str, JSONValue] | None, status: KnowledgeDocumentStatus = "draft") -> KnowledgeDocument: entity = KnowledgeDocument( knowledge_base_id=knowledge_base_id, title=title, source_type=source_type, source_uri=source_uri, content_text=content_text, content_hash=content_hash, metadata_json=metadata_json, status=status) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_base( self, *, knowledge_base_id: str) -> list[KnowledgeDocument]: stmt = ( select(KnowledgeDocument) .where(KnowledgeDocument.knowledge_base_id == knowledge_base_id) .order_by(KnowledgeDocument.created_time.desc()) ) return list(self.db.scalars(stmt)) def list_filtered( self, *, knowledge_base_id: str | None = None, keyword: str | None = None, status: KnowledgeDocumentStatus | None = None, source_type: str | None = None) -> list[KnowledgeDocument]: stmt = select(KnowledgeDocument) if knowledge_base_id is not None: stmt = stmt.where(KnowledgeDocument.knowledge_base_id == knowledge_base_id) if status is not None: stmt = stmt.where(KnowledgeDocument.status == status) if source_type is not None: stmt = stmt.where(KnowledgeDocument.source_type == source_type) if keyword: pattern = f"%{keyword.strip()}%" stmt = stmt.where( or_( KnowledgeDocument.title.ilike(pattern), KnowledgeDocument.source_uri.ilike(pattern))) stmt = stmt.order_by(KnowledgeDocument.created_time.desc()) return list(self.db.scalars(stmt)) def get_next_pending_indexing( self, *, stale_before: datetime) -> KnowledgeDocument | None: stmt = ( select(KnowledgeDocument) .where( or_( KnowledgeDocument.status == "queued", (KnowledgeDocument.status == "indexing") & (KnowledgeDocument.updated_time < stale_before))) .order_by(KnowledgeDocument.updated_time.asc()) .limit(1) ) return self.db.scalar(stmt) def get_by_id(self, *, document_id: str) -> KnowledgeDocument | None: stmt = ( select(KnowledgeDocument) .where(KnowledgeDocument.id == document_id) ) return self.db.scalar(stmt) def update_status( self, *, document_id: str, status: KnowledgeDocumentStatus) -> KnowledgeDocument | None: entity = self.db.get(KnowledgeDocument, document_id) if entity is None: return None entity.status = status entity.indexed_time = datetime.utcnow() if status == "indexed" else entity.indexed_time self.db.commit() self.db.refresh(entity) return entity def update( self, *, document_id: str, title: str | None = None, source_uri: str | None = None, status: KnowledgeDocumentStatus | None = None, metadata_json: dict[str, JSONValue] | None = None) -> KnowledgeDocument | None: entity = self.get_by_id(document_id=document_id) if entity is None: return None if title is not None: entity.title = title if source_uri is not None: entity.source_uri = source_uri if status is not None: entity.status = status entity.indexed_time = datetime.utcnow() if status == "indexed" else entity.indexed_time if metadata_json is not None: entity.metadata_json = metadata_json self.db.commit() self.db.refresh(entity) return entity def delete(self, *, document_id: str) -> KnowledgeDocument | None: entity = self.get_by_id(document_id=document_id) if entity is None: return None self.db.delete(entity) self.db.commit() return entity class KnowledgeChunkRepository: def __init__(self, db: Session) -> None: self.db = db def replace_document_chunks( self, *, knowledge_base_id: str, document_id: str, chunks: list[dict[str, JSONValue]]) -> list[KnowledgeChunk]: self.db.execute( delete(KnowledgeChunk) .where(KnowledgeChunk.document_id == document_id) ) entities: list[KnowledgeChunk] = [] for chunk in chunks: entity = KnowledgeChunk( knowledge_base_id=knowledge_base_id, document_id=document_id, chunk_index=_read_int(chunk, "chunk_index"), content_text=_read_string(chunk, "content_text"), token_count=_read_int(chunk, "token_count"), embedding_model=_read_optional_string(chunk, "embedding_model"), embedding_json=_read_float_list(chunk, "embedding_json"), embedding_vector=_format_vector(_read_float_list(chunk, "embedding_json")), metadata_json=_read_optional_dict(chunk, "metadata_json")) self.db.add(entity) entities.append(entity) self.db.commit() for entity in entities: self.db.refresh(entity) return entities def list_by_base( self, *, knowledge_base_id: str) -> list[KnowledgeChunk]: stmt = ( select(KnowledgeChunk) .where(KnowledgeChunk.knowledge_base_id == knowledge_base_id) .order_by(KnowledgeChunk.created_time.asc()) ) return list(self.db.scalars(stmt)) def list_filtered( self, *, knowledge_base_id: str | None = None, document_id: str | None = None, keyword: str | None = None) -> list[KnowledgeChunk]: stmt = select(KnowledgeChunk) if knowledge_base_id is not None: stmt = stmt.where(KnowledgeChunk.knowledge_base_id == knowledge_base_id) if document_id is not None: stmt = stmt.where(KnowledgeChunk.document_id == document_id) if keyword: stmt = stmt.where(KnowledgeChunk.content_text.ilike(f"%{keyword.strip()}%")) stmt = stmt.order_by(KnowledgeChunk.created_time.asc()) return list(self.db.scalars(stmt)) def list_by_document(self, *, document_id: str) -> list[KnowledgeChunk]: return self.list_filtered(document_id=document_id) def get_by_id(self, *, chunk_id: str) -> KnowledgeChunk | None: stmt = select(KnowledgeChunk).where(KnowledgeChunk.id == chunk_id) return self.db.scalar(stmt) def delete_by_document(self, *, document_id: str) -> int: result = self.db.execute( delete(KnowledgeChunk) .where(KnowledgeChunk.document_id == document_id)) self.db.commit() return int(result.rowcount or 0) def delete_by_base(self, *, knowledge_base_id: str) -> int: result = self.db.execute( delete(KnowledgeChunk) .where(KnowledgeChunk.knowledge_base_id == knowledge_base_id)) self.db.commit() return int(result.rowcount or 0) def delete(self, *, chunk_id: str) -> bool: entity = self.get_by_id(chunk_id=chunk_id) if entity is None: return False self.db.delete(entity) self.db.commit() return True def search_by_vector( self, *, knowledge_base_id: str, embedding: list[float], limit: int) -> list[tuple[KnowledgeChunk, float]]: if not self._supports_pgvector_search(): return [] vector = _format_vector(embedding) if vector is None: return [] stmt = text( """ SELECT id, 1 - ( embedding_vector OPERATOR(public.<=>) CAST(:embedding AS public.vector) ) AS score FROM knowledge_chunk WHERE knowledge_base_id = :knowledge_base_id AND embedding_vector IS NOT NULL ORDER BY embedding_vector OPERATOR(public.<=>) CAST(:embedding AS public.vector) LIMIT :limit """ ) try: rows = self.db.execute( stmt, { "knowledge_base_id": knowledge_base_id, "embedding": vector, "limit": limit, }).all() except Exception: self.db.rollback() return [] if not rows: return [] chunk_ids = [str(row[0]) for row in rows] chunks_by_id = { chunk.id: chunk for chunk in self.db.scalars( select(KnowledgeChunk).where(KnowledgeChunk.id.in_(chunk_ids)) ) } scored: list[tuple[KnowledgeChunk, float]] = [] for row in rows: chunk = chunks_by_id.get(str(row[0])) if chunk is not None: scored.append((chunk, float(row[1] or 0.0))) return scored def _supports_pgvector_search(self) -> bool: return self.db.bind is not None and self.db.bind.dialect.name == "postgresql" def _read_string(payload: dict[str, JSONValue], key: str) -> str: value = payload.get(key) return value if isinstance(value, str) else "" def _read_optional_string(payload: dict[str, JSONValue], key: str) -> str | None: value = payload.get(key) return value if isinstance(value, str) else None def _read_int(payload: dict[str, JSONValue], key: str) -> int: value = payload.get(key) return value if isinstance(value, int) and not isinstance(value, bool) else 0 def _read_float_list(payload: dict[str, JSONValue], key: str) -> list[float] | None: value = payload.get(key) if not isinstance(value, list): return None return [float(item) for item in value if isinstance(item, (int, float))] def _format_vector(value: list[float] | None) -> str | None: if not value: return None return "[" + ",".join(str(float(item)) for item in value) + "]" def _read_optional_dict( payload: dict[str, JSONValue], key: str) -> dict[str, JSONValue] | None: value = payload.get(key) if isinstance(value, dict): return {str(item_key): item_value for item_key, item_value in value.items()} return None