Coverage for polar/meter/service.py: 18%

165 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 16:17 +0000

1import uuid 1a

2from collections.abc import Sequence 1a

3from datetime import UTC, datetime 1a

4from typing import Any 1a

5 

6from sqlalchemy import ( 1a

7 ColumnElement, 

8 ColumnExpressionArgument, 

9 Select, 

10 UnaryExpression, 

11 and_, 

12 asc, 

13 cte, 

14 desc, 

15 func, 

16 or_, 

17 select, 

18) 

19from sqlalchemy.orm import joinedload 1a

20 

21from polar.auth.models import AuthSubject, Organization, User 1a

22from polar.billing_entry.repository import BillingEntryRepository 1a

23from polar.config import settings 1a

24from polar.event.repository import EventRepository 1a

25from polar.exceptions import PolarError, PolarRequestValidationError, ValidationError 1a

26from polar.kit.metadata import MetadataQuery, apply_metadata_clause, get_metadata_clause 1a

27from polar.kit.pagination import PaginationParams 1a

28from polar.kit.sorting import Sorting 1a

29from polar.kit.time_queries import TimeInterval, get_timestamp_series_cte 1a

30from polar.meter.aggregation import AggregationFunction 1a

31from polar.models import ( 1a

32 Benefit, 

33 BillingEntry, 

34 Customer, 

35 Event, 

36 Meter, 

37 Product, 

38 ProductPriceMeteredUnit, 

39 SubscriptionProductPrice, 

40) 

41from polar.organization.resolver import get_payload_organization 1a

42from polar.postgres import AsyncReadSession, AsyncSession 1a

43from polar.subscription.repository import ( 1a

44 CustomerSubscriptionProductPrice, 

45 SubscriptionProductPriceRepository, 

46) 

47from polar.worker import enqueue_job 1a

48 

49from .repository import MeterRepository 1a

50from .schemas import MeterCreate, MeterQuantities, MeterQuantity, MeterUpdate 1a

51from .sorting import MeterSortProperty 1a

52 

53 

54class MeterError(PolarError): ... 1a

55 

56 

57class MeterService: 1a

58 async def list( 1a

59 self, 

60 session: AsyncReadSession, 

61 auth_subject: AuthSubject[User | Organization], 

62 *, 

63 organization_id: Sequence[uuid.UUID] | None = None, 

64 metadata: MetadataQuery | None = None, 

65 query: str | None = None, 

66 is_archived: bool | None = None, 

67 pagination: PaginationParams, 

68 sorting: list[Sorting[MeterSortProperty]] = [ 

69 (MeterSortProperty.meter_name, False) 

70 ], 

71 ) -> tuple[Sequence[Meter], int]: 

72 repository = MeterRepository.from_session(session) 

73 statement = repository.get_readable_statement(auth_subject) 

74 

75 if organization_id is not None: 

76 statement = statement.where(Meter.organization_id.in_(organization_id)) 

77 

78 if query is not None: 

79 statement = statement.where(Meter.name.ilike(f"%{query}%")) 

80 

81 if is_archived is not None: 

82 if is_archived: 

83 statement = statement.where(Meter.archived_at.is_not(None)) 

84 else: 

85 statement = statement.where(Meter.archived_at.is_(None)) 

86 

87 if metadata is not None: 

88 statement = apply_metadata_clause(Meter, statement, metadata) 

89 

90 order_by_clauses: list[UnaryExpression[Any]] = [] 

91 for criterion, is_desc in sorting: 

92 clause_function = desc if is_desc else asc 

93 if criterion == MeterSortProperty.created_at: 

94 order_by_clauses.append(clause_function(Meter.created_at)) 

95 elif criterion == MeterSortProperty.meter_name: 

96 order_by_clauses.append(clause_function(Meter.name)) 

97 statement = statement.order_by(*order_by_clauses) 

98 

99 return await repository.paginate( 

100 statement, limit=pagination.limit, page=pagination.page 

101 ) 

102 

103 async def get( 1a

104 self, 

105 session: AsyncReadSession, 

106 auth_subject: AuthSubject[User | Organization], 

107 id: uuid.UUID, 

108 ) -> Meter | None: 

109 repository = MeterRepository.from_session(session) 

110 statement = ( 

111 repository.get_readable_statement(auth_subject) 

112 .where(Meter.id == id) 

113 .options(joinedload(Meter.last_billed_event)) 

114 ) 

