Coverage for polar/kit/repository/base.py: 56%
115 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
1from collections.abc import AsyncGenerator, Sequence 1ab
2from datetime import datetime 1ab
3from enum import StrEnum 1ab
4from typing import Any, Protocol, Self, TypeAlias 1ab
6from sqlalchemy import Select, UnaryExpression, asc, desc, func, over, select 1ab
7from sqlalchemy.orm import Mapped 1ab
8from sqlalchemy.orm.attributes import flag_modified 1ab
9from sqlalchemy.sql.base import ExecutableOption 1ab
10from sqlalchemy.sql.expression import ColumnExpressionArgument 1ab
12from polar.config import settings 1ab
13from polar.kit.db.postgres import AsyncReadSession, AsyncSession 1ab
14from polar.kit.sorting import Sorting 1ab
15from polar.kit.utils import utc_now 1ab
18class ModelDeletedAtProtocol(Protocol): 1ab
19 deleted_at: Mapped[datetime | None] 1ab
22class ModelIDProtocol[ID_TYPE](Protocol): 1ab
23 id: Mapped[ID_TYPE] 1ab
26class ModelDeletedAtIDProtocol[ID_TYPE](Protocol): 1ab
27 id: Mapped[ID_TYPE] 1ab
28 deleted_at: Mapped[datetime | None] 1ab
31Options: TypeAlias = Sequence[ExecutableOption] 1ab
34class RepositoryProtocol[M](Protocol): 1ab
35 model: type[M] 1ab
37 async def get_one(self, statement: Select[tuple[M]]) -> M: ... 37 ↛ exitline 37 didn't return from function 'get_one' because 1ab
39 async def get_one_or_none(self, statement: Select[tuple[M]]) -> M | None: ... 39 ↛ exitline 39 didn't return from function 'get_one_or_none' because 1ab
41 async def get_all(self, statement: Select[tuple[M]]) -> Sequence[M]: ... 41 ↛ exitline 41 didn't return from function 'get_all' because 1ab
43 async def paginate( 43 ↛ exitline 43 didn't return from function 'paginate' because 1ab
44 self, statement: Select[tuple[M]], *, limit: int, page: int
45 ) -> tuple[list[M], int]: ...
47 def get_base_statement(self) -> Select[tuple[M]]: ... 47 ↛ exitline 47 didn't return from function 'get_base_statement' because 1ab
49 async def create(self, object: M, *, flush: bool = False) -> M: ... 49 ↛ exitline 49 didn't return from function 'create' because 1ab
51 async def update( 51 ↛ anywhereline 51 didn't jump anywhere: it always raised an exception.1ab
52 self,
53 object: M,
54 *,
55 update_dict: dict[str, Any] | None = None,
56 flush: bool = False,
57 ) -> M: ...
60class RepositoryBase[M]: 1ab
61 model: type[M]
63 def __init__(self, session: AsyncSession | AsyncReadSession) -> None: 1ab
64 self.session = session 1c
66 async def get_one(self, statement: Select[tuple[M]]) -> M: 1ab
67 result = await self.session.execute(statement)
68 return result.unique().scalar_one()
70 async def get_one_or_none(self, statement: Select[tuple[M]]) -> M | None: 1ab
71 result = await self.session.execute(statement) 1c
72 return result.unique().scalar_one_or_none()
74 async def get_all(self, statement: Select[tuple[M]]) -> Sequence[M]: 1ab
75 result = await self.session.execute(statement)
76 return result.scalars().unique().all()
78 async def stream(self, statement: Select[tuple[M]]) -> AsyncGenerator[M, None]: 1ab
79 """
80 Stream results from the database using the given statement.
82 This is useful for processing large datasets without loading everything
83 into memory at once.
85 The caveat is that your statement shouldn't join many-to-one or
86 many-to-many relationships, as we can't apply ORM's `unique()` method
87 to the results, which may lead to duplicates.
89 Args:
90 statement: The SQLAlchemy select statement to execute.
92 Yields:
93 Instances of the model `M` as they are fetched from the database.
94 """
95 results = await self.session.stream_scalars(
96 statement,
97 execution_options={"yield_per": settings.DATABASE_STREAM_YIELD_PER},
98 )
99 try:
100 async for result in results:
101 yield result
102 finally:
103 await results.close()
105 async def paginate( 1ab
106 self, statement: Select[tuple[M]], *, limit: int, page: int
107 ) -> tuple[list[M], int]:
108 offset = (page - 1) * limit 1c
109 paginated_statement: Select[tuple[M, int]] = ( 1c
110 statement.add_columns(over(func.count())).limit(limit).offset(offset)
111 )
112 # Streaming can't be applied here, since we need to call ORM's unique()
113 results = await self.session.execute(paginated_statement) 1c
115 items: list[M] = []
116 count = 0
117 for result in results.unique().all():
118 item, count = result._tuple()
119 items.append(item)
121 return items, count
123 def get_base_statement(self) -> Select[tuple[M]]: 1ab
124 return select(self.model) 1c
126 async def create(self, object: M, *, flush: bool = False) -> M: 1ab
127 self.session.add(object)
129 if flush:
130 await self.session.flush()
132 return object
134 async def update( 1ab
135 self,
136 object: M,
137 *,
138 update_dict: dict[str, Any] | None = None,
139 flush: bool = False,
140 ) -> M:
141 if update_dict is not None:
142 for attr, value in update_dict.items():
143 setattr(object, attr, value)
144 # Always consider that the attribute was modified if it's explictly set
145 # in the update_dict. This forces SQLAlchemy to include it in the
146 # UPDATE statement, even if the value is the same as before.
147 # Ref: https://docs.sqlalchemy.org/en/20/orm/session_api.html#sqlalchemy.orm.attributes.flag_modified
148 try:
149 flag_modified(object, attr)
150 # Don't fail if the attribute is not tracked by SQLAlchemy
151 except KeyError:
152 pass
154 self.session.add(object)
156 if flush:
157 await self.session.flush()
159 return object
161 async def count(self, statement: Select[tuple[M]]) -> int: 1ab
162 count_statement = statement.with_only_columns(func.count())
163 result = await self.session.execute(count_statement)
164 return result.scalar_one()
166 @classmethod 1ab
167 def from_session(cls, session: AsyncSession | AsyncReadSession) -> Self: 1ab
168 return cls(session) 1c
171class RepositorySoftDeletionProtocol[MODEL_DELETED_AT: ModelDeletedAtProtocol]( 1ab
172 RepositoryProtocol[MODEL_DELETED_AT], Protocol
173):
174 def get_base_statement( 174 ↛ exitline 174 didn't return from function 'get_base_statement' because 1ab
175 self, *, include_deleted: bool = False
176 ) -> Select[tuple[MODEL_DELETED_AT]]: ...
178 async def soft_delete( 178 ↛ exitline 178 didn't return from function 'soft_delete' because 1ab
179 self, object: MODEL_DELETED_AT, *, flush: bool = False
180 ) -> MODEL_DELETED_AT: ...
183class RepositorySoftDeletionMixin[MODEL_DELETED_AT: ModelDeletedAtProtocol]: 1ab
184 def get_base_statement( 1ab
185 self: RepositoryProtocol[MODEL_DELETED_AT],
186 *,
187 include_deleted: bool = False,
188 ) -> Select[tuple[MODEL_DELETED_AT]]:
189 statement = super().get_base_statement() # type: ignore[safe-super] 1c
190 if not include_deleted: 190 ↛ 192line 190 didn't jump to line 192 because the condition on line 190 was always true1c
191 statement = statement.where(self.model.deleted_at.is_(None)) 1c
192 return statement 1c
194 async def soft_delete( 1ab
195 self: RepositoryProtocol[MODEL_DELETED_AT],
196 object: MODEL_DELETED_AT,
197 *,
198 flush: bool = False,
199 ) -> MODEL_DELETED_AT:
200 return await self.update(
201 object, update_dict={"deleted_at": utc_now()}, flush=flush
202 )
205class RepositoryIDMixin[MODEL_ID: ModelIDProtocol, ID_TYPE]: # type: ignore[type-arg] 1ab
206 async def get_by_id( 1ab
207 self: RepositoryProtocol[MODEL_ID],
208 id: ID_TYPE,
209 *,
210 options: Options = (),
211 ) -> MODEL_ID | None:
212 statement = (
213 self.get_base_statement().where(self.model.id == id).options(*options)
214 )
215 return await self.get_one_or_none(statement)
218class RepositorySoftDeletionIDMixin[ 1ab
219 MODEL_DELETED_AT_ID: ModelDeletedAtIDProtocol, # type: ignore[type-arg]
220 ID_TYPE,
221]:
222 async def get_by_id( 1ab
223 self: RepositorySoftDeletionProtocol[MODEL_DELETED_AT_ID],
224 id: ID_TYPE,
225 *,
226 options: Options = (),
227 include_deleted: bool = False,
228 ) -> MODEL_DELETED_AT_ID | None:
229 statement = (
230 self.get_base_statement(include_deleted=include_deleted)
231 .where(self.model.id == id)
232 .options(*options)
233 )
234 return await self.get_one_or_none(statement)
237SortingClause: TypeAlias = ColumnExpressionArgument[Any] | UnaryExpression[Any] 1ab
240class RepositorySortingMixin[M, PE: StrEnum]: 1ab
241 sorting_enum: type[PE]
243 def apply_sorting( 1ab
244 self,
245 statement: Select[tuple[M]],
246 sorting: list[Sorting[PE]],
247 ) -> Select[tuple[M]]:
248 order_by_clauses: list[UnaryExpression[Any]] = [] 1c
249 for criterion, is_desc in sorting: 1c
250 clause_function = desc if is_desc else asc 1c
251 order_by_clauses.append(clause_function(self.get_sorting_clause(criterion))) 1c
252 return statement.order_by(*order_by_clauses) 1c
254 def get_sorting_clause(self, property: PE) -> SortingClause: 1ab
255 raise NotImplementedError()