Coverage for polar/kit/pagination.py: 75%
68 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
1import math 1a
2from collections.abc import Sequence 1a
3from typing import Annotated, Any, NamedTuple, Self, overload 1a
5from fastapi import Depends, Query 1a
6from pydantic import BaseModel, GetCoreSchemaHandler 1a
7from pydantic._internal._repr import display_as_type 1a
8from pydantic_core import CoreSchema 1a
9from sqlalchemy import Select, func, over 1a
10from sqlalchemy.sql._typing import _ColumnsClauseArgument 1a
12from polar.config import settings 1a
13from polar.kit.db.models import RecordModel 1a
14from polar.kit.db.models.base import Model 1a
15from polar.kit.db.postgres import AsyncReadSession 1a
16from polar.kit.schemas import ClassName, Schema 1a
19class PaginationParams(NamedTuple): 1a
20 page: int 1a
21 limit: int 1a
24@overload 1a
25async def paginate[RM: RecordModel]( 1a
26 session: AsyncReadSession,
27 statement: Select[tuple[RM]],
28 *,
29 pagination: PaginationParams,
30 count_clause: _ColumnsClauseArgument[Any] | None = None,
31) -> tuple[Sequence[RM], int]: ...
34@overload 1a
35async def paginate[M: Model]( 1a
36 session: AsyncReadSession,
37 statement: Select[tuple[M]],
38 *,
39 pagination: PaginationParams,
40 count_clause: _ColumnsClauseArgument[Any] | None = None,
41) -> tuple[Sequence[M], int]: ...
44@overload 1a
45async def paginate[T: Any]( 1a
46 session: AsyncReadSession,
47 statement: Select[T],
48 *,
49 pagination: PaginationParams,
50 count_clause: _ColumnsClauseArgument[Any] | None = None,
51) -> tuple[Sequence[T], int]: ...
54async def paginate( 1a
55 session: AsyncReadSession,
56 statement: Select[Any],
57 *,
58 pagination: PaginationParams,
59 count_clause: _ColumnsClauseArgument[Any] | None = None,
60) -> tuple[Sequence[Any], int]:
61 page, limit = pagination
62 offset = limit * (page - 1)
63 statement = statement.offset(offset).limit(limit)
65 if count_clause is not None:
66 statement = statement.add_columns(count_clause)
67 else:
68 statement = statement.add_columns(over(func.count()))
70 result = await session.execute(statement)
72 results: list[Any] = []
73 count = 0
74 for row in result.unique().all():
75 (*queried_data, c) = row._tuple()
76 count = int(c)
77 if len(queried_data) == 1:
78 results.append(queried_data[0])
79 else:
80 results.append(queried_data)
82 return results, count
85async def get_pagination_params( 1a
86 page: int = Query(1, description="Page number, defaults to 1.", gt=0),
87 limit: int = Query(
88 10,
89 description=(
90 f"Size of a page, defaults to 10. "
91 f"Maximum is {settings.API_PAGINATION_MAX_LIMIT}."
92 ),
93 gt=0,
94 ),
95) -> PaginationParams:
96 return PaginationParams(page, min(settings.API_PAGINATION_MAX_LIMIT, limit)) 1b
99PaginationParamsQuery = Annotated[PaginationParams, Depends(get_pagination_params)] 1a
102class Pagination(Schema): 1a
103 total_count: int 1a
104 max_page: int 1a
107class ListResource[T: Any](BaseModel): 1a
108 items: list[T] 1a
109 pagination: Pagination 1a
111 @classmethod 1a
112 def from_paginated_results( 1a
113 cls, items: Sequence[T], total_count: int, pagination_params: PaginationParams
114 ) -> Self:
115 return cls( 1b
116 items=list(items),
117 pagination=Pagination(
118 total_count=total_count,
119 max_page=math.ceil(total_count / pagination_params.limit),
120 ),
121 )
123 @classmethod 1a
124 def model_parametrized_name(cls, params: tuple[type[Any], ...]) -> str: 1a
125 """
126 Override default model name implementation to detect `ClassName` metadata.
128 It's useful to shorten the name when a long union type is used.
129 """
130 param_names = [] 1a
131 for param in params: 1a
132 if hasattr(param, "__metadata__"): 1a
133 for metadata in param.__metadata__: 1a
134 if isinstance(metadata, ClassName): 1a
135 param_names.append(metadata.name) 1a
136 else:
137 param_names.append(display_as_type(param)) 1a
139 params_component = ", ".join(param_names) 1a
140 return f"{cls.__name__}[{params_component}]" 1a
142 @classmethod 1a
143 def __get_pydantic_core_schema__( 1a
144 cls, source: type[BaseModel], handler: GetCoreSchemaHandler, /
145 ) -> CoreSchema:
146 """
147 Override the schema to set the `ref` field to the overridden class name.
148 """
149 result = handler(source) 1abc
150 result["ref"] = cls.__name__ # type: ignore 1abc
151 return result 1abc