Coverage for polar/event/repository.py: 13%
192 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
1from collections.abc import Sequence 1a
2from datetime import datetime 1a
3from typing import Any 1a
4from uuid import UUID 1a
6from sqlalchemy import ( 1a
7 ColumnElement,
8 ColumnExpressionArgument,
9 Numeric,
10 Select,
11 UnaryExpression,
12 and_,
13 asc,
14 cast,
15 desc,
16 func,
17 literal_column,
18 or_,
19 over,
20 select,
21 text,
22)
23from sqlalchemy.dialects.postgresql import insert 1a
24from sqlalchemy.orm import aliased, joinedload 1a
26from polar.auth.models import AuthSubject, Organization, User, is_organization, is_user 1a
27from polar.kit.repository import RepositoryBase, RepositoryIDMixin 1a
28from polar.kit.repository.base import Options 1a
29from polar.kit.utils import generate_uuid 1a
30from polar.models import ( 1a
31 BillingEntry,
32 Customer,
33 Event,
34 EventType,
35 Meter,
36 UserOrganization,
37)
38from polar.models.event import EventClosure, EventSource 1a
39from polar.models.product_price import ProductPriceMeteredUnit 1a
41from .system import SystemEvent 1a
44class EventRepository(RepositoryBase[Event], RepositoryIDMixin[Event, UUID]): 1a
45 model = Event 1a
47 async def get_all_by_name(self, name: str) -> Sequence[Event]: 1a
48 statement = self.get_base_statement().where(Event.name == name)
49 return await self.get_all(statement)
51 async def get_all_by_organization(self, organization_id: UUID) -> Sequence[Event]: 1a
52 statement = self.get_base_statement().where(
53 Event.organization_id == organization_id
54 )
55 return await self.get_all(statement)
57 async def insert_batch( 1a
58 self, events: Sequence[dict[str, Any]]
59 ) -> tuple[Sequence[UUID], int]:
60 if not events:
61 return [], 0
63 events_needing_parent_lookup = []
65 # Set root_id for root events before insertion
66 for event in events:
67 if event.get("root_id") is not None:
68 continue
69 elif event.get("parent_id") is None:
70 if event.get("id") is None:
71 event["id"] = generate_uuid()
72 event["root_id"] = event["id"]
73 else:
74 # Child event without root_id - needs to be looked up from parent
75 # This is a fail-safe in the event that we did not set this before calling
76 # insert_batch
77 events_needing_parent_lookup.append(event)
79 # Look up root_id from parents for events that need it
80 if events_needing_parent_lookup:
81 parent_ids = {event["parent_id"] for event in events_needing_parent_lookup}
82 result = await self.session.execute(
83 select(Event.id, Event.root_id).where(Event.id.in_(parent_ids))
84 )
85 parent_root_map = {
86 parent_id: root_id or parent_id for parent_id, root_id in result
87 }
89 for event in events_needing_parent_lookup:
90 parent_id = event["parent_id"]
91 event["root_id"] = parent_root_map.get(parent_id, parent_id)
93 statement = (
94 insert(Event)
95 .on_conflict_do_nothing(index_elements=["external_id"])
96 .returning(Event.id)
97 )
98 result = await self.session.execute(statement, events)
99 inserted_ids = [row[0] for row in result.all()]
101 duplicates_count = len(events) - len(inserted_ids)
103 return inserted_ids, duplicates_count
105 async def get_latest_meter_reset( 1a
106 self, customer: Customer, meter_id: UUID
107 ) -> Event | None:
108 statement = (
109 self.get_base_statement()
110 .where(
111 Event.customer == customer,
112 Event.source == EventSource.system,
113 Event.name == SystemEvent.meter_reset,
114 Event.user_metadata["meter_id"].as_string() == str(meter_id),
115 )
116 .order_by(Event.timestamp.desc())
117 .limit(1)
118 )
119 return await self.get_one_or_none(statement)
121 def get_event_names_statement( 1a
122 self, auth_subject: AuthSubject[User | Organization]
123 ) -> Select[tuple[str, EventSource, int, datetime, datetime]]:
124 return (
125 self.get_readable_statement(auth_subject)
126 .with_only_columns(
127 Event.name,
128 Event.source,
129 func.count(Event.id).label("occurrences"),
130 func.min(Event.timestamp).label("first_seen"),
131 func.max(Event.timestamp).label("last_seen"),
132 )
133 .group_by(Event.name, Event.source)
134 )
136 def get_readable_statement( 1a
137 self, auth_subject: AuthSubject[User | Organization]
138 ) -> Select[tuple[Event]]:
139 statement = self.get_base_statement()
141 if is_user(auth_subject):
142 user = auth_subject.subject
143 statement = statement.where(
144 Event.organization_id.in_(
145 select(UserOrganization.organization_id).where(
146 UserOrganization.user_id == user.id,
147 UserOrganization.deleted_at.is_(None),
148 )
149 )
150 )
152 elif is_organization(auth_subject):
153 statement = statement.where(
154 Event.organization_id == auth_subject.subject.id
155 )
157 return statement
159 def get_customer_id_filter_clause( 1a
160 self, customer_id: Sequence[UUID]
161 ) -> ColumnElement[bool]:
162 return or_(
163 Event.customer_id.in_(customer_id),
164 Event.external_customer_id.in_(
165 select(Customer.external_id).where(Customer.id.in_(customer_id))
166 ),
167 )
169 def get_external_customer_id_filter_clause( 1a
170 self, external_customer_id: Sequence[str]
171 ) -> ColumnElement[bool]:
172 return or_(
173 Event.external_customer_id.in_(external_customer_id),
174 Event.customer_id.in_(
175 select(Customer.id).where(
176 Customer.external_id.in_(external_customer_id)
177 )
178 ),
179 )
181 def get_meter_clause(self, meter: Meter) -> ColumnExpressionArgument[bool]: 1a
182 return and_(
183 meter.filter.get_sql_clause(Event),
184 # Additional clauses to make sure we work on rows with the right type for aggregation
185 meter.aggregation.get_sql_clause(Event),
186 )
188 def get_meter_system_clause(self, meter: Meter) -> ColumnExpressionArgument[bool]: 1a
189 return and_(
190 Event.source == EventSource.system,
191 Event.name.in_((SystemEvent.meter_credited, SystemEvent.meter_reset)),
192 Event.user_metadata["meter_id"].as_string() == str(meter.id),
193 )
195 def get_meter_statement(self, meter: Meter) -> Select[tuple[Event]]: 1a
196 return self.get_base_statement().where(
197 Event.organization_id == meter.organization_id,
198 self.get_meter_clause(meter),
199 )
201 def get_by_pending_entries_statement( 1a
202 self, subscription: UUID, price: UUID
203 ) -> Select[tuple[Event]]:
204 return (
205 self.get_base_statement()
206 .join(BillingEntry, Event.id == BillingEntry.event_id)
207 .where(
208 BillingEntry.subscription_id == subscription,
209 BillingEntry.order_item_id.is_(None),
210 BillingEntry.product_price_id == price,
211 )
212 .order_by(Event.ingested_at.asc())
213 )
215 def get_by_pending_entries_for_meter_statement( 1a
216 self, subscription: UUID, meter: UUID
217 ) -> Select[tuple[Event]]:
218 """
219 Get events for pending billing entries grouped by meter.
220 Used for non-summable aggregations where we need to compute across all events
221 in the period, regardless of which price was active when the event occurred.
222 """
223 return (
224 self.get_base_statement()
225 .join(BillingEntry, Event.id == BillingEntry.event_id)
226 .join(
227 ProductPriceMeteredUnit,
228 BillingEntry.product_price_id == ProductPriceMeteredUnit.id,
229 )
230 .where(
231 BillingEntry.subscription_id == subscription,
232 BillingEntry.order_item_id.is_(None),
233 ProductPriceMeteredUnit.meter_id == meter,
234 )
235 .order_by(Event.ingested_at.asc())
236 )
238 def get_eager_options(self) -> Options: 1a
239 return (joinedload(Event.customer), joinedload(Event.event_types))
241 async def list_with_closure_table( 1a
242 self,
243 statement: Select[tuple[Event]],
244 limit: int,
245 page: int,
246 aggregate_fields: Sequence[str] = (),
247 ) -> tuple[Sequence[Event], int]:
248 """
249 List events using closure table to get a correct children_count.
250 Optionally aggregates fields from descendants's metadata.
251 """
252 descendant_event = aliased(Event, name="descendant_event")
254 # Step 1: Get paginated event IDs with total count
255 offset = (page - 1) * limit
257 paginated_events_subquery = (
258 statement.add_columns(over(func.count()).label("total_count"))
259 .limit(limit)
260 .offset(offset)
261 ).subquery("paginated_events")
263 aggregation_columns: list[Any] = [
264 EventClosure.ancestor_id,
265 (func.count() - 1).label("descendant_count"),
266 ]
268 field_aggregations = {}
269 for field_path in aggregate_fields:
270 pg_path = "{" + field_path.replace(".", ",") + "}"
271 label = f"agg_{field_path.replace('.', '_')}"
273 # Only aggregate numeric fields by summing them
274 # Returns NULL if no values to sum or if all values are NULL
275 numeric_expr = cast(
276 descendant_event.user_metadata.op("#>>")(
277 literal_column(f"'{pg_path}'")
278 ),
279 Numeric,
280 )
282 aggregation_columns.append(func.sum(numeric_expr).label(label))
283 field_aggregations[field_path] = label
285 paginated_event_id = paginated_events_subquery.c.id
287 aggregations_lateral = (
288 select(*aggregation_columns)
289 .select_from(EventClosure)
290 .join(descendant_event, EventClosure.descendant_id == descendant_event.id)
291 .where(EventClosure.ancestor_id == paginated_event_id)
292 .group_by(EventClosure.ancestor_id)
293 ).lateral("aggregations")
295 # Reference user_metadata from the paginated subquery
296 paginated_user_metadata = paginated_events_subquery.c.user_metadata
298 metadata_expr: Any = paginated_user_metadata
299 if aggregate_fields:
300 for field_path, label in field_aggregations.items():
301 parts = field_path.split(".")
302 pg_path = "{" + ",".join(parts) + "}"
303 agg_column = getattr(aggregations_lateral.c, label)
305 # For nested paths, jsonb_set with create_if_missing doesn't work reliably
306 # Use deep merge approach: extract parent, merge, set back
307 if len(parts) > 1:
308 # Build the full nested structure
309 # For "_cost.amount"=7: {"_cost": {"amount": 7}}
310 nested_value = func.to_jsonb(agg_column)
311 for part in reversed(parts):
312 nested_value = func.jsonb_build_object(part, nested_value)
314 # Deep merge: get existing parent object, merge with new, set back
315 parent_key = parts[0]
316 existing_parent = func.coalesce(
317 metadata_expr.op("->")(parent_key), text("'{}'::jsonb")
318 )
319 merged_parent = existing_parent.op("||")(
320 nested_value.op("->")(parent_key)
321 )
323 metadata_expr = func.jsonb_set(
324 metadata_expr,
325 text(f"'{{{parent_key}}}'"),
326 merged_parent,
327 text("true"),
328 )
329 else:
330 # Simple top-level key
331 metadata_expr = func.jsonb_set(
332 metadata_expr,
333 text(f"'{pg_path}'"),
334 func.to_jsonb(agg_column),
335 text("true"),
336 )
338 # Step 2: Join back to Event table to get full ORM objects with relationships
339 final_query = (
340 select(Event, paginated_events_subquery.c.total_count)
341 .select_from(paginated_events_subquery)
342 .join(Event, Event.id == paginated_events_subquery.c.id)
343 .add_columns(
344 func.coalesce(aggregations_lateral.c.descendant_count, 0).label(
345 "child_count"
346 ),
347 metadata_expr.label("aggregated_metadata"),
348 )
349 .outerjoin(aggregations_lateral, literal_column("true"))
350 .options(*self.get_eager_options())
351 )
353 result = await self.session.execute(final_query)
354 rows = result.all()
356 events = []
357 total_count = 0
358 for row in rows:
359 event = row[0]
360 event.child_count = row.child_count
362 if aggregate_fields:
363 aggregated = row.aggregated_metadata
364 # If _cost exists but has None/missing fields, clean it up
365 if "_cost" in aggregated:
366 cost_obj = aggregated.get("_cost")
367 if cost_obj is None or cost_obj.get("amount") is None:
368 # Remove incomplete _cost object entirely
369 del aggregated["_cost"]
370 elif "currency" not in cost_obj:
371 # Add default currency if missing
372 cost_obj["currency"] = "usd" # FIXME: Main Polar currency
374 event.user_metadata = aggregated
376 # Expunge the event from the session to prevent modifications from being persisted
377 # We're only modifying transient display fields (child_count, aggregated metadata)
378 self.session.expunge(event)
380 events.append(event)
381 total_count = row.total_count
383 return events, total_count
385 async def get_hierarchy_stats( 1a
386 self,
387 statement: Select[tuple[Event]],
388 aggregate_fields: Sequence[str] = ("cost.amount",),
389 sorting: Sequence[tuple[str, bool]] = (("total", True),),
390 timestamp_series: Any = None,
391 ) -> Sequence[dict[str, Any]]:
392 """
393 Get aggregate statistics grouped by root event name across all hierarchies.
395 Uses root_id for efficient rollup and joins with event_types for labels:
396 1. Filter root events based on statement
397 2. Roll up costs from all events in each hierarchy (via root_id)
398 3. Calculate avg, p95, p99 on those rolled-up totals across root events with same name
399 4. Join with event_types to include labels
401 Args:
402 statement: Base query for root events to include
403 aggregate_fields: List of user_metadata field paths to aggregate
404 sorting: List of (property, is_desc) tuples for sorting
405 timestamp_series: Optional CTE for time bucketing. If provided, stats are grouped by timestamp.
407 Returns:
408 List of dicts containing name, label, occurrences, and statistics for each field.
409 If timestamp_series is provided, also includes timestamp for each row.
410 """
411 root_events_subquery = statement.where(
412 and_(Event.parent_id.is_(None), Event.source == EventSource.user)
413 ).subquery()
415 all_events = aliased(Event, name="all_events")
417 per_root_select_exprs: list[ColumnElement[Any]] = [
418 literal_column("root_event.id").label("root_id"),
419 literal_column("root_event.name").label("root_name"),
420 literal_column("root_event.organization_id").label("root_org_id"),
421 ]
423 if timestamp_series is not None:
424 per_root_select_exprs.append(
425 literal_column("root_event.timestamp").label("root_timestamp")
426 )
428 for field_path in aggregate_fields:
429 field_parts = field_path.split(".")
430 pg_path = "{" + ",".join(field_parts) + "}"
431 safe_field_name = field_path.replace(".", "_")
433 field_expr = cast(
434 all_events.user_metadata.op("#>>")(literal_column(f"'{pg_path}'")),
435 Numeric,
436 )
438 sum_expr = func.sum(field_expr).label(f"{safe_field_name}_total")
439 per_root_select_exprs.append(sum_expr)
441 group_by_exprs: list[ColumnElement[Any]] = [
442 literal_column("root_event.id"),
443 literal_column("root_event.name"),
444 literal_column("root_event.organization_id"),
445 ]
446 if timestamp_series is not None:
447 group_by_exprs.append(literal_column("root_event.timestamp"))
449 per_root_query = (
450 select(*per_root_select_exprs)
451 .select_from(root_events_subquery.alias("root_event"))
452 .join(all_events, all_events.root_id == literal_column("root_event.id"))
453 .group_by(*group_by_exprs)
454 )
456 per_root_subquery = per_root_query.subquery("per_root_totals")
458 event_type = aliased(EventType, name="event_type")
460 if timestamp_series is not None:
461 timestamp_column: ColumnElement[datetime] = timestamp_series.c.timestamp
463 timestamp_with_next = (
464 select(
465 timestamp_column.label("bucket_start"),
466 func.lead(timestamp_column)
467 .over(order_by=timestamp_column)
468 .label("bucket_end"),
469 ).select_from(timestamp_series)
470 ).subquery("timestamp_with_next")
472 bucketed_columns = [
473 timestamp_with_next.c.bucket_start.label("bucket"),
474 per_root_subquery.c.root_name,
475 per_root_subquery.c.root_org_id,
476 ]
477 for field_path in aggregate_fields:
478 safe_field_name = field_path.replace(".", "_")
479 bucketed_columns.append(
480 getattr(per_root_subquery.c, f"{safe_field_name}_total")
481 )
483 bucketed_subquery = (
484 select(*bucketed_columns)
485 .select_from(timestamp_with_next)
486 .outerjoin(
487 per_root_subquery,
488 and_(
489 per_root_subquery.c.root_timestamp
490 >= timestamp_with_next.c.bucket_start,
491 or_(
492 timestamp_with_next.c.bucket_end.is_(None),
493 per_root_subquery.c.root_timestamp
494 < timestamp_with_next.c.bucket_end,
495 ),
496 ),
497 )
498 ).subquery("bucketed")
500 aggregation_exprs = []
501 for field_path in aggregate_fields:
502 safe_field_name = field_path.replace(".", "_")
503 total_col: ColumnElement[Any] = getattr(
504 bucketed_subquery.c, f"{safe_field_name}_total"
505 )
507 aggregation_exprs.extend(
508 [
509 func.sum(total_col).label(f"{safe_field_name}_sum"),
510 func.avg(func.coalesce(total_col, 0)).label(
511 f"{safe_field_name}_avg"
512 ),
513 func.percentile_cont(0.5)
514 .within_group(func.coalesce(total_col, 0))
515 .label(f"{safe_field_name}_p50"),
516 func.percentile_cont(0.95)
517 .within_group(func.coalesce(total_col, 0))
518 .label(f"{safe_field_name}_p95"),
519 func.percentile_cont(0.99)
520 .within_group(func.coalesce(total_col, 0))
521 .label(f"{safe_field_name}_p99"),
522 ]
523 )
525 stats_query = (
526 select(
527 bucketed_subquery.c.bucket.label("timestamp"),
528 bucketed_subquery.c.root_name.label("name"),
529 event_type.id.label("event_type_id"),
530 event_type.label.label("label"),
531 func.count(
532 getattr(
533 bucketed_subquery.c,
534 f"{aggregate_fields[0].replace('.', '_')}_total",
535 )
536 ).label("occurrences"),
537 *aggregation_exprs,
538 )
539 .select_from(bucketed_subquery)
540 .outerjoin(
541 event_type,
542 and_(
543 event_type.name == bucketed_subquery.c.root_name,
544 event_type.organization_id == bucketed_subquery.c.root_org_id,
545 ),
546 )
547 .group_by(
548 bucketed_subquery.c.bucket,
549 bucketed_subquery.c.root_name,
550 event_type.id,
551 event_type.label,
552 )
553 )
554 else:
555 aggregation_exprs = []
556 for field_path in aggregate_fields:
557 safe_field_name = field_path.replace(".", "_")
558 total_col_ref: ColumnElement[Any] = literal_column(
559 f"{safe_field_name}_total"
560 )
562 aggregation_exprs.extend(
563 [
564 func.sum(total_col_ref).label(f"{safe_field_name}_sum"),
565 func.avg(func.coalesce(total_col_ref, 0)).label(
566 f"{safe_field_name}_avg"
567 ),
568 func.percentile_cont(0.5)
569 .within_group(func.coalesce(total_col_ref, 0))
570 .label(f"{safe_field_name}_p50"),
571 func.percentile_cont(0.95)
572 .within_group(func.coalesce(total_col_ref, 0))
573 .label(f"{safe_field_name}_p95"),
574 func.percentile_cont(0.99)
575 .within_group(func.coalesce(total_col_ref, 0))
576 .label(f"{safe_field_name}_p99"),
577 ]
578 )
580 stats_query = (
581 select(
582 per_root_subquery.c.root_name.label("name"),
583 event_type.id.label("event_type_id"),
584 event_type.label.label("label"),
585 func.count(per_root_subquery.c.root_id).label("occurrences"),
586 *aggregation_exprs,
587 )
588 .select_from(per_root_subquery)
589 .outerjoin(
590 event_type,
591 and_(
592 event_type.name == per_root_subquery.c.root_name,
593 event_type.organization_id == per_root_subquery.c.root_org_id,
594 ),
595 )
596 .group_by(
597 per_root_subquery.c.root_name, event_type.id, event_type.label
598 )
599 )
601 order_by_clauses: list[UnaryExpression[Any]] = []
603 if timestamp_series is not None:
604 order_by_clauses.append(asc(text("timestamp")))
606 for criterion, is_desc_sort in sorting:
607 clause_function = desc if is_desc_sort else asc
608 if criterion == "name":
609 order_by_clauses.append(clause_function(text("name")))
610 elif criterion == "occurrences":
611 order_by_clauses.append(clause_function(text("occurrences")))
612 elif criterion in ("total", "average", "p95", "p99"):
613 if aggregate_fields:
614 safe_field_name = aggregate_fields[0].replace(".", "_")
615 suffix_map = {
616 "total": "sum",
617 "average": "avg",
618 "p95": "p95",
619 "p99": "p99",
620 }
621 suffix = suffix_map[criterion]
622 order_by_clauses.append(
623 clause_function(text(f"{safe_field_name}_{suffix}"))
624 )
626 if order_by_clauses:
627 stats_query = stats_query.order_by(*order_by_clauses)
629 result = await self.session.execute(stats_query)
630 rows = result.all()
632 result_list = []
633 for row in rows:
634 row_dict = {
635 "name": row.name,
636 "label": row.label,
637 "event_type_id": row.event_type_id,
638 "occurrences": row.occurrences,
639 "totals": {
640 field.replace(".", "_"): getattr(
641 row, f"{field.replace('.', '_')}_sum"
642 )
643 or 0
644 for field in aggregate_fields
645 },
646 "averages": {
647 field.replace(".", "_"): getattr(
648 row, f"{field.replace('.', '_')}_avg"
649 )
650 or 0
651 for field in aggregate_fields
652 },
653 "p50": {
654 field.replace(".", "_"): getattr(
655 row, f"{field.replace('.', '_')}_p50"
656 )
657 or 0
658 for field in aggregate_fields
659 },
660 "p95": {
661 field.replace(".", "_"): getattr(
662 row, f"{field.replace('.', '_')}_p95"
663 )
664 or 0
665 for field in aggregate_fields
666 },
667 "p99": {
668 field.replace(".", "_"): getattr(
669 row, f"{field.replace('.', '_')}_p99"
670 )
671 or 0
672 for field in aggregate_fields
673 },
674 }
676 if timestamp_series is not None:
677 row_dict["timestamp"] = row.timestamp
679 result_list.append(row_dict)
681 return result_list