Coverage for polar/customer_seat/service.py: 17%
295 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
1import secrets 1a
2import uuid 1a
3from collections.abc import Sequence 1a
4from datetime import UTC, datetime, timedelta 1a
5from typing import Any 1a
7import structlog 1a
9from polar.auth.models import AuthSubject 1a
10from polar.customer.repository import CustomerRepository 1a
11from polar.customer_seat.sender import send_seat_invitation_email 1a
12from polar.customer_session.service import ( 1a
13 customer_session as customer_session_service,
14)
15from polar.eventstream.service import publish as eventstream_publish 1a
16from polar.exceptions import PolarError 1a
17from polar.kit.db.postgres import AsyncSession 1a
18from polar.models import ( 1a
19 Customer,
20 CustomerSeat,
21 Order,
22 Organization,
23 Product,
24 Subscription,
25 User,
26)
27from polar.models.customer_seat import SeatStatus 1a
28from polar.models.order import OrderStatus 1a
29from polar.models.webhook_endpoint import WebhookEventType 1a
30from polar.organization.repository import OrganizationRepository 1a
31from polar.postgres import AsyncReadSession 1a
32from polar.webhook.service import webhook as webhook_service 1a
33from polar.worker import enqueue_job 1a
35from .repository import CustomerSeatRepository 1a
37log = structlog.get_logger() 1a
40class SeatError(PolarError): ... 1a
43class SeatNotAvailable(SeatError): 1a
44 def __init__(self, source_id: uuid.UUID, reason: str | None = None) -> None: 1a
45 self.source_id = source_id
46 message = reason or f"No available seats for {source_id}"
47 super().__init__(message, 400)
50class InvalidInvitationToken(SeatError): 1a
51 def __init__(self, token: str) -> None: 1a
52 self.token = token
53 message = "Invalid or expired invitation token"
54 super().__init__(message, 400)
57class FeatureNotEnabled(SeatError): 1a
58 def __init__(self) -> None: 1a
59 message = "Seat-based pricing is not enabled for this organization"
60 super().__init__(message, 403)
63class SeatAlreadyAssigned(SeatError): 1a
64 def __init__(self, customer_email: str) -> None: 1a
65 self.customer_email = customer_email
66 message = f"Seat already assigned to customer {customer_email}"
67 super().__init__(message, 400)
70class SeatNotPending(SeatError): 1a
71 def __init__(self) -> None: 1a
72 message = "Seat is not in pending status"
73 super().__init__(message, 400)
76class InvalidSeatAssignmentRequest(SeatError): 1a
77 def __init__(self) -> None: 1a
78 message = "Exactly one of email, external_customer_id, or customer_id must be provided"
79 super().__init__(message, 400)
82class CustomerNotFound(SeatError): 1a
83 def __init__(self, customer_identifier: str) -> None: 1a
84 self.customer_identifier = customer_identifier
85 message = f"Customer not found: {customer_identifier}"
86 super().__init__(message, 404)
89SeatContainer = Subscription | Order 1a
92class SeatService: 1a
93 def _get_customer_id(self, container: SeatContainer) -> uuid.UUID: 1a
94 return container.customer_id
96 def _get_product(self, container: SeatContainer) -> Product | None: 1a
97 return container.product
99 def _get_organization_id(self, container: SeatContainer) -> uuid.UUID: 1a
100 return container.organization.id
102 def _get_seats_count(self, container: SeatContainer) -> int: 1a
103 return container.seats or 0
105 def _get_container_id(self, container: SeatContainer) -> uuid.UUID: 1a
106 return container.id
108 def _is_subscription(self, container: SeatContainer) -> bool: 1a
109 return isinstance(container, Subscription)
111 async def _enqueue_benefit_grant( 1a
112 self, seat: CustomerSeat, product_id: uuid.UUID
113 ) -> None:
114 """Enqueue benefit grant job for a claimed seat."""
115 if seat.subscription_id:
116 enqueue_job(
117 "benefit.enqueue_benefits_grants",
118 task="grant",
119 customer_id=seat.customer_id,
120 product_id=product_id,
121 subscription_id=seat.subscription_id,
122 )
123 else:
124 enqueue_job(
125 "benefit.enqueue_benefits_grants",
126 task="grant",
127 customer_id=seat.customer_id,
128 product_id=product_id,
129 order_id=seat.order_id,
130 )
132 async def _publish_seat_claimed_event( 1a
133 self, seat: CustomerSeat, product_id: uuid.UUID
134 ) -> None:
135 """Publish eventstream event for seat claimed."""
136 await eventstream_publish(
137 "customer_seat.claimed",
138 {
139 "seat_id": str(seat.id),
140 "subscription_id": str(seat.subscription_id)
141 if seat.subscription_id
142 else None,
143 "order_id": str(seat.order_id) if seat.order_id else None,
144 "product_id": str(product_id),
145 },
146 customer_id=seat.customer_id,
147 )
149 async def _send_seat_claimed_webhook( 1a
150 self, session: AsyncSession, organization_id: uuid.UUID, seat: CustomerSeat
151 ) -> None:
152 """Send webhook for seat claimed."""
153 organization_repository = OrganizationRepository.from_session(session)
154 organization = await organization_repository.get_by_id(organization_id)
155 if organization:
156 await webhook_service.send(
157 session,
158 organization,
159 WebhookEventType.customer_seat_claimed,
160 seat,
161 )
163 async def check_seat_feature_enabled( 1a
164 self, session: AsyncReadSession, organization_id: uuid.UUID
165 ) -> None:
166 from polar.organization.repository import OrganizationRepository
168 organization_repository = OrganizationRepository.from_session(session)
169 organization = await organization_repository.get_by_id(organization_id)
170 if not organization:
171 raise FeatureNotEnabled()
172 if not organization.feature_settings.get("seat_based_pricing_enabled", False):
173 raise FeatureNotEnabled()
175 async def list_seats( 1a
176 self,
177 session: AsyncReadSession,
178 container: SeatContainer,
179 ) -> Sequence[CustomerSeat]:
180 await self.check_seat_feature_enabled(
181 session, self._get_organization_id(container)
182 )
183 repository = CustomerSeatRepository.from_session(session)
184 return await repository.list_by_container(
185 container,
186 options=repository.get_eager_options(),
187 )
189 async def get_available_seats_count( 1a
190 self,
191 session: AsyncReadSession,
192 container: SeatContainer,
193 ) -> int:
194 await self.check_seat_feature_enabled(
195 session, self._get_organization_id(container)
196 )
197 repository = CustomerSeatRepository.from_session(session)
198 return await repository.get_available_seats_count_for_container(container)
200 async def count_assigned_seats_for_subscription( 1a
201 self,
202 session: AsyncReadSession,
203 subscription: Subscription,
204 ) -> int:
205 repository = CustomerSeatRepository.from_session(session)
206 return await repository.count_assigned_seats_for_subscription(subscription.id)
208 async def assign_seat( 1a
209 self,
210 session: AsyncSession,
211 container: SeatContainer,
212 *,
213 email: str | None = None,
214 external_customer_id: str | None = None,
215 customer_id: uuid.UUID | None = None,
216 metadata: dict[str, Any] | None = None,
217 immediate_claim: bool = False,
218 ) -> CustomerSeat:
219 product = self._get_product(container)
220 source_id = self._get_container_id(container)
222 if product is None:
223 raise SeatNotAvailable(
224 source_id,
225 "Container has no associated product",
226 )
228 organization_id = self._get_organization_id(container)
229 billing_manager_customer = container.customer
230 is_subscription = self._is_subscription(container)
232 await self.check_seat_feature_enabled(session, organization_id)
234 # Validate order payment status
235 if isinstance(container, Order):
236 if container.status == OrderStatus.pending:
237 raise SeatNotAvailable(
238 source_id, "Order must be paid before assigning seats"
239 )
241 repository = CustomerSeatRepository.from_session(session)
243 available_seats = await repository.get_available_seats_count_for_container(
244 container
245 )
247 if available_seats <= 0:
248 raise SeatNotAvailable(source_id)
250 customer = await self._find__or_create_customer(
251 session,
252 organization_id,
253 email,
254 external_customer_id,
255 customer_id,
256 )
258 existing_seat = await repository.get_by_container_and_customer(
259 container, customer.id
260 )
262 if existing_seat and not existing_seat.is_revoked():
263 identifier = email or external_customer_id or str(customer_id)
264 raise SeatAlreadyAssigned(identifier)
266 # Only generate invitation token for standard (non-immediate) claims
267 if immediate_claim:
268 invitation_token = None
269 token_expires_at = None
270 else:
271 invitation_token = secrets.token_urlsafe(32)
272 token_expires_at = datetime.now(UTC) + timedelta(days=1)
274 revoked_seat = await repository.get_revoked_seat_by_container(container)
276 if revoked_seat:
277 seat = revoked_seat
278 seat.status = SeatStatus.claimed if immediate_claim else SeatStatus.pending
279 seat.invitation_token = invitation_token
280 seat.invitation_token_expires_at = token_expires_at
281 seat.customer_id = customer.id
282 seat.seat_metadata = metadata or {}
283 seat.revoked_at = None
284 seat.claimed_at = datetime.now(UTC) if immediate_claim else None
285 else:
286 seat_data = {
287 "status": SeatStatus.claimed if immediate_claim else SeatStatus.pending,
288 "invitation_token": invitation_token,
289 "invitation_token_expires_at": token_expires_at,
290 "customer_id": customer.id,
291 "seat_metadata": metadata or {},
292 "claimed_at": datetime.now(UTC) if immediate_claim else None,
293 }
294 if is_subscription:
295 seat_data["subscription_id"] = source_id
296 else:
297 seat_data["order_id"] = source_id
299 seat = CustomerSeat(**seat_data)
300 session.add(seat)
302 await session.flush()
304 if immediate_claim:
305 # Immediate claim flow: grant benefits and trigger claimed webhook
306 log.info(
307 "Seat immediately claimed",
308 subscription_id=seat.subscription_id,
309 order_id=seat.order_id,
310 email=email,
311 customer_id=customer.id,
312 )
314 await self._publish_seat_claimed_event(seat, product.id)
315 await self._enqueue_benefit_grant(seat, product.id)
316 await self._send_seat_claimed_webhook(session, organization_id, seat)
317 else:
318 # Standard flow: send invitation email and trigger assigned webhook
319 log.info(
320 "Seat assigned",
321 subscription_id=seat.subscription_id,
322 order_id=seat.order_id,
323 email=email,
324 customer_id=customer.id,
325 invitation_token=invitation_token or "none",
326 )
328 organization_repository = OrganizationRepository.from_session(session)
329 organization = await organization_repository.get_by_id(organization_id)
330 if organization:
331 send_seat_invitation_email(
332 customer_email=customer.email,
333 seat=seat,
334 organization=organization,
335 product_name=product.name,
336 billing_manager_email=billing_manager_customer.email,
337 )
339 await webhook_service.send(
340 session,
341 organization,
342 WebhookEventType.customer_seat_assigned,
343 seat,
344 )
346 return seat
348 async def get_seat_by_token( 1a
349 self,
350 session: AsyncReadSession,
351 invitation_token: str,
352 ) -> CustomerSeat | None:
353 repository = CustomerSeatRepository.from_session(session)
354 seat = await repository.get_by_invitation_token(
355 invitation_token,
356 options=repository.get_eager_options(),
357 )
359 if not seat or seat.is_revoked() or seat.is_claimed():
360 return None
362 if (
363 seat.invitation_token_expires_at
364 and seat.invitation_token_expires_at < datetime.now(UTC)
365 ):
366 return None
368 return seat
370 async def claim_seat( 1a
371 self,
372 session: AsyncSession,
373 invitation_token: str,
374 request_metadata: dict[str, Any] | None = None,
375 ) -> tuple[CustomerSeat, str]:
376 repository = CustomerSeatRepository.from_session(session)
378 seat = await repository.get_by_invitation_token(
379 invitation_token,
380 options=repository.get_eager_options(),
381 )
383 if not seat or seat.is_revoked():
384 raise InvalidInvitationToken(invitation_token)
386 if (
387 seat.invitation_token_expires_at
388 and seat.invitation_token_expires_at < datetime.now(UTC)
389 ):
390 raise InvalidInvitationToken(invitation_token)
392 # Reject already-claimed tokens for security
393 if seat.is_claimed():
394 raise InvalidInvitationToken(invitation_token)
396 # Get product and organization_id from either subscription or order
397 if seat.subscription_id and seat.subscription:
398 product = seat.subscription.product
399 organization_id = product.organization_id
400 product_id = product.id
401 elif seat.order_id and seat.order:
402 assert seat.order.product is not None
403 product = seat.order.product
404 organization_id = product.organization_id
405 product_id = product.id
406 else:
407 raise InvalidInvitationToken(invitation_token)
409 await self.check_seat_feature_enabled(session, organization_id)
411 if not seat.customer_id or not seat.customer:
412 raise InvalidInvitationToken(invitation_token)
414 seat.status = SeatStatus.claimed
415 seat.claimed_at = datetime.now(UTC)
416 seat.invitation_token = None # Single-use token
418 await session.flush()
420 await self._publish_seat_claimed_event(seat, product_id)
421 await self._enqueue_benefit_grant(seat, product_id)
422 session_token, _ = await customer_session_service.create_customer_session(
423 session, seat.customer
424 )
426 log.info(
427 "Seat claimed",
428 seat_id=seat.id,
429 customer_id=seat.customer_id,
430 subscription_id=seat.subscription_id,
431 **(request_metadata or {}),
432 )
434 await self._send_seat_claimed_webhook(session, organization_id, seat)
436 return seat, session_token
438 async def revoke_seat( 1a
439 self,
440 session: AsyncSession,
441 seat: CustomerSeat,
442 ) -> CustomerSeat:
443 # Get product from either subscription or order
444 if seat.subscription_id and seat.subscription:
445 organization_id = seat.subscription.product.organization_id
446 product_id = seat.subscription.product_id
447 elif seat.order_id and seat.order and seat.order.product_id:
448 organization_id = seat.order.organization.id
449 product_id = seat.order.product_id
450 else:
451 raise ValueError("Seat must have either subscription or order")
453 await self.check_seat_feature_enabled(session, organization_id)
455 # Capture customer_id before clearing to avoid race condition
456 original_customer_id = seat.customer_id
458 # Revoke benefits from the customer before clearing the customer_id
459 if original_customer_id:
460 if seat.subscription_id:
461 enqueue_job(
462 "benefit.enqueue_benefits_grants",
463 task="revoke",
464 customer_id=original_customer_id,
465 product_id=product_id,
466 subscription_id=seat.subscription_id,
467 )
468 else:
469 enqueue_job(
470 "benefit.enqueue_benefits_grants",
471 task="revoke",
472 customer_id=original_customer_id,
473 product_id=product_id,
474 order_id=seat.order_id,
475 )
477 seat.status = SeatStatus.revoked
478 seat.revoked_at = datetime.now(UTC)
479 seat.customer_id = None
480 seat.invitation_token = None
482 await session.flush()
484 log.info(
485 "Seat revoked",
486 seat_id=seat.id,
487 subscription_id=seat.subscription_id,
488 order_id=seat.order_id,
489 )
491 organization_repository = OrganizationRepository.from_session(session)
492 organization = await organization_repository.get_by_id(organization_id)
493 if organization:
494 await webhook_service.send(
495 session,
496 organization,
497 WebhookEventType.customer_seat_revoked,
498 seat,
499 )
501 return seat
503 async def get_seat( 1a
504 self,
505 session: AsyncReadSession,
506 auth_subject: AuthSubject[User | Organization],
507 seat_id: uuid.UUID,
508 ) -> CustomerSeat | None:
509 repository = CustomerSeatRepository.from_session(session)
511 seat = await repository.get_by_id(
512 seat_id,
513 options=repository.get_eager_options(),
514 )
516 if not seat:
517 return None
519 # Get organization_id from either subscription or order
520 if seat.subscription_id and seat.subscription:
521 organization_id = seat.subscription.product.organization_id
522 elif seat.order_id and seat.order:
523 organization_id = seat.order.organization.id
524 else:
525 return None
527 if isinstance(auth_subject.subject, Organization):
528 if organization_id != auth_subject.subject.id:
529 return None
530 elif isinstance(auth_subject.subject, User):
531 pass
533 await self.check_seat_feature_enabled(session, organization_id)
534 return seat
536 async def resend_invitation( 1a
537 self,
538 session: AsyncSession,
539 seat: CustomerSeat,
540 ) -> CustomerSeat:
541 # Get product info from either subscription or order
542 if seat.subscription_id and seat.subscription and seat.subscription.product:
543 organization_id = seat.subscription.product.organization_id
544 product_name = seat.subscription.product.name
545 billing_manager_email = seat.subscription.customer.email
546 elif seat.order_id and seat.order and seat.order.product:
547 organization_id = seat.order.product.organization_id
548 product_name = seat.order.product.name
549 billing_manager_email = seat.order.customer.email
550 else:
551 raise ValueError("Seat must have either subscription or order")
553 await self.check_seat_feature_enabled(session, organization_id)
555 if not seat.is_pending():
556 raise SeatNotPending()
558 if not seat.customer or not seat.invitation_token:
559 raise InvalidInvitationToken(seat.invitation_token or "")
561 log.info(
562 "Resending seat invitation",
563 seat_id=seat.id,
564 customer_id=seat.customer_id,
565 subscription_id=seat.subscription_id,
566 order_id=seat.order_id,
567 )
569 organization_repository = OrganizationRepository.from_session(session)
570 organization = await organization_repository.get_by_id(organization_id)
571 if organization:
572 send_seat_invitation_email(
573 customer_email=seat.customer.email,
574 seat=seat,
575 organization=organization,
576 product_name=product_name,
577 billing_manager_email=billing_manager_email,
578 )
580 return seat
582 async def get_seat_for_customer( 1a
583 self,
584 session: AsyncReadSession,
585 customer: Customer,
586 seat_id: uuid.UUID,
587 ) -> CustomerSeat | None:
588 """Get a seat and verify it belongs to a subscription or order owned by the customer."""
589 repository = CustomerSeatRepository.from_session(session)
591 seat = await repository.get_by_id_for_customer(
592 seat_id,
593 customer.id,
594 options=repository.get_eager_options(),
595 )
597 return seat
599 async def revoke_all_seats_for_subscription( 1a
600 self,
601 session: AsyncSession,
602 subscription: Subscription,
603 ) -> int:
604 """
605 Revoke all non-revoked seats for a subscription.
607 This is typically called when a subscription is cancelled to ensure
608 all seat holders lose access to their benefits.
610 Returns the number of seats revoked.
611 """
612 repository = CustomerSeatRepository.from_session(session)
614 all_seats = await repository.list_by_subscription_id(
615 subscription.id,
616 options=repository.get_eager_options(),
617 )
619 active_seats = [seat for seat in all_seats if not seat.is_revoked()]
621 revoked_count = 0
622 for seat in active_seats:
623 await self.revoke_seat(session, seat)
624 revoked_count += 1
626 if revoked_count > 0:
627 await session.flush()
628 log.info(
629 "Revoked all seats for subscription",
630 subscription_id=subscription.id,
631 seats_revoked=revoked_count,
632 )
634 return revoked_count
636 async def _find__or_create_customer( 1a
637 self,
638 session: AsyncSession,
639 organization_id: uuid.UUID,
640 email: str | None,
641 external_customer_id: str | None,
642 customer_id: uuid.UUID | None,
643 ) -> Customer:
644 # Validate that exactly one identifier is provided
645 provided_identifiers = [email, external_customer_id, customer_id]
646 non_null_count = sum(
647 1 for identifier in provided_identifiers if identifier is not None
648 )
650 if non_null_count != 1:
651 raise InvalidSeatAssignmentRequest()
653 customer_repository = CustomerRepository.from_session(session)
654 customer = None
656 # Find customer based on provided identifier
657 if email:
658 customer = await customer_repository.get_by_email_and_organization(
659 email, organization_id
660 )
661 elif external_customer_id:
662 customer = await customer_repository.get_by_external_id_and_organization(
663 external_customer_id, organization_id
664 )
665 elif customer_id:
666 customer = await customer_repository.get_by_id_and_organization(
667 customer_id, organization_id
668 )
670 if not customer and not email:
671 raise CustomerNotFound(external_customer_id or str(customer_id))
672 elif not customer:
673 customer = Customer(
674 organization_id=organization_id,
675 email=email,
676 )
677 session.add(customer)
678 await session.flush()
679 return customer
682seat_service = SeatService() 1a