from datetime import datetime from core_events import EventDeliveryStatus from core_shared import JSONValue from sqlalchemy import select from sqlalchemy.orm import Session from app.db.models import EventRecord class EventRecordRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, event_id: str, event_type: str, source_service: str, aggregate_type: str | None, aggregate_id: str | None, correlation_id: str | None, causation_id: str | None, payload_json: dict[str, JSONValue], metadata_json: dict[str, JSONValue], event_time: datetime) -> EventRecord: entity = EventRecord( event_id=event_id, event_type=event_type, source_service=source_service, aggregate_type=aggregate_type, aggregate_id=aggregate_id, correlation_id=correlation_id, causation_id=causation_id, payload_json=payload_json, metadata_json=metadata_json, event_time=event_time, status="pending") self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_scope( self, *, event_type: str | None = None, source_service: str | None = None, aggregate_type: str | None = None, aggregate_id: str | None = None, correlation_id: str | None = None, status: EventDeliveryStatus | None = None, limit: int = 100) -> list[EventRecord]: stmt = select(EventRecord) if event_type is not None: stmt = stmt.where(EventRecord.event_type == event_type) if source_service is not None: stmt = stmt.where(EventRecord.source_service == source_service) if aggregate_type is not None: stmt = stmt.where(EventRecord.aggregate_type == aggregate_type) if aggregate_id is not None: stmt = stmt.where(EventRecord.aggregate_id == aggregate_id) if correlation_id is not None: stmt = stmt.where(EventRecord.correlation_id == correlation_id) if status is not None: stmt = stmt.where(EventRecord.status == status) stmt = stmt.order_by(EventRecord.event_time.desc()).limit(limit) return list(self.db.scalars(stmt)) def get_by_id(self, *, event_record_id: str) -> EventRecord | None: stmt = ( select(EventRecord) .where(EventRecord.id == event_record_id) ) return self.db.scalar(stmt) def claim_pending(self, *, limit: int) -> list[EventRecord]: stmt = ( select(EventRecord) .where(EventRecord.status == "pending") .order_by(EventRecord.event_time.asc()) .limit(limit) ) entities = list(self.db.scalars(stmt)) for entity in entities: entity.publish_attempt_count += 1 if entities: self.db.commit() for entity in entities: self.db.refresh(entity) return entities def update_delivery_status( self, *, event_record_id: str, status: EventDeliveryStatus, last_error_message: str | None = None) -> EventRecord | None: entity = self.db.get(EventRecord, event_record_id) if entity is None: return None entity.status = status entity.last_error_message = last_error_message if status == "published": entity.published_time = datetime.utcnow() self.db.commit() self.db.refresh(entity) return entity