Coverage for polar/models/product_price.py: 63%

189 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 16:17 +0000

1from decimal import Decimal 1ab

2from enum import StrEnum 1ab

3from typing import TYPE_CHECKING, Any, Literal, TypedDict 1ab

4from uuid import UUID 1ab

5 

6import stripe as stripe_lib 1ab

7from babel.numbers import format_currency, format_decimal 1ab

8from sqlalchemy import ( 1ab

9 Boolean, 

10 ColumnElement, 

11 ForeignKey, 

12 Integer, 

13 Numeric, 

14 String, 

15 Uuid, 

16 case, 

17 event, 

18 func, 

19 type_coerce, 

20) 

21from sqlalchemy.dialects import postgresql 1ab

22from sqlalchemy.ext.hybrid import hybrid_property 1ab

23from sqlalchemy.orm import ( 1ab

24 Mapped, 

25 declared_attr, 

26 mapped_column, 

27 object_mapper, 

28 relationship, 

29) 

30 

31from polar.enums import SubscriptionRecurringInterval 1ab

32from polar.kit.db.models import RecordModel 1ab

33from polar.kit.extensions.sqlalchemy.types import StringEnum 1ab

34from polar.kit.math import polar_round 1ab

35 

36if TYPE_CHECKING: 36 ↛ 37line 36 didn't jump to line 37 because the condition on line 36 was never true1ab

37 from polar.models import Meter, Product 

38 

39 

40class ProductPriceType(StrEnum): 1ab

41 one_time = "one_time" 1ab

42 recurring = "recurring" 1ab

43 

44 def as_literal(self) -> Literal["one_time", "recurring"]: 1ab

45 return self.value 

46 

47 

48class ProductPriceAmountType(StrEnum): 1ab

49 fixed = "fixed" 1ab

50 custom = "custom" 1ab

51 free = "free" 1ab

52 metered_unit = "metered_unit" 1ab

53 seat_based = "seat_based" 1ab

54 

55 

56class ProductPriceSource(StrEnum): 1ab

57 catalog = "catalog" 1ab

58 ad_hoc = "ad_hoc" 1ab

59 

60 

61class SeatTier(TypedDict): 1ab

62 """A single pricing tier for seat-based pricing.""" 

63 

64 min_seats: int 1ab

65 max_seats: int | None 1ab

66 price_per_seat: int 1ab

67 

68 

69class SeatTiersData(TypedDict): 1ab

70 """The structure of the seat_tiers JSONB column.""" 

71 

72 tiers: list[SeatTier] 1ab

73 

74 

75class HasPriceCurrency: 1ab

76 price_currency: Mapped[str] = mapped_column( 1ab

77 String(3), nullable=True, use_existing_column=True 

78 ) 

79 

80 

81class HasStripePriceId: 1ab

82 stripe_price_id: Mapped[str] = mapped_column( 1ab

83 String, nullable=True, use_existing_column=True 

84 ) 

85 

86 def get_stripe_price_params( 1ab

87 self, recurring_interval: SubscriptionRecurringInterval | None 

88 ) -> stripe_lib.Price.CreateParams: 

89 raise NotImplementedError() 

90 

91 

92LEGACY_IDENTITY_PREFIX = "legacy_" 1ab

93 

94 

95class ProductPrice(RecordModel): 1ab

96 __tablename__ = "product_prices" 1ab

97 

98 # Legacy: recurring is now set on product 

99 type: Mapped[Any] = mapped_column(String, nullable=True, index=True, default=None) 1ab

100 recurring_interval: Mapped[Any] = mapped_column( 1ab

101 StringEnum(SubscriptionRecurringInterval), 

102 nullable=True, 

103 index=True, 

104 default=None, 

105 ) 

106 

107 source = mapped_column( 1ab

108 StringEnum(ProductPriceSource), 

109 nullable=False, 

110 index=True, 

111 default=ProductPriceSource.catalog, 

112 ) 

113 amount_type: Mapped[ProductPriceAmountType] = mapped_column( 1ab

114 String, nullable=False, index=True 

115 ) 

116 is_archived: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) 1ab

117 

118 product_id: Mapped[UUID] = mapped_column( 1ab

119 Uuid, ForeignKey("products.id", ondelete="cascade"), nullable=False, index=True 

120 ) 

121 

122 checkout_product_id: Mapped[UUID | None] = mapped_column( 1ab

123 Uuid, 

124 ForeignKey("checkout_products.id", ondelete="set null"), 

125 nullable=True, 

126 index=True, 

127 default=None, 

128 ) 

