Coverage for polar/product/service.py: 13%

262 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 15:52 +0000

1import builtins 1a

2import uuid 1a

3from collections.abc import Sequence 1a

4from typing import Literal 1a

5 

6import stripe 1a

7from sqlalchemy import select 1a

8from sqlalchemy.orm import contains_eager, selectinload 1a

9 

10from polar.auth.models import AuthSubject, is_user 1a

11from polar.benefit.service import benefit as benefit_service 1a

12from polar.checkout_link.repository import CheckoutLinkRepository 1a

13from polar.custom_field.service import custom_field as custom_field_service 1a

14from polar.enums import SubscriptionRecurringInterval 1a

15from polar.exceptions import ( 1a

16 PolarError, 

17 PolarRequestValidationError, 

18 ValidationError, 

19) 

20from polar.file.service import file as file_service 1a

21from polar.integrations.loops.service import loops as loops_service 1a

22from polar.integrations.stripe.service import stripe as stripe_service 1a

23from polar.kit.db.postgres import AsyncReadSession, AsyncSession 1a

24from polar.kit.metadata import MetadataQuery, apply_metadata_clause 1a

25from polar.kit.pagination import PaginationParams 1a

26from polar.kit.sorting import Sorting 1a

27from polar.meter.repository import MeterRepository 1a

28from polar.models import ( 1a

29 Benefit, 

30 Organization, 

31 Product, 

32 ProductBenefit, 

33 ProductMedia, 

34 ProductPrice, 

35 User, 

36) 

37from polar.models.product_custom_field import ProductCustomField 1a

38from polar.models.product_price import HasStripePriceId, ProductPriceSource 1a

39from polar.models.webhook_endpoint import WebhookEventType 1a

40from polar.organization.repository import OrganizationRepository 1a

41from polar.organization.resolver import get_payload_organization 1a

42from polar.product.guard import is_legacy_price, is_metered_price, is_static_price 1a

43from polar.product.repository import ProductRepository 1a

44from polar.webhook.service import webhook as webhook_service 1a

45from polar.worker import enqueue_job 1a

46 

47from .schemas import ( 1a

48 ExistingProductPrice, 

49 ProductCreate, 

50 ProductPriceCreate, 

51 ProductPriceMeteredCreateBase, 

52 ProductUpdate, 

53) 

54from .sorting import ProductSortProperty 1a

55 

56 

57class ProductError(PolarError): ... 1a

58 

59 

60class ProductService: 1a

61 async def list( 1a

62 self, 

63 session: AsyncReadSession, 

64 auth_subject: AuthSubject[User | Organization], 

65 *, 

66 id: Sequence[uuid.UUID] | None = None, 

67 organization_id: Sequence[uuid.UUID] | None = None, 

68 query: str | None = None, 

69 is_archived: bool | None = None, 

70 is_recurring: bool | None = None, 

71 benefit_id: Sequence[uuid.UUID] | None = None, 

72 metadata: MetadataQuery | None = None, 

73 pagination: PaginationParams, 

74 sorting: list[Sorting[ProductSortProperty]] = [ 

75 (ProductSortProperty.created_at, True) 

76 ], 

77 ) -> tuple[Sequence[Product], int]: 

78 repository = ProductRepository.from_session(session) 

79 statement = repository.get_readable_statement(auth_subject).join( 

80 ProductPrice, 

81 onclause=( 

82 ProductPrice.id 

83 == select(ProductPrice) 

84 .correlate(Product) 

85 .with_only_columns(ProductPrice.id) 

86 .where( 

87 ProductPrice.product_id == Product.id, 

88 ProductPrice.is_archived.is_(False), 

89 ProductPrice.deleted_at.is_(None), 

90 ) 

91 .order_by(ProductPrice.created_at.asc()) 

92 .limit(1) 

93 .scalar_subquery() 

94 ), 

95 isouter=True, 

96 ) 

97 

98 if id is not None: 

99 statement = statement.where(Product.id.in_(id)) 

100 

