Coverage for polar/discount/service.py: 16%
188 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
1import contextlib 1a
2import uuid 1a
3from collections.abc import AsyncIterator, Sequence 1a
4from typing import Any 1a
6from sqlalchemy import Select, UnaryExpression, asc, delete, desc, func, or_, select 1a
7from sqlalchemy.orm import joinedload 1a
9from polar.auth.models import AuthSubject, is_organization, is_user 1a
10from polar.exceptions import PolarError, PolarRequestValidationError 1a
11from polar.integrations.stripe.service import stripe as stripe_service 1a
12from polar.kit.pagination import PaginationParams, paginate 1a
13from polar.kit.services import ResourceServiceReader 1a
14from polar.kit.sorting import Sorting 1a
15from polar.kit.utils import utc_now 1a
16from polar.locker import Locker 1a
17from polar.models import ( 1a
18 Discount,
19 DiscountProduct,
20 Organization,
21 Product,
22 User,
23 UserOrganization,
24)
25from polar.models.checkout import Checkout 1a
26from polar.models.discount import DiscountFixed 1a
27from polar.models.discount_redemption import DiscountRedemption 1a
28from polar.organization.resolver import get_payload_organization 1a
29from polar.postgres import AsyncSession 1a
30from polar.product.repository import ProductRepository 1a
32from .schemas import DiscountCreate, DiscountUpdate 1a
33from .sorting import DiscountSortProperty 1a
36class DiscountError(PolarError): ... 1a
39class DiscountNotRedeemableError(DiscountError): 1a
40 def __init__(self, discount: Discount): 1a
41 super().__init__(f"Discount {discount.id} is not redeemable.")
44class DiscountService(ResourceServiceReader[Discount]): 1a
45 async def list( 1a
46 self,
47 session: AsyncSession,
48 auth_subject: AuthSubject[User | Organization],
49 *,
50 organization_id: Sequence[uuid.UUID] | None = None,
51 query: str | None = None,
52 pagination: PaginationParams,
53 sorting: list[Sorting[DiscountSortProperty]] = [
54 (DiscountSortProperty.created_at, True)
55 ],
56 ) -> tuple[Sequence[Discount], int]:
57 statement = self._get_readable_discount_statement(auth_subject)
59 if organization_id is not None:
60 statement = statement.where(Discount.organization_id.in_(organization_id))
62 if query is not None:
63 statement = statement.where(
64 or_(
65 Discount.name.like(f"%{query}%"),
66 Discount.code.ilike(f"%{query}%"),
67 )
68 )
70 order_by_clauses: list[UnaryExpression[Any]] = []
71 for criterion, is_desc in sorting:
72 clause_function = desc if is_desc else asc
73 if criterion == DiscountSortProperty.created_at:
74 order_by_clauses.append(clause_function(Discount.created_at))
75 elif criterion == DiscountSortProperty.discount_name:
76 order_by_clauses.append(clause_function(Discount.name))
77 elif criterion == DiscountSortProperty.code:
78 order_by_clauses.append(clause_function(Discount.code))
79 elif criterion == DiscountSortProperty.redemptions_count:
80 order_by_clauses.append(clause_function(Discount.redemptions_count))
81 statement = statement.order_by(*order_by_clauses)
83 return await paginate(session, statement, pagination=pagination)
85 async def get_by_id( 1a
86 self,
87 session: AsyncSession,
88 auth_subject: AuthSubject[User | Organization],
89 id: uuid.UUID,
90 ) -> Discount | None:
91 statement = (
92 self._get_readable_discount_statement(auth_subject)
93 .where(Discount.id == id)
94 .options(joinedload(Discount.organization))
95 )
96 result = await session.execute(statement)
97 return result.scalar_one_or_none()
99 async def create( 1a
100 self,
101 session: AsyncSession,
102 discount_create: DiscountCreate,
103 auth_subject: AuthSubject[User | Organization],
104 ) -> Discount:
105 organization = await get_payload_organization(
106 session, auth_subject, discount_create
107 )
109 if discount_create.code is not None:
110 existing_discount = await self.get_by_code_and_organization(
111 session, discount_create.code, organization, redeemable=False
112 )
113 if existing_discount is not None:
114 raise PolarRequestValidationError(
115 [
116 {
117 "type": "value_error",
118 "loc": ("body", "code"),
119 "msg": "Discount with this code already exists.",
120 "input": discount_create.code,
121 }
122 ]
123 )
125 discount_products: list[DiscountProduct] = []
126 if discount_create.products:
127 product_repository = ProductRepository.from_session(session)
128 for index, product_id in enumerate(discount_create.products):
129 product = await product_repository.get_by_id_and_organization(
130 product_id, organization.id
131 )
132 if product is None:
133 raise PolarRequestValidationError(
134 [
135 {
136 "type": "value_error",
137 "loc": ("body", "products", index),
138 "msg": "Product not found.",
139 "input": product_id,
140 }
141 ]
142 )
143 discount_products.append(DiscountProduct(product=product))
145 discount_model = discount_create.type.get_model()
146 discount_id = uuid.uuid4()
147 discount = discount_model(
148 **discount_create.model_dump(
149 exclude={"organization_id", "products"}, by_alias=True
150 ),
151 id=discount_id,
152 organization=organization,
153 discount_products=discount_products,
154 discount_redemptions=[],
155 redemptions_count=0,
156 )
157 stripe_coupon = await stripe_service.create_coupon(
158 **discount.get_stripe_coupon_params()
159 )
160 discount.stripe_coupon_id = stripe_coupon.id
162 session.add(discount)
164 return discount
166 async def update( 1a
167 self,
168 session: AsyncSession,
169 discount: Discount,
170 discount_update: DiscountUpdate,
171 ) -> Discount:
172 if (
173 discount_update.duration is not None
174 and discount_update.duration != discount.duration
175 ):
176 raise PolarRequestValidationError(
177 [
178 {
179 "type": "value_error",
180 "loc": ("body", "duration"),
181 "msg": "Duration cannot be changed.",
182 "input": discount_update.duration,
183 }
184 ]
185 )
187 if discount_update.type is not None and discount_update.type != discount.type:
188 raise PolarRequestValidationError(
189 [
190 {
191 "type": "value_error",
192 "loc": ("body", "type"),
193 "msg": "Type cannot be changed.",
194 "input": discount_update.type,
195 }
196 ]
197 )
199 if discount.redemptions_count > 0:
200 forbidden_fields = (
201 {"amount", "currency"}
202 if isinstance(discount, DiscountFixed)
203 else {"basis_points"}
204 )
205 for field in forbidden_fields:
206 discount_update_value = getattr(discount_update, field, None)
207 if (
208 discount_update_value is not None
209 and discount_update_value != getattr(discount, field, None)
210 ):
211 raise PolarRequestValidationError(
212 [
213 {
214 "type": "value_error",
215 "loc": ("body", field),
216 "msg": (
217 "This field cannot be changed because "
218 "the discount has already been redeemed."
219 ),
220 "input": getattr(discount, field),
221 }
222 ]
223 )
225 if discount_update.products is not None:
226 nested = await session.begin_nested()
227 discount.discount_products = []
228 await session.flush()
230 product_repository = ProductRepository.from_session(session)
231 for index, product_id in enumerate(discount_update.products):
232 product = await product_repository.get_by_id_and_organization(
233 product_id, discount.organization_id
234 )
235 if product is None:
236 await nested.rollback()
237 raise PolarRequestValidationError(
238 [
239 {
240 "type": "value_error",
241 "loc": ("body", "products", index),
242 "msg": "Product not found.",
243 "input": product_id,
244 }
245 ]
246 )
247 discount.discount_products.append(DiscountProduct(product=product))
249 updated_fields = set()
250 exclude = {"products"}
251 if isinstance(discount, DiscountFixed):
252 exclude.add("basis_points")
253 else:
254 exclude.add("amount")
255 exclude.add("currency")
256 for attr, value in discount_update.model_dump(
257 exclude_unset=True, exclude=exclude, by_alias=True
258 ).items():
259 if value != getattr(discount, attr):
260 setattr(discount, attr, value)
261 updated_fields.add(attr)
263 sensitive_fields = {
264 "starts_at",
265 "ends_at",
266 "max_redemptions",
267 "duration_in_months",
268 *(
269 {"amount", "currency"}
270 if isinstance(discount, DiscountFixed)
271 else {"basis_points"}
272 ),
273 }
274 if sensitive_fields.intersection(updated_fields):
275 if discount.ends_at is not None and discount.ends_at < utc_now():
276 raise PolarRequestValidationError(
277 [
278 {
279 "type": "value_error",
280 "loc": ("body", "ends_at"),
281 "msg": "Ends at must be in the future.",
282 "input": discount.ends_at,
283 }
284 ]
285 )
286 new_stripe_coupon = await stripe_service.create_coupon(
287 **discount.get_stripe_coupon_params()
288 )
289 await stripe_service.delete_coupon(discount.stripe_coupon_id)
290 discount.stripe_coupon_id = new_stripe_coupon.id
291 elif "name" in updated_fields:
292 await stripe_service.update_coupon(
293 discount.stripe_coupon_id,
294 name=discount.name[:40], # Stripe coupon name max length is 40
295 )
297 session.add(discount)
298 await session.flush()
299 await session.refresh(discount)
301 return discount
303 async def delete(self, session: AsyncSession, discount: Discount) -> Discount: 1a
304 discount.set_deleted_at()
306 await stripe_service.delete_coupon(discount.stripe_coupon_id)
308 session.add(discount)
309 return discount
311 async def get_by_id_and_organization( 1a
312 self,
313 session: AsyncSession,
314 id: uuid.UUID,
315 organization: Organization,
316 *,
317 products: Sequence[Product] | None = None,
318 redeemable: bool = True,
319 ) -> Discount | None:
320 statement = select(Discount).where(
321 Discount.id == id,
322 Discount.organization_id == organization.id,
323 Discount.deleted_at.is_(None),
324 )
325 result = await session.execute(statement)
326 discount = result.scalar_one_or_none()
328 if discount is None:
329 return None
331 if products is not None:
332 for product in products:
333 if not discount.is_applicable(product):
334 return None
336 if redeemable and not await self.is_redeemable_discount(session, discount):
337 return None
339 return discount
341 async def get_by_code_and_organization( 1a
342 self,
343 session: AsyncSession,
344 code: str,
345 organization: Organization,
346 *,
347 redeemable: bool = True,
348 ) -> Discount | None:
349 statement = select(Discount).where(
350 func.upper(Discount.code) == code.upper(),
351 Discount.organization_id == organization.id,
352 Discount.deleted_at.is_(None),
353 )
354 result = await session.execute(statement)
355 discount = result.scalar_one_or_none()
357 if discount is None:
358 return None
360 if redeemable and not await self.is_redeemable_discount(session, discount):
361 return None
363 return discount
365 async def get_by_code_and_product( 1a
366 self,
367 session: AsyncSession,
368 code: str,
369 organization: Organization,
370 product: Product,
371 *,
372 redeemable: bool = True,
373 ) -> Discount | None:
374 discount = await self.get_by_code_and_organization(
375 session, code, organization, redeemable=redeemable
376 )
378 if discount is None:
379 return None
381 if len(discount.products) > 0 and product not in discount.products:
382 return None
384 return discount
386 async def get_by_stripe_coupon_id( 1a
387 self, session: AsyncSession, stripe_coupon_id: str
388 ) -> Discount | None:
389 statement = select(Discount).where(
390 Discount.stripe_coupon_id == stripe_coupon_id
391 )
392 result = await session.execute(statement)
393 return result.scalar_one_or_none()
395 async def is_redeemable_discount( 1a
396 self, session: AsyncSession, discount: Discount
397 ) -> bool:
398 if discount.starts_at is not None and discount.starts_at > utc_now():
399 return False
401 if discount.ends_at is not None and discount.ends_at < utc_now():
402 return False
404 if discount.max_redemptions is not None:
405 statement = select(func.count(DiscountRedemption.id)).where(
406 DiscountRedemption.discount_id == discount.id
407 )
408 result = await session.execute(statement)
409 redemptions_count = result.scalar_one()
410 return redemptions_count < discount.max_redemptions
412 return True
414 @contextlib.asynccontextmanager 1a
415 async def redeem_discount( 1a
416 self, session: AsyncSession, locker: Locker, discount: Discount
417 ) -> AsyncIterator[DiscountRedemption]:
418 # The timeout is purposely set to 10 seconds, a high value.
419 # We've seen in the past Stripe payment requests taking more than 5 seconds,
420 # causing the lock to expire while waiting for the payment to complete.
421 async with locker.lock(
422 f"discount:{discount.id}", timeout=10, blocking_timeout=10
423 ):
424 if not await self.is_redeemable_discount(session, discount):
425 raise DiscountNotRedeemableError(discount)
427 discount_redemption = DiscountRedemption(discount=discount)
429 yield discount_redemption
431 session.add(discount_redemption)
432 await session.flush()
433 await session.refresh(discount, {"redemptions_count"})
435 async def remove_checkout_redemption( 1a
436 self, session: AsyncSession, checkout: Checkout
437 ) -> None:
438 statement = delete(DiscountRedemption).where(
439 DiscountRedemption.checkout_id == checkout.id
440 )
441 await session.execute(statement)
443 def _get_readable_discount_statement( 1a
444 self, auth_subject: AuthSubject[User | Organization]
445 ) -> Select[tuple[Discount]]:
446 statement = select(Discount).where(Discount.deleted_at.is_(None))
448 if is_user(auth_subject):
449 user = auth_subject.subject
450 statement = statement.where(
451 Discount.organization_id.in_(
452 select(UserOrganization.organization_id).where(
453 UserOrganization.user_id == user.id,
454 UserOrganization.deleted_at.is_(None),
455 )
456 )
457 )
458 elif is_organization(auth_subject):
459 statement = statement.where(
460 Discount.organization_id == auth_subject.subject.id,
461 )
463 return statement
466discount = DiscountService(Discount) 1a