Coverage for polar/kit/pagination.py: 75%

68 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1import math 1a

2from collections.abc import Sequence 1a

3from typing import Annotated, Any, NamedTuple, Self, overload 1a

4 

5from fastapi import Depends, Query 1a

6from pydantic import BaseModel, GetCoreSchemaHandler 1a

7from pydantic._internal._repr import display_as_type 1a

8from pydantic_core import CoreSchema 1a

9from sqlalchemy import Select, func, over 1a

10from sqlalchemy.sql._typing import _ColumnsClauseArgument 1a

11 

12from polar.config import settings 1a

13from polar.kit.db.models import RecordModel 1a

14from polar.kit.db.models.base import Model 1a

15from polar.kit.db.postgres import AsyncReadSession 1a

16from polar.kit.schemas import ClassName, Schema 1a

17 

18 

19class PaginationParams(NamedTuple): 1a

20 page: int 1a

21 limit: int 1a

22 

23 

24@overload 1a

25async def paginate[RM: RecordModel]( 1a

26 session: AsyncReadSession, 

27 statement: Select[tuple[RM]], 

28 *, 

29 pagination: PaginationParams, 

30 count_clause: _ColumnsClauseArgument[Any] | None = None, 

31) -> tuple[Sequence[RM], int]: ... 

32 

33 

34@overload 1a

35async def paginate[M: Model]( 1a

36 session: AsyncReadSession, 

37 statement: Select[tuple[M]], 

38 *, 

39 pagination: PaginationParams, 

40 count_clause: _ColumnsClauseArgument[Any] | None = None, 

41) -> tuple[Sequence[M], int]: ... 

42 

43 

44@overload 1a

45async def paginate[T: Any]( 1a

46 session: AsyncReadSession, 

47 statement: Select[T], 

48 *, 

49 pagination: PaginationParams, 

50 count_clause: _ColumnsClauseArgument[Any] | None = None, 

51) -> tuple[Sequence[T], int]: ... 

52 

53 

54async def paginate( 1a

55 session: AsyncReadSession, 

56 statement: Select[Any], 

57 *, 

58 pagination: PaginationParams, 

59 count_clause: _ColumnsClauseArgument[Any] | None = None, 

60) -> tuple[Sequence[Any], int]: 

61 page, limit = pagination 

62 offset = limit * (page - 1) 

63 statement = statement.offset(offset).limit(limit) 

64 

65 if count_clause is not None: 

66 statement = statement.add_columns(count_clause) 

67 else: 

68 statement = statement.add_columns(over(func.count())) 

69 

70 result = await session.execute(statement) 

71 

72 results: list[Any] = [] 

73 count = 0 

74 for row in result.unique().all(): 

75 (*queried_data, c) = row._tuple() 

76 count = int(c) 

77 if len(queried_data) == 1: 

78 results.append(queried_data[0]) 

79 else: 

80 results.append(queried_data) 

81 

82 return results, count 

83 

84 

85async def get_pagination_params( 1a

86 page: int = Query(1, description="Page number, defaults to 1.", gt=0), 

87 limit: int = Query( 

88 10, 

89 description=( 

90 f"Size of a page, defaults to 10. " 

91 f"Maximum is {settings.API_PAGINATION_MAX_LIMIT}." 

92 ), 

93 gt=0, 

94 ), 

95) -> PaginationParams: 

96 return PaginationParams(page, min(settings.API_PAGINATION_MAX_LIMIT, limit)) 1b

97 

98 

99PaginationParamsQuery = Annotated[PaginationParams, Depends(get_pagination_params)] 1a

100 

101 

102class Pagination(Schema): 1a

103 total_count: int 1a

104 max_page: int 1a

105 

106 

107class ListResource[T: Any](BaseModel): 1a

108 items: list[T] 1a

109 pagination: Pagination 1a

110 

111 @classmethod 1a

112 def from_paginated_results( 1a

113 cls, items: Sequence[T], total_count: int, pagination_params: PaginationParams 

114 ) -> Self: 

115 return cls( 1b

116 items=list(items), 

117 pagination=Pagination( 

118 total_count=total_count, 

119 max_page=math.ceil(total_count / pagination_params.limit), 

120 ), 

121 ) 

122 

123 @classmethod 1a

124 def model_parametrized_name(cls, params: tuple[type[Any], ...]) -> str: 1a

125 """ 

126 Override default model name implementation to detect `ClassName` metadata. 

127 

128 It's useful to shorten the name when a long union type is used. 

129 """ 

130 param_names = [] 1a

131 for param in params: 1a

132 if hasattr(param, "__metadata__"): 1a

133 for metadata in param.__metadata__: 1a

134 if isinstance(metadata, ClassName): 1a

135 param_names.append(metadata.name) 1a

136 else: 

137 param_names.append(display_as_type(param)) 1a

138 

139 params_component = ", ".join(param_names) 1a

140 return f"{cls.__name__}[{params_component}]" 1a

141 

142 @classmethod 1a

143 def __get_pydantic_core_schema__( 1a

144 cls, source: type[BaseModel], handler: GetCoreSchemaHandler, / 

145 ) -> CoreSchema: 

146 """ 

147 Override the schema to set the `ref` field to the overridden class name. 

148 """ 

149 result = handler(source) 1abc

150 result["ref"] = cls.__name__ # type: ignore 1abc

151 return result 1abc