conftest.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from __future__ import annotations
  2. import sys
  3. from dataclasses import dataclass
  4. from pathlib import Path
  5. from typing import Any
  6. REPO_ROOT = Path(__file__).resolve().parents[1]
  7. @dataclass(frozen=True)
  8. class ServiceImportConfig:
  9. service_name: str
  10. libs: tuple[str, ...]
  11. SERVICE_IMPORT_CONFIGS: dict[str, ServiceImportConfig] = {
  12. "agent-service": ServiceImportConfig(
  13. service_name="agent-service",
  14. libs=("core-domain", "core-shared", "core-db", "core-events")),
  15. "knowledge-service": ServiceImportConfig(
  16. service_name="knowledge-service",
  17. libs=("core-domain", "core-shared", "core-db")),
  18. "runtime-service": ServiceImportConfig(
  19. service_name="runtime-service",
  20. libs=("core-domain", "core-shared", "core-db", "core-events", "core-dsl")),
  21. "workflow-service": ServiceImportConfig(
  22. service_name="workflow-service",
  23. libs=("core-domain", "core-shared", "core-db", "core-dsl")),
  24. }
  25. def prepare_service_import(
  26. service_name: str,
  27. *,
  28. libs: tuple[str, ...]) -> None:
  29. for module_name in list(sys.modules):
  30. if module_name == "app" or module_name.startswith("app."):
  31. del sys.modules[module_name]
  32. for lib_name in libs:
  33. lib_path = REPO_ROOT / "libs" / lib_name / "src"
  34. _prepend_sys_path(lib_path)
  35. _prepend_sys_path(REPO_ROOT / "services" / service_name)
  36. def prepare_known_service_import(service_name: str) -> None:
  37. config = SERVICE_IMPORT_CONFIGS[service_name]
  38. prepare_service_import(config.service_name, libs=config.libs)
  39. def build_sqlite_database_url(tmp_path: Path, filename: str) -> str:
  40. return f"sqlite:///{tmp_path / filename}"
  41. def _prepend_sys_path(path: Path) -> None:
  42. path_text = str(path)
  43. if path_text in sys.path:
  44. sys.path.remove(path_text)
  45. sys.path.insert(0, path_text)
  46. def build_fastapi_test_client(app: Any) -> Any:
  47. _patch_httpx_testclient_compatibility()
  48. from fastapi.testclient import TestClient
  49. return TestClient(app)
  50. def _patch_httpx_testclient_compatibility() -> None:
  51. import inspect
  52. import httpx
  53. if "app" in inspect.signature(httpx.Client.__init__).parameters:
  54. return
  55. if getattr(httpx.Client.__init__, "_agent_platform_patched", False):
  56. return
  57. original_init = httpx.Client.__init__
  58. def patched_init(self: httpx.Client, *args: Any, **kwargs: Any) -> None:
  59. kwargs.pop("app", None)
  60. original_init(self, *args, **kwargs)
  61. patched_init._agent_platform_patched = True
  62. httpx.Client.__init__ = patched_init