101 if organization_id is not None: 

102 statement = statement.where(Product.organization_id.in_(organization_id)) 

103 

104 if query is not None: 

105 statement = statement.where(Product.name.ilike(f"%{query}%")) 

106 

107 if is_archived is not None: 

108 statement = statement.where(Product.is_archived.is_(is_archived)) 

109 

110 if is_recurring is not None: 

111 statement = statement.where(Product.is_recurring.is_(is_recurring)) 

112 

113 if benefit_id is not None: 

114 statement = ( 

115 statement.join(Product.product_benefits) 

116 .where(ProductBenefit.benefit_id.in_(benefit_id)) 

117 .options(contains_eager(Product.product_benefits)) 

118 ) 

119 

120 if metadata is not None: 

121 statement = apply_metadata_clause(Product, statement, metadata) 

122 

123 statement = repository.apply_sorting(statement, sorting) 

124 

125 statement = statement.options( 

126 selectinload(Product.product_medias), 

127 selectinload(Product.attached_custom_fields), 

128 ) 

129 

130 return await repository.paginate( 

131 statement, limit=pagination.limit, page=pagination.page 

132 ) 

133 

134 async def get( 1a

135 self, 

136 session: AsyncReadSession, 

137 auth_subject: AuthSubject[User | Organization], 

138 id: uuid.UUID, 

139 ) -> Product | None: 

140 repository = ProductRepository.from_session(session) 

141 statement = ( 

142 repository.get_readable_statement(auth_subject) 

143 .where(Product.id == id) 

144 .options(*repository.get_eager_options()) 

145 ) 

146 return await repository.get_one_or_none(statement) 

147 

148 async def get_embed( 1a

149 self, session: AsyncReadSession, id: uuid.UUID 

150 ) -> Product | None: 

151 repository = ProductRepository.from_session(session) 

152 statement = ( 

153 repository.get_base_statement() 

154 .where(Product.id == id, Product.is_archived.is_(False)) 

155 .options(selectinload(Product.product_medias)) 

156 ) 

157 return await repository.get_one_or_none(statement) 

158 

159 async def create( 1a

160 self, 

161 session: AsyncSession, 

162 create_schema: ProductCreate, 

163 auth_subject: AuthSubject[User | Organization], 

164 ) -> Product: 

165 repository = ProductRepository.from_session(session) 

166 organization = await get_payload_organization( 

167 session, auth_subject, create_schema 

168 ) 

169 

170 errors: list[ValidationError] = [] 

171 prices, _, _, prices_errors = await self.get_validated_prices( 

172 session, 

173 create_schema.prices, 

174 create_schema.recurring_interval, 

175 None, 

176 auth_subject, 

177 ) 

178 errors.extend(prices_errors) 

179 

180 product = await repository.create( 

181 Product( 

182 organization=organization, 

183 prices=prices, 

184 all_prices=prices, 

185 product_benefits=[], 

186 product_medias=[], 

187 attached_custom_fields=[], 

188 **create_schema.model_dump( 

189 exclude={ 

190 "organization_id", 

191 "prices", 

192 "medias", 

193 "attached_custom_fields", 

194 }, 

195 by_alias=True, 

196 ), 

197 ), 

198 flush=True, 

199 ) 

200 assert product.id is not None 

201 

202 if create_schema.medias is not None: 

203 for order, file_id in enumerate(create_schema.medias): 

204 file = await file_service.get_selectable_product_media_file( 

205 session, file_id, organization_id=product.organization_id 

206 ) 

207 if file is None: 

208 errors.append( 

209 { 

210 "type": "value_error", 

211 "loc": ("body", "medias", order), 

212 "msg": "File does not exist or is not yet uploaded.", 

213 "input": file_id, 

214 } 

215 ) 

216 product.product_medias.append(ProductMedia(file=file, order=order)) 

217 

218 for order, attached_custom_field in enumerate( 

219 create_schema.attached_custom_fields 

220 ): 

