Coverage for polar/models/discount.py: 66%

93 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1from datetime import datetime 1ab

2from enum import StrEnum 1ab

3from typing import TYPE_CHECKING, Literal, cast 1ab

4from uuid import UUID 1ab

5 

6import stripe as stripe_lib 1ab

7from dateutil.relativedelta import relativedelta 1ab

8from sqlalchemy import ( 1ab

9 TIMESTAMP, 

10 Column, 

11 ForeignKey, 

12 Index, 

13 Integer, 

14 String, 

15 Uuid, 

16 func, 

17 select, 

18) 

19from sqlalchemy.dialects.postgresql import CITEXT 1ab

20from sqlalchemy.ext.associationproxy import AssociationProxy, association_proxy 1ab

21from sqlalchemy.orm import ( 1ab

22 Mapped, 

23 column_property, 

24 declared_attr, 

25 mapped_column, 

26 relationship, 

27) 

28 

29from polar.kit.db.models import RecordModel 1ab

30from polar.kit.math import polar_round 1ab

31from polar.kit.metadata import MetadataMixin 1ab

32 

33if TYPE_CHECKING: 33 ↛ 34line 33 didn't jump to line 34 because the condition on line 33 was never true1ab

34 from . import DiscountProduct, DiscountRedemption, Organization, Product 

35 

36 

37class DiscountType(StrEnum): 1ab

38 fixed = "fixed" 1ab

39 percentage = "percentage" 1ab

40 

41 def get_model(self) -> type["Discount"]: 1ab

42 return { 

43 DiscountType.fixed: DiscountFixed, 

44 DiscountType.percentage: DiscountPercentage, 

45 }[self] 

46 

47 

48class DiscountDuration(StrEnum): 1ab

49 once = "once" 1ab

50 forever = "forever" 1ab

51 repeating = "repeating" 1ab

52 

53 

54class Discount(MetadataMixin, RecordModel): 1ab

55 __tablename__ = "discounts" 1ab

56 

57 @declared_attr.directive 1ab

58 def __table_args__(cls) -> tuple[Index]: 1ab

59 # During tests this function is called multiple times which ends up adding the index 

60 # multiple times -- leading to errors. We memoize this function to ensure we end up with 

61 # the index just once. 

62 if not hasattr(cls, "_memoized_indexes"): 1ab

63 _deleted_at_column = cast( 1ab

64 Column[datetime | None], cls.deleted_at 

65 ) # cast to satisfy mypy 

66 cls._memoized_indexes = ( 1ab

67 Index( 

68 "ix_discounts_code_uniqueness", 

69 "organization_id", 

70 func.lower(cls.code), 

71 unique=True, 

72 # partial index 

73 postgresql_where=(_deleted_at_column.is_(None)), 

74 ), 

75 ) 

76 return cls._memoized_indexes 1ab

77 

78 name: Mapped[str] = mapped_column(CITEXT, nullable=False) 1ab

79 type: Mapped[DiscountType] = mapped_column(String, nullable=False) 1ab

80 code: Mapped[str | None] = mapped_column(String, nullable=True, index=True) 1ab

81 

82 starts_at: Mapped[datetime | None] = mapped_column( 1ab

83 TIMESTAMP(timezone=True), nullable=True 

84 ) 

85 ends_at: Mapped[datetime | None] = mapped_column( 1ab

86 TIMESTAMP(timezone=True), nullable=True 

87 ) 

88 max_redemptions: Mapped[int | None] = mapped_column(Integer, nullable=True) 1ab

89 

90 duration: Mapped[DiscountDuration] = mapped_column(String, nullable=False) 1ab

91 duration_in_months: Mapped[int | None] = mapped_column(Integer, nullable=True) 1ab

92 

93 stripe_coupon_id: Mapped[str] = mapped_column( 1ab

94 String, nullable=False, unique=True, index=True 

95 ) 

96 

97 organization_id: Mapped[UUID] = mapped_column( 1ab

98 Uuid, 

99 ForeignKey("organizations.id", ondelete="cascade"), 

100 nullable=False, 

101 index=True, 

102 ) 

103 

104 @declared_attr 1ab

105 def organization(cls) -> Mapped["Organization"]: 1ab

106 return relationship("Organization", lazy="raise") 1ab

107 

108 discount_redemptions: Mapped[list["DiscountRedemption"]] = relationship( 1ab

109 "DiscountRedemption", back_populates="discount", lazy="raise" 

110 ) 