129 """ 1ab

130 Foreign key to the CheckoutProduct this price is associated with, if any. 

131 

132 Used for ad-hoc prices created on-demand for checkout sessions. 

133 """ 

134 

135 @declared_attr 1ab

136 def product(cls) -> Mapped["Product"]: 1ab

137 return relationship("Product", lazy="raise_on_sql", back_populates="all_prices") 1ab

138 

139 @declared_attr 1ab

140 def checkout_product(cls) -> Mapped["Product | None"]: 1ab

141 return relationship( 1ab

142 "CheckoutProduct", lazy="raise_on_sql", back_populates="ad_hoc_prices" 

143 ) 

144 

145 @hybrid_property 1ab

146 def is_recurring(self) -> bool: 1ab

147 return self.type == ProductPriceType.recurring 

148 

149 @is_recurring.inplace.expression 1ab

150 @classmethod 1ab

151 def _is_recurring_expression(cls) -> ColumnElement[bool]: 1ab

152 return type_coerce(cls.type == ProductPriceType.recurring, Boolean) 

153 

154 @hybrid_property 1ab

155 def is_static(self) -> bool: 1ab

156 return self.amount_type in { 

157 ProductPriceAmountType.fixed, 

158 ProductPriceAmountType.free, 

159 ProductPriceAmountType.custom, 

160 ProductPriceAmountType.seat_based, 

161 } 

162 

163 @is_static.inplace.expression 1ab

164 @classmethod 1ab

165 def _is_static_price_expression(cls) -> ColumnElement[bool]: 1ab

166 return cls.amount_type.in_( 

167 ( 

168 ProductPriceAmountType.fixed, 

169 ProductPriceAmountType.free, 

170 ProductPriceAmountType.custom, 

171 ProductPriceAmountType.seat_based, 

172 ) 

173 ) 

174 

175 @hybrid_property 1ab

176 def is_metered(self) -> bool: 1ab

177 return self.amount_type in {ProductPriceAmountType.metered_unit} 

178 

179 @is_metered.inplace.expression 1ab

180 @classmethod 1ab

181 def _is_metered_price_expression(cls) -> ColumnElement[bool]: 1ab

182 return cls.amount_type.in_((ProductPriceAmountType.metered_unit,)) 

183 

184 @property 1ab

185 def legacy_type(self) -> ProductPriceType | None: 1ab

186 if self.product.is_recurring: 

187 return ProductPriceType.recurring 

188 return ProductPriceType.one_time 

189 

190 @property 1ab

191 def legacy_recurring_interval(self) -> SubscriptionRecurringInterval | None: 1ab

192 return self.product.recurring_interval 

193 

194 __mapper_args__ = { 1ab

195 "polymorphic_on": case( 

196 (type.is_(None), amount_type), 

197 else_=func.concat(LEGACY_IDENTITY_PREFIX, amount_type), 

198 ) 

199 } 

200 

201 

202class LegacyRecurringProductPrice: 1ab

203 __abstract__ = True 1ab

204 

205 type: Mapped[ProductPriceType] = mapped_column( 1ab

206 use_existing_column=True, nullable=True 

207 ) 

208 recurring_interval: Mapped[SubscriptionRecurringInterval] = mapped_column( 1ab

209 use_existing_column=True, nullable=True 

210 ) 

211 

212 __mapper_args__ = { 1ab

213 "polymorphic_abstract": True, 

214 "polymorphic_load": "inline", 

215 } 

216 

217 

218class NewProductPrice: 1ab

219 __abstract__ = True 1ab

220 

221 type: Mapped[Literal[None]] = mapped_column( 1ab

222 use_existing_column=True, nullable=True, default=None 

223 ) 

224 recurring_interval: Mapped[Literal[None]] = mapped_column( 1ab

225 use_existing_column=True, nullable=True, default=None 

226 ) 

227 

228 __mapper_args__ = { 1ab

229 "polymorphic_abstract": True, 

230 "polymorphic_load": "inline", 

231 } 

232 

233 

234class _ProductPriceFixed(HasStripePriceId, HasPriceCurrency, ProductPrice): 1ab

235 price_amount: Mapped[int] = mapped_column(Integer, nullable=True) 1ab

236 amount_type: Mapped[Literal[ProductPriceAmountType.fixed]] = mapped_column( 1ab

237 use_existing_column=True, default=ProductPriceAmountType.fixed 

238 ) 

239 

240 def get_stripe_price_params( 1ab

241 self, recurring_interval: SubscriptionRecurringInterval | None 

242 ) -> stripe_lib.Price.CreateParams: 

243 params: stripe_lib.Price.CreateParams = { 

244 "unit_amount": self.price_amount, 

245 "currency": self.price_currency, 

246 } 

