Coverage for polar/kit/repository/base.py: 56%

115 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 16:17 +0000

1from collections.abc import AsyncGenerator, Sequence 1ab

2from datetime import datetime 1ab

3from enum import StrEnum 1ab

4from typing import Any, Protocol, Self, TypeAlias 1ab

5 

6from sqlalchemy import Select, UnaryExpression, asc, desc, func, over, select 1ab

7from sqlalchemy.orm import Mapped 1ab

8from sqlalchemy.orm.attributes import flag_modified 1ab

9from sqlalchemy.sql.base import ExecutableOption 1ab

10from sqlalchemy.sql.expression import ColumnExpressionArgument 1ab

11 

12from polar.config import settings 1ab

13from polar.kit.db.postgres import AsyncReadSession, AsyncSession 1ab

14from polar.kit.sorting import Sorting 1ab

15from polar.kit.utils import utc_now 1ab

16 

17 

18class ModelDeletedAtProtocol(Protocol): 1ab

19 deleted_at: Mapped[datetime | None] 1ab

20 

21 

22class ModelIDProtocol[ID_TYPE](Protocol): 1ab

23 id: Mapped[ID_TYPE] 1ab

24 

25 

26class ModelDeletedAtIDProtocol[ID_TYPE](Protocol): 1ab

27 id: Mapped[ID_TYPE] 1ab

28 deleted_at: Mapped[datetime | None] 1ab

29 

30 

31Options: TypeAlias = Sequence[ExecutableOption] 1ab

32 

33 

34class RepositoryProtocol[M](Protocol): 1ab

35 model: type[M] 1ab

36 

37 async def get_one(self, statement: Select[tuple[M]]) -> M: ... 37 ↛ exitline 37 didn't return from function 'get_one' because 1ab

38 

39 async def get_one_or_none(self, statement: Select[tuple[M]]) -> M | None: ... 39 ↛ exitline 39 didn't return from function 'get_one_or_none' because 1ab

40 

41 async def get_all(self, statement: Select[tuple[M]]) -> Sequence[M]: ... 41 ↛ exitline 41 didn't return from function 'get_all' because 1ab

42 

43 async def paginate( 43 ↛ exitline 43 didn't return from function 'paginate' because 1ab

44 self, statement: Select[tuple[M]], *, limit: int, page: int 

45 ) -> tuple[list[M], int]: ... 

46 

47 def get_base_statement(self) -> Select[tuple[M]]: ... 47 ↛ exitline 47 didn't return from function 'get_base_statement' because 1ab

48 

49 async def create(self, object: M, *, flush: bool = False) -> M: ... 49 ↛ exitline 49 didn't return from function 'create' because 1ab

50 

51 async def update( 51 ↛ anywhereline 51 didn't jump anywhere: it always raised an exception.1ab

52 self, 

53 object: M, 

54 *, 

55 update_dict: dict[str, Any] | None = None, 

56 flush: bool = False, 

57 ) -> M: ... 

58 

59 

60class RepositoryBase[M]: 1ab

61 model: type[M] 

62 

63 def __init__(self, session: AsyncSession | AsyncReadSession) -> None: 1ab

64 self.session = session 1c

65 

66 async def get_one(self, statement: Select[tuple[M]]) -> M: 1ab

67 result = await self.session.execute(statement) 

68 return result.unique().scalar_one() 

69 

70 async def get_one_or_none(self, statement: Select[tuple[M]]) -> M | None: 1ab

71 result = await self.session.execute(statement) 1c

72 return result.unique().scalar_one_or_none() 

73 

74 async def get_all(self, statement: Select[tuple[M]]) -> Sequence[M]: 1ab

75 result = await self.session.execute(statement) 

76 return result.scalars().unique().all() 

77 

78 async def stream(self, statement: Select[tuple[M]]) -> AsyncGenerator[M, None]: 1ab

79 """ 

80 Stream results from the database using the given statement. 

81 

82 This is useful for processing large datasets without loading everything 

83 into memory at once. 

84 

85 The caveat is that your statement shouldn't join many-to-one or 

86 many-to-many relationships, as we can't apply ORM's `unique()` method 

87 to the results, which may lead to duplicates. 

88 

89 Args: 

90 statement: The SQLAlchemy select statement to execute. 

91 

92 Yields: 

93 Instances of the model `M` as they are fetched from the database. 

94 """ 

95 results = await self.session.stream_scalars( 

96 statement, 

97 execution_options={"yield_per": settings.DATABASE_STREAM_YIELD_PER}, 

98 ) 

99 try: 

100 async for result in results: 

101 yield result 

102 finally: 

103 await results.close() 

104 

105 async def paginate( 1ab

106 self, statement: Select[tuple[M]], *, limit: int, page: int 

107 ) -> tuple[list[M], int]: 

108 offset = (page - 1) * limit 1c

109 paginated_statement: Select[tuple[M, int]] = ( 1c

110 statement.add_columns(over(func.count())).limit(limit).offset(offset) 

111 ) 

112 # Streaming can't be applied here, since we need to call ORM's unique() 

113 results = await self.session.execute(paginated_statement) 1c

114 

115 items: list[M] = [] 

116 count = 0 

117 for result in results.unique().all(): 

118 item, count = result._tuple() 

119 items.append(item) 

120 

121 return items, count 

122 

123 def get_base_statement(self) -> Select[tuple[M]]: 1ab

124 return select(self.model) 1c

125 

