Coverage for polar/metrics/service.py: 25%

57 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1import uuid 1a

2from collections.abc import Sequence 1a

3from datetime import date, datetime 1a

4from zoneinfo import ZoneInfo 1a

5 

6import logfire 1a

7from sqlalchemy import ColumnElement, FromClause, select, text 1a

8 

9from polar.auth.models import AuthSubject 1a

10from polar.config import settings 1a

11from polar.kit.time_queries import TimeInterval, get_timestamp_series_cte 1a

12from polar.models import Organization, User 1a

13from polar.models.product import ProductBillingType 1a

14from polar.postgres import AsyncReadSession, AsyncSession 1a

15 

16from .metrics import METRICS, METRICS_POST_COMPUTE, METRICS_SQL 1a

17from .queries import QUERIES 1a

18from .schemas import MetricsPeriod, MetricsResponse 1a

19 

20 

21class MetricsService: 1a

22 async def get_metrics( 1a

23 self, 

24 session: AsyncSession | AsyncReadSession, 

25 auth_subject: AuthSubject[User | Organization], 

26 *, 

27 start_date: date, 

28 end_date: date, 

29 timezone: ZoneInfo, 

30 interval: TimeInterval, 

31 organization_id: Sequence[uuid.UUID] | None = None, 

32 product_id: Sequence[uuid.UUID] | None = None, 

33 billing_type: Sequence[ProductBillingType] | None = None, 

34 customer_id: Sequence[uuid.UUID] | None = None, 

35 now: datetime | None = None, 

36 ) -> MetricsResponse: 

37 await session.execute(text(f"SET LOCAL TIME ZONE '{timezone.key}'")) 

38 start_timestamp = datetime( 

39 start_date.year, start_date.month, start_date.day, 0, 0, 0, 0, timezone 

40 ) 

41 end_timestamp = datetime( 

42 end_date.year, end_date.month, end_date.day, 23, 59, 59, 999999, timezone 

43 ) 

44 

45 # Store original bounds before truncation for filtering queries 

46 original_start_timestamp = start_timestamp 

47 original_end_timestamp = end_timestamp 

48 

49 # Truncate start_timestamp to the beginning of the interval period 

50 # This ensures the timestamp series aligns with how daily metrics are grouped 

51 if interval == TimeInterval.month: 

52 start_timestamp = start_timestamp.replace(day=1) 

53 elif interval == TimeInterval.year: 

54 start_timestamp = start_timestamp.replace(month=1, day=1) 

55 

56 timestamp_series = get_timestamp_series_cte( 

57 start_timestamp, end_timestamp, interval 

58 ) 

59 timestamp_column: ColumnElement[datetime] = timestamp_series.c.timestamp 

60 

61 queries = [ 

62 query( 

63 timestamp_series, 

64 interval, 

65 auth_subject, 

66 METRICS_SQL, 

67 now or datetime.now(tz=timezone), 

68 bounds=(original_start_timestamp, original_end_timestamp), 

69 organization_id=organization_id, 

70 product_id=product_id, 

71 billing_type=billing_type, 

72 customer_id=customer_id, 

73 ) 

74 for query in QUERIES 

75 ] 

76 

77 from_query: FromClause = timestamp_series 

78 for query in queries: 

79 from_query = from_query.join( 

80 query, 

81 onclause=query.c.timestamp == timestamp_column, 

82 ) 

83 

84 statement = ( 

85 select( 

86 timestamp_column.label("timestamp"), 

87 *queries, 

88 ) 

89 .select_from(from_query) 

90 .order_by(timestamp_column.asc()) 

91 ) 

92 

93 periods: list[MetricsPeriod] = [] 

94 with logfire.span( 

95 "Stream and process metrics query", 

96 start_date=str(start_date), 

97 end_date=str(end_date), 

98 ): 

99 result = await session.stream( 

100 statement, 

101 execution_options={"yield_per": settings.DATABASE_STREAM_YIELD_PER}, 

102 ) 

103 

104 row_count = 0 

105 with logfire.span("Fetch and process rows"): 

106 async for row in result: 

107 row_count += 1 

108 period_dict = row._asdict() 

109 

110 # Compute meta metrics with cascading dependencies 

111 # Each metric can depend on previously computed metrics 

112 temp_period_dict = dict(period_dict) 

113 

114 # Initialize all computed metrics to 0 first to satisfy Pydantic schema 

115 for meta_metric in METRICS_POST_COMPUTE: 

116 temp_period_dict[meta_metric.slug] = 0 

117 

118 # Now compute each metric, updating the dict as we go 

119 # This allows later metrics to depend on earlier computed metrics 

120 for meta_metric in METRICS_POST_COMPUTE: 

121 temp_period = MetricsPeriod(**temp_period_dict) 

122 computed_value = meta_metric.compute_from_period(temp_period) 

123 temp_period_dict[meta_metric.slug] = computed_value 

124 period_dict[meta_metric.slug] = computed_value 

125 

126 periods.append(MetricsPeriod(**period_dict)) 

127 

128 logfire.info("Processed {row_count} rows", row_count=row_count) 

129 

130 totals: dict[str, int | float] = {} 

131 with logfire.span( 

132 "Get cumulative metrics", 

133 start_date=str(start_date), 

134 end_date=str(end_date), 

135 ): 

136 for metric in METRICS: 

137 totals[metric.slug] = metric.get_cumulative(periods) 

138 

139 return MetricsResponse.model_validate( 

140 { 

141 "periods": periods, 

142 "totals": totals, 

143 "metrics": {m.slug: m for m in METRICS}, 

144 } 

145 ) 

146 

147 

148metrics = MetricsService() 1a