221 custom_field = await custom_field_service.get_by_organization_and_id( 

222 session, 

223 attached_custom_field.custom_field_id, 

224 organization.id, 

225 ) 

226 if custom_field is None: 

227 errors.append( 

228 { 

229 "type": "value_error", 

230 "loc": ("body", "attached_custom_fields", order), 

231 "msg": "Custom field does not exist.", 

232 "input": attached_custom_field.custom_field_id, 

233 } 

234 ) 

235 product.attached_custom_fields.append( 

236 ProductCustomField( 

237 custom_field=custom_field, 

238 order=order, 

239 required=attached_custom_field.required, 

240 ) 

241 ) 

242 

243 if errors: 

244 raise PolarRequestValidationError(errors) 

245 

246 metadata: dict[str, str] = {"product_id": str(product.id)} 

247 metadata["organization_id"] = str(organization.id) 

248 metadata["organization_name"] = organization.slug 

249 

250 stripe_product = await stripe_service.create_product( 

251 product.get_stripe_name(), 

252 description=product.description, 

253 metadata=metadata, 

254 ) 

255 product.stripe_product_id = stripe_product.id 

256 

257 for price in product.all_prices: 

258 if isinstance(price, HasStripePriceId): 

259 stripe_price = await stripe_service.create_price_for_product( 

260 stripe_product.id, 

261 price.get_stripe_price_params(product.recurring_interval), 

262 ) 

263 price.stripe_price_id = stripe_price.id 

264 session.add(price) 

265 

266 await session.flush() 

267 

268 await self._after_product_created(session, auth_subject, product) 

269 

270 return product 

271 

272 async def update( 1a

273 self, 

274 session: AsyncSession, 

275 product: Product, 

276 update_schema: ProductUpdate, 

277 auth_subject: AuthSubject[User | Organization], 

278 ) -> Product: 

279 errors: list[ValidationError] = [] 

280 

281 # Validate prices 

282 existing_prices = set(product.prices) 

283 added_prices: list[ProductPrice] = [] 

284 if update_schema.prices is not None: 

285 ( 

286 _, 

287 existing_prices, 

288 added_prices, 

289 prices_errors, 

290 ) = await self.get_validated_prices( 

291 session, 

292 update_schema.prices, 

293 product.recurring_interval, 

294 product, 

295 auth_subject, 

296 ) 

297 errors.extend(prices_errors) 

298 

299 # Prevent non-legacy products from changing their recurring interval 

300 if ( 

301 update_schema.recurring_interval is not None 

302 and ( 

303 update_schema.recurring_interval != product.recurring_interval 

304 or update_schema.recurring_interval_count 

305 != product.recurring_interval_count 

306 ) 

307 and not all(is_legacy_price(price) for price in product.prices) 

308 ): 

309 errors.append( 

310 { 

311 "type": "value_error", 

312 "loc": ("body", "recurring_interval"), 

313 "msg": "Recurring interval cannot be changed.", 

314 "input": update_schema.recurring_interval, 

315 } 

316 ) 

317 

318 # Prevent trying to add trial configuration to non-recurring products 

319 if ( 

320 update_schema.trial_interval is not None 

321 or update_schema.trial_interval_count is not None 

322 ) and product.recurring_interval is None: 

323 errors.extend( 

324 [ 

325 { 

326 "type": "value_error", 

327 "loc": ("body", "trial_interval"), 

328 "msg": "Trial configuration is only supported on recurring products.", 

329 "input": update_schema.trial_interval, 

330 }, 

331 { 

332 "type": "value_error", 

333 "loc": ("body", "trial_interval_count"), 

334 "msg": "Trial configuration is only supported on recurring products.", 

335 "input": update_schema.trial_interval_count, 

336 }, 

337 ] 

338 ) 

339 

340 if update_schema.medias is not None: 

341 medias_errors: list[ValidationError] = [] 

342 nested = await session.begin_nested() 

343 product.medias = [] 

344 await session.flush() 

345 

346 for order, file_id in enumerate(update_schema.medias): 