247 if recurring_interval is not None: 

248 params = { 

249 **params, 

250 "recurring": {"interval": recurring_interval.as_literal()}, 

251 } 

252 return params 

253 

254 __mapper_args__ = { 1ab

255 "polymorphic_abstract": True, 

256 "polymorphic_load": "inline", 

257 } 

258 

259 

260class ProductPriceFixed(NewProductPrice, _ProductPriceFixed): 1ab

261 __mapper_args__ = { 1ab

262 "polymorphic_identity": ProductPriceAmountType.fixed, 

263 "polymorphic_load": "inline", 

264 } 

265 

266 

267class LegacyRecurringProductPriceFixed(LegacyRecurringProductPrice, _ProductPriceFixed): 1ab

268 __mapper_args__ = { 1ab

269 "polymorphic_identity": f"{LEGACY_IDENTITY_PREFIX}{ProductPriceAmountType.fixed}", 

270 "polymorphic_load": "inline", 

271 } 

272 

273 

274class _ProductPriceCustom(HasStripePriceId, HasPriceCurrency, ProductPrice): 1ab

275 amount_type: Mapped[Literal[ProductPriceAmountType.custom]] = mapped_column( 1ab

276 use_existing_column=True, default=ProductPriceAmountType.custom 

277 ) 

278 minimum_amount: Mapped[int | None] = mapped_column( 1ab

279 Integer, nullable=True, default=None 

280 ) 

281 maximum_amount: Mapped[int | None] = mapped_column( 1ab

282 Integer, nullable=True, default=None 

283 ) 

284 preset_amount: Mapped[int | None] = mapped_column( 1ab

285 Integer, nullable=True, default=None 

286 ) 

287 

288 def get_stripe_price_params( 1ab

289 self, recurring_interval: SubscriptionRecurringInterval | None 

290 ) -> stripe_lib.Price.CreateParams: 

291 custom_unit_amount_params: stripe_lib.Price.CreateParamsCustomUnitAmount = { 

292 "enabled": True, 

293 } 

294 if self.minimum_amount is not None: 

295 custom_unit_amount_params["minimum"] = self.minimum_amount 

296 if self.maximum_amount is not None: 

297 custom_unit_amount_params["maximum"] = self.maximum_amount 

298 if self.preset_amount is not None: 

299 custom_unit_amount_params["preset"] = self.preset_amount 

300 

301 # `recurring_interval` is unused because we actually create ad-hoc prices, 

302 # since Stripe doesn't support PWYW pricing for subscriptions. 

303 

304 return { 

305 "currency": self.price_currency, 

306 "custom_unit_amount": custom_unit_amount_params, 

307 } 

308 

309 __mapper_args__ = { 1ab

310 "polymorphic_abstract": True, 

311 "polymorphic_load": "inline", 

312 } 

313 

314 

315class ProductPriceCustom(NewProductPrice, _ProductPriceCustom): 1ab

316 __mapper_args__ = { 1ab

317 "polymorphic_identity": ProductPriceAmountType.custom, 

318 "polymorphic_load": "inline", 

319 } 

320 

321 

322class LegacyRecurringProductPriceCustom( 1ab

323 LegacyRecurringProductPrice, _ProductPriceCustom 

324): 

325 __mapper_args__ = { 1ab

326 "polymorphic_identity": f"{LEGACY_IDENTITY_PREFIX}{ProductPriceAmountType.custom}", 

327 "polymorphic_load": "inline", 

328 } 

329 

330 

331class _ProductPriceFree(HasStripePriceId, ProductPrice): 1ab

332 amount_type: Mapped[Literal[ProductPriceAmountType.free]] = mapped_column( 1ab

333 use_existing_column=True, default=ProductPriceAmountType.free 

334 ) 

335 

336 def get_stripe_price_params( 1ab

337 self, recurring_interval: SubscriptionRecurringInterval | None 

338 ) -> stripe_lib.Price.CreateParams: 

339 params: stripe_lib.Price.CreateParams = { 

340 "unit_amount": 0, 

341 "currency": "usd", 

342 } 

343 if recurring_interval is not None: 

344 params = { 

345 **params, 

346 "recurring": {"interval": recurring_interval.as_literal()}, 

347 } 

348 

349 return params 

350 

351 __mapper_args__ = { 1ab

352 "polymorphic_abstract": True, 

353 "polymorphic_load": "inline", 

354 } 

355 

356 

357class ProductPriceFree(NewProductPrice, _ProductPriceFree): 1ab

