Coverage for polar/models/checkout.py: 61%

242 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1import uuid 1ab

2from collections.abc import Sequence 1ab

3from datetime import datetime, timedelta 1ab

4from enum import StrEnum 1ab

5from typing import TYPE_CHECKING, Any, TypedDict 1ab

6from uuid import UUID 1ab

7 

8from sqlalchemy import ( 1ab

9 TIMESTAMP, 

10 Boolean, 

11 ColumnElement, 

12 Connection, 

13 ForeignKey, 

14 Integer, 

15 String, 

16 Uuid, 

17 event, 

18) 

19from sqlalchemy.dialects.postgresql import JSONB 1ab

20from sqlalchemy.ext.associationproxy import AssociationProxy, association_proxy 1ab

21from sqlalchemy.ext.hybrid import hybrid_property 1ab

22from sqlalchemy.orm import Mapped, Mapper, declared_attr, mapped_column, relationship 1ab

23 

24from polar.config import settings 1ab

25from polar.custom_field.data import CustomFieldDataMixin 1ab

26from polar.enums import PaymentProcessor 1ab

27from polar.kit.address import Address, AddressType 1ab

28from polar.kit.db.models import RecordModel 1ab

29from polar.kit.metadata import MetadataColumn, MetadataMixin 1ab

30from polar.kit.tax import TaxID, TaxIDType 1ab

31from polar.kit.trial import TrialConfigurationMixin, TrialInterval 1ab

32from polar.kit.utils import utc_now 1ab

33from polar.product.guard import ( 1ab

34 is_discount_applicable, 

35 is_free_price, 

36 is_metered_price, 

37) 

38 

39from .customer import Customer 1ab

40from .discount import Discount 1ab

41from .organization import Organization 1ab

42from .product import Product 1ab

43from .product_price import ProductPrice, ProductPriceSeatUnit 1ab

44from .subscription import Subscription 1ab

45 

46if TYPE_CHECKING: 46 ↛ 47line 46 didn't jump to line 47 because the condition on line 46 was never true1ab

47 from polar.custom_field.attachment import AttachedCustomFieldMixin 

48 

49 from .checkout_product import CheckoutProduct 

50 

51 

52def get_expires_at() -> datetime: 1ab

53 return utc_now() + timedelta(seconds=settings.CHECKOUT_TTL_SECONDS) 

54 

55 

56class CheckoutStatus(StrEnum): 1ab

57 open = "open" 1ab

58 expired = "expired" 1ab

59 confirmed = "confirmed" 1ab

60 succeeded = "succeeded" 1ab

61 failed = "failed" 1ab

62 

63 

64class CheckoutCustomerBillingAddressFields(TypedDict): 1ab

65 """ 

66 Deprecated: Use CheckoutBillingAddressFields instead. 

67 """ 

68 

69 country: bool 1ab

70 state: bool 1ab

71 city: bool 1ab

72 postal_code: bool 1ab

73 line1: bool 1ab

74 line2: bool 1ab

75 

76 

77class BillingAddressFieldMode(StrEnum): 1ab

78 required = "required" 1ab

79 optional = "optional" 1ab

80 disabled = "disabled" 1ab

81 

82 

83class CheckoutBillingAddressFields(TypedDict): 1ab

84 country: BillingAddressFieldMode 1ab

85 state: BillingAddressFieldMode 1ab

86 city: BillingAddressFieldMode 1ab

87 postal_code: BillingAddressFieldMode 1ab

88 line1: BillingAddressFieldMode 1ab

89 line2: BillingAddressFieldMode 1ab

90 

91 

92class Checkout( 1ab

93 TrialConfigurationMixin, CustomFieldDataMixin, MetadataMixin, RecordModel 

94): 

95 __tablename__ = "checkouts" 1ab

96 

97 payment_processor: Mapped[PaymentProcessor] = mapped_column( 1ab

98 String, nullable=False, default=PaymentProcessor.stripe, index=True 

99 ) 

100 status: Mapped[CheckoutStatus] = mapped_column( 1ab

101 String, nullable=False, default=CheckoutStatus.open, index=True 

102 ) 

103 client_secret: Mapped[str] = mapped_column( 1ab

104 String, index=True, nullable=False, unique=True 

105 ) 

106 expires_at: Mapped[datetime] = mapped_column( 1ab

107 TIMESTAMP(timezone=True), index=True, nullable=False, default=get_expires_at 

108 ) 