115 return await repository.get_one_or_none(statement) 

116 

117 async def create( 1a

118 self, 

119 session: AsyncSession, 

120 meter_create: MeterCreate, 

121 auth_subject: AuthSubject[User | Organization], 

122 ) -> Meter: 

123 repository = MeterRepository.from_session(session) 

124 organization = await get_payload_organization( 

125 session, auth_subject, meter_create 

126 ) 

127 

128 meter = await repository.create( 

129 Meter( 

130 **meter_create.model_dump( 

131 by_alias=True, exclude={"filter", "aggregation"} 

132 ), 

133 filter=meter_create.filter, 

134 aggregation=meter_create.aggregation, 

135 organization=organization, 

136 ), 

137 flush=True, 

138 ) 

139 

140 # Retrieve the latest matching event for the meter and set it as the last billed event 

141 # This is done to ensure that the meter is billed from the last event onwards 

142 event_repository = EventRepository.from_session(session) 

143 statement = ( 

144 event_repository.get_meter_statement(meter) 

145 .order_by(Event.timestamp.desc()) 

146 .limit(1) 

147 ) 

148 last_billed_event = await event_repository.get_one_or_none(statement) 

149 await repository.update( 

150 meter, update_dict={"last_billed_event": last_billed_event} 

151 ) 

152 

153 return meter 

154 

155 async def update( 1a

156 self, session: AsyncSession, meter: Meter, meter_update: MeterUpdate 

157 ) -> Meter: 

158 repository = MeterRepository.from_session(session) 

159 

160 errors: list[ValidationError] = [] 

161 if meter.last_billed_event is not None: 

162 sensitive_fields = {"filter", "aggregation"} 

163 for sensitive_field in sensitive_fields: 

164 if sensitive_field in meter_update.model_fields_set: 

165 errors.append( 

166 { 

167 "type": "forbidden", 

168 "loc": ("body", sensitive_field), 

169 "msg": ( 

170 "This field can't be updated because the meter " 

171 "is already aggregating events." 

172 ), 

173 "input": getattr(meter_update, sensitive_field), 

174 } 

175 ) 

176 

177 if errors: 

178 raise PolarRequestValidationError(errors) 

179 

180 update_dict = meter_update.model_dump( 

181 by_alias=True, 

182 exclude_unset=True, 

183 exclude={"filter", "aggregation", "is_archived"}, 

184 ) 

185 if meter_update.filter is not None: 

186 update_dict["filter"] = meter_update.filter 

187 if meter_update.aggregation is not None: 

188 update_dict["aggregation"] = meter_update.aggregation 

189 

190 # Handle archiving/unarchiving 

191 if meter_update.is_archived is not None: 

192 if meter_update.is_archived: 

193 meter = await self.archive(session, meter) 

194 else: 

195 meter = await self.unarchive(session, meter) 

196 

197 return await repository.update(meter, update_dict=update_dict) 

198 

199 async def archive(self, session: AsyncSession, meter: Meter) -> Meter: 1a

200 # Check if meter is attached to any active ProductPriceMeteredUnit 

201 active_prices = await session.scalar( 

202 select(func.count(ProductPriceMeteredUnit.id)) 

203 .join(Product) 

204 .where( 

205 Product.is_archived.is_(False), 

206 ProductPriceMeteredUnit.meter_id == meter.id, 

207 ProductPriceMeteredUnit.is_archived.is_(False), 

208 ProductPriceMeteredUnit.deleted_at.is_(None), 

209 ) 

210 ) 

211 

212 if active_prices and active_prices > 0: 

213 raise PolarRequestValidationError( 

214 [ 

215 { 

216 "type": "value_error", 

217 "loc": ("body", "is_archived"), 

218 "msg": "Cannot archive meter that is still attached to active products", 

219 "input": True, 

220 } 

221 ] 

222 ) 

223 

224 # Check if meter is referenced by any active Benefits with meter_credit type 

225 active_benefits = await session.scalar( 

226 select(func.count(Benefit.id)).where( 

227 Benefit.type == "meter_credit", 

228 Benefit.properties["meter_id"].as_string() == str(meter.id), 

229 Benefit.deleted_at.is_(None), 

230 ) 

231 ) 

232 

233 if active_benefits and active_benefits > 0: 

