repositories.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. from datetime import datetime
  2. from sqlalchemy import delete, select, text
  3. from sqlalchemy.orm import Session
  4. from core_domain import KnowledgeBaseStatus, KnowledgeDocumentStatus
  5. from core_shared import JSONValue
  6. from app.db.models import KnowledgeBase, KnowledgeChunk, KnowledgeDocument
  7. class KnowledgeBaseRepository:
  8. def __init__(self, db: Session) -> None:
  9. self.db = db
  10. def create(
  11. self,
  12. *,
  13. tenant_id: str,
  14. code: str,
  15. name: str,
  16. description: str | None,
  17. metadata_json: dict[str, JSONValue] | None,
  18. ) -> KnowledgeBase:
  19. entity = KnowledgeBase(
  20. tenant_id=tenant_id,
  21. code=code,
  22. name=name,
  23. description=description,
  24. metadata_json=metadata_json,
  25. )
  26. self.db.add(entity)
  27. self.db.commit()
  28. self.db.refresh(entity)
  29. return entity
  30. def list_by_tenant(self, *, tenant_id: str) -> list[KnowledgeBase]:
  31. stmt = (
  32. select(KnowledgeBase)
  33. .where(KnowledgeBase.tenant_id == tenant_id)
  34. .order_by(KnowledgeBase.created_time.desc())
  35. )
  36. return list(self.db.scalars(stmt))
  37. def get_by_id(self, *, tenant_id: str, knowledge_base_id: str) -> KnowledgeBase | None:
  38. stmt = (
  39. select(KnowledgeBase)
  40. .where(KnowledgeBase.tenant_id == tenant_id)
  41. .where(KnowledgeBase.id == knowledge_base_id)
  42. )
  43. return self.db.scalar(stmt)
  44. def update_status(
  45. self,
  46. *,
  47. tenant_id: str,
  48. knowledge_base_id: str,
  49. status: KnowledgeBaseStatus,
  50. ) -> KnowledgeBase | None:
  51. entity = self.get_by_id(tenant_id=tenant_id, knowledge_base_id=knowledge_base_id)
  52. if entity is None:
  53. return None
  54. entity.status = status
  55. self.db.commit()
  56. self.db.refresh(entity)
  57. return entity
  58. class KnowledgeDocumentRepository:
  59. def __init__(self, db: Session) -> None:
  60. self.db = db
  61. def create(
  62. self,
  63. *,
  64. tenant_id: str,
  65. knowledge_base_id: str,
  66. title: str,
  67. source_type: str,
  68. source_uri: str | None,
  69. content_text: str,
  70. content_hash: str | None,
  71. metadata_json: dict[str, JSONValue] | None,
  72. ) -> KnowledgeDocument:
  73. entity = KnowledgeDocument(
  74. tenant_id=tenant_id,
  75. knowledge_base_id=knowledge_base_id,
  76. title=title,
  77. source_type=source_type,
  78. source_uri=source_uri,
  79. content_text=content_text,
  80. content_hash=content_hash,
  81. metadata_json=metadata_json,
  82. status="draft",
  83. )
  84. self.db.add(entity)
  85. self.db.commit()
  86. self.db.refresh(entity)
  87. return entity
  88. def list_by_base(
  89. self,
  90. *,
  91. tenant_id: str,
  92. knowledge_base_id: str,
  93. ) -> list[KnowledgeDocument]:
  94. stmt = (
  95. select(KnowledgeDocument)
  96. .where(KnowledgeDocument.tenant_id == tenant_id)
  97. .where(KnowledgeDocument.knowledge_base_id == knowledge_base_id)
  98. .order_by(KnowledgeDocument.created_time.desc())
  99. )
  100. return list(self.db.scalars(stmt))
  101. def get_by_id(self, *, tenant_id: str, document_id: str) -> KnowledgeDocument | None:
  102. stmt = (
  103. select(KnowledgeDocument)
  104. .where(KnowledgeDocument.tenant_id == tenant_id)
  105. .where(KnowledgeDocument.id == document_id)
  106. )
  107. return self.db.scalar(stmt)
  108. def update_status(
  109. self,
  110. *,
  111. document_id: str,
  112. status: KnowledgeDocumentStatus,
  113. ) -> KnowledgeDocument | None:
  114. entity = self.db.get(KnowledgeDocument, document_id)
  115. if entity is None:
  116. return None
  117. entity.status = status
  118. entity.indexed_time = datetime.utcnow() if status == "indexed" else entity.indexed_time
  119. self.db.commit()
  120. self.db.refresh(entity)
  121. return entity
  122. class KnowledgeChunkRepository:
  123. def __init__(self, db: Session) -> None:
  124. self.db = db
  125. def replace_document_chunks(
  126. self,
  127. *,
  128. tenant_id: str,
  129. knowledge_base_id: str,
  130. document_id: str,
  131. chunks: list[dict[str, JSONValue]],
  132. ) -> list[KnowledgeChunk]:
  133. self.db.execute(
  134. delete(KnowledgeChunk)
  135. .where(KnowledgeChunk.tenant_id == tenant_id)
  136. .where(KnowledgeChunk.document_id == document_id)
  137. )
  138. entities: list[KnowledgeChunk] = []
  139. for chunk in chunks:
  140. entity = KnowledgeChunk(
  141. tenant_id=tenant_id,
  142. knowledge_base_id=knowledge_base_id,
  143. document_id=document_id,
  144. chunk_index=_read_int(chunk, "chunk_index"),
  145. content_text=_read_string(chunk, "content_text"),
  146. token_count=_read_int(chunk, "token_count"),
  147. embedding_model=_read_optional_string(chunk, "embedding_model"),
  148. embedding_json=_read_float_list(chunk, "embedding_json"),
  149. embedding_vector=_format_vector(_read_float_list(chunk, "embedding_json")),
  150. metadata_json=_read_optional_dict(chunk, "metadata_json"),
  151. )
  152. self.db.add(entity)
  153. entities.append(entity)
  154. self.db.commit()
  155. for entity in entities:
  156. self.db.refresh(entity)
  157. return entities
  158. def list_by_base(
  159. self,
  160. *,
  161. tenant_id: str,
  162. knowledge_base_id: str,
  163. ) -> list[KnowledgeChunk]:
  164. stmt = (
  165. select(KnowledgeChunk)
  166. .where(KnowledgeChunk.tenant_id == tenant_id)
  167. .where(KnowledgeChunk.knowledge_base_id == knowledge_base_id)
  168. .order_by(KnowledgeChunk.created_time.asc())
  169. )
  170. return list(self.db.scalars(stmt))
  171. def search_by_vector(
  172. self,
  173. *,
  174. tenant_id: str,
  175. knowledge_base_id: str,
  176. embedding: list[float],
  177. limit: int,
  178. ) -> list[tuple[KnowledgeChunk, float]]:
  179. if not self._supports_pgvector_search():
  180. return []
  181. vector = _format_vector(embedding)
  182. if vector is None:
  183. return []
  184. stmt = text(
  185. """
  186. SELECT id, 1 - (embedding_vector <=> CAST(:embedding AS vector)) AS score
  187. FROM knowledge_chunk
  188. WHERE tenant_id = :tenant_id
  189. AND knowledge_base_id = :knowledge_base_id
  190. AND embedding_vector IS NOT NULL
  191. ORDER BY embedding_vector <=> CAST(:embedding AS vector)
  192. LIMIT :limit
  193. """
  194. )
  195. rows = self.db.execute(
  196. stmt,
  197. {
  198. "tenant_id": tenant_id,
  199. "knowledge_base_id": knowledge_base_id,
  200. "embedding": vector,
  201. "limit": limit,
  202. },
  203. ).all()
  204. if not rows:
  205. return []
  206. chunk_ids = [str(row[0]) for row in rows]
  207. chunks_by_id = {
  208. chunk.id: chunk
  209. for chunk in self.db.scalars(
  210. select(KnowledgeChunk).where(KnowledgeChunk.id.in_(chunk_ids))
  211. )
  212. }
  213. scored: list[tuple[KnowledgeChunk, float]] = []
  214. for row in rows:
  215. chunk = chunks_by_id.get(str(row[0]))
  216. if chunk is not None:
  217. scored.append((chunk, float(row[1] or 0.0)))
  218. return scored
  219. def _supports_pgvector_search(self) -> bool:
  220. return self.db.bind is not None and self.db.bind.dialect.name == "postgresql"
  221. def _read_string(payload: dict[str, JSONValue], key: str) -> str:
  222. value = payload.get(key)
  223. return value if isinstance(value, str) else ""
  224. def _read_optional_string(payload: dict[str, JSONValue], key: str) -> str | None:
  225. value = payload.get(key)
  226. return value if isinstance(value, str) else None
  227. def _read_int(payload: dict[str, JSONValue], key: str) -> int:
  228. value = payload.get(key)
  229. return value if isinstance(value, int) and not isinstance(value, bool) else 0
  230. def _read_float_list(payload: dict[str, JSONValue], key: str) -> list[float] | None:
  231. value = payload.get(key)
  232. if not isinstance(value, list):
  233. return None
  234. return [float(item) for item in value if isinstance(item, (int, float))]
  235. def _format_vector(value: list[float] | None) -> str | None:
  236. if not value:
  237. return None
  238. return "[" + ",".join(str(float(item)) for item in value) + "]"
  239. def _read_optional_dict(
  240. payload: dict[str, JSONValue],
  241. key: str,
  242. ) -> dict[str, JSONValue] | None:
  243. value = payload.get(key)
  244. if isinstance(value, dict):
  245. return {str(item_key): item_value for item_key, item_value in value.items()}
  246. return None