Coverage for polar/customer_portal/service/subscription.py: 32%
82 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
1import uuid 1a
2from collections.abc import Sequence 1a
3from enum import StrEnum 1a
4from typing import Any 1a
6from sqlalchemy import Select, UnaryExpression, asc, desc, select 1a
7from sqlalchemy.orm import contains_eager, joinedload, selectinload 1a
9from polar.auth.models import AuthSubject 1a
10from polar.exceptions import PolarError 1a
11from polar.kit.db.postgres import AsyncSession 1a
12from polar.kit.pagination import PaginationParams, paginate 1a
13from polar.kit.services import ResourceServiceReader 1a
14from polar.kit.sorting import Sorting 1a
15from polar.models import ( 1a
16 Customer,
17 Organization,
18 Product,
19 Subscription,
20 SubscriptionMeter,
21)
22from polar.models.subscription import CustomerCancellationReason 1a
23from polar.subscription.service import subscription as subscription_service 1a
25from ..schemas.subscription import ( 1a
26 CustomerSubscriptionUpdate,
27 CustomerSubscriptionUpdateProduct,
28 CustomerSubscriptionUpdateSeats,
29)
32class CustomerSubscriptionError(PolarError): ... 1a
35class UpdateSubscriptionNotAllowed(CustomerSubscriptionError): 1a
36 def __init__(self) -> None: 1a
37 super().__init__("Updating subscription is not allowed.", 403)
40class CustomerSubscriptionSortProperty(StrEnum): 1a
41 started_at = "started_at" 1a
42 amount = "amount" 1a
43 status = "status" 1a
44 organization = "organization" 1a
45 product = "product" 1a
48class CustomerSubscriptionService(ResourceServiceReader[Subscription]): 1a
49 async def list( 1a
50 self,
51 session: AsyncSession,
52 auth_subject: AuthSubject[Customer],
53 *,
54 product_id: Sequence[uuid.UUID] | None = None,
55 active: bool | None = None,
56 query: str | None = None,
57 pagination: PaginationParams,
58 sorting: list[Sorting[CustomerSubscriptionSortProperty]] = [
59 (CustomerSubscriptionSortProperty.started_at, True)
60 ],
61 ) -> tuple[Sequence[Subscription], int]:
62 statement = self._get_readable_subscription_statement(auth_subject).where(
63 Subscription.started_at.is_not(None)
64 )
66 statement = (
67 statement.join(Product, onclause=Subscription.product_id == Product.id)
68 .join(Organization, onclause=Product.organization_id == Organization.id)
69 .options(
70 joinedload(Subscription.customer),
71 contains_eager(Subscription.product).options(
72 selectinload(Product.product_medias),
73 contains_eager(Product.organization),
74 ),
75 selectinload(Subscription.meters).joinedload(SubscriptionMeter.meter),
76 )
77 )
79 if product_id is not None:
80 statement = statement.where(Subscription.product_id.in_(product_id))
82 if active is not None:
83 if active:
84 statement = statement.where(Subscription.active.is_(True))
85 else:
86 statement = statement.where(Subscription.revoked.is_(True))
88 if query is not None:
89 statement = statement.where(Product.name.ilike(f"%{query}%"))
91 order_by_clauses: list[UnaryExpression[Any]] = []
92 for criterion, is_desc in sorting:
93 clause_function = desc if is_desc else asc
94 if criterion == CustomerSubscriptionSortProperty.started_at:
95 order_by_clauses.append(clause_function(Subscription.started_at))
96 elif criterion == CustomerSubscriptionSortProperty.amount:
97 order_by_clauses.append(clause_function(Subscription.amount))
98 elif criterion == CustomerSubscriptionSortProperty.status:
99 order_by_clauses.append(clause_function(Subscription.status))
100 elif criterion == CustomerSubscriptionSortProperty.organization:
101 order_by_clauses.append(clause_function(Organization.slug))
102 elif criterion == CustomerSubscriptionSortProperty.product:
103 order_by_clauses.append(clause_function(Product.name))
104 statement = statement.order_by(*order_by_clauses)
106 return await paginate(session, statement, pagination=pagination)
108 async def get_by_id( 1a
109 self,
110 session: AsyncSession,
111 auth_subject: AuthSubject[Customer],
112 id: uuid.UUID,
113 ) -> Subscription | None:
114 statement = (
115 self._get_readable_subscription_statement(auth_subject)
116 .where(Subscription.id == id)
117 .options(
118 joinedload(Subscription.customer),
119 joinedload(Subscription.product).options(
120 selectinload(Product.product_medias),
121 joinedload(Product.organization),
122 ),
123 selectinload(Subscription.meters).joinedload(SubscriptionMeter.meter),
124 )
125 )
127 result = await session.execute(statement)
128 return result.scalar_one_or_none()
130 async def update( 1a
131 self,
132 session: AsyncSession,
133 subscription: Subscription,
134 *,
135 updates: CustomerSubscriptionUpdate,
136 ) -> Subscription:
137 if isinstance(updates, CustomerSubscriptionUpdateProduct):
138 organization = subscription.product.organization
139 if not organization.allow_customer_updates:
140 raise UpdateSubscriptionNotAllowed()
142 return await self.update_product(
143 session,
144 subscription,
145 product_id=updates.product_id,
146 )
148 if isinstance(updates, CustomerSubscriptionUpdateSeats):
149 organization = subscription.product.organization
151 return await subscription_service.update_seats(
152 session,
153 subscription,
154 seats=updates.seats,
155 proration_behavior=updates.proration_behavior,
156 )
158 cancel = updates.cancel_at_period_end is True
159 uncancel = updates.cancel_at_period_end is False
160 if not (cancel or uncancel):
161 return subscription
163 if cancel:
164 return await self.cancel(
165 session,
166 subscription,
167 reason=updates.cancellation_reason,
168 comment=updates.cancellation_comment,
169 )
171 return await self.uncancel(session, subscription)
173 async def update_product( 1a
174 self,
175 session: AsyncSession,
176 subscription: Subscription,
177 *,
178 product_id: uuid.UUID,
179 ) -> Subscription:
180 return await subscription_service.update_product(
181 session, subscription, product_id=product_id
182 )
184 async def uncancel( 1a
185 self,
186 session: AsyncSession,
187 subscription: Subscription,
188 ) -> Subscription:
189 return await subscription_service.uncancel(
190 session,
191 subscription,
192 )
194 async def cancel( 1a
195 self,
196 session: AsyncSession,
197 subscription: Subscription,
198 *,
199 reason: CustomerCancellationReason | None = None,
200 comment: str | None = None,
201 ) -> Subscription:
202 return await subscription_service.cancel(
203 session,
204 subscription,
205 customer_reason=reason,
206 customer_comment=comment,
207 )
209 def _get_readable_subscription_statement( 1a
210 self, auth_subject: AuthSubject[Customer]
211 ) -> Select[tuple[Subscription]]:
212 return select(Subscription).where(
213 Subscription.deleted_at.is_(None),
214 Subscription.customer_id == auth_subject.subject.id,
215 )
218customer_subscription = CustomerSubscriptionService(Subscription) 1a