Coverage for polar/meter/filter.py: 32%

89 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1from enum import StrEnum 1ab

2from typing import Annotated, Any 1ab

3 

4from annotated_types import Ge, Le, MaxLen 1ab

5from pydantic import AfterValidator, BaseModel, ConfigDict 1ab

6from sqlalchemy import ( 1ab

7 ColumnExpressionArgument, 

8 Dialect, 

9 String, 

10 TypeDecorator, 

11 and_, 

12 case, 

13 false, 

14 func, 

15 or_, 

16 true, 

17) 

18from sqlalchemy.dialects.postgresql import JSONB 1ab

19 

20# PostgreSQL int4 range limits 

21INT_MIN_VALUE = -2_147_483_648 1ab

22INT_MAX_VALUE = 2_147_483_647 1ab

23 

24# String length limit for filtering values 

25MAX_STRING_LENGTH = 1000 1ab

26 

27 

28class FilterOperator(StrEnum): 1ab

29 eq = "eq" 1ab

30 ne = "ne" 1ab

31 gt = "gt" 1ab

32 gte = "gte" 1ab

33 lt = "lt" 1ab

34 lte = "lte" 1ab

35 like = "like" 1ab

36 not_like = "not_like" 1ab

37 

38 

39def _strip_metadata_prefix(value: str) -> str: 1ab

40 prefix = "metadata." 

41 return value[len(prefix) :] if value.startswith(prefix) else value 

42 

43 

44class FilterClause(BaseModel): 1ab

45 property: Annotated[str, AfterValidator(_strip_metadata_prefix)] 1ab

46 operator: FilterOperator 1ab

47 value: ( 1ab

48 Annotated[str, MaxLen(MAX_STRING_LENGTH)] 

49 | Annotated[int, Ge(INT_MIN_VALUE), Le(INT_MAX_VALUE)] 

50 | bool 

51 ) 

52 

53 def get_sql_clause(self, model: type[Any]) -> ColumnExpressionArgument[bool]: 1ab

54 if self.property in model._filterable_fields: 

55 allowed_type, attr = model._filterable_fields[self.property] 

56 if not isinstance(self.value, allowed_type): 

57 return false() 

58 # The operator is LIKE OR NOT LIKE, treat the attribute as a string 

59 if self.operator in (FilterOperator.like, FilterOperator.not_like): 

60 if allowed_type is not str: 

61 attr = func.cast(attr, String) 

62 return self._get_comparison_clause(attr, self._get_str_value()) 

63 return self._get_comparison_clause(attr, self.value) 

64 

65 attr = model.user_metadata[self.property] 

66 

67 # The operator is LIKE OR NOT LIKE, treat everything as a string 

68 if self.operator in (FilterOperator.like, FilterOperator.not_like): 

69 return self._get_comparison_clause(attr.as_string(), self._get_str_value()) 

70 

71 return case( 

72 # The property is a string, compare it with the value as a string 

73 ( 

74 func.jsonb_typeof(attr) == "string", 

75 self._get_comparison_clause(attr.as_string(), self._get_str_value()), 

76 ), 

77 # The property is a number 

78 ( 

79 func.jsonb_typeof(attr) == "number", 

80 # Compare it with the value if it's a number 

81 self._get_comparison_clause(attr.as_float(), self._get_number_value()) 

82 if isinstance(self.value, int | float) 

83 # Otherwise return false 

84 else false(), 

85 ), 

86 # The property is a boolean 

87 ( 

88 func.jsonb_typeof(attr) == "boolean", 

89 # Compare it with the value if it's a boolean 

90 self._get_comparison_clause(attr.as_boolean(), self.value) 

91 if isinstance(self.value, bool) 

92 # Otherwise return false 

93 else false(), 

94 ), 

95 ) 

96 

97 def _get_comparison_clause(self, attr: Any, value: str | int | bool) -> Any: 1ab

98 if self.operator == FilterOperator.eq: 

99 return attr == value 

100 elif self.operator == FilterOperator.ne: 

101 return attr != value 

102 elif self.operator == FilterOperator.gt: 

103 return attr > value 

104 elif self.operator == FilterOperator.gte: 

105 return attr >= value 

106 elif self.operator == FilterOperator.lt: 

107 return attr < value 

108 elif self.operator == FilterOperator.lte: 

109 return attr <= value 

110 elif self.operator == FilterOperator.like: 

111 return attr.like(f"%{value}%") 

112 elif self.operator == FilterOperator.not_like: 

113 return attr.notlike(f"%{value}%") 

114 raise ValueError(f"Unsupported operator: {self.operator}") 

115 

116 def _get_str_value(self) -> str: 1ab

117 if isinstance(self.value, bool): 

118 return "t" if self.value else "f" 

119 return str(self.value) 

120 

121 def _get_number_value(self) -> int: 1ab

122 if isinstance(self.value, str): 

123 raise ValueError("Cannot convert string to number") 

124 if isinstance(self.value, bool): 

125 return 1 if self.value else 0 

126 return self.value 

127 

128 

129class FilterConjunction(StrEnum): 1ab

130 and_ = "and" 1ab

131 or_ = "or" 1ab

132 

133 

134class Filter(BaseModel): 1ab

135 conjunction: FilterConjunction 1ab

136 clauses: list["FilterClause | Filter"] 1ab

137 

138 model_config = ConfigDict( 1ab

139 # IMPORTANT: this ensures FastAPI doesn't generate `-Input` for output schemas 

140 json_schema_mode_override="serialization", 

141 ) 

142 

143 def get_sql_clause(self, model: type[Any]) -> ColumnExpressionArgument[bool]: 1ab

144 sql_clauses: list[ColumnExpressionArgument[bool]] = [ 

145 clause.get_sql_clause(model) for clause in self.clauses 

146 ] 

147 conjunction = and_ if self.conjunction == FilterConjunction.and_ else or_ 

148 return conjunction(*sql_clauses or (true(),)) 

149 

150 

151class FilterType(TypeDecorator[Any]): 1ab

152 impl = JSONB 1ab

153 cache_ok = True 1ab

154 

155 def process_bind_param(self, value: Any, dialect: Dialect) -> Any: 1ab

156 if isinstance(value, Filter): 

157 return value.model_dump() 

158 return value 

159 

160 def process_result_value(self, value: str | None, dialect: Dialect) -> Any: 1ab

161 if value is not None: 

162 return Filter.model_validate(value) 

163 return value