repositories.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. from datetime import datetime
  2. from core_domain import RoleAssignmentStatus, RoleStatus, UserStatus
  3. from core_shared import JSONValue
  4. from sqlalchemy import select
  5. from sqlalchemy.orm import Session
  6. from app.db.models import ApiKey, Role, RoleAssignment, RolePermissionBinding, User
  7. class UserRepository:
  8. def __init__(self, db: Session) -> None:
  9. self.db = db
  10. def create(
  11. self,
  12. *,
  13. username: str,
  14. password_hash: str,
  15. display_name: str | None,
  16. email: str | None,
  17. metadata_json: dict[str, JSONValue]) -> User:
  18. entity = User(
  19. username=username,
  20. password_hash=password_hash,
  21. display_name=display_name,
  22. email=email,
  23. metadata_json=metadata_json)
  24. self.db.add(entity)
  25. self.db.commit()
  26. self.db.refresh(entity)
  27. return entity
  28. def list_all(self) -> list[User]:
  29. stmt = select(User).order_by(User.created_time.desc())
  30. return list(self.db.scalars(stmt))
  31. def has_any(self) -> bool:
  32. stmt = select(User.id).limit(1)
  33. return self.db.scalar(stmt) is not None
  34. def list_page(self, *, offset: int, limit: int, keyword: str | None) -> tuple[list[User], int]:
  35. stmt = select(User)
  36. if keyword:
  37. like = f"%{keyword}%"
  38. stmt = stmt.where(
  39. (User.username.like(like))
  40. | (User.display_name.like(like))
  41. | (User.email.like(like))
  42. )
  43. total = len(list(self.db.scalars(stmt)))
  44. items_stmt = stmt.order_by(User.created_time.desc()).offset(offset).limit(limit)
  45. return list(self.db.scalars(items_stmt)), total
  46. def get_by_id(self, *, user_id: str) -> User | None:
  47. return self.db.get(User, user_id)
  48. def get_by_username(self, *, username: str) -> User | None:
  49. stmt = select(User).where(User.username == username)
  50. return self.db.scalar(stmt)
  51. def touch_last_login_time(self, *, user_id: str) -> None:
  52. entity = self.db.get(User, user_id)
  53. if entity is None:
  54. return
  55. entity.last_login_time = datetime.utcnow()
  56. self.db.commit()
  57. self.db.refresh(entity)
  58. def update_status(self, *, user_id: str, status: UserStatus) -> User | None:
  59. entity = self.get_by_id(user_id=user_id)
  60. if entity is None:
  61. return None
  62. entity.status = status
  63. self.db.commit()
  64. self.db.refresh(entity)
  65. return entity
  66. class RoleRepository:
  67. def __init__(self, db: Session) -> None:
  68. self.db = db
  69. def create(
  70. self,
  71. *,
  72. code: str,
  73. name: str,
  74. description: str | None,
  75. permissions_json: list[str]) -> Role:
  76. entity = Role(
  77. code=code,
  78. name=name,
  79. description=description,
  80. permissions_json=permissions_json)
  81. self.db.add(entity)
  82. self.db.commit()
  83. self.db.refresh(entity)
  84. return entity
  85. def list_all(self) -> list[Role]:
  86. stmt = select(Role).order_by(Role.created_time.desc())
  87. return list(self.db.scalars(stmt))
  88. def get_by_name(self, *, name: str) -> Role | None:
  89. stmt = select(Role).where(Role.name == name)
  90. return self.db.scalar(stmt)
  91. def list_page(self, *, offset: int, limit: int, keyword: str | None) -> tuple[list[Role], int]:
  92. stmt = select(Role)
  93. if keyword:
  94. like = f"%{keyword}%"
  95. stmt = stmt.where((Role.name.like(like)) | (Role.description.like(like)))
  96. total = len(list(self.db.scalars(stmt)))
  97. items_stmt = stmt.order_by(Role.created_time.desc()).offset(offset).limit(limit)
  98. return list(self.db.scalars(items_stmt)), total
  99. def get_by_id(self, *, role_id: str) -> Role | None:
  100. return self.db.get(Role, role_id)
  101. def update_status(self, *, role_id: str, status: RoleStatus) -> Role | None:
  102. entity = self.get_by_id(role_id=role_id)
  103. if entity is None:
  104. return None
  105. entity.status = status
  106. self.db.commit()
  107. self.db.refresh(entity)
  108. return entity
  109. class RoleAssignmentRepository:
  110. def __init__(self, db: Session) -> None:
  111. self.db = db
  112. def create(
  113. self,
  114. *,
  115. user_id: str,
  116. role_id: str,
  117. scope_type: str | None,
  118. scope_id: str | None,
  119. expires_time: datetime | None) -> RoleAssignment:
  120. entity = RoleAssignment(
  121. user_id=user_id,
  122. role_id=role_id,
  123. scope_type=scope_type,
  124. scope_id=scope_id,
  125. expires_time=expires_time)
  126. self.db.add(entity)
  127. self.db.commit()
  128. self.db.refresh(entity)
  129. return entity
  130. def list_by_user(self, *, user_id: str) -> list[RoleAssignment]:
  131. stmt = (
  132. select(RoleAssignment)
  133. .where(RoleAssignment.user_id == user_id)
  134. .order_by(RoleAssignment.created_time.desc())
  135. )
  136. return list(self.db.scalars(stmt))
  137. def get_by_id(
  138. self,
  139. *,
  140. assignment_id: str) -> RoleAssignment | None:
  141. return self.db.get(RoleAssignment, assignment_id)
  142. def update_status(
  143. self,
  144. *,
  145. assignment_id: str,
  146. status: RoleAssignmentStatus) -> RoleAssignment | None:
  147. entity = self.get_by_id(assignment_id=assignment_id)
  148. if entity is None:
  149. return None
  150. entity.status = status
  151. self.db.commit()
  152. self.db.refresh(entity)
  153. return entity
  154. class RolePermissionBindingRepository:
  155. def __init__(self, db: Session) -> None:
  156. self.db = db
  157. def create(
  158. self,
  159. *,
  160. role_id: str,
  161. permission: str,
  162. scope_type: str | None,
  163. scope_id: str | None) -> RolePermissionBinding:
  164. entity = RolePermissionBinding(
  165. role_id=role_id,
  166. permission=permission,
  167. scope_type=scope_type,
  168. scope_id=scope_id)
  169. self.db.add(entity)
  170. self.db.commit()
  171. self.db.refresh(entity)
  172. return entity
  173. def list_by_role(
  174. self,
  175. *,
  176. role_id: str,
  177. offset: int = 0,
  178. limit: int = 100) -> tuple[list[RolePermissionBinding], int]:
  179. stmt = select(RolePermissionBinding).where(RolePermissionBinding.role_id == role_id)
  180. total = len(list(self.db.scalars(stmt)))
  181. items_stmt = (
  182. stmt.order_by(RolePermissionBinding.created_time.desc())
  183. .offset(offset)
  184. .limit(limit)
  185. )
  186. return list(self.db.scalars(items_stmt)), total
  187. def list_all_by_role(self, *, role_id: str) -> list[RolePermissionBinding]:
  188. stmt = (
  189. select(RolePermissionBinding)
  190. .where(RolePermissionBinding.role_id == role_id)
  191. .order_by(RolePermissionBinding.created_time.desc())
  192. )
  193. return list(self.db.scalars(stmt))
  194. def delete(self, *, binding_id: str) -> bool:
  195. entity = self.db.get(RolePermissionBinding, binding_id)
  196. if entity is None:
  197. return False
  198. self.db.delete(entity)
  199. self.db.commit()
  200. return True
  201. class ApiKeyRepository:
  202. def __init__(self, db: Session) -> None:
  203. self.db = db
  204. def create(
  205. self,
  206. *,
  207. name: str,
  208. key_prefix: str,
  209. key_hash: str,
  210. scopes: str | None,
  211. expires_time: datetime | None) -> ApiKey:
  212. entity = ApiKey(
  213. name=name,
  214. key_prefix=key_prefix,
  215. key_hash=key_hash,
  216. scopes=scopes,
  217. expires_time=expires_time)
  218. self.db.add(entity)
  219. self.db.commit()
  220. self.db.refresh(entity)
  221. return entity
  222. def list_page(
  223. self,
  224. *,
  225. offset: int,
  226. limit: int,
  227. keyword: str | None) -> tuple[list[ApiKey], int]:
  228. stmt = select(ApiKey)
  229. if keyword:
  230. stmt = stmt.where(ApiKey.name.like(f"%{keyword}%"))
  231. total = len(list(self.db.scalars(stmt)))
  232. items_stmt = stmt.order_by(ApiKey.created_time.desc()).offset(offset).limit(limit)
  233. return list(self.db.scalars(items_stmt)), total
  234. def get_by_id(self, *, api_key_id: str) -> ApiKey | None:
  235. return self.db.get(ApiKey, api_key_id)
  236. def revoke(self, *, api_key_id: str) -> ApiKey | None:
  237. entity = self.get_by_id(api_key_id=api_key_id)
  238. if entity is None:
  239. return None
  240. entity.revoked_time = datetime.utcnow()
  241. self.db.commit()
  242. self.db.refresh(entity)
  243. return entity