Explorar o código

feat: add encrypted tool credentials

Jax Docker hai 1 mes
pai
achega
e70785ce6e

+ 1 - 0
deployments/docker/.env.example

@@ -28,6 +28,7 @@ AGENT_PLATFORM_AUTH_REQUIRED=false
 AGENT_PLATFORM_AUTHZ_REQUIRED=false
 AGENT_PLATFORM_INTERNAL_SERVICE_AUTH_REQUIRED=false
 AGENT_PLATFORM_INTERNAL_SERVICE_TOKEN=replace-with-shared-internal-token
+AGENT_PLATFORM_CREDENTIAL_ENCRYPTION_KEY=replace-with-strong-credential-encryption-key
 AGENT_PLATFORM_RATE_LIMIT_ENABLED=false
 AGENT_PLATFORM_TENANT_RATE_LIMIT_PER_MINUTE=600
 AGENT_PLATFORM_API_KEY_RATE_LIMIT_PER_MINUTE=1200

+ 1 - 0
deployments/docker/docker-compose.yml

@@ -1,6 +1,7 @@
 x-agent-platform-common-env: &agent-platform-common-env
   AGENT_PLATFORM_INTERNAL_SERVICE_AUTH_REQUIRED: ${AGENT_PLATFORM_INTERNAL_SERVICE_AUTH_REQUIRED:-false}
   AGENT_PLATFORM_INTERNAL_SERVICE_TOKEN: ${AGENT_PLATFORM_INTERNAL_SERVICE_TOKEN:-}
+  AGENT_PLATFORM_CREDENTIAL_ENCRYPTION_KEY: ${AGENT_PLATFORM_CREDENTIAL_ENCRYPTION_KEY:-local-development-credential-key}
 
 services:
   postgres:

+ 4 - 0
libs/core-domain/src/core_domain/__init__.py

