浏览代码

feat: add memory hybrid retrieval

Jax Docker 1 月之前
父节点
当前提交
fd85c8ac52

+ 7 - 0
README.md

@@ -264,6 +264,13 @@ Through `api-gateway`, use `/gateway/agents/**`.
 
 `memory-service` stores scoped memories for tenants, users, sessions, agents, and teams. The first version uses database text search so it works without vector infrastructure; pgvector can be added later behind the same API.
 
+Memory search now stores a local deterministic embedding per memory and uses hybrid rerank:
+
+- `keyword_score`: token overlap and frequency
+- `vector_score`: cosine similarity over local hash embeddings
+- `importance_score`: normalized memory importance boost
+- `rerank_mode`: `hybrid-local`
+
 Create a memory:
 
 ```powershell

+ 3 - 0
libs/core-domain/src/core_domain/memory_contracts.py

@@ -19,6 +19,8 @@ class MemoryItemContract(BaseModel):
     content_text: str
     content_json: dict[str, JSONValue] | None = None
     metadata_json: dict[str, JSONValue] = Field(default_factory=dict)
+    embedding_model: str | None = None
+    embedding_json: list[float] | None = None
     owner_agent_id: str | None = None
     user_id: str | None = None
     session_id: str | None = None
@@ -60,3 +62,4 @@ class MemorySearchRequestContract(BaseModel):
 class MemorySearchResultContract(BaseModel):
     item: MemoryItemContract
     score: float
+    score_json: dict[str, JSONValue] = Field(default_factory=dict)

+ 34 - 0
services/memory-service/alembic/versions/20260425_0002_add_memory_embeddings.py

@@ -0,0 +1,34 @@
+"""add memory embeddings
+
+Revision ID: 20260425_0002
+Revises: 20260425_0001
+Create Date: 2026-04-25 17:40:00
+"""
+
+from collections.abc import Sequence
+
+from alembic import op
+import sqlalchemy as sa
+
+
+revision: str = "20260425_0002"
+down_revision: str | None = "20260425_0001"
+branch_labels: Sequence[str] | None = None
+depends_on: Sequence[str] | None = None
+
+
+def upgrade() -> None:
+    op.add_column("memory_item", sa.Column("embedding_model", sa.String(length=64), nullable=True))
+    op.add_column("memory_item", sa.Column("embedding_json", sa.JSON(), nullable=True))
+    op.create_index(
+        "ix_memory_item_embedding_model",
+        "memory_item",
+        ["embedding_model"],
+        unique=False,
+    )
+
+
+def downgrade() -> None:
+    op.drop_index("ix_memory_item_embedding_model", table_name="memory_item")
+    op.drop_column("memory_item", "embedding_json")
+    op.drop_column("memory_item", "embedding_model")

+ 2 - 2
services/memory-service/app/api/routes.py

@@ -64,8 +64,8 @@ def search_memories(
     service: MemoryApplicationService = Depends(get_memory_application_service),
 ) -> list[MemorySearchResultResponse]:
     return [
-        MemorySearchResultResponse.from_entity(item, score=score)
-        for item, score in service.search_memories(payload)
+        MemorySearchResultResponse.from_entity(item, score=score, score_json=score_json)
+        for item, score, score_json in service.search_memories(payload)
     ]
 
 

+ 59 - 0
services/memory-service/app/application/retrieval.py

@@ -0,0 +1,59 @@
+import hashlib
+import math
+import re
+from collections import Counter
+
+TOKEN_PATTERN = re.compile(r"[\w\u4e00-\u9fff]+", re.UNICODE)
+
+
+def tokenize(text: str) -> list[str]:
+    return [item.lower() for item in TOKEN_PATTERN.findall(text)]
+
+
+def build_hash_embedding(text: str, *, dimensions: int) -> list[float]:
+    vector = [0.0 for _ in range(dimensions)]
+    tokens = tokenize(text)
+    if not tokens:
+        return vector
+    for token in tokens:
+        digest = hashlib.sha256(token.encode("utf-8")).digest()
+        index = int.from_bytes(digest[:4], "big") % dimensions
+        sign = 1.0 if digest[4] % 2 == 0 else -1.0
+        vector[index] += sign
+    norm = math.sqrt(sum(item * item for item in vector))
+    if norm == 0:
+        return vector
+    return [round(item / norm, 6) for item in vector]
+
+
+def cosine_similarity(left: list[float] | None, right: list[float] | None) -> float:
+    if not left or not right or len(left) != len(right):
+        return 0.0
+    left_norm = math.sqrt(sum(item * item for item in left))
+    right_norm = math.sqrt(sum(item * item for item in right))
+    if left_norm == 0 or right_norm == 0:
+        return 0.0
+    return sum(a * b for a, b in zip(left, right, strict=True)) / (left_norm * right_norm)
+
+
+def keyword_score(query: str, text: str) -> float:
+    query_tokens = tokenize(query)
+    if not query_tokens:
+        return 0.0
+    text_counts = Counter(tokenize(text))
+    if not text_counts:
+        return 0.0
+    unique_query_tokens = set(query_tokens)
+    matched = sum(1 for token in unique_query_tokens if token in text_counts)
+    frequency = sum(text_counts.get(token, 0) for token in query_tokens)
+    return matched / len(unique_query_tokens) + min(frequency / 20.0, 1.0)
+
+
+def rerank_score(
+    *,
+    keyword: float,
+    vector: float,
+    importance_score: int,
+) -> float:
+    importance = min(max(importance_score, 0), 100) / 100
+    return round(keyword * 0.45 + vector * 0.45 + importance * 0.10, 6)

+ 58 - 12
services/memory-service/app/application/services.py

@@ -2,6 +2,13 @@ from datetime import datetime
 
 from core_domain import MemoryScopeType, MemoryStatus
 
+from app.application.retrieval import (
+    build_hash_embedding,
+    cosine_similarity,
+    keyword_score,
+    rerank_score,
+)
+from app.bootstrap.settings import MemoryServiceSettings
 from app.db.models import MemoryItem
 from app.domain.repositories import MemoryItemRepository
 from app.schemas.memory import (
@@ -12,10 +19,20 @@ from app.schemas.memory import (
 
 
 class MemoryApplicationService:
-    def __init__(self, *, memory_repository: MemoryItemRepository) -> None:
+    def __init__(
+        self,
+        *,
+        memory_repository: MemoryItemRepository,
+        settings: MemoryServiceSettings | None = None,
+    ) -> None:
         self.memory_repository = memory_repository
+        self.settings = settings or MemoryServiceSettings()
 
     def create_memory(self, payload: MemoryCreateRequest) -> MemoryItem:
+        embedding_json = build_hash_embedding(
+            payload.content_text,
+            dimensions=self.settings.embedding_dimensions,
+        )
         return self.memory_repository.create(
             tenant_id=payload.tenant_id,
             scope_type=payload.scope_type,
@@ -24,6 +41,8 @@ class MemoryApplicationService:
             content_text=payload.content_text,
             content_json=payload.content_json,
             metadata_json=payload.metadata_json,
+            embedding_model=self.settings.embedding_model,
+            embedding_json=embedding_json,
             owner_agent_id=payload.owner_agent_id,
             user_id=payload.user_id,
             session_id=payload.session_id,
@@ -49,20 +68,32 @@ class MemoryApplicationService:
             limit=limit,
         )
 
-    def search_memories(self, payload: MemorySearchRequest) -> list[tuple[MemoryItem, float]]:
-        items = self.memory_repository.search(
+    def search_memories(
+        self,
+        payload: MemorySearchRequest,
+    ) -> list[tuple[MemoryItem, float, dict[str, float | str]]]:
+        query_embedding = build_hash_embedding(
+            payload.query,
+            dimensions=self.settings.embedding_dimensions,
+        )
+        candidates = self.memory_repository.search_candidates(
             tenant_id=payload.tenant_id,
-            query=payload.query,
             scope_type=payload.scope_type,
             scope_id=payload.scope_id,
             owner_agent_id=payload.owner_agent_id,
             user_id=payload.user_id,
             session_id=payload.session_id,
-            limit=payload.limit,
+            limit=max(payload.limit * 10, payload.limit),
         )
+        scored_items = [
+            self._score(item=item, query=payload.query, query_embedding=query_embedding)
+            for item in candidates
+        ]
+        scored_items.sort(key=lambda item: item[1], reverse=True)
+        items = [item for item, _, _ in scored_items[: payload.limit]]
         now = datetime.utcnow()
         self.memory_repository.touch_many(memory_ids=[item.id for item in items], accessed_time=now)
-        return [(item, self._score(item=item, query=payload.query)) for item in items]
+        return scored_items[: payload.limit]
 
     def update_memory_status(
         self,
@@ -76,9 +107,24 @@ class MemoryApplicationService:
             status=payload.status,
         )
 
-    def _score(self, *, item: MemoryItem, query: str) -> float:
-        lowered_content = item.content_text.lower()
-        lowered_query = query.lower()
-        exact_bonus = 1.0 if lowered_query in lowered_content else 0.0
-        importance_bonus = min(item.importance_score, 100) / 100
-        return round(exact_bonus + importance_bonus, 4)
+    def _score(
+        self,
+        *,
+        item: MemoryItem,
+        query: str,
+        query_embedding: list[float],
+    ) -> tuple[MemoryItem, float, dict[str, float | str]]:
+        keyword = keyword_score(query, item.content_text)
+        vector = cosine_similarity(query_embedding, item.embedding_json)
+        score = rerank_score(
+            keyword=keyword,
+            vector=vector,
+            importance_score=item.importance_score,
+        )
+        return item, score, {
+            "keyword_score": round(keyword, 6),
+            "vector_score": round(vector, 6),
+            "importance_score": float(item.importance_score),
+            "embedding_model": item.embedding_model or self.settings.embedding_model,
+            "rerank_mode": "hybrid-local",
+        }

+ 2 - 0
services/memory-service/app/bootstrap/settings.py

@@ -6,3 +6,5 @@ class MemoryServiceSettings(ServiceSettings):
     service_port: int = 8008
     database_url: str = "sqlite:///./memory_service.db"
     default_search_limit: int = 8
+    embedding_dimensions: int = 32
+    embedding_model: str = "local-hash-v1"

+ 2 - 0
services/memory-service/app/db/models/memory_item.py

@@ -17,6 +17,8 @@ class MemoryItem(TenantMixin, AuditMixin, VersionMixin, Base):
     content_text: Mapped[str] = mapped_column(Text)
     content_json: Mapped[dict[str, JSONValue] | None] = mapped_column(JSON, nullable=True)
     metadata_json: Mapped[dict[str, JSONValue]] = mapped_column(JSON, default=dict)
+    embedding_model: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
+    embedding_json: Mapped[list[float] | None] = mapped_column(JSON, nullable=True)
     owner_agent_id: Mapped[str | None] = mapped_column(String(36), nullable=True, index=True)
     user_id: Mapped[str | None] = mapped_column(String(36), nullable=True, index=True)
     session_id: Mapped[str | None] = mapped_column(String(36), nullable=True, index=True)

+ 9 - 7
services/memory-service/app/domain/repositories.py

@@ -23,6 +23,8 @@ class MemoryItemRepository:
         content_text: str,
         content_json: dict[str, JSONValue] | None,
         metadata_json: dict[str, JSONValue],
+        embedding_model: str | None,
+        embedding_json: list[float] | None,
         owner_agent_id: str | None,
         user_id: str | None,
         session_id: str | None,
@@ -38,6 +40,8 @@ class MemoryItemRepository:
             content_text=content_text,
             content_json=content_json,
             metadata_json=metadata_json,
+            embedding_model=embedding_model,
+            embedding_json=embedding_json,
             owner_agent_id=owner_agent_id,
             user_id=user_id,
             session_id=session_id,
@@ -70,11 +74,10 @@ class MemoryItemRepository:
         stmt = stmt.order_by(MemoryItem.created_time.desc()).limit(limit)
         return list(self.db.scalars(stmt))
 
-    def search(
+    def search_candidates(
         self,
         *,
         tenant_id: str,
-        query: str,
         scope_type: MemoryScopeType | None,
         scope_id: str | None,
         owner_agent_id: str | None,
@@ -83,13 +86,11 @@ class MemoryItemRepository:
         limit: int,
     ) -> list[MemoryItem]:
         now = datetime.utcnow()
-        pattern = f"%{query}%"
         stmt = (
             select(MemoryItem)
             .where(MemoryItem.tenant_id == tenant_id)
             .where(MemoryItem.status == "active")
             .where(or_(MemoryItem.expires_time.is_(None), MemoryItem.expires_time > now))
-            .where(MemoryItem.content_text.like(pattern))
         )
         if scope_type is not None:
             stmt = stmt.where(MemoryItem.scope_type == scope_type)
@@ -101,9 +102,10 @@ class MemoryItemRepository:
             stmt = stmt.where(MemoryItem.user_id == user_id)
         if session_id is not None:
             stmt = stmt.where(MemoryItem.session_id == session_id)
-        stmt = stmt.order_by(MemoryItem.importance_score.desc(), MemoryItem.created_time.desc()).limit(
-            limit
-        )
+        stmt = stmt.order_by(
+            MemoryItem.importance_score.desc(),
+            MemoryItem.created_time.desc(),
+        ).limit(limit)
         return list(self.db.scalars(stmt))
 
     def get_by_id(self, *, tenant_id: str, memory_id: str) -> MemoryItem | None:

+ 12 - 2
services/memory-service/app/schemas/memory.py

@@ -9,6 +9,7 @@ from core_domain import (
     MemorySearchResultContract,
     MemoryStatus,
 )
+from core_shared import JSONValue
 
 if TYPE_CHECKING:
     from app.db.models import MemoryItem
@@ -37,5 +38,14 @@ class MemorySearchResultResponse(MemorySearchResultContract):
     item: MemoryResponse
 
     @classmethod
-    def from_entity(cls, entity: "MemoryItem", score: float) -> "MemorySearchResultResponse":
-        return cls(item=MemoryResponse.from_entity(entity), score=score)
+    def from_entity(
+        cls,
+        entity: "MemoryItem",
+        score: float,
+        score_json: dict[str, JSONValue],
+    ) -> "MemorySearchResultResponse":
+        return cls(
+            item=MemoryResponse.from_entity(entity),
+            score=score,
+            score_json=score_json,
+        )