Coverage for polar/models/subscription.py: 61%
211 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
1from collections.abc import Sequence 1ab
2from datetime import UTC, datetime 1ab
3from enum import StrEnum 1ab
4from typing import TYPE_CHECKING, Self 1ab
5from uuid import UUID 1ab
7from sqlalchemy import ( 1ab
8 TIMESTAMP,
9 Boolean,
10 ColumnElement,
11 ForeignKey,
12 Integer,
13 String,
14 Text,
15 Uuid,
16 event,
17 type_coerce,
18)
19from sqlalchemy.ext.associationproxy import AssociationProxy, association_proxy 1ab
20from sqlalchemy.ext.hybrid import hybrid_property 1ab
21from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship 1ab
22from sqlalchemy.orm.attributes import OP_BULK_REPLACE, Event 1ab
24from polar.custom_field.data import CustomFieldDataMixin 1ab
25from polar.enums import SubscriptionRecurringInterval 1ab
26from polar.kit.db.models import RecordModel 1ab
27from polar.kit.extensions.sqlalchemy.types import StringEnum 1ab
28from polar.kit.metadata import MetadataMixin 1ab
29from polar.product.guard import is_metered_price 1ab
31from .product_price import HasPriceCurrency 1ab
32from .subscription_meter import SubscriptionMeter 1ab
34if TYPE_CHECKING: 34 ↛ 35line 34 didn't jump to line 35 because the condition on line 34 was never true1ab
35 from . import (
36 BenefitGrant,
37 Checkout,
38 Customer,
39 CustomerSeat,
40 Discount,
41 Meter,
42 Organization,
43 PaymentMethod,
44 Product,
45 ProductPrice,
46 SubscriptionProductPrice,
47 )
50class SubscriptionStatus(StrEnum): 1ab
51 incomplete = "incomplete" 1ab
52 incomplete_expired = "incomplete_expired" 1ab
53 trialing = "trialing" 1ab
54 active = "active" 1ab
55 past_due = "past_due" 1ab
56 canceled = "canceled" 1ab
57 unpaid = "unpaid" 1ab
59 @classmethod 1ab
60 def incomplete_statuses(cls) -> set[Self]: 1ab
61 return {cls.incomplete, cls.incomplete_expired} # type: ignore
63 @classmethod 1ab
64 def active_statuses(cls) -> set[Self]: 1ab
65 return {cls.trialing, cls.active} # type: ignore
67 @classmethod 1ab
68 def revoked_statuses(cls) -> set[Self]: 1ab
69 return {cls.past_due, cls.canceled, cls.unpaid} # type: ignore
71 @classmethod 1ab
72 def billable_statuses(cls) -> set[Self]: 1ab
73 return cls.active_statuses() | {cls.past_due} # type: ignore
75 @classmethod 1ab
76 def is_incomplete(cls, status: Self) -> bool: 1ab
77 return status in cls.incomplete_statuses()
79 @classmethod 1ab
80 def is_active(cls, status: Self) -> bool: 1ab
81 return status in cls.active_statuses()
83 @classmethod 1ab
84 def is_revoked(cls, status: Self) -> bool: 1ab
85 return status in cls.revoked_statuses()
87 @classmethod 1ab
88 def is_billable(cls, status: Self) -> bool: 1ab
89 return status in cls.billable_statuses()
92class CustomerCancellationReason(StrEnum): 1ab
93 customer_service = "customer_service" 1ab
94 low_quality = "low_quality" 1ab
95 missing_features = "missing_features" 1ab
96 switched_service = "switched_service" 1ab
97 too_complex = "too_complex" 1ab
98 too_expensive = "too_expensive" 1ab
99 unused = "unused" 1ab
100 other = "other" 1ab
103class Subscription(CustomFieldDataMixin, MetadataMixin, RecordModel): 1ab
104 __tablename__ = "subscriptions" 1ab
106 amount: Mapped[int] = mapped_column(Integer, nullable=False) 1ab
107 currency: Mapped[str] = mapped_column(String(3), nullable=False) 1ab
108 recurring_interval: Mapped[SubscriptionRecurringInterval] = mapped_column( 1ab
109 StringEnum(SubscriptionRecurringInterval), nullable=False, index=True
110 )
111 recurring_interval_count: Mapped[int] = mapped_column(Integer, nullable=False) 1ab
113 stripe_subscription_id: Mapped[str | None] = mapped_column( 1ab
114 String, nullable=True, index=True, default=None
115 )
116 """ 1ab
117 The ID of the subscription in Stripe.
119 If set, indicates that the subscription is managed by Stripe Billing.
120 """
121 legacy_stripe_subscription_id: Mapped[str | None] = mapped_column( 1ab
122 String, nullable=True, index=True, default=None
123 )
124 """ 1ab
125 Original ID of the subscription in Stripe.
127 If set, indicates that the subscription was originally managed by Stripe Billing,
128 but has been migrated to be managed by Polar.
129 """
131 tax_exempted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) 1ab
132 """ 1ab
133 Whether the subscription is tax exempted.
135 We use this to disable tax on subscriptions created before we were
136 registered in a given country, so we don't surprise customers with
137 tax charges.
138 """
140 status: Mapped[SubscriptionStatus] = mapped_column( 1ab
141 StringEnum(SubscriptionStatus), nullable=False
142 )
143 current_period_start: Mapped[datetime] = mapped_column( 1ab
144 TIMESTAMP(timezone=True), nullable=False
145 )
146 current_period_end: Mapped[datetime | None] = mapped_column( 1ab
147 TIMESTAMP(timezone=True), nullable=True, default=None
148 )
149 trial_start: Mapped[datetime | None] = mapped_column( 1ab
150 TIMESTAMP(timezone=True), nullable=True, default=None
151 )
152 trial_end: Mapped[datetime | None] = mapped_column( 1ab
153 TIMESTAMP(timezone=True), nullable=True, default=None
154 )
155 cancel_at_period_end: Mapped[bool] = mapped_column(Boolean, nullable=False) 1ab
156 canceled_at: Mapped[datetime | None] = mapped_column( 1ab
157 TIMESTAMP(timezone=True), nullable=True, default=None, index=True
158 )
159 started_at: Mapped[datetime | None] = mapped_column( 1ab
160 TIMESTAMP(timezone=True), nullable=True, default=None, index=True
161 )
162 ends_at: Mapped[datetime | None] = mapped_column( 1ab
163 TIMESTAMP(timezone=True), nullable=True, default=None, index=True
164 )
165 ended_at: Mapped[datetime | None] = mapped_column( 1ab
166 TIMESTAMP(timezone=True), nullable=True, default=None, index=True
167 )
168 past_due_at: Mapped[datetime | None] = mapped_column( 1ab
169 TIMESTAMP(timezone=True), nullable=True, default=None
170 )
172 scheduler_locked_at: Mapped[datetime | None] = mapped_column( 1ab
173 TIMESTAMP(timezone=True), nullable=True, default=None, index=True
174 )
176 customer_id: Mapped[UUID] = mapped_column( 1ab
177 Uuid, ForeignKey("customers.id", ondelete="cascade"), nullable=False, index=True
178 )
180 @declared_attr 1ab
181 def customer(cls) -> Mapped["Customer"]: 1ab
182 return relationship("Customer", lazy="raise") 1ab
184 payment_method_id: Mapped[UUID | None] = mapped_column( 1ab
185 Uuid, ForeignKey("payment_methods.id", ondelete="set null"), nullable=True
186 )
188 @declared_attr 1ab
189 def payment_method(cls) -> Mapped["PaymentMethod | None"]: 1ab
190 return relationship("PaymentMethod", lazy="raise") 1ab
192 product_id: Mapped[UUID] = mapped_column( 1ab
193 Uuid,
194 ForeignKey("products.id", ondelete="cascade"),
195 nullable=False,
196 index=True,
197 )
199 @declared_attr 1ab
200 def product(cls) -> Mapped["Product"]: 1ab
201 return relationship("Product", lazy="raise") 1ab
203 subscription_product_prices: Mapped[list["SubscriptionProductPrice"]] = ( 1ab
204 relationship(
205 "SubscriptionProductPrice",
206 back_populates="subscription",
207 cascade="all, delete-orphan",
208 # Prices are almost always needed, so eager loading makes sense
209 lazy="selectin",
210 )
211 )
213 prices: AssociationProxy[list["ProductPrice"]] = association_proxy( 1ab
214 "subscription_product_prices", "product_price"
215 )
217 discount_id: Mapped[UUID | None] = mapped_column( 1ab
218 Uuid, ForeignKey("discounts.id", ondelete="set null"), nullable=True
219 )
221 @declared_attr 1ab
222 def discount(cls) -> Mapped["Discount | None"]: 1ab
223 return relationship("Discount", lazy="joined") 1ab
225 meters: Mapped[list[SubscriptionMeter]] = relationship( 1ab
226 SubscriptionMeter,
227 order_by="SubscriptionMeter.created_at",
228 back_populates="subscription",
229 cascade="all, delete-orphan",
230 # Eager load
231 lazy="selectin",
232 )
234 organization: AssociationProxy["Organization"] = association_proxy( 1ab
235 "product", "organization"
236 )
238 checkout_id: Mapped[UUID | None] = mapped_column( 1ab
239 Uuid, ForeignKey("checkouts.id", ondelete="set null"), nullable=True, index=True
240 )
242 customer_cancellation_reason: Mapped[CustomerCancellationReason | None] = ( 1ab
243 mapped_column(String, nullable=True)
244 )
245 customer_cancellation_comment: Mapped[str | None] = mapped_column( 1ab
246 Text, nullable=True
247 )
249 seats: Mapped[int | None] = mapped_column(Integer, nullable=True, default=None) 1ab
251 @declared_attr 1ab
252 def checkout(cls) -> Mapped["Checkout | None"]: 1ab
253 return relationship( 1ab
254 "Checkout",
255 lazy="raise",
256 foreign_keys=[cls.checkout_id], # type: ignore
257 )
259 @declared_attr 1ab
260 def grants(cls) -> Mapped[list["BenefitGrant"]]: 1ab
261 return relationship( 1ab
262 "BenefitGrant",
263 lazy="raise",
264 order_by="BenefitGrant.benefit_id",
265 back_populates="subscription",
266 )
268 @declared_attr 1ab
269 def customer_seats(cls) -> Mapped[list["CustomerSeat"]]: 1ab
270 return relationship( 1ab
271 "CustomerSeat",
272 lazy="raise",
273 back_populates="subscription",
274 cascade="all, delete-orphan",
275 )
277 def is_incomplete(self) -> bool: 1ab
278 return SubscriptionStatus.is_incomplete(self.status)
280 @hybrid_property 1ab
281 def trialing(self) -> bool: 1ab
282 return self.status == SubscriptionStatus.trialing
284 @trialing.inplace.expression 1ab
285 @classmethod 1ab
286 def _trialing_expression(cls) -> ColumnElement[bool]: 1ab
287 return cls.status == SubscriptionStatus.trialing
289 @hybrid_property 1ab
290 def active(self) -> bool: 1ab
291 return SubscriptionStatus.is_active(self.status)
293 @active.inplace.expression 1ab
294 @classmethod 1ab
295 def _active_expression(cls) -> ColumnElement[bool]: 1ab
296 return type_coerce(
297 cls.status.in_(SubscriptionStatus.active_statuses()),
298 Boolean,
299 )
301 @hybrid_property 1ab
302 def revoked(self) -> bool: 1ab
303 return SubscriptionStatus.is_revoked(self.status)
305 @revoked.inplace.expression 1ab
306 @classmethod 1ab
307 def _revoked_expression(cls) -> ColumnElement[bool]: 1ab
308 return type_coerce(
309 cls.status.in_(SubscriptionStatus.revoked_statuses()),
310 Boolean,
311 )
313 @hybrid_property 1ab
314 def canceled(self) -> bool: 1ab
315 return self.canceled_at is not None
317 @canceled.inplace.expression 1ab
318 @classmethod 1ab
319 def _canceled_expression(cls) -> ColumnElement[bool]: 1ab
320 return cls.canceled_at.is_not(None)
322 @hybrid_property 1ab
323 def billable(self) -> bool: 1ab
324 return SubscriptionStatus.is_billable(self.status)
326 @billable.inplace.expression 1ab
327 @classmethod 1ab
328 def _billable_expression(cls) -> ColumnElement[bool]: 1ab
329 return type_coerce(
330 cls.status.in_(SubscriptionStatus.billable_statuses()),
331 Boolean,
332 )
334 def can_cancel(self, immediately: bool = False) -> bool: 1ab
335 if not SubscriptionStatus.is_billable(self.status):
336 return False
338 if self.ended_at:
339 return False
341 if immediately:
342 return True
344 if self.cancel_at_period_end or self.ends_at:
345 return False
346 return True
348 def can_uncancel(self) -> bool: 1ab
349 return (
350 self.cancel_at_period_end
351 and self.status in SubscriptionStatus.billable_statuses()
352 )
354 def set_started_at(self) -> None: 1ab
355 """
356 Stores the starting date when the subscription
357 becomes active for the first time.
358 """
359 if self.active and self.started_at is None:
360 self.started_at = datetime.now(UTC)
362 def update_amount_and_currency( 1ab
363 self, prices: Sequence["SubscriptionProductPrice"], discount: "Discount | None"
364 ) -> None:
365 amount = sum(price.amount for price in prices)
366 if discount is not None:
367 amount -= discount.get_discount_amount(amount)
368 self.amount = amount
370 currencies = set(
371 price.product_price.price_currency
372 for price in prices
373 if isinstance(price.product_price, HasPriceCurrency)
374 )
375 if len(currencies) == 0:
376 self.currency = "usd" # FIXME: Main Polar currency
377 elif len(currencies) == 1:
378 self.currency = currencies.pop()
379 else:
380 raise ValueError("Multiple currencies in subscription prices")
382 def update_meters(self, prices: Sequence["SubscriptionProductPrice"]) -> None: 1ab
383 subscription_meters = self.meters or []
385 # Add new ones
386 price_meters = [
387 price.product_price.meter
388 for price in prices
389 if is_metered_price(price.product_price)
390 ]
391 for price_meter in price_meters:
392 try:
393 # Check if the meter already exists in the subscription
394 next(sm for sm in subscription_meters if sm.meter == price_meter)
395 except StopIteration:
396 # If it doesn't, create a new SubscriptionMeter
397 subscription_meters.append(SubscriptionMeter(meter=price_meter))
399 # Remove old ones
400 for subscription_meter in subscription_meters:
401 if subscription_meter.meter not in price_meters:
402 subscription_meters.remove(subscription_meter)
404 self.meters = subscription_meters
406 def get_meter(self, meter: "Meter") -> SubscriptionMeter | None: 1ab
407 for subscription_meter in self.meters:
408 if subscription_meter.meter_id == meter.id:
409 return subscription_meter
410 return None
413@event.listens_for(Subscription.subscription_product_prices, "bulk_replace") 1ab
414def _prices_replaced( 1ab
415 target: Subscription, values: list["SubscriptionProductPrice"], initiator: Event
416) -> None:
417 target.update_amount_and_currency(values, target.discount)
418 target.update_meters(values)
421@event.listens_for(Subscription.subscription_product_prices, "append") 1ab
422def _price_appended( 1ab
423 target: Subscription, value: "SubscriptionProductPrice", initiator: Event
424) -> None:
425 # In case of a bulk replace, do nothing.
426 # The bulk replace event will handle the update as a whole, preventing errors
427 # where the append handler deletes a meter on first append which is still needed
428 # in subsequent appends.
429 if initiator is not None and initiator.op is OP_BULK_REPLACE:
430 return
432 target.update_amount_and_currency(
433 [*target.subscription_product_prices, value], target.discount
434 )
435 target.update_meters([*target.subscription_product_prices, value])
438@event.listens_for(Subscription.discount, "set") 1ab
439def _discount_set( 1ab
440 target: Subscription,
441 value: "Discount | None",
442 oldvalue: "Discount | None",
443 initiator: Event,
444) -> None:
445 target.update_amount_and_currency(target.subscription_product_prices, value)