Coverage for polar/checkout/service.py: 11%
829 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
1import contextlib 1a
2import typing 1a
3import uuid 1a
4from collections.abc import AsyncGenerator, AsyncIterator, Sequence 1a
6import stripe as stripe_lib 1a
7import structlog 1a
8from pydantic import UUID4 1a
9from pydantic import ValidationError as PydanticValidationError 1a
10from sqlalchemy import func, select 1a
11from sqlalchemy.orm import contains_eager, joinedload, selectinload 1a
13from polar.auth.models import Anonymous, AuthSubject 1a
14from polar.checkout.guard import has_product_checkout 1a
15from polar.checkout.schemas import ( 1a
16 CheckoutConfirm,
17 CheckoutCreate,
18 CheckoutCreatePublic,
19 CheckoutPriceCreate,
20 CheckoutProductCreate,
21 CheckoutUpdate,
22 CheckoutUpdatePublic,
23)
24from polar.config import settings 1a
25from polar.custom_field.data import validate_custom_field_data 1a
26from polar.customer.repository import CustomerRepository 1a
27from polar.customer_session.service import customer_session as customer_session_service 1a
28from polar.discount.service import DiscountNotRedeemableError 1a
29from polar.discount.service import discount as discount_service 1a
30from polar.enums import PaymentProcessor, SubscriptionRecurringInterval 1a
31from polar.exceptions import ( 1a
32 BadRequest,
33 NotPermitted,
34 PaymentNotReady,
35 PolarError,
36 PolarRequestValidationError,
37 ResourceNotFound,
38 ValidationError,
39)
40from polar.integrations.stripe.schemas import ProductType 1a
41from polar.integrations.stripe.service import stripe as stripe_service 1a
42from polar.integrations.stripe.utils import get_fingerprint 1a
43from polar.kit.address import AddressInput 1a
44from polar.kit.crypto import generate_token 1a
45from polar.kit.operator import attrgetter 1a
46from polar.kit.pagination import PaginationParams 1a
47from polar.kit.sorting import Sorting 1a
48from polar.kit.tax import ( 1a
49 InvalidTaxID,
50 TaxCalculationError,
51 TaxCode,
52 TaxID,
53 calculate_tax,
54 to_stripe_tax_id,
55 validate_tax_id,
56)
57from polar.kit.utils import utc_now 1a
58from polar.locker import Locker 1a
59from polar.logging import Logger 1a
60from polar.member import member_service 1a
61from polar.models import ( 1a
62 Account,
63 Checkout,
64 CheckoutLink,
65 Customer,
66 Discount,
67 LegacyRecurringProductPriceCustom,
68 LegacyRecurringProductPriceFixed,
69 Organization,
70 Payment,
71 PaymentMethod,
72 Product,
73 ProductPrice,
74 Subscription,
75 User,
76)
77from polar.models.checkout import CheckoutStatus 1a
78from polar.models.checkout_product import CheckoutProduct 1a
79from polar.models.discount import DiscountDuration 1a
80from polar.models.order import OrderBillingReasonInternal 1a
81from polar.models.product_price import ProductPriceAmountType, ProductPriceSource 1a
82from polar.models.webhook_endpoint import WebhookEventType 1a
83from polar.order.service import order as order_service 1a
84from polar.organization.service import organization as organization_service 1a
85from polar.postgres import AsyncReadSession, AsyncSession 1a
86from polar.product.guard import ( 1a
87 is_currency_price,
88 is_custom_price,
89 is_discount_applicable,
90 is_fixed_price,
91 is_seat_price,
92)
93from polar.product.repository import ProductPriceRepository, ProductRepository 1a
94from polar.product.schemas import ProductPriceCreateList 1a
95from polar.product.service import product as product_service 1a
96from polar.subscription.repository import SubscriptionRepository 1a
97from polar.subscription.service import subscription as subscription_service 1a
98from polar.trial_redemption.service import trial_redemption as trial_redemption_service 1a
99from polar.webhook.service import webhook as webhook_service 1a
100from polar.worker import enqueue_job 1a
102from . import ip_geolocation 1a
103from .eventstream import CheckoutEvent, publish_checkout_event 1a
104from .price import get_default_price 1a
105from .repository import CheckoutRepository 1a
106from .sorting import CheckoutSortProperty 1a
108log: Logger = structlog.get_logger() 1a
111class CheckoutError(PolarError): ... 1a
114class ExpiredCheckoutError(CheckoutError): 1a
115 def __init__(self) -> None: 1a
116 message = "This checkout session has expired."
117 super().__init__(message, 410)
120class AlreadyActiveSubscriptionError(CheckoutError): 1a
121 def __init__(self) -> None: 1a
122 message = "You already have an active subscription."
123 super().__init__(message, 403)
126class PaymentError(CheckoutError): 1a
127 def __init__( 1a
128 self, checkout: Checkout, error_type: str | None, error: str | None
129 ) -> None:
130 self.checkout = checkout
131 self.error_type = error_type
132 self.error = error
133 message = (
134 f"The payment failed{f': {error}' if error else '.'} "
135 "Please try again with a different payment method."
136 )
137 super().__init__(message, 400)
140class CheckoutDoesNotExist(CheckoutError): 1a
141 def __init__(self, checkout_id: uuid.UUID) -> None: 1a
142 self.checkout_id = checkout_id
143 message = f"Checkout {checkout_id} does not exist."
144 super().__init__(message)
147class NotOpenCheckout(CheckoutError): 1a
148 def __init__(self, checkout: Checkout) -> None: 1a
149 self.checkout = checkout
150 self.status = checkout.status
151 message = f"Checkout {checkout.id} is not open: {checkout.status}"
152 super().__init__(message, 403)
155class NotConfirmedCheckout(CheckoutError): 1a
156 def __init__(self, checkout: Checkout) -> None: 1a
157 self.checkout = checkout
158 self.status = checkout.status
159 message = f"Checkout {checkout.id} is not confirmed: {checkout.status}"
160 super().__init__(message)
163class PaymentDoesNotExist(CheckoutError): 1a
164 def __init__(self, payment_id: uuid.UUID) -> None: 1a
165 self.payment_id = payment_id
166 message = f"Payment {payment_id} does not exist."
167 super().__init__(message)
170class ArchivedPriceCheckout(CheckoutError): 1a
171 def __init__(self, checkout: Checkout) -> None: 1a
172 self.checkout = checkout
173 self.price = checkout.product_price
174 message = (
175 f"Checkout {checkout.id} has an archived price: {checkout.product_price_id}"
176 )
177 super().__init__(message)
180class IntentNotSucceeded(CheckoutError): 1a
181 def __init__(self, checkout: Checkout, intent_id: str) -> None: 1a
182 self.checkout = checkout
183 self.intent_id = intent_id
184 message = f"Intent {intent_id} for {checkout.id} is not successful."
185 super().__init__(message)
188class NoPaymentMethodOnIntent(CheckoutError): 1a
189 def __init__(self, checkout: Checkout, intent_id: str) -> None: 1a
190 self.checkout = checkout
191 self.intent_id = intent_id
192 message = (
193 f"Intent {intent_id} for {checkout.id} has no payment method associated."
194 )
195 super().__init__(message)
198class PaymentRequired(CheckoutError): 1a
199 def __init__(self, checkout: Checkout) -> None: 1a
200 self.checkout = checkout
201 message = f"{checkout.id} requires a payment."
202 super().__init__(message)
205class TrialAlreadyRedeemed(CheckoutError): 1a
206 def __init__(self, checkout: Checkout) -> None: 1a
207 self.checkout = checkout
208 message = (
209 "You have already used a trial for this product. "
210 "Trials can only be used once per customer."
211 )
212 super().__init__(message, 403)
215CHECKOUT_CLIENT_SECRET_PREFIX = "polar_c_" 1a
218class CheckoutService: 1a
219 async def list( 1a
220 self,
221 session: AsyncReadSession,
222 auth_subject: AuthSubject[User | Organization],
223 *,
224 organization_id: Sequence[uuid.UUID] | None = None,
225 product_id: Sequence[uuid.UUID] | None = None,
226 customer_id: Sequence[uuid.UUID] | None = None,
227 status: Sequence[CheckoutStatus] | None = None,
228 query: str | None = None,
229 pagination: PaginationParams,
230 sorting: list[Sorting[CheckoutSortProperty]] = [
231 (CheckoutSortProperty.created_at, True)
232 ],
233 ) -> tuple[Sequence[Checkout], int]:
234 repository = CheckoutRepository.from_session(session)
235 statement = repository.get_readable_statement(auth_subject).options(
236 *repository.get_eager_options()
237 )
239 if organization_id is not None:
240 statement = statement.where(Checkout.organization_id.in_(organization_id))
242 if product_id is not None:
243 statement = statement.where(Checkout.product_id.in_(product_id))
245 if customer_id is not None:
246 statement = statement.where(Checkout.customer_id.in_(customer_id))
248 if status is not None:
249 statement = statement.where(Checkout.status.in_(status))
251 if query is not None:
252 statement = statement.where(Checkout.customer_email.ilike(f"%{query}%"))
254 statement = repository.apply_sorting(statement, sorting)
256 return await repository.paginate(
257 statement, limit=pagination.limit, page=pagination.page
258 )
260 async def get_by_id( 1a
261 self,
262 session: AsyncReadSession,
263 auth_subject: AuthSubject[User | Organization],
264 id: uuid.UUID,
265 ) -> Checkout | None:
266 repository = CheckoutRepository.from_session(session)
267 statement = (
268 repository.get_readable_statement(auth_subject)
269 .where(Checkout.id == id)
270 .options(*repository.get_eager_options())
271 )
272 checkout = await repository.get_one_or_none(statement)
274 if checkout is None:
275 return None
277 if checkout.organization.is_blocked():
278 raise NotPermitted()
280 return checkout
282 async def create( 1a
283 self,
284 session: AsyncSession,
285 checkout_create: CheckoutCreate,
286 auth_subject: AuthSubject[User | Organization],
287 ip_geolocation_client: ip_geolocation.IPGeolocationClient | None = None,
288 ) -> Checkout:
289 ad_hoc_prices: dict[Product, Sequence[ProductPrice]] = {}
290 if isinstance(checkout_create, CheckoutPriceCreate):
291 products, product, price = await self._get_validated_price(
292 session, auth_subject, checkout_create.product_price_id
293 )
294 elif isinstance(checkout_create, CheckoutProductCreate):
295 products, product, price = await self._get_validated_product(
296 session, auth_subject, checkout_create.product_id
297 )
298 else:
299 products = await self._get_validated_products(
300 session, auth_subject, checkout_create.products
301 )
302 if checkout_create.prices:
303 ad_hoc_prices = await self._get_validated_prices(
304 session, auth_subject, products, checkout_create.prices
305 )
307 product = products[0]
308 try:
309 price = get_default_price(ad_hoc_prices[product])
310 except KeyError:
311 price = get_default_price(product.prices)
313 if product.organization.is_blocked():
314 raise NotPermitted()
316 if checkout_create.amount is not None and is_custom_price(price):
317 self._validate_custom_price_amount(price, checkout_create.amount)
319 discount: Discount | None = None
320 if checkout_create.discount_id is not None:
321 discount = await self._get_validated_discount(
322 session,
323 product.organization,
324 product,
325 price,
326 discount_id=checkout_create.discount_id,
327 )
329 customer_tax_id: TaxID | None = None
330 if checkout_create.customer_tax_id is not None:
331 if checkout_create.customer_billing_address is None:
332 raise PolarRequestValidationError(
333 [
334 {
335 "type": "missing",
336 "loc": ("body", "customer_billing_address"),
337 "msg": "Country is required to validate tax ID.",
338 "input": None,
339 }
340 ]
341 )
342 try:
343 customer_tax_id = validate_tax_id(
344 checkout_create.customer_tax_id,
345 checkout_create.customer_billing_address.country,
346 )
347 except InvalidTaxID as e:
348 raise PolarRequestValidationError(
349 [
350 {
351 "type": "value_error",
352 "loc": ("body", "customer_tax_id"),
353 "msg": "Invalid tax ID.",
354 "input": checkout_create.customer_tax_id,
355 }
356 ]
357 ) from e
359 # Validate seats for seat-based pricing
360 if is_seat_price(price):
361 if checkout_create.seats is None or checkout_create.seats < 1:
362 raise PolarRequestValidationError(
363 [
364 {
365 "type": "missing",
366 "loc": ("body", "seats"),
367 "msg": "Seats is required for seat-based pricing.",
368 "input": checkout_create.seats,
369 }
370 ]
371 )
372 elif checkout_create.seats is not None:
373 raise PolarRequestValidationError(
374 [
375 {
376 "type": "value_error",
377 "loc": ("body", "seats"),
378 "msg": "Seats can only be set for seat-based pricing.",
379 "input": checkout_create.seats,
380 }
381 ]
382 )
384 product = await self._eager_load_product(session, product)
386 subscription: Subscription | None = None
387 customer: Customer | None = None
388 customer_repository = CustomerRepository.from_session(session)
389 if checkout_create.subscription_id is not None:
390 subscription, customer = await self._get_validated_subscription(
391 session, checkout_create.subscription_id, product.organization_id
392 )
393 elif checkout_create.customer_id is not None:
394 customer = await customer_repository.get_by_id_and_organization(
395 checkout_create.customer_id, product.organization_id
396 )
397 if customer is None:
398 raise PolarRequestValidationError(
399 [
400 {
401 "type": "value_error",
402 "loc": ("body", "customer_id"),
403 "msg": "Customer does not exist.",
404 "input": checkout_create.customer_id,
405 }
406 ]
407 )
408 elif checkout_create.external_customer_id is not None:
409 # Link customer by external ID, if it exists.
410 # It not, that's fine': we'll create a new customer on confirm.
411 customer = await customer_repository.get_by_external_id_and_organization(
412 checkout_create.external_customer_id, product.organization_id
413 )
415 amount = checkout_create.amount
416 currency = None
417 if is_fixed_price(price):
418 amount = price.price_amount
419 currency = price.price_currency
420 elif is_custom_price(price):
421 currency = price.price_currency
422 if amount is None:
423 amount = (
424 price.preset_amount
425 or price.minimum_amount
426 or settings.CUSTOM_PRICE_PRESET_FALLBACK
427 )
428 elif is_seat_price(price):
429 # Calculate amount based on seat count
430 seats = checkout_create.seats or 1
431 amount = price.calculate_amount(seats)
432 currency = price.price_currency
433 else:
434 amount = 0
435 currency = price.price_currency if is_currency_price(price) else "usd"
437 custom_field_data = validate_custom_field_data(
438 product.attached_custom_fields,
439 checkout_create.custom_field_data,
440 validate_required=False,
441 )
443 checkout_products = [
444 CheckoutProduct(product=product, order=i, ad_hoc_prices=[])
445 for i, product in enumerate(products)
446 ]
448 require_billing_address = checkout_create.require_billing_address
449 customer_billing_address = checkout_create.customer_billing_address
450 if customer_billing_address is not None and any(
451 (
452 customer_billing_address.has_address(),
453 customer_billing_address.has_state()
454 and customer_billing_address.country not in {"US", "CA"},
455 )
456 ):
457 require_billing_address = True
459 checkout = Checkout(
460 payment_processor=PaymentProcessor.stripe,
461 client_secret=generate_token(prefix=CHECKOUT_CLIENT_SECRET_PREFIX),
462 amount=amount,
463 currency=currency,
464 organization=product.organization,
465 checkout_products=checkout_products,
466 product=product,
467 product_price=price,
468 discount=discount,
469 customer_billing_address=customer_billing_address,
470 require_billing_address=require_billing_address,
471 customer_tax_id=customer_tax_id,
472 subscription=subscription,
473 customer=customer,
474 custom_field_data=custom_field_data,
475 **checkout_create.model_dump(
476 exclude={
477 "product_price_id",
478 "product_id",
479 "products",
480 "prices",
481 "amount",
482 "require_billing_address",
483 "customer_billing_address",
484 "customer_tax_id",
485 "subscription_id",
486 "custom_field_data",
487 },
488 by_alias=True,
489 ),
490 )
492 if checkout.customer is not None:
493 prefill_attributes = (
494 "email",
495 "name",
496 "billing_name",
497 "billing_address",
498 "tax_id",
499 )
500 for attribute in prefill_attributes:
501 checkout_attribute = f"customer_{attribute}"
502 if getattr(checkout, checkout_attribute) is None:
503 setattr(
504 checkout,
505 checkout_attribute,
506 getattr(checkout.customer, attribute),
507 )
509 # Auto-select business customer if they have both a billing name (without the fallback to customer.name)
510 # and a billing address since that means they've previously checked the is_business_customer checkbox
511 # Only auto-select if is_business_customer wasn't explicitly set in the request
512 if (
513 "is_business_customer" not in checkout_create.model_fields_set
514 and checkout.customer.actual_billing_name is not None
515 and checkout.customer.billing_address is not None
516 and checkout.customer.billing_address.has_address()
517 ):
518 checkout.is_business_customer = True
520 if checkout.payment_processor == PaymentProcessor.stripe:
521 checkout.payment_processor_metadata = {
522 **(checkout.payment_processor_metadata or {}),
523 "publishable_key": settings.STRIPE_PUBLISHABLE_KEY,
524 }
525 if checkout.customer and checkout.customer.stripe_customer_id is not None:
526 stripe_customer_session = await stripe_service.create_customer_session(
527 checkout.customer.stripe_customer_id
528 )
529 checkout.payment_processor_metadata = {
530 **(checkout.payment_processor_metadata or {}),
531 "customer_session_client_secret": stripe_customer_session.client_secret,
532 }
534 session.add(checkout)
536 checkout = await self._update_checkout_ip_geolocation(
537 session, checkout, ip_geolocation_client
538 )
539 checkout = await self._update_trial_end(checkout)
541 try:
542 checkout = await self._update_checkout_tax(session, checkout)
543 # Swallow incomplete tax calculation error: require it only on confirm
544 except TaxCalculationError:
545 pass
547 await session.flush()
549 if ad_hoc_prices:
550 for checkout_product in checkout.checkout_products:
551 checkout_product.ad_hoc_prices = ad_hoc_prices.get(
552 checkout_product.product, []
553 )
554 session.add(checkout_product)
555 await session.flush()
557 await self._after_checkout_created(session, checkout)
559 return checkout
561 async def client_create( 1a
562 self,
563 session: AsyncSession,
564 checkout_create: CheckoutCreatePublic,
565 auth_subject: AuthSubject[User | Anonymous],
566 ip_geolocation_client: ip_geolocation.IPGeolocationClient | None = None,
567 ip_address: str | None = None,
568 ) -> Checkout:
569 product_repository = ProductRepository.from_session(session)
570 product = await product_repository.get_by_id(
571 checkout_create.product_id, options=product_repository.get_eager_options()
572 )
574 if product is None:
575 raise PolarRequestValidationError(
576 [
577 {
578 "type": "value_error",
579 "loc": ("body", "product_id"),
580 "msg": "Product does not exist.",
581 "input": checkout_create.product_id,
582 }
583 ]
584 )
586 if product.is_archived:
587 raise PolarRequestValidationError(
588 [
589 {
590 "type": "value_error",
591 "loc": ("body", "product_id"),
592 "msg": "Product is archived.",
593 "input": checkout_create.product_id,
594 }
595 ]
596 )
598 if product.organization.blocked_at is not None:
599 raise PolarRequestValidationError(
600 [
601 {
602 "type": "value_error",
603 "loc": ("body", "product_id"),
604 "msg": "Organization is blocked.",
605 "input": checkout_create.product_id,
606 }
607 ]
608 )
610 # Select the static price in priority, as it determines the amount and specific behavior
611 price = product.get_static_price() or product.prices[0]
613 # Validate seats for seat-based pricing
614 if is_seat_price(price):
615 if checkout_create.seats is None or checkout_create.seats < 1:
616 raise PolarRequestValidationError(
617 [
618 {
619 "type": "missing",
620 "loc": ("body", "seats"),
621 "msg": "Seats is required for seat-based pricing.",
622 "input": checkout_create.seats,
623 }
624 ]
625 )
626 elif checkout_create.seats is not None:
627 raise PolarRequestValidationError(
628 [
629 {
630 "type": "value_error",
631 "loc": ("body", "seats"),
632 "msg": "Seats can only be set for seat-based pricing.",
633 "input": checkout_create.seats,
634 }
635 ]
636 )
638 amount = 0
639 currency = "usd"
640 if is_fixed_price(price):
641 amount = price.price_amount
642 currency = price.price_currency
643 elif is_custom_price(price):
644 currency = price.price_currency
645 amount = (
646 price.preset_amount
647 or price.minimum_amount
648 or settings.CUSTOM_PRICE_PRESET_FALLBACK
649 )
650 elif is_seat_price(price):
651 # Calculate amount based on seat count
652 seats = checkout_create.seats or 1
653 amount = price.calculate_amount(seats)
654 currency = price.price_currency
655 elif is_currency_price(price):
656 currency = price.price_currency
658 checkout = Checkout(
659 payment_processor=PaymentProcessor.stripe,
660 client_secret=generate_token(prefix=CHECKOUT_CLIENT_SECRET_PREFIX),
661 amount=amount,
662 currency=currency,
663 seats=checkout_create.seats,
664 allow_trial=True,
665 organization=product.organization,
666 checkout_products=[
667 CheckoutProduct(product=product, order=0, ad_hoc_prices=[])
668 ],
669 product=product,
670 product_price=price,
671 discount=None,
672 customer=None,
673 subscription=None,
674 customer_email=checkout_create.customer_email,
675 )
677 if checkout.payment_processor == PaymentProcessor.stripe:
678 checkout.payment_processor_metadata = {
679 **(checkout.payment_processor_metadata or {}),
680 "publishable_key": settings.STRIPE_PUBLISHABLE_KEY,
681 }
682 if checkout.customer and checkout.customer.stripe_customer_id is not None:
683 stripe_customer_session = await stripe_service.create_customer_session(
684 checkout.customer.stripe_customer_id
685 )
686 checkout.payment_processor_metadata = {
687 **(checkout.payment_processor_metadata or {}),
688 "customer_session_client_secret": stripe_customer_session.client_secret,
689 }
691 checkout.customer_ip_address = ip_address
692 checkout = await self._update_checkout_ip_geolocation(
693 session, checkout, ip_geolocation_client
694 )
695 checkout = await self._update_trial_end(checkout)
697 try:
698 checkout = await self._update_checkout_tax(session, checkout)
699 # Swallow incomplete tax calculation error: require it only on confirm
700 except TaxCalculationError:
701 pass
703 session.add(checkout)
705 await session.flush()
706 await self._after_checkout_created(session, checkout)
708 return checkout
710 async def checkout_link_create( 1a
711 self,
712 session: AsyncSession,
713 checkout_link: CheckoutLink,
714 embed_origin: str | None = None,
715 ip_geolocation_client: ip_geolocation.IPGeolocationClient | None = None,
716 ip_address: str | None = None,
717 query_prefill: dict[str, str | UUID4 | dict[str, str] | None] | None = None,
718 **query_metadata: str | None,
719 ) -> Checkout:
720 products: list[Product] = []
721 for product in checkout_link.products:
722 if not product.is_archived:
723 products.append(product)
725 if len(products) == 0:
726 raise PolarRequestValidationError(
727 [
728 {
729 "type": "value_error",
730 "loc": ("body", "products"),
731 "msg": "No valid products.",
732 "input": checkout_link.products,
733 }
734 ]
735 )
737 # Pre-select product if product_id is provided and matches a configured product
738 product = products[0]
739 query_product_id = query_prefill.get("product_id") if query_prefill else None
740 product_id = (
741 query_product_id if isinstance(query_product_id, uuid.UUID) else None
742 )
744 if product_id is not None:
745 for p in products:
746 if p.id == product_id:
747 product = p
748 break
749 # Select the static price in priority, as it determines the amount and specific behavior
750 price = product.get_static_price() or product.prices[0]
752 amount = 0
753 currency = "usd"
754 seats = None
755 if is_fixed_price(price):
756 amount = price.price_amount
757 currency = price.price_currency
758 elif is_custom_price(price):
759 currency = price.price_currency
760 query_amount_str = query_prefill.get("amount") if query_prefill else None
762 # Try to parse and validate query amount
763 valid_query_amount = None
764 if query_amount_str is not None and isinstance(query_amount_str, str):
765 try:
766 query_amount_int = int(float(query_amount_str))
767 self._validate_custom_price_amount(price, query_amount_int)
768 valid_query_amount = query_amount_int
769 except (ValueError, TypeError, PolarRequestValidationError):
770 pass
772 amount = (
773 valid_query_amount
774 or price.preset_amount
775 or price.minimum_amount
776 or settings.CUSTOM_PRICE_PRESET_FALLBACK
777 )
778 elif is_seat_price(price):
779 # Default to 1 seat for checkout links with seat-based pricing
780 seats = 1
781 amount = price.calculate_amount(seats)
782 currency = price.price_currency
783 elif is_currency_price(price):
784 currency = price.price_currency
786 discount: Discount | None = None
787 if checkout_link.discount_id is not None:
788 try:
789 discount = await self._get_validated_discount(
790 session,
791 product.organization,
792 product,
793 price,
794 discount_id=checkout_link.discount_id,
795 )
796 # If the discount is not valid, just ignore it
797 except PolarRequestValidationError:
798 pass
800 checkout = Checkout(
801 client_secret=generate_token(prefix=CHECKOUT_CLIENT_SECRET_PREFIX),
802 amount=amount,
803 currency=currency,
804 seats=seats,
805 trial_interval=checkout_link.trial_interval,
806 trial_interval_count=checkout_link.trial_interval_count,
807 allow_discount_codes=checkout_link.allow_discount_codes,
808 allow_trial=True,
809 require_billing_address=checkout_link.require_billing_address,
810 organization=checkout_link.organization,
811 checkout_products=[
812 CheckoutProduct(product=p, order=i, ad_hoc_prices=[])
813 for i, p in enumerate(products)
814 ],
815 product=product,
816 product_price=price,
817 discount=discount,
818 embed_origin=embed_origin,
819 customer_ip_address=ip_address,
820 payment_processor=checkout_link.payment_processor,
821 success_url=checkout_link.success_url,
822 user_metadata=checkout_link.user_metadata,
823 )
825 # Handle query parameter prefill
826 if query_prefill:
827 customer_email = query_prefill.get("customer_email")
828 if customer_email is not None and isinstance(customer_email, str):
829 checkout.customer_email = customer_email
831 customer_name = query_prefill.get("customer_name")
832 if customer_name is not None and isinstance(customer_name, str):
833 checkout.customer_name = customer_name
835 discount_code = query_prefill.get("discount_code")
836 if discount_code is not None and isinstance(discount_code, str):
837 try:
838 discount = await self._get_validated_discount(
839 session,
840 product.organization,
841 product,
842 price,
843 discount_code=discount_code,
844 )
845 checkout.discount = discount
846 except PolarRequestValidationError:
847 pass
849 custom_field_data_value = query_prefill.get("custom_field_data")
850 if custom_field_data_value is not None and isinstance(
851 custom_field_data_value, dict
852 ):
853 valid_slugs = {
854 cf.custom_field.slug for cf in product.attached_custom_fields
855 }
857 filtered_data = {
858 slug: value
859 for slug, value in custom_field_data_value.items()
860 if slug in valid_slugs
861 }
863 if filtered_data:
864 try:
865 validated_data = validate_custom_field_data(
866 product.attached_custom_fields,
867 filtered_data,
868 validate_required=False,
869 )
870 checkout.custom_field_data = {
871 **(checkout.custom_field_data or {}),
872 **validated_data,
873 }
874 except PolarRequestValidationError:
875 # If validation fails, just ignore the custom field data
876 pass
878 for key, value in query_metadata.items():
879 if value is not None and key not in checkout.user_metadata:
880 checkout.user_metadata = {
881 **(checkout.user_metadata or {}),
882 key: value,
883 }
885 if checkout.payment_processor == PaymentProcessor.stripe:
886 checkout.payment_processor_metadata = {
887 **(checkout.payment_processor_metadata or {}),
888 "publishable_key": settings.STRIPE_PUBLISHABLE_KEY,
889 }
891 session.add(checkout)
893 checkout = await self._update_checkout_ip_geolocation(
894 session, checkout, ip_geolocation_client
895 )
896 checkout = await self._update_trial_end(checkout)
898 try:
899 checkout = await self._update_checkout_tax(session, checkout)
900 # Swallow incomplete tax calculation error: require it only on confirm
901 except TaxCalculationError:
902 pass
904 await session.flush()
905 await self._after_checkout_created(session, checkout)
907 return checkout
909 async def update( 1a
910 self,
911 session: AsyncSession,
912 locker: Locker,
913 checkout: Checkout,
914 checkout_update: CheckoutUpdate | CheckoutUpdatePublic,
915 ip_geolocation_client: ip_geolocation.IPGeolocationClient | None = None,
916 ) -> Checkout:
917 async with self._lock_checkout_update(session, locker, checkout) as checkout:
918 checkout = await self._update_checkout(
919 session, checkout, checkout_update, ip_geolocation_client
920 )
921 try:
922 checkout = await self._update_checkout_tax(session, checkout)
923 # Swallow incomplete tax calculation error: require it only on confirm
924 except TaxCalculationError:
925 pass
927 await self._after_checkout_updated(session, checkout)
928 return checkout
930 async def confirm( 1a
931 self,
932 session: AsyncSession,
933 locker: Locker,
934 auth_subject: AuthSubject[User | Anonymous],
935 checkout: Checkout,
936 checkout_confirm: CheckoutConfirm,
937 ) -> Checkout:
938 async with self._lock_checkout_update(session, locker, checkout) as checkout:
939 checkout = await self._update_checkout(session, checkout, checkout_confirm)
940 # When redeeming a discount, we need to lock the discount to prevent concurrent redemptions
941 if checkout.discount is not None:
942 try:
943 async with discount_service.redeem_discount(
944 session, locker, checkout.discount
945 ) as discount_redemption:
946 discount_redemption.checkout = checkout
947 return await self._confirm_inner(
948 session, auth_subject, checkout, checkout_confirm
949 )
950 except DiscountNotRedeemableError as e:
951 raise PolarRequestValidationError(
952 [
953 {
954 "type": "value_error",
955 "loc": ("body", "discount_id"),
956 "msg": "Discount is no longer redeemable.",
957 "input": checkout.discount.id,
958 }
959 ]
960 ) from e
962 return await self._confirm_inner(
963 session, auth_subject, checkout, checkout_confirm
964 )
966 async def _confirm_inner( 1a
967 self,
968 session: AsyncSession,
969 auth_subject: AuthSubject[User | Anonymous],
970 checkout: Checkout,
971 checkout_confirm: CheckoutConfirm,
972 ) -> Checkout:
973 errors: list[ValidationError] = []
974 try:
975 checkout = await self._update_checkout_tax(session, checkout)
976 except TaxCalculationError as e:
977 errors.append(
978 {
979 "type": "value_error",
980 "loc": ("body", "customer_billing_address"),
981 "msg": e.message,
982 "input": None,
983 }
984 )
986 # Case where the price was archived after the checkout was created
987 if has_product_checkout(checkout) and checkout.product_price.is_archived:
988 errors.append(
989 {
990 "type": "value_error",
991 "loc": ("body", "product_price_id"),
992 "msg": "Price is archived.",
993 "input": checkout.product_price_id,
994 }
995 )
997 # Check if organization can accept payments (only block paid transactions)
998 if (
999 checkout.is_payment_required
1000 and not await organization_service.is_organization_ready_for_payment(
1001 session, checkout.organization
1002 )
1003 ):
1004 raise PaymentNotReady()
1006 required_fields = self._get_required_confirm_fields(checkout)
1007 for required_field in required_fields:
1008 if (
1009 attrgetter(checkout, required_field) is None
1010 and attrgetter(checkout_confirm, required_field) is None
1011 ):
1012 errors.append(
1013 {
1014 "type": "missing",
1015 "loc": ("body", *required_field),
1016 "msg": "Field is required.",
1017 "input": None,
1018 }
1019 )
1021 if checkout.require_billing_address or checkout.is_business_customer:
1022 if (
1023 checkout.customer_billing_address is None
1024 or not checkout.customer_billing_address.has_address()
1025 ):
1026 errors.append(
1027 {
1028 "type": "value_error",
1029 "loc": ("body", "customer_billing_address"),
1030 "msg": "Full billing address is required.",
1031 "input": checkout.customer_billing_address,
1032 }
1033 )
1035 if (
1036 checkout.is_payment_form_required
1037 and checkout_confirm.confirmation_token_id is None
1038 ):
1039 errors.append(
1040 {
1041 "type": "missing",
1042 "loc": ("body", "confirmation_token_id"),
1043 "msg": "Confirmation token is required.",
1044 "input": None,
1045 }
1046 )
1048 if len(errors) > 0:
1049 raise PolarRequestValidationError(errors)
1051 if (
1052 checkout.trial_end is not None
1053 and not checkout.organization.subscriptions_billing_engine
1054 ):
1055 raise BadRequest(
1056 "Trials are not supported on susbcriptions managed by Stripe."
1057 )
1059 if checkout.payment_processor == PaymentProcessor.stripe:
1060 async with self._create_or_update_customer(
1061 session, auth_subject, checkout
1062 ) as customer:
1063 checkout.customer = customer
1064 stripe_customer_id = customer.stripe_customer_id
1065 assert stripe_customer_id is not None
1066 checkout.payment_processor_metadata = {
1067 **checkout.payment_processor_metadata,
1068 "customer_id": stripe_customer_id,
1069 }
1071 intent: stripe_lib.PaymentIntent | stripe_lib.SetupIntent | None = None
1072 if checkout.is_payment_form_required:
1073 assert checkout_confirm.confirmation_token_id is not None
1074 assert checkout.customer_billing_address is not None
1075 intent_metadata: dict[str, str] = {
1076 "checkout_id": str(checkout.id),
1077 "type": ProductType.product,
1078 "tax_amount": str(checkout.tax_amount),
1079 "tax_country": checkout.customer_billing_address.country,
1080 }
1081 if (
1082 state
1083 := checkout.customer_billing_address.get_unprefixed_state()
1084 ) is not None:
1085 intent_metadata["tax_state"] = state
1087 try:
1088 if checkout.is_payment_required:
1089 payment_intent_params: stripe_lib.PaymentIntent.CreateParams = {
1090 "amount": checkout.total_amount,
1091 "currency": checkout.currency,
1092 "automatic_payment_methods": {"enabled": True},
1093 "confirm": True,
1094 "confirmation_token": checkout_confirm.confirmation_token_id,
1095 "customer": stripe_customer_id,
1096 "statement_descriptor_suffix": checkout.organization.statement_descriptor(),
1097 "description": checkout.description,
1098 "metadata": intent_metadata,
1099 "return_url": settings.generate_frontend_url(
1100 f"/checkout/{checkout.client_secret}/confirmation"
1101 ),
1102 "expand": ["payment_method"],
1103 }
1104 if checkout.should_save_payment_method:
1105 payment_intent_params["setup_future_usage"] = (
1106 "off_session"
1107 )
1108 intent = await stripe_service.create_payment_intent(
1109 **payment_intent_params
1110 )
1111 else:
1112 setup_intent_params: stripe_lib.SetupIntent.CreateParams = {
1113 "automatic_payment_methods": {"enabled": True},
1114 "confirm": True,
1115 "confirmation_token": checkout_confirm.confirmation_token_id,
1116 "customer": stripe_customer_id,
1117 "description": checkout.description,
1118 "metadata": intent_metadata,
1119 "return_url": settings.generate_frontend_url(
1120 f"/checkout/{checkout.client_secret}/confirmation"
1121 ),
1122 "expand": ["payment_method"],
1123 }
1124 intent = await stripe_service.create_setup_intent(
1125 **setup_intent_params
1126 )
1127 except stripe_lib.StripeError as e:
1128 error = e.error
1129 error_type = error.type if error is not None else None
1130 error_message = error.message if error is not None else None
1131 raise PaymentError(checkout, error_type, error_message) from e
1132 else:
1133 checkout.payment_processor_metadata = {
1134 **checkout.payment_processor_metadata,
1135 "intent_client_secret": intent.client_secret,
1136 "intent_status": intent.status,
1137 }
1139 # Check for trial abuse
1140 if (
1141 checkout.trial_end is not None
1142 and checkout.organization.prevent_trial_abuse
1143 ):
1144 trial_already_redeemed = (
1145 await trial_redemption_service.check_trial_already_redeemed(
1146 session,
1147 checkout.organization,
1148 customer=customer,
1149 payment_method_fingerprint=get_fingerprint(
1150 typing.cast(
1151 stripe_lib.PaymentMethod, intent.payment_method
1152 )
1153 )
1154 if (intent and intent.payment_method)
1155 else None,
1156 )
1157 )
1158 if trial_already_redeemed:
1159 raise TrialAlreadyRedeemed(checkout)
1161 if not checkout.is_payment_form_required:
1162 enqueue_job("checkout.handle_free_success", checkout_id=checkout.id)
1164 checkout.status = CheckoutStatus.confirmed
1165 session.add(checkout)
1167 await self._after_checkout_updated(session, checkout)
1169 assert checkout.customer is not None
1170 (
1171 customer_session_token,
1172 _,
1173 ) = await customer_session_service.create_customer_session(
1174 session, checkout.customer
1175 )
1176 checkout.customer_session_token = customer_session_token
1178 return checkout
1180 async def handle_success( 1a
1181 self,
1182 session: AsyncSession,
1183 checkout: Checkout,
1184 payment: Payment | None = None,
1185 payment_method: PaymentMethod | None = None,
1186 ) -> Checkout:
1187 if checkout.status != CheckoutStatus.confirmed:
1188 raise NotConfirmedCheckout(checkout)
1190 if not has_product_checkout(checkout):
1191 raise NotImplementedError()
1193 product_price = checkout.product_price
1194 if product_price.is_archived:
1195 raise ArchivedPriceCheckout(checkout)
1197 product = checkout.product
1198 subscription: Subscription | None = None
1199 if product.is_recurring:
1200 if not checkout.organization.subscriptions_billing_engine:
1201 (
1202 subscription,
1203 _,
1204 ) = await subscription_service.create_or_update_from_checkout_stripe(
1205 session, checkout, payment, payment_method
1206 )
1207 else:
1208 (
1209 subscription,
1210 created,
1211 ) = await subscription_service.create_or_update_from_checkout(
1212 session, checkout, payment_method
1213 )
1214 await order_service.create_from_checkout_subscription(
1215 session,
1216 checkout,
1217 subscription,
1218 OrderBillingReasonInternal.subscription_create
1219 if created
1220 else OrderBillingReasonInternal.subscription_update,
1221 payment,
1222 )
1223 else:
1224 await order_service.create_from_checkout_one_time(
1225 session, checkout, payment
1226 )
1228 # Create trial redemption record if this checkout had a trial period
1229 if checkout.trial_end is not None:
1230 assert checkout.customer is not None
1231 await trial_redemption_service.create_trial_redemption(
1232 session,
1233 customer=checkout.customer,
1234 product=product,
1235 payment_method_fingerprint=payment_method.fingerprint
1236 if payment_method
1237 else None,
1238 )
1240 repository = CheckoutRepository.from_session(session)
1241 checkout = await repository.update(
1242 checkout,
1243 update_dict={
1244 "status": CheckoutStatus.succeeded,
1245 "payment_processor_metadata": {
1246 **checkout.payment_processor_metadata,
1247 "intent_status": "succeeded",
1248 },
1249 },
1250 )
1252 await self._after_checkout_updated(session, checkout)
1254 return checkout
1256 async def handle_failure( 1a
1257 self, session: AsyncSession, checkout: Checkout, payment: Payment | None = None
1258 ) -> Checkout:
1259 # Checkout is in an unrecoverable status: do nothing
1260 if checkout.status in {
1261 CheckoutStatus.expired,
1262 CheckoutStatus.succeeded,
1263 CheckoutStatus.failed,
1264 }:
1265 return checkout
1267 # Put back checkout in open state so the customer can try another payment method
1268 checkout.status = CheckoutStatus.open
1269 checkout.payment_processor_metadata = {
1270 k: v
1271 for k, v in checkout.payment_processor_metadata.items()
1272 if k not in {"intent_status", "intent_client_secret"}
1273 }
1274 session.add(checkout)
1276 # Make sure to remove the Discount Redemptions
1277 # To avoid race conditions, we save the Discount Redemption when *confirming*
1278 # the Checkout.
1279 # However, if it ultimately fails, we need to free up the Discount Redemption.
1280 await discount_service.remove_checkout_redemption(session, checkout)
1282 await self._after_checkout_updated(session, checkout)
1284 return checkout
1286 async def get_by_client_secret( 1a
1287 self, session: AsyncSession, client_secret: str
1288 ) -> Checkout:
1289 repository = CheckoutRepository.from_session(session)
1290 checkout = await repository.get_by_client_secret(
1291 client_secret, options=repository.get_eager_options()
1292 )
1293 if checkout is None:
1294 raise ResourceNotFound()
1295 if checkout.is_expired:
1296 raise ExpiredCheckoutError()
1297 return checkout
1299 async def _get_validated_price( 1a
1300 self,
1301 session: AsyncSession,
1302 auth_subject: AuthSubject[User | Organization],
1303 product_price_id: uuid.UUID,
1304 ) -> tuple[Sequence[Product], Product, ProductPrice]:
1305 product_price_repository = ProductPriceRepository.from_session(session)
1306 price = await product_price_repository.get_readable_by_id(
1307 product_price_id,
1308 auth_subject,
1309 options=(
1310 contains_eager(ProductPrice.product).options(
1311 joinedload(Product.organization)
1312 .joinedload(Organization.account)
1313 .joinedload(Account.admin),
1314 selectinload(Product.prices),
1315 ),
1316 ),
1317 )
1319 if price is None:
1320 raise PolarRequestValidationError(
1321 [
1322 {
1323 "type": "value_error",
1324 "loc": ("body", "product_price_id"),
1325 "msg": "Price does not exist.",
1326 "input": product_price_id,
1327 }
1328 ]
1329 )
1331 if price.is_archived:
1332 raise PolarRequestValidationError(
1333 [
1334 {
1335 "type": "value_error",
1336 "loc": ("body", "product_price_id"),
1337 "msg": "Price is archived.",
1338 "input": product_price_id,
1339 }
1340 ]
1341 )
1343 product = price.product
1344 if product.is_archived:
1345 raise PolarRequestValidationError(
1346 [
1347 {
1348 "type": "value_error",
1349 "loc": ("body", "product_price_id"),
1350 "msg": "Product is archived.",
1351 "input": product_price_id,
1352 }
1353 ]
1354 )
1356 return [product], product, price
1358 async def _get_validated_product( 1a
1359 self,
1360 session: AsyncSession,
1361 auth_subject: AuthSubject[User | Organization],
1362 product_id: uuid.UUID,
1363 ) -> tuple[Sequence[Product], Product, ProductPrice]:
1364 product = await product_service.get(session, auth_subject, product_id)
1366 if product is None:
1367 raise PolarRequestValidationError(
1368 [
1369 {
1370 "type": "value_error",
1371 "loc": ("body", "product_id"),
1372 "msg": "Product does not exist.",
1373 "input": product_id,
1374 }
1375 ]
1376 )
1378 if product.is_archived:
1379 raise PolarRequestValidationError(
1380 [
1381 {
1382 "type": "value_error",
1383 "loc": ("body", "product_id"),
1384 "msg": "Product is archived.",
1385 "input": product_id,
1386 }
1387 ]
1388 )
1390 # Select the static price in priority, as it determines the amount and specific behavior, like PWYW
1391 price = product.get_static_price() or product.prices[0]
1393 return [product], product, price
1395 async def _get_validated_products( 1a
1396 self,
1397 session: AsyncSession,
1398 auth_subject: AuthSubject[User | Organization],
1399 product_ids: Sequence[uuid.UUID],
1400 ) -> Sequence[Product]:
1401 products: list[Product] = []
1402 errors: list[ValidationError] = []
1404 for index, product_id in enumerate(product_ids):
1405 product = await product_service.get(session, auth_subject, product_id)
1407 if product is None:
1408 errors.append(
1409 {
1410 "type": "value_error",
1411 "loc": ("body", "products", index),
1412 "msg": "Product does not exist.",
1413 "input": product_id,
1414 }
1415 )
1416 continue
1418 if product.is_archived:
1419 errors.append(
1420 {
1421 "type": "value_error",
1422 "loc": ("body", "products", index),
1423 "msg": "Product is archived.",
1424 "input": product_id,
1425 }
1426 )
1427 continue
1429 products.append(product)
1431 organization_ids = {product.organization_id for product in products}
1432 if len(organization_ids) > 1:
1433 errors.append(
1434 {
1435 "type": "value_error",
1436 "loc": ("body", "products"),
1437 "msg": "Products must all belong to the same organization.",
1438 "input": products,
1439 }
1440 )
1442 if len(errors) > 0:
1443 raise PolarRequestValidationError(errors)
1445 return products
1447 async def _get_validated_prices( 1a
1448 self,
1449 session: AsyncSession,
1450 auth_subject: AuthSubject[User | Organization],
1451 products: Sequence[Product],
1452 prices: dict[uuid.UUID, ProductPriceCreateList],
1453 ) -> dict[Product, Sequence[ProductPrice]]:
1454 validated_prices: dict[Product, Sequence[ProductPrice]] = {}
1455 errors: list[ValidationError] = []
1456 for product_id, product_prices in prices.items():
1457 try:
1458 product = next(p for p in products if p.id == product_id)
1459 except StopIteration:
1460 errors.append(
1461 {
1462 "type": "value_error",
1463 "loc": ("body", "prices", str(product_id)),
1464 "msg": "Product is not set on that checkout.",
1465 "input": product_id,
1466 }
1467 )
1468 continue
1470 (
1471 validated_product_prices,
1472 _,
1473 _,
1474 price_errors,
1475 ) = await product_service.get_validated_prices(
1476 session,
1477 product_prices,
1478 product.recurring_interval,
1479 product,
1480 auth_subject,
1481 source=ProductPriceSource.ad_hoc,
1482 error_prefix=(
1483 "body",
1484 "prices",
1485 str(product_id),
1486 ),
1487 )
1488 errors = [*errors, *price_errors]
1489 validated_prices[product] = validated_product_prices
1491 if len(errors) > 0:
1492 raise PolarRequestValidationError(errors)
1494 return validated_prices
1496 @typing.overload 1a
1497 async def _get_validated_discount( 1497 ↛ exitline 1497 didn't return from function '_get_validated_discount' because 1a
1498 self,
1499 session: AsyncSession,
1500 organization: Organization,
1501 product: Product,
1502 price: ProductPrice,
1503 *,
1504 discount_id: uuid.UUID,
1505 ) -> Discount: ...
1507 @typing.overload 1a
1508 async def _get_validated_discount( 1508 ↛ exitline 1508 didn't return from function '_get_validated_discount' because 1a
1509 self,
1510 session: AsyncSession,
1511 organization: Organization,
1512 product: Product,
1513 price: ProductPrice,
1514 *,
1515 discount_code: str,
1516 ) -> Discount: ...
1518 async def _get_validated_discount( 1a
1519 self,
1520 session: AsyncSession,
1521 organization: Organization,
1522 product: Product,
1523 price: ProductPrice,
1524 *,
1525 discount_id: uuid.UUID | None = None,
1526 discount_code: str | None = None,
1527 ) -> Discount:
1528 loc_field = "discount_id" if discount_id is not None else "discount_code"
1530 if not any(is_discount_applicable(price) for price in product.prices):
1531 raise PolarRequestValidationError(
1532 [
1533 {
1534 "type": "value_error",
1535 "loc": ("body", loc_field),
1536 "msg": "Discounts are not applicable to this product.",
1537 "input": discount_id,
1538 }
1539 ]
1540 )
1542 discount: Discount | None = None
1543 if discount_id is not None:
1544 discount = await discount_service.get_by_id_and_organization(
1545 session, discount_id, organization, products=[product]
1546 )
1547 elif discount_code is not None:
1548 discount = await discount_service.get_by_code_and_product(
1549 session, discount_code, organization, product
1550 )
1552 if discount is None:
1553 raise PolarRequestValidationError(
1554 [
1555 {
1556 "type": "value_error",
1557 "loc": ("body", loc_field),
1558 "msg": "Discount does not exist.",
1559 "input": discount_id,
1560 }
1561 ]
1562 )
1564 if (
1565 product.recurring_interval is None
1566 and not isinstance(
1567 price,
1568 LegacyRecurringProductPriceFixed | LegacyRecurringProductPriceCustom,
1569 )
1570 and discount.duration == DiscountDuration.repeating
1571 ):
1572 raise PolarRequestValidationError(
1573 [
1574 {
1575 "type": "value_error",
1576 "loc": ("body", loc_field),
1577 "msg": "Discount is not applicable to this product.",
1578 "input": discount_id,
1579 }
1580 ]
1581 )
1583 return discount
1585 async def _get_validated_subscription( 1a
1586 self,
1587 session: AsyncSession,
1588 subscription_id: uuid.UUID,
1589 organization_id: uuid.UUID,
1590 ) -> tuple[Subscription, Customer]:
1591 subscription_repository = SubscriptionRepository.from_session(session)
1592 subscription = await subscription_repository.get_by_id_and_organization(
1593 subscription_id,
1594 organization_id,
1595 options=(joinedload(Subscription.customer),),
1596 )
1598 if subscription is None:
1599 raise PolarRequestValidationError(
1600 [
1601 {
1602 "type": "value_error",
1603 "loc": ("body", "subscription_id"),
1604 "msg": "Subscription does not exist.",
1605 "input": subscription_id,
1606 }
1607 ]
1608 )
1610 for price in subscription.prices:
1611 if price.amount_type != ProductPriceAmountType.free:
1612 raise PolarRequestValidationError(
1613 [
1614 {
1615 "type": "value_error",
1616 "loc": ("body", "subscription_id"),
1617 "msg": "Only free subscriptions can be upgraded.",
1618 "input": subscription_id,
1619 }
1620 ]
1621 )
1623 return subscription, subscription.customer
1625 @contextlib.asynccontextmanager 1a
1626 async def _lock_checkout_update( 1a
1627 self, session: AsyncSession, locker: Locker, checkout: Checkout
1628 ) -> AsyncIterator[Checkout]:
1629 """
1630 Set a lock to prevent updating the checkout while confirming.
1631 We've seen in the wild someone switching pricing while the payment was being made!
1633 The timeout is purposely set to 10 seconds, a high value.
1634 We've seen in the past Stripe payment requests taking more than 5 seconds,
1635 causing the lock to expire while waiting for the payment to complete.
1636 """
1637 async with locker.lock(
1638 f"checkout:{checkout.id}", timeout=10, blocking_timeout=10
1639 ):
1640 # Refresh the checkout: it may have changed while waiting for the lock
1641 repository = CheckoutRepository.from_session(session)
1642 checkout = typing.cast(
1643 Checkout,
1644 await repository.get_by_id(
1645 checkout.id, options=repository.get_eager_options()
1646 ),
1647 )
1648 yield checkout
1650 # 🚨 It's not a mistake: we do explicitly commit here before releasing the lock.
1651 # The goal is to avoid race conditions where waiting updates take the lock and refresh
1652 # the Checkout object _before_ the previous operation had the chance to commit
1653 # See: https://github.com/polarsource/polar/issues/7260
1654 await session.commit()
1656 async def _update_checkout( 1a
1657 self,
1658 session: AsyncSession,
1659 checkout: Checkout,
1660 checkout_update: CheckoutUpdate | CheckoutUpdatePublic | CheckoutConfirm,
1661 ip_geolocation_client: ip_geolocation.IPGeolocationClient | None = None,
1662 ) -> Checkout:
1663 if checkout.status != CheckoutStatus.open:
1664 raise NotOpenCheckout(checkout)
1666 if checkout_update.product_id is not None:
1667 product_repository = ProductRepository.from_session(session)
1668 product = await product_repository.get_by_id_and_checkout(
1669 checkout_update.product_id,
1670 checkout.id,
1671 options=product_repository.get_eager_options(),
1672 )
1674 if product is None:
1675 raise PolarRequestValidationError(
1676 [
1677 {
1678 "type": "value_error",
1679 "loc": ("body", "product_id"),
1680 "msg": "Product is not available in this checkout.",
1681 "input": checkout_update.product_id,
1682 }
1683 ]
1684 )
1686 if product.is_archived:
1687 raise PolarRequestValidationError(
1688 [
1689 {
1690 "type": "value_error",
1691 "loc": ("body", "product_id"),
1692 "msg": "Product is archived.",
1693 "input": checkout_update.product_id,
1694 }
1695 ]
1696 )
1698 checkout.product = product
1700 if checkout_update.product_price_id is not None:
1701 try:
1702 price = next(
1703 p
1704 for p in checkout.prices[product.id]
1705 if p.id == checkout_update.product_price_id
1706 )
1707 except StopIteration as e:
1708 raise PolarRequestValidationError(
1709 [
1710 {
1711 "type": "value_error",
1712 "loc": ("body", "product_price_id"),
1713 "msg": "Price is not available in this checkout.",
1714 "input": checkout_update.product_price_id,
1715 }
1716 ]
1717 ) from e
1718 else:
1719 price = get_default_price(checkout.prices[product.id])
1721 checkout.product_price = price
1722 checkout.amount = 0
1723 checkout.currency = "usd"
1724 if is_fixed_price(price):
1725 checkout.amount = price.price_amount
1726 checkout.currency = price.price_currency
1727 checkout.seats = None
1728 elif is_custom_price(price):
1729 checkout.amount = (
1730 price.preset_amount
1731 or price.minimum_amount
1732 or settings.CUSTOM_PRICE_PRESET_FALLBACK
1733 )
1734 checkout.currency = price.price_currency
1735 checkout.seats = None
1736 elif is_seat_price(price):
1737 seats = checkout.seats or checkout_update.seats or 1
1738 checkout.seats = seats
1739 checkout.amount = price.calculate_amount(seats)
1740 checkout.currency = price.price_currency
1741 elif is_currency_price(price):
1742 checkout.currency = price.price_currency
1743 checkout.seats = None
1745 # When changing product, remove the discount if it's not applicable
1746 if (
1747 has_product_checkout(checkout)
1748 and checkout.discount is not None
1749 and not checkout.discount.is_applicable(checkout.product)
1750 ):
1751 checkout.discount = None
1753 if (
1754 has_product_checkout(checkout)
1755 and checkout_update.amount is not None
1756 and is_custom_price(checkout.product_price)
1757 ):
1758 self._validate_custom_price_amount(
1759 checkout.product_price, checkout_update.amount
1760 )
1761 checkout.amount = checkout_update.amount
1763 # Handle seat updates for seat-based pricing
1764 if (
1765 has_product_checkout(checkout)
1766 and checkout_update.seats is not None
1767 and is_seat_price(checkout.product_price)
1768 ):
1769 checkout.seats = checkout_update.seats
1770 checkout.amount = checkout.product_price.calculate_amount(
1771 checkout_update.seats
1772 )
1773 elif checkout_update.seats is not None:
1774 # Seats provided for non-seat-based pricing
1775 raise PolarRequestValidationError(
1776 [
1777 {
1778 "type": "value_error",
1779 "loc": ("body", "seats"),
1780 "msg": "Seats can only be set for seat-based pricing.",
1781 "input": checkout_update.seats,
1782 }
1783 ]
1784 )
1786 if isinstance(checkout_update, CheckoutUpdate):
1787 if (
1788 has_product_checkout(checkout)
1789 and checkout_update.discount_id is not None
1790 ):
1791 checkout.discount = await self._get_validated_discount(
1792 session,
1793 checkout.organization,
1794 checkout.product,
1795 checkout.product_price,
1796 discount_id=checkout_update.discount_id,
1797 )
1798 # User explicitly removed the discount
1799 elif "discount_id" in checkout_update.model_fields_set:
1800 checkout.discount = None
1801 elif (
1802 isinstance(checkout_update, CheckoutUpdatePublic)
1803 and checkout.allow_discount_codes
1804 ):
1805 if (
1806 has_product_checkout(checkout)
1807 and checkout_update.discount_code is not None
1808 ):
1809 discount = await self._get_validated_discount(
1810 session,
1811 checkout.organization,
1812 checkout.product,
1813 checkout.product_price,
1814 discount_code=checkout_update.discount_code,
1815 )
1816 checkout.discount = discount
1817 # User explicitly removed the discount
1818 elif "discount_code" in checkout_update.model_fields_set:
1819 checkout.discount = None
1821 if checkout_update.customer_billing_address:
1822 checkout.customer_billing_address = checkout_update.customer_billing_address
1824 if (
1825 checkout_update.customer_tax_id is None
1826 and "customer_tax_id" in checkout_update.model_fields_set
1827 ):
1828 checkout.customer_tax_id = None
1829 else:
1830 customer_tax_id_number = (
1831 checkout_update.customer_tax_id or checkout.customer_tax_id_number
1832 )
1833 if customer_tax_id_number is not None:
1834 customer_billing_address = (
1835 checkout_update.customer_billing_address
1836 or checkout.customer_billing_address
1837 )
1838 if customer_billing_address is None:
1839 raise PolarRequestValidationError(
1840 [
1841 {
1842 "type": "missing",
1843 "loc": ("body", "customer_billing_address"),
1844 "msg": "Country is required to validate tax ID.",
1845 "input": None,
1846 }
1847 ]
1848 )
1849 try:
1850 checkout.customer_tax_id = validate_tax_id(
1851 customer_tax_id_number, customer_billing_address.country
1852 )
1853 except InvalidTaxID as e:
1854 raise PolarRequestValidationError(
1855 [
1856 {
1857 "type": "value_error",
1858 "loc": ("body", "customer_tax_id"),
1859 "msg": "Invalid tax ID.",
1860 "input": customer_tax_id_number,
1861 }
1862 ]
1863 ) from e
1865 if (
1866 has_product_checkout(checkout)
1867 and checkout_update.custom_field_data is not None
1868 ):
1869 custom_field_data = validate_custom_field_data(
1870 checkout.product.attached_custom_fields,
1871 checkout_update.custom_field_data,
1872 validate_required=isinstance(checkout_update, CheckoutConfirm),
1873 )
1874 checkout.custom_field_data = custom_field_data
1876 checkout = await self._update_checkout_ip_geolocation(
1877 session, checkout, ip_geolocation_client
1878 )
1880 exclude = {
1881 "product_id",
1882 "product_price_id",
1883 "amount",
1884 "customer_billing_address",
1885 "customer_tax_id",
1886 "custom_field_data",
1887 }
1889 if checkout.customer_id is not None:
1890 exclude.add("customer_email")
1892 for attr, value in checkout_update.model_dump(
1893 exclude_unset=True, exclude=exclude, by_alias=True
1894 ).items():
1895 setattr(checkout, attr, value)
1897 checkout = await self._update_trial_end(checkout)
1899 session.add(checkout)
1901 await self._validate_subscription_uniqueness(session, checkout)
1903 return checkout
1905 async def _update_checkout_tax( 1a
1906 self, session: AsyncSession, checkout: Checkout
1907 ) -> Checkout:
1908 is_tax_applicable = True
1909 tax_code = TaxCode.general_electronically_supplied_services
1910 if has_product_checkout(checkout):
1911 is_tax_applicable = checkout.product.is_tax_applicable
1912 tax_code = checkout.product.tax_code
1914 if not (checkout.is_payment_form_required and is_tax_applicable):
1915 checkout.tax_amount = 0
1916 checkout.tax_processor_id = None
1917 return checkout
1919 if checkout.customer_billing_address is not None:
1920 try:
1921 tax_calculation = await calculate_tax(
1922 checkout.id,
1923 checkout.currency,
1924 checkout.net_amount,
1925 tax_code,
1926 checkout.customer_billing_address,
1927 (
1928 [checkout.customer_tax_id]
1929 if checkout.customer_tax_id is not None
1930 else []
1931 ),
1932 customer_exempt=False,
1933 )
1934 checkout.tax_amount = tax_calculation["amount"]
1935 checkout.tax_processor_id = tax_calculation["processor_id"]
1936 except TaxCalculationError:
1937 checkout.tax_amount = None
1938 checkout.tax_processor_id = None
1939 raise
1940 finally:
1941 session.add(checkout)
1943 return checkout
1945 async def _update_checkout_ip_geolocation( 1a
1946 self,
1947 session: AsyncSession,
1948 checkout: Checkout,
1949 ip_geolocation_client: ip_geolocation.IPGeolocationClient | None,
1950 ) -> Checkout:
1951 if ip_geolocation_client is None:
1952 return checkout
1954 if checkout.customer_ip_address is None:
1955 return checkout
1957 if checkout.customer_billing_address is not None:
1958 return checkout
1960 country = ip_geolocation.get_ip_country(
1961 ip_geolocation_client, checkout.customer_ip_address
1962 )
1963 if country is None:
1964 return checkout
1966 try:
1967 address = AddressInput.model_validate({"country": country})
1968 except PydanticValidationError:
1969 return checkout
1971 checkout.customer_billing_address = address
1972 session.add(checkout)
1973 return checkout
1975 async def _update_trial_end(self, checkout: Checkout) -> Checkout: 1a
1976 if not has_product_checkout(checkout):
1977 checkout.trial_end = None
1978 return checkout
1980 if not checkout.product.is_recurring:
1981 checkout.trial_end = None
1982 return checkout
1984 trial_interval = checkout.active_trial_interval
1985 trial_interval_count = checkout.active_trial_interval_count
1987 if trial_interval is not None and trial_interval_count is not None:
1988 checkout.trial_end = trial_interval.get_end(utc_now(), trial_interval_count)
1989 else:
1990 checkout.trial_end = None
1992 return checkout
1994 async def _validate_subscription_uniqueness( 1a
1995 self, session: AsyncSession, checkout: Checkout
1996 ) -> None:
1997 organization = checkout.organization
1999 # No product checkout
2000 if not has_product_checkout(checkout):
2001 return
2003 # Multiple subscriptions allowed
2004 if organization.allow_multiple_subscriptions:
2005 return
2007 # One-time purchase
2008 if not checkout.product.is_recurring:
2009 return
2011 # Subscription upgrade
2012 if checkout.subscription is not None:
2013 return
2015 # No information yet to check customer subscription uniqueness
2016 if checkout.customer_id is None and checkout.customer_email is None:
2017 return
2019 statement = (
2020 select(Subscription)
2021 .join(Product, onclause=Product.id == Subscription.product_id)
2022 .where(
2023 Product.organization_id == organization.id,
2024 Subscription.billable.is_(True),
2025 )
2026 )
2027 if checkout.customer is not None:
2028 statement = statement.where(
2029 Subscription.customer_id == checkout.customer_id
2030 )
2031 elif checkout.customer_email is not None:
2032 statement = statement.join(
2033 Customer, onclause=Customer.id == Subscription.customer_id
2034 ).where(
2035 func.lower(Customer.email) == checkout.customer_email.lower(),
2036 Customer.deleted_at.is_(None),
2037 )
2039 result = await session.execute(statement)
2040 existing_subscriptions = result.scalars().all()
2042 if len(existing_subscriptions) > 0:
2043 raise AlreadyActiveSubscriptionError()
2045 def _validate_custom_price_amount( 1a
2046 self,
2047 price: ProductPrice,
2048 amount: int,
2049 loc: tuple[str, ...] = ("body", "amount"),
2050 ) -> None:
2051 """Validate that an amount is within the min/max bounds for a custom price."""
2052 if not is_custom_price(price):
2053 return
2055 if price.minimum_amount is not None and amount < price.minimum_amount:
2056 raise PolarRequestValidationError(
2057 [
2058 {
2059 "type": "greater_than_equal",
2060 "loc": loc,
2061 "msg": "Amount is below minimum.",
2062 "input": amount,
2063 "ctx": {"ge": price.minimum_amount},
2064 }
2065 ]
2066 )
2068 if price.maximum_amount is not None and amount > price.maximum_amount:
2069 raise PolarRequestValidationError(
2070 [
2071 {
2072 "type": "less_than_equal",
2073 "loc": loc,
2074 "msg": "Amount is above maximum.",
2075 "input": amount,
2076 "ctx": {"le": price.maximum_amount},
2077 }
2078 ]
2079 )
2081 def _get_required_confirm_fields(self, checkout: Checkout) -> set[tuple[str, ...]]: 1a
2082 fields: set[tuple[str, ...]] = {("customer_email",)}
2083 if checkout.is_payment_form_required:
2084 fields.update({("customer_name",), ("customer_billing_address",)})
2085 for (
2086 address_field,
2087 required,
2088 ) in checkout.customer_billing_address_fields.items():
2089 if required:
2090 fields.add(("customer_billing_address", address_field))
2091 if checkout.is_business_customer:
2092 fields.update({("customer_billing_name",), ("customer_billing_address",)})
2093 return fields
2095 @contextlib.asynccontextmanager 1a
2096 async def _create_or_update_customer( 1a
2097 self,
2098 session: AsyncSession,
2099 auth_subject: AuthSubject[User | Anonymous],
2100 checkout: Checkout,
2101 ) -> AsyncGenerator[Customer]:
2102 repository = CustomerRepository.from_session(session)
2104 created = False
2105 customer = checkout.customer
2107 if customer is None:
2108 assert checkout.customer_email is not None
2109 customer = await repository.get_by_email_and_organization(
2110 checkout.customer_email, checkout.organization.id
2111 )
2112 if customer is None:
2113 customer = Customer(
2114 external_id=checkout.external_customer_id,
2115 email=checkout.customer_email,
2116 email_verified=False,
2117 stripe_customer_id=None,
2118 organization=checkout.organization,
2119 user_metadata={},
2120 )
2121 created = True
2123 stripe_customer_id = customer.stripe_customer_id
2124 if stripe_customer_id is None:
2125 create_params: stripe_lib.Customer.CreateParams = {"email": customer.email}
2126 if checkout.customer_billing_name is not None:
2127 create_params["name"] = checkout.customer_billing_name
2128 elif checkout.customer_name is not None:
2129 create_params["name"] = checkout.customer_name
2130 if checkout.customer_billing_address is not None:
2131 create_params["address"] = checkout.customer_billing_address.to_dict() # type: ignore
2132 if checkout.customer_tax_id is not None:
2133 create_params["tax_id_data"] = [
2134 to_stripe_tax_id(checkout.customer_tax_id)
2135 ]
2136 stripe_customer = await stripe_service.create_customer(**create_params)
2137 stripe_customer_id = stripe_customer.id
2138 else:
2139 update_params: stripe_lib.Customer.ModifyParams = {"email": customer.email}
2140 if checkout.customer_billing_name is not None:
2141 update_params["name"] = checkout.customer_billing_name
2142 elif checkout.customer_name is not None:
2143 update_params["name"] = checkout.customer_name
2144 if checkout.customer_billing_address is not None:
2145 update_params["address"] = checkout.customer_billing_address.to_dict() # type: ignore
2146 await stripe_service.update_customer(
2147 stripe_customer_id,
2148 tax_id=(
2149 to_stripe_tax_id(checkout.customer_tax_id)
2150 if checkout.customer_tax_id is not None
2151 else None
2152 ),
2153 **update_params,
2154 )
2156 if checkout.customer_name is not None:
2157 customer.name = checkout.customer_name
2158 if checkout.customer_billing_name is not None:
2159 customer.billing_name = checkout.customer_billing_name
2160 if checkout.customer_billing_address is not None:
2161 customer.billing_address = checkout.customer_billing_address
2162 if checkout.customer_tax_id is not None:
2163 customer.tax_id = checkout.customer_tax_id
2165 customer.stripe_customer_id = stripe_customer_id
2166 customer.user_metadata = {
2167 **customer.user_metadata,
2168 **checkout.customer_metadata,
2169 }
2171 if created:
2172 async with repository.create_context(customer, flush=False) as customer:
2173 await member_service.create_owner_member(
2174 session, customer, checkout.organization
2175 )
2176 yield customer
2177 else:
2178 yield await repository.update(customer, flush=True)
2180 async def _create_ad_hoc_custom_price( 1a
2181 self, checkout: Checkout, *, idempotency_key: str | None = None
2182 ) -> stripe_lib.Price:
2183 assert has_product_checkout(checkout)
2184 assert checkout.product.stripe_product_id is not None
2185 price_params: stripe_lib.Price.CreateParams = {
2186 "unit_amount": checkout.amount,
2187 "currency": checkout.currency,
2188 "metadata": {
2189 "product_price_id": str(checkout.product_price_id),
2190 },
2191 }
2192 if checkout.product.is_recurring:
2193 recurring_interval: SubscriptionRecurringInterval
2194 if isinstance(checkout.product_price, LegacyRecurringProductPriceCustom):
2195 recurring_interval = checkout.product_price.recurring_interval
2196 else:
2197 assert checkout.product.recurring_interval is not None
2198 recurring_interval = checkout.product.recurring_interval
2199 price_params["recurring"] = {
2200 "interval": recurring_interval.as_literal(),
2201 }
2202 return await stripe_service.create_price_for_product(
2203 checkout.product.stripe_product_id,
2204 price_params,
2205 idempotency_key=idempotency_key,
2206 )
2208 async def _after_checkout_created( 1a
2209 self, session: AsyncSession, checkout: Checkout
2210 ) -> None:
2211 await webhook_service.send(
2212 session, checkout.organization, WebhookEventType.checkout_created, checkout
2213 )
2215 async def _after_checkout_updated( 1a
2216 self, session: AsyncSession, checkout: Checkout
2217 ) -> None:
2218 await publish_checkout_event(
2219 checkout.client_secret, CheckoutEvent.updated, {"status": checkout.status}
2220 )
2221 events = await webhook_service.send(
2222 session, checkout.organization, WebhookEventType.checkout_updated, checkout
2223 )
2224 # No webhook to send, publish the webhook_event immediately
2225 if len(events) == 0:
2226 await publish_checkout_event(
2227 checkout.client_secret,
2228 CheckoutEvent.webhook_event_delivered,
2229 {"status": checkout.status},
2230 )
2232 async def _eager_load_product( 1a
2233 self, session: AsyncSession, product: Product
2234 ) -> Product:
2235 await session.refresh(
2236 product,
2237 {"organization", "prices", "product_medias", "attached_custom_fields"},
2238 )
2239 return product
2242checkout = CheckoutService() 1a