|
|
@@ -1,6 +1,6 @@
|
|
|
from datetime import datetime
|
|
|
|
|
|
-from sqlalchemy import delete, select
|
|
|
+from sqlalchemy import delete, select, text
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
from core_domain import KnowledgeBaseStatus, KnowledgeDocumentStatus
|
|
|
@@ -164,6 +164,7 @@ class KnowledgeChunkRepository:
|
|
|
token_count=_read_int(chunk, "token_count"),
|
|
|
embedding_model=_read_optional_string(chunk, "embedding_model"),
|
|
|
embedding_json=_read_float_list(chunk, "embedding_json"),
|
|
|
+ embedding_vector=_format_vector(_read_float_list(chunk, "embedding_json")),
|
|
|
metadata_json=_read_optional_dict(chunk, "metadata_json"),
|
|
|
)
|
|
|
self.db.add(entity)
|
|
|
@@ -187,6 +188,58 @@ class KnowledgeChunkRepository:
|
|
|
)
|
|
|
return list(self.db.scalars(stmt))
|
|
|
|
|
|
+ def search_by_vector(
|
|
|
+ self,
|
|
|
+ *,
|
|
|
+ tenant_id: str,
|
|
|
+ knowledge_base_id: str,
|
|
|
+ embedding: list[float],
|
|
|
+ limit: int,
|
|
|
+ ) -> list[tuple[KnowledgeChunk, float]]:
|
|
|
+ if not self._supports_pgvector_search():
|
|
|
+ return []
|
|
|
+ vector = _format_vector(embedding)
|
|
|
+ if vector is None:
|
|
|
+ return []
|
|
|
+ stmt = text(
|
|
|
+ """
|
|
|
+ SELECT id, 1 - (embedding_vector <=> CAST(:embedding AS vector)) AS score
|
|
|
+ FROM knowledge_chunk
|
|
|
+ WHERE tenant_id = :tenant_id
|
|
|
+ AND knowledge_base_id = :knowledge_base_id
|
|
|
+ AND embedding_vector IS NOT NULL
|
|
|
+ ORDER BY embedding_vector <=> CAST(:embedding AS vector)
|
|
|
+ LIMIT :limit
|
|
|
+ """
|
|
|
+ )
|
|
|
+ rows = self.db.execute(
|
|
|
+ stmt,
|
|
|
+ {
|
|
|
+ "tenant_id": tenant_id,
|
|
|
+ "knowledge_base_id": knowledge_base_id,
|
|
|
+ "embedding": vector,
|
|
|
+ "limit": limit,
|
|
|
+ },
|
|
|
+ ).all()
|
|
|
+ if not rows:
|
|
|
+ return []
|
|
|
+ chunk_ids = [str(row[0]) for row in rows]
|
|
|
+ chunks_by_id = {
|
|
|
+ chunk.id: chunk
|
|
|
+ for chunk in self.db.scalars(
|
|
|
+ select(KnowledgeChunk).where(KnowledgeChunk.id.in_(chunk_ids))
|
|
|
+ )
|
|
|
+ }
|
|
|
+ scored: list[tuple[KnowledgeChunk, float]] = []
|
|
|
+ for row in rows:
|
|
|
+ chunk = chunks_by_id.get(str(row[0]))
|
|
|
+ if chunk is not None:
|
|
|
+ scored.append((chunk, float(row[1] or 0.0)))
|
|
|
+ return scored
|
|
|
+
|
|
|
+ def _supports_pgvector_search(self) -> bool:
|
|
|
+ return self.db.bind is not None and self.db.bind.dialect.name == "postgresql"
|
|
|
+
|
|
|
|
|
|
def _read_string(payload: dict[str, JSONValue], key: str) -> str:
|
|
|
value = payload.get(key)
|
|
|
@@ -210,6 +263,12 @@ def _read_float_list(payload: dict[str, JSONValue], key: str) -> list[float] | N
|
|
|
return [float(item) for item in value if isinstance(item, (int, float))]
|
|
|
|
|
|
|
|
|
+def _format_vector(value: list[float] | None) -> str | None:
|
|
|
+ if not value:
|
|
|
+ return None
|
|
|
+ return "[" + ",".join(str(float(item)) for item in value) + "]"
|
|
|
+
|
|
|
+
|
|
|
def _read_optional_dict(
|
|
|
payload: dict[str, JSONValue],
|
|
|
key: str,
|