Coverage for polar/discount/service.py: 16%

188 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 15:52 +0000

1import contextlib 1a

2import uuid 1a

3from collections.abc import AsyncIterator, Sequence 1a

4from typing import Any 1a

5 

6from sqlalchemy import Select, UnaryExpression, asc, delete, desc, func, or_, select 1a

7from sqlalchemy.orm import joinedload 1a

8 

9from polar.auth.models import AuthSubject, is_organization, is_user 1a

10from polar.exceptions import PolarError, PolarRequestValidationError 1a

11from polar.integrations.stripe.service import stripe as stripe_service 1a

12from polar.kit.pagination import PaginationParams, paginate 1a

13from polar.kit.services import ResourceServiceReader 1a

14from polar.kit.sorting import Sorting 1a

15from polar.kit.utils import utc_now 1a

16from polar.locker import Locker 1a

17from polar.models import ( 1a

18 Discount, 

19 DiscountProduct, 

20 Organization, 

21 Product, 

22 User, 

23 UserOrganization, 

24) 

25from polar.models.checkout import Checkout 1a

26from polar.models.discount import DiscountFixed 1a

27from polar.models.discount_redemption import DiscountRedemption 1a

28from polar.organization.resolver import get_payload_organization 1a

29from polar.postgres import AsyncSession 1a

30from polar.product.repository import ProductRepository 1a

31 

32from .schemas import DiscountCreate, DiscountUpdate 1a

33from .sorting import DiscountSortProperty 1a

34 

35 

36class DiscountError(PolarError): ... 1a

37 

38 

39class DiscountNotRedeemableError(DiscountError): 1a

40 def __init__(self, discount: Discount): 1a

41 super().__init__(f"Discount {discount.id} is not redeemable.") 

42 

43 

44class DiscountService(ResourceServiceReader[Discount]): 1a

45 async def list( 1a

46 self, 

47 session: AsyncSession, 

48 auth_subject: AuthSubject[User | Organization], 

49 *, 

50 organization_id: Sequence[uuid.UUID] | None = None, 

51 query: str | None = None, 

52 pagination: PaginationParams, 

53 sorting: list[Sorting[DiscountSortProperty]] = [ 

54 (DiscountSortProperty.created_at, True) 

55 ], 

56 ) -> tuple[Sequence[Discount], int]: 

57 statement = self._get_readable_discount_statement(auth_subject) 

58 

59 if organization_id is not None: 

60 statement = statement.where(Discount.organization_id.in_(organization_id)) 

61 

62 if query is not None: 

63 statement = statement.where( 

64 or_( 

65 Discount.name.like(f"%{query}%"), 

66 Discount.code.ilike(f"%{query}%"), 

67 ) 

68 ) 

69 

70 order_by_clauses: list[UnaryExpression[Any]] = [] 

71 for criterion, is_desc in sorting: 

72 clause_function = desc if is_desc else asc 

73 if criterion == DiscountSortProperty.created_at: 

74 order_by_clauses.append(clause_function(Discount.created_at)) 

75 elif criterion == DiscountSortProperty.discount_name: 

76 order_by_clauses.append(clause_function(Discount.name)) 

77 elif criterion == DiscountSortProperty.code: 

78 order_by_clauses.append(clause_function(Discount.code)) 

79 elif criterion == DiscountSortProperty.redemptions_count: 

80 order_by_clauses.append(clause_function(Discount.redemptions_count)) 

81 statement = statement.order_by(*order_by_clauses) 

82 

83 return await paginate(session, statement, pagination=pagination) 

84 

85 async def get_by_id( 1a

86 self, 

87 session: AsyncSession, 

88 auth_subject: AuthSubject[User | Organization], 

89 id: uuid.UUID, 

90 ) -> Discount | None: 

91 statement = ( 

92 self._get_readable_discount_statement(auth_subject) 

93 .where(Discount.id == id) 

94 .options(joinedload(Discount.organization)) 

95 ) 

96 result = await session.execute(statement) 