126 async def create(self, object: M, *, flush: bool = False) -> M: 1ab

127 self.session.add(object) 

128 

129 if flush: 

130 await self.session.flush() 

131 

132 return object 

133 

134 async def update( 1ab

135 self, 

136 object: M, 

137 *, 

138 update_dict: dict[str, Any] | None = None, 

139 flush: bool = False, 

140 ) -> M: 

141 if update_dict is not None: 

142 for attr, value in update_dict.items(): 

143 setattr(object, attr, value) 

144 # Always consider that the attribute was modified if it's explictly set 

145 # in the update_dict. This forces SQLAlchemy to include it in the 

146 # UPDATE statement, even if the value is the same as before. 

147 # Ref: https://docs.sqlalchemy.org/en/20/orm/session_api.html#sqlalchemy.orm.attributes.flag_modified 

148 try: 

149 flag_modified(object, attr) 

150 # Don't fail if the attribute is not tracked by SQLAlchemy 

151 except KeyError: 

152 pass 

153 

154 self.session.add(object) 

155 

156 if flush: 

157 await self.session.flush() 

158 

159 return object 

160 

161 async def count(self, statement: Select[tuple[M]]) -> int: 1ab

162 count_statement = statement.with_only_columns(func.count()) 

163 result = await self.session.execute(count_statement) 

164 return result.scalar_one() 

165 

166 @classmethod 1ab

167 def from_session(cls, session: AsyncSession | AsyncReadSession) -> Self: 1ab

168 return cls(session) 1c

169 

170 

171class RepositorySoftDeletionProtocol[MODEL_DELETED_AT: ModelDeletedAtProtocol]( 1ab

172 RepositoryProtocol[MODEL_DELETED_AT], Protocol 

173): 

174 def get_base_statement( 174 ↛ exitline 174 didn't return from function 'get_base_statement' because 1ab

175 self, *, include_deleted: bool = False 

176 ) -> Select[tuple[MODEL_DELETED_AT]]: ... 

177 

178 async def soft_delete( 178 ↛ exitline 178 didn't return from function 'soft_delete' because 1ab

179 self, object: MODEL_DELETED_AT, *, flush: bool = False 

180 ) -> MODEL_DELETED_AT: ... 

181 

182 

183class RepositorySoftDeletionMixin[MODEL_DELETED_AT: ModelDeletedAtProtocol]: 1ab

184 def get_base_statement( 1ab

185 self: RepositoryProtocol[MODEL_DELETED_AT], 

186 *, 

187 include_deleted: bool = False, 

188 ) -> Select[tuple[MODEL_DELETED_AT]]: 

189 statement = super().get_base_statement() # type: ignore[safe-super] 1c

190 if not include_deleted: 190 ↛ 192line 190 didn't jump to line 192 because the condition on line 190 was always true1c

191 statement = statement.where(self.model.deleted_at.is_(None)) 1c

192 return statement 1c

193 

194 async def soft_delete( 1ab

195 self: RepositoryProtocol[MODEL_DELETED_AT], 

196 object: MODEL_DELETED_AT, 

197 *, 

198 flush: bool = False, 

199 ) -> MODEL_DELETED_AT: 

200 return await self.update( 

201 object, update_dict={"deleted_at": utc_now()}, flush=flush 

202 ) 

203 

204 

205class RepositoryIDMixin[MODEL_ID: ModelIDProtocol, ID_TYPE]: # type: ignore[type-arg] 1ab

206 async def get_by_id( 1ab

207 self: RepositoryProtocol[MODEL_ID], 

208 id: ID_TYPE, 

209 *, 

210 options: Options = (), 

211 ) -> MODEL_ID | None: 

212 statement = ( 

213 self.get_base_statement().where(self.model.id == id).options(*options) 

214 ) 

215 return await self.get_one_or_none(statement) 

216 

217 

218class RepositorySoftDeletionIDMixin[ 1ab

219 MODEL_DELETED_AT_ID: ModelDeletedAtIDProtocol, # type: ignore[type-arg] 

220 ID_TYPE, 

221]: 

222 async def get_by_id( 1ab

223 self: RepositorySoftDeletionProtocol[MODEL_DELETED_AT_ID], 

224 id: ID_TYPE, 

225 *, 

226 options: Options = (), 

227 include_deleted: bool = False, 

228 ) -> MODEL_DELETED_AT_ID | None: 

229 statement = ( 

230 self.get_base_statement(include_deleted=include_deleted) 

231 .where(self.model.id == id) 

232 .options(*options) 

233 ) 

234 return await self.get_one_or_none(statement) 

235 

236 

237SortingClause: TypeAlias = ColumnExpressionArgument[Any] | UnaryExpression[Any] 1ab

238 

239 

240class RepositorySortingMixin[M, PE: StrEnum]: 1ab

241 sorting_enum: type[PE] 

242 

243 def apply_sorting( 1ab

244 self, 

245 statement: Select[tuple[M]], 

246 sorting: list[Sorting[PE]], 

247 ) -> Select[tuple[M]]: 

248 order_by_clauses: list[UnaryExpression[Any]] = [] 1c

249 for criterion, is_desc in sorting: 1c

250 clause_function = desc if is_desc else asc 1c

251 order_by_clauses.append(clause_function(self.get_sorting_clause(criterion))) 1c

252 return statement.order_by(*order_by_clauses) 1c

253 

254 def get_sorting_clause(self, property: PE) -> SortingClause: 1ab

255 raise NotImplementedError()