234 raise PolarRequestValidationError( 

235 [ 

236 { 

237 "type": "value_error", 

238 "loc": ("body", "is_archived"), 

239 "msg": "Cannot archive meter that is still referenced by active benefits", 

240 "input": True, 

241 } 

242 ] 

243 ) 

244 

245 repository = MeterRepository.from_session(session) 

246 return await repository.update( 

247 meter, update_dict={"archived_at": datetime.now(UTC)} 

248 ) 

249 

250 async def unarchive(self, session: AsyncSession, meter: Meter) -> Meter: 1a

251 repository = MeterRepository.from_session(session) 

252 return await repository.update(meter, update_dict={"archived_at": None}) 

253 

254 async def events( 1a

255 self, 

256 session: AsyncSession, 

257 meter: Meter, 

258 *, 

259 pagination: PaginationParams, 

260 ) -> tuple[Sequence[Event], int]: 

261 repository = EventRepository.from_session(session) 

262 statement = repository.get_meter_statement(meter).order_by( 

263 Event.timestamp.desc() 

264 ) 

265 return await repository.paginate( 

266 statement, limit=pagination.limit, page=pagination.page 

267 ) 

268 

269 async def get_quantities( 1a

270 self, 

271 session: AsyncReadSession, 

272 meter: Meter, 

273 *, 

274 start_timestamp: datetime, 

275 end_timestamp: datetime, 

276 interval: TimeInterval, 

277 customer_id: Sequence[uuid.UUID] | None = None, 

278 external_customer_id: Sequence[str] | None = None, 

279 metadata: MetadataQuery | None = None, 

280 customer_aggregation_function: AggregationFunction | None = None, 

281 ) -> MeterQuantities: 

282 timestamp_series = get_timestamp_series_cte( 

283 start_timestamp, end_timestamp, interval 

284 ) 

285 timestamp_column: ColumnElement[datetime] = timestamp_series.c.timestamp 

286 

287 event_clauses: list[ColumnExpressionArgument[bool]] = [ 

288 Event.organization_id == meter.organization_id, 

289 ] 

290 event_repository = EventRepository.from_session(session) 

291 if customer_id is not None: 

292 event_clauses.append( 

293 event_repository.get_customer_id_filter_clause(customer_id) 

294 ) 

295 if external_customer_id is not None: 

296 event_clauses.append( 

297 event_repository.get_external_customer_id_filter_clause( 

298 external_customer_id 

299 ) 

300 ) 

301 if metadata is not None: 

302 event_clauses.append(get_metadata_clause(Event, metadata)) 

303 event_clauses.append(event_repository.get_meter_clause(meter)) 

304 

305 statement = ( 

306 select( 

307 timestamp_column.label("timestamp"), 

308 func.coalesce( 

309 meter.aggregation.get_sql_column(Event).filter( 

310 interval.sql_date_trunc(Event.timestamp) 

311 == interval.sql_date_trunc(timestamp_column), 

312 ), 

313 0, 

314 ).label("quantity"), 

315 func.coalesce( 

316 meter.aggregation.get_sql_column(Event).filter( 

317 interval.sql_date_trunc(Event.timestamp) 

318 >= interval.sql_date_trunc(start_timestamp), 

319 interval.sql_date_trunc(Event.timestamp) 

320 <= interval.sql_date_trunc(end_timestamp), 

321 ), 

322 0, 

323 ).label("total"), 

324 ) 

325 .join(Event, onclause=and_(*event_clauses), isouter=True) 

326 .group_by(timestamp_column) 

327 .order_by(timestamp_column.asc()) 

328 ) 

329 

330 if customer_aggregation_function is not None: 

331 inner_statement = cte( 

332 statement.add_columns(Event.resolved_customer_id).group_by( 

333 timestamp_column, Event.resolved_customer_id 

334 ) 

335 ) 

336 statement = ( 

337 select( 

338 timestamp_column.label("timestamp"), 

339 customer_aggregation_function.get_sql_function( 

340 inner_statement.c.quantity 

341 ).label("quantity"), 

342 customer_aggregation_function.get_sql_function( 

343 inner_statement.c.total 

344 ).label("total"), 

345 ) 

346 .join( 

347 inner_statement, 

348 onclause=inner_statement.c.timestamp == timestamp_column, 

349 ) 

350 .group_by(timestamp_column) 

351 .order_by(timestamp_column.asc()) 

352 ) 

