repositories.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from datetime import datetime
  2. from core_domain import HumanTaskStatus, HumanTaskType
  3. from core_shared import JSONValue
  4. from sqlalchemy import select
  5. from sqlalchemy.orm import Session
  6. from app.db.models import HumanTask
  7. class HumanTaskRepository:
  8. def __init__(self, db: Session) -> None:
  9. self.db = db
  10. def create(
  11. self,
  12. *,
  13. task_type: HumanTaskType,
  14. title: str,
  15. description: str | None,
  16. source_type: str | None,
  17. source_id: str | None,
  18. run_id: str | None,
  19. node_run_id: str | None,
  20. requested_by: str | None,
  21. assigned_to: str | None,
  22. request_payload_json: dict[str, JSONValue],
  23. due_time: datetime | None) -> HumanTask:
  24. entity = HumanTask(
  25. task_type=task_type,
  26. title=title,
  27. description=description,
  28. source_type=source_type,
  29. source_id=source_id,
  30. run_id=run_id,
  31. node_run_id=node_run_id,
  32. requested_by=requested_by,
  33. assigned_to=assigned_to,
  34. request_payload_json=request_payload_json,
  35. due_time=due_time)
  36. self.db.add(entity)
  37. self.db.commit()
  38. self.db.refresh(entity)
  39. return entity
  40. def list_by_scope(
  41. self,
  42. *,
  43. status: HumanTaskStatus | None = None,
  44. assigned_to: str | None = None,
  45. run_id: str | None = None,
  46. limit: int = 100) -> list[HumanTask]:
  47. stmt = select(HumanTask)
  48. if status is not None:
  49. stmt = stmt.where(HumanTask.status == status)
  50. if assigned_to is not None:
  51. stmt = stmt.where(HumanTask.assigned_to == assigned_to)
  52. if run_id is not None:
  53. stmt = stmt.where(HumanTask.run_id == run_id)
  54. stmt = stmt.order_by(HumanTask.created_time.desc()).limit(limit)
  55. return list(self.db.scalars(stmt))
  56. def get_by_id(self, *, human_task_id: str) -> HumanTask | None:
  57. stmt = (
  58. select(HumanTask)
  59. .where(HumanTask.id == human_task_id)
  60. )
  61. return self.db.scalar(stmt)
  62. def claim(
  63. self,
  64. *,
  65. human_task_id: str,
  66. claimed_by: str) -> HumanTask | None:
  67. entity = self.get_by_id(human_task_id=human_task_id)
  68. if entity is None:
  69. return None
  70. entity.status = "claimed"
  71. entity.claimed_by = claimed_by
  72. entity.claimed_time = datetime.utcnow()
  73. self.db.commit()
  74. self.db.refresh(entity)
  75. return entity
  76. def complete(
  77. self,
  78. *,
  79. human_task_id: str,
  80. status: HumanTaskStatus,
  81. response_payload_json: dict[str, JSONValue]) -> HumanTask | None:
  82. entity = self.get_by_id(human_task_id=human_task_id)
  83. if entity is None:
  84. return None
  85. entity.status = status
  86. entity.response_payload_json = response_payload_json
  87. entity.completed_time = datetime.utcnow()
  88. self.db.commit()
  89. self.db.refresh(entity)
  90. return entity