109 payment_processor_metadata: Mapped[dict[str, Any]] = mapped_column( 1ab

110 JSONB, nullable=False, default=dict 

111 ) 

112 return_url: Mapped[str | None] = mapped_column(String, nullable=True, default=None) 1ab

113 _success_url: Mapped[str | None] = mapped_column( 1ab

114 "success_url", String, nullable=True, default=None 

115 ) 

116 embed_origin: Mapped[str | None] = mapped_column(String, nullable=True) 1ab

117 allow_discount_codes: Mapped[bool] = mapped_column( 1ab

118 Boolean, nullable=False, default=True 

119 ) 

120 require_billing_address: Mapped[bool] = mapped_column( 1ab

121 Boolean, nullable=False, default=False 

122 ) 

123 

124 amount: Mapped[int] = mapped_column(Integer, nullable=False) 1ab

125 currency: Mapped[str] = mapped_column(String(3), nullable=False) 1ab

126 seats: Mapped[int | None] = mapped_column(Integer, nullable=True, default=None) 1ab

127 

128 tax_amount: Mapped[int | None] = mapped_column(Integer, nullable=True, default=None) 1ab

129 tax_processor_id: Mapped[str | None] = mapped_column( 1ab

130 String, nullable=True, default=None 

131 ) 

132 

133 # TODO: proper data migration to make it non-nullable 

134 allow_trial: Mapped[bool | None] = mapped_column( 1ab

135 Boolean, nullable=True, default=True 

136 ) 

137 trial_end: Mapped[datetime | None] = mapped_column( 1ab

138 TIMESTAMP(timezone=True), nullable=True, default=None 

139 ) 

140 

141 organization_id: Mapped[UUID] = mapped_column( 1ab

142 Uuid, 

143 ForeignKey("organizations.id", ondelete="cascade"), 

144 nullable=False, 

145 index=True, 

146 ) 

147 

148 @declared_attr 1ab

149 def organization(cls) -> Mapped["Organization"]: 1ab

150 return relationship("Organization", lazy="raise") 1ab

151 

152 product_id: Mapped[UUID | None] = mapped_column( 1ab

153 Uuid, ForeignKey("products.id", ondelete="cascade"), nullable=True 

154 ) 

155 

156 @declared_attr 1ab

157 def product(cls) -> Mapped[Product | None]: 1ab

158 return relationship(Product, lazy="raise") 1ab

159 

160 product_price_id: Mapped[UUID | None] = mapped_column( 1ab

161 Uuid, ForeignKey("product_prices.id", ondelete="cascade"), nullable=True 

162 ) 

163 

164 @declared_attr 1ab

165 def product_price(cls) -> Mapped[ProductPrice | None]: 1ab

166 return relationship(ProductPrice, lazy="raise") 1ab

167 

168 checkout_products: Mapped[list["CheckoutProduct"]] = relationship( 1ab

169 "CheckoutProduct", 

170 back_populates="checkout", 

171 cascade="all, delete-orphan", 

172 order_by="CheckoutProduct.order", 

173 lazy="raise", 

174 ) 

175 

176 products: AssociationProxy[list["Product"]] = association_proxy( 1ab

177 "checkout_products", "product" 

178 ) 

179 

180 discount_id: Mapped[UUID | None] = mapped_column( 1ab

181 Uuid, ForeignKey("discounts.id", ondelete="set null"), nullable=True 

182 ) 

183 

184 @declared_attr 1ab

185 def discount(cls) -> Mapped[Discount | None]: 1ab

186 return relationship(Discount, lazy="raise") 1ab

187 

188 customer_id: Mapped[UUID | None] = mapped_column( 1ab

189 Uuid, ForeignKey("customers.id", ondelete="set null"), nullable=True 

190 ) 

191 

192 @declared_attr 1ab

193 def customer(cls) -> Mapped[Customer | None]: 1ab

194 return relationship(Customer, lazy="raise") 1ab

195 

196 is_business_customer: Mapped[bool] = mapped_column( 1ab

197 Boolean, nullable=False, default=False 

198 ) 

199 external_customer_id: Mapped[str | None] = mapped_column( 1ab

200 String, nullable=True, default=None 

201 ) 

202 customer_name: Mapped[str | None] = mapped_column( 1ab

203 String, nullable=True, default=None 

204 ) 

205 customer_email: Mapped[str | None] = mapped_column( 1ab

206 String, nullable=True, default=None 

207 ) 

