Coverage for polar/billing_entry/service.py: 33%

119 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 16:17 +0000

1import dataclasses 1a

2import uuid 1a

3from collections.abc import Sequence 1a

4from datetime import datetime 1a

5from typing import cast 1a

6 

7import structlog 1a

8from babel.dates import format_date 1a

9from sqlalchemy.util.typing import Literal 1a

10from typing_extensions import AsyncGenerator 1a

11 

12from polar.event.repository import EventRepository 1a

13from polar.integrations.stripe.service import stripe as stripe_service 1a

14from polar.kit.math import non_negative_running_sum 1a

15from polar.meter.service import meter as meter_service 1a

16from polar.models import BillingEntry, Event, OrderItem, Subscription 1a

17from polar.models.billing_entry import BillingEntryDirection, BillingEntryType 1a

18from polar.models.event import EventSource 1a

19from polar.postgres import AsyncSession 1a

20from polar.product.guard import ( 1a

21 MeteredPrice, 

22 StaticPrice, 

23 is_metered_price, 

24) 

25from polar.product.repository import ProductPriceRepository, ProductRepository 1a

26from polar.worker._enqueue import enqueue_job 1a

27 

28from .repository import BillingEntryRepository 1a

29 

30log = structlog.get_logger(__name__) 1a

31 

32 

33@dataclasses.dataclass 1a

34class StaticLineItem: 1a

35 price: StaticPrice 1a

36 amount: int 1a

37 currency: str 1a

38 label: str 1a

39 proration: bool 1a

40 

41 

42@dataclasses.dataclass 1a

43class MeteredLineItem: 1a

44 price: MeteredPrice 1a

45 start_timestamp: datetime 1a

46 end_timestamp: datetime 1a

47 consumed_units: float 1a

48 credited_units: int 1a

49 amount: int 1a

50 currency: str 1a

51 label: str 1a

52 proration: Literal[False] = False 1a

53 

54 

55class BillingEntryService: 1a

56 async def create_order_items_from_pending( 1a

57 self, 

58 session: AsyncSession, 

59 subscription: Subscription, 

60 *, 

61 stripe_invoice_id: str | None = None, 

62 stripe_customer_id: str | None = None, 

63 ) -> Sequence[OrderItem]: 

64 items: list[OrderItem] = [] 

65 async for line_item, entries in self.compute_pending_subscription_line_items( 

66 session, subscription 

67 ): 

68 order_item_id = uuid.uuid4() 

69 

70 # For legacy subscriptions managed by Stripe, we create invoice items on Stripe 

71 if stripe_invoice_id and stripe_customer_id: 

72 assert isinstance(line_item, MeteredLineItem) 

73 price = line_item.price 

74 await stripe_service.create_invoice_item( 

75 customer=stripe_customer_id, 

76 invoice=stripe_invoice_id, 

77 amount=line_item.amount, 

78 currency=line_item.currency, 

79 description=line_item.label, 

80 metadata={ 

81 "order_item_id": str(order_item_id), 

82 "product_price_id": str(price.id), 

83 "meter_id": str(price.meter_id), 

84 "units": str(line_item.consumed_units), 

85 "credited_units": str(line_item.credited_units), 

86 "unit_amount": str(price.unit_amount), 

87 "cap_amount": str(price.cap_amount), 

88 }, 

89 ) 

90 

91 order_item = OrderItem( 

92 id=order_item_id, 

93 label=line_item.label, 

94 amount=line_item.amount, 

95 tax_amount=0, 

96 proration=line_item.proration, 

97 product_price=line_item.price, 

98 ) 

99 items.append(order_item) 

100 

101 # Do it asynchronously to avoid issues with DB flush, since we're 

102 # generating OrderItem without attached to an Order yet. 

103 enqueue_job("billing_entry.set_order_item", entries, order_item.id) 

104 

105 return items 

106 