97 return result.scalar_one_or_none() 

98 

99 async def create( 1a

100 self, 

101 session: AsyncSession, 

102 discount_create: DiscountCreate, 

103 auth_subject: AuthSubject[User | Organization], 

104 ) -> Discount: 

105 organization = await get_payload_organization( 

106 session, auth_subject, discount_create 

107 ) 

108 

109 if discount_create.code is not None: 

110 existing_discount = await self.get_by_code_and_organization( 

111 session, discount_create.code, organization, redeemable=False 

112 ) 

113 if existing_discount is not None: 

114 raise PolarRequestValidationError( 

115 [ 

116 { 

117 "type": "value_error", 

118 "loc": ("body", "code"), 

119 "msg": "Discount with this code already exists.", 

120 "input": discount_create.code, 

121 } 

122 ] 

123 ) 

124 

125 discount_products: list[DiscountProduct] = [] 

126 if discount_create.products: 

127 product_repository = ProductRepository.from_session(session) 

128 for index, product_id in enumerate(discount_create.products): 

129 product = await product_repository.get_by_id_and_organization( 

130 product_id, organization.id 

131 ) 

132 if product is None: 

133 raise PolarRequestValidationError( 

134 [ 

135 { 

136 "type": "value_error", 

137 "loc": ("body", "products", index), 

138 "msg": "Product not found.", 

139 "input": product_id, 

140 } 

141 ] 

142 ) 

143 discount_products.append(DiscountProduct(product=product)) 

144 

145 discount_model = discount_create.type.get_model() 

146 discount_id = uuid.uuid4() 

147 discount = discount_model( 

148 **discount_create.model_dump( 

149 exclude={"organization_id", "products"}, by_alias=True 

150 ), 

151 id=discount_id, 

152 organization=organization, 

153 discount_products=discount_products, 

154 discount_redemptions=[], 

155 redemptions_count=0, 

156 ) 

157 stripe_coupon = await stripe_service.create_coupon( 

158 **discount.get_stripe_coupon_params() 

159 ) 

160 discount.stripe_coupon_id = stripe_coupon.id 

161 

162 session.add(discount) 

163 

164 return discount 

165 

166 async def update( 1a

167 self, 

168 session: AsyncSession, 

169 discount: Discount, 

170 discount_update: DiscountUpdate, 

171 ) -> Discount: 

172 if ( 

173 discount_update.duration is not None 

174 and discount_update.duration != discount.duration 

175 ): 

176 raise PolarRequestValidationError( 

177 [ 

178 { 

179 "type": "value_error", 

180 "loc": ("body", "duration"), 

181 "msg": "Duration cannot be changed.", 

182 "input": discount_update.duration, 

183 } 

184 ] 

185 ) 

186 

187 if discount_update.type is not None and discount_update.type != discount.type: 

188 raise PolarRequestValidationError( 

189 [ 

190 { 

191 "type": "value_error", 

192 "loc": ("body", "type"), 

193 "msg": "Type cannot be changed.", 

194 "input": discount_update.type, 

195 } 

196 ] 

197 ) 

198 

199 if discount.redemptions_count > 0: 

200 forbidden_fields = ( 

201 {"amount", "currency"} 

202 if isinstance(discount, DiscountFixed) 

203 else {"basis_points"} 

204 ) 

205 for field in forbidden_fields: 

206 discount_update_value = getattr(discount_update, field, None) 

207 if ( 

208 discount_update_value is not None 

209 and discount_update_value != getattr(discount, field, None) 

210 ): 

211 raise PolarRequestValidationError( 

212 [ 

213 { 

214 "type": "value_error", 

215 "loc": ("body", field), 

216 "msg": ( 

217 "This field cannot be changed because " 

218 "the discount has already been redeemed." 

219 ), 

220 "input": getattr(discount, field), 

221 } 

222 ] 

223 ) 

224 

225 if discount_update.products is not None: 

226 nested = await session.begin_nested() 