111 

112 discount_products: Mapped[list["DiscountProduct"]] = relationship( 1ab

113 "DiscountProduct", 

114 back_populates="discount", 

115 cascade="all, delete-orphan", 

116 # Products are almost always needed, so eager loading makes sense 

117 lazy="selectin", 

118 ) 

119 

120 products: AssociationProxy[list["Product"]] = association_proxy( 1ab

121 "discount_products", "product" 

122 ) 

123 

124 @declared_attr 1ab

125 def redemptions_count(cls) -> Mapped[int]: 1ab

126 from .discount_redemption import DiscountRedemption 1ab

127 

128 return column_property( 1ab

129 select(func.count(DiscountRedemption.id)) 

130 .where(DiscountRedemption.discount_id == cls.id) 

131 .correlate_except(DiscountRedemption) 

132 .scalar_subquery() 

133 ) 

134 

135 def get_discount_amount(self, amount: int) -> int: 1ab

136 raise NotImplementedError() 

137 

138 def get_stripe_coupon_params(self) -> stripe_lib.Coupon.CreateParams: 1ab

139 params: stripe_lib.Coupon.CreateParams = { 

140 "name": self.name[:40], 

141 "duration": cast(Literal["once", "forever", "repeating"], self.duration), 

142 "metadata": { 

143 "discount_id": str(self.id), 

144 "organization_id": str(self.organization.id), 

145 }, 

146 } 

147 if self.duration_in_months is not None: 

148 params["duration_in_months"] = self.duration_in_months 

149 return params 

150 

151 def is_applicable(self, product: "Product") -> bool: 1ab

152 if len(self.products) == 0: 

153 return True 

154 return product in self.products 

155 

156 def is_repetition_expired( 1ab

157 self, 

158 started_at: datetime, 

159 current_period_start: datetime, 

160 trial_ended: bool = False, 

161 ) -> bool: 

162 if self.duration == DiscountDuration.once: 

163 # If transitioning from trial to active, this is the first billed cycle 

164 # so the discount should still apply 

165 return not trial_ended 

166 if self.duration == DiscountDuration.forever: 

167 return False 

168 if self.duration_in_months is None: 

169 return False 

170 

171 # -1 because the first month counts as a first repetition 

172 end_at = started_at + relativedelta(months=self.duration_in_months - 1) 

173 return current_period_start > end_at 

174 

175 __mapper_args__ = { 1ab

176 "polymorphic_on": "type", 

177 } 

178 

179 

180class DiscountFixed(Discount): 1ab

181 type: Mapped[Literal[DiscountType.fixed]] = mapped_column(use_existing_column=True) 1ab

182 amount: Mapped[int] = mapped_column(Integer, nullable=True) 1ab

183 currency: Mapped[str] = mapped_column( 1ab

184 String(3), nullable=True, use_existing_column=True 

185 ) 

186 

187 def get_discount_amount(self, amount: int) -> int: 1ab

188 return min(self.amount, amount) 

189 

190 def get_stripe_coupon_params(self) -> stripe_lib.Coupon.CreateParams: 1ab

191 params = super().get_stripe_coupon_params() 

192 return { 

193 **params, 

194 "amount_off": self.amount, 

195 "currency": self.currency, 

196 } 

197 

198 __mapper_args__ = { 1ab

199 "polymorphic_identity": DiscountType.fixed, 

200 "polymorphic_load": "inline", 

201 } 

202 

203 

204class DiscountPercentage(Discount): 1ab

205 type: Mapped[Literal[DiscountType.percentage]] = mapped_column( 1ab

206 use_existing_column=True 

207 ) 

208 basis_points: Mapped[int] = mapped_column(Integer, nullable=True) 1ab

209 

210 def get_discount_amount(self, amount: int) -> int: 1ab

211 discount_amount_float = amount * (self.basis_points / 10_000) 

212 return polar_round(discount_amount_float) 

213 

214 def get_stripe_coupon_params(self) -> stripe_lib.Coupon.CreateParams: 1ab

215 params = super().get_stripe_coupon_params() 

216 return { 

217 **params, 

218 "percent_off": self.basis_points / 100, 

219 } 

220 

221 __mapper_args__ = { 1ab

222 "polymorphic_identity": DiscountType.percentage, 

223 "polymorphic_load": "inline", 

224 }