repositories.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from datetime import datetime
  2. from sqlalchemy import select
  3. from sqlalchemy.orm import Session
  4. from core_events import EventDeliveryStatus
  5. from core_shared import JSONValue
  6. from app.db.models import EventRecord
  7. class EventRecordRepository:
  8. def __init__(self, db: Session) -> None:
  9. self.db = db
  10. def create(
  11. self,
  12. *,
  13. tenant_id: str,
  14. event_id: str,
  15. event_type: str,
  16. source_service: str,
  17. aggregate_type: str | None,
  18. aggregate_id: str | None,
  19. correlation_id: str | None,
  20. causation_id: str | None,
  21. payload_json: dict[str, JSONValue],
  22. metadata_json: dict[str, JSONValue],
  23. event_time: datetime,
  24. ) -> EventRecord:
  25. entity = EventRecord(
  26. tenant_id=tenant_id,
  27. event_id=event_id,
  28. event_type=event_type,
  29. source_service=source_service,
  30. aggregate_type=aggregate_type,
  31. aggregate_id=aggregate_id,
  32. correlation_id=correlation_id,
  33. causation_id=causation_id,
  34. payload_json=payload_json,
  35. metadata_json=metadata_json,
  36. event_time=event_time,
  37. status="pending",
  38. )
  39. self.db.add(entity)
  40. self.db.commit()
  41. self.db.refresh(entity)
  42. return entity
  43. def list_by_scope(
  44. self,
  45. *,
  46. tenant_id: str,
  47. event_type: str | None = None,
  48. source_service: str | None = None,
  49. aggregate_type: str | None = None,
  50. aggregate_id: str | None = None,
  51. correlation_id: str | None = None,
  52. status: EventDeliveryStatus | None = None,
  53. limit: int = 100,
  54. ) -> list[EventRecord]:
  55. stmt = select(EventRecord).where(EventRecord.tenant_id == tenant_id)
  56. if event_type is not None:
  57. stmt = stmt.where(EventRecord.event_type == event_type)
  58. if source_service is not None:
  59. stmt = stmt.where(EventRecord.source_service == source_service)
  60. if aggregate_type is not None:
  61. stmt = stmt.where(EventRecord.aggregate_type == aggregate_type)
  62. if aggregate_id is not None:
  63. stmt = stmt.where(EventRecord.aggregate_id == aggregate_id)
  64. if correlation_id is not None:
  65. stmt = stmt.where(EventRecord.correlation_id == correlation_id)
  66. if status is not None:
  67. stmt = stmt.where(EventRecord.status == status)
  68. stmt = stmt.order_by(EventRecord.event_time.desc()).limit(limit)
  69. return list(self.db.scalars(stmt))
  70. def get_by_id(self, *, tenant_id: str, event_record_id: str) -> EventRecord | None:
  71. stmt = (
  72. select(EventRecord)
  73. .where(EventRecord.tenant_id == tenant_id)
  74. .where(EventRecord.id == event_record_id)
  75. )
  76. return self.db.scalar(stmt)
  77. def claim_pending(self, *, tenant_id: str, limit: int) -> list[EventRecord]:
  78. stmt = (
  79. select(EventRecord)
  80. .where(EventRecord.tenant_id == tenant_id)
  81. .where(EventRecord.status == "pending")
  82. .order_by(EventRecord.event_time.asc())
  83. .limit(limit)
  84. )
  85. entities = list(self.db.scalars(stmt))
  86. for entity in entities:
  87. entity.publish_attempt_count += 1
  88. if entities:
  89. self.db.commit()
  90. for entity in entities:
  91. self.db.refresh(entity)
  92. return entities
  93. def update_delivery_status(
  94. self,
  95. *,
  96. event_record_id: str,
  97. status: EventDeliveryStatus,
  98. last_error_message: str | None = None,
  99. ) -> EventRecord | None:
  100. entity = self.db.get(EventRecord, event_record_id)
  101. if entity is None:
  102. return None
  103. entity.status = status
  104. entity.last_error_message = last_error_message
  105. if status == "published":
  106. entity.published_time = datetime.utcnow()
  107. self.db.commit()
  108. self.db.refresh(entity)
  109. return entity