Coverage for polar/models/subscription.py: 61%

211 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1from collections.abc import Sequence 1ab

2from datetime import UTC, datetime 1ab

3from enum import StrEnum 1ab

4from typing import TYPE_CHECKING, Self 1ab

5from uuid import UUID 1ab

6 

7from sqlalchemy import ( 1ab

8 TIMESTAMP, 

9 Boolean, 

10 ColumnElement, 

11 ForeignKey, 

12 Integer, 

13 String, 

14 Text, 

15 Uuid, 

16 event, 

17 type_coerce, 

18) 

19from sqlalchemy.ext.associationproxy import AssociationProxy, association_proxy 1ab

20from sqlalchemy.ext.hybrid import hybrid_property 1ab

21from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship 1ab

22from sqlalchemy.orm.attributes import OP_BULK_REPLACE, Event 1ab

23 

24from polar.custom_field.data import CustomFieldDataMixin 1ab

25from polar.enums import SubscriptionRecurringInterval 1ab

26from polar.kit.db.models import RecordModel 1ab

27from polar.kit.extensions.sqlalchemy.types import StringEnum 1ab

28from polar.kit.metadata import MetadataMixin 1ab

29from polar.product.guard import is_metered_price 1ab

30 

31from .product_price import HasPriceCurrency 1ab

32from .subscription_meter import SubscriptionMeter 1ab

33 

34if TYPE_CHECKING: 34 ↛ 35line 34 didn't jump to line 35 because the condition on line 34 was never true1ab

35 from . import ( 

36 BenefitGrant, 

37 Checkout, 

38 Customer, 

39 CustomerSeat, 

40 Discount, 

41 Meter, 

42 Organization, 

43 PaymentMethod, 

44 Product, 

45 ProductPrice, 

46 SubscriptionProductPrice, 

47 ) 

48 

49 

50class SubscriptionStatus(StrEnum): 1ab

51 incomplete = "incomplete" 1ab

52 incomplete_expired = "incomplete_expired" 1ab

53 trialing = "trialing" 1ab

54 active = "active" 1ab

55 past_due = "past_due" 1ab

56 canceled = "canceled" 1ab

57 unpaid = "unpaid" 1ab

58 

59 @classmethod 1ab

60 def incomplete_statuses(cls) -> set[Self]: 1ab

61 return {cls.incomplete, cls.incomplete_expired} # type: ignore 

62 

63 @classmethod 1ab

64 def active_statuses(cls) -> set[Self]: 1ab

65 return {cls.trialing, cls.active} # type: ignore 

66 

67 @classmethod 1ab

68 def revoked_statuses(cls) -> set[Self]: 1ab

69 return {cls.past_due, cls.canceled, cls.unpaid} # type: ignore 

70 

71 @classmethod 1ab

72 def billable_statuses(cls) -> set[Self]: 1ab

73 return cls.active_statuses() | {cls.past_due} # type: ignore 

74 

75 @classmethod 1ab

76 def is_incomplete(cls, status: Self) -> bool: 1ab

77 return status in cls.incomplete_statuses() 

78 

79 @classmethod 1ab

80 def is_active(cls, status: Self) -> bool: 1ab

81 return status in cls.active_statuses() 

82 

83 @classmethod 1ab

84 def is_revoked(cls, status: Self) -> bool: 1ab

85 return status in cls.revoked_statuses() 

86 

87 @classmethod 1ab

88 def is_billable(cls, status: Self) -> bool: 1ab

89 return status in cls.billable_statuses() 

90 

91 

92class CustomerCancellationReason(StrEnum): 1ab

93 customer_service = "customer_service" 1ab

94 low_quality = "low_quality" 1ab

95 missing_features = "missing_features" 1ab

96 switched_service = "switched_service" 1ab

97 too_complex = "too_complex" 1ab

98 too_expensive = "too_expensive" 1ab

99 unused = "unused" 1ab

100 other = "other" 1ab

101 

102 

103class Subscription(CustomFieldDataMixin, MetadataMixin, RecordModel): 1ab

104 __tablename__ = "subscriptions" 1ab

105 

106 amount: Mapped[int] = mapped_column(Integer, nullable=False) 1ab

107 currency: Mapped[str] = mapped_column(String(3), nullable=False) 1ab

108 recurring_interval: Mapped[SubscriptionRecurringInterval] = mapped_column( 1ab

109 StringEnum(SubscriptionRecurringInterval), nullable=False, index=True 

110 ) 