358 __mapper_args__ = { 1ab

359 "polymorphic_identity": ProductPriceAmountType.free, 

360 "polymorphic_load": "inline", 

361 } 

362 

363 

364class LegacyRecurringProductPriceFree(LegacyRecurringProductPrice, _ProductPriceFree): 1ab

365 __mapper_args__ = { 1ab

366 "polymorphic_identity": f"{LEGACY_IDENTITY_PREFIX}{ProductPriceAmountType.free}", 

367 "polymorphic_load": "inline", 

368 } 

369 

370 

371class ProductPriceMeteredUnit(ProductPrice, HasPriceCurrency, NewProductPrice): 1ab

372 amount_type: Mapped[Literal[ProductPriceAmountType.metered_unit]] = mapped_column( 1ab

373 use_existing_column=True, default=ProductPriceAmountType.metered_unit 

374 ) 

375 unit_amount: Mapped[Decimal] = mapped_column( 1ab

376 Numeric(17, 12), # 12 decimal places, 17 digits total 

377 # Polymorphic columns must be nullable, as they don't apply to other types 

378 nullable=True, 

379 ) 

380 cap_amount: Mapped[int | None] = mapped_column(Integer, nullable=True, default=None) 1ab

381 meter_id: Mapped[UUID] = mapped_column( 1ab

382 Uuid, 

383 ForeignKey("meters.id"), 

384 # Polymorphic columns must be nullable, as they don't apply to other types 

385 nullable=True, 

386 index=True, 

387 ) 

388 

389 @declared_attr 1ab

390 def meter(cls) -> Mapped["Meter"]: 1ab

391 # For convenience, eager load it, at it's embedded in all schemas outputting a price 

392 return relationship("Meter", lazy="joined") 1ab

393 

394 def get_amount_and_label(self, units: float) -> tuple[int, str]: 1ab

395 label = f"({format_decimal(units, locale='en_US')} consumed units" 

396 

397 label += f") × {format_currency(self.unit_amount / 100, self.price_currency.upper(), locale='en_US')}" 

398 

399 billable_units = Decimal(max(0, units)) 

400 raw_amount = self.unit_amount * billable_units 

401 amount = polar_round(raw_amount) 

402 

403 if self.cap_amount is not None and amount > self.cap_amount: 

404 amount = self.cap_amount 

405 label += f"— Capped at {format_currency(self.cap_amount / 100, self.price_currency.upper(), locale='en_US')}" 

406 

407 return amount, label 

408 

409 __mapper_args__ = { 1ab

410 "polymorphic_identity": ProductPriceAmountType.metered_unit, 

411 "polymorphic_load": "inline", 

412 } 

413 

414 

415class ProductPriceSeatUnit(NewProductPrice, HasPriceCurrency, ProductPrice): 1ab

416 amount_type: Mapped[Literal[ProductPriceAmountType.seat_based]] = mapped_column( 1ab

417 use_existing_column=True, default=ProductPriceAmountType.seat_based 

418 ) 

419 seat_tiers: Mapped[SeatTiersData] = mapped_column( 1ab

420 postgresql.JSONB, 

421 nullable=True, 

422 ) 

423 

424 def get_tier_for_seats(self, seats: int) -> SeatTier: 1ab

425 for tier in self.seat_tiers.get("tiers", []): 

426 min_seats = tier["min_seats"] 

427 max_seats = tier.get("max_seats") 

428 if seats >= min_seats and (max_seats is None or seats <= max_seats): 

429 return tier 

430 raise ValueError(f"No tier found for {seats} seats") 

431 

432 def get_price_per_seat(self, seats: int) -> int: 1ab

433 tier = self.get_tier_for_seats(seats) 

434 return tier["price_per_seat"] 

435 

436 def calculate_amount(self, seats: int) -> int: 1ab

437 return self.get_price_per_seat(seats) * seats 

438 

439 __mapper_args__ = { 1ab

440 "polymorphic_identity": ProductPriceAmountType.seat_based, 

441 "polymorphic_load": "inline", 

442 } 

443 

444 

445@event.listens_for(ProductPrice, "init", propagate=True) 1ab

446def set_identity(instance: ProductPrice, *arg: Any, **kw: Any) -> None: 1ab

447 mapper = object_mapper(instance) 

448 

449 identity: str | None = mapper.polymorphic_identity 

450 

451 if identity is None: 

452 return 

453 

454 if identity.startswith(LEGACY_IDENTITY_PREFIX): 

455 identity = identity[len(LEGACY_IDENTITY_PREFIX) :] 

456 else: 

457 instance.type = None 

458 instance.recurring_interval = None 

459 

460 instance.amount_type = ProductPriceAmountType(identity)