107 async def compute_pending_subscription_line_items( 1a

108 self, session: AsyncSession, subscription: Subscription 

109 ) -> AsyncGenerator[tuple[StaticLineItem | MeteredLineItem, Sequence[uuid.UUID]]]: 

110 repository = BillingEntryRepository.from_session(session) 

111 

112 async for entry in repository.get_static_pending_by_subscription( 

113 subscription.id 

114 ): 

115 static_price = cast(StaticPrice, entry.product_price) 

116 static_line_item = await self._get_static_price_line_item( 

117 session, static_price, entry 

118 ) 

119 yield static_line_item, [entry.id] 

120 

121 # 👋 Reading the code below, you might wonder: 

122 # "Why is this so complex?" 

123 # "Why are there so many queries?" 

124 # Well, if you look at the previous implementation, it was much more readable 

125 # but it involved to load lot of BillingEntry in memory, which was causing 

126 # performance issues and even OOM on large subscriptions. 

127 product_price_repository = ProductPriceRepository.from_session(session) 

128 

129 # Track which meters we've already processed to avoid duplicates 

130 # For non-summable aggregations (max, min, avg, unique), we process each meter only once 

131 # (even if there are billing entries with multiple prices) because these aggregations 

132 # must be computed across ALL events, not per-price 

133 processed_meters: set[uuid.UUID] = set() 

134 

135 async for ( 

136 product_price_id, 

137 meter_id, 

138 start_timestamp, 

139 end_timestamp, 

140 ) in repository.get_pending_metered_by_subscription_tuples(subscription.id): 

141 metered_price = cast( 

142 MeteredPrice, await product_price_repository.get_by_id(product_price_id) 

143 ) 

144 

145 # Check if this meter uses a non-summable aggregation 

146 # Non-summable aggregations (max, min, avg, unique) must be computed across 

147 # ALL events in the period, not per-price. For example: 

148 # - MAX(3 servers on priceA, 2 servers on priceB) = 3 servers (not 3+2=5) 

149 # - We bill this at the currently active price from subscription 

150 if not metered_price.meter.aggregation.is_summable(): 

151 if meter_id in processed_meters: 

152 continue 

153 processed_meters.add(meter_id) 

154 

155 # Find the currently active price for this meter from the subscription 

156 # This is the source of truth - even if all billing entries used priceA, 

157 # if the customer changed to priceB, we bill at priceB 

158 active_price = None 

159 for spp in subscription.subscription_product_prices: 

160 if ( 

161 is_metered_price(spp.product_price) 

162 and spp.product_price.meter_id == meter_id 

163 ): 

164 active_price = spp.product_price 

165 break 

166 

167 if active_price is None: 

168 log.info( 

169 f"No active price found for meter {meter_id} in subscription {subscription.id}" 

170 ) 

171 continue 

172 

173 metered_line_item = await self._get_metered_line_item_by_meter( 

174 session, active_price, subscription, start_timestamp, end_timestamp 

175 ) 

176 pending_entries_ids = ( 

177 await repository.get_pending_ids_by_subscription_and_meter( 

178 subscription.id, meter_id 

179 ) 

180 ) 

181 else: 

182 metered_line_item = await self._get_metered_line_item( 

183 session, metered_price, subscription, start_timestamp, end_timestamp 

184 ) 

185 pending_entries_ids = ( 

186 await repository.get_pending_ids_by_subscription_and_price( 

187 subscription.id, product_price_id 

188 ) 

189 ) 

190 

191 yield metered_line_item, pending_entries_ids 

192 

193 async def _get_static_price_line_item( 1a

194 self, session: AsyncSession, price: StaticPrice, entry: BillingEntry 

195 ) -> StaticLineItem: 

196 assert entry.amount is not None 

197 assert entry.currency is not None 

198 

199 product_repository = ProductRepository.from_session(session) 

200 product = await product_repository.get_by_id(price.product_id) 

201 assert product is not None 

202 

203 start = format_date(entry.start_timestamp.date(), locale="en_US") 

204 end = format_date(entry.end_timestamp.date(), locale="en_US") 

