Coverage for polar/meter/aggregation.py: 40%
77 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
1from enum import StrEnum 1ab
2from typing import Annotated, Any, Literal 1ab
4from pydantic import AfterValidator, BaseModel, Discriminator, TypeAdapter 1ab
5from sqlalchemy import ( 1ab
6 ColumnExpressionArgument,
7 Dialect,
8 Float,
9 TypeDecorator,
10 false,
11 func,
12 true,
13)
14from sqlalchemy.dialects.postgresql import JSONB 1ab
17class AggregationFunction(StrEnum): 1ab
18 cnt = "count" # `count` is a reserved keyword, so we use `cnt` as key 1ab
19 sum = "sum" 1ab
20 max = "max" 1ab
21 min = "min" 1ab
22 avg = "avg" 1ab
23 unique = "unique" 1ab
25 def get_sql_function(self, attr: Any) -> Any: 1ab
26 match self:
27 case AggregationFunction.cnt:
28 return func.count(attr)
29 case AggregationFunction.sum:
30 return func.sum(attr)
31 case AggregationFunction.max:
32 return func.max(attr)
33 case AggregationFunction.min:
34 return func.min(attr)
35 case AggregationFunction.avg:
36 return func.avg(attr)
37 case AggregationFunction.unique:
38 return func.count(func.distinct(attr))
41class CountAggregation(BaseModel): 1ab
42 func: Literal[AggregationFunction.cnt] = AggregationFunction.cnt 1ab
44 def get_sql_column(self, model: type[Any]) -> Any: 1ab
45 return self.func.get_sql_function(model.id)
47 def get_sql_clause(self, model: type[Any]) -> ColumnExpressionArgument[bool]: 1ab
48 return true()
50 def is_summable(self) -> bool: 1ab
51 """
52 Whether this aggregation can be computed separately across different price groups
53 and then summed together. Count aggregations are summable.
54 """
55 return True
58def _strip_metadata_prefix(value: str) -> str: 1ab
59 prefix = "metadata."
60 return value[len(prefix) :] if value.startswith(prefix) else value
63class PropertyAggregation(BaseModel): 1ab
64 func: Literal[ 1ab
65 AggregationFunction.sum,
66 AggregationFunction.max,
67 AggregationFunction.min,
68 AggregationFunction.avg,
69 ]
70 property: Annotated[str, AfterValidator(_strip_metadata_prefix)] 1ab
72 def get_sql_column(self, model: type[Any]) -> Any: 1ab
73 if self.property in model._filterable_fields:
74 _, attr = model._filterable_fields[self.property]
75 attr = func.cast(attr, Float)
76 else:
77 attr = model.user_metadata[self.property].as_float()
79 return self.func.get_sql_function(attr)
81 def get_sql_clause(self, model: type[Any]) -> ColumnExpressionArgument[bool]: 1ab
82 if self.property in model._filterable_fields:
83 allowed_type, _ = model._filterable_fields[self.property]
84 return true() if allowed_type is int else false()
86 return func.jsonb_typeof(model.user_metadata[self.property]) == "number"
88 def is_summable(self) -> bool: 1ab
89 """
90 Whether this aggregation can be computed separately across different groups
91 and then summed together. Only SUM is summable; MAX, MIN, AVG are not.
92 """
93 return self.func == AggregationFunction.sum
96class UniqueAggregation(BaseModel): 1ab
97 func: Literal[AggregationFunction.unique] = AggregationFunction.unique 1ab
98 property: Annotated[str, AfterValidator(_strip_metadata_prefix)] 1ab
100 def get_sql_column(self, model: type[Any]) -> Any: 1ab
101 attr = model.user_metadata[self.property]
102 return self.func.get_sql_function(attr)
104 def get_sql_clause(self, model: type[Any]) -> ColumnExpressionArgument[bool]: 1ab
105 return true()
107 def is_summable(self) -> bool: 1ab
108 """
109 Whether this aggregation can be computed separately across different groups
110 and then summed together. Unique count is not summable (same unique value
111 could appear in multiple groups).
112 """
113 return False
116_Aggregation = CountAggregation | PropertyAggregation | UniqueAggregation 1ab
117Aggregation = Annotated[_Aggregation, Discriminator("func")] 1ab
118AggregationTypeAdapter: TypeAdapter[Aggregation] = TypeAdapter(Aggregation) 1ab
121class AggregationType(TypeDecorator[Any]): 1ab
122 impl = JSONB 1ab
123 cache_ok = True 1ab
125 def process_bind_param(self, value: Any, dialect: Dialect) -> Any: 1ab
126 if isinstance(value, _Aggregation):
127 return value.model_dump()
128 return value
130 def process_result_value(self, value: str | None, dialect: Dialect) -> Any: 1ab
131 if value is not None:
132 return AggregationTypeAdapter.validate_python(value)
133 return value