Coverage for polar/models/checkout.py: 61%
242 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
1import uuid 1ab
2from collections.abc import Sequence 1ab
3from datetime import datetime, timedelta 1ab
4from enum import StrEnum 1ab
5from typing import TYPE_CHECKING, Any, TypedDict 1ab
6from uuid import UUID 1ab
8from sqlalchemy import ( 1ab
9 TIMESTAMP,
10 Boolean,
11 ColumnElement,
12 Connection,
13 ForeignKey,
14 Integer,
15 String,
16 Uuid,
17 event,
18)
19from sqlalchemy.dialects.postgresql import JSONB 1ab
20from sqlalchemy.ext.associationproxy import AssociationProxy, association_proxy 1ab
21from sqlalchemy.ext.hybrid import hybrid_property 1ab
22from sqlalchemy.orm import Mapped, Mapper, declared_attr, mapped_column, relationship 1ab
24from polar.config import settings 1ab
25from polar.custom_field.data import CustomFieldDataMixin 1ab
26from polar.enums import PaymentProcessor 1ab
27from polar.kit.address import Address, AddressType 1ab
28from polar.kit.db.models import RecordModel 1ab
29from polar.kit.metadata import MetadataColumn, MetadataMixin 1ab
30from polar.kit.tax import TaxID, TaxIDType 1ab
31from polar.kit.trial import TrialConfigurationMixin, TrialInterval 1ab
32from polar.kit.utils import utc_now 1ab
33from polar.product.guard import ( 1ab
34 is_discount_applicable,
35 is_free_price,
36 is_metered_price,
37)
39from .customer import Customer 1ab
40from .discount import Discount 1ab
41from .organization import Organization 1ab
42from .product import Product 1ab
43from .product_price import ProductPrice, ProductPriceSeatUnit 1ab
44from .subscription import Subscription 1ab
46if TYPE_CHECKING: 46 ↛ 47line 46 didn't jump to line 47 because the condition on line 46 was never true1ab
47 from polar.custom_field.attachment import AttachedCustomFieldMixin
49 from .checkout_product import CheckoutProduct
52def get_expires_at() -> datetime: 1ab
53 return utc_now() + timedelta(seconds=settings.CHECKOUT_TTL_SECONDS)
56class CheckoutStatus(StrEnum): 1ab
57 open = "open" 1ab
58 expired = "expired" 1ab
59 confirmed = "confirmed" 1ab
60 succeeded = "succeeded" 1ab
61 failed = "failed" 1ab
64class CheckoutCustomerBillingAddressFields(TypedDict): 1ab
65 """
66 Deprecated: Use CheckoutBillingAddressFields instead.
67 """
69 country: bool 1ab
70 state: bool 1ab
71 city: bool 1ab
72 postal_code: bool 1ab
73 line1: bool 1ab
74 line2: bool 1ab
77class BillingAddressFieldMode(StrEnum): 1ab
78 required = "required" 1ab
79 optional = "optional" 1ab
80 disabled = "disabled" 1ab
83class CheckoutBillingAddressFields(TypedDict): 1ab
84 country: BillingAddressFieldMode 1ab
85 state: BillingAddressFieldMode 1ab
86 city: BillingAddressFieldMode 1ab
87 postal_code: BillingAddressFieldMode 1ab
88 line1: BillingAddressFieldMode 1ab
89 line2: BillingAddressFieldMode 1ab
92class Checkout( 1ab
93 TrialConfigurationMixin, CustomFieldDataMixin, MetadataMixin, RecordModel
94):
95 __tablename__ = "checkouts" 1ab
97 payment_processor: Mapped[PaymentProcessor] = mapped_column( 1ab
98 String, nullable=False, default=PaymentProcessor.stripe, index=True
99 )
100 status: Mapped[CheckoutStatus] = mapped_column( 1ab
101 String, nullable=False, default=CheckoutStatus.open, index=True
102 )
103 client_secret: Mapped[str] = mapped_column( 1ab
104 String, index=True, nullable=False, unique=True
105 )
106 expires_at: Mapped[datetime] = mapped_column( 1ab
107 TIMESTAMP(timezone=True), index=True, nullable=False, default=get_expires_at
108 )
109 payment_processor_metadata: Mapped[dict[str, Any]] = mapped_column( 1ab
110 JSONB, nullable=False, default=dict
111 )
112 return_url: Mapped[str | None] = mapped_column(String, nullable=True, default=None) 1ab
113 _success_url: Mapped[str | None] = mapped_column( 1ab
114 "success_url", String, nullable=True, default=None
115 )
116 embed_origin: Mapped[str | None] = mapped_column(String, nullable=True) 1ab
117 allow_discount_codes: Mapped[bool] = mapped_column( 1ab
118 Boolean, nullable=False, default=True
119 )
120 require_billing_address: Mapped[bool] = mapped_column( 1ab
121 Boolean, nullable=False, default=False
122 )
124 amount: Mapped[int] = mapped_column(Integer, nullable=False) 1ab
125 currency: Mapped[str] = mapped_column(String(3), nullable=False) 1ab
126 seats: Mapped[int | None] = mapped_column(Integer, nullable=True, default=None) 1ab
128 tax_amount: Mapped[int | None] = mapped_column(Integer, nullable=True, default=None) 1ab
129 tax_processor_id: Mapped[str | None] = mapped_column( 1ab
130 String, nullable=True, default=None
131 )
133 # TODO: proper data migration to make it non-nullable
134 allow_trial: Mapped[bool | None] = mapped_column( 1ab
135 Boolean, nullable=True, default=True
136 )
137 trial_end: Mapped[datetime | None] = mapped_column( 1ab
138 TIMESTAMP(timezone=True), nullable=True, default=None
139 )
141 organization_id: Mapped[UUID] = mapped_column( 1ab
142 Uuid,
143 ForeignKey("organizations.id", ondelete="cascade"),
144 nullable=False,
145 index=True,
146 )
148 @declared_attr 1ab
149 def organization(cls) -> Mapped["Organization"]: 1ab
150 return relationship("Organization", lazy="raise") 1ab
152 product_id: Mapped[UUID | None] = mapped_column( 1ab
153 Uuid, ForeignKey("products.id", ondelete="cascade"), nullable=True
154 )
156 @declared_attr 1ab
157 def product(cls) -> Mapped[Product | None]: 1ab
158 return relationship(Product, lazy="raise") 1ab
160 product_price_id: Mapped[UUID | None] = mapped_column( 1ab
161 Uuid, ForeignKey("product_prices.id", ondelete="cascade"), nullable=True
162 )
164 @declared_attr 1ab
165 def product_price(cls) -> Mapped[ProductPrice | None]: 1ab
166 return relationship(ProductPrice, lazy="raise") 1ab
168 checkout_products: Mapped[list["CheckoutProduct"]] = relationship( 1ab
169 "CheckoutProduct",
170 back_populates="checkout",
171 cascade="all, delete-orphan",
172 order_by="CheckoutProduct.order",
173 lazy="raise",
174 )
176 products: AssociationProxy[list["Product"]] = association_proxy( 1ab
177 "checkout_products", "product"
178 )
180 discount_id: Mapped[UUID | None] = mapped_column( 1ab
181 Uuid, ForeignKey("discounts.id", ondelete="set null"), nullable=True
182 )
184 @declared_attr 1ab
185 def discount(cls) -> Mapped[Discount | None]: 1ab
186 return relationship(Discount, lazy="raise") 1ab
188 customer_id: Mapped[UUID | None] = mapped_column( 1ab
189 Uuid, ForeignKey("customers.id", ondelete="set null"), nullable=True
190 )
192 @declared_attr 1ab
193 def customer(cls) -> Mapped[Customer | None]: 1ab
194 return relationship(Customer, lazy="raise") 1ab
196 is_business_customer: Mapped[bool] = mapped_column( 1ab
197 Boolean, nullable=False, default=False
198 )
199 external_customer_id: Mapped[str | None] = mapped_column( 1ab
200 String, nullable=True, default=None
201 )
202 customer_name: Mapped[str | None] = mapped_column( 1ab
203 String, nullable=True, default=None
204 )
205 customer_email: Mapped[str | None] = mapped_column( 1ab
206 String, nullable=True, default=None
207 )
208 _customer_ip_address: Mapped[str | None] = mapped_column( 1ab
209 "customer_ip_address", String, nullable=True, default=None
210 )
211 customer_billing_name: Mapped[str | None] = mapped_column( 1ab
212 String, nullable=True, default=None
213 )
214 customer_billing_address: Mapped[Address | None] = mapped_column( 1ab
215 AddressType, nullable=True, default=None
216 )
217 customer_tax_id: Mapped[TaxID | None] = mapped_column( 1ab
218 TaxIDType, nullable=True, default=None
219 )
220 customer_metadata: Mapped[MetadataColumn] 1ab
222 # Only set when a checkout is attached to an existing subscription (free-to-paid upgrades).
223 # For subscriptions created by the checkout itself, see `Subscription.checkout_id`.
224 subscription_id: Mapped[UUID | None] = mapped_column( 1ab
225 Uuid, ForeignKey("subscriptions.id", ondelete="set null"), nullable=True
226 )
228 @declared_attr 1ab
229 def subscription(cls) -> Mapped[Subscription | None]: 1ab
230 return relationship( 1ab
231 Subscription,
232 lazy="raise",
233 foreign_keys=[cls.subscription_id], # type: ignore
234 )
236 @hybrid_property 1ab
237 def is_expired(self) -> bool: 1ab
238 return self.expires_at < utc_now()
240 @is_expired.inplace.expression 1ab
241 @classmethod 1ab
242 def _is_expired_expression(cls) -> ColumnElement[bool]: 1ab
243 return cls.expires_at < utc_now()
245 @hybrid_property 1ab
246 def customer_ip_address(self) -> str | None: 1ab
247 return self._customer_ip_address
249 @customer_ip_address.inplace.setter 1ab
250 def _customer_ip_address_setter(self, value: Any | None) -> None: 1ab
251 self._customer_ip_address = str(value) if value is not None else None
253 @property 1ab
254 def success_url(self) -> str: 1ab
255 if self._success_url is None:
256 return settings.generate_frontend_url(
257 f"/checkout/{self.client_secret}/confirmation"
258 )
259 try:
260 return self._success_url.format(CHECKOUT_ID=self.id)
261 except KeyError:
262 return self._success_url
264 @success_url.setter 1ab
265 def success_url(self, value: str | None) -> None: 1ab
266 self._success_url = str(value) if value is not None else None
268 @property 1ab
269 def customer_tax_id_number(self) -> str | None: 1ab
270 return self.customer_tax_id[0] if self.customer_tax_id is not None else None
272 @property 1ab
273 def discount_amount(self) -> int: 1ab
274 return self.discount.get_discount_amount(self.amount) if self.discount else 0
276 @property 1ab
277 def net_amount(self) -> int: 1ab
278 return self.amount - self.discount_amount
280 @property 1ab
281 def total_amount(self) -> int: 1ab
282 return self.net_amount + (self.tax_amount or 0)
284 @property 1ab
285 def is_discount_applicable(self) -> bool: 1ab
286 if self.product_prices is None:
287 return False
288 return any(is_discount_applicable(price) for price in self.product_prices)
290 @property 1ab
291 def is_free_product_price(self) -> bool: 1ab
292 if self.product_prices is None:
293 return False
294 return all(is_free_price(price) for price in self.product_prices)
296 @property 1ab
297 def has_metered_prices(self) -> bool: 1ab
298 if self.product_prices is None:
299 return False
300 return any(is_metered_price(price) for price in self.product_prices)
302 @property 1ab
303 def is_payment_required(self) -> bool: 1ab
304 return self.total_amount > 0 and self.trial_end is None
306 @property 1ab
307 def is_payment_setup_required(self) -> bool: 1ab
308 if self.product is None:
309 return False
310 return self.product.is_recurring and not self.is_free_product_price
312 @property 1ab
313 def should_save_payment_method(self) -> bool: 1ab
314 return self.product is not None and self.product.is_recurring
316 @property 1ab
317 def is_payment_form_required(self) -> bool: 1ab
318 return self.is_payment_required or self.is_payment_setup_required
320 @property 1ab
321 def url(self) -> str: 1ab
322 return settings.generate_frontend_url(f"/checkout/{self.client_secret}")
324 @property 1ab
325 def customer_session_token(self) -> str | None: 1ab
326 return getattr(self, "_customer_session_token", None)
328 @customer_session_token.setter 1ab
329 def customer_session_token(self, value: str) -> None: 1ab
330 self._customer_session_token = value
332 attached_custom_fields: AssociationProxy[ 1ab
333 Sequence["AttachedCustomFieldMixin"] | None
334 ] = association_proxy("product", "attached_custom_fields")
336 @property 1ab
337 def customer_billing_address_fields(self) -> CheckoutCustomerBillingAddressFields: 1ab
338 address = self.customer_billing_address
339 country = address.country if address else None
340 is_us = country == "US"
341 require_billing_address = (
342 self.require_billing_address or self.is_business_customer or is_us
343 )
344 return {
345 "country": True,
346 "state": country in {"US", "CA"},
347 "line1": require_billing_address,
348 "line2": False,
349 "city": require_billing_address,
350 "postal_code": require_billing_address,
351 }
353 @property 1ab
354 def billing_address_fields(self) -> CheckoutBillingAddressFields: 1ab
355 address = self.customer_billing_address
356 country = address.country if address else None
357 is_us = country == "US"
358 require_billing_address = (
359 self.require_billing_address or self.is_business_customer or is_us
360 )
361 return {
362 "country": BillingAddressFieldMode.required,
363 "state": BillingAddressFieldMode.required
364 if country in {"US", "CA"}
365 else (
366 BillingAddressFieldMode.optional
367 if require_billing_address
368 else BillingAddressFieldMode.disabled
369 ),
370 "line1": BillingAddressFieldMode.required
371 if require_billing_address
372 else BillingAddressFieldMode.disabled,
373 "line2": BillingAddressFieldMode.optional
374 if require_billing_address
375 else BillingAddressFieldMode.disabled,
376 "city": BillingAddressFieldMode.required
377 if require_billing_address
378 else BillingAddressFieldMode.disabled,
379 "postal_code": BillingAddressFieldMode.required
380 if require_billing_address
381 else BillingAddressFieldMode.disabled,
382 }
384 @property 1ab
385 def active_trial_interval(self) -> TrialInterval | None: 1ab
386 if not self.allow_trial:
387 return None
388 if self.product is None:
389 return None
390 return self.trial_interval or self.product.trial_interval
392 @property 1ab
393 def active_trial_interval_count(self) -> int | None: 1ab
394 if not self.allow_trial:
395 return None
396 if self.product is None:
397 return None
398 return self.trial_interval_count or self.product.trial_interval_count
400 @property 1ab
401 def price_per_seat(self) -> int | None: 1ab
402 if not isinstance(self.product_price, ProductPriceSeatUnit):
403 return None
405 if self.seats is None:
406 return None
408 return self.product_price.get_price_per_seat(self.seats)
410 @property 1ab
411 def description(self) -> str: 1ab
412 if self.product is not None:
413 return f"{self.organization.name} — {self.product.name}"
414 raise NotImplementedError()
416 @property 1ab
417 def prices(self) -> dict[uuid.UUID, list[ProductPrice]]: 1ab
418 prices: dict[uuid.UUID, list[ProductPrice]] = {}
419 for checkout_product in self.checkout_products:
420 if checkout_product.ad_hoc_prices:
421 prices[checkout_product.product_id] = checkout_product.ad_hoc_prices
422 else:
423 prices[checkout_product.product_id] = checkout_product.product.prices
424 return prices
426 @property 1ab
427 def product_prices(self) -> list[ProductPrice] | None: 1ab
428 if self.product_id is None:
429 return None
430 return self.prices[self.product_id]
433@event.listens_for(Checkout, "before_update") 1ab
434def check_expiration( 1ab
435 mapper: Mapper[Any], connection: Connection, target: Checkout
436) -> None:
437 if target.expires_at < utc_now() and target.status == CheckoutStatus.open:
438 target.status = CheckoutStatus.expired