208 _customer_ip_address: Mapped[str | None] = mapped_column( 1ab

209 "customer_ip_address", String, nullable=True, default=None 

210 ) 

211 customer_billing_name: Mapped[str | None] = mapped_column( 1ab

212 String, nullable=True, default=None 

213 ) 

214 customer_billing_address: Mapped[Address | None] = mapped_column( 1ab

215 AddressType, nullable=True, default=None 

216 ) 

217 customer_tax_id: Mapped[TaxID | None] = mapped_column( 1ab

218 TaxIDType, nullable=True, default=None 

219 ) 

220 customer_metadata: Mapped[MetadataColumn] 1ab

221 

222 # Only set when a checkout is attached to an existing subscription (free-to-paid upgrades). 

223 # For subscriptions created by the checkout itself, see `Subscription.checkout_id`. 

224 subscription_id: Mapped[UUID | None] = mapped_column( 1ab

225 Uuid, ForeignKey("subscriptions.id", ondelete="set null"), nullable=True 

226 ) 

227 

228 @declared_attr 1ab

229 def subscription(cls) -> Mapped[Subscription | None]: 1ab

230 return relationship( 1ab

231 Subscription, 

232 lazy="raise", 

233 foreign_keys=[cls.subscription_id], # type: ignore 

234 ) 

235 

236 @hybrid_property 1ab

237 def is_expired(self) -> bool: 1ab

238 return self.expires_at < utc_now() 

239 

240 @is_expired.inplace.expression 1ab

241 @classmethod 1ab

242 def _is_expired_expression(cls) -> ColumnElement[bool]: 1ab

243 return cls.expires_at < utc_now() 

244 

245 @hybrid_property 1ab

246 def customer_ip_address(self) -> str | None: 1ab

247 return self._customer_ip_address 

248 

249 @customer_ip_address.inplace.setter 1ab

250 def _customer_ip_address_setter(self, value: Any | None) -> None: 1ab

251 self._customer_ip_address = str(value) if value is not None else None 

252 

253 @property 1ab

254 def success_url(self) -> str: 1ab

255 if self._success_url is None: 

256 return settings.generate_frontend_url( 

257 f"/checkout/{self.client_secret}/confirmation" 

258 ) 

259 try: 

260 return self._success_url.format(CHECKOUT_ID=self.id) 

261 except KeyError: 

262 return self._success_url 

263 

264 @success_url.setter 1ab

265 def success_url(self, value: str | None) -> None: 1ab

266 self._success_url = str(value) if value is not None else None 

267 

268 @property 1ab

269 def customer_tax_id_number(self) -> str | None: 1ab

270 return self.customer_tax_id[0] if self.customer_tax_id is not None else None 

271 

272 @property 1ab

273 def discount_amount(self) -> int: 1ab

274 return self.discount.get_discount_amount(self.amount) if self.discount else 0 

275 

276 @property 1ab

277 def net_amount(self) -> int: 1ab

278 return self.amount - self.discount_amount 

279 

280 @property 1ab

281 def total_amount(self) -> int: 1ab

282 return self.net_amount + (self.tax_amount or 0) 

283 

284 @property 1ab

285 def is_discount_applicable(self) -> bool: 1ab

286 if self.product_prices is None: 

287 return False 

288 return any(is_discount_applicable(price) for price in self.product_prices) 

289 

290 @property 1ab

291 def is_free_product_price(self) -> bool: 1ab

292 if self.product_prices is None: 

293 return False 

294 return all(is_free_price(price) for price in self.product_prices) 

295 

296 @property 1ab

297 def has_metered_prices(self) -> bool: 1ab

298 if self.product_prices is None: 

299 return False 

300 return any(is_metered_price(price) for price in self.product_prices) 

301 

302 @property 1ab

303 def is_payment_required(self) -> bool: 1ab

304 return self.total_amount > 0 and self.trial_end is None 

305 

306 @property 1ab

307 def is_payment_setup_required(self) -> bool: 1ab

308 if self.product is None: 

309 return False 

310 return self.product.is_recurring and not self.is_free_product_price 

311 

312 @property 1ab

313 def should_save_payment_method(self) -> bool: 1ab

314 return self.product is not None and self.product.is_recurring 

315 

316 @property 1ab

317 def is_payment_form_required(self) -> bool: 1ab

318 return self.is_payment_required or self.is_payment_setup_required 

319 

320 @property 1ab

