| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279 |
- from datetime import datetime
- from sqlalchemy import delete, select, text
- from sqlalchemy.orm import Session
- from core_domain import KnowledgeBaseStatus, KnowledgeDocumentStatus
- from core_shared import JSONValue
- from app.db.models import KnowledgeBase, KnowledgeChunk, KnowledgeDocument
- class KnowledgeBaseRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- tenant_id: str,
- code: str,
- name: str,
- description: str | None,
- metadata_json: dict[str, JSONValue] | None,
- ) -> KnowledgeBase:
- entity = KnowledgeBase(
- tenant_id=tenant_id,
- code=code,
- name=name,
- description=description,
- metadata_json=metadata_json,
- )
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_by_tenant(self, *, tenant_id: str) -> list[KnowledgeBase]:
- stmt = (
- select(KnowledgeBase)
- .where(KnowledgeBase.tenant_id == tenant_id)
- .order_by(KnowledgeBase.created_time.desc())
- )
- return list(self.db.scalars(stmt))
- def get_by_id(self, *, tenant_id: str, knowledge_base_id: str) -> KnowledgeBase | None:
- stmt = (
- select(KnowledgeBase)
- .where(KnowledgeBase.tenant_id == tenant_id)
- .where(KnowledgeBase.id == knowledge_base_id)
- )
- return self.db.scalar(stmt)
- def update_status(
- self,
- *,
- tenant_id: str,
- knowledge_base_id: str,
- status: KnowledgeBaseStatus,
- ) -> KnowledgeBase | None:
- entity = self.get_by_id(tenant_id=tenant_id, knowledge_base_id=knowledge_base_id)
- if entity is None:
- return None
- entity.status = status
- self.db.commit()
- self.db.refresh(entity)
- return entity
- class KnowledgeDocumentRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- tenant_id: str,
- knowledge_base_id: str,
- title: str,
- source_type: str,
- source_uri: str | None,
- content_text: str,
- content_hash: str | None,
- metadata_json: dict[str, JSONValue] | None,
- ) -> KnowledgeDocument:
- entity = KnowledgeDocument(
- tenant_id=tenant_id,
- knowledge_base_id=knowledge_base_id,
- title=title,
- source_type=source_type,
- source_uri=source_uri,
- content_text=content_text,
- content_hash=content_hash,
- metadata_json=metadata_json,
- status="draft",
- )
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_by_base(
- self,
- *,
- tenant_id: str,
- knowledge_base_id: str,
- ) -> list[KnowledgeDocument]:
- stmt = (
- select(KnowledgeDocument)
- .where(KnowledgeDocument.tenant_id == tenant_id)
- .where(KnowledgeDocument.knowledge_base_id == knowledge_base_id)
- .order_by(KnowledgeDocument.created_time.desc())
- )
- return list(self.db.scalars(stmt))
- def get_by_id(self, *, tenant_id: str, document_id: str) -> KnowledgeDocument | None:
- stmt = (
- select(KnowledgeDocument)
- .where(KnowledgeDocument.tenant_id == tenant_id)
- .where(KnowledgeDocument.id == document_id)
- )
- return self.db.scalar(stmt)
- def update_status(
- self,
- *,
- document_id: str,
- status: KnowledgeDocumentStatus,
- ) -> KnowledgeDocument | None:
- entity = self.db.get(KnowledgeDocument, document_id)
- if entity is None:
- return None
- entity.status = status
- entity.indexed_time = datetime.utcnow() if status == "indexed" else entity.indexed_time
- self.db.commit()
- self.db.refresh(entity)
- return entity
- class KnowledgeChunkRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def replace_document_chunks(
- self,
- *,
- tenant_id: str,
- knowledge_base_id: str,
- document_id: str,
- chunks: list[dict[str, JSONValue]],
- ) -> list[KnowledgeChunk]:
- self.db.execute(
- delete(KnowledgeChunk)
- .where(KnowledgeChunk.tenant_id == tenant_id)
- .where(KnowledgeChunk.document_id == document_id)
- )
- entities: list[KnowledgeChunk] = []
- for chunk in chunks:
- entity = KnowledgeChunk(
- tenant_id=tenant_id,
- knowledge_base_id=knowledge_base_id,
- document_id=document_id,
- chunk_index=_read_int(chunk, "chunk_index"),
- content_text=_read_string(chunk, "content_text"),
- token_count=_read_int(chunk, "token_count"),
- embedding_model=_read_optional_string(chunk, "embedding_model"),
- embedding_json=_read_float_list(chunk, "embedding_json"),
- embedding_vector=_format_vector(_read_float_list(chunk, "embedding_json")),
- metadata_json=_read_optional_dict(chunk, "metadata_json"),
- )
- self.db.add(entity)
- entities.append(entity)
- self.db.commit()
- for entity in entities:
- self.db.refresh(entity)
- return entities
- def list_by_base(
- self,
- *,
- tenant_id: str,
- knowledge_base_id: str,
- ) -> list[KnowledgeChunk]:
- stmt = (
- select(KnowledgeChunk)
- .where(KnowledgeChunk.tenant_id == tenant_id)
- .where(KnowledgeChunk.knowledge_base_id == knowledge_base_id)
- .order_by(KnowledgeChunk.created_time.asc())
- )
- return list(self.db.scalars(stmt))
- def search_by_vector(
- self,
- *,
- tenant_id: str,
- knowledge_base_id: str,
- embedding: list[float],
- limit: int,
- ) -> list[tuple[KnowledgeChunk, float]]:
- if not self._supports_pgvector_search():
- return []
- vector = _format_vector(embedding)
- if vector is None:
- return []
- stmt = text(
- """
- SELECT id, 1 - (embedding_vector <=> CAST(:embedding AS vector)) AS score
- FROM knowledge_chunk
- WHERE tenant_id = :tenant_id
- AND knowledge_base_id = :knowledge_base_id
- AND embedding_vector IS NOT NULL
- ORDER BY embedding_vector <=> CAST(:embedding AS vector)
- LIMIT :limit
- """
- )
- rows = self.db.execute(
- stmt,
- {
- "tenant_id": tenant_id,
- "knowledge_base_id": knowledge_base_id,
- "embedding": vector,
- "limit": limit,
- },
- ).all()
- if not rows:
- return []
- chunk_ids = [str(row[0]) for row in rows]
- chunks_by_id = {
- chunk.id: chunk
- for chunk in self.db.scalars(
- select(KnowledgeChunk).where(KnowledgeChunk.id.in_(chunk_ids))
- )
- }
- scored: list[tuple[KnowledgeChunk, float]] = []
- for row in rows:
- chunk = chunks_by_id.get(str(row[0]))
- if chunk is not None:
- scored.append((chunk, float(row[1] or 0.0)))
- return scored
- def _supports_pgvector_search(self) -> bool:
- return self.db.bind is not None and self.db.bind.dialect.name == "postgresql"
- def _read_string(payload: dict[str, JSONValue], key: str) -> str:
- value = payload.get(key)
- return value if isinstance(value, str) else ""
- def _read_optional_string(payload: dict[str, JSONValue], key: str) -> str | None:
- value = payload.get(key)
- return value if isinstance(value, str) else None
- def _read_int(payload: dict[str, JSONValue], key: str) -> int:
- value = payload.get(key)
- return value if isinstance(value, int) and not isinstance(value, bool) else 0
- def _read_float_list(payload: dict[str, JSONValue], key: str) -> list[float] | None:
- value = payload.get(key)
- if not isinstance(value, list):
- return None
- return [float(item) for item in value if isinstance(item, (int, float))]
- def _format_vector(value: list[float] | None) -> str | None:
- if not value:
- return None
- return "[" + ",".join(str(float(item)) for item in value) + "]"
- def _read_optional_dict(
- payload: dict[str, JSONValue],
- key: str,
- ) -> dict[str, JSONValue] | None:
- value = payload.get(key)
- if isinstance(value, dict):
- return {str(item_key): item_value for item_key, item_value in value.items()}
- return None
|