Coverage for polar/product/service.py: 13%
262 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
1import builtins 1a
2import uuid 1a
3from collections.abc import Sequence 1a
4from typing import Literal 1a
6import stripe 1a
7from sqlalchemy import select 1a
8from sqlalchemy.orm import contains_eager, selectinload 1a
10from polar.auth.models import AuthSubject, is_user 1a
11from polar.benefit.service import benefit as benefit_service 1a
12from polar.checkout_link.repository import CheckoutLinkRepository 1a
13from polar.custom_field.service import custom_field as custom_field_service 1a
14from polar.enums import SubscriptionRecurringInterval 1a
15from polar.exceptions import ( 1a
16 PolarError,
17 PolarRequestValidationError,
18 ValidationError,
19)
20from polar.file.service import file as file_service 1a
21from polar.integrations.loops.service import loops as loops_service 1a
22from polar.integrations.stripe.service import stripe as stripe_service 1a
23from polar.kit.db.postgres import AsyncReadSession, AsyncSession 1a
24from polar.kit.metadata import MetadataQuery, apply_metadata_clause 1a
25from polar.kit.pagination import PaginationParams 1a
26from polar.kit.sorting import Sorting 1a
27from polar.meter.repository import MeterRepository 1a
28from polar.models import ( 1a
29 Benefit,
30 Organization,
31 Product,
32 ProductBenefit,
33 ProductMedia,
34 ProductPrice,
35 User,
36)
37from polar.models.product_custom_field import ProductCustomField 1a
38from polar.models.product_price import HasStripePriceId, ProductPriceSource 1a
39from polar.models.webhook_endpoint import WebhookEventType 1a
40from polar.organization.repository import OrganizationRepository 1a
41from polar.organization.resolver import get_payload_organization 1a
42from polar.product.guard import is_legacy_price, is_metered_price, is_static_price 1a
43from polar.product.repository import ProductRepository 1a
44from polar.webhook.service import webhook as webhook_service 1a
45from polar.worker import enqueue_job 1a
47from .schemas import ( 1a
48 ExistingProductPrice,
49 ProductCreate,
50 ProductPriceCreate,
51 ProductPriceMeteredCreateBase,
52 ProductUpdate,
53)
54from .sorting import ProductSortProperty 1a
57class ProductError(PolarError): ... 1a
60class ProductService: 1a
61 async def list( 1a
62 self,
63 session: AsyncReadSession,
64 auth_subject: AuthSubject[User | Organization],
65 *,
66 id: Sequence[uuid.UUID] | None = None,
67 organization_id: Sequence[uuid.UUID] | None = None,
68 query: str | None = None,
69 is_archived: bool | None = None,
70 is_recurring: bool | None = None,
71 benefit_id: Sequence[uuid.UUID] | None = None,
72 metadata: MetadataQuery | None = None,
73 pagination: PaginationParams,
74 sorting: list[Sorting[ProductSortProperty]] = [
75 (ProductSortProperty.created_at, True)
76 ],
77 ) -> tuple[Sequence[Product], int]:
78 repository = ProductRepository.from_session(session)
79 statement = repository.get_readable_statement(auth_subject).join(
80 ProductPrice,
81 onclause=(
82 ProductPrice.id
83 == select(ProductPrice)
84 .correlate(Product)
85 .with_only_columns(ProductPrice.id)
86 .where(
87 ProductPrice.product_id == Product.id,
88 ProductPrice.is_archived.is_(False),
89 ProductPrice.deleted_at.is_(None),
90 )
91 .order_by(ProductPrice.created_at.asc())
92 .limit(1)
93 .scalar_subquery()
94 ),
95 isouter=True,
96 )
98 if id is not None:
99 statement = statement.where(Product.id.in_(id))
101 if organization_id is not None:
102 statement = statement.where(Product.organization_id.in_(organization_id))
104 if query is not None:
105 statement = statement.where(Product.name.ilike(f"%{query}%"))
107 if is_archived is not None:
108 statement = statement.where(Product.is_archived.is_(is_archived))
110 if is_recurring is not None:
111 statement = statement.where(Product.is_recurring.is_(is_recurring))
113 if benefit_id is not None:
114 statement = (
115 statement.join(Product.product_benefits)
116 .where(ProductBenefit.benefit_id.in_(benefit_id))
117 .options(contains_eager(Product.product_benefits))
118 )
120 if metadata is not None:
121 statement = apply_metadata_clause(Product, statement, metadata)
123 statement = repository.apply_sorting(statement, sorting)
125 statement = statement.options(
126 selectinload(Product.product_medias),
127 selectinload(Product.attached_custom_fields),
128 )
130 return await repository.paginate(
131 statement, limit=pagination.limit, page=pagination.page
132 )
134 async def get( 1a
135 self,
136 session: AsyncReadSession,
137 auth_subject: AuthSubject[User | Organization],
138 id: uuid.UUID,
139 ) -> Product | None:
140 repository = ProductRepository.from_session(session)
141 statement = (
142 repository.get_readable_statement(auth_subject)
143 .where(Product.id == id)
144 .options(*repository.get_eager_options())
145 )
146 return await repository.get_one_or_none(statement)
148 async def get_embed( 1a
149 self, session: AsyncReadSession, id: uuid.UUID
150 ) -> Product | None:
151 repository = ProductRepository.from_session(session)
152 statement = (
153 repository.get_base_statement()
154 .where(Product.id == id, Product.is_archived.is_(False))
155 .options(selectinload(Product.product_medias))
156 )
157 return await repository.get_one_or_none(statement)
159 async def create( 1a
160 self,
161 session: AsyncSession,
162 create_schema: ProductCreate,
163 auth_subject: AuthSubject[User | Organization],
164 ) -> Product:
165 repository = ProductRepository.from_session(session)
166 organization = await get_payload_organization(
167 session, auth_subject, create_schema
168 )
170 errors: list[ValidationError] = []
171 prices, _, _, prices_errors = await self.get_validated_prices(
172 session,
173 create_schema.prices,
174 create_schema.recurring_interval,
175 None,
176 auth_subject,
177 )
178 errors.extend(prices_errors)
180 product = await repository.create(
181 Product(
182 organization=organization,
183 prices=prices,
184 all_prices=prices,
185 product_benefits=[],
186 product_medias=[],
187 attached_custom_fields=[],
188 **create_schema.model_dump(
189 exclude={
190 "organization_id",
191 "prices",
192 "medias",
193 "attached_custom_fields",
194 },
195 by_alias=True,
196 ),
197 ),
198 flush=True,
199 )
200 assert product.id is not None
202 if create_schema.medias is not None:
203 for order, file_id in enumerate(create_schema.medias):
204 file = await file_service.get_selectable_product_media_file(
205 session, file_id, organization_id=product.organization_id
206 )
207 if file is None:
208 errors.append(
209 {
210 "type": "value_error",
211 "loc": ("body", "medias", order),
212 "msg": "File does not exist or is not yet uploaded.",
213 "input": file_id,
214 }
215 )
216 product.product_medias.append(ProductMedia(file=file, order=order))
218 for order, attached_custom_field in enumerate(
219 create_schema.attached_custom_fields
220 ):
221 custom_field = await custom_field_service.get_by_organization_and_id(
222 session,
223 attached_custom_field.custom_field_id,
224 organization.id,
225 )
226 if custom_field is None:
227 errors.append(
228 {
229 "type": "value_error",
230 "loc": ("body", "attached_custom_fields", order),
231 "msg": "Custom field does not exist.",
232 "input": attached_custom_field.custom_field_id,
233 }
234 )
235 product.attached_custom_fields.append(
236 ProductCustomField(
237 custom_field=custom_field,
238 order=order,
239 required=attached_custom_field.required,
240 )
241 )
243 if errors:
244 raise PolarRequestValidationError(errors)
246 metadata: dict[str, str] = {"product_id": str(product.id)}
247 metadata["organization_id"] = str(organization.id)
248 metadata["organization_name"] = organization.slug
250 stripe_product = await stripe_service.create_product(
251 product.get_stripe_name(),
252 description=product.description,
253 metadata=metadata,
254 )
255 product.stripe_product_id = stripe_product.id
257 for price in product.all_prices:
258 if isinstance(price, HasStripePriceId):
259 stripe_price = await stripe_service.create_price_for_product(
260 stripe_product.id,
261 price.get_stripe_price_params(product.recurring_interval),
262 )
263 price.stripe_price_id = stripe_price.id
264 session.add(price)
266 await session.flush()
268 await self._after_product_created(session, auth_subject, product)
270 return product
272 async def update( 1a
273 self,
274 session: AsyncSession,
275 product: Product,
276 update_schema: ProductUpdate,
277 auth_subject: AuthSubject[User | Organization],
278 ) -> Product:
279 errors: list[ValidationError] = []
281 # Validate prices
282 existing_prices = set(product.prices)
283 added_prices: list[ProductPrice] = []
284 if update_schema.prices is not None:
285 (
286 _,
287 existing_prices,
288 added_prices,
289 prices_errors,
290 ) = await self.get_validated_prices(
291 session,
292 update_schema.prices,
293 product.recurring_interval,
294 product,
295 auth_subject,
296 )
297 errors.extend(prices_errors)
299 # Prevent non-legacy products from changing their recurring interval
300 if (
301 update_schema.recurring_interval is not None
302 and (
303 update_schema.recurring_interval != product.recurring_interval
304 or update_schema.recurring_interval_count
305 != product.recurring_interval_count
306 )
307 and not all(is_legacy_price(price) for price in product.prices)
308 ):
309 errors.append(
310 {
311 "type": "value_error",
312 "loc": ("body", "recurring_interval"),
313 "msg": "Recurring interval cannot be changed.",
314 "input": update_schema.recurring_interval,
315 }
316 )
318 # Prevent trying to add trial configuration to non-recurring products
319 if (
320 update_schema.trial_interval is not None
321 or update_schema.trial_interval_count is not None
322 ) and product.recurring_interval is None:
323 errors.extend(
324 [
325 {
326 "type": "value_error",
327 "loc": ("body", "trial_interval"),
328 "msg": "Trial configuration is only supported on recurring products.",
329 "input": update_schema.trial_interval,
330 },
331 {
332 "type": "value_error",
333 "loc": ("body", "trial_interval_count"),
334 "msg": "Trial configuration is only supported on recurring products.",
335 "input": update_schema.trial_interval_count,
336 },
337 ]
338 )
340 if update_schema.medias is not None:
341 medias_errors: list[ValidationError] = []
342 nested = await session.begin_nested()
343 product.medias = []
344 await session.flush()
346 for order, file_id in enumerate(update_schema.medias):
347 file = await file_service.get_selectable_product_media_file(
348 session, file_id, organization_id=product.organization_id
349 )
350 if file is None:
351 medias_errors.append(
352 {
353 "type": "value_error",
354 "loc": ("body", "medias", order),
355 "msg": "File does not exist or is not yet uploaded.",
356 "input": file_id,
357 }
358 )
359 continue
360 product.product_medias.append(ProductMedia(file=file, order=order))
362 if medias_errors:
363 await nested.rollback()
364 errors.extend(medias_errors)
366 if update_schema.attached_custom_fields is not None:
367 attached_custom_fields_errors: list[ValidationError] = []
368 nested = await session.begin_nested()
369 product.attached_custom_fields = []
370 await session.flush()
372 for order, attached_custom_field in enumerate(
373 update_schema.attached_custom_fields
374 ):
375 custom_field = await custom_field_service.get_by_organization_and_id(
376 session,
377 attached_custom_field.custom_field_id,
378 product.organization_id,
379 )
380 if custom_field is None:
381 attached_custom_fields_errors.append(
382 {
383 "type": "value_error",
384 "loc": ("body", "attached_custom_fields", order),
385 "msg": "Custom field does not exist.",
386 "input": attached_custom_field.custom_field_id,
387 }
388 )
389 continue
390 product.attached_custom_fields.append(
391 ProductCustomField(
392 custom_field=custom_field,
393 order=order,
394 required=attached_custom_field.required,
395 )
396 )
398 if attached_custom_fields_errors:
399 await nested.rollback()
400 errors.extend(attached_custom_fields_errors)
402 if errors:
403 raise PolarRequestValidationError(errors)
405 if product.is_archived and update_schema.is_archived is False:
406 product = await self._unarchive(product)
408 product_update: stripe.Product.ModifyParams = {}
409 if update_schema.name is not None and update_schema.name != product.name:
410 product.name = update_schema.name
411 product_update["name"] = product.get_stripe_name()
412 if (
413 update_schema.description is not None
414 and update_schema.description != product.description
415 ):
416 product.description = update_schema.description
417 product_update["description"] = update_schema.description
419 if product_update and product.stripe_product_id is not None:
420 await stripe_service.update_product(
421 product.stripe_product_id, **product_update
422 )
424 if update_schema.recurring_interval is not None:
425 product.recurring_interval = update_schema.recurring_interval
427 deleted_prices = set(product.prices) - existing_prices
428 for deleted_price in deleted_prices:
429 assert product.stripe_product_id is not None
430 if isinstance(deleted_price, HasStripePriceId):
431 await stripe_service.update_product(
432 product.stripe_product_id, default_price=""
433 )
434 await stripe_service.archive_price(deleted_price.stripe_price_id)
435 deleted_price.is_archived = True
437 for price in added_prices:
438 if isinstance(price, HasStripePriceId):
439 assert product.stripe_product_id is not None
440 stripe_price = await stripe_service.create_price_for_product(
441 product.stripe_product_id,
442 price.get_stripe_price_params(product.recurring_interval),
443 )
444 price.stripe_price_id = stripe_price.id
446 if update_schema.is_archived:
447 product = await self._archive(session, product)
449 for attr, value in update_schema.model_dump(
450 exclude_unset=True,
451 exclude={"prices", "medias", "attached_custom_fields"},
452 by_alias=True,
453 ).items():
454 setattr(product, attr, value)
456 session.add(product)
457 await session.flush()
459 await session.refresh(product, {"prices", "all_prices"})
461 await self._after_product_updated(session, product)
463 return product
465 async def update_benefits( 1a
466 self,
467 session: AsyncSession,
468 product: Product,
469 benefits: Sequence[uuid.UUID],
470 auth_subject: AuthSubject[User | Organization],
471 ) -> tuple[Product, set[Benefit], set[Benefit]]:
472 previous_benefits = set(product.benefits)
473 new_benefits: set[Benefit] = set()
475 new_product_benefits: list[ProductBenefit] = []
476 for order, benefit_id in enumerate(benefits):
477 benefit = await benefit_service.get(session, auth_subject, benefit_id)
478 if benefit is None:
479 raise PolarRequestValidationError(
480 [
481 {
482 "type": "value_error",
483 "loc": ("body", "benefits", order),
484 "msg": "Benefit does not exist.",
485 "input": benefit_id,
486 }
487 ]
488 )
489 if not benefit.selectable and benefit not in previous_benefits:
490 raise PolarRequestValidationError(
491 [
492 {
493 "type": "value_error",
494 "loc": ("body", "benefits", order),
495 "msg": "Benefit is not selectable.",
496 "input": benefit_id,
497 }
498 ]
499 )
500 new_benefits.add(benefit)
501 new_product_benefits.append(ProductBenefit(benefit=benefit, order=order))
503 # Remove all previous benefits: flush to actually remove them
504 product.product_benefits = []
505 session.add(product)
506 await session.flush()
508 # Set the new benefits
509 product.product_benefits = new_product_benefits
511 added_benefits = new_benefits - previous_benefits
512 deleted_benefits = previous_benefits - new_benefits
514 for deleted_benefit in deleted_benefits:
515 if not deleted_benefit.selectable:
516 raise PolarRequestValidationError(
517 [
518 {
519 "type": "value_error",
520 "loc": (
521 "body",
522 "benefits",
523 ),
524 "msg": "Benefit is not selectable.",
525 "input": deleted_benefit.id,
526 }
527 ]
528 )
530 session.add(product)
532 if added_benefits or deleted_benefits:
533 enqueue_job(
534 "subscription.subscription.update_product_benefits_grants", product.id
535 )
536 enqueue_job("order.update_product_benefits_grants", product.id)
538 await self._after_product_updated(session, product)
540 return product, added_benefits, deleted_benefits
542 async def get_validated_prices( 1a
543 self,
544 session: AsyncSession,
545 prices_schema: Sequence[ExistingProductPrice | ProductPriceCreate],
546 recurring_interval: SubscriptionRecurringInterval | None,
547 product: Product | None,
548 auth_subject: AuthSubject[User | Organization],
549 source: ProductPriceSource = ProductPriceSource.catalog,
550 error_prefix: tuple[str, ...] = ("body", "prices"),
551 ) -> tuple[
552 builtins.list[ProductPrice],
553 builtins.set[ProductPrice],
554 builtins.list[ProductPrice],
555 builtins.list[ValidationError],
556 ]:
557 meter_repository = MeterRepository.from_session(session)
558 prices: list[ProductPrice] = []
559 existing_prices: set[ProductPrice] = set()
560 added_prices: list[ProductPrice] = []
561 errors: list[ValidationError] = []
562 meters: set[uuid.UUID] = set()
563 for index, price_schema in enumerate(prices_schema):
564 if isinstance(price_schema, ExistingProductPrice):
565 assert product is not None
566 price = product.get_price(price_schema.id)
567 if price is None:
568 errors.append(
569 {
570 "type": "value_error",
571 "loc": (*error_prefix, index),
572 "msg": "Price does not exist.",
573 "input": price_schema.id,
574 }
575 )
576 continue
577 existing_prices.add(price)
578 else:
579 model_class = price_schema.get_model_class()
580 price = model_class(
581 product=product, source=source, **price_schema.model_dump()
582 )
583 if is_metered_price(price) and isinstance(
584 price_schema, ProductPriceMeteredCreateBase
585 ):
586 if recurring_interval is None:
587 errors.append(
588 {
589 "type": "value_error",
590 "loc": (*error_prefix, index),
591 "msg": "Metered pricing is not supported on one-time products.",
592 "input": price_schema,
593 }
594 )
595 continue
597 if price_schema.meter_id in meters:
598 errors.append(
599 {
600 "type": "value_error",
601 "loc": (*error_prefix, index, "meter_id"),
602 "msg": "Meter is already used for another price.",
603 "input": price_schema.meter_id,
604 }
605 )
606 continue
608 price.meter = await meter_repository.get_readable_by_id(
609 price_schema.meter_id, auth_subject
610 )
611 if price.meter is None:
612 errors.append(
613 {
614 "type": "value_error",
615 "loc": (*error_prefix, index, "meter_id"),
616 "msg": "Meter does not exist.",
617 "input": price_schema.meter_id,
618 }
619 )
620 continue
621 meters.add(price_schema.meter_id)
622 added_prices.append(price)
623 prices.append(price)
625 if len(prices) < 1:
626 errors.append(
627 {
628 "type": "too_short",
629 "loc": error_prefix,
630 "msg": "At least one price is required.",
631 "input": prices_schema,
632 }
633 )
635 static_prices = [p for p in prices if is_static_price(p)]
636 if len(static_prices) > 1:
637 # Bypass that rule for legacy recurring products
638 if not all(is_legacy_price(p) for p in static_prices):
639 errors.append(
640 {
641 "type": "value_error",
642 "loc": error_prefix,
643 "msg": "Only one static price is allowed.",
644 "input": prices_schema,
645 }
646 )
648 return prices, existing_prices, added_prices, errors
650 async def _archive(self, session: AsyncSession, product: Product) -> Product: 1a
651 if product.stripe_product_id is not None:
652 await stripe_service.archive_product(product.stripe_product_id)
654 product.is_archived = True
656 checkout_link_repository = CheckoutLinkRepository.from_session(session)
657 await checkout_link_repository.archive_product(product.id)
659 return product
661 async def _unarchive(self, product: Product) -> Product: 1a
662 if product.stripe_product_id is not None:
663 await stripe_service.unarchive_product(product.stripe_product_id)
665 product.is_archived = False
667 return product
669 async def _after_product_created( 1a
670 self,
671 session: AsyncSession,
672 auth_subject: AuthSubject[User | Organization],
673 product: Product,
674 ) -> None:
675 await self._send_webhook(session, product, WebhookEventType.product_created)
676 if is_user(auth_subject):
677 user = auth_subject.subject
678 await loops_service.user_created_product(user)
680 async def _after_product_updated( 1a
681 self, session: AsyncSession, product: Product
682 ) -> None:
683 await self._send_webhook(session, product, WebhookEventType.product_updated)
685 async def _send_webhook( 1a
686 self,
687 session: AsyncSession,
688 product: Product,
689 event_type: Literal[
690 WebhookEventType.product_created, WebhookEventType.product_updated
691 ],
692 ) -> None:
693 organization_repository = OrganizationRepository.from_session(session)
694 organization = await organization_repository.get_by_id(product.organization_id)
695 if organization is not None:
696 await webhook_service.send(session, organization, event_type, product)
699product = ProductService() 1a