321 def url(self) -> str: 1ab

322 return settings.generate_frontend_url(f"/checkout/{self.client_secret}") 

323 

324 @property 1ab

325 def customer_session_token(self) -> str | None: 1ab

326 return getattr(self, "_customer_session_token", None) 

327 

328 @customer_session_token.setter 1ab

329 def customer_session_token(self, value: str) -> None: 1ab

330 self._customer_session_token = value 

331 

332 attached_custom_fields: AssociationProxy[ 1ab

333 Sequence["AttachedCustomFieldMixin"] | None 

334 ] = association_proxy("product", "attached_custom_fields") 

335 

336 @property 1ab

337 def customer_billing_address_fields(self) -> CheckoutCustomerBillingAddressFields: 1ab

338 address = self.customer_billing_address 

339 country = address.country if address else None 

340 is_us = country == "US" 

341 require_billing_address = ( 

342 self.require_billing_address or self.is_business_customer or is_us 

343 ) 

344 return { 

345 "country": True, 

346 "state": country in {"US", "CA"}, 

347 "line1": require_billing_address, 

348 "line2": False, 

349 "city": require_billing_address, 

350 "postal_code": require_billing_address, 

351 } 

352 

353 @property 1ab

354 def billing_address_fields(self) -> CheckoutBillingAddressFields: 1ab

355 address = self.customer_billing_address 

356 country = address.country if address else None 

357 is_us = country == "US" 

358 require_billing_address = ( 

359 self.require_billing_address or self.is_business_customer or is_us 

360 ) 

361 return { 

362 "country": BillingAddressFieldMode.required, 

363 "state": BillingAddressFieldMode.required 

364 if country in {"US", "CA"} 

365 else ( 

366 BillingAddressFieldMode.optional 

367 if require_billing_address 

368 else BillingAddressFieldMode.disabled 

369 ), 

370 "line1": BillingAddressFieldMode.required 

371 if require_billing_address 

372 else BillingAddressFieldMode.disabled, 

373 "line2": BillingAddressFieldMode.optional 

374 if require_billing_address 

375 else BillingAddressFieldMode.disabled, 

376 "city": BillingAddressFieldMode.required 

377 if require_billing_address 

378 else BillingAddressFieldMode.disabled, 

379 "postal_code": BillingAddressFieldMode.required 

380 if require_billing_address 

381 else BillingAddressFieldMode.disabled, 

382 } 

383 

384 @property 1ab

385 def active_trial_interval(self) -> TrialInterval | None: 1ab

386 if not self.allow_trial: 

387 return None 

388 if self.product is None: 

389 return None 

390 return self.trial_interval or self.product.trial_interval 

391 

392 @property 1ab

393 def active_trial_interval_count(self) -> int | None: 1ab

394 if not self.allow_trial: 

395 return None 

396 if self.product is None: 

397 return None 

398 return self.trial_interval_count or self.product.trial_interval_count 

399 

400 @property 1ab

401 def price_per_seat(self) -> int | None: 1ab

402 if not isinstance(self.product_price, ProductPriceSeatUnit): 

403 return None 

404 

405 if self.seats is None: 

406 return None 

407 

408 return self.product_price.get_price_per_seat(self.seats) 

409 

410 @property 1ab

411 def description(self) -> str: 1ab

412 if self.product is not None: 

413 return f"{self.organization.name}{self.product.name}" 

414 raise NotImplementedError() 

415 

416 @property 1ab

417 def prices(self) -> dict[uuid.UUID, list[ProductPrice]]: 1ab

418 prices: dict[uuid.UUID, list[ProductPrice]] = {} 

419 for checkout_product in self.checkout_products: 

420 if checkout_product.ad_hoc_prices: 

421 prices[checkout_product.product_id] = checkout_product.ad_hoc_prices 

422 else: 

423 prices[checkout_product.product_id] = checkout_product.product.prices 

424 return prices 

425 

426 @property 1ab

427 def product_prices(self) -> list[ProductPrice] | None: 1ab

428 if self.product_id is None: 

429 return None 

430 return self.prices[self.product_id] 

431 

432 

433@event.listens_for(Checkout, "before_update") 1ab

434def check_expiration( 1ab

435 mapper: Mapper[Any], connection: Connection, target: Checkout 

436) -> None: 

437 if target.expires_at < utc_now() and target.status == CheckoutStatus.open: 

438 target.status = CheckoutStatus.expired