111 recurring_interval_count: Mapped[int] = mapped_column(Integer, nullable=False) 1ab

112 

113 stripe_subscription_id: Mapped[str | None] = mapped_column( 1ab

114 String, nullable=True, index=True, default=None 

115 ) 

116 """ 1ab

117 The ID of the subscription in Stripe. 

118 

119 If set, indicates that the subscription is managed by Stripe Billing. 

120 """ 

121 legacy_stripe_subscription_id: Mapped[str | None] = mapped_column( 1ab

122 String, nullable=True, index=True, default=None 

123 ) 

124 """ 1ab

125 Original ID of the subscription in Stripe. 

126 

127 If set, indicates that the subscription was originally managed by Stripe Billing, 

128 but has been migrated to be managed by Polar. 

129 """ 

130 

131 tax_exempted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) 1ab

132 """ 1ab

133 Whether the subscription is tax exempted. 

134 

135 We use this to disable tax on subscriptions created before we were 

136 registered in a given country, so we don't surprise customers with 

137 tax charges. 

138 """ 

139 

140 status: Mapped[SubscriptionStatus] = mapped_column( 1ab

141 StringEnum(SubscriptionStatus), nullable=False 

142 ) 

143 current_period_start: Mapped[datetime] = mapped_column( 1ab

144 TIMESTAMP(timezone=True), nullable=False 

145 ) 

146 current_period_end: Mapped[datetime | None] = mapped_column( 1ab

147 TIMESTAMP(timezone=True), nullable=True, default=None 

148 ) 

149 trial_start: Mapped[datetime | None] = mapped_column( 1ab

150 TIMESTAMP(timezone=True), nullable=True, default=None 

151 ) 

152 trial_end: Mapped[datetime | None] = mapped_column( 1ab

153 TIMESTAMP(timezone=True), nullable=True, default=None 

154 ) 

155 cancel_at_period_end: Mapped[bool] = mapped_column(Boolean, nullable=False) 1ab

156 canceled_at: Mapped[datetime | None] = mapped_column( 1ab

157 TIMESTAMP(timezone=True), nullable=True, default=None, index=True 

158 ) 

159 started_at: Mapped[datetime | None] = mapped_column( 1ab

160 TIMESTAMP(timezone=True), nullable=True, default=None, index=True 

161 ) 

162 ends_at: Mapped[datetime | None] = mapped_column( 1ab

163 TIMESTAMP(timezone=True), nullable=True, default=None, index=True 

164 ) 

165 ended_at: Mapped[datetime | None] = mapped_column( 1ab

166 TIMESTAMP(timezone=True), nullable=True, default=None, index=True 

167 ) 

168 past_due_at: Mapped[datetime | None] = mapped_column( 1ab

169 TIMESTAMP(timezone=True), nullable=True, default=None 

170 ) 

171 

172 scheduler_locked_at: Mapped[datetime | None] = mapped_column( 1ab

173 TIMESTAMP(timezone=True), nullable=True, default=None, index=True 

174 ) 

175 

176 customer_id: Mapped[UUID] = mapped_column( 1ab

177 Uuid, ForeignKey("customers.id", ondelete="cascade"), nullable=False, index=True 

178 ) 

179 

180 @declared_attr 1ab

181 def customer(cls) -> Mapped["Customer"]: 1ab

182 return relationship("Customer", lazy="raise") 1ab

183 

184 payment_method_id: Mapped[UUID | None] = mapped_column( 1ab

185 Uuid, ForeignKey("payment_methods.id", ondelete="set null"), nullable=True 

186 ) 

187 

188 @declared_attr 1ab

189 def payment_method(cls) -> Mapped["PaymentMethod | None"]: 1ab

190 return relationship("PaymentMethod", lazy="raise") 1ab

191 

192 product_id: Mapped[UUID] = mapped_column( 1ab

193 Uuid, 

194 ForeignKey("products.id", ondelete="cascade"), 

195 nullable=False, 

196 index=True, 

197 ) 

198 

199 @declared_attr 1ab

200 def product(cls) -> Mapped["Product"]: 1ab

201 return relationship("Product", lazy="raise") 1ab

202 

203 subscription_product_prices: Mapped[list["SubscriptionProductPrice"]] = ( 1ab

204 relationship( 

205 "SubscriptionProductPrice", 

206 back_populates="subscription", 

207 cascade="all, delete-orphan", 

208 # Prices are almost always needed, so eager loading makes sense 

209 lazy="selectin", 

210 ) 

211 ) 

