Coverage for polar/meter/service.py: 18%
165 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
1import uuid 1a
2from collections.abc import Sequence 1a
3from datetime import UTC, datetime 1a
4from typing import Any 1a
6from sqlalchemy import ( 1a
7 ColumnElement,
8 ColumnExpressionArgument,
9 Select,
10 UnaryExpression,
11 and_,
12 asc,
13 cte,
14 desc,
15 func,
16 or_,
17 select,
18)
19from sqlalchemy.orm import joinedload 1a
21from polar.auth.models import AuthSubject, Organization, User 1a
22from polar.billing_entry.repository import BillingEntryRepository 1a
23from polar.config import settings 1a
24from polar.event.repository import EventRepository 1a
25from polar.exceptions import PolarError, PolarRequestValidationError, ValidationError 1a
26from polar.kit.metadata import MetadataQuery, apply_metadata_clause, get_metadata_clause 1a
27from polar.kit.pagination import PaginationParams 1a
28from polar.kit.sorting import Sorting 1a
29from polar.kit.time_queries import TimeInterval, get_timestamp_series_cte 1a
30from polar.meter.aggregation import AggregationFunction 1a
31from polar.models import ( 1a
32 Benefit,
33 BillingEntry,
34 Customer,
35 Event,
36 Meter,
37 Product,
38 ProductPriceMeteredUnit,
39 SubscriptionProductPrice,
40)
41from polar.organization.resolver import get_payload_organization 1a
42from polar.postgres import AsyncReadSession, AsyncSession 1a
43from polar.subscription.repository import ( 1a
44 CustomerSubscriptionProductPrice,
45 SubscriptionProductPriceRepository,
46)
47from polar.worker import enqueue_job 1a
49from .repository import MeterRepository 1a
50from .schemas import MeterCreate, MeterQuantities, MeterQuantity, MeterUpdate 1a
51from .sorting import MeterSortProperty 1a
54class MeterError(PolarError): ... 1a
57class MeterService: 1a
58 async def list( 1a
59 self,
60 session: AsyncReadSession,
61 auth_subject: AuthSubject[User | Organization],
62 *,
63 organization_id: Sequence[uuid.UUID] | None = None,
64 metadata: MetadataQuery | None = None,
65 query: str | None = None,
66 is_archived: bool | None = None,
67 pagination: PaginationParams,
68 sorting: list[Sorting[MeterSortProperty]] = [
69 (MeterSortProperty.meter_name, False)
70 ],
71 ) -> tuple[Sequence[Meter], int]:
72 repository = MeterRepository.from_session(session)
73 statement = repository.get_readable_statement(auth_subject)
75 if organization_id is not None:
76 statement = statement.where(Meter.organization_id.in_(organization_id))
78 if query is not None:
79 statement = statement.where(Meter.name.ilike(f"%{query}%"))
81 if is_archived is not None:
82 if is_archived:
83 statement = statement.where(Meter.archived_at.is_not(None))
84 else:
85 statement = statement.where(Meter.archived_at.is_(None))
87 if metadata is not None:
88 statement = apply_metadata_clause(Meter, statement, metadata)
90 order_by_clauses: list[UnaryExpression[Any]] = []
91 for criterion, is_desc in sorting:
92 clause_function = desc if is_desc else asc
93 if criterion == MeterSortProperty.created_at:
94 order_by_clauses.append(clause_function(Meter.created_at))
95 elif criterion == MeterSortProperty.meter_name:
96 order_by_clauses.append(clause_function(Meter.name))
97 statement = statement.order_by(*order_by_clauses)
99 return await repository.paginate(
100 statement, limit=pagination.limit, page=pagination.page
101 )
103 async def get( 1a
104 self,
105 session: AsyncReadSession,
106 auth_subject: AuthSubject[User | Organization],
107 id: uuid.UUID,
108 ) -> Meter | None:
109 repository = MeterRepository.from_session(session)
110 statement = (
111 repository.get_readable_statement(auth_subject)
112 .where(Meter.id == id)
113 .options(joinedload(Meter.last_billed_event))
114 )
115 return await repository.get_one_or_none(statement)
117 async def create( 1a
118 self,
119 session: AsyncSession,
120 meter_create: MeterCreate,
121 auth_subject: AuthSubject[User | Organization],
122 ) -> Meter:
123 repository = MeterRepository.from_session(session)
124 organization = await get_payload_organization(
125 session, auth_subject, meter_create
126 )
128 meter = await repository.create(
129 Meter(
130 **meter_create.model_dump(
131 by_alias=True, exclude={"filter", "aggregation"}
132 ),
133 filter=meter_create.filter,
134 aggregation=meter_create.aggregation,
135 organization=organization,
136 ),
137 flush=True,
138 )
140 # Retrieve the latest matching event for the meter and set it as the last billed event
141 # This is done to ensure that the meter is billed from the last event onwards
142 event_repository = EventRepository.from_session(session)
143 statement = (
144 event_repository.get_meter_statement(meter)
145 .order_by(Event.timestamp.desc())
146 .limit(1)
147 )
148 last_billed_event = await event_repository.get_one_or_none(statement)
149 await repository.update(
150 meter, update_dict={"last_billed_event": last_billed_event}
151 )
153 return meter
155 async def update( 1a
156 self, session: AsyncSession, meter: Meter, meter_update: MeterUpdate
157 ) -> Meter:
158 repository = MeterRepository.from_session(session)
160 errors: list[ValidationError] = []
161 if meter.last_billed_event is not None:
162 sensitive_fields = {"filter", "aggregation"}
163 for sensitive_field in sensitive_fields:
164 if sensitive_field in meter_update.model_fields_set:
165 errors.append(
166 {
167 "type": "forbidden",
168 "loc": ("body", sensitive_field),
169 "msg": (
170 "This field can't be updated because the meter "
171 "is already aggregating events."
172 ),
173 "input": getattr(meter_update, sensitive_field),
174 }
175 )
177 if errors:
178 raise PolarRequestValidationError(errors)
180 update_dict = meter_update.model_dump(
181 by_alias=True,
182 exclude_unset=True,
183 exclude={"filter", "aggregation", "is_archived"},
184 )
185 if meter_update.filter is not None:
186 update_dict["filter"] = meter_update.filter
187 if meter_update.aggregation is not None:
188 update_dict["aggregation"] = meter_update.aggregation
190 # Handle archiving/unarchiving
191 if meter_update.is_archived is not None:
192 if meter_update.is_archived:
193 meter = await self.archive(session, meter)
194 else:
195 meter = await self.unarchive(session, meter)
197 return await repository.update(meter, update_dict=update_dict)
199 async def archive(self, session: AsyncSession, meter: Meter) -> Meter: 1a
200 # Check if meter is attached to any active ProductPriceMeteredUnit
201 active_prices = await session.scalar(
202 select(func.count(ProductPriceMeteredUnit.id))
203 .join(Product)
204 .where(
205 Product.is_archived.is_(False),
206 ProductPriceMeteredUnit.meter_id == meter.id,
207 ProductPriceMeteredUnit.is_archived.is_(False),
208 ProductPriceMeteredUnit.deleted_at.is_(None),
209 )
210 )
212 if active_prices and active_prices > 0:
213 raise PolarRequestValidationError(
214 [
215 {
216 "type": "value_error",
217 "loc": ("body", "is_archived"),
218 "msg": "Cannot archive meter that is still attached to active products",
219 "input": True,
220 }
221 ]
222 )
224 # Check if meter is referenced by any active Benefits with meter_credit type
225 active_benefits = await session.scalar(
226 select(func.count(Benefit.id)).where(
227 Benefit.type == "meter_credit",
228 Benefit.properties["meter_id"].as_string() == str(meter.id),
229 Benefit.deleted_at.is_(None),
230 )
231 )
233 if active_benefits and active_benefits > 0:
234 raise PolarRequestValidationError(
235 [
236 {
237 "type": "value_error",
238 "loc": ("body", "is_archived"),
239 "msg": "Cannot archive meter that is still referenced by active benefits",
240 "input": True,
241 }
242 ]
243 )
245 repository = MeterRepository.from_session(session)
246 return await repository.update(
247 meter, update_dict={"archived_at": datetime.now(UTC)}
248 )
250 async def unarchive(self, session: AsyncSession, meter: Meter) -> Meter: 1a
251 repository = MeterRepository.from_session(session)
252 return await repository.update(meter, update_dict={"archived_at": None})
254 async def events( 1a
255 self,
256 session: AsyncSession,
257 meter: Meter,
258 *,
259 pagination: PaginationParams,
260 ) -> tuple[Sequence[Event], int]:
261 repository = EventRepository.from_session(session)
262 statement = repository.get_meter_statement(meter).order_by(
263 Event.timestamp.desc()
264 )
265 return await repository.paginate(
266 statement, limit=pagination.limit, page=pagination.page
267 )
269 async def get_quantities( 1a
270 self,
271 session: AsyncReadSession,
272 meter: Meter,
273 *,
274 start_timestamp: datetime,
275 end_timestamp: datetime,
276 interval: TimeInterval,
277 customer_id: Sequence[uuid.UUID] | None = None,
278 external_customer_id: Sequence[str] | None = None,
279 metadata: MetadataQuery | None = None,
280 customer_aggregation_function: AggregationFunction | None = None,
281 ) -> MeterQuantities:
282 timestamp_series = get_timestamp_series_cte(
283 start_timestamp, end_timestamp, interval
284 )
285 timestamp_column: ColumnElement[datetime] = timestamp_series.c.timestamp
287 event_clauses: list[ColumnExpressionArgument[bool]] = [
288 Event.organization_id == meter.organization_id,
289 ]
290 event_repository = EventRepository.from_session(session)
291 if customer_id is not None:
292 event_clauses.append(
293 event_repository.get_customer_id_filter_clause(customer_id)
294 )
295 if external_customer_id is not None:
296 event_clauses.append(
297 event_repository.get_external_customer_id_filter_clause(
298 external_customer_id
299 )
300 )
301 if metadata is not None:
302 event_clauses.append(get_metadata_clause(Event, metadata))
303 event_clauses.append(event_repository.get_meter_clause(meter))
305 statement = (
306 select(
307 timestamp_column.label("timestamp"),
308 func.coalesce(
309 meter.aggregation.get_sql_column(Event).filter(
310 interval.sql_date_trunc(Event.timestamp)
311 == interval.sql_date_trunc(timestamp_column),
312 ),
313 0,
314 ).label("quantity"),
315 func.coalesce(
316 meter.aggregation.get_sql_column(Event).filter(
317 interval.sql_date_trunc(Event.timestamp)
318 >= interval.sql_date_trunc(start_timestamp),
319 interval.sql_date_trunc(Event.timestamp)
320 <= interval.sql_date_trunc(end_timestamp),
321 ),
322 0,
323 ).label("total"),
324 )
325 .join(Event, onclause=and_(*event_clauses), isouter=True)
326 .group_by(timestamp_column)
327 .order_by(timestamp_column.asc())
328 )
330 if customer_aggregation_function is not None:
331 inner_statement = cte(
332 statement.add_columns(Event.resolved_customer_id).group_by(
333 timestamp_column, Event.resolved_customer_id
334 )
335 )
336 statement = (
337 select(
338 timestamp_column.label("timestamp"),
339 customer_aggregation_function.get_sql_function(
340 inner_statement.c.quantity
341 ).label("quantity"),
342 customer_aggregation_function.get_sql_function(
343 inner_statement.c.total
344 ).label("total"),
345 )
346 .join(
347 inner_statement,
348 onclause=inner_statement.c.timestamp == timestamp_column,
349 )
350 .group_by(timestamp_column)
351 .order_by(timestamp_column.asc())
352 )
354 total = 0.0
355 quantities: list[MeterQuantity] = []
356 result = await session.stream(
357 statement,
358 execution_options={"yield_per": settings.DATABASE_STREAM_YIELD_PER},
359 )
360 async for row in result:
361 quantities.append(MeterQuantity.model_validate(row))
362 total = row.total
364 return MeterQuantities(quantities=quantities, total=total)
366 async def enqueue_billing(self, session: AsyncSession) -> None: 1a
367 repository = MeterRepository.from_session(session)
368 statement = repository.get_base_statement().order_by(Meter.created_at.asc())
369 async for meter in repository.stream(statement):
370 enqueue_job("meter.billing_entries", meter.id)
372 async def _create_subscription_holder_billing_entry( 1a
373 self,
374 session: AsyncSession,
375 event: Event,
376 customer: "Customer",
377 subscription_product_price: SubscriptionProductPrice,
378 ) -> BillingEntry:
379 billing_entry_repository = BillingEntryRepository.from_session(session)
380 return await billing_entry_repository.create(
381 BillingEntry.from_metered_event(customer, subscription_product_price, event)
382 )
384 async def create_billing_entries( 1a
385 self, session: AsyncSession, meter: Meter
386 ) -> Sequence[BillingEntry]:
387 event_repository = EventRepository.from_session(session)
388 statement = (
389 event_repository.get_base_statement()
390 .where(
391 Event.organization_id == meter.organization_id,
392 Event.customer.is_not(None),
393 or_(
394 # Events matching meter definitions
395 event_repository.get_meter_clause(meter),
396 # System events impacting the meter balance
397 event_repository.get_meter_system_clause(meter),
398 ),
399 )
400 .order_by(Event.ingested_at.asc())
401 .options(*event_repository.get_eager_options())
402 )
403 last_billed_event = meter.last_billed_event
404 if last_billed_event is not None:
405 statement = statement.where(
406 Event.ingested_at > last_billed_event.ingested_at
407 )
409 subscription_product_price_repository = (
410 SubscriptionProductPriceRepository.from_session(session)
411 )
412 customer_price_map: dict[
413 uuid.UUID, CustomerSubscriptionProductPrice | None
414 ] = {}
416 entries: list[BillingEntry] = []
417 updated_subscriptions: set[uuid.UUID] = set()
418 last_event: Event | None = None
419 async for event in event_repository.stream(statement):
420 last_event = event
421 customer = event.customer
422 assert customer is not None
424 # Retrieve the paying customer and subscription product price
425 try:
426 customer_price = customer_price_map[customer.id]
427 except KeyError:
428 customer_price = await subscription_product_price_repository.get_by_customer_and_meter(
429 customer.id, meter.id
430 )
431 customer_price_map[customer.id] = customer_price
433 if customer_price is None:
434 continue
436 # Get the paying customer (billing manager) from the subscription
437 paying_customer = (
438 customer_price.subscription_product_price.subscription.customer
439 )
441 entry = await self._create_subscription_holder_billing_entry(
442 session,
443 event,
444 paying_customer,
445 customer_price.subscription_product_price,
446 )
447 entries.append(entry)
448 if entry.subscription is not None:
449 updated_subscriptions.add(entry.subscription.id)
451 meter.last_billed_event = (
452 last_event if last_event is not None else last_billed_event
453 )
454 session.add(meter)
456 for subscription_id in updated_subscriptions:
457 enqueue_job("subscription.update_meters", subscription_id)
459 return entries
461 async def get_quantity( 1a
462 self,
463 session: AsyncSession,
464 meter: Meter,
465 events_statement: Select[tuple[uuid.UUID]],
466 ) -> float:
467 statement = select(
468 func.coalesce(meter.aggregation.get_sql_column(Event), 0)
469 ).where(Event.id.in_(events_statement))
470 result = await session.scalar(statement)
471 return result or 0.0
474meter = MeterService() 1a