Coverage for polar/meter/aggregation.py: 40%

77 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1from enum import StrEnum 1ab

2from typing import Annotated, Any, Literal 1ab

3 

4from pydantic import AfterValidator, BaseModel, Discriminator, TypeAdapter 1ab

5from sqlalchemy import ( 1ab

6 ColumnExpressionArgument, 

7 Dialect, 

8 Float, 

9 TypeDecorator, 

10 false, 

11 func, 

12 true, 

13) 

14from sqlalchemy.dialects.postgresql import JSONB 1ab

15 

16 

17class AggregationFunction(StrEnum): 1ab

18 cnt = "count" # `count` is a reserved keyword, so we use `cnt` as key 1ab

19 sum = "sum" 1ab

20 max = "max" 1ab

21 min = "min" 1ab

22 avg = "avg" 1ab

23 unique = "unique" 1ab

24 

25 def get_sql_function(self, attr: Any) -> Any: 1ab

26 match self: 

27 case AggregationFunction.cnt: 

28 return func.count(attr) 

29 case AggregationFunction.sum: 

30 return func.sum(attr) 

31 case AggregationFunction.max: 

32 return func.max(attr) 

33 case AggregationFunction.min: 

34 return func.min(attr) 

35 case AggregationFunction.avg: 

36 return func.avg(attr) 

37 case AggregationFunction.unique: 

38 return func.count(func.distinct(attr)) 

39 

40 

41class CountAggregation(BaseModel): 1ab

42 func: Literal[AggregationFunction.cnt] = AggregationFunction.cnt 1ab

43 

44 def get_sql_column(self, model: type[Any]) -> Any: 1ab

45 return self.func.get_sql_function(model.id) 

46 

47 def get_sql_clause(self, model: type[Any]) -> ColumnExpressionArgument[bool]: 1ab

48 return true() 

49 

50 def is_summable(self) -> bool: 1ab

51 """ 

52 Whether this aggregation can be computed separately across different price groups 

53 and then summed together. Count aggregations are summable. 

54 """ 

55 return True 

56 

57 

58def _strip_metadata_prefix(value: str) -> str: 1ab

59 prefix = "metadata." 

60 return value[len(prefix) :] if value.startswith(prefix) else value 

61 

62 

63class PropertyAggregation(BaseModel): 1ab

64 func: Literal[ 1ab

65 AggregationFunction.sum, 

66 AggregationFunction.max, 

67 AggregationFunction.min, 

68 AggregationFunction.avg, 

69 ] 

70 property: Annotated[str, AfterValidator(_strip_metadata_prefix)] 1ab

71 

72 def get_sql_column(self, model: type[Any]) -> Any: 1ab

73 if self.property in model._filterable_fields: 

74 _, attr = model._filterable_fields[self.property] 

75 attr = func.cast(attr, Float) 

76 else: 

77 attr = model.user_metadata[self.property].as_float() 

78 

79 return self.func.get_sql_function(attr) 

80 

81 def get_sql_clause(self, model: type[Any]) -> ColumnExpressionArgument[bool]: 1ab

82 if self.property in model._filterable_fields: 

83 allowed_type, _ = model._filterable_fields[self.property] 

84 return true() if allowed_type is int else false() 

85 

86 return func.jsonb_typeof(model.user_metadata[self.property]) == "number" 

87 

88 def is_summable(self) -> bool: 1ab

89 """ 

90 Whether this aggregation can be computed separately across different groups 

91 and then summed together. Only SUM is summable; MAX, MIN, AVG are not. 

92 """ 

93 return self.func == AggregationFunction.sum 

94 

95 

96class UniqueAggregation(BaseModel): 1ab

97 func: Literal[AggregationFunction.unique] = AggregationFunction.unique 1ab

98 property: Annotated[str, AfterValidator(_strip_metadata_prefix)] 1ab

99 

100 def get_sql_column(self, model: type[Any]) -> Any: 1ab

101 attr = model.user_metadata[self.property] 

102 return self.func.get_sql_function(attr) 

103 

104 def get_sql_clause(self, model: type[Any]) -> ColumnExpressionArgument[bool]: 1ab

105 return true() 

106 

107 def is_summable(self) -> bool: 1ab

108 """ 

109 Whether this aggregation can be computed separately across different groups 

110 and then summed together. Unique count is not summable (same unique value 

111 could appear in multiple groups). 

112 """ 

113 return False 

114 

115 

116_Aggregation = CountAggregation | PropertyAggregation | UniqueAggregation 1ab

117Aggregation = Annotated[_Aggregation, Discriminator("func")] 1ab

118AggregationTypeAdapter: TypeAdapter[Aggregation] = TypeAdapter(Aggregation) 1ab

119 

120 

121class AggregationType(TypeDecorator[Any]): 1ab

122 impl = JSONB 1ab

123 cache_ok = True 1ab

124 

125 def process_bind_param(self, value: Any, dialect: Dialect) -> Any: 1ab

126 if isinstance(value, _Aggregation): 

127 return value.model_dump() 

128 return value 

129 

130 def process_result_value(self, value: str | None, dialect: Dialect) -> Any: 1ab

131 if value is not None: 

132 return AggregationTypeAdapter.validate_python(value) 

133 return value