Coverage for polar/metrics/service.py: 25%
57 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
1import uuid 1a
2from collections.abc import Sequence 1a
3from datetime import date, datetime 1a
4from zoneinfo import ZoneInfo 1a
6import logfire 1a
7from sqlalchemy import ColumnElement, FromClause, select, text 1a
9from polar.auth.models import AuthSubject 1a
10from polar.config import settings 1a
11from polar.kit.time_queries import TimeInterval, get_timestamp_series_cte 1a
12from polar.models import Organization, User 1a
13from polar.models.product import ProductBillingType 1a
14from polar.postgres import AsyncReadSession, AsyncSession 1a
16from .metrics import METRICS, METRICS_POST_COMPUTE, METRICS_SQL 1a
17from .queries import QUERIES 1a
18from .schemas import MetricsPeriod, MetricsResponse 1a
21class MetricsService: 1a
22 async def get_metrics( 1a
23 self,
24 session: AsyncSession | AsyncReadSession,
25 auth_subject: AuthSubject[User | Organization],
26 *,
27 start_date: date,
28 end_date: date,
29 timezone: ZoneInfo,
30 interval: TimeInterval,
31 organization_id: Sequence[uuid.UUID] | None = None,
32 product_id: Sequence[uuid.UUID] | None = None,
33 billing_type: Sequence[ProductBillingType] | None = None,
34 customer_id: Sequence[uuid.UUID] | None = None,
35 now: datetime | None = None,
36 ) -> MetricsResponse:
37 await session.execute(text(f"SET LOCAL TIME ZONE '{timezone.key}'"))
38 start_timestamp = datetime(
39 start_date.year, start_date.month, start_date.day, 0, 0, 0, 0, timezone
40 )
41 end_timestamp = datetime(
42 end_date.year, end_date.month, end_date.day, 23, 59, 59, 999999, timezone
43 )
45 # Store original bounds before truncation for filtering queries
46 original_start_timestamp = start_timestamp
47 original_end_timestamp = end_timestamp
49 # Truncate start_timestamp to the beginning of the interval period
50 # This ensures the timestamp series aligns with how daily metrics are grouped
51 if interval == TimeInterval.month:
52 start_timestamp = start_timestamp.replace(day=1)
53 elif interval == TimeInterval.year:
54 start_timestamp = start_timestamp.replace(month=1, day=1)
56 timestamp_series = get_timestamp_series_cte(
57 start_timestamp, end_timestamp, interval
58 )
59 timestamp_column: ColumnElement[datetime] = timestamp_series.c.timestamp
61 queries = [
62 query(
63 timestamp_series,
64 interval,
65 auth_subject,
66 METRICS_SQL,
67 now or datetime.now(tz=timezone),
68 bounds=(original_start_timestamp, original_end_timestamp),
69 organization_id=organization_id,
70 product_id=product_id,
71 billing_type=billing_type,
72 customer_id=customer_id,
73 )
74 for query in QUERIES
75 ]
77 from_query: FromClause = timestamp_series
78 for query in queries:
79 from_query = from_query.join(
80 query,
81 onclause=query.c.timestamp == timestamp_column,
82 )
84 statement = (
85 select(
86 timestamp_column.label("timestamp"),
87 *queries,
88 )
89 .select_from(from_query)
90 .order_by(timestamp_column.asc())
91 )
93 periods: list[MetricsPeriod] = []
94 with logfire.span(
95 "Stream and process metrics query",
96 start_date=str(start_date),
97 end_date=str(end_date),
98 ):
99 result = await session.stream(
100 statement,
101 execution_options={"yield_per": settings.DATABASE_STREAM_YIELD_PER},
102 )
104 row_count = 0
105 with logfire.span("Fetch and process rows"):
106 async for row in result:
107 row_count += 1
108 period_dict = row._asdict()
110 # Compute meta metrics with cascading dependencies
111 # Each metric can depend on previously computed metrics
112 temp_period_dict = dict(period_dict)
114 # Initialize all computed metrics to 0 first to satisfy Pydantic schema
115 for meta_metric in METRICS_POST_COMPUTE:
116 temp_period_dict[meta_metric.slug] = 0
118 # Now compute each metric, updating the dict as we go
119 # This allows later metrics to depend on earlier computed metrics
120 for meta_metric in METRICS_POST_COMPUTE:
121 temp_period = MetricsPeriod(**temp_period_dict)
122 computed_value = meta_metric.compute_from_period(temp_period)
123 temp_period_dict[meta_metric.slug] = computed_value
124 period_dict[meta_metric.slug] = computed_value
126 periods.append(MetricsPeriod(**period_dict))
128 logfire.info("Processed {row_count} rows", row_count=row_count)
130 totals: dict[str, int | float] = {}
131 with logfire.span(
132 "Get cumulative metrics",
133 start_date=str(start_date),
134 end_date=str(end_date),
135 ):
136 for metric in METRICS:
137 totals[metric.slug] = metric.get_cumulative(periods)
139 return MetricsResponse.model_validate(
140 {
141 "periods": periods,
142 "totals": totals,
143 "metrics": {m.slug: m for m in METRICS},
144 }
145 )
148metrics = MetricsService() 1a