repositories.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from datetime import datetime
  2. from core_events import EventDeliveryStatus
  3. from core_shared import JSONValue
  4. from sqlalchemy import select
  5. from sqlalchemy.orm import Session
  6. from app.db.models import EventRecord
  7. class EventRecordRepository:
  8. def __init__(self, db: Session) -> None:
  9. self.db = db
  10. def create(
  11. self,
  12. *,
  13. event_id: str,
  14. event_type: str,
  15. source_service: str,
  16. aggregate_type: str | None,
  17. aggregate_id: str | None,
  18. correlation_id: str | None,
  19. causation_id: str | None,
  20. payload_json: dict[str, JSONValue],
  21. metadata_json: dict[str, JSONValue],
  22. event_time: datetime) -> EventRecord:
  23. entity = EventRecord(
  24. event_id=event_id,
  25. event_type=event_type,
  26. source_service=source_service,
  27. aggregate_type=aggregate_type,
  28. aggregate_id=aggregate_id,
  29. correlation_id=correlation_id,
  30. causation_id=causation_id,
  31. payload_json=payload_json,
  32. metadata_json=metadata_json,
  33. event_time=event_time,
  34. status="pending")
  35. self.db.add(entity)
  36. self.db.commit()
  37. self.db.refresh(entity)
  38. return entity
  39. def list_by_scope(
  40. self,
  41. *,
  42. event_type: str | None = None,
  43. source_service: str | None = None,
  44. aggregate_type: str | None = None,
  45. aggregate_id: str | None = None,
  46. correlation_id: str | None = None,
  47. status: EventDeliveryStatus | None = None,
  48. limit: int = 100) -> list[EventRecord]:
  49. stmt = select(EventRecord)
  50. if event_type is not None:
  51. stmt = stmt.where(EventRecord.event_type == event_type)
  52. if source_service is not None:
  53. stmt = stmt.where(EventRecord.source_service == source_service)
  54. if aggregate_type is not None:
  55. stmt = stmt.where(EventRecord.aggregate_type == aggregate_type)
  56. if aggregate_id is not None:
  57. stmt = stmt.where(EventRecord.aggregate_id == aggregate_id)
  58. if correlation_id is not None:
  59. stmt = stmt.where(EventRecord.correlation_id == correlation_id)
  60. if status is not None:
  61. stmt = stmt.where(EventRecord.status == status)
  62. stmt = stmt.order_by(EventRecord.event_time.desc()).limit(limit)
  63. return list(self.db.scalars(stmt))
  64. def get_by_id(self, *, event_record_id: str) -> EventRecord | None:
  65. stmt = (
  66. select(EventRecord)
  67. .where(EventRecord.id == event_record_id)
  68. )
  69. return self.db.scalar(stmt)
  70. def claim_pending(self, *, limit: int) -> list[EventRecord]:
  71. stmt = (
  72. select(EventRecord)
  73. .where(EventRecord.status == "pending")
  74. .order_by(EventRecord.event_time.asc())
  75. .limit(limit)
  76. )
  77. entities = list(self.db.scalars(stmt))
  78. for entity in entities:
  79. entity.publish_attempt_count += 1
  80. if entities:
  81. self.db.commit()
  82. for entity in entities:
  83. self.db.refresh(entity)
  84. return entities
  85. def update_delivery_status(
  86. self,
  87. *,
  88. event_record_id: str,
  89. status: EventDeliveryStatus,
  90. last_error_message: str | None = None) -> EventRecord | None:
  91. entity = self.db.get(EventRecord, event_record_id)
  92. if entity is None:
  93. return None
  94. entity.status = status
  95. entity.last_error_message = last_error_message
  96. if status == "published":
  97. entity.published_time = datetime.utcnow()
  98. self.db.commit()
  99. self.db.refresh(entity)
  100. return entity