212 

213 prices: AssociationProxy[list["ProductPrice"]] = association_proxy( 1ab

214 "subscription_product_prices", "product_price" 

215 ) 

216 

217 discount_id: Mapped[UUID | None] = mapped_column( 1ab

218 Uuid, ForeignKey("discounts.id", ondelete="set null"), nullable=True 

219 ) 

220 

221 @declared_attr 1ab

222 def discount(cls) -> Mapped["Discount | None"]: 1ab

223 return relationship("Discount", lazy="joined") 1ab

224 

225 meters: Mapped[list[SubscriptionMeter]] = relationship( 1ab

226 SubscriptionMeter, 

227 order_by="SubscriptionMeter.created_at", 

228 back_populates="subscription", 

229 cascade="all, delete-orphan", 

230 # Eager load 

231 lazy="selectin", 

232 ) 

233 

234 organization: AssociationProxy["Organization"] = association_proxy( 1ab

235 "product", "organization" 

236 ) 

237 

238 checkout_id: Mapped[UUID | None] = mapped_column( 1ab

239 Uuid, ForeignKey("checkouts.id", ondelete="set null"), nullable=True, index=True 

240 ) 

241 

242 customer_cancellation_reason: Mapped[CustomerCancellationReason | None] = ( 1ab

243 mapped_column(String, nullable=True) 

244 ) 

245 customer_cancellation_comment: Mapped[str | None] = mapped_column( 1ab

246 Text, nullable=True 

247 ) 

248 

249 seats: Mapped[int | None] = mapped_column(Integer, nullable=True, default=None) 1ab

250 

251 @declared_attr 1ab

252 def checkout(cls) -> Mapped["Checkout | None"]: 1ab

253 return relationship( 1ab

254 "Checkout", 

255 lazy="raise", 

256 foreign_keys=[cls.checkout_id], # type: ignore 

257 ) 

258 

259 @declared_attr 1ab

260 def grants(cls) -> Mapped[list["BenefitGrant"]]: 1ab

261 return relationship( 1ab

262 "BenefitGrant", 

263 lazy="raise", 

264 order_by="BenefitGrant.benefit_id", 

265 back_populates="subscription", 

266 ) 

267 

268 @declared_attr 1ab

269 def customer_seats(cls) -> Mapped[list["CustomerSeat"]]: 1ab

270 return relationship( 1ab

271 "CustomerSeat", 

272 lazy="raise", 

273 back_populates="subscription", 

274 cascade="all, delete-orphan", 

275 ) 

276 

277 def is_incomplete(self) -> bool: 1ab

278 return SubscriptionStatus.is_incomplete(self.status) 

279 

280 @hybrid_property 1ab

281 def trialing(self) -> bool: 1ab

282 return self.status == SubscriptionStatus.trialing 

283 

284 @trialing.inplace.expression 1ab

285 @classmethod 1ab

286 def _trialing_expression(cls) -> ColumnElement[bool]: 1ab

287 return cls.status == SubscriptionStatus.trialing 

288 

289 @hybrid_property 1ab

290 def active(self) -> bool: 1ab

291 return SubscriptionStatus.is_active(self.status) 

292 

293 @active.inplace.expression 1ab

294 @classmethod 1ab

295 def _active_expression(cls) -> ColumnElement[bool]: 1ab

296 return type_coerce( 

297 cls.status.in_(SubscriptionStatus.active_statuses()), 

298 Boolean, 

299 ) 

300 

301 @hybrid_property 1ab

302 def revoked(self) -> bool: 1ab

303 return SubscriptionStatus.is_revoked(self.status) 

304 

305 @revoked.inplace.expression 1ab

306 @classmethod 1ab

307 def _revoked_expression(cls) -> ColumnElement[bool]: 1ab

308 return type_coerce( 

309 cls.status.in_(SubscriptionStatus.revoked_statuses()), 

310 Boolean, 

311 ) 

312 

313 @hybrid_property 1ab

314 def canceled(self) -> bool: 1ab

315 return self.canceled_at is not None 

316 

317 @canceled.inplace.expression 1ab

318 @classmethod 1ab

319 def _canceled_expression(cls) -> ColumnElement[bool]: 1ab

320 return cls.canceled_at.is_not(None) 

321 

322 @hybrid_property 1ab

323 def billable(self) -> bool: 1ab

324 return SubscriptionStatus.is_billable(self.status) 

