import base64 import hashlib import hmac import json import secrets from datetime import datetime, timedelta TOKEN_PREFIX = "apt" class TokenError(ValueError): pass def issue_access_token( *, user_id: str, secret: str, ttl_minutes: int = 480) -> tuple[str, datetime]: expires_time = datetime.utcnow() + timedelta(minutes=ttl_minutes) payload = { "user_id": user_id, "expires_time": expires_time.isoformat(timespec="seconds") + "Z", "nonce": secrets.token_urlsafe(16), } encoded_payload = _urlsafe_b64encode( json.dumps(payload, separators=(",", ":")).encode("utf-8") ) signature = _sign(encoded_payload, secret) return f"{TOKEN_PREFIX}_{encoded_payload}.{signature}", expires_time def verify_access_token(token: str, *, secret: str) -> dict[str, str]: encoded_payload, signature = _split_token(token) expected_signature = _sign(encoded_payload, secret) if not hmac.compare_digest(signature, expected_signature): raise TokenError("invalid token signature") try: payload = json.loads(_urlsafe_b64decode(encoded_payload).decode("utf-8")) except (json.JSONDecodeError, UnicodeDecodeError) as exc: raise TokenError("invalid token payload") from exc user_id = payload.get("user_id") expires_time_raw = payload.get("expires_time") if not isinstance(user_id, str) or not user_id: raise TokenError("missing user_id") if not isinstance(expires_time_raw, str) or not expires_time_raw: raise TokenError("missing expires_time") expires_time = datetime.fromisoformat(expires_time_raw.removesuffix("Z")) if expires_time <= datetime.utcnow(): raise TokenError("token expired") return { "user_id": user_id, "expires_time": expires_time_raw, } def _split_token(token: str) -> tuple[str, str]: if not token.startswith(f"{TOKEN_PREFIX}_"): raise TokenError("invalid token prefix") token_body = token.removeprefix(f"{TOKEN_PREFIX}_") try: encoded_payload, signature = token_body.split(".", 1) except ValueError as exc: raise TokenError("invalid token format") from exc if not encoded_payload or not signature: raise TokenError("invalid token format") return encoded_payload, signature def _sign(encoded_payload: str, secret: str) -> str: digest = hmac.new( secret.encode("utf-8"), encoded_payload.encode("ascii"), hashlib.sha256).digest() return _urlsafe_b64encode(digest) def _urlsafe_b64encode(value: bytes) -> str: return base64.urlsafe_b64encode(value).decode("ascii").rstrip("=") def _urlsafe_b64decode(value: str) -> bytes: padding = "=" * (-len(value) % 4) return base64.urlsafe_b64decode(f"{value}{padding}".encode("ascii"))