Coverage for polar/billing_entry/service.py: 33%
119 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
1import dataclasses 1a
2import uuid 1a
3from collections.abc import Sequence 1a
4from datetime import datetime 1a
5from typing import cast 1a
7import structlog 1a
8from babel.dates import format_date 1a
9from sqlalchemy.util.typing import Literal 1a
10from typing_extensions import AsyncGenerator 1a
12from polar.event.repository import EventRepository 1a
13from polar.integrations.stripe.service import stripe as stripe_service 1a
14from polar.kit.math import non_negative_running_sum 1a
15from polar.meter.service import meter as meter_service 1a
16from polar.models import BillingEntry, Event, OrderItem, Subscription 1a
17from polar.models.billing_entry import BillingEntryDirection, BillingEntryType 1a
18from polar.models.event import EventSource 1a
19from polar.postgres import AsyncSession 1a
20from polar.product.guard import ( 1a
21 MeteredPrice,
22 StaticPrice,
23 is_metered_price,
24)
25from polar.product.repository import ProductPriceRepository, ProductRepository 1a
26from polar.worker._enqueue import enqueue_job 1a
28from .repository import BillingEntryRepository 1a
30log = structlog.get_logger(__name__) 1a
33@dataclasses.dataclass 1a
34class StaticLineItem: 1a
35 price: StaticPrice 1a
36 amount: int 1a
37 currency: str 1a
38 label: str 1a
39 proration: bool 1a
42@dataclasses.dataclass 1a
43class MeteredLineItem: 1a
44 price: MeteredPrice 1a
45 start_timestamp: datetime 1a
46 end_timestamp: datetime 1a
47 consumed_units: float 1a
48 credited_units: int 1a
49 amount: int 1a
50 currency: str 1a
51 label: str 1a
52 proration: Literal[False] = False 1a
55class BillingEntryService: 1a
56 async def create_order_items_from_pending( 1a
57 self,
58 session: AsyncSession,
59 subscription: Subscription,
60 *,
61 stripe_invoice_id: str | None = None,
62 stripe_customer_id: str | None = None,
63 ) -> Sequence[OrderItem]:
64 items: list[OrderItem] = []
65 async for line_item, entries in self.compute_pending_subscription_line_items(
66 session, subscription
67 ):
68 order_item_id = uuid.uuid4()
70 # For legacy subscriptions managed by Stripe, we create invoice items on Stripe
71 if stripe_invoice_id and stripe_customer_id:
72 assert isinstance(line_item, MeteredLineItem)
73 price = line_item.price
74 await stripe_service.create_invoice_item(
75 customer=stripe_customer_id,
76 invoice=stripe_invoice_id,
77 amount=line_item.amount,
78 currency=line_item.currency,
79 description=line_item.label,
80 metadata={
81 "order_item_id": str(order_item_id),
82 "product_price_id": str(price.id),
83 "meter_id": str(price.meter_id),
84 "units": str(line_item.consumed_units),
85 "credited_units": str(line_item.credited_units),
86 "unit_amount": str(price.unit_amount),
87 "cap_amount": str(price.cap_amount),
88 },
89 )
91 order_item = OrderItem(
92 id=order_item_id,
93 label=line_item.label,
94 amount=line_item.amount,
95 tax_amount=0,
96 proration=line_item.proration,
97 product_price=line_item.price,
98 )
99 items.append(order_item)
101 # Do it asynchronously to avoid issues with DB flush, since we're
102 # generating OrderItem without attached to an Order yet.
103 enqueue_job("billing_entry.set_order_item", entries, order_item.id)
105 return items
107 async def compute_pending_subscription_line_items( 1a
108 self, session: AsyncSession, subscription: Subscription
109 ) -> AsyncGenerator[tuple[StaticLineItem | MeteredLineItem, Sequence[uuid.UUID]]]:
110 repository = BillingEntryRepository.from_session(session)
112 async for entry in repository.get_static_pending_by_subscription(
113 subscription.id
114 ):
115 static_price = cast(StaticPrice, entry.product_price)
116 static_line_item = await self._get_static_price_line_item(
117 session, static_price, entry
118 )
119 yield static_line_item, [entry.id]
121 # 👋 Reading the code below, you might wonder:
122 # "Why is this so complex?"
123 # "Why are there so many queries?"
124 # Well, if you look at the previous implementation, it was much more readable
125 # but it involved to load lot of BillingEntry in memory, which was causing
126 # performance issues and even OOM on large subscriptions.
127 product_price_repository = ProductPriceRepository.from_session(session)
129 # Track which meters we've already processed to avoid duplicates
130 # For non-summable aggregations (max, min, avg, unique), we process each meter only once
131 # (even if there are billing entries with multiple prices) because these aggregations
132 # must be computed across ALL events, not per-price
133 processed_meters: set[uuid.UUID] = set()
135 async for (
136 product_price_id,
137 meter_id,
138 start_timestamp,
139 end_timestamp,
140 ) in repository.get_pending_metered_by_subscription_tuples(subscription.id):
141 metered_price = cast(
142 MeteredPrice, await product_price_repository.get_by_id(product_price_id)
143 )
145 # Check if this meter uses a non-summable aggregation
146 # Non-summable aggregations (max, min, avg, unique) must be computed across
147 # ALL events in the period, not per-price. For example:
148 # - MAX(3 servers on priceA, 2 servers on priceB) = 3 servers (not 3+2=5)
149 # - We bill this at the currently active price from subscription
150 if not metered_price.meter.aggregation.is_summable():
151 if meter_id in processed_meters:
152 continue
153 processed_meters.add(meter_id)
155 # Find the currently active price for this meter from the subscription
156 # This is the source of truth - even if all billing entries used priceA,
157 # if the customer changed to priceB, we bill at priceB
158 active_price = None
159 for spp in subscription.subscription_product_prices:
160 if (
161 is_metered_price(spp.product_price)
162 and spp.product_price.meter_id == meter_id
163 ):
164 active_price = spp.product_price
165 break
167 if active_price is None:
168 log.info(
169 f"No active price found for meter {meter_id} in subscription {subscription.id}"
170 )
171 continue
173 metered_line_item = await self._get_metered_line_item_by_meter(
174 session, active_price, subscription, start_timestamp, end_timestamp
175 )
176 pending_entries_ids = (
177 await repository.get_pending_ids_by_subscription_and_meter(
178 subscription.id, meter_id
179 )
180 )
181 else:
182 metered_line_item = await self._get_metered_line_item(
183 session, metered_price, subscription, start_timestamp, end_timestamp
184 )
185 pending_entries_ids = (
186 await repository.get_pending_ids_by_subscription_and_price(
187 subscription.id, product_price_id
188 )
189 )
191 yield metered_line_item, pending_entries_ids
193 async def _get_static_price_line_item( 1a
194 self, session: AsyncSession, price: StaticPrice, entry: BillingEntry
195 ) -> StaticLineItem:
196 assert entry.amount is not None
197 assert entry.currency is not None
199 product_repository = ProductRepository.from_session(session)
200 product = await product_repository.get_by_id(price.product_id)
201 assert product is not None
203 start = format_date(entry.start_timestamp.date(), locale="en_US")
204 end = format_date(entry.end_timestamp.date(), locale="en_US")
205 amount = entry.amount
207 if entry.direction == BillingEntryDirection.credit:
208 label = f"Remaining time on {product.name} — From {start} to {end}"
209 amount = -amount
210 elif entry.direction == BillingEntryDirection.debit:
211 label = f"{product.name} — From {start} to {end}"
212 amount = amount
214 return StaticLineItem(
215 price=price,
216 amount=amount,
217 currency=entry.currency,
218 label=label,
219 proration=entry.type == BillingEntryType.proration,
220 )
222 async def _get_metered_line_item( 1a
223 self,
224 session: AsyncSession,
225 price: MeteredPrice,
226 subscription: Subscription,
227 start_timestamp: datetime,
228 end_timestamp: datetime,
229 ) -> MeteredLineItem:
230 """
231 Compute a metered line item for a specific price.
232 Used for summable aggregations (sum, count) where we can group by price.
233 """
234 event_repository = EventRepository.from_session(session)
235 events_statement = event_repository.get_by_pending_entries_statement(
236 subscription.id, price.id
237 )
238 meter = price.meter
239 units = await meter_service.get_quantity(
240 session,
241 meter,
242 events_statement.with_only_columns(Event.id).where(
243 Event.source == EventSource.user
244 ),
245 )
246 credit_events_statement = events_statement.where(
247 Event.is_meter_credit.is_(True)
248 )
249 credit_events = await event_repository.get_all(credit_events_statement)
250 credited_units = non_negative_running_sum(
251 event.user_metadata["units"] for event in credit_events
252 )
253 amount, amount_label = price.get_amount_and_label(units - credited_units)
254 label = f"{meter.name} — {amount_label}"
256 return MeteredLineItem(
257 price=price,
258 start_timestamp=start_timestamp,
259 end_timestamp=end_timestamp,
260 consumed_units=units,
261 credited_units=credited_units,
262 amount=amount,
263 currency=price.price_currency,
264 label=label,
265 )
267 async def _get_metered_line_item_by_meter( 1a
268 self,
269 session: AsyncSession,
270 price: MeteredPrice,
271 subscription: Subscription,
272 start_timestamp: datetime,
273 end_timestamp: datetime,
274 ) -> MeteredLineItem:
275 """
276 Compute a metered line item grouped by meter.
277 Used for non-summable aggregations (max, min, avg, unique) where we must
278 compute across ALL events for the meter, regardless of which price was active.
279 Uses the provided price for billing (should be the most recent/current price).
280 """
281 event_repository = EventRepository.from_session(session)
282 meter = price.meter
284 # Get events across ALL prices for this meter
285 events_statement = event_repository.get_by_pending_entries_for_meter_statement(
286 subscription.id, meter.id
287 )
288 units = await meter_service.get_quantity(
289 session,
290 meter,
291 events_statement.with_only_columns(Event.id).where(
292 Event.source == EventSource.user
293 ),
294 )
295 credit_events_statement = events_statement.where(
296 Event.is_meter_credit.is_(True)
297 )
298 credit_events = await event_repository.get_all(credit_events_statement)
299 credited_units = non_negative_running_sum(
300 event.user_metadata["units"] for event in credit_events
301 )
302 amount, amount_label = price.get_amount_and_label(units - credited_units)
303 label = f"{meter.name} — {amount_label}"
305 return MeteredLineItem(
306 price=price,
307 start_timestamp=start_timestamp,
308 end_timestamp=end_timestamp,
309 consumed_units=units,
310 credited_units=credited_units,
311 amount=amount,
312 currency=price.price_currency,
313 label=label,
314 )
317billing_entry = BillingEntryService() 1a