Coverage for polar/customer_seat/repository.py: 31%
103 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
1from collections.abc import Sequence 1a
2from uuid import UUID 1a
4from sqlalchemy import Select, func, select 1a
5from sqlalchemy.orm import joinedload 1a
7from polar.auth.models import AuthSubject, Organization, User, is_organization, is_user 1a
8from polar.kit.repository import RepositoryBase 1a
9from polar.kit.repository.base import Options 1a
10from polar.models import ( 1a
11 Customer,
12 CustomerSeat,
13 Order,
14 Product,
15 Subscription,
16 UserOrganization,
17)
18from polar.models.customer_seat import SeatStatus 1a
19from polar.order.repository import OrderRepository 1a
20from polar.subscription.repository import SubscriptionRepository 1a
22SeatContainer = Subscription | Order 1a
25class CustomerSeatRepository(RepositoryBase[CustomerSeat]): 1a
26 model = CustomerSeat 1a
28 async def list_by_container( 1a
29 self, container: SeatContainer, *, options: Options = ()
30 ) -> Sequence[CustomerSeat]:
31 """List seats for a subscription or order."""
32 if isinstance(container, Subscription):
33 return await self.list_by_subscription_id(container.id, options=options)
34 else:
35 return await self.list_by_order_id(container.id, options=options)
37 async def get_available_seats_count_for_container( 1a
38 self, container: SeatContainer
39 ) -> int:
40 """Get available seats count for a subscription or order."""
41 if isinstance(container, Subscription):
42 return await self.get_available_seats_count(container.id)
43 else:
44 return await self.get_available_seats_count_for_order(container.id)
46 async def get_by_container_and_customer( 1a
47 self,
48 container: SeatContainer,
49 customer_id: UUID,
50 *,
51 options: Options = (),
52 ) -> CustomerSeat | None:
53 if isinstance(container, Subscription):
54 return await self.get_by_subscription_and_customer(
55 container.id, customer_id, options=options
56 )
57 else:
58 return await self.get_by_order_and_customer(
59 container.id, customer_id, options=options
60 )
62 async def get_revoked_seat_by_container( 1a
63 self,
64 container: SeatContainer,
65 *,
66 options: Options = (),
67 ) -> CustomerSeat | None:
68 if isinstance(container, Subscription):
69 return await self.get_revoked_seat_by_subscription(
70 container.id, options=options
71 )
72 else:
73 return await self.get_revoked_seat_by_order(container.id, options=options)
75 async def list_by_subscription_id( 1a
76 self, subscription_id: UUID, *, options: Options = ()
77 ) -> Sequence[CustomerSeat]:
78 statement = (
79 select(CustomerSeat)
80 .where(CustomerSeat.subscription_id == subscription_id)
81 .options(*options)
82 )
83 return await self.get_all(statement)
85 async def list_by_order_id( 1a
86 self, order_id: UUID, *, options: Options = ()
87 ) -> Sequence[CustomerSeat]:
88 statement = (
89 select(CustomerSeat)
90 .where(CustomerSeat.order_id == order_id)
91 .options(*options)
92 )
93 return await self.get_all(statement)
95 async def get_by_invitation_token( 1a
96 self, token: str, *, options: Options = ()
97 ) -> CustomerSeat | None:
98 statement = (
99 select(CustomerSeat)
100 .where(CustomerSeat.invitation_token == token)
101 .options(*options)
102 )
103 return await self.get_one_or_none(statement)
105 async def count_assigned_seats_for_subscription(self, subscription_id: UUID) -> int: 1a
106 statement = select(func.count(CustomerSeat.id)).where(
107 CustomerSeat.subscription_id == subscription_id,
108 CustomerSeat.status.in_([SeatStatus.pending, SeatStatus.claimed]),
109 )
110 result = await self.session.execute(statement)
111 return result.scalar_one()
113 async def get_available_seats_count(self, subscription_id: UUID) -> int: 1a
114 subscription_statement = select(Subscription).where(
115 Subscription.id == subscription_id
116 )
117 subscription_repository = SubscriptionRepository.from_session(self.session)
118 subscription = await subscription_repository.get_one_or_none(
119 subscription_statement
120 )
122 if not subscription or subscription.seats is None:
123 return 0
124 claimed_statement = select(CustomerSeat).where(
125 CustomerSeat.subscription_id == subscription_id,
126 CustomerSeat.status.in_([SeatStatus.claimed, SeatStatus.pending]),
127 )
128 claimed_seats = await self.get_all(claimed_statement)
130 return max(0, subscription.seats - len(claimed_seats))
132 async def get_available_seats_count_for_order(self, order_id: UUID) -> int: 1a
133 order_repository = OrderRepository.from_session(self.session)
134 order_statement = select(Order).where(Order.id == order_id)
135 order = await order_repository.get_one_or_none(order_statement)
137 if not order or order.seats is None:
138 return 0
140 claimed_statement = select(CustomerSeat).where(
141 CustomerSeat.order_id == order_id,
142 CustomerSeat.status.in_([SeatStatus.claimed, SeatStatus.pending]),
143 )
144 claimed_seats = await self.get_all(claimed_statement)
146 return max(0, order.seats - len(claimed_seats))
148 async def list_by_customer_id( 1a
149 self, customer_id: UUID, *, options: Options = ()
150 ) -> Sequence[CustomerSeat]:
151 statement = (
152 select(CustomerSeat)
153 .where(CustomerSeat.customer_id == customer_id)
154 .options(*options)
155 )
156 return await self.get_all(statement)
158 async def get_by_subscription_and_customer( 1a
159 self,
160 subscription_id: UUID,
161 customer_id: UUID,
162 *,
163 options: Options = (),
164 ) -> CustomerSeat | None:
165 statement = (
166 select(CustomerSeat)
167 .where(
168 CustomerSeat.subscription_id == subscription_id,
169 CustomerSeat.customer_id == customer_id,
170 )
171 .options(*options)
172 )
173 return await self.get_one_or_none(statement)
175 async def get_by_order_and_customer( 1a
176 self,
177 order_id: UUID,
178 customer_id: UUID,
179 *,
180 options: Options = (),
181 ) -> CustomerSeat | None:
182 statement = (
183 select(CustomerSeat)
184 .where(
185 CustomerSeat.order_id == order_id,
186 CustomerSeat.customer_id == customer_id,
187 )
188 .options(*options)
189 )
190 return await self.get_one_or_none(statement)
192 async def get_revoked_seat_by_subscription( 1a
193 self,
194 subscription_id: UUID,
195 *,
196 options: Options = (),
197 ) -> CustomerSeat | None:
198 """Get a revoked seat for a subscription that can be reused."""
199 statement = (
200 select(CustomerSeat)
201 .where(
202 CustomerSeat.subscription_id == subscription_id,
203 CustomerSeat.status == SeatStatus.revoked,
204 )
205 .options(*options)
206 .limit(1)
207 )
208 return await self.get_one_or_none(statement)
210 async def get_revoked_seat_by_order( 1a
211 self,
212 order_id: UUID,
213 *,
214 options: Options = (),
215 ) -> CustomerSeat | None:
216 """Get a revoked seat for an order that can be reused."""
217 statement = (
218 select(CustomerSeat)
219 .where(
220 CustomerSeat.order_id == order_id,
221 CustomerSeat.status == SeatStatus.revoked,
222 )
223 .options(*options)
224 .limit(1)
225 )
226 return await self.get_one_or_none(statement)
228 async def get_by_id( 1a
229 self,
230 seat_id: UUID,
231 *,
232 options: Options = (),
233 ) -> CustomerSeat | None:
234 """Get a seat by ID."""
235 statement = (
236 select(CustomerSeat).where(CustomerSeat.id == seat_id).options(*options)
237 )
238 return await self.get_one_or_none(statement)
240 async def get_by_id_for_customer( 1a
241 self,
242 seat_id: UUID,
243 customer_id: UUID,
244 *,
245 options: Options = (),
246 ) -> CustomerSeat | None:
247 """Get a seat by ID and verify it belongs to a subscription or order owned by the customer."""
248 statement = (
249 select(CustomerSeat)
250 .outerjoin(Subscription, CustomerSeat.subscription_id == Subscription.id)
251 .outerjoin(Order, CustomerSeat.order_id == Order.id)
252 .where(
253 CustomerSeat.id == seat_id,
254 (
255 (Subscription.customer_id == customer_id)
256 | (Order.customer_id == customer_id)
257 ),
258 )
259 .options(*options)
260 )
261 return await self.get_one_or_none(statement)
263 def get_readable_statement( 1a
264 self, auth_subject: AuthSubject[User | Organization]
265 ) -> Select[tuple[CustomerSeat]]:
266 """
267 Get a statement filtered by authorization.
269 Seats are readable by users/organizations who have access to the product's organization.
270 Handles both subscription-based and order-based seats.
271 """
273 statement = (
274 self.get_base_statement()
275 .outerjoin(Subscription, CustomerSeat.subscription_id == Subscription.id)
276 .outerjoin(Order, CustomerSeat.order_id == Order.id)
277 .outerjoin(
278 Product,
279 (Subscription.product_id == Product.id)
280 | (Order.product_id == Product.id),
281 )
282 )
284 if is_user(auth_subject):
285 user_org_ids = select(UserOrganization.organization_id).where(
286 UserOrganization.user_id == auth_subject.subject.id,
287 UserOrganization.deleted_at.is_(None),
288 )
289 statement = statement.where(Product.organization_id.in_(user_org_ids))
290 elif is_organization(auth_subject):
291 statement = statement.where(
292 Product.organization_id == auth_subject.subject.id
293 )
295 return statement
297 async def get_by_id_and_auth_subject( 1a
298 self,
299 auth_subject: AuthSubject[User | Organization],
300 seat_id: UUID,
301 *,
302 options: Options = (),
303 ) -> CustomerSeat | None:
304 """Get a seat by ID filtered by auth subject."""
305 statement = (
306 self.get_readable_statement(auth_subject)
307 .where(CustomerSeat.id == seat_id)
308 .options(*options)
309 )
310 return await self.get_one_or_none(statement)
312 async def get_by_subscription_and_auth_subject( 1a
313 self,
314 auth_subject: AuthSubject[User | Organization],
315 seat_id: UUID,
316 subscription_id: UUID,
317 *,
318 options: Options = (),
319 ) -> CustomerSeat | None:
320 """Get a seat by ID and subscription ID filtered by auth subject."""
321 statement = (
322 self.get_readable_statement(auth_subject)
323 .where(
324 CustomerSeat.id == seat_id,
325 CustomerSeat.subscription_id == subscription_id,
326 )
327 .options(*options)
328 )
329 return await self.get_one_or_none(statement)
331 async def get_active_seat_for_customer( 1a
332 self,
333 customer_id: UUID,
334 *,
335 options: Options = (),
336 ) -> CustomerSeat | None:
337 """
338 Get an active (claimed) seat for a customer.
340 Used to determine if a customer is a seat holder and should have
341 their usage charges routed to the billing manager's subscription.
342 """
343 statement = (
344 select(CustomerSeat)
345 .where(
346 CustomerSeat.customer_id == customer_id,
347 CustomerSeat.status == SeatStatus.claimed,
348 )
349 .options(*options)
350 .limit(1)
351 )
352 return await self.get_one_or_none(statement)
354 def get_eager_options(self) -> Options: 1a
355 return (
356 joinedload(CustomerSeat.subscription).options(
357 joinedload(Subscription.product).joinedload(Product.organization),
358 joinedload(Subscription.customer),
359 ),
360 joinedload(CustomerSeat.order).options(
361 joinedload(Order.product),
362 joinedload(Order.customer).joinedload(Customer.organization),
363 ),
364 joinedload(CustomerSeat.customer),
365 )
367 def get_eager_options_with_prices(self) -> Options: 1a
368 return (
369 *self.get_eager_options(),
370 joinedload(CustomerSeat.subscription).joinedload(
371 Subscription.subscription_product_prices
372 ),
373 )