Coverage for polar/models/product_price.py: 63%
189 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
1from decimal import Decimal 1ab
2from enum import StrEnum 1ab
3from typing import TYPE_CHECKING, Any, Literal, TypedDict 1ab
4from uuid import UUID 1ab
6import stripe as stripe_lib 1ab
7from babel.numbers import format_currency, format_decimal 1ab
8from sqlalchemy import ( 1ab
9 Boolean,
10 ColumnElement,
11 ForeignKey,
12 Integer,
13 Numeric,
14 String,
15 Uuid,
16 case,
17 event,
18 func,
19 type_coerce,
20)
21from sqlalchemy.dialects import postgresql 1ab
22from sqlalchemy.ext.hybrid import hybrid_property 1ab
23from sqlalchemy.orm import ( 1ab
24 Mapped,
25 declared_attr,
26 mapped_column,
27 object_mapper,
28 relationship,
29)
31from polar.enums import SubscriptionRecurringInterval 1ab
32from polar.kit.db.models import RecordModel 1ab
33from polar.kit.extensions.sqlalchemy.types import StringEnum 1ab
34from polar.kit.math import polar_round 1ab
36if TYPE_CHECKING: 36 ↛ 37line 36 didn't jump to line 37 because the condition on line 36 was never true1ab
37 from polar.models import Meter, Product
40class ProductPriceType(StrEnum): 1ab
41 one_time = "one_time" 1ab
42 recurring = "recurring" 1ab
44 def as_literal(self) -> Literal["one_time", "recurring"]: 1ab
45 return self.value
48class ProductPriceAmountType(StrEnum): 1ab
49 fixed = "fixed" 1ab
50 custom = "custom" 1ab
51 free = "free" 1ab
52 metered_unit = "metered_unit" 1ab
53 seat_based = "seat_based" 1ab
56class ProductPriceSource(StrEnum): 1ab
57 catalog = "catalog" 1ab
58 ad_hoc = "ad_hoc" 1ab
61class SeatTier(TypedDict): 1ab
62 """A single pricing tier for seat-based pricing."""
64 min_seats: int 1ab
65 max_seats: int | None 1ab
66 price_per_seat: int 1ab
69class SeatTiersData(TypedDict): 1ab
70 """The structure of the seat_tiers JSONB column."""
72 tiers: list[SeatTier] 1ab
75class HasPriceCurrency: 1ab
76 price_currency: Mapped[str] = mapped_column( 1ab
77 String(3), nullable=True, use_existing_column=True
78 )
81class HasStripePriceId: 1ab
82 stripe_price_id: Mapped[str] = mapped_column( 1ab
83 String, nullable=True, use_existing_column=True
84 )
86 def get_stripe_price_params( 1ab
87 self, recurring_interval: SubscriptionRecurringInterval | None
88 ) -> stripe_lib.Price.CreateParams:
89 raise NotImplementedError()
92LEGACY_IDENTITY_PREFIX = "legacy_" 1ab
95class ProductPrice(RecordModel): 1ab
96 __tablename__ = "product_prices" 1ab
98 # Legacy: recurring is now set on product
99 type: Mapped[Any] = mapped_column(String, nullable=True, index=True, default=None) 1ab
100 recurring_interval: Mapped[Any] = mapped_column( 1ab
101 StringEnum(SubscriptionRecurringInterval),
102 nullable=True,
103 index=True,
104 default=None,
105 )
107 source = mapped_column( 1ab
108 StringEnum(ProductPriceSource),
109 nullable=False,
110 index=True,
111 default=ProductPriceSource.catalog,
112 )
113 amount_type: Mapped[ProductPriceAmountType] = mapped_column( 1ab
114 String, nullable=False, index=True
115 )
116 is_archived: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) 1ab
118 product_id: Mapped[UUID] = mapped_column( 1ab
119 Uuid, ForeignKey("products.id", ondelete="cascade"), nullable=False, index=True
120 )
122 checkout_product_id: Mapped[UUID | None] = mapped_column( 1ab
123 Uuid,
124 ForeignKey("checkout_products.id", ondelete="set null"),
125 nullable=True,
126 index=True,
127 default=None,
128 )
129 """ 1ab
130 Foreign key to the CheckoutProduct this price is associated with, if any.
132 Used for ad-hoc prices created on-demand for checkout sessions.
133 """
135 @declared_attr 1ab
136 def product(cls) -> Mapped["Product"]: 1ab
137 return relationship("Product", lazy="raise_on_sql", back_populates="all_prices") 1ab
139 @declared_attr 1ab
140 def checkout_product(cls) -> Mapped["Product | None"]: 1ab
141 return relationship( 1ab
142 "CheckoutProduct", lazy="raise_on_sql", back_populates="ad_hoc_prices"
143 )
145 @hybrid_property 1ab
146 def is_recurring(self) -> bool: 1ab
147 return self.type == ProductPriceType.recurring
149 @is_recurring.inplace.expression 1ab
150 @classmethod 1ab
151 def _is_recurring_expression(cls) -> ColumnElement[bool]: 1ab
152 return type_coerce(cls.type == ProductPriceType.recurring, Boolean)
154 @hybrid_property 1ab
155 def is_static(self) -> bool: 1ab
156 return self.amount_type in {
157 ProductPriceAmountType.fixed,
158 ProductPriceAmountType.free,
159 ProductPriceAmountType.custom,
160 ProductPriceAmountType.seat_based,
161 }
163 @is_static.inplace.expression 1ab
164 @classmethod 1ab
165 def _is_static_price_expression(cls) -> ColumnElement[bool]: 1ab
166 return cls.amount_type.in_(
167 (
168 ProductPriceAmountType.fixed,
169 ProductPriceAmountType.free,
170 ProductPriceAmountType.custom,
171 ProductPriceAmountType.seat_based,
172 )
173 )
175 @hybrid_property 1ab
176 def is_metered(self) -> bool: 1ab
177 return self.amount_type in {ProductPriceAmountType.metered_unit}
179 @is_metered.inplace.expression 1ab
180 @classmethod 1ab
181 def _is_metered_price_expression(cls) -> ColumnElement[bool]: 1ab
182 return cls.amount_type.in_((ProductPriceAmountType.metered_unit,))
184 @property 1ab
185 def legacy_type(self) -> ProductPriceType | None: 1ab
186 if self.product.is_recurring:
187 return ProductPriceType.recurring
188 return ProductPriceType.one_time
190 @property 1ab
191 def legacy_recurring_interval(self) -> SubscriptionRecurringInterval | None: 1ab
192 return self.product.recurring_interval
194 __mapper_args__ = { 1ab
195 "polymorphic_on": case(
196 (type.is_(None), amount_type),
197 else_=func.concat(LEGACY_IDENTITY_PREFIX, amount_type),
198 )
199 }
202class LegacyRecurringProductPrice: 1ab
203 __abstract__ = True 1ab
205 type: Mapped[ProductPriceType] = mapped_column( 1ab
206 use_existing_column=True, nullable=True
207 )
208 recurring_interval: Mapped[SubscriptionRecurringInterval] = mapped_column( 1ab
209 use_existing_column=True, nullable=True
210 )
212 __mapper_args__ = { 1ab
213 "polymorphic_abstract": True,
214 "polymorphic_load": "inline",
215 }
218class NewProductPrice: 1ab
219 __abstract__ = True 1ab
221 type: Mapped[Literal[None]] = mapped_column( 1ab
222 use_existing_column=True, nullable=True, default=None
223 )
224 recurring_interval: Mapped[Literal[None]] = mapped_column( 1ab
225 use_existing_column=True, nullable=True, default=None
226 )
228 __mapper_args__ = { 1ab
229 "polymorphic_abstract": True,
230 "polymorphic_load": "inline",
231 }
234class _ProductPriceFixed(HasStripePriceId, HasPriceCurrency, ProductPrice): 1ab
235 price_amount: Mapped[int] = mapped_column(Integer, nullable=True) 1ab
236 amount_type: Mapped[Literal[ProductPriceAmountType.fixed]] = mapped_column( 1ab
237 use_existing_column=True, default=ProductPriceAmountType.fixed
238 )
240 def get_stripe_price_params( 1ab
241 self, recurring_interval: SubscriptionRecurringInterval | None
242 ) -> stripe_lib.Price.CreateParams:
243 params: stripe_lib.Price.CreateParams = {
244 "unit_amount": self.price_amount,
245 "currency": self.price_currency,
246 }
247 if recurring_interval is not None:
248 params = {
249 **params,
250 "recurring": {"interval": recurring_interval.as_literal()},
251 }
252 return params
254 __mapper_args__ = { 1ab
255 "polymorphic_abstract": True,
256 "polymorphic_load": "inline",
257 }
260class ProductPriceFixed(NewProductPrice, _ProductPriceFixed): 1ab
261 __mapper_args__ = { 1ab
262 "polymorphic_identity": ProductPriceAmountType.fixed,
263 "polymorphic_load": "inline",
264 }
267class LegacyRecurringProductPriceFixed(LegacyRecurringProductPrice, _ProductPriceFixed): 1ab
268 __mapper_args__ = { 1ab
269 "polymorphic_identity": f"{LEGACY_IDENTITY_PREFIX}{ProductPriceAmountType.fixed}",
270 "polymorphic_load": "inline",
271 }
274class _ProductPriceCustom(HasStripePriceId, HasPriceCurrency, ProductPrice): 1ab
275 amount_type: Mapped[Literal[ProductPriceAmountType.custom]] = mapped_column( 1ab
276 use_existing_column=True, default=ProductPriceAmountType.custom
277 )
278 minimum_amount: Mapped[int | None] = mapped_column( 1ab
279 Integer, nullable=True, default=None
280 )
281 maximum_amount: Mapped[int | None] = mapped_column( 1ab
282 Integer, nullable=True, default=None
283 )
284 preset_amount: Mapped[int | None] = mapped_column( 1ab
285 Integer, nullable=True, default=None
286 )
288 def get_stripe_price_params( 1ab
289 self, recurring_interval: SubscriptionRecurringInterval | None
290 ) -> stripe_lib.Price.CreateParams:
291 custom_unit_amount_params: stripe_lib.Price.CreateParamsCustomUnitAmount = {
292 "enabled": True,
293 }
294 if self.minimum_amount is not None:
295 custom_unit_amount_params["minimum"] = self.minimum_amount
296 if self.maximum_amount is not None:
297 custom_unit_amount_params["maximum"] = self.maximum_amount
298 if self.preset_amount is not None:
299 custom_unit_amount_params["preset"] = self.preset_amount
301 # `recurring_interval` is unused because we actually create ad-hoc prices,
302 # since Stripe doesn't support PWYW pricing for subscriptions.
304 return {
305 "currency": self.price_currency,
306 "custom_unit_amount": custom_unit_amount_params,
307 }
309 __mapper_args__ = { 1ab
310 "polymorphic_abstract": True,
311 "polymorphic_load": "inline",
312 }
315class ProductPriceCustom(NewProductPrice, _ProductPriceCustom): 1ab
316 __mapper_args__ = { 1ab
317 "polymorphic_identity": ProductPriceAmountType.custom,
318 "polymorphic_load": "inline",
319 }
322class LegacyRecurringProductPriceCustom( 1ab
323 LegacyRecurringProductPrice, _ProductPriceCustom
324):
325 __mapper_args__ = { 1ab
326 "polymorphic_identity": f"{LEGACY_IDENTITY_PREFIX}{ProductPriceAmountType.custom}",
327 "polymorphic_load": "inline",
328 }
331class _ProductPriceFree(HasStripePriceId, ProductPrice): 1ab
332 amount_type: Mapped[Literal[ProductPriceAmountType.free]] = mapped_column( 1ab
333 use_existing_column=True, default=ProductPriceAmountType.free
334 )
336 def get_stripe_price_params( 1ab
337 self, recurring_interval: SubscriptionRecurringInterval | None
338 ) -> stripe_lib.Price.CreateParams:
339 params: stripe_lib.Price.CreateParams = {
340 "unit_amount": 0,
341 "currency": "usd",
342 }
343 if recurring_interval is not None:
344 params = {
345 **params,
346 "recurring": {"interval": recurring_interval.as_literal()},
347 }
349 return params
351 __mapper_args__ = { 1ab
352 "polymorphic_abstract": True,
353 "polymorphic_load": "inline",
354 }
357class ProductPriceFree(NewProductPrice, _ProductPriceFree): 1ab
358 __mapper_args__ = { 1ab
359 "polymorphic_identity": ProductPriceAmountType.free,
360 "polymorphic_load": "inline",
361 }
364class LegacyRecurringProductPriceFree(LegacyRecurringProductPrice, _ProductPriceFree): 1ab
365 __mapper_args__ = { 1ab
366 "polymorphic_identity": f"{LEGACY_IDENTITY_PREFIX}{ProductPriceAmountType.free}",
367 "polymorphic_load": "inline",
368 }
371class ProductPriceMeteredUnit(ProductPrice, HasPriceCurrency, NewProductPrice): 1ab
372 amount_type: Mapped[Literal[ProductPriceAmountType.metered_unit]] = mapped_column( 1ab
373 use_existing_column=True, default=ProductPriceAmountType.metered_unit
374 )
375 unit_amount: Mapped[Decimal] = mapped_column( 1ab
376 Numeric(17, 12), # 12 decimal places, 17 digits total
377 # Polymorphic columns must be nullable, as they don't apply to other types
378 nullable=True,
379 )
380 cap_amount: Mapped[int | None] = mapped_column(Integer, nullable=True, default=None) 1ab
381 meter_id: Mapped[UUID] = mapped_column( 1ab
382 Uuid,
383 ForeignKey("meters.id"),
384 # Polymorphic columns must be nullable, as they don't apply to other types
385 nullable=True,
386 index=True,
387 )
389 @declared_attr 1ab
390 def meter(cls) -> Mapped["Meter"]: 1ab
391 # For convenience, eager load it, at it's embedded in all schemas outputting a price
392 return relationship("Meter", lazy="joined") 1ab
394 def get_amount_and_label(self, units: float) -> tuple[int, str]: 1ab
395 label = f"({format_decimal(units, locale='en_US')} consumed units"
397 label += f") × {format_currency(self.unit_amount / 100, self.price_currency.upper(), locale='en_US')}"
399 billable_units = Decimal(max(0, units))
400 raw_amount = self.unit_amount * billable_units
401 amount = polar_round(raw_amount)
403 if self.cap_amount is not None and amount > self.cap_amount:
404 amount = self.cap_amount
405 label += f"— Capped at {format_currency(self.cap_amount / 100, self.price_currency.upper(), locale='en_US')}"
407 return amount, label
409 __mapper_args__ = { 1ab
410 "polymorphic_identity": ProductPriceAmountType.metered_unit,
411 "polymorphic_load": "inline",
412 }
415class ProductPriceSeatUnit(NewProductPrice, HasPriceCurrency, ProductPrice): 1ab
416 amount_type: Mapped[Literal[ProductPriceAmountType.seat_based]] = mapped_column( 1ab
417 use_existing_column=True, default=ProductPriceAmountType.seat_based
418 )
419 seat_tiers: Mapped[SeatTiersData] = mapped_column( 1ab
420 postgresql.JSONB,
421 nullable=True,
422 )
424 def get_tier_for_seats(self, seats: int) -> SeatTier: 1ab
425 for tier in self.seat_tiers.get("tiers", []):
426 min_seats = tier["min_seats"]
427 max_seats = tier.get("max_seats")
428 if seats >= min_seats and (max_seats is None or seats <= max_seats):
429 return tier
430 raise ValueError(f"No tier found for {seats} seats")
432 def get_price_per_seat(self, seats: int) -> int: 1ab
433 tier = self.get_tier_for_seats(seats)
434 return tier["price_per_seat"]
436 def calculate_amount(self, seats: int) -> int: 1ab
437 return self.get_price_per_seat(seats) * seats
439 __mapper_args__ = { 1ab
440 "polymorphic_identity": ProductPriceAmountType.seat_based,
441 "polymorphic_load": "inline",
442 }
445@event.listens_for(ProductPrice, "init", propagate=True) 1ab
446def set_identity(instance: ProductPrice, *arg: Any, **kw: Any) -> None: 1ab
447 mapper = object_mapper(instance)
449 identity: str | None = mapper.polymorphic_identity
451 if identity is None:
452 return
454 if identity.startswith(LEGACY_IDENTITY_PREFIX):
455 identity = identity[len(LEGACY_IDENTITY_PREFIX) :]
456 else:
457 instance.type = None
458 instance.recurring_interval = None
460 instance.amount_type = ProductPriceAmountType(identity)