| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419 |
- from datetime import datetime
- from sqlalchemy import delete, or_, select, text
- from sqlalchemy.orm import Session
- from core_domain import KnowledgeBaseStatus, KnowledgeDocumentStatus
- from core_shared import JSONValue
- from app.db.models import KnowledgeBase, KnowledgeChunk, KnowledgeDocument
- class KnowledgeBaseRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- code: str,
- name: str,
- description: str | None,
- metadata_json: dict[str, JSONValue] | None) -> KnowledgeBase:
- entity = KnowledgeBase(
- code=code,
- name=name,
- description=description,
- metadata_json=metadata_json)
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_all(self) -> list[KnowledgeBase]:
- stmt = (
- select(KnowledgeBase)
- .order_by(KnowledgeBase.created_time.desc())
- )
- return list(self.db.scalars(stmt))
- def list_filtered(
- self,
- *,
- keyword: str | None = None,
- status: KnowledgeBaseStatus | None = None) -> list[KnowledgeBase]:
- stmt = select(KnowledgeBase)
- if status is not None:
- stmt = stmt.where(KnowledgeBase.status == status)
- if keyword:
- pattern = f"%{keyword.strip()}%"
- stmt = stmt.where(
- or_(
- KnowledgeBase.name.ilike(pattern),
- KnowledgeBase.description.ilike(pattern)))
- stmt = stmt.order_by(KnowledgeBase.created_time.desc())
- return list(self.db.scalars(stmt))
- def get_by_id(self, *, knowledge_base_id: str) -> KnowledgeBase | None:
- stmt = (
- select(KnowledgeBase)
- .where(KnowledgeBase.id == knowledge_base_id)
- )
- return self.db.scalar(stmt)
- def update_status(
- self,
- *,
- knowledge_base_id: str,
- status: KnowledgeBaseStatus) -> KnowledgeBase | None:
- entity = self.get_by_id(knowledge_base_id=knowledge_base_id)
- if entity is None:
- return None
- entity.status = status
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def update(
- self,
- *,
- knowledge_base_id: str,
- name: str | None = None,
- description: str | None = None,
- status: KnowledgeBaseStatus | None = None,
- metadata_json: dict[str, JSONValue] | None = None) -> KnowledgeBase | None:
- entity = self.get_by_id(knowledge_base_id=knowledge_base_id)
- if entity is None:
- return None
- if name is not None:
- entity.name = name
- if description is not None:
- entity.description = description
- if status is not None:
- entity.status = status
- if metadata_json is not None:
- entity.metadata_json = metadata_json
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def delete(self, *, knowledge_base_id: str) -> bool:
- entity = self.get_by_id(knowledge_base_id=knowledge_base_id)
- if entity is None:
- return False
- self.db.delete(entity)
- self.db.commit()
- return True
- class KnowledgeDocumentRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- knowledge_base_id: str,
- title: str,
- source_type: str,
- source_uri: str | None,
- content_text: str,
- content_hash: str | None,
- metadata_json: dict[str, JSONValue] | None,
- status: KnowledgeDocumentStatus = "draft") -> KnowledgeDocument:
- entity = KnowledgeDocument(
- knowledge_base_id=knowledge_base_id,
- title=title,
- source_type=source_type,
- source_uri=source_uri,
- content_text=content_text,
- content_hash=content_hash,
- metadata_json=metadata_json,
- status=status)
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_by_base(
- self,
- *,
- knowledge_base_id: str) -> list[KnowledgeDocument]:
- stmt = (
- select(KnowledgeDocument)
- .where(KnowledgeDocument.knowledge_base_id == knowledge_base_id)
- .order_by(KnowledgeDocument.created_time.desc())
- )
- return list(self.db.scalars(stmt))
- def list_filtered(
- self,
- *,
- knowledge_base_id: str | None = None,
- keyword: str | None = None,
- status: KnowledgeDocumentStatus | None = None,
- source_type: str | None = None) -> list[KnowledgeDocument]:
- stmt = select(KnowledgeDocument)
- if knowledge_base_id is not None:
- stmt = stmt.where(KnowledgeDocument.knowledge_base_id == knowledge_base_id)
- if status is not None:
- stmt = stmt.where(KnowledgeDocument.status == status)
- if source_type is not None:
- stmt = stmt.where(KnowledgeDocument.source_type == source_type)
- if keyword:
- pattern = f"%{keyword.strip()}%"
- stmt = stmt.where(
- or_(
- KnowledgeDocument.title.ilike(pattern),
- KnowledgeDocument.source_uri.ilike(pattern)))
- stmt = stmt.order_by(KnowledgeDocument.created_time.desc())
- return list(self.db.scalars(stmt))
- def get_next_pending_indexing(
- self,
- *,
- stale_before: datetime) -> KnowledgeDocument | None:
- stmt = (
- select(KnowledgeDocument)
- .where(
- or_(
- KnowledgeDocument.status == "queued",
- (KnowledgeDocument.status == "indexing")
- & (KnowledgeDocument.updated_time < stale_before)))
- .order_by(KnowledgeDocument.updated_time.asc())
- .limit(1)
- )
- return self.db.scalar(stmt)
- def get_by_id(self, *, document_id: str) -> KnowledgeDocument | None:
- stmt = (
- select(KnowledgeDocument)
- .where(KnowledgeDocument.id == document_id)
- )
- return self.db.scalar(stmt)
- def update_status(
- self,
- *,
- document_id: str,
- status: KnowledgeDocumentStatus) -> KnowledgeDocument | None:
- entity = self.db.get(KnowledgeDocument, document_id)
- if entity is None:
- return None
- entity.status = status
- entity.indexed_time = datetime.utcnow() if status == "indexed" else entity.indexed_time
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def update(
- self,
- *,
- document_id: str,
- title: str | None = None,
- source_uri: str | None = None,
- status: KnowledgeDocumentStatus | None = None,
- metadata_json: dict[str, JSONValue] | None = None) -> KnowledgeDocument | None:
- entity = self.get_by_id(document_id=document_id)
- if entity is None:
- return None
- if title is not None:
- entity.title = title
- if source_uri is not None:
- entity.source_uri = source_uri
- if status is not None:
- entity.status = status
- entity.indexed_time = datetime.utcnow() if status == "indexed" else entity.indexed_time
- if metadata_json is not None:
- entity.metadata_json = metadata_json
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def delete(self, *, document_id: str) -> KnowledgeDocument | None:
- entity = self.get_by_id(document_id=document_id)
- if entity is None:
- return None
- self.db.delete(entity)
- self.db.commit()
- return entity
- class KnowledgeChunkRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def replace_document_chunks(
- self,
- *,
- knowledge_base_id: str,
- document_id: str,
- chunks: list[dict[str, JSONValue]]) -> list[KnowledgeChunk]:
- self.db.execute(
- delete(KnowledgeChunk)
- .where(KnowledgeChunk.document_id == document_id)
- )
- entities: list[KnowledgeChunk] = []
- for chunk in chunks:
- entity = KnowledgeChunk(
- knowledge_base_id=knowledge_base_id,
- document_id=document_id,
- chunk_index=_read_int(chunk, "chunk_index"),
- content_text=_read_string(chunk, "content_text"),
- token_count=_read_int(chunk, "token_count"),
- embedding_model=_read_optional_string(chunk, "embedding_model"),
- embedding_json=_read_float_list(chunk, "embedding_json"),
- embedding_vector=_format_vector(_read_float_list(chunk, "embedding_json")),
- metadata_json=_read_optional_dict(chunk, "metadata_json"))
- self.db.add(entity)
- entities.append(entity)
- self.db.commit()
- for entity in entities:
- self.db.refresh(entity)
- return entities
- def list_by_base(
- self,
- *,
- knowledge_base_id: str) -> list[KnowledgeChunk]:
- stmt = (
- select(KnowledgeChunk)
- .where(KnowledgeChunk.knowledge_base_id == knowledge_base_id)
- .order_by(KnowledgeChunk.created_time.asc())
- )
- return list(self.db.scalars(stmt))
- def list_filtered(
- self,
- *,
- knowledge_base_id: str | None = None,
- document_id: str | None = None,
- keyword: str | None = None) -> list[KnowledgeChunk]:
- stmt = select(KnowledgeChunk)
- if knowledge_base_id is not None:
- stmt = stmt.where(KnowledgeChunk.knowledge_base_id == knowledge_base_id)
- if document_id is not None:
- stmt = stmt.where(KnowledgeChunk.document_id == document_id)
- if keyword:
- stmt = stmt.where(KnowledgeChunk.content_text.ilike(f"%{keyword.strip()}%"))
- stmt = stmt.order_by(KnowledgeChunk.created_time.asc())
- return list(self.db.scalars(stmt))
- def list_by_document(self, *, document_id: str) -> list[KnowledgeChunk]:
- return self.list_filtered(document_id=document_id)
- def get_by_id(self, *, chunk_id: str) -> KnowledgeChunk | None:
- stmt = select(KnowledgeChunk).where(KnowledgeChunk.id == chunk_id)
- return self.db.scalar(stmt)
- def delete_by_document(self, *, document_id: str) -> int:
- result = self.db.execute(
- delete(KnowledgeChunk)
- .where(KnowledgeChunk.document_id == document_id))
- self.db.commit()
- return int(result.rowcount or 0)
- def delete_by_base(self, *, knowledge_base_id: str) -> int:
- result = self.db.execute(
- delete(KnowledgeChunk)
- .where(KnowledgeChunk.knowledge_base_id == knowledge_base_id))
- self.db.commit()
- return int(result.rowcount or 0)
- def delete(self, *, chunk_id: str) -> bool:
- entity = self.get_by_id(chunk_id=chunk_id)
- if entity is None:
- return False
- self.db.delete(entity)
- self.db.commit()
- return True
- def search_by_vector(
- self,
- *,
- knowledge_base_id: str,
- embedding: list[float],
- limit: int) -> list[tuple[KnowledgeChunk, float]]:
- if not self._supports_pgvector_search():
- return []
- vector = _format_vector(embedding)
- if vector is None:
- return []
- stmt = text(
- """
- SELECT id, 1 - (
- embedding_vector OPERATOR(public.<=>) CAST(:embedding AS public.vector)
- ) AS score
- FROM knowledge_chunk
- WHERE knowledge_base_id = :knowledge_base_id
- AND embedding_vector IS NOT NULL
- ORDER BY embedding_vector OPERATOR(public.<=>) CAST(:embedding AS public.vector)
- LIMIT :limit
- """
- )
- try:
- rows = self.db.execute(
- stmt,
- {
- "knowledge_base_id": knowledge_base_id,
- "embedding": vector,
- "limit": limit,
- }).all()
- except Exception:
- self.db.rollback()
- return []
- if not rows:
- return []
- chunk_ids = [str(row[0]) for row in rows]
- chunks_by_id = {
- chunk.id: chunk
- for chunk in self.db.scalars(
- select(KnowledgeChunk).where(KnowledgeChunk.id.in_(chunk_ids))
- )
- }
- scored: list[tuple[KnowledgeChunk, float]] = []
- for row in rows:
- chunk = chunks_by_id.get(str(row[0]))
- if chunk is not None:
- scored.append((chunk, float(row[1] or 0.0)))
- return scored
- def _supports_pgvector_search(self) -> bool:
- return self.db.bind is not None and self.db.bind.dialect.name == "postgresql"
- def _read_string(payload: dict[str, JSONValue], key: str) -> str:
- value = payload.get(key)
- return value if isinstance(value, str) else ""
- def _read_optional_string(payload: dict[str, JSONValue], key: str) -> str | None:
- value = payload.get(key)
- return value if isinstance(value, str) else None
- def _read_int(payload: dict[str, JSONValue], key: str) -> int:
- value = payload.get(key)
- return value if isinstance(value, int) and not isinstance(value, bool) else 0
- def _read_float_list(payload: dict[str, JSONValue], key: str) -> list[float] | None:
- value = payload.get(key)
- if not isinstance(value, list):
- return None
- return [float(item) for item in value if isinstance(item, (int, float))]
- def _format_vector(value: list[float] | None) -> str | None:
- if not value:
- return None
- return "[" + ",".join(str(float(item)) for item in value) + "]"
- def _read_optional_dict(
- payload: dict[str, JSONValue],
- key: str) -> dict[str, JSONValue] | None:
- value = payload.get(key)
- if isinstance(value, dict):
- return {str(item_key): item_value for item_key, item_value in value.items()}
- return None
|