227 discount.discount_products = [] 

228 await session.flush() 

229 

230 product_repository = ProductRepository.from_session(session) 

231 for index, product_id in enumerate(discount_update.products): 

232 product = await product_repository.get_by_id_and_organization( 

233 product_id, discount.organization_id 

234 ) 

235 if product is None: 

236 await nested.rollback() 

237 raise PolarRequestValidationError( 

238 [ 

239 { 

240 "type": "value_error", 

241 "loc": ("body", "products", index), 

242 "msg": "Product not found.", 

243 "input": product_id, 

244 } 

245 ] 

246 ) 

247 discount.discount_products.append(DiscountProduct(product=product)) 

248 

249 updated_fields = set() 

250 exclude = {"products"} 

251 if isinstance(discount, DiscountFixed): 

252 exclude.add("basis_points") 

253 else: 

254 exclude.add("amount") 

255 exclude.add("currency") 

256 for attr, value in discount_update.model_dump( 

257 exclude_unset=True, exclude=exclude, by_alias=True 

258 ).items(): 

259 if value != getattr(discount, attr): 

260 setattr(discount, attr, value) 

261 updated_fields.add(attr) 

262 

263 sensitive_fields = { 

264 "starts_at", 

265 "ends_at", 

266 "max_redemptions", 

267 "duration_in_months", 

268 *( 

269 {"amount", "currency"} 

270 if isinstance(discount, DiscountFixed) 

271 else {"basis_points"} 

272 ), 

273 } 

274 if sensitive_fields.intersection(updated_fields): 

275 if discount.ends_at is not None and discount.ends_at < utc_now(): 

276 raise PolarRequestValidationError( 

277 [ 

278 { 

279 "type": "value_error", 

280 "loc": ("body", "ends_at"), 

281 "msg": "Ends at must be in the future.", 

282 "input": discount.ends_at, 

283 } 

284 ] 

285 ) 

286 new_stripe_coupon = await stripe_service.create_coupon( 

287 **discount.get_stripe_coupon_params() 

288 ) 

289 await stripe_service.delete_coupon(discount.stripe_coupon_id) 

290 discount.stripe_coupon_id = new_stripe_coupon.id 

291 elif "name" in updated_fields: 

292 await stripe_service.update_coupon( 

293 discount.stripe_coupon_id, 

294 name=discount.name[:40], # Stripe coupon name max length is 40 

295 ) 

296 

297 session.add(discount) 

298 await session.flush() 

299 await session.refresh(discount) 

300 

301 return discount 

302 

303 async def delete(self, session: AsyncSession, discount: Discount) -> Discount: 1a

304 discount.set_deleted_at() 

305 

306 await stripe_service.delete_coupon(discount.stripe_coupon_id) 

307 

308 session.add(discount) 

309 return discount 

310 

311 async def get_by_id_and_organization( 1a

312 self, 

313 session: AsyncSession, 

314 id: uuid.UUID, 

315 organization: Organization, 

316 *, 

317 products: Sequence[Product] | None = None, 

318 redeemable: bool = True, 

319 ) -> Discount | None: 

320 statement = select(Discount).where( 

321 Discount.id == id, 

322 Discount.organization_id == organization.id, 

323 Discount.deleted_at.is_(None), 

324 ) 

325 result = await session.execute(statement) 

326 discount = result.scalar_one_or_none() 

327 

328 if discount is None: 

329 return None 

330 

331 if products is not None: 

332 for product in products: 

333 if not discount.is_applicable(product): 

334 return None 

335 

336 if redeemable and not await self.is_redeemable_discount(session, discount): 

337 return None 

338 

339 return discount 

340 

341 async def get_by_code_and_organization( 1a

342 self, 

343 session: AsyncSession, 

344 code: str, 

345 organization: Organization, 

346 *, 

347 redeemable: bool = True, 

348 ) -> Discount | None: 

349 statement = select(Discount).where( 

350 func.upper(Discount.code) == code.upper(), 

351 Discount.organization_id == organization.id, 

352 Discount.deleted_at.is_(None), 

353 ) 

