tokens.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import base64
  2. import hashlib
  3. import hmac
  4. import json
  5. import secrets
  6. from datetime import datetime, timedelta
  7. TOKEN_PREFIX = "apt"
  8. class TokenError(ValueError):
  9. pass
  10. def issue_access_token(
  11. *,
  12. user_id: str,
  13. secret: str,
  14. ttl_minutes: int = 480) -> tuple[str, datetime]:
  15. expires_time = datetime.utcnow() + timedelta(minutes=ttl_minutes)
  16. payload = {
  17. "user_id": user_id,
  18. "expires_time": expires_time.isoformat(timespec="seconds") + "Z",
  19. "nonce": secrets.token_urlsafe(16),
  20. }
  21. encoded_payload = _urlsafe_b64encode(
  22. json.dumps(payload, separators=(",", ":")).encode("utf-8")
  23. )
  24. signature = _sign(encoded_payload, secret)
  25. return f"{TOKEN_PREFIX}_{encoded_payload}.{signature}", expires_time
  26. def verify_access_token(token: str, *, secret: str) -> dict[str, str]:
  27. encoded_payload, signature = _split_token(token)
  28. expected_signature = _sign(encoded_payload, secret)
  29. if not hmac.compare_digest(signature, expected_signature):
  30. raise TokenError("invalid token signature")
  31. try:
  32. payload = json.loads(_urlsafe_b64decode(encoded_payload).decode("utf-8"))
  33. except (json.JSONDecodeError, UnicodeDecodeError) as exc:
  34. raise TokenError("invalid token payload") from exc
  35. user_id = payload.get("user_id")
  36. expires_time_raw = payload.get("expires_time")
  37. if not isinstance(user_id, str) or not user_id:
  38. raise TokenError("missing user_id")
  39. if not isinstance(expires_time_raw, str) or not expires_time_raw:
  40. raise TokenError("missing expires_time")
  41. expires_time = datetime.fromisoformat(expires_time_raw.removesuffix("Z"))
  42. if expires_time <= datetime.utcnow():
  43. raise TokenError("token expired")
  44. return {
  45. "user_id": user_id,
  46. "expires_time": expires_time_raw,
  47. }
  48. def _split_token(token: str) -> tuple[str, str]:
  49. if not token.startswith(f"{TOKEN_PREFIX}_"):
  50. raise TokenError("invalid token prefix")
  51. token_body = token.removeprefix(f"{TOKEN_PREFIX}_")
  52. try:
  53. encoded_payload, signature = token_body.split(".", 1)
  54. except ValueError as exc:
  55. raise TokenError("invalid token format") from exc
  56. if not encoded_payload or not signature:
  57. raise TokenError("invalid token format")
  58. return encoded_payload, signature
  59. def _sign(encoded_payload: str, secret: str) -> str:
  60. digest = hmac.new(
  61. secret.encode("utf-8"),
  62. encoded_payload.encode("ascii"),
  63. hashlib.sha256).digest()
  64. return _urlsafe_b64encode(digest)
  65. def _urlsafe_b64encode(value: bytes) -> str:
  66. return base64.urlsafe_b64encode(value).decode("ascii").rstrip("=")
  67. def _urlsafe_b64decode(value: str) -> bytes:
  68. padding = "=" * (-len(value) % 4)
  69. return base64.urlsafe_b64decode(f"{value}{padding}".encode("ascii"))