conftest.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. import sys
  4. from pathlib import Path
  5. from typing import Any
  6. REPO_ROOT = Path(__file__).resolve().parents[1]
  7. @dataclass(frozen=True)
  8. class ServiceImportConfig:
  9. service_name: str
  10. libs: tuple[str, ...]
  11. SERVICE_IMPORT_CONFIGS: dict[str, ServiceImportConfig] = {
  12. "agent-service": ServiceImportConfig(
  13. service_name="agent-service",
  14. libs=("core-domain", "core-shared", "core-db", "core-events"),
  15. ),
  16. "knowledge-service": ServiceImportConfig(
  17. service_name="knowledge-service",
  18. libs=("core-domain", "core-shared", "core-db"),
  19. ),
  20. "runtime-service": ServiceImportConfig(
  21. service_name="runtime-service",
  22. libs=("core-domain", "core-shared", "core-db", "core-events", "core-dsl"),
  23. ),
  24. "workflow-service": ServiceImportConfig(
  25. service_name="workflow-service",
  26. libs=("core-domain", "core-shared", "core-db", "core-dsl"),
  27. ),
  28. }
  29. def prepare_service_import(
  30. service_name: str,
  31. *,
  32. libs: tuple[str, ...],
  33. ) -> None:
  34. for module_name in list(sys.modules):
  35. if module_name == "app" or module_name.startswith("app."):
  36. del sys.modules[module_name]
  37. for lib_name in libs:
  38. lib_path = REPO_ROOT / "libs" / lib_name / "src"
  39. _prepend_sys_path(lib_path)
  40. _prepend_sys_path(REPO_ROOT / "services" / service_name)
  41. def prepare_known_service_import(service_name: str) -> None:
  42. config = SERVICE_IMPORT_CONFIGS[service_name]
  43. prepare_service_import(config.service_name, libs=config.libs)
  44. def build_sqlite_database_url(tmp_path: Path, filename: str) -> str:
  45. return f"sqlite:///{tmp_path / filename}"
  46. def _prepend_sys_path(path: Path) -> None:
  47. path_text = str(path)
  48. if path_text in sys.path:
  49. sys.path.remove(path_text)
  50. sys.path.insert(0, path_text)
  51. def build_fastapi_test_client(app: Any) -> Any:
  52. _patch_httpx_testclient_compatibility()
  53. from fastapi.testclient import TestClient
  54. return TestClient(app)
  55. def _patch_httpx_testclient_compatibility() -> None:
  56. import inspect
  57. import httpx
  58. if "app" in inspect.signature(httpx.Client.__init__).parameters:
  59. return
  60. if getattr(httpx.Client.__init__, "_agent_platform_patched", False):
  61. return
  62. original_init = httpx.Client.__init__
  63. def patched_init(self: httpx.Client, *args: Any, **kwargs: Any) -> None:
  64. kwargs.pop("app", None)
  65. original_init(self, *args, **kwargs)
  66. setattr(patched_init, "_agent_platform_patched", True)
  67. httpx.Client.__init__ = patched_init