conftest.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. from __future__ import annotations
  2. import atexit
  3. import hashlib
  4. import os
  5. import sys
  6. from dataclasses import dataclass
  7. from pathlib import Path
  8. from typing import Any
  9. from sqlalchemy import create_engine, text
  10. from sqlalchemy.engine import Engine, make_url
  11. REPO_ROOT = Path(__file__).resolve().parents[1]
  12. DEFAULT_TEST_DATABASE_URL = (
  13. "postgresql+psycopg://admin:hFOvG5UBeK5KIGhz5cQH@git.newpoint.work:5432/vectordb"
  14. )
  15. _CREATED_TEST_SCHEMAS: set[str] = set()
  16. @dataclass(frozen=True)
  17. class ServiceImportConfig:
  18. service_name: str
  19. libs: tuple[str, ...]
  20. SERVICE_IMPORT_CONFIGS: dict[str, ServiceImportConfig] = {
  21. "agent-service": ServiceImportConfig(
  22. service_name="agent-service",
  23. libs=("core-domain", "core-shared", "core-db", "core-events")),
  24. "api-gateway": ServiceImportConfig(
  25. service_name="api-gateway",
  26. libs=("core-domain", "core-shared", "core-db")),
  27. "auth-service": ServiceImportConfig(
  28. service_name="auth-service",
  29. libs=("core-domain", "core-shared", "core-db")),
  30. "event-service": ServiceImportConfig(
  31. service_name="event-service",
  32. libs=("core-domain", "core-shared", "core-db", "core-events")),
  33. "human-service": ServiceImportConfig(
  34. service_name="human-service",
  35. libs=("core-domain", "core-shared", "core-db")),
  36. "knowledge-service": ServiceImportConfig(
  37. service_name="knowledge-service",
  38. libs=("core-domain", "core-shared", "core-db")),
  39. "memory-service": ServiceImportConfig(
  40. service_name="memory-service",
  41. libs=("core-domain", "core-shared", "core-db")),
  42. "model-gateway-service": ServiceImportConfig(
  43. service_name="model-gateway-service",
  44. libs=("core-domain", "core-shared", "core-db")),
  45. "scheduler-service": ServiceImportConfig(
  46. service_name="scheduler-service",
  47. libs=("core-domain", "core-shared", "core-db")),
  48. "session-service": ServiceImportConfig(
  49. service_name="session-service",
  50. libs=("core-domain", "core-shared", "core-db")),
  51. "skill-service": ServiceImportConfig(
  52. service_name="skill-service",
  53. libs=("core-domain", "core-shared", "core-db")),
  54. "tool-service": ServiceImportConfig(
  55. service_name="tool-service",
  56. libs=("core-domain", "core-shared", "core-db")),
  57. "team-service": ServiceImportConfig(
  58. service_name="team-service",
  59. libs=("core-domain", "core-shared", "core-db", "core-events")),
  60. }
  61. def prepare_service_import(
  62. service_name: str,
  63. *,
  64. libs: tuple[str, ...]) -> None:
  65. for module_name in list(sys.modules):
  66. if module_name == "app" or module_name.startswith("app."):
  67. del sys.modules[module_name]
  68. _clear_shared_sqlalchemy_metadata()
  69. for lib_name in libs:
  70. lib_path = REPO_ROOT / "libs" / lib_name / "src"
  71. _prepend_sys_path(lib_path)
  72. _prepend_sys_path(REPO_ROOT / "services" / service_name)
  73. def prepare_known_service_import(service_name: str) -> None:
  74. config = SERVICE_IMPORT_CONFIGS[service_name]
  75. prepare_service_import(config.service_name, libs=config.libs)
  76. def build_postgres_database_url(tmp_path: Path, filename: str) -> str:
  77. schema_name = _build_test_schema_name(tmp_path=tmp_path, filename=filename)
  78. base_url = _base_test_database_url()
  79. _create_postgres_schema(base_url=base_url, schema_name=schema_name)
  80. url = make_url(base_url)
  81. query = dict(url.query)
  82. query["options"] = f"-csearch_path={schema_name}"
  83. return url.set(query=query).render_as_string(hide_password=False)
  84. def build_postgres_engine(database_url: str) -> Engine:
  85. return create_engine(database_url, pool_pre_ping=True)
  86. def _prepend_sys_path(path: Path) -> None:
  87. path_text = str(path)
  88. if path_text in sys.path:
  89. sys.path.remove(path_text)
  90. sys.path.insert(0, path_text)
  91. def _clear_shared_sqlalchemy_metadata() -> None:
  92. try:
  93. from core_db import Base
  94. except ImportError:
  95. return
  96. Base.registry.dispose()
  97. Base.metadata.clear()
  98. def _base_test_database_url() -> str:
  99. return (
  100. os.getenv("AGENT_PLATFORM_TEST_DATABASE_URL")
  101. or os.getenv("AGENT_PLATFORM_DATABASE_URL")
  102. or DEFAULT_TEST_DATABASE_URL
  103. )
  104. def _build_test_schema_name(*, tmp_path: Path, filename: str) -> str:
  105. digest = hashlib.sha1(str(tmp_path / filename).encode("utf-8")).hexdigest()[:16]
  106. return f"test_{digest}"
  107. def _create_postgres_schema(*, base_url: str, schema_name: str) -> None:
  108. engine = create_engine(base_url, isolation_level="AUTOCOMMIT", pool_pre_ping=True)
  109. with engine.connect() as connection:
  110. connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
  111. connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
  112. _CREATED_TEST_SCHEMAS.add(schema_name)
  113. engine.dispose()
  114. def _drop_created_postgres_schemas() -> None:
  115. if not _CREATED_TEST_SCHEMAS:
  116. return
  117. engine = create_engine(_base_test_database_url(), isolation_level="AUTOCOMMIT")
  118. with engine.connect() as connection:
  119. for schema_name in sorted(_CREATED_TEST_SCHEMAS):
  120. connection.execute(text(f'DROP SCHEMA IF EXISTS "{schema_name}" CASCADE'))
  121. engine.dispose()
  122. atexit.register(_drop_created_postgres_schemas)
  123. def build_fastapi_test_client(app: Any) -> Any:
  124. _patch_httpx_testclient_compatibility()
  125. from fastapi.testclient import TestClient
  126. return TestClient(app)
  127. def _patch_httpx_testclient_compatibility() -> None:
  128. import inspect
  129. import httpx
  130. if "app" in inspect.signature(httpx.Client.__init__).parameters:
  131. return
  132. if getattr(httpx.Client.__init__, "_agent_platform_patched", False):
  133. return
  134. original_init = httpx.Client.__init__
  135. def patched_init(self: httpx.Client, *args: Any, **kwargs: Any) -> None:
  136. kwargs.pop("app", None)
  137. original_init(self, *args, **kwargs)
  138. patched_init._agent_platform_patched = True
  139. httpx.Client.__init__ = patched_init