repositories.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. from datetime import datetime
  2. from sqlalchemy import delete, select, text
  3. from sqlalchemy.orm import Session
  4. from core_domain import KnowledgeBaseStatus, KnowledgeDocumentStatus
  5. from core_shared import JSONValue
  6. from app.db.models import KnowledgeBase, KnowledgeChunk, KnowledgeDocument
  7. class KnowledgeBaseRepository:
  8. def __init__(self, db: Session) -> None:
  9. self.db = db
  10. def create(
  11. self,
  12. *,
  13. code: str,
  14. name: str,
  15. description: str | None,
  16. metadata_json: dict[str, JSONValue] | None) -> KnowledgeBase:
  17. entity = KnowledgeBase(
  18. code=code,
  19. name=name,
  20. description=description,
  21. metadata_json=metadata_json)
  22. self.db.add(entity)
  23. self.db.commit()
  24. self.db.refresh(entity)
  25. return entity
  26. def list_all(self) -> list[KnowledgeBase]:
  27. stmt = (
  28. select(KnowledgeBase)
  29. .order_by(KnowledgeBase.created_time.desc())
  30. )
  31. return list(self.db.scalars(stmt))
  32. def get_by_id(self, *, knowledge_base_id: str) -> KnowledgeBase | None:
  33. stmt = (
  34. select(KnowledgeBase)
  35. .where(KnowledgeBase.id == knowledge_base_id)
  36. )
  37. return self.db.scalar(stmt)
  38. def update_status(
  39. self,
  40. *,
  41. knowledge_base_id: str,
  42. status: KnowledgeBaseStatus) -> KnowledgeBase | None:
  43. entity = self.get_by_id(knowledge_base_id=knowledge_base_id)
  44. if entity is None:
  45. return None
  46. entity.status = status
  47. self.db.commit()
  48. self.db.refresh(entity)
  49. return entity
  50. class KnowledgeDocumentRepository:
  51. def __init__(self, db: Session) -> None:
  52. self.db = db
  53. def create(
  54. self,
  55. *,
  56. knowledge_base_id: str,
  57. title: str,
  58. source_type: str,
  59. source_uri: str | None,
  60. content_text: str,
  61. content_hash: str | None,
  62. metadata_json: dict[str, JSONValue] | None) -> KnowledgeDocument:
  63. entity = KnowledgeDocument(
  64. knowledge_base_id=knowledge_base_id,
  65. title=title,
  66. source_type=source_type,
  67. source_uri=source_uri,
  68. content_text=content_text,
  69. content_hash=content_hash,
  70. metadata_json=metadata_json,
  71. status="draft")
  72. self.db.add(entity)
  73. self.db.commit()
  74. self.db.refresh(entity)
  75. return entity
  76. def list_by_base(
  77. self,
  78. *,
  79. knowledge_base_id: str) -> list[KnowledgeDocument]:
  80. stmt = (
  81. select(KnowledgeDocument)
  82. .where(KnowledgeDocument.knowledge_base_id == knowledge_base_id)
  83. .order_by(KnowledgeDocument.created_time.desc())
  84. )
  85. return list(self.db.scalars(stmt))
  86. def get_by_id(self, *, document_id: str) -> KnowledgeDocument | None:
  87. stmt = (
  88. select(KnowledgeDocument)
  89. .where(KnowledgeDocument.id == document_id)
  90. )
  91. return self.db.scalar(stmt)
  92. def update_status(
  93. self,
  94. *,
  95. document_id: str,
  96. status: KnowledgeDocumentStatus) -> KnowledgeDocument | None:
  97. entity = self.db.get(KnowledgeDocument, document_id)
  98. if entity is None:
  99. return None
  100. entity.status = status
  101. entity.indexed_time = datetime.utcnow() if status == "indexed" else entity.indexed_time
  102. self.db.commit()
  103. self.db.refresh(entity)
  104. return entity
  105. class KnowledgeChunkRepository:
  106. def __init__(self, db: Session) -> None:
  107. self.db = db
  108. def replace_document_chunks(
  109. self,
  110. *,
  111. knowledge_base_id: str,
  112. document_id: str,
  113. chunks: list[dict[str, JSONValue]]) -> list[KnowledgeChunk]:
  114. self.db.execute(
  115. delete(KnowledgeChunk)
  116. .where(KnowledgeChunk.document_id == document_id)
  117. )
  118. entities: list[KnowledgeChunk] = []
  119. for chunk in chunks:
  120. entity = KnowledgeChunk(
  121. knowledge_base_id=knowledge_base_id,
  122. document_id=document_id,
  123. chunk_index=_read_int(chunk, "chunk_index"),
  124. content_text=_read_string(chunk, "content_text"),
  125. token_count=_read_int(chunk, "token_count"),
  126. embedding_model=_read_optional_string(chunk, "embedding_model"),
  127. embedding_json=_read_float_list(chunk, "embedding_json"),
  128. embedding_vector=_format_vector(_read_float_list(chunk, "embedding_json")),
  129. metadata_json=_read_optional_dict(chunk, "metadata_json"))
  130. self.db.add(entity)
  131. entities.append(entity)
  132. self.db.commit()
  133. for entity in entities:
  134. self.db.refresh(entity)
  135. return entities
  136. def list_by_base(
  137. self,
  138. *,
  139. knowledge_base_id: str) -> list[KnowledgeChunk]:
  140. stmt = (
  141. select(KnowledgeChunk)
  142. .where(KnowledgeChunk.knowledge_base_id == knowledge_base_id)
  143. .order_by(KnowledgeChunk.created_time.asc())
  144. )
  145. return list(self.db.scalars(stmt))
  146. def search_by_vector(
  147. self,
  148. *,
  149. knowledge_base_id: str,
  150. embedding: list[float],
  151. limit: int) -> list[tuple[KnowledgeChunk, float]]:
  152. if not self._supports_pgvector_search():
  153. return []
  154. vector = _format_vector(embedding)
  155. if vector is None:
  156. return []
  157. stmt = text(
  158. """
  159. SELECT id, 1 - (embedding_vector <=> CAST(:embedding AS vector)) AS score
  160. FROM knowledge_chunk
  161. WHERE knowledge_base_id = :knowledge_base_id
  162. AND embedding_vector IS NOT NULL
  163. ORDER BY embedding_vector <=> CAST(:embedding AS vector)
  164. LIMIT :limit
  165. """
  166. )
  167. rows = self.db.execute(
  168. stmt,
  169. {
  170. "knowledge_base_id": knowledge_base_id,
  171. "embedding": vector,
  172. "limit": limit,
  173. }).all()
  174. if not rows:
  175. return []
  176. chunk_ids = [str(row[0]) for row in rows]
  177. chunks_by_id = {
  178. chunk.id: chunk
  179. for chunk in self.db.scalars(
  180. select(KnowledgeChunk).where(KnowledgeChunk.id.in_(chunk_ids))
  181. )
  182. }
  183. scored: list[tuple[KnowledgeChunk, float]] = []
  184. for row in rows:
  185. chunk = chunks_by_id.get(str(row[0]))
  186. if chunk is not None:
  187. scored.append((chunk, float(row[1] or 0.0)))
  188. return scored
  189. def _supports_pgvector_search(self) -> bool:
  190. return self.db.bind is not None and self.db.bind.dialect.name == "postgresql"
  191. def _read_string(payload: dict[str, JSONValue], key: str) -> str:
  192. value = payload.get(key)
  193. return value if isinstance(value, str) else ""
  194. def _read_optional_string(payload: dict[str, JSONValue], key: str) -> str | None:
  195. value = payload.get(key)
  196. return value if isinstance(value, str) else None
  197. def _read_int(payload: dict[str, JSONValue], key: str) -> int:
  198. value = payload.get(key)
  199. return value if isinstance(value, int) and not isinstance(value, bool) else 0
  200. def _read_float_list(payload: dict[str, JSONValue], key: str) -> list[float] | None:
  201. value = payload.get(key)
  202. if not isinstance(value, list):
  203. return None
  204. return [float(item) for item in value if isinstance(item, (int, float))]
  205. def _format_vector(value: list[float] | None) -> str | None:
  206. if not value:
  207. return None
  208. return "[" + ",".join(str(float(item)) for item in value) + "]"
  209. def _read_optional_dict(
  210. payload: dict[str, JSONValue],
  211. key: str) -> dict[str, JSONValue] | None:
  212. value = payload.get(key)
  213. if isinstance(value, dict):
  214. return {str(item_key): item_value for item_key, item_value in value.items()}
  215. return None