347 file = await file_service.get_selectable_product_media_file( 

348 session, file_id, organization_id=product.organization_id 

349 ) 

350 if file is None: 

351 medias_errors.append( 

352 { 

353 "type": "value_error", 

354 "loc": ("body", "medias", order), 

355 "msg": "File does not exist or is not yet uploaded.", 

356 "input": file_id, 

357 } 

358 ) 

359 continue 

360 product.product_medias.append(ProductMedia(file=file, order=order)) 

361 

362 if medias_errors: 

363 await nested.rollback() 

364 errors.extend(medias_errors) 

365 

366 if update_schema.attached_custom_fields is not None: 

367 attached_custom_fields_errors: list[ValidationError] = [] 

368 nested = await session.begin_nested() 

369 product.attached_custom_fields = [] 

370 await session.flush() 

371 

372 for order, attached_custom_field in enumerate( 

373 update_schema.attached_custom_fields 

374 ): 

375 custom_field = await custom_field_service.get_by_organization_and_id( 

376 session, 

377 attached_custom_field.custom_field_id, 

378 product.organization_id, 

379 ) 

380 if custom_field is None: 

381 attached_custom_fields_errors.append( 

382 { 

383 "type": "value_error", 

384 "loc": ("body", "attached_custom_fields", order), 

385 "msg": "Custom field does not exist.", 

386 "input": attached_custom_field.custom_field_id, 

387 } 

388 ) 

389 continue 

390 product.attached_custom_fields.append( 

391 ProductCustomField( 

392 custom_field=custom_field, 

393 order=order, 

394 required=attached_custom_field.required, 

395 ) 

396 ) 

397 

398 if attached_custom_fields_errors: 

399 await nested.rollback() 

400 errors.extend(attached_custom_fields_errors) 

401 

402 if errors: 

403 raise PolarRequestValidationError(errors) 

404 

405 if product.is_archived and update_schema.is_archived is False: 

406 product = await self._unarchive(product) 

407 

408 product_update: stripe.Product.ModifyParams = {} 

409 if update_schema.name is not None and update_schema.name != product.name: 

410 product.name = update_schema.name 

411 product_update["name"] = product.get_stripe_name() 

412 if ( 

413 update_schema.description is not None 

414 and update_schema.description != product.description 

415 ): 

416 product.description = update_schema.description 

417 product_update["description"] = update_schema.description 

418 

419 if product_update and product.stripe_product_id is not None: 

420 await stripe_service.update_product( 

421 product.stripe_product_id, **product_update 

422 ) 

423 

424 if update_schema.recurring_interval is not None: 

425 product.recurring_interval = update_schema.recurring_interval 

426 

427 deleted_prices = set(product.prices) - existing_prices 

428 for deleted_price in deleted_prices: 

429 assert product.stripe_product_id is not None 

430 if isinstance(deleted_price, HasStripePriceId): 

431 await stripe_service.update_product( 

432 product.stripe_product_id, default_price="" 

433 ) 

434 await stripe_service.archive_price(deleted_price.stripe_price_id) 

435 deleted_price.is_archived = True 

436 

437 for price in added_prices: 

438 if isinstance(price, HasStripePriceId): 

439 assert product.stripe_product_id is not None 

440 stripe_price = await stripe_service.create_price_for_product( 

441 product.stripe_product_id, 

442 price.get_stripe_price_params(product.recurring_interval), 

443 ) 

444 price.stripe_price_id = stripe_price.id 

445 

446 if update_schema.is_archived: 

447 product = await self._archive(session, product) 

448 

449 for attr, value in update_schema.model_dump( 

450 exclude_unset=True, 

451 exclude={"prices", "medias", "attached_custom_fields"}, 

452 by_alias=True, 

453 ).items(): 

454 setattr(product, attr, value) 

455 

456 session.add(product) 

457 await session.flush() 

458 

459 await session.refresh(product, {"prices", "all_prices"}) 

460 

461 await self._after_product_updated(session, product) 

462 

463 return product 

