Coverage for polar/subscription/repository.py: 28%
102 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
1from collections.abc import Sequence 1a
2from dataclasses import dataclass 1a
3from typing import TYPE_CHECKING 1a
4from uuid import UUID 1a
6from sqlalchemy import Select, case, or_, select 1a
7from sqlalchemy.orm import contains_eager 1a
8from sqlalchemy.orm.strategy_options import joinedload, selectinload 1a
10from polar.auth.models import ( 1a
11 AuthSubject,
12 Organization,
13 User,
14 is_customer,
15 is_organization,
16 is_user,
17)
18from polar.auth.models import ( 1a
19 Customer as AuthCustomer,
20)
21from polar.enums import SubscriptionRecurringInterval 1a
22from polar.kit.repository import ( 1a
23 Options,
24 RepositoryBase,
25 RepositorySoftDeletionIDMixin,
26 RepositorySoftDeletionMixin,
27 RepositorySortingMixin,
28 SortingClause,
29)
30from polar.models import ( 1a
31 Customer,
32 CustomerSeat,
33 Discount,
34 Product,
35 ProductPrice,
36 ProductPriceMeteredUnit,
37 Subscription,
38 SubscriptionMeter,
39 SubscriptionProductPrice,
40 UserOrganization,
41)
42from polar.models.customer_seat import SeatStatus 1a
43from polar.models.subscription import SubscriptionStatus 1a
44from polar.product.guard import is_metered_price 1a
46from .sorting import SubscriptionSortProperty 1a
48if TYPE_CHECKING: 48 ↛ 49line 48 didn't jump to line 49 because the condition on line 48 was never true1a
49 from sqlalchemy.orm.strategy_options import _AbstractLoad
52@dataclass 1a
53class CustomerSubscriptionProductPrice: 1a
54 """
55 Result of looking up a customer's subscription product price for a meter.
57 Contains the paying customer ID (which may be different from the queried customer
58 if they are a seat holder) and the associated subscription product price.
60 The full customer object can be accessed via subscription_product_price.subscription.customer.
61 """
63 customer_id: UUID 1a
64 subscription_product_price: SubscriptionProductPrice 1a
67class SubscriptionRepository( 1a
68 RepositorySortingMixin[Subscription, SubscriptionSortProperty],
69 RepositorySoftDeletionIDMixin[Subscription, UUID],
70 RepositorySoftDeletionMixin[Subscription],
71 RepositoryBase[Subscription],
72):
73 model = Subscription 1a
75 async def list_active_by_customer( 1a
76 self, customer_id: UUID, *, options: Options = ()
77 ) -> Sequence[Subscription]:
78 statement = (
79 self.get_base_statement()
80 .where(
81 Subscription.customer_id == customer_id,
82 Subscription.active.is_(True),
83 )
84 .options(*options)
85 )
86 return await self.get_all(statement)
88 async def get_by_id_and_organization( 1a
89 self,
90 id: UUID,
91 organization_id: UUID,
92 *,
93 options: Options = (),
94 ) -> Subscription | None:
95 statement = (
96 self.get_base_statement()
97 .join(Product)
98 .where(
99 Subscription.id == id,
100 Product.organization_id == organization_id,
101 )
102 .options(contains_eager(Subscription.product), *options)
103 )
104 return await self.get_one_or_none(statement)
106 async def get_by_checkout_id( 1a
107 self, checkout_id: UUID, *, options: Options = ()
108 ) -> Subscription | None:
109 statement = (
110 self.get_base_statement()
111 .where(Subscription.checkout_id == checkout_id)
112 .options(*options)
113 )
114 result = await self.session.execute(statement)
115 return result.scalar_one_or_none()
117 async def get_by_stripe_subscription_id( 1a
118 self, stripe_subscription_id: str, *, options: Options = ()
119 ) -> Subscription | None:
120 statement = (
121 self.get_base_statement()
122 .where(
123 or_(
124 Subscription.stripe_subscription_id == stripe_subscription_id,
125 Subscription.legacy_stripe_subscription_id
126 == stripe_subscription_id,
127 )
128 )
129 .options(*options)
130 )
131 return await self.get_one_or_none(statement)
133 def get_eager_options( 1a
134 self, *, product_load: "_AbstractLoad | None" = None
135 ) -> Options:
136 if product_load is None:
137 product_load = joinedload(Subscription.product)
138 return (
139 joinedload(Subscription.customer),
140 product_load.options(
141 joinedload(Product.organization),
142 selectinload(Product.product_medias),
143 selectinload(Product.attached_custom_fields),
144 ),
145 selectinload(Subscription.meters).joinedload(SubscriptionMeter.meter),
146 )
148 def get_readable_statement( 1a
149 self, auth_subject: AuthSubject[User | Organization | Customer]
150 ) -> Select[tuple[Subscription]]:
151 statement = self.get_base_statement().join(Product)
153 if is_user(auth_subject):
154 user = auth_subject.subject
155 statement = statement.where(
156 Product.organization_id.in_(
157 select(UserOrganization.organization_id).where(
158 UserOrganization.user_id == user.id,
159 UserOrganization.deleted_at.is_(None),
160 )
161 )
162 )
163 elif is_organization(auth_subject):
164 statement = statement.where(
165 Product.organization_id == auth_subject.subject.id,
166 )
167 elif is_customer(auth_subject):
168 customer = auth_subject.subject
169 statement = statement.where(
170 Subscription.customer_id == customer.id,
171 Subscription.deleted_at.is_(None),
172 )
174 return statement
176 def get_claimed_subscriptions_statement( 1a
177 self, auth_subject: AuthSubject[AuthCustomer]
178 ) -> Select[tuple[Subscription]]:
179 """Get subscriptions where the customer has a claimed seat."""
180 customer = auth_subject.subject
182 statement = (
183 self.get_base_statement()
184 .join(CustomerSeat, CustomerSeat.subscription_id == Subscription.id)
185 .where(
186 CustomerSeat.customer_id == customer.id,
187 CustomerSeat.status == SeatStatus.claimed,
188 )
189 )
191 return statement
193 def get_sorting_clause(self, property: SubscriptionSortProperty) -> SortingClause: 1a
194 match property:
195 case SubscriptionSortProperty.customer:
196 return Customer.email
197 case SubscriptionSortProperty.status:
198 return case(
199 (Subscription.status == SubscriptionStatus.incomplete, 1),
200 (
201 Subscription.status == SubscriptionStatus.incomplete_expired,
202 2,
203 ),
204 (Subscription.status == SubscriptionStatus.trialing, 3),
205 (
206 Subscription.status == SubscriptionStatus.active,
207 case(
208 (Subscription.cancel_at_period_end.is_(False), 4),
209 (Subscription.cancel_at_period_end.is_(True), 5),
210 ),
211 ),
212 (Subscription.status == SubscriptionStatus.past_due, 6),
213 (Subscription.status == SubscriptionStatus.canceled, 7),
214 (Subscription.status == SubscriptionStatus.unpaid, 8),
215 )
216 case SubscriptionSortProperty.started_at:
217 return Subscription.started_at
218 case SubscriptionSortProperty.current_period_end:
219 return Subscription.current_period_end
220 case SubscriptionSortProperty.amount:
221 return case(
222 (
223 Subscription.recurring_interval
224 == SubscriptionRecurringInterval.year,
225 Subscription.amount / 12,
226 ),
227 (
228 Subscription.recurring_interval
229 == SubscriptionRecurringInterval.month,
230 Subscription.amount,
231 ),
232 (
233 Subscription.recurring_interval
234 == SubscriptionRecurringInterval.week,
235 Subscription.amount * 4,
236 ),
237 (
238 Subscription.recurring_interval
239 == SubscriptionRecurringInterval.day,
240 Subscription.amount * 30,
241 ),
242 )
243 case SubscriptionSortProperty.product:
244 return Product.name
245 case SubscriptionSortProperty.discount:
246 return Discount.name
249class SubscriptionProductPriceRepository( 1a
250 RepositorySoftDeletionIDMixin[SubscriptionProductPrice, UUID],
251 RepositorySoftDeletionMixin[SubscriptionProductPrice],
252 RepositoryBase[SubscriptionProductPrice],
253):
254 model = SubscriptionProductPrice 1a
256 async def get_by_customer_and_meter( 1a
257 self, customer_id: UUID, meter_id: UUID
258 ) -> CustomerSubscriptionProductPrice | None:
259 """
260 Get the paying customer and subscription product price for a customer and meter.
262 If the customer has a direct subscription with the meter, returns that.
263 If the customer is a seat holder, returns the billing manager's subscription.
264 """
265 result = await self._get_direct_subscription_price(customer_id, meter_id)
266 if result is not None:
267 return result
269 return await self._get_seat_subscription_price(customer_id, meter_id)
271 async def _get_direct_subscription_price( 1a
272 self, customer_id: UUID, meter_id: UUID
273 ) -> CustomerSubscriptionProductPrice | None:
274 statement = (
275 self.get_base_statement()
276 .join(
277 ProductPrice,
278 SubscriptionProductPrice.product_price_id == ProductPrice.id,
279 )
280 .join(
281 Subscription,
282 Subscription.id == SubscriptionProductPrice.subscription_id,
283 )
284 .where(
285 ProductPrice.is_metered.is_(True),
286 ProductPriceMeteredUnit.meter_id == meter_id,
287 Subscription.billable.is_(True),
288 Subscription.customer_id == customer_id,
289 )
290 # In case customer has several subscriptions, take the earliest one
291 .order_by(Subscription.started_at.asc())
292 .limit(1)
293 .options(
294 contains_eager(SubscriptionProductPrice.product_price),
295 contains_eager(SubscriptionProductPrice.subscription).joinedload(
296 Subscription.customer
297 ),
298 )
299 )
301 subscription_product_price = await self.get_one_or_none(statement)
302 if subscription_product_price is None:
303 return None
305 return CustomerSubscriptionProductPrice(
306 customer_id=subscription_product_price.subscription.customer_id,
307 subscription_product_price=subscription_product_price,
308 )
310 async def _get_seat_subscription_price( 1a
311 self, customer_id: UUID, meter_id: UUID
312 ) -> CustomerSubscriptionProductPrice | None:
313 """
314 Get subscription product price for a customer who is a seat holder.
316 Returns the billing manager's subscription if the seat holder has access
317 to a metered price for the specified meter.
318 """
319 seat = await self._get_active_seat_for_customer(customer_id)
320 if seat is None or seat.subscription is None:
321 return None
323 # Find matching metered price in billing manager's subscription
324 assert seat.subscription is not None
325 metered_price = self._find_metered_price_in_subscription(
326 seat.subscription, meter_id
327 )
328 if metered_price is None:
329 return None
331 return CustomerSubscriptionProductPrice(
332 customer_id=seat.subscription.customer_id,
333 subscription_product_price=metered_price,
334 )
336 async def _get_active_seat_for_customer( 1a
337 self, customer_id: UUID
338 ) -> CustomerSeat | None:
339 """Get the active seat for a customer, with subscription data eagerly loaded."""
340 statement = (
341 select(CustomerSeat)
342 .where(
343 CustomerSeat.customer_id == customer_id,
344 CustomerSeat.status == SeatStatus.claimed,
345 )
346 .options(
347 joinedload(CustomerSeat.subscription).options(
348 joinedload(Subscription.customer),
349 joinedload(Subscription.subscription_product_prices).options(
350 joinedload(SubscriptionProductPrice.product_price),
351 # Load the back-reference to satisfy lazy='raise_on_sql'
352 # This points to the same Subscription already being loaded above
353 joinedload(SubscriptionProductPrice.subscription),
354 ),
355 )
356 )
357 .limit(1)
358 )
359 return await self.session.scalar(statement)
361 def _find_metered_price_in_subscription( 1a
362 self, subscription: Subscription, meter_id: UUID
363 ) -> SubscriptionProductPrice | None:
364 """
365 Find a metered price for the given meter in a subscription.
367 Returns None if no matching metered price is found.
368 """
369 for spp in subscription.subscription_product_prices:
370 if (
371 is_metered_price(spp.product_price)
372 and spp.product_price.meter_id == meter_id
373 ):
374 return spp
375 return None