repositories.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. from datetime import datetime
  2. from sqlalchemy import delete, select
  3. from sqlalchemy.orm import Session
  4. from core_domain import KnowledgeBaseStatus, KnowledgeDocumentStatus
  5. from core_shared import JSONValue
  6. from app.db.models import KnowledgeBase, KnowledgeChunk, KnowledgeDocument
  7. class KnowledgeBaseRepository:
  8. def __init__(self, db: Session) -> None:
  9. self.db = db
  10. def create(
  11. self,
  12. *,
  13. tenant_id: str,
  14. code: str,
  15. name: str,
  16. description: str | None,
  17. metadata_json: dict[str, JSONValue] | None,
  18. ) -> KnowledgeBase:
  19. entity = KnowledgeBase(
  20. tenant_id=tenant_id,
  21. code=code,
  22. name=name,
  23. description=description,
  24. metadata_json=metadata_json,
  25. )
  26. self.db.add(entity)
  27. self.db.commit()
  28. self.db.refresh(entity)
  29. return entity
  30. def list_by_tenant(self, *, tenant_id: str) -> list[KnowledgeBase]:
  31. stmt = (
  32. select(KnowledgeBase)
  33. .where(KnowledgeBase.tenant_id == tenant_id)
  34. .order_by(KnowledgeBase.created_time.desc())
  35. )
  36. return list(self.db.scalars(stmt))
  37. def get_by_id(self, *, tenant_id: str, knowledge_base_id: str) -> KnowledgeBase | None:
  38. stmt = (
  39. select(KnowledgeBase)
  40. .where(KnowledgeBase.tenant_id == tenant_id)
  41. .where(KnowledgeBase.id == knowledge_base_id)
  42. )
  43. return self.db.scalar(stmt)
  44. def update_status(
  45. self,
  46. *,
  47. tenant_id: str,
  48. knowledge_base_id: str,
  49. status: KnowledgeBaseStatus,
  50. ) -> KnowledgeBase | None:
  51. entity = self.get_by_id(tenant_id=tenant_id, knowledge_base_id=knowledge_base_id)
  52. if entity is None:
  53. return None
  54. entity.status = status
  55. self.db.commit()
  56. self.db.refresh(entity)
  57. return entity
  58. class KnowledgeDocumentRepository:
  59. def __init__(self, db: Session) -> None:
  60. self.db = db
  61. def create(
  62. self,
  63. *,
  64. tenant_id: str,
  65. knowledge_base_id: str,
  66. title: str,
  67. source_type: str,
  68. source_uri: str | None,
  69. content_text: str,
  70. content_hash: str | None,
  71. metadata_json: dict[str, JSONValue] | None,
  72. ) -> KnowledgeDocument:
  73. entity = KnowledgeDocument(
  74. tenant_id=tenant_id,
  75. knowledge_base_id=knowledge_base_id,
  76. title=title,
  77. source_type=source_type,
  78. source_uri=source_uri,
  79. content_text=content_text,
  80. content_hash=content_hash,
  81. metadata_json=metadata_json,
  82. status="draft",
  83. )
  84. self.db.add(entity)
  85. self.db.commit()
  86. self.db.refresh(entity)
  87. return entity
  88. def list_by_base(
  89. self,
  90. *,
  91. tenant_id: str,
  92. knowledge_base_id: str,
  93. ) -> list[KnowledgeDocument]:
  94. stmt = (
  95. select(KnowledgeDocument)
  96. .where(KnowledgeDocument.tenant_id == tenant_id)
  97. .where(KnowledgeDocument.knowledge_base_id == knowledge_base_id)
  98. .order_by(KnowledgeDocument.created_time.desc())
  99. )
  100. return list(self.db.scalars(stmt))
  101. def get_by_id(self, *, tenant_id: str, document_id: str) -> KnowledgeDocument | None:
  102. stmt = (
  103. select(KnowledgeDocument)
  104. .where(KnowledgeDocument.tenant_id == tenant_id)
  105. .where(KnowledgeDocument.id == document_id)
  106. )
  107. return self.db.scalar(stmt)
  108. def update_status(
  109. self,
  110. *,
  111. document_id: str,
  112. status: KnowledgeDocumentStatus,
  113. ) -> KnowledgeDocument | None:
  114. entity = self.db.get(KnowledgeDocument, document_id)
  115. if entity is None:
  116. return None
  117. entity.status = status
  118. entity.indexed_time = datetime.utcnow() if status == "indexed" else entity.indexed_time
  119. self.db.commit()
  120. self.db.refresh(entity)
  121. return entity
  122. class KnowledgeChunkRepository:
  123. def __init__(self, db: Session) -> None:
  124. self.db = db
  125. def replace_document_chunks(
  126. self,
  127. *,
  128. tenant_id: str,
  129. knowledge_base_id: str,
  130. document_id: str,
  131. chunks: list[dict[str, JSONValue]],
  132. ) -> list[KnowledgeChunk]:
  133. self.db.execute(
  134. delete(KnowledgeChunk)
  135. .where(KnowledgeChunk.tenant_id == tenant_id)
  136. .where(KnowledgeChunk.document_id == document_id)
  137. )
  138. entities: list[KnowledgeChunk] = []
  139. for chunk in chunks:
  140. entity = KnowledgeChunk(
  141. tenant_id=tenant_id,
  142. knowledge_base_id=knowledge_base_id,
  143. document_id=document_id,
  144. chunk_index=_read_int(chunk, "chunk_index"),
  145. content_text=_read_string(chunk, "content_text"),
  146. token_count=_read_int(chunk, "token_count"),
  147. embedding_model=_read_optional_string(chunk, "embedding_model"),
  148. embedding_json=_read_float_list(chunk, "embedding_json"),
  149. metadata_json=_read_optional_dict(chunk, "metadata_json"),
  150. )
  151. self.db.add(entity)
  152. entities.append(entity)
  153. self.db.commit()
  154. for entity in entities:
  155. self.db.refresh(entity)
  156. return entities
  157. def list_by_base(
  158. self,
  159. *,
  160. tenant_id: str,
  161. knowledge_base_id: str,
  162. ) -> list[KnowledgeChunk]:
  163. stmt = (
  164. select(KnowledgeChunk)
  165. .where(KnowledgeChunk.tenant_id == tenant_id)
  166. .where(KnowledgeChunk.knowledge_base_id == knowledge_base_id)
  167. .order_by(KnowledgeChunk.created_time.asc())
  168. )
  169. return list(self.db.scalars(stmt))
  170. def _read_string(payload: dict[str, JSONValue], key: str) -> str:
  171. value = payload.get(key)
  172. return value if isinstance(value, str) else ""
  173. def _read_optional_string(payload: dict[str, JSONValue], key: str) -> str | None:
  174. value = payload.get(key)
  175. return value if isinstance(value, str) else None
  176. def _read_int(payload: dict[str, JSONValue], key: str) -> int:
  177. value = payload.get(key)
  178. return value if isinstance(value, int) and not isinstance(value, bool) else 0
  179. def _read_float_list(payload: dict[str, JSONValue], key: str) -> list[float] | None:
  180. value = payload.get(key)
  181. if not isinstance(value, list):
  182. return None
  183. return [float(item) for item in value if isinstance(item, (int, float))]
  184. def _read_optional_dict(
  185. payload: dict[str, JSONValue],
  186. key: str,
  187. ) -> dict[str, JSONValue] | None:
  188. value = payload.get(key)
  189. if isinstance(value, dict):
  190. return {str(item_key): item_value for item_key, item_value in value.items()}
  191. return None