Coverage for opt/mealie/lib/python3.12/site-packages/mealie/repos/repository_generic.py: 86%
275 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 14:03 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 14:03 +0000
1from __future__ import annotations 1o
3import random 1o
4from collections.abc import Iterable 1o
5from datetime import UTC, datetime 1o
6from math import ceil 1o
7from typing import Any 1o
9from fastapi import HTTPException 1o
10from pydantic import UUID4, BaseModel 1o
11from sqlalchemy import ColumnElement, Select, case, delete, func, nulls_first, nulls_last, select 1o
12from sqlalchemy.orm import InstrumentedAttribute 1o
13from sqlalchemy.orm.session import Session 1o
14from sqlalchemy.sql import sqltypes 1o
16from mealie.core.root_logger import get_logger 1o
17from mealie.db.models._model_base import SqlAlchemyBase 1o
18from mealie.schema._mealie import MealieModel 1o
19from mealie.schema.response.pagination import ( 1o
20 OrderByNullPosition,
21 OrderDirection,
22 PaginationBase,
23 PaginationQuery,
24 RequestQuery,
25)
26from mealie.schema.response.query_filter import QueryFilterBuilder 1o
27from mealie.schema.response.query_search import SearchFilter 1o
29from ._utils import NOT_SET, NotSet 1o
32class RepositoryGeneric[Schema: MealieModel, Model: SqlAlchemyBase]: 1o
33 """A Generic BaseAccess Model method to perform common operations on the database
35 Args:
36 Schema: Represents the Pydantic Model
37 Model: Represents the SqlAlchemyModel Model
38 """
40 session: Session 1o
42 _group_id: UUID4 | None = None 1o
43 _household_id: UUID4 | None = None 1o
45 def __init__( 1o
46 self,
47 session: Session,
48 primary_key: str,
49 sql_model: type[Model],
50 schema: type[Schema],
51 ) -> None:
52 self.session = session 1owvuyxcifdlpghjnkebqrsmta
53 self.primary_key = primary_key 1owvuyxcifdlpghjnkebqrsmta
54 self.model = sql_model 1owvuyxcifdlpghjnkebqrsmta
55 self.schema = schema 1owvuyxcifdlpghjnkebqrsmta
57 self.logger = get_logger() 1owvuyxcifdlpghjnkebqrsmta
59 @property 1o
60 def group_id(self) -> UUID4 | None: 1o
61 return self._group_id 1owvuyxcifdlpghjnkebqrsmta
63 @property 1o
64 def household_id(self) -> UUID4 | None: 1o
65 return self._household_id 1owvuyxcifdlpghjnkebqrsmta
67 @property 1o
68 def column_aliases(self) -> dict[str, ColumnElement]: 1o
69 return {} 1oucifdlpghjnkebqrsmta
71 def _random_seed(self) -> str: 1o
72 return str(datetime.now(tz=UTC)) 1cifdlpghjkebqrma
74 def _log_exception(self, e: Exception) -> None: 1o
75 self.logger.error(f"Error processing query for Repo model={self.model.__name__} schema={self.schema.__name__}") 1cifdpghjneba
76 self.logger.error(e) 1cifdpghjneba
78 def _query(self, override_schema: type[MealieModel] | None = None, with_options=True): 1o
79 q = select(self.model) 1owvuyxcifdlpghjnkebqrsmta
80 if with_options: 1owvuyxcifdlpghjnkebqrsmta
81 schema = override_schema or self.schema 1wvuyxcifdlpghjnkebqrsmta
82 return q.options(*schema.loader_options()) 1wvuyxcifdlpghjnkebqrsmta
83 else:
84 return q 1oucifdlpghjnkebqrsmta
86 def _filter_builder(self, **kwargs) -> dict[str, Any]: 1o
87 dct = {} 1owvuyxcifdlpghjnkebqrsmta
88 if self.group_id: 1owvuyxcifdlpghjnkebqrsmta
89 dct["group_id"] = self.group_id 1wvuxcifdlpghjnkebqrsmta
90 if self.household_id: 1owvuyxcifdlpghjnkebqrsmta
91 dct["household_id"] = self.household_id 1wvuxcifdlpghjnkebqrsmta
93 return {**dct, **kwargs} 1owvuyxcifdlpghjnkebqrsmta
95 def get_all( 1o
96 self,
97 limit: int | None = None,
98 order_by: str | None = None,
99 order_descending: bool = True,
100 override=None,
101 ) -> list[Schema]:
102 pq = PaginationQuery( 1oa
103 per_page=limit or -1,
104 order_by=order_by,
105 order_direction=OrderDirection.desc if order_descending else OrderDirection.asc,
106 page=1,
107 )
109 results = self.page_all(pq, override=override) 1oa
111 return results.items 1oa
113 def multi_query( 1o
114 self,
115 query_by: dict[str, str | bool | int | UUID4],
116 start=0,
117 limit: int | None = None,
118 override_schema=None,
119 order_by: str | None = None,
120 ) -> list[Schema]:
121 # sourcery skip: remove-unnecessary-cast
122 eff_schema = override_schema or self.schema 1wvucifdlpghjnkebqrsmta
124 fltr = self._filter_builder(**query_by) 1wvucifdlpghjnkebqrsmta
125 q = self._query(override_schema=eff_schema).filter_by(**fltr) 1wvucifdlpghjnkebqrsmta
127 if order_by: 127 ↛ 128line 127 didn't jump to line 128 because the condition on line 127 was never true1wvucifdlpghjnkebqrsmta
128 if order_attr := getattr(self.model, str(order_by)):
129 order_attr = order_attr.desc()
130 q = q.order_by(order_attr)
132 q = q.offset(start).limit(limit) 1wvucifdlpghjnkebqrsmta
133 result = self.session.execute(q).unique().scalars().all() 1wvucifdlpghjnkebqrsmta
134 return [eff_schema.model_validate(x) for x in result] 1wvucifdlpghjnkebqrsmta
136 def _query_one(self, match_value: str | int | UUID4, match_key: str | None = None) -> Model: 1o
137 """
138 Query the sql database for one item an return the sql alchemy model
139 object. If no match key is provided the primary_key attribute will be used.
140 """
141 if match_key is None: 1vucifdlghjnkebmta
142 match_key = self.primary_key 1vucifdhebma
144 fltr = self._filter_builder(**{match_key: match_value}) 1vucifdlghjnkebmta
145 return self.session.execute(self._query().filter_by(**fltr)).unique().scalars().one() 1vucifdlghjnkebmta
147 def get_one( 1o
148 self,
149 value: str | int | UUID4,
150 key: str | None = None,
151 any_case=False,
152 override_schema=None,
153 ) -> Schema | None:
154 key = key or self.primary_key 1wvuyxcifdlpghjnkebqrsmta
155 eff_schema = override_schema or self.schema 1wvuyxcifdlpghjnkebqrsmta
157 q = self._query(override_schema=eff_schema) 1wvuyxcifdlpghjnkebqrsmta
159 if any_case: 1wvuyxcifdlpghjnkebqrsmta
160 search_attr = getattr(self.model, key) 1vucifdlpghjnkebqrsmta
161 q = q.where(func.lower(search_attr) == str(value).lower()).filter_by(**self._filter_builder()) 1vucifdlpghjnkebqrsmta
162 else:
163 q = q.filter_by(**self._filter_builder(**{key: value})) 1wvuyxcifdlpghjnkebqrsmta
165 result = self.session.execute(q).unique().scalars().one_or_none() 1wvuyxcifdlpghjnkebqrsmta
167 if not result: 1wvuyxcifdlpghjnkebqrsmta
168 return None 1wvuxcifdlpghjnkebqrsmta
170 return eff_schema.model_validate(result) 1vuyxcifdlpghjnkebqrsmta
172 def create(self, data: Schema | BaseModel | dict) -> Schema: 1o
173 try: 1owucifdlpghjnkebqrsmta
174 data = data if isinstance(data, dict) else data.model_dump() 1owucifdlpghjnkebqrsmta
175 new_document = self.model(session=self.session, **data) 1owucifdlpghjnkebqrsmta
176 self.session.add(new_document) 1owucifdlpghjnkebqrsmta
177 self.session.commit() 1owucifdlpghjnkebqrsmta
178 except Exception: 1wucifdlpghjnkebqrsmta
179 self.session.rollback() 1wucifdlpghjnkebqrsmta
180 raise 1wucifdlpghjnkebqrsmta
182 self.session.refresh(new_document) 1owucifdlpghjnkebqrsmta
184 return self.schema.model_validate(new_document) 1owucifdlpghjnkebqrsmta
186 def create_many(self, data: Iterable[Schema | dict]) -> list[Schema]: 1o
187 new_documents = [] 1cifdlpghjnkebqrsma
188 for document in data: 1cifdlpghjnkebqrsma
189 document = document if isinstance(document, dict) else document.model_dump() 1cifdlpghjnkebqrsma
190 new_document = self.model(session=self.session, **document) 1cifdlpghjnkebqrsma
191 new_documents.append(new_document) 1cifdlpghjnkebqrsma
193 self.session.add_all(new_documents) 1cifdlpghjnkebqrsma
194 self.session.commit() 1cifdlpghjnkebqrsma
196 for created_document in new_documents: 1cifdlpghjnkebqrsma
197 self.session.refresh(created_document) 1cifdlpghjnkebqrsma
199 return [self.schema.model_validate(x) for x in new_documents] 1cifdlpghjnkebqrsma
201 def update(self, match_value: str | int | UUID4, new_data: dict | BaseModel) -> Schema: 1o
202 """Update a database entry.
203 Args:
204 session (Session): Database Session
205 match_value (str): Match "key"
206 new_data (str): Match "value"
208 Returns:
209 dict: Returns a dictionary representation of the database entry
210 """
211 new_data = new_data if isinstance(new_data, dict) else new_data.model_dump() 1vucifdhebma
213 entry = self._query_one(match_value=match_value) 1vucifdhebma
214 entry.update(session=self.session, **new_data) 1vucifdhebma
216 self.session.commit() 1vucifdhebma
217 return self.schema.model_validate(entry) 1vucidebma
219 def update_many(self, data: Iterable[Schema | dict]) -> list[Schema]: 1o
220 document_data_by_id: dict[str, dict] = {}
221 for document in data:
222 document_data = document if isinstance(document, dict) else document.model_dump()
223 document_data_by_id[document_data["id"]] = document_data
225 documents_to_update_query = self._query().filter(self.model.id.in_(list(document_data_by_id.keys())))
226 documents_to_update = self.session.execute(documents_to_update_query).unique().scalars().all()
228 updated_documents = []
229 for document_to_update in documents_to_update: 229 ↛ 230line 229 didn't jump to line 230 because the loop on line 229 never started
230 data = document_data_by_id[document_to_update.id] # type: ignore
231 document_to_update.update(session=self.session, **data) # type: ignore
232 updated_documents.append(document_to_update)
234 self.session.commit()
235 return [self.schema.model_validate(x) for x in updated_documents]
237 def patch(self, match_value: str | int | UUID4, new_data: dict | BaseModel) -> Schema: 1o
238 new_data = new_data if isinstance(new_data, dict) else new_data.model_dump() 1eba
240 entry = self._query_one(match_value=match_value) 1eba
242 entry_as_dict = self.schema.model_validate(entry).model_dump() 1eba
243 entry_as_dict.update(new_data) 1eba
245 return self.update(match_value, entry_as_dict) 1eba
247 def delete(self, value, match_key: str | None = None) -> Schema: 1o
248 match_key = match_key or self.primary_key 1vcfdlgjnkbta
250 result = self._query_one(value, match_key) 1vcfdlgjnkbta
251 result_as_model = self.schema.model_validate(result) 1cfdlgjnkbta
253 try: 1cfdlgjnkbta
254 self.session.delete(result) 1cfdlgjnkbta
255 self.session.commit() 1cfdlgjnkbta
256 except Exception as e:
257 self.session.rollback()
258 raise e
260 return result_as_model 1cfdlgjnkbta
262 def delete_many(self, values: Iterable) -> list[Schema]: 1o
263 query = self._query().filter(self.model.id.in_(values)) 1va
264 results = self.session.execute(query).unique().scalars().all() 1va
265 results_as_model = [self.schema.model_validate(result) for result in results] 1va
267 try: 1va
268 # we create a delete statement for each row
269 # we don't delete the whole query in one statement because postgres doesn't cascade correctly
270 for result in results: 270 ↛ 271line 270 didn't jump to line 271 because the loop on line 270 never started1va
271 self.session.delete(result)
273 self.session.commit() 1va
274 except Exception as e:
275 self.session.rollback()
276 raise e
278 return results_as_model 1va
280 def delete_all(self) -> None: 1o
281 delete(self.model)
282 self.session.commit()
284 def count_all(self, match_key=None, match_value=None) -> int: 1o
285 q = select(func.count(self.model.id))
286 if None not in [match_key, match_value]:
287 q = q.filter_by(**{match_key: match_value})
288 return self.session.scalar(q)
290 def _count_attribute( 1o
291 self,
292 attribute_name: str,
293 attr_match: str | None = None,
294 count=True,
295 override_schema=None,
296 ) -> int | list[Schema]: # sourcery skip: assign-if-exp
297 eff_schema = override_schema or self.schema
299 if count:
300 q = select(func.count(self.model.id)).filter(attribute_name == attr_match)
301 return self.session.scalar(q)
302 else:
303 q = self._query(override_schema=eff_schema).filter(attribute_name == attr_match)
304 return [eff_schema.model_validate(x) for x in self.session.execute(q).scalars().all()]
306 def page_all(self, pagination: PaginationQuery, override=None, search: str | None = None) -> PaginationBase[Schema]: 1o
307 """
308 pagination is a method to interact with the filtered database table and return a paginated result
309 using the PaginationBase that provides several data points that are needed to manage pagination
310 on the client side. This method does utilize the _filter_build method to ensure that the results
311 are filtered by the group id when applicable.
313 NOTE: When you provide an override you'll need to manually type the result of this method
314 as the override, as the type system is not able to infer the result of this method.
315 """
316 eff_schema = override or self.schema 1oucifdlpghjnkebqrsmta
317 # Copy this, because calling methods (e.g. tests) might rely on it not getting mutated
318 pagination_result = pagination.model_copy() 1oucifdlpghjnkebqrsmta
319 q = self._query(override_schema=eff_schema, with_options=False) 1oucifdlpghjnkebqrsmta
321 fltr = self._filter_builder() 1oucifdlpghjnkebqrsmta
322 q = q.filter_by(**fltr) 1oucifdlpghjnkebqrsmta
323 if search: 1oucifdlpghjnkebqrsmta
324 q = self.add_search_to_query(q, eff_schema, search) 1cifdlpghjnkebqrsta
326 if not pagination_result.order_by and not search: 1oucifdlpghjnkebqrsmta
327 # default ordering if not searching
328 pagination_result.order_by = "created_at" 1oucifdlpghjnkebqrsmta
330 q, count, total_pages = self.add_pagination_to_query(q, pagination_result) 1oucifdlpghjnkebqrsmta
332 # Apply options late, so they do not get used for counting
333 q = q.options(*eff_schema.loader_options()) 1oucifdlpghjnkebqrsmta
334 try: 1oucifdlpghjnkebqrsmta
335 data = self.session.execute(q).unique().scalars().all() 1oucifdlpghjnkebqrsmta
336 except Exception as e: 1cifdpghjneba
337 self._log_exception(e) 1cifdpghjneba
338 self.session.rollback() 1cifdpghjneba
339 raise e 1cifdpghjneba
340 return PaginationBase( 1oucifdlpghjnkebqrsmta
341 page=pagination_result.page,
342 per_page=pagination_result.per_page,
343 total=count,
344 total_pages=total_pages,
345 items=[eff_schema.model_validate(s) for s in data],
346 )
348 def add_pagination_to_query(self, query: Select, pagination: PaginationQuery) -> tuple[Select, int, int]: 1o
349 """
350 Adds pagination data to an existing query.
352 :returns:
353 - query - modified query with pagination data
354 - count - total number of records (without pagination)
355 - total_pages - the total number of pages in the query
356 """
358 if pagination.query_filter: 1oucifdlpghjnkebqrsmta
359 try: 1cifdlpghjnkebqrsmta
360 query_filter_builder = QueryFilterBuilder(pagination.query_filter) 1cifdlpghjnkebqrsmta
361 query = query_filter_builder.filter_query(query, model=self.model, column_aliases=self.column_aliases) 1cifdlpghjnkebqrsmta
363 except ValueError as e: 1cfdlghkebqrmta
364 self.logger.error(e) 1cfdlghkeqrmta
365 raise HTTPException(status_code=400, detail=str(e)) from e 1cfdlghkeqrmta
367 count_query = select(func.count()).select_from(query.subquery()) 1oucifdlpghjnkebqrsmta
368 count = self.session.scalar(count_query) 1oucifdlpghjnkebqrsmta
369 if not count: 1oucifdlpghjnkebqrsmta
370 count = 0 1ocifdlpghjnkebqrsmta
372 # interpret -1 as "get_all"
373 if pagination.per_page == -1: 1oucifdlpghjnkebqrsmta
374 pagination.per_page = count 1oucifdlpghjnkebqrsma
376 try: 1oucifdlpghjnkebqrsmta
377 total_pages = ceil(count / pagination.per_page) 1oucifdlpghjnkebqrsmta
378 except ZeroDivisionError: 1oiksma
379 total_pages = 0 1oiksma
381 # interpret -1 as "last page"
382 if pagination.page == -1: 382 ↛ 383line 382 didn't jump to line 383 because the condition on line 382 was never true1oucifdlpghjnkebqrsmta
383 pagination.page = total_pages
385 # failsafe for user input error
386 if pagination.page < 1: 1oucifdlpghjnkebqrsmta
387 pagination.page = 1 1cifdlpghjnkebqrsmta
389 query = self.add_order_by_to_query(query, pagination) 1oucifdlpghjnkebqrsmta
390 return query.limit(pagination.per_page).offset((pagination.page - 1) * pagination.per_page), count, total_pages 1oucifdlpghjnkebqrsmta
392 def add_order_attr_to_query( 1o
393 self,
394 query: Select,
395 order_attr: InstrumentedAttribute,
396 order_dir: OrderDirection,
397 order_by_null: OrderByNullPosition | None,
398 ) -> Select:
399 order_attr = self.column_aliases.get(order_attr.key, order_attr) 1oucifdlpghjnkebqrsmta
401 # queries handle uppercase and lowercase differently, which is undesirable
402 if isinstance(order_attr.type, sqltypes.String): 1oucifdlpghjnkebqrsmta
403 order_attr = func.lower(order_attr) 1cifdlpghjkebqrsma
405 if order_dir is OrderDirection.asc: 1oucifdlpghjnkebqrsmta
406 order_attr = order_attr.asc() 1cifdlpghjnkebqrsma
407 elif order_dir is OrderDirection.desc: 407 ↛ 410line 407 didn't jump to line 410 because the condition on line 407 was always true1oucifdlpghjnkebqrsmta
408 order_attr = order_attr.desc() 1oucifdlpghjnkebqrsmta
410 if order_by_null is OrderByNullPosition.first: 1oucifdlpghjnkebqrsmta
411 order_attr = nulls_first(order_attr) 1cifdlpghjnkebsma
412 elif order_by_null is OrderByNullPosition.last: 1oucifdlpghjnkebqrsmta
413 order_attr = nulls_last(order_attr) 1cifdlpghjkebqrsmta
415 return query.order_by(order_attr) 1oucifdlpghjnkebqrsmta
417 def add_order_by_to_query(self, query: Select, request_query: RequestQuery) -> Select: 1o
418 if not request_query.order_by: 1oucifdlpghjnkebqrsmta
419 return query 1cfdlpghjnkebqa
421 elif request_query.order_by == "random": 1oucifdlpghjnkebqrsmta
422 # randomize outside of database, since not all db's can set random seeds
423 # this solution is db-independent & stable to paging
424 temp_query = query.with_only_columns(self.model.id) 1cifdlpghjkebqrma
425 allids = self.session.execute(temp_query).scalars().all() # fast because id is indexed 1cifdlpghjkebqrma
426 if not allids: 426 ↛ 429line 426 didn't jump to line 429 because the condition on line 426 was always true1cifdlpghjkebqrma
427 return query 1cifdlpghjkebqrma
429 order = list(range(len(allids)))
430 random.seed(request_query.pagination_seed)
431 random.shuffle(order)
432 random_dict = dict(zip(allids, order, strict=True))
433 case_stmt = case(random_dict, value=self.model.id)
434 return query.order_by(case_stmt)
436 else:
437 for order_by_val in request_query.order_by.split(","): 1oucifdlpghjnkebqrsmta
438 try: 1oucifdlpghjnkebqrsmta
439 order_by_val = order_by_val.strip() 1oucifdlpghjnkebqrsmta
440 if ":" in order_by_val: 1oucifdlpghjnkebqrsmta
441 order_by, order_dir_val = order_by_val.split(":") 1ea
442 order_dir = OrderDirection(order_dir_val) 1ea
443 else:
444 order_by = order_by_val 1oucifdlpghjnkebqrsmta
445 order_dir = request_query.order_direction 1oucifdlpghjnkebqrsmta
447 _, order_attr, query = QueryFilterBuilder.get_model_and_model_attr_from_attr_string( 1oucifdlpghjnkebqrsmta
448 order_by, self.model, query=query
449 )
451 query = self.add_order_attr_to_query( 1oucifdlpghjnkebqrsmta
452 query, order_attr, order_dir, request_query.order_by_null_position
453 )
455 except ValueError as e: 1cifdlpghjnkebqrsmta
456 raise HTTPException( 1cifdlpghjnkebqrsmta
457 status_code=400,
458 detail=f'Invalid order_by statement "{request_query.order_by}": "{order_by_val}" is invalid',
459 ) from e
461 return query 1oucifdlpghjnkebqrsmta
463 def add_search_to_query(self, query: Select, schema: type[Schema], search: str) -> Select: 1o
464 search_filter = SearchFilter(self.session, search, schema._normalize_search) 1cifdlpghjnkebqrsta
465 return search_filter.filter_query_by_search(query, schema, self.model) 1cifdlpghjnkebqrsta
468class GroupRepositoryGeneric[Schema: MealieModel, Model: SqlAlchemyBase](RepositoryGeneric[Schema, Model]): 1o
469 def __init__( 1o
470 self,
471 session: Session,
472 primary_key: str,
473 sql_model: type[Model],
474 schema: type[Schema],
475 *,
476 group_id: UUID4 | None | NotSet,
477 ) -> None:
478 super().__init__(session, primary_key, sql_model, schema) 1owvuyxcifdlpghjnkebqrsmta
479 if group_id is NOT_SET: 479 ↛ 480line 479 didn't jump to line 480 because the condition on line 479 was never true1owvuyxcifdlpghjnkebqrsmta
480 raise ValueError("group_id must be set")
481 self._group_id = group_id if group_id else None 1owvuyxcifdlpghjnkebqrsmta
484class HouseholdRepositoryGeneric[Schema: MealieModel, Model: SqlAlchemyBase](RepositoryGeneric[Schema, Model]): 1o
485 def __init__( 1o
486 self,
487 session: Session,
488 primary_key: str,
489 sql_model: type[Model],
490 schema: type[Schema],
491 *,
492 group_id: UUID4 | None | NotSet,
493 household_id: UUID4 | None | NotSet,
494 ) -> None:
495 super().__init__(session, primary_key, sql_model, schema) 1owvuxcifdlpghjnkebqrsmta
496 if group_id is NOT_SET: 496 ↛ 497line 496 didn't jump to line 497 because the condition on line 496 was never true1owvuxcifdlpghjnkebqrsmta
497 raise ValueError("group_id must be set")
498 self._group_id = group_id if group_id else None 1owvuxcifdlpghjnkebqrsmta
500 if household_id is NOT_SET: 500 ↛ 501line 500 didn't jump to line 501 because the condition on line 500 was never true1owvuxcifdlpghjnkebqrsmta
501 raise ValueError("household_id must be set")
502 self._household_id = household_id if household_id else None 1owvuxcifdlpghjnkebqrsmta