@@ -99,6 +99,8 @@ from .team_contracts import (
 from .tool_contracts import (
     ToolBindingContract,
     ToolBindingDetailContract,
+    ToolCredentialContract,
+    ToolCredentialRevealContract,
     ToolDefinitionContract,
     ToolVersionContract,
 )
@@ -180,6 +182,8 @@ __all__ = [
     "TeamVersionStatus",
     "ToolBindingContract",
     "ToolBindingDetailContract",
+    "ToolCredentialContract",
+    "ToolCredentialRevealContract",
     "ToolDefinitionContract",
     "ToolVersionContract",
     "WorkflowRunStatus",

+ 17 - 1
libs/core-domain/src/core_domain/tool_contracts.py

@@ -1,6 +1,6 @@
 from datetime import datetime
 
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
 
 from core_shared import JSONValue
 
@@ -41,6 +41,22 @@ class ToolBindingContract(BaseModel):
     created_time: datetime
 
 
+class ToolCredentialContract(BaseModel):
+    id: str
+    tenant_id: str
+    name: str
+    credential_type: str
+    secret_fingerprint: str
+    encryption_algorithm: str
+    metadata_json: dict[str, JSONValue] = Field(default_factory=dict)
+    created_time: datetime
+
+
+class ToolCredentialRevealContract(BaseModel):
+    credential: ToolCredentialContract
+    secret_json: dict[str, JSONValue] = Field(default_factory=dict)
+
+
 class ToolBindingDetailContract(BaseModel):
     binding: ToolBindingContract
     tool_version: ToolVersionContract

+ 1 - 0
libs/core-shared/src/core_shared/config.py

@@ -15,6 +15,7 @@ class ServiceSettings(BaseSettings):
     internal_service_token: str | None = Field(default=None)
     internal_service_token_header_name: str = Field(default="x-internal-service-token")
     internal_service_name_header_name: str = Field(default="x-internal-service-name")
+    credential_encryption_key: str = Field(default="local-development-credential-key")
 
     model_config = SettingsConfigDict(
         env_prefix="AGENT_PLATFORM_",

+ 108 - 0
libs/core-shared/src/core_shared/secrets.py

@@ -0,0 +1,108 @@
+from __future__ import annotations
+
+import base64
+import hashlib
+import hmac
+import json
+from dataclasses import dataclass
+
+from core_shared.types import JSONValue
+
+
+@dataclass(frozen=True, slots=True)
+class EncryptedSecret:
+    ciphertext: str
+    fingerprint: str
+    algorithm: str
+
+
+class SecretCipher:
+    def __init__(self, *, key: str) -> None:
+        if not key:
+            raise ValueError("secret cipher key is required")
+        self._key = key.encode("utf-8")
+
+    def encrypt_json(self, payload: dict[str, JSONValue]) -> EncryptedSecret:
+        plaintext = json.dumps(payload, ensure_ascii=False, sort_keys=True).encode("utf-8")
+        fingerprint = hashlib.sha256(plaintext).hexdigest()
+        fernet_ciphertext = self._try_encrypt_with_fernet(plaintext)
+        if fernet_ciphertext is not None:
+            return EncryptedSecret(
+                ciphertext=fernet_ciphertext,
+                fingerprint=fingerprint,
+                algorithm="fernet-sha256",
+            )
+        return EncryptedSecret(
+            ciphertext=self._encrypt_with_hmac_stream(plaintext),
+            fingerprint=fingerprint,
+            algorithm="hmac-stream-sha256",
+        )
+
+    def decrypt_json(self, encrypted: EncryptedSecret) -> dict[str, JSONValue]:
+        if encrypted.algorithm == "fernet-sha256":
+            plaintext = self._decrypt_with_fernet(encrypted.ciphertext)
+        else:
+            plaintext = self._decrypt_with_hmac_stream(encrypted.ciphertext)
+        fingerprint = hashlib.sha256(plaintext).hexdigest()
+        if not hmac.compare_digest(fingerprint, encrypted.fingerprint):
+            raise ValueError("secret fingerprint mismatch")
+        value = json.loads(plaintext.decode("utf-8"))
+        if not isinstance(value, dict):
+            raise ValueError("secret payload must be a JSON object")
+        return {str(item_key): item_value for item_key, item_value in value.items()}
+
+    def _try_encrypt_with_fernet(self, plaintext: bytes) -> str | None:
+        try:
+            from cryptography.fernet import Fernet
+        except Exception:
+            return None
+        return Fernet(self._fernet_key()).encrypt(plaintext).decode("utf-8")
+
+    def _decrypt_with_fernet(self, ciphertext: str) -> bytes:
+        try:
+            from cryptography.fernet import Fernet
+        except Exception as exc:
+            raise ValueError("cryptography is required to decrypt fernet secrets") from exc
+        return Fernet(self._fernet_key()).decrypt(ciphertext.encode("utf-8"))
+
+    def _fernet_key(self) -> bytes:
+        digest = hashlib.sha256(self._key).digest()
+        return base64.urlsafe_b64encode(digest)
+
+    def _encrypt_with_hmac_stream(self, plaintext: bytes) -> str:
+        nonce = hashlib.sha256(self._key + plaintext).digest()[:16]
+        stream = _build_stream(key=self._key, nonce=nonce, length=len(plaintext))
+        ciphertext = bytes(item ^ stream[index] for index, item in enumerate(plaintext))
+        signature = hmac.new(self._key, nonce + ciphertext, hashlib.sha256).digest()
+        return base64.urlsafe_b64encode(nonce + signature + ciphertext).decode("utf-8")
+
+    def _decrypt_with_hmac_stream(self, ciphertext: str) -> bytes:
+        try:
+            payload = base64.urlsafe_b64decode(ciphertext.encode("utf-8"))
+        except Exception as exc:
+            raise ValueError("invalid encrypted secret encoding") from exc
+        if len(payload) < 48:
+            raise ValueError("invalid encrypted secret payload")
+        nonce = payload[:16]
+        signature = payload[16:48]
+        encrypted_payload = payload[48:]
+        expected_signature = hmac.new(
+            self._key,
+            nonce + encrypted_payload,
+            hashlib.sha256,
+        ).digest()
+        if not hmac.compare_digest(signature, expected_signature):
+            raise ValueError("secret signature mismatch")
+        stream = _build_stream(key=self._key, nonce=nonce, length=len(encrypted_payload))
+        return bytes(item ^ stream[index] for index, item in enumerate(encrypted_payload))
+
+
+def _build_stream(*, key: bytes, nonce: bytes, length: int) -> bytes:
+    chunks: list[bytes] = []
+    counter = 0
+    while sum(len(chunk) for chunk in chunks) < length:
+        chunks.append(
+            hashlib.sha256(key + nonce + counter.to_bytes(8, "big")).digest()
+        )
+        counter += 1
+    return b"".join(chunks)[:length]

+ 50 - 0
services/tool-service/alembic/versions/20260427_0002_add_tool_credentials.py

@@ -0,0 +1,50 @@
+"""add tool credentials
+
+Revision ID: 20260427_0002
+Revises: 20260422_0001
+Create Date: 2026-04-27 00:00:00
+"""
+
+from collections.abc import Sequence
+
+from alembic import op
+import sqlalchemy as sa
+
+
+revision: str = "20260427_0002"
+down_revision: str | None = "20260422_0001"
+branch_labels: Sequence[str] | None = None
+depends_on: Sequence[str] | None = None
+
+
+def upgrade() -> None:
+    op.create_table(
+        "tool_credential",
+        sa.Column("name", sa.String(length=128), nullable=False),
+        sa.Column("credential_type", sa.String(length=64), nullable=False),
+        sa.Column("encrypted_payload_text", sa.Text(), nullable=False),
+        sa.Column("secret_fingerprint", sa.String(length=64), nullable=False),
+        sa.Column("encryption_algorithm", sa.String(length=64), nullable=False),
+        sa.Column("metadata_json", sa.JSON(), nullable=True),
+        sa.Column("id", sa.String(length=36), nullable=False),
+        sa.Column("tenant_id", sa.String(length=36), nullable=False),
+        sa.Column("created_by", sa.String(length=36), nullable=True),
+        sa.Column("updated_by", sa.String(length=36), nullable=True),
+        sa.Column("created_time", sa.DateTime(), nullable=False),
+        sa.Column("updated_time", sa.DateTime(), nullable=False),
+        sa.Column("deleted_time", sa.DateTime(), nullable=True),
+        sa.Column("version", sa.Integer(), nullable=False),
+        sa.PrimaryKeyConstraint("id"),
+    )
+    op.create_index("ix_tool_credential_tenant_id", "tool_credential", ["tenant_id"])
+    op.create_index(
+        "ix_tool_credential_secret_fingerprint",
+        "tool_credential",
+        ["secret_fingerprint"],
+    )
+
+
+def downgrade() -> None:
+    op.drop_index("ix_tool_credential_secret_fingerprint", table_name="tool_credential")
+    op.drop_index("ix_tool_credential_tenant_id", table_name="tool_credential")
+    op.drop_table("tool_credential")

+ 62 - 4
services/tool-service/app/api/routes.py

@@ -1,15 +1,25 @@
-from fastapi import APIRouter, Depends, HTTPException, Query
+from fastapi import APIRouter, Depends, HTTPException, Query, Request
 from sqlalchemy import text
 from sqlalchemy.orm import Session
 
 from core_domain import ServiceHealth
+from core_shared.secrets import SecretCipher
 from app.application.services import ToolApplicationService
+from app.bootstrap.settings import ToolServiceSettings
 from app.db.session import get_db
-from app.domain.repositories import ToolBindingRepository, ToolDefinitionRepository, ToolVersionRepository
+from app.domain.repositories import (
+    ToolBindingRepository,
+    ToolCredentialRepository,
+    ToolDefinitionRepository,
+    ToolVersionRepository,
+)
 from app.schemas.tool import (
     ToolBindingCreateRequest,
     ToolBindingDetailResponse,
     ToolBindingResponse,
+    ToolCredentialCreateRequest,
+    ToolCredentialResponse,
+    ToolCredentialRevealResponse,
     ToolCreateRequest,
     ToolResponse,
     ToolVersionCreateRequest,
@@ -19,11 +29,17 @@ from app.schemas.tool import (
 router = APIRouter()
 
 
-def get_tool_application_service(db: Session = Depends(get_db)) -> ToolApplicationService:
+def get_tool_application_service(
+    request: Request,
+    db: Session = Depends(get_db),
+) -> ToolApplicationService:
+    settings: ToolServiceSettings = request.app.state.settings
     return ToolApplicationService(
         tool_definition_repository=ToolDefinitionRepository(db),
         tool_version_repository=ToolVersionRepository(db),
         tool_binding_repository=ToolBindingRepository(db),
+        tool_credential_repository=ToolCredentialRepository(db),
+        secret_cipher=SecretCipher(key=settings.credential_encryption_key),
     )
 
 
@@ -76,7 +92,10 @@ def create_tool_binding(
     payload: ToolBindingCreateRequest,
     service: ToolApplicationService = Depends(get_tool_application_service),
 ) -> ToolBindingResponse:
-    entity = service.create_tool_binding(payload)
+    try:
+        entity = service.create_tool_binding(payload)
+    except ValueError as exc:
+        raise HTTPException(status_code=422, detail=str(exc)) from exc
     return ToolBindingResponse.from_entity(entity)
 
 
@@ -108,3 +127,42 @@ def get_tool_binding_detail(
         tool_version=ToolVersionResponse.from_entity(tool_version),
         tool_definition=ToolResponse.from_entity(tool_definition),
     )
+
+
+@router.post("/credentials", response_model=ToolCredentialResponse)
+def create_tool_credential(
+    payload: ToolCredentialCreateRequest,
+    service: ToolApplicationService = Depends(get_tool_application_service),
+) -> ToolCredentialResponse:
+    entity = service.create_tool_credential(payload)
+    return ToolCredentialResponse.from_entity(entity)
+
+
+@router.get("/credentials", response_model=list[ToolCredentialResponse])
+def list_tool_credentials(
+    tenant_id: str = Query(...),
+    service: ToolApplicationService = Depends(get_tool_application_service),
+) -> list[ToolCredentialResponse]:
+    return [
+        ToolCredentialResponse.from_entity(item)
+        for item in service.list_tool_credentials(tenant_id)
+    ]
+
+
+@router.post("/credentials/{credential_id}/reveal", response_model=ToolCredentialRevealResponse)
+def reveal_tool_credential(
+    credential_id: str,
+    tenant_id: str = Query(...),
+    service: ToolApplicationService = Depends(get_tool_application_service),
+) -> ToolCredentialRevealResponse:
+    result = service.reveal_tool_credential(
+        tenant_id=tenant_id,
+        credential_id=credential_id,
+    )
+    if result is None:
+        raise HTTPException(status_code=404, detail=f"tool credential not found: {credential_id}")
+    credential, secret_json = result
+    return ToolCredentialRevealResponse(
+        credential=ToolCredentialResponse.from_entity(credential),
+        secret_json=secret_json,
+    )

+ 58 - 2
services/tool-service/app/application/services.py

@@ -1,7 +1,16 @@
-from app.db.models import ToolBinding, ToolDefinition, ToolVersion
-from app.domain.repositories import ToolBindingRepository, ToolDefinitionRepository, ToolVersionRepository
+from core_shared import JSONValue
+from core_shared.secrets import EncryptedSecret, SecretCipher
+
+from app.db.models import ToolBinding, ToolCredential, ToolDefinition, ToolVersion
+from app.domain.repositories import (
+    ToolBindingRepository,
+    ToolCredentialRepository,
+    ToolDefinitionRepository,
+    ToolVersionRepository,
+)
 from app.schemas.tool import (
     ToolBindingCreateRequest,
+    ToolCredentialCreateRequest,
     ToolCreateRequest,
     ToolVersionCreateRequest,
 )
@@ -13,10 +22,14 @@ class ToolApplicationService:
         tool_definition_repository: ToolDefinitionRepository,
         tool_version_repository: ToolVersionRepository,
         tool_binding_repository: ToolBindingRepository,
+        tool_credential_repository: ToolCredentialRepository,
+        secret_cipher: SecretCipher,
     ) -> None:
         self.tool_definition_repository = tool_definition_repository
         self.tool_version_repository = tool_version_repository
         self.tool_binding_repository = tool_binding_repository
+        self.tool_credential_repository = tool_credential_repository
+        self.secret_cipher = secret_cipher
 
     def create_tool_definition(self, payload: ToolCreateRequest) -> ToolDefinition:
         return self.tool_definition_repository.create(
@@ -46,6 +59,13 @@ class ToolApplicationService:
         return self.tool_version_repository.list_by_tool(tenant_id=tenant_id, tool_id=tool_id)
 
     def create_tool_binding(self, payload: ToolBindingCreateRequest) -> ToolBinding:
+        if payload.credential_id is not None:
+            credential = self.tool_credential_repository.get_by_id(
+                tenant_id=payload.tenant_id,
+                credential_id=payload.credential_id,
+            )
+            if credential is None:
+                raise ValueError(f"tool credential not found: {payload.credential_id}")
         return self.tool_binding_repository.create(
             tenant_id=payload.tenant_id,
             app_id=payload.app_id,
@@ -59,6 +79,42 @@ class ToolApplicationService:
     def list_tool_bindings(self, tenant_id: str, app_id: str | None = None) -> list[ToolBinding]:
         return self.tool_binding_repository.list_by_scope(tenant_id=tenant_id, app_id=app_id)
 
+    def create_tool_credential(self, payload: ToolCredentialCreateRequest) -> ToolCredential:
+        encrypted = self.secret_cipher.encrypt_json(payload.secret_json)
+        return self.tool_credential_repository.create(
+            tenant_id=payload.tenant_id,
+            name=payload.name,
+            credential_type=payload.credential_type,
+            encrypted_payload_text=encrypted.ciphertext,
+            secret_fingerprint=encrypted.fingerprint,
+            encryption_algorithm=encrypted.algorithm,
+            metadata_json=payload.metadata_json,
+        )
+
+    def list_tool_credentials(self, tenant_id: str) -> list[ToolCredential]:
+        return self.tool_credential_repository.list_by_tenant(tenant_id=tenant_id)
+
+    def reveal_tool_credential(
+        self,
+        *,
+        tenant_id: str,
+        credential_id: str,
+    ) -> tuple[ToolCredential, dict[str, JSONValue]] | None:
+        credential = self.tool_credential_repository.get_by_id(
+            tenant_id=tenant_id,
+            credential_id=credential_id,
+        )
+        if credential is None:
+            return None
+        secret_json = self.secret_cipher.decrypt_json(
+            EncryptedSecret(
+                ciphertext=credential.encrypted_payload_text,
+                fingerprint=credential.secret_fingerprint,
+                algorithm=credential.encryption_algorithm,
+            )
+        )
+        return credential, secret_json
+
     def get_tool_binding_detail(
         self,
         *,

+ 2 - 2
services/tool-service/app/db/models/__init__.py

@@ -1,8 +1,8 @@
 from core_db import Base
 
 from .tool_binding import ToolBinding
+from .tool_credential import ToolCredential
 from .tool_definition import ToolDefinition
 from .tool_version import ToolVersion
 
-__all__ = ["Base", "ToolBinding", "ToolDefinition", "ToolVersion"]
-
+__all__ = ["Base", "ToolBinding", "ToolCredential", "ToolDefinition", "ToolVersion"]

+ 17 - 0
services/tool-service/app/db/models/tool_credential.py

@@ -0,0 +1,17 @@
+from sqlalchemy import String, Text
+from sqlalchemy.dialects.sqlite import JSON
+from sqlalchemy.orm import Mapped, mapped_column
+
+from core_db import AuditMixin, Base, TenantMixin, VersionMixin
+from core_shared import JSONValue
+
+
+class ToolCredential(TenantMixin, AuditMixin, VersionMixin, Base):
+    __tablename__ = "tool_credential"
+
+    name: Mapped[str] = mapped_column(String(128))
+    credential_type: Mapped[str] = mapped_column(String(64), default="generic")
+    encrypted_payload_text: Mapped[str] = mapped_column(Text)
+    secret_fingerprint: Mapped[str] = mapped_column(String(64), index=True)
+    encryption_algorithm: Mapped[str] = mapped_column(String(64))
+    metadata_json: Mapped[dict[str, JSONValue]] = mapped_column(JSON, default=dict)

+ 47 - 1
services/tool-service/app/domain/repositories.py

@@ -1,7 +1,7 @@
 from sqlalchemy import func, select
 from sqlalchemy.orm import Session
 
-from app.db.models import ToolBinding, ToolDefinition, ToolVersion
+from app.db.models import ToolBinding, ToolCredential, ToolDefinition, ToolVersion
 from core_shared import JSONValue
 
 
@@ -145,3 +145,49 @@ class ToolBindingRepository:
             .where(ToolBinding.id == binding_id)
         )
         return self.db.scalar(stmt)
+
+
+class ToolCredentialRepository:
+    def __init__(self, db: Session) -> None:
+        self.db = db
+
+    def create(
+        self,
+        *,
+        tenant_id: str,
+        name: str,
+        credential_type: str,
+        encrypted_payload_text: str,
+        secret_fingerprint: str,
+        encryption_algorithm: str,
+        metadata_json: dict[str, JSONValue],
+    ) -> ToolCredential:
+        entity = ToolCredential(
+            tenant_id=tenant_id,
+            name=name,
+            credential_type=credential_type,
+            encrypted_payload_text=encrypted_payload_text,
+            secret_fingerprint=secret_fingerprint,
+            encryption_algorithm=encryption_algorithm,
+            metadata_json=metadata_json,
+        )
+        self.db.add(entity)
+        self.db.commit()
+        self.db.refresh(entity)
+        return entity
+
+    def list_by_tenant(self, *, tenant_id: str) -> list[ToolCredential]:
+        stmt = (
+            select(ToolCredential)
+            .where(ToolCredential.tenant_id == tenant_id)
+            .order_by(ToolCredential.created_time.desc())
+        )
+        return list(self.db.scalars(stmt))
+
+    def get_by_id(self, *, tenant_id: str, credential_id: str) -> ToolCredential | None:
+        stmt = (
+            select(ToolCredential)
+            .where(ToolCredential.tenant_id == tenant_id)
+            .where(ToolCredential.id == credential_id)
+        )
+        return self.db.scalar(stmt)

+ 22 - 1
services/tool-service/app/schemas/tool.py

@@ -5,13 +5,15 @@ from pydantic import BaseModel, Field
 from core_domain import (
     ToolBindingContract,
     ToolBindingDetailContract,
+    ToolCredentialContract,
+    ToolCredentialRevealContract,
     ToolDefinitionContract,
     ToolVersionContract,
 )
 from core_shared import JSONValue
 
 if TYPE_CHECKING:
-    from app.db.models import ToolBinding, ToolDefinition, ToolVersion
+    from app.db.models import ToolBinding, ToolCredential, ToolDefinition, ToolVersion
 
 
 class ToolCreateRequest(BaseModel):
@@ -68,3 +70,22 @@ class ToolBindingDetailResponse(ToolBindingDetailContract):
     binding: ToolBindingResponse
     tool_version: ToolVersionResponse
     tool_definition: ToolResponse
+
+
+class ToolCredentialCreateRequest(BaseModel):
+    tenant_id: str
+    name: str
+    credential_type: str = "generic"
+    secret_json: dict[str, JSONValue] = Field(default_factory=dict)
+    metadata_json: dict[str, JSONValue] = Field(default_factory=dict)
+
+
+class ToolCredentialResponse(ToolCredentialContract):
+
+    @classmethod
+    def from_entity(cls, entity: "ToolCredential") -> "ToolCredentialResponse":
+        return cls.model_validate(entity, from_attributes=True)
+
+
+class ToolCredentialRevealResponse(ToolCredentialRevealContract):
+    credential: ToolCredentialResponse

+ 30 - 0
tests/test_secrets.py

@@ -0,0 +1,30 @@
+import pytest
+
+from core_shared.secrets import EncryptedSecret, SecretCipher
+
+
+def test_secret_cipher_round_trips_json_payload() -> None:
+    cipher = SecretCipher(key="test-key")
+    payload = {"api_key": "secret-value", "nested": {"token": "abc"}}
+
+    encrypted = cipher.encrypt_json(payload)
+    decrypted = cipher.decrypt_json(encrypted)
+
+    assert encrypted.ciphertext
+    assert encrypted.ciphertext != "secret-value"
+    assert encrypted.fingerprint
+    assert decrypted == payload
+
+
+def test_secret_cipher_rejects_tampered_payload() -> None:
+    cipher = SecretCipher(key="test-key")
+    encrypted = cipher.encrypt_json({"api_key": "secret-value"})
+
+    tampered = EncryptedSecret(
+        ciphertext=encrypted.ciphertext[:-2] + "AA",
+        fingerprint=encrypted.fingerprint,
+        algorithm=encrypted.algorithm,
+    )
+
+    with pytest.raises(ValueError):
+        cipher.decrypt_json(tampered)