|
|
@@ -0,0 +1,122 @@
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+import time
|
|
|
+from dataclasses import dataclass
|
|
|
+from threading import Lock
|
|
|
+from typing import TYPE_CHECKING, Protocol
|
|
|
+
|
|
|
+if TYPE_CHECKING:
|
|
|
+ from redis import Redis
|
|
|
+
|
|
|
+
|
|
|
+@dataclass(frozen=True, slots=True)
|
|
|
+class RateLimitDecision:
|
|
|
+ allowed: bool
|
|
|
+ limit: int
|
|
|
+ remaining: int
|
|
|
+ reset_epoch_seconds: int
|
|
|
+ current_count: int
|
|
|
+
|
|
|
+
|
|
|
+class RateLimiter(Protocol):
|
|
|
+ def check(
|
|
|
+ self,
|
|
|
+ *,
|
|
|
+ key: str,
|
|
|
+ limit: int,
|
|
|
+ window_seconds: int,
|
|
|
+ now_epoch_seconds: int | None = None,
|
|
|
+ ) -> RateLimitDecision: ...
|
|
|
+
|
|
|
+
|
|
|
+class InMemoryFixedWindowRateLimiter:
|
|
|
+ def __init__(self) -> None:
|
|
|
+ self._lock = Lock()
|
|
|
+ self._counts: dict[tuple[str, int], int] = {}
|
|
|
+
|
|
|
+ def check(
|
|
|
+ self,
|
|
|
+ *,
|
|
|
+ key: str,
|
|
|
+ limit: int,
|
|
|
+ window_seconds: int,
|
|
|
+ now_epoch_seconds: int | None = None,
|
|
|
+ ) -> RateLimitDecision:
|
|
|
+ now = now_epoch_seconds or int(time.time())
|
|
|
+ window = _window_start(now, window_seconds)
|
|
|
+ bucket_key = (key, window)
|
|
|
+ with self._lock:
|
|
|
+ self._cleanup_expired(now=now, window_seconds=window_seconds)
|
|
|
+ current_count = self._counts.get(bucket_key, 0) + 1
|
|
|
+ self._counts[bucket_key] = current_count
|
|
|
+ return _decision(
|
|
|
+ current_count=current_count,
|
|
|
+ limit=limit,
|
|
|
+ reset_epoch_seconds=window + window_seconds,
|
|
|
+ )
|
|
|
+
|
|
|
+ def _cleanup_expired(self, *, now: int, window_seconds: int) -> None:
|
|
|
+ cutoff = _window_start(now, window_seconds) - window_seconds
|
|
|
+ expired_keys = [
|
|
|
+ bucket_key for bucket_key in self._counts if bucket_key[1] < cutoff
|
|
|
+ ]
|
|
|
+ for bucket_key in expired_keys:
|
|
|
+ self._counts.pop(bucket_key, None)
|
|
|
+
|
|
|
+
|
|
|
+class RedisFixedWindowRateLimiter:
|
|
|
+ def __init__(self, *, client: "Redis", prefix: str = "rate-limit") -> None:
|
|
|
+ self._client = client
|
|
|
+ self._prefix = prefix
|
|
|
+
|
|
|
+ def check(
|
|
|
+ self,
|
|
|
+ *,
|
|
|
+ key: str,
|
|
|
+ limit: int,
|
|
|
+ window_seconds: int,
|
|
|
+ now_epoch_seconds: int | None = None,
|
|
|
+ ) -> RateLimitDecision:
|
|
|
+ now = now_epoch_seconds or int(time.time())
|
|
|
+ window = _window_start(now, window_seconds)
|
|
|
+ redis_key = f"{self._prefix}:{key}:{window}"
|
|
|
+ pipeline = self._client.pipeline()
|
|
|
+ pipeline.incr(redis_key)
|
|
|
+ pipeline.expire(redis_key, window_seconds * 2)
|
|
|
+ result = pipeline.execute()
|
|
|
+ current_count = int(result[0])
|
|
|
+ return _decision(
|
|
|
+ current_count=current_count,
|
|
|
+ limit=limit,
|
|
|
+ reset_epoch_seconds=window + window_seconds,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def build_rate_limiter(redis_client: "Redis | None") -> RateLimiter:
|
|
|
+ if redis_client is None:
|
|
|
+ return InMemoryFixedWindowRateLimiter()
|
|
|
+ try:
|
|
|
+ return RedisFixedWindowRateLimiter(client=redis_client)
|
|
|
+ except Exception:
|
|
|
+ return InMemoryFixedWindowRateLimiter()
|
|
|
+
|
|
|
+
|
|
|
+def _window_start(now_epoch_seconds: int, window_seconds: int) -> int:
|
|
|
+ safe_window_seconds = max(window_seconds, 1)
|
|
|
+ return now_epoch_seconds - (now_epoch_seconds % safe_window_seconds)
|
|
|
+
|
|
|
+
|
|
|
+def _decision(
|
|
|
+ *,
|
|
|
+ current_count: int,
|
|
|
+ limit: int,
|
|
|
+ reset_epoch_seconds: int,
|
|
|
+) -> RateLimitDecision:
|
|
|
+ remaining = max(limit - current_count, 0)
|
|
|
+ return RateLimitDecision(
|
|
|
+ allowed=current_count <= limit,
|
|
|
+ limit=limit,
|
|
|
+ remaining=remaining,
|
|
|
+ reset_epoch_seconds=reset_epoch_seconds,
|
|
|
+ current_count=current_count,
|
|
|
+ )
|