325 

326 @billable.inplace.expression 1ab

327 @classmethod 1ab

328 def _billable_expression(cls) -> ColumnElement[bool]: 1ab

329 return type_coerce( 

330 cls.status.in_(SubscriptionStatus.billable_statuses()), 

331 Boolean, 

332 ) 

333 

334 def can_cancel(self, immediately: bool = False) -> bool: 1ab

335 if not SubscriptionStatus.is_billable(self.status): 

336 return False 

337 

338 if self.ended_at: 

339 return False 

340 

341 if immediately: 

342 return True 

343 

344 if self.cancel_at_period_end or self.ends_at: 

345 return False 

346 return True 

347 

348 def can_uncancel(self) -> bool: 1ab

349 return ( 

350 self.cancel_at_period_end 

351 and self.status in SubscriptionStatus.billable_statuses() 

352 ) 

353 

354 def set_started_at(self) -> None: 1ab

355 """ 

356 Stores the starting date when the subscription 

357 becomes active for the first time. 

358 """ 

359 if self.active and self.started_at is None: 

360 self.started_at = datetime.now(UTC) 

361 

362 def update_amount_and_currency( 1ab

363 self, prices: Sequence["SubscriptionProductPrice"], discount: "Discount | None" 

364 ) -> None: 

365 amount = sum(price.amount for price in prices) 

366 if discount is not None: 

367 amount -= discount.get_discount_amount(amount) 

368 self.amount = amount 

369 

370 currencies = set( 

371 price.product_price.price_currency 

372 for price in prices 

373 if isinstance(price.product_price, HasPriceCurrency) 

374 ) 

375 if len(currencies) == 0: 

376 self.currency = "usd" # FIXME: Main Polar currency 

377 elif len(currencies) == 1: 

378 self.currency = currencies.pop() 

379 else: 

380 raise ValueError("Multiple currencies in subscription prices") 

381 

382 def update_meters(self, prices: Sequence["SubscriptionProductPrice"]) -> None: 1ab

383 subscription_meters = self.meters or [] 

384 

385 # Add new ones 

386 price_meters = [ 

387 price.product_price.meter 

388 for price in prices 

389 if is_metered_price(price.product_price) 

390 ] 

391 for price_meter in price_meters: 

392 try: 

393 # Check if the meter already exists in the subscription 

394 next(sm for sm in subscription_meters if sm.meter == price_meter) 

395 except StopIteration: 

396 # If it doesn't, create a new SubscriptionMeter 

397 subscription_meters.append(SubscriptionMeter(meter=price_meter)) 

398 

399 # Remove old ones 

400 for subscription_meter in subscription_meters: 

401 if subscription_meter.meter not in price_meters: 

402 subscription_meters.remove(subscription_meter) 

403 

404 self.meters = subscription_meters 

405 

406 def get_meter(self, meter: "Meter") -> SubscriptionMeter | None: 1ab

407 for subscription_meter in self.meters: 

408 if subscription_meter.meter_id == meter.id: 

409 return subscription_meter 

410 return None 

411 

412 

413@event.listens_for(Subscription.subscription_product_prices, "bulk_replace") 1ab

414def _prices_replaced( 1ab

415 target: Subscription, values: list["SubscriptionProductPrice"], initiator: Event 

416) -> None: 

417 target.update_amount_and_currency(values, target.discount) 

418 target.update_meters(values) 

419 

420 

421@event.listens_for(Subscription.subscription_product_prices, "append") 1ab

422def _price_appended( 1ab

423 target: Subscription, value: "SubscriptionProductPrice", initiator: Event 

424) -> None: 

425 # In case of a bulk replace, do nothing. 

426 # The bulk replace event will handle the update as a whole, preventing errors 

427 # where the append handler deletes a meter on first append which is still needed 

428 # in subsequent appends. 

429 if initiator is not None and initiator.op is OP_BULK_REPLACE: 

430 return 

431 

432 target.update_amount_and_currency( 

433 [*target.subscription_product_prices, value], target.discount 

434 ) 

435 target.update_meters([*target.subscription_product_prices, value]) 

436 

437 

438@event.listens_for(Subscription.discount, "set") 1ab

439def _discount_set( 1ab

440 target: Subscription, 

441 value: "Discount | None", 

442 oldvalue: "Discount | None", 

443 initiator: Event, 

444) -> None: 

445 target.update_amount_and_currency(target.subscription_product_prices, value)