repositories.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. from datetime import datetime
  2. from sqlalchemy import select
  3. from sqlalchemy.orm import Session
  4. from core_domain import HumanTaskStatus, HumanTaskType
  5. from core_shared import JSONValue
  6. from app.db.models import HumanTask
  7. class HumanTaskRepository:
  8. def __init__(self, db: Session) -> None:
  9. self.db = db
  10. def create(
  11. self,
  12. *,
  13. tenant_id: str,
  14. task_type: HumanTaskType,
  15. title: str,
  16. description: str | None,
  17. source_type: str | None,
  18. source_id: str | None,
  19. run_id: str | None,
  20. node_run_id: str | None,
  21. requested_by: str | None,
  22. assigned_to: str | None,
  23. request_payload_json: dict[str, JSONValue],
  24. due_time: datetime | None,
  25. ) -> HumanTask:
  26. entity = HumanTask(
  27. tenant_id=tenant_id,
  28. task_type=task_type,
  29. title=title,
  30. description=description,
  31. source_type=source_type,
  32. source_id=source_id,
  33. run_id=run_id,
  34. node_run_id=node_run_id,
  35. requested_by=requested_by,
  36. assigned_to=assigned_to,
  37. request_payload_json=request_payload_json,
  38. due_time=due_time,
  39. )
  40. self.db.add(entity)
  41. self.db.commit()
  42. self.db.refresh(entity)
  43. return entity
  44. def list_by_scope(
  45. self,
  46. *,
  47. tenant_id: str,
  48. status: HumanTaskStatus | None = None,
  49. assigned_to: str | None = None,
  50. run_id: str | None = None,
  51. limit: int = 100,
  52. ) -> list[HumanTask]:
  53. stmt = select(HumanTask).where(HumanTask.tenant_id == tenant_id)
  54. if status is not None:
  55. stmt = stmt.where(HumanTask.status == status)
  56. if assigned_to is not None:
  57. stmt = stmt.where(HumanTask.assigned_to == assigned_to)
  58. if run_id is not None:
  59. stmt = stmt.where(HumanTask.run_id == run_id)
  60. stmt = stmt.order_by(HumanTask.created_time.desc()).limit(limit)
  61. return list(self.db.scalars(stmt))
  62. def get_by_id(self, *, tenant_id: str, human_task_id: str) -> HumanTask | None:
  63. stmt = (
  64. select(HumanTask)
  65. .where(HumanTask.tenant_id == tenant_id)
  66. .where(HumanTask.id == human_task_id)
  67. )
  68. return self.db.scalar(stmt)
  69. def claim(
  70. self,
  71. *,
  72. tenant_id: str,
  73. human_task_id: str,
  74. claimed_by: str,
  75. ) -> HumanTask | None:
  76. entity = self.get_by_id(tenant_id=tenant_id, human_task_id=human_task_id)
  77. if entity is None:
  78. return None
  79. entity.status = "claimed"
  80. entity.claimed_by = claimed_by
  81. entity.claimed_time = datetime.utcnow()
  82. self.db.commit()
  83. self.db.refresh(entity)
  84. return entity
  85. def complete(
  86. self,
  87. *,
  88. tenant_id: str,
  89. human_task_id: str,
  90. status: HumanTaskStatus,
  91. response_payload_json: dict[str, JSONValue],
  92. ) -> HumanTask | None:
  93. entity = self.get_by_id(tenant_id=tenant_id, human_task_id=human_task_id)
  94. if entity is None:
  95. return None
  96. entity.status = status
  97. entity.response_payload_json = response_payload_json
  98. entity.completed_time = datetime.utcnow()
  99. self.db.commit()
  100. self.db.refresh(entity)
  101. return entity