353 

354 total = 0.0 

355 quantities: list[MeterQuantity] = [] 

356 result = await session.stream( 

357 statement, 

358 execution_options={"yield_per": settings.DATABASE_STREAM_YIELD_PER}, 

359 ) 

360 async for row in result: 

361 quantities.append(MeterQuantity.model_validate(row)) 

362 total = row.total 

363 

364 return MeterQuantities(quantities=quantities, total=total) 

365 

366 async def enqueue_billing(self, session: AsyncSession) -> None: 1a

367 repository = MeterRepository.from_session(session) 

368 statement = repository.get_base_statement().order_by(Meter.created_at.asc()) 

369 async for meter in repository.stream(statement): 

370 enqueue_job("meter.billing_entries", meter.id) 

371 

372 async def _create_subscription_holder_billing_entry( 1a

373 self, 

374 session: AsyncSession, 

375 event: Event, 

376 customer: "Customer", 

377 subscription_product_price: SubscriptionProductPrice, 

378 ) -> BillingEntry: 

379 billing_entry_repository = BillingEntryRepository.from_session(session) 

380 return await billing_entry_repository.create( 

381 BillingEntry.from_metered_event(customer, subscription_product_price, event) 

382 ) 

383 

384 async def create_billing_entries( 1a

385 self, session: AsyncSession, meter: Meter 

386 ) -> Sequence[BillingEntry]: 

387 event_repository = EventRepository.from_session(session) 

388 statement = ( 

389 event_repository.get_base_statement() 

390 .where( 

391 Event.organization_id == meter.organization_id, 

392 Event.customer.is_not(None), 

393 or_( 

394 # Events matching meter definitions 

395 event_repository.get_meter_clause(meter), 

396 # System events impacting the meter balance 

397 event_repository.get_meter_system_clause(meter), 

398 ), 

399 ) 

400 .order_by(Event.ingested_at.asc()) 

401 .options(*event_repository.get_eager_options()) 

402 ) 

403 last_billed_event = meter.last_billed_event 

404 if last_billed_event is not None: 

405 statement = statement.where( 

406 Event.ingested_at > last_billed_event.ingested_at 

407 ) 

408 

409 subscription_product_price_repository = ( 

410 SubscriptionProductPriceRepository.from_session(session) 

411 ) 

412 customer_price_map: dict[ 

413 uuid.UUID, CustomerSubscriptionProductPrice | None 

414 ] = {} 

415 

416 entries: list[BillingEntry] = [] 

417 updated_subscriptions: set[uuid.UUID] = set() 

418 last_event: Event | None = None 

419 async for event in event_repository.stream(statement): 

420 last_event = event 

421 customer = event.customer 

422 assert customer is not None 

423 

424 # Retrieve the paying customer and subscription product price 

425 try: 

426 customer_price = customer_price_map[customer.id] 

427 except KeyError: 

428 customer_price = await subscription_product_price_repository.get_by_customer_and_meter( 

429 customer.id, meter.id 

430 ) 

431 customer_price_map[customer.id] = customer_price 

432 

433 if customer_price is None: 

434 continue 

435 

436 # Get the paying customer (billing manager) from the subscription 

437 paying_customer = ( 

438 customer_price.subscription_product_price.subscription.customer 

439 ) 

440 

441 entry = await self._create_subscription_holder_billing_entry( 

442 session, 

443 event, 

444 paying_customer, 

445 customer_price.subscription_product_price, 

446 ) 

447 entries.append(entry) 

448 if entry.subscription is not None: 

449 updated_subscriptions.add(entry.subscription.id) 

450 

451 meter.last_billed_event = ( 

452 last_event if last_event is not None else last_billed_event 

453 ) 

454 session.add(meter) 

455 

456 for subscription_id in updated_subscriptions: 

457 enqueue_job("subscription.update_meters", subscription_id) 

458 

459 return entries 

460 

461 async def get_quantity( 1a

462 self, 

463 session: AsyncSession, 

464 meter: Meter, 

465 events_statement: Select[tuple[uuid.UUID]], 

466 ) -> float: 

467 statement = select( 

468 func.coalesce(meter.aggregation.get_sql_column(Event), 0) 

469 ).where(Event.id.in_(events_statement)) 

470 result = await session.scalar(statement) 

471 return result or 0.0 

472 

473 

474meter = MeterService() 1a