464 

465 async def update_benefits( 1a

466 self, 

467 session: AsyncSession, 

468 product: Product, 

469 benefits: Sequence[uuid.UUID], 

470 auth_subject: AuthSubject[User | Organization], 

471 ) -> tuple[Product, set[Benefit], set[Benefit]]: 

472 previous_benefits = set(product.benefits) 

473 new_benefits: set[Benefit] = set() 

474 

475 new_product_benefits: list[ProductBenefit] = [] 

476 for order, benefit_id in enumerate(benefits): 

477 benefit = await benefit_service.get(session, auth_subject, benefit_id) 

478 if benefit is None: 

479 raise PolarRequestValidationError( 

480 [ 

481 { 

482 "type": "value_error", 

483 "loc": ("body", "benefits", order), 

484 "msg": "Benefit does not exist.", 

485 "input": benefit_id, 

486 } 

487 ] 

488 ) 

489 if not benefit.selectable and benefit not in previous_benefits: 

490 raise PolarRequestValidationError( 

491 [ 

492 { 

493 "type": "value_error", 

494 "loc": ("body", "benefits", order), 

495 "msg": "Benefit is not selectable.", 

496 "input": benefit_id, 

497 } 

498 ] 

499 ) 

500 new_benefits.add(benefit) 

501 new_product_benefits.append(ProductBenefit(benefit=benefit, order=order)) 

502 

503 # Remove all previous benefits: flush to actually remove them 

504 product.product_benefits = [] 

505 session.add(product) 

506 await session.flush() 

507 

508 # Set the new benefits 

509 product.product_benefits = new_product_benefits 

510 

511 added_benefits = new_benefits - previous_benefits 

512 deleted_benefits = previous_benefits - new_benefits 

513 

514 for deleted_benefit in deleted_benefits: 

515 if not deleted_benefit.selectable: 

516 raise PolarRequestValidationError( 

517 [ 

518 { 

519 "type": "value_error", 

520 "loc": ( 

521 "body", 

522 "benefits", 

523 ), 

524 "msg": "Benefit is not selectable.", 

525 "input": deleted_benefit.id, 

526 } 

527 ] 

528 ) 

529 

530 session.add(product) 

531 

532 if added_benefits or deleted_benefits: 

533 enqueue_job( 

534 "subscription.subscription.update_product_benefits_grants", product.id 

535 ) 

536 enqueue_job("order.update_product_benefits_grants", product.id) 

537 

538 await self._after_product_updated(session, product) 

539 

540 return product, added_benefits, deleted_benefits 

541 

542 async def get_validated_prices( 1a

543 self, 

544 session: AsyncSession, 

545 prices_schema: Sequence[ExistingProductPrice | ProductPriceCreate], 

546 recurring_interval: SubscriptionRecurringInterval | None, 

547 product: Product | None, 

548 auth_subject: AuthSubject[User | Organization], 

549 source: ProductPriceSource = ProductPriceSource.catalog, 

550 error_prefix: tuple[str, ...] = ("body", "prices"), 

551 ) -> tuple[ 

552 builtins.list[ProductPrice], 

553 builtins.set[ProductPrice], 

554 builtins.list[ProductPrice], 

555 builtins.list[ValidationError], 

556 ]: 

557 meter_repository = MeterRepository.from_session(session) 

558 prices: list[ProductPrice] = [] 

559 existing_prices: set[ProductPrice] = set() 

560 added_prices: list[ProductPrice] = [] 

561 errors: list[ValidationError] = [] 

562 meters: set[uuid.UUID] = set() 

563 for index, price_schema in enumerate(prices_schema): 

564 if isinstance(price_schema, ExistingProductPrice): 

565 assert product is not None 

566 price = product.get_price(price_schema.id) 

567 if price is None: 

568 errors.append( 

569 { 

570 "type": "value_error", 

571 "loc": (*error_prefix, index), 

572 "msg": "Price does not exist.", 

573 "input": price_schema.id, 

574 } 

575 ) 