205 amount = entry.amount 

206 

207 if entry.direction == BillingEntryDirection.credit: 

208 label = f"Remaining time on {product.name} — From {start} to {end}" 

209 amount = -amount 

210 elif entry.direction == BillingEntryDirection.debit: 

211 label = f"{product.name} — From {start} to {end}" 

212 amount = amount 

213 

214 return StaticLineItem( 

215 price=price, 

216 amount=amount, 

217 currency=entry.currency, 

218 label=label, 

219 proration=entry.type == BillingEntryType.proration, 

220 ) 

221 

222 async def _get_metered_line_item( 1a

223 self, 

224 session: AsyncSession, 

225 price: MeteredPrice, 

226 subscription: Subscription, 

227 start_timestamp: datetime, 

228 end_timestamp: datetime, 

229 ) -> MeteredLineItem: 

230 """ 

231 Compute a metered line item for a specific price. 

232 Used for summable aggregations (sum, count) where we can group by price. 

233 """ 

234 event_repository = EventRepository.from_session(session) 

235 events_statement = event_repository.get_by_pending_entries_statement( 

236 subscription.id, price.id 

237 ) 

238 meter = price.meter 

239 units = await meter_service.get_quantity( 

240 session, 

241 meter, 

242 events_statement.with_only_columns(Event.id).where( 

243 Event.source == EventSource.user 

244 ), 

245 ) 

246 credit_events_statement = events_statement.where( 

247 Event.is_meter_credit.is_(True) 

248 ) 

249 credit_events = await event_repository.get_all(credit_events_statement) 

250 credited_units = non_negative_running_sum( 

251 event.user_metadata["units"] for event in credit_events 

252 ) 

253 amount, amount_label = price.get_amount_and_label(units - credited_units) 

254 label = f"{meter.name}{amount_label}" 

255 

256 return MeteredLineItem( 

257 price=price, 

258 start_timestamp=start_timestamp, 

259 end_timestamp=end_timestamp, 

260 consumed_units=units, 

261 credited_units=credited_units, 

262 amount=amount, 

263 currency=price.price_currency, 

264 label=label, 

265 ) 

266 

267 async def _get_metered_line_item_by_meter( 1a

268 self, 

269 session: AsyncSession, 

270 price: MeteredPrice, 

271 subscription: Subscription, 

272 start_timestamp: datetime, 

273 end_timestamp: datetime, 

274 ) -> MeteredLineItem: 

275 """ 

276 Compute a metered line item grouped by meter. 

277 Used for non-summable aggregations (max, min, avg, unique) where we must 

278 compute across ALL events for the meter, regardless of which price was active. 

279 Uses the provided price for billing (should be the most recent/current price). 

280 """ 

281 event_repository = EventRepository.from_session(session) 

282 meter = price.meter 

283 

284 # Get events across ALL prices for this meter 

285 events_statement = event_repository.get_by_pending_entries_for_meter_statement( 

286 subscription.id, meter.id 

287 ) 

288 units = await meter_service.get_quantity( 

289 session, 

290 meter, 

291 events_statement.with_only_columns(Event.id).where( 

292 Event.source == EventSource.user 

293 ), 

294 ) 

295 credit_events_statement = events_statement.where( 

296 Event.is_meter_credit.is_(True) 

297 ) 

298 credit_events = await event_repository.get_all(credit_events_statement) 

299 credited_units = non_negative_running_sum( 

300 event.user_metadata["units"] for event in credit_events 

301 ) 

302 amount, amount_label = price.get_amount_and_label(units - credited_units) 

303 label = f"{meter.name}{amount_label}" 

304 

305 return MeteredLineItem( 

306 price=price, 

307 start_timestamp=start_timestamp, 

308 end_timestamp=end_timestamp, 

309 consumed_units=units, 

310 credited_units=credited_units, 

311 amount=amount, 

312 currency=price.price_currency, 

313 label=label, 

314 ) 

315 

316 

317billing_entry = BillingEntryService() 1a