repositories.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. from datetime import datetime
  2. from sqlalchemy import delete, or_, select, text
  3. from sqlalchemy.orm import Session
  4. from core_domain import KnowledgeBaseStatus, KnowledgeDocumentStatus
  5. from core_shared import JSONValue
  6. from app.db.models import KnowledgeBase, KnowledgeChunk, KnowledgeDocument
  7. class KnowledgeBaseRepository:
  8. def __init__(self, db: Session) -> None:
  9. self.db = db
  10. def create(
  11. self,
  12. *,
  13. code: str,
  14. name: str,
  15. description: str | None,
  16. metadata_json: dict[str, JSONValue] | None) -> KnowledgeBase:
  17. entity = KnowledgeBase(
  18. code=code,
  19. name=name,
  20. description=description,
  21. metadata_json=metadata_json)
  22. self.db.add(entity)
  23. self.db.commit()
  24. self.db.refresh(entity)
  25. return entity
  26. def list_all(self) -> list[KnowledgeBase]:
  27. stmt = (
  28. select(KnowledgeBase)
  29. .order_by(KnowledgeBase.created_time.desc())
  30. )
  31. return list(self.db.scalars(stmt))
  32. def list_filtered(
  33. self,
  34. *,
  35. keyword: str | None = None,
  36. status: KnowledgeBaseStatus | None = None) -> list[KnowledgeBase]:
  37. stmt = select(KnowledgeBase)
  38. if status is not None:
  39. stmt = stmt.where(KnowledgeBase.status == status)
  40. if keyword:
  41. pattern = f"%{keyword.strip()}%"
  42. stmt = stmt.where(
  43. or_(
  44. KnowledgeBase.name.ilike(pattern),
  45. KnowledgeBase.description.ilike(pattern)))
  46. stmt = stmt.order_by(KnowledgeBase.created_time.desc())
  47. return list(self.db.scalars(stmt))
  48. def get_by_id(self, *, knowledge_base_id: str) -> KnowledgeBase | None:
  49. stmt = (
  50. select(KnowledgeBase)
  51. .where(KnowledgeBase.id == knowledge_base_id)
  52. )
  53. return self.db.scalar(stmt)
  54. def update_status(
  55. self,
  56. *,
  57. knowledge_base_id: str,
  58. status: KnowledgeBaseStatus) -> KnowledgeBase | None:
  59. entity = self.get_by_id(knowledge_base_id=knowledge_base_id)
  60. if entity is None:
  61. return None
  62. entity.status = status
  63. self.db.commit()
  64. self.db.refresh(entity)
  65. return entity
  66. def update(
  67. self,
  68. *,
  69. knowledge_base_id: str,
  70. name: str | None = None,
  71. description: str | None = None,
  72. status: KnowledgeBaseStatus | None = None,
  73. metadata_json: dict[str, JSONValue] | None = None) -> KnowledgeBase | None:
  74. entity = self.get_by_id(knowledge_base_id=knowledge_base_id)
  75. if entity is None:
  76. return None
  77. if name is not None:
  78. entity.name = name
  79. if description is not None:
  80. entity.description = description
  81. if status is not None:
  82. entity.status = status
  83. if metadata_json is not None:
  84. entity.metadata_json = metadata_json
  85. self.db.commit()
  86. self.db.refresh(entity)
  87. return entity
  88. def delete(self, *, knowledge_base_id: str) -> bool:
  89. entity = self.get_by_id(knowledge_base_id=knowledge_base_id)
  90. if entity is None:
  91. return False
  92. self.db.delete(entity)
  93. self.db.commit()
  94. return True
  95. class KnowledgeDocumentRepository:
  96. def __init__(self, db: Session) -> None:
  97. self.db = db
  98. def create(
  99. self,
  100. *,
  101. knowledge_base_id: str,
  102. title: str,
  103. source_type: str,
  104. source_uri: str | None,
  105. content_text: str,
  106. content_hash: str | None,
  107. metadata_json: dict[str, JSONValue] | None,
  108. status: KnowledgeDocumentStatus = "draft") -> KnowledgeDocument:
  109. entity = KnowledgeDocument(
  110. knowledge_base_id=knowledge_base_id,
  111. title=title,
  112. source_type=source_type,
  113. source_uri=source_uri,
  114. content_text=content_text,
  115. content_hash=content_hash,
  116. metadata_json=metadata_json,
  117. status=status)
  118. self.db.add(entity)
  119. self.db.commit()
  120. self.db.refresh(entity)
  121. return entity
  122. def list_by_base(
  123. self,
  124. *,
  125. knowledge_base_id: str) -> list[KnowledgeDocument]:
  126. stmt = (
  127. select(KnowledgeDocument)
  128. .where(KnowledgeDocument.knowledge_base_id == knowledge_base_id)
  129. .order_by(KnowledgeDocument.created_time.desc())
  130. )
  131. return list(self.db.scalars(stmt))
  132. def list_filtered(
  133. self,
  134. *,
  135. knowledge_base_id: str | None = None,
  136. keyword: str | None = None,
  137. status: KnowledgeDocumentStatus | None = None,
  138. source_type: str | None = None) -> list[KnowledgeDocument]:
  139. stmt = select(KnowledgeDocument)
  140. if knowledge_base_id is not None:
  141. stmt = stmt.where(KnowledgeDocument.knowledge_base_id == knowledge_base_id)
  142. if status is not None:
  143. stmt = stmt.where(KnowledgeDocument.status == status)
  144. if source_type is not None:
  145. stmt = stmt.where(KnowledgeDocument.source_type == source_type)
  146. if keyword:
  147. pattern = f"%{keyword.strip()}%"
  148. stmt = stmt.where(
  149. or_(
  150. KnowledgeDocument.title.ilike(pattern),
  151. KnowledgeDocument.source_uri.ilike(pattern)))
  152. stmt = stmt.order_by(KnowledgeDocument.created_time.desc())
  153. return list(self.db.scalars(stmt))
  154. def get_next_pending_indexing(
  155. self,
  156. *,
  157. stale_before: datetime) -> KnowledgeDocument | None:
  158. stmt = (
  159. select(KnowledgeDocument)
  160. .where(
  161. or_(
  162. KnowledgeDocument.status == "queued",
  163. (KnowledgeDocument.status == "indexing")
  164. & (KnowledgeDocument.updated_time < stale_before)))
  165. .order_by(KnowledgeDocument.updated_time.asc())
  166. .limit(1)
  167. )
  168. return self.db.scalar(stmt)
  169. def get_by_id(self, *, document_id: str) -> KnowledgeDocument | None:
  170. stmt = (
  171. select(KnowledgeDocument)
  172. .where(KnowledgeDocument.id == document_id)
  173. )
  174. return self.db.scalar(stmt)
  175. def update_status(
  176. self,
  177. *,
  178. document_id: str,
  179. status: KnowledgeDocumentStatus) -> KnowledgeDocument | None:
  180. entity = self.db.get(KnowledgeDocument, document_id)
  181. if entity is None:
  182. return None
  183. entity.status = status
  184. entity.indexed_time = datetime.utcnow() if status == "indexed" else entity.indexed_time
  185. self.db.commit()
  186. self.db.refresh(entity)
  187. return entity
  188. def update(
  189. self,
  190. *,
  191. document_id: str,
  192. title: str | None = None,
  193. source_uri: str | None = None,
  194. status: KnowledgeDocumentStatus | None = None,
  195. metadata_json: dict[str, JSONValue] | None = None) -> KnowledgeDocument | None:
  196. entity = self.get_by_id(document_id=document_id)
  197. if entity is None:
  198. return None
  199. if title is not None:
  200. entity.title = title
  201. if source_uri is not None:
  202. entity.source_uri = source_uri
  203. if status is not None:
  204. entity.status = status
  205. entity.indexed_time = datetime.utcnow() if status == "indexed" else entity.indexed_time
  206. if metadata_json is not None:
  207. entity.metadata_json = metadata_json
  208. self.db.commit()
  209. self.db.refresh(entity)
  210. return entity
  211. def delete(self, *, document_id: str) -> KnowledgeDocument | None:
  212. entity = self.get_by_id(document_id=document_id)
  213. if entity is None:
  214. return None
  215. self.db.delete(entity)
  216. self.db.commit()
  217. return entity
  218. class KnowledgeChunkRepository:
  219. def __init__(self, db: Session) -> None:
  220. self.db = db
  221. def replace_document_chunks(
  222. self,
  223. *,
  224. knowledge_base_id: str,
  225. document_id: str,
  226. chunks: list[dict[str, JSONValue]]) -> list[KnowledgeChunk]:
  227. self.db.execute(
  228. delete(KnowledgeChunk)
  229. .where(KnowledgeChunk.document_id == document_id)
  230. )
  231. entities: list[KnowledgeChunk] = []
  232. for chunk in chunks:
  233. entity = KnowledgeChunk(
  234. knowledge_base_id=knowledge_base_id,
  235. document_id=document_id,
  236. chunk_index=_read_int(chunk, "chunk_index"),
  237. content_text=_read_string(chunk, "content_text"),
  238. token_count=_read_int(chunk, "token_count"),
  239. embedding_model=_read_optional_string(chunk, "embedding_model"),
  240. embedding_json=_read_float_list(chunk, "embedding_json"),
  241. embedding_vector=_format_vector(_read_float_list(chunk, "embedding_json")),
  242. metadata_json=_read_optional_dict(chunk, "metadata_json"))
  243. self.db.add(entity)
  244. entities.append(entity)
  245. self.db.commit()
  246. for entity in entities:
  247. self.db.refresh(entity)
  248. return entities
  249. def list_by_base(
  250. self,
  251. *,
  252. knowledge_base_id: str) -> list[KnowledgeChunk]:
  253. stmt = (
  254. select(KnowledgeChunk)
  255. .where(KnowledgeChunk.knowledge_base_id == knowledge_base_id)
  256. .order_by(KnowledgeChunk.created_time.asc())
  257. )
  258. return list(self.db.scalars(stmt))
  259. def list_filtered(
  260. self,
  261. *,
  262. knowledge_base_id: str | None = None,
  263. document_id: str | None = None,
  264. keyword: str | None = None) -> list[KnowledgeChunk]:
  265. stmt = select(KnowledgeChunk)
  266. if knowledge_base_id is not None:
  267. stmt = stmt.where(KnowledgeChunk.knowledge_base_id == knowledge_base_id)
  268. if document_id is not None:
  269. stmt = stmt.where(KnowledgeChunk.document_id == document_id)
  270. if keyword:
  271. stmt = stmt.where(KnowledgeChunk.content_text.ilike(f"%{keyword.strip()}%"))
  272. stmt = stmt.order_by(KnowledgeChunk.created_time.asc())
  273. return list(self.db.scalars(stmt))
  274. def list_by_document(self, *, document_id: str) -> list[KnowledgeChunk]:
  275. return self.list_filtered(document_id=document_id)
  276. def get_by_id(self, *, chunk_id: str) -> KnowledgeChunk | None:
  277. stmt = select(KnowledgeChunk).where(KnowledgeChunk.id == chunk_id)
  278. return self.db.scalar(stmt)
  279. def delete_by_document(self, *, document_id: str) -> int:
  280. result = self.db.execute(
  281. delete(KnowledgeChunk)
  282. .where(KnowledgeChunk.document_id == document_id))
  283. self.db.commit()
  284. return int(result.rowcount or 0)
  285. def delete_by_base(self, *, knowledge_base_id: str) -> int:
  286. result = self.db.execute(
  287. delete(KnowledgeChunk)
  288. .where(KnowledgeChunk.knowledge_base_id == knowledge_base_id))
  289. self.db.commit()
  290. return int(result.rowcount or 0)
  291. def delete(self, *, chunk_id: str) -> bool:
  292. entity = self.get_by_id(chunk_id=chunk_id)
  293. if entity is None:
  294. return False
  295. self.db.delete(entity)
  296. self.db.commit()
  297. return True
  298. def search_by_vector(
  299. self,
  300. *,
  301. knowledge_base_id: str,
  302. embedding: list[float],
  303. limit: int) -> list[tuple[KnowledgeChunk, float]]:
  304. if not self._supports_pgvector_search():
  305. return []
  306. vector = _format_vector(embedding)
  307. if vector is None:
  308. return []
  309. stmt = text(
  310. """
  311. SELECT id, 1 - (
  312. embedding_vector OPERATOR(public.<=>) CAST(:embedding AS public.vector)
  313. ) AS score
  314. FROM knowledge_chunk
  315. WHERE knowledge_base_id = :knowledge_base_id
  316. AND embedding_vector IS NOT NULL
  317. ORDER BY embedding_vector OPERATOR(public.<=>) CAST(:embedding AS public.vector)
  318. LIMIT :limit
  319. """
  320. )
  321. try:
  322. rows = self.db.execute(
  323. stmt,
  324. {
  325. "knowledge_base_id": knowledge_base_id,
  326. "embedding": vector,
  327. "limit": limit,
  328. }).all()
  329. except Exception:
  330. self.db.rollback()
  331. return []
  332. if not rows:
  333. return []
  334. chunk_ids = [str(row[0]) for row in rows]
  335. chunks_by_id = {
  336. chunk.id: chunk
  337. for chunk in self.db.scalars(
  338. select(KnowledgeChunk).where(KnowledgeChunk.id.in_(chunk_ids))
  339. )
  340. }
  341. scored: list[tuple[KnowledgeChunk, float]] = []
  342. for row in rows:
  343. chunk = chunks_by_id.get(str(row[0]))
  344. if chunk is not None:
  345. scored.append((chunk, float(row[1] or 0.0)))
  346. return scored
  347. def _supports_pgvector_search(self) -> bool:
  348. return self.db.bind is not None and self.db.bind.dialect.name == "postgresql"
  349. def _read_string(payload: dict[str, JSONValue], key: str) -> str:
  350. value = payload.get(key)
  351. return value if isinstance(value, str) else ""
  352. def _read_optional_string(payload: dict[str, JSONValue], key: str) -> str | None:
  353. value = payload.get(key)
  354. return value if isinstance(value, str) else None
  355. def _read_int(payload: dict[str, JSONValue], key: str) -> int:
  356. value = payload.get(key)
  357. return value if isinstance(value, int) and not isinstance(value, bool) else 0
  358. def _read_float_list(payload: dict[str, JSONValue], key: str) -> list[float] | None:
  359. value = payload.get(key)
  360. if not isinstance(value, list):
  361. return None
  362. return [float(item) for item in value if isinstance(item, (int, float))]
  363. def _format_vector(value: list[float] | None) -> str | None:
  364. if not value:
  365. return None
  366. return "[" + ",".join(str(float(item)) for item in value) + "]"
  367. def _read_optional_dict(
  368. payload: dict[str, JSONValue],
  369. key: str) -> dict[str, JSONValue] | None:
  370. value = payload.get(key)
  371. if isinstance(value, dict):
  372. return {str(item_key): item_value for item_key, item_value in value.items()}
  373. return None