Coverage for polar/models/discount.py: 66%
93 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
1from datetime import datetime 1ab
2from enum import StrEnum 1ab
3from typing import TYPE_CHECKING, Literal, cast 1ab
4from uuid import UUID 1ab
6import stripe as stripe_lib 1ab
7from dateutil.relativedelta import relativedelta 1ab
8from sqlalchemy import ( 1ab
9 TIMESTAMP,
10 Column,
11 ForeignKey,
12 Index,
13 Integer,
14 String,
15 Uuid,
16 func,
17 select,
18)
19from sqlalchemy.dialects.postgresql import CITEXT 1ab
20from sqlalchemy.ext.associationproxy import AssociationProxy, association_proxy 1ab
21from sqlalchemy.orm import ( 1ab
22 Mapped,
23 column_property,
24 declared_attr,
25 mapped_column,
26 relationship,
27)
29from polar.kit.db.models import RecordModel 1ab
30from polar.kit.math import polar_round 1ab
31from polar.kit.metadata import MetadataMixin 1ab
33if TYPE_CHECKING: 33 ↛ 34line 33 didn't jump to line 34 because the condition on line 33 was never true1ab
34 from . import DiscountProduct, DiscountRedemption, Organization, Product
37class DiscountType(StrEnum): 1ab
38 fixed = "fixed" 1ab
39 percentage = "percentage" 1ab
41 def get_model(self) -> type["Discount"]: 1ab
42 return {
43 DiscountType.fixed: DiscountFixed,
44 DiscountType.percentage: DiscountPercentage,
45 }[self]
48class DiscountDuration(StrEnum): 1ab
49 once = "once" 1ab
50 forever = "forever" 1ab
51 repeating = "repeating" 1ab
54class Discount(MetadataMixin, RecordModel): 1ab
55 __tablename__ = "discounts" 1ab
57 @declared_attr.directive 1ab
58 def __table_args__(cls) -> tuple[Index]: 1ab
59 # During tests this function is called multiple times which ends up adding the index
60 # multiple times -- leading to errors. We memoize this function to ensure we end up with
61 # the index just once.
62 if not hasattr(cls, "_memoized_indexes"): 1ab
63 _deleted_at_column = cast( 1ab
64 Column[datetime | None], cls.deleted_at
65 ) # cast to satisfy mypy
66 cls._memoized_indexes = ( 1ab
67 Index(
68 "ix_discounts_code_uniqueness",
69 "organization_id",
70 func.lower(cls.code),
71 unique=True,
72 # partial index
73 postgresql_where=(_deleted_at_column.is_(None)),
74 ),
75 )
76 return cls._memoized_indexes 1ab
78 name: Mapped[str] = mapped_column(CITEXT, nullable=False) 1ab
79 type: Mapped[DiscountType] = mapped_column(String, nullable=False) 1ab
80 code: Mapped[str | None] = mapped_column(String, nullable=True, index=True) 1ab
82 starts_at: Mapped[datetime | None] = mapped_column( 1ab
83 TIMESTAMP(timezone=True), nullable=True
84 )
85 ends_at: Mapped[datetime | None] = mapped_column( 1ab
86 TIMESTAMP(timezone=True), nullable=True
87 )
88 max_redemptions: Mapped[int | None] = mapped_column(Integer, nullable=True) 1ab
90 duration: Mapped[DiscountDuration] = mapped_column(String, nullable=False) 1ab
91 duration_in_months: Mapped[int | None] = mapped_column(Integer, nullable=True) 1ab
93 stripe_coupon_id: Mapped[str] = mapped_column( 1ab
94 String, nullable=False, unique=True, index=True
95 )
97 organization_id: Mapped[UUID] = mapped_column( 1ab
98 Uuid,
99 ForeignKey("organizations.id", ondelete="cascade"),
100 nullable=False,
101 index=True,
102 )
104 @declared_attr 1ab
105 def organization(cls) -> Mapped["Organization"]: 1ab
106 return relationship("Organization", lazy="raise") 1ab
108 discount_redemptions: Mapped[list["DiscountRedemption"]] = relationship( 1ab
109 "DiscountRedemption", back_populates="discount", lazy="raise"
110 )
112 discount_products: Mapped[list["DiscountProduct"]] = relationship( 1ab
113 "DiscountProduct",
114 back_populates="discount",
115 cascade="all, delete-orphan",
116 # Products are almost always needed, so eager loading makes sense
117 lazy="selectin",
118 )
120 products: AssociationProxy[list["Product"]] = association_proxy( 1ab
121 "discount_products", "product"
122 )
124 @declared_attr 1ab
125 def redemptions_count(cls) -> Mapped[int]: 1ab
126 from .discount_redemption import DiscountRedemption 1ab
128 return column_property( 1ab
129 select(func.count(DiscountRedemption.id))
130 .where(DiscountRedemption.discount_id == cls.id)
131 .correlate_except(DiscountRedemption)
132 .scalar_subquery()
133 )
135 def get_discount_amount(self, amount: int) -> int: 1ab
136 raise NotImplementedError()
138 def get_stripe_coupon_params(self) -> stripe_lib.Coupon.CreateParams: 1ab
139 params: stripe_lib.Coupon.CreateParams = {
140 "name": self.name[:40],
141 "duration": cast(Literal["once", "forever", "repeating"], self.duration),
142 "metadata": {
143 "discount_id": str(self.id),
144 "organization_id": str(self.organization.id),
145 },
146 }
147 if self.duration_in_months is not None:
148 params["duration_in_months"] = self.duration_in_months
149 return params
151 def is_applicable(self, product: "Product") -> bool: 1ab
152 if len(self.products) == 0:
153 return True
154 return product in self.products
156 def is_repetition_expired( 1ab
157 self,
158 started_at: datetime,
159 current_period_start: datetime,
160 trial_ended: bool = False,
161 ) -> bool:
162 if self.duration == DiscountDuration.once:
163 # If transitioning from trial to active, this is the first billed cycle
164 # so the discount should still apply
165 return not trial_ended
166 if self.duration == DiscountDuration.forever:
167 return False
168 if self.duration_in_months is None:
169 return False
171 # -1 because the first month counts as a first repetition
172 end_at = started_at + relativedelta(months=self.duration_in_months - 1)
173 return current_period_start > end_at
175 __mapper_args__ = { 1ab
176 "polymorphic_on": "type",
177 }
180class DiscountFixed(Discount): 1ab
181 type: Mapped[Literal[DiscountType.fixed]] = mapped_column(use_existing_column=True) 1ab
182 amount: Mapped[int] = mapped_column(Integer, nullable=True) 1ab
183 currency: Mapped[str] = mapped_column( 1ab
184 String(3), nullable=True, use_existing_column=True
185 )
187 def get_discount_amount(self, amount: int) -> int: 1ab
188 return min(self.amount, amount)
190 def get_stripe_coupon_params(self) -> stripe_lib.Coupon.CreateParams: 1ab
191 params = super().get_stripe_coupon_params()
192 return {
193 **params,
194 "amount_off": self.amount,
195 "currency": self.currency,
196 }
198 __mapper_args__ = { 1ab
199 "polymorphic_identity": DiscountType.fixed,
200 "polymorphic_load": "inline",
201 }
204class DiscountPercentage(Discount): 1ab
205 type: Mapped[Literal[DiscountType.percentage]] = mapped_column( 1ab
206 use_existing_column=True
207 )
208 basis_points: Mapped[int] = mapped_column(Integer, nullable=True) 1ab
210 def get_discount_amount(self, amount: int) -> int: 1ab
211 discount_amount_float = amount * (self.basis_points / 10_000)
212 return polar_round(discount_amount_float)
214 def get_stripe_coupon_params(self) -> stripe_lib.Coupon.CreateParams: 1ab
215 params = super().get_stripe_coupon_params()
216 return {
217 **params,
218 "percent_off": self.basis_points / 100,
219 }
221 __mapper_args__ = { 1ab
222 "polymorphic_identity": DiscountType.percentage,
223 "polymorphic_load": "inline",
224 }