576 continue 

577 existing_prices.add(price) 

578 else: 

579 model_class = price_schema.get_model_class() 

580 price = model_class( 

581 product=product, source=source, **price_schema.model_dump() 

582 ) 

583 if is_metered_price(price) and isinstance( 

584 price_schema, ProductPriceMeteredCreateBase 

585 ): 

586 if recurring_interval is None: 

587 errors.append( 

588 { 

589 "type": "value_error", 

590 "loc": (*error_prefix, index), 

591 "msg": "Metered pricing is not supported on one-time products.", 

592 "input": price_schema, 

593 } 

594 ) 

595 continue 

596 

597 if price_schema.meter_id in meters: 

598 errors.append( 

599 { 

600 "type": "value_error", 

601 "loc": (*error_prefix, index, "meter_id"), 

602 "msg": "Meter is already used for another price.", 

603 "input": price_schema.meter_id, 

604 } 

605 ) 

606 continue 

607 

608 price.meter = await meter_repository.get_readable_by_id( 

609 price_schema.meter_id, auth_subject 

610 ) 

611 if price.meter is None: 

612 errors.append( 

613 { 

614 "type": "value_error", 

615 "loc": (*error_prefix, index, "meter_id"), 

616 "msg": "Meter does not exist.", 

617 "input": price_schema.meter_id, 

618 } 

619 ) 

620 continue 

621 meters.add(price_schema.meter_id) 

622 added_prices.append(price) 

623 prices.append(price) 

624 

625 if len(prices) < 1: 

626 errors.append( 

627 { 

628 "type": "too_short", 

629 "loc": error_prefix, 

630 "msg": "At least one price is required.", 

631 "input": prices_schema, 

632 } 

633 ) 

634 

635 static_prices = [p for p in prices if is_static_price(p)] 

636 if len(static_prices) > 1: 

637 # Bypass that rule for legacy recurring products 

638 if not all(is_legacy_price(p) for p in static_prices): 

639 errors.append( 

640 { 

641 "type": "value_error", 

642 "loc": error_prefix, 

643 "msg": "Only one static price is allowed.", 

644 "input": prices_schema, 

645 } 

646 ) 

647 

648 return prices, existing_prices, added_prices, errors 

649 

650 async def _archive(self, session: AsyncSession, product: Product) -> Product: 1a

651 if product.stripe_product_id is not None: 

652 await stripe_service.archive_product(product.stripe_product_id) 

653 

654 product.is_archived = True 

655 

656 checkout_link_repository = CheckoutLinkRepository.from_session(session) 

657 await checkout_link_repository.archive_product(product.id) 

658 

659 return product 

660 

661 async def _unarchive(self, product: Product) -> Product: 1a

662 if product.stripe_product_id is not None: 

663 await stripe_service.unarchive_product(product.stripe_product_id) 

664 

665 product.is_archived = False 

666 

667 return product 

668 

669 async def _after_product_created( 1a

670 self, 

671 session: AsyncSession, 

672 auth_subject: AuthSubject[User | Organization], 

673 product: Product, 

674 ) -> None: 

675 await self._send_webhook(session, product, WebhookEventType.product_created) 

676 if is_user(auth_subject): 

677 user = auth_subject.subject 

678 await loops_service.user_created_product(user) 

679 

680 async def _after_product_updated( 1a

681 self, session: AsyncSession, product: Product 

682 ) -> None: 

683 await self._send_webhook(session, product, WebhookEventType.product_updated) 

684 

685 async def _send_webhook( 1a

686 self, 

687 session: AsyncSession, 

688 product: Product, 

689 event_type: Literal[ 

690 WebhookEventType.product_created, WebhookEventType.product_updated 

691 ], 

692 ) -> None: 

693 organization_repository = OrganizationRepository.from_session(session) 

694 organization = await organization_repository.get_by_id(product.organization_id) 

695 if organization is not None: 

696 await webhook_service.send(session, organization, event_type, product) 

697 

698 

699product = ProductService() 1a