354 result = await session.execute(statement) 

355 discount = result.scalar_one_or_none() 

356 

357 if discount is None: 

358 return None 

359 

360 if redeemable and not await self.is_redeemable_discount(session, discount): 

361 return None 

362 

363 return discount 

364 

365 async def get_by_code_and_product( 1a

366 self, 

367 session: AsyncSession, 

368 code: str, 

369 organization: Organization, 

370 product: Product, 

371 *, 

372 redeemable: bool = True, 

373 ) -> Discount | None: 

374 discount = await self.get_by_code_and_organization( 

375 session, code, organization, redeemable=redeemable 

376 ) 

377 

378 if discount is None: 

379 return None 

380 

381 if len(discount.products) > 0 and product not in discount.products: 

382 return None 

383 

384 return discount 

385 

386 async def get_by_stripe_coupon_id( 1a

387 self, session: AsyncSession, stripe_coupon_id: str 

388 ) -> Discount | None: 

389 statement = select(Discount).where( 

390 Discount.stripe_coupon_id == stripe_coupon_id 

391 ) 

392 result = await session.execute(statement) 

393 return result.scalar_one_or_none() 

394 

395 async def is_redeemable_discount( 1a

396 self, session: AsyncSession, discount: Discount 

397 ) -> bool: 

398 if discount.starts_at is not None and discount.starts_at > utc_now(): 

399 return False 

400 

401 if discount.ends_at is not None and discount.ends_at < utc_now(): 

402 return False 

403 

404 if discount.max_redemptions is not None: 

405 statement = select(func.count(DiscountRedemption.id)).where( 

406 DiscountRedemption.discount_id == discount.id 

407 ) 

408 result = await session.execute(statement) 

409 redemptions_count = result.scalar_one() 

410 return redemptions_count < discount.max_redemptions 

411 

412 return True 

413 

414 @contextlib.asynccontextmanager 1a

415 async def redeem_discount( 1a

416 self, session: AsyncSession, locker: Locker, discount: Discount 

417 ) -> AsyncIterator[DiscountRedemption]: 

418 # The timeout is purposely set to 10 seconds, a high value. 

419 # We've seen in the past Stripe payment requests taking more than 5 seconds, 

420 # causing the lock to expire while waiting for the payment to complete. 

421 async with locker.lock( 

422 f"discount:{discount.id}", timeout=10, blocking_timeout=10 

423 ): 

424 if not await self.is_redeemable_discount(session, discount): 

425 raise DiscountNotRedeemableError(discount) 

426 

427 discount_redemption = DiscountRedemption(discount=discount) 

428 

429 yield discount_redemption 

430 

431 session.add(discount_redemption) 

432 await session.flush() 

433 await session.refresh(discount, {"redemptions_count"}) 

434 

435 async def remove_checkout_redemption( 1a

436 self, session: AsyncSession, checkout: Checkout 

437 ) -> None: 

438 statement = delete(DiscountRedemption).where( 

439 DiscountRedemption.checkout_id == checkout.id 

440 ) 

441 await session.execute(statement) 

442 

443 def _get_readable_discount_statement( 1a

444 self, auth_subject: AuthSubject[User | Organization] 

445 ) -> Select[tuple[Discount]]: 

446 statement = select(Discount).where(Discount.deleted_at.is_(None)) 

447 

448 if is_user(auth_subject): 

449 user = auth_subject.subject 

450 statement = statement.where( 

451 Discount.organization_id.in_( 

452 select(UserOrganization.organization_id).where( 

453 UserOrganization.user_id == user.id, 

454 UserOrganization.deleted_at.is_(None), 

455 ) 

456 ) 

457 ) 

458 elif is_organization(auth_subject): 

459 statement = statement.where( 

460 Discount.organization_id == auth_subject.subject.id, 

461 ) 

462 

463 return statement 

464 

465 

466discount = DiscountService(Discount) 1a