Coverage for polar/customer_seat/service.py: 17%

295 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1import secrets 1a

2import uuid 1a

3from collections.abc import Sequence 1a

4from datetime import UTC, datetime, timedelta 1a

5from typing import Any 1a

6 

7import structlog 1a

8 

9from polar.auth.models import AuthSubject 1a

10from polar.customer.repository import CustomerRepository 1a

11from polar.customer_seat.sender import send_seat_invitation_email 1a

12from polar.customer_session.service import ( 1a

13 customer_session as customer_session_service, 

14) 

15from polar.eventstream.service import publish as eventstream_publish 1a

16from polar.exceptions import PolarError 1a

17from polar.kit.db.postgres import AsyncSession 1a

18from polar.models import ( 1a

19 Customer, 

20 CustomerSeat, 

21 Order, 

22 Organization, 

23 Product, 

24 Subscription, 

25 User, 

26) 

27from polar.models.customer_seat import SeatStatus 1a

28from polar.models.order import OrderStatus 1a

29from polar.models.webhook_endpoint import WebhookEventType 1a

30from polar.organization.repository import OrganizationRepository 1a

31from polar.postgres import AsyncReadSession 1a

32from polar.webhook.service import webhook as webhook_service 1a

33from polar.worker import enqueue_job 1a

34 

35from .repository import CustomerSeatRepository 1a

36 

37log = structlog.get_logger() 1a

38 

39 

40class SeatError(PolarError): ... 1a

41 

42 

43class SeatNotAvailable(SeatError): 1a

44 def __init__(self, source_id: uuid.UUID, reason: str | None = None) -> None: 1a

45 self.source_id = source_id 

46 message = reason or f"No available seats for {source_id}" 

47 super().__init__(message, 400) 

48 

49 

50class InvalidInvitationToken(SeatError): 1a

51 def __init__(self, token: str) -> None: 1a

52 self.token = token 

53 message = "Invalid or expired invitation token" 

54 super().__init__(message, 400) 

55 

56 

57class FeatureNotEnabled(SeatError): 1a

58 def __init__(self) -> None: 1a

59 message = "Seat-based pricing is not enabled for this organization" 

60 super().__init__(message, 403) 

61 

62 

63class SeatAlreadyAssigned(SeatError): 1a

64 def __init__(self, customer_email: str) -> None: 1a

65 self.customer_email = customer_email 

66 message = f"Seat already assigned to customer {customer_email}" 

67 super().__init__(message, 400) 

68 

69 

70class SeatNotPending(SeatError): 1a

71 def __init__(self) -> None: 1a

72 message = "Seat is not in pending status" 

73 super().__init__(message, 400) 

74 

75 

76class InvalidSeatAssignmentRequest(SeatError): 1a

77 def __init__(self) -> None: 1a

78 message = "Exactly one of email, external_customer_id, or customer_id must be provided" 

79 super().__init__(message, 400) 

80 

81 

82class CustomerNotFound(SeatError): 1a

83 def __init__(self, customer_identifier: str) -> None: 1a

84 self.customer_identifier = customer_identifier 

85 message = f"Customer not found: {customer_identifier}" 

86 super().__init__(message, 404) 

87 

88 

89SeatContainer = Subscription | Order 1a

90 

91 

92class SeatService: 1a

93 def _get_customer_id(self, container: SeatContainer) -> uuid.UUID: 1a

94 return container.customer_id 

95 

96 def _get_product(self, container: SeatContainer) -> Product | None: 1a

97 return container.product 

98 

99 def _get_organization_id(self, container: SeatContainer) -> uuid.UUID: 1a

100 return container.organization.id 

101 

102 def _get_seats_count(self, container: SeatContainer) -> int: 1a

103 return container.seats or 0 

104 

105 def _get_container_id(self, container: SeatContainer) -> uuid.UUID: 1a

106 return container.id 

107 

108 def _is_subscription(self, container: SeatContainer) -> bool: 1a

109 return isinstance(container, Subscription) 

110 

111 async def _enqueue_benefit_grant( 1a

112 self, seat: CustomerSeat, product_id: uuid.UUID 

113 ) -> None: 

114 """Enqueue benefit grant job for a claimed seat.""" 

115 if seat.subscription_id: 

116 enqueue_job( 

117 "benefit.enqueue_benefits_grants", 

118 task="grant", 

119 customer_id=seat.customer_id, 

120 product_id=product_id, 

121 subscription_id=seat.subscription_id, 

122 ) 

123 else: 

124 enqueue_job( 

125 "benefit.enqueue_benefits_grants", 

126 task="grant", 

127 customer_id=seat.customer_id, 

128 product_id=product_id, 

129 order_id=seat.order_id, 

130 ) 

131 

132 async def _publish_seat_claimed_event( 1a

133 self, seat: CustomerSeat, product_id: uuid.UUID 

134 ) -> None: 

135 """Publish eventstream event for seat claimed.""" 

136 await eventstream_publish( 

137 "customer_seat.claimed", 

138 { 

139 "seat_id": str(seat.id), 

140 "subscription_id": str(seat.subscription_id) 

141 if seat.subscription_id 

142 else None, 

143 "order_id": str(seat.order_id) if seat.order_id else None, 

144 "product_id": str(product_id), 

145 }, 

146 customer_id=seat.customer_id, 

147 ) 

148 

149 async def _send_seat_claimed_webhook( 1a

150 self, session: AsyncSession, organization_id: uuid.UUID, seat: CustomerSeat 

151 ) -> None: 

152 """Send webhook for seat claimed.""" 

153 organization_repository = OrganizationRepository.from_session(session) 

154 organization = await organization_repository.get_by_id(organization_id) 

155 if organization: 

156 await webhook_service.send( 

157 session, 

158 organization, 

159 WebhookEventType.customer_seat_claimed, 

160 seat, 

161 ) 

162 

163 async def check_seat_feature_enabled( 1a

164 self, session: AsyncReadSession, organization_id: uuid.UUID 

165 ) -> None: 

166 from polar.organization.repository import OrganizationRepository 

167 

168 organization_repository = OrganizationRepository.from_session(session) 

169 organization = await organization_repository.get_by_id(organization_id) 

170 if not organization: 

171 raise FeatureNotEnabled() 

172 if not organization.feature_settings.get("seat_based_pricing_enabled", False): 

173 raise FeatureNotEnabled() 

174 

175 async def list_seats( 1a

176 self, 

177 session: AsyncReadSession, 

178 container: SeatContainer, 

179 ) -> Sequence[CustomerSeat]: 

180 await self.check_seat_feature_enabled( 

181 session, self._get_organization_id(container) 

182 ) 

183 repository = CustomerSeatRepository.from_session(session) 

184 return await repository.list_by_container( 

185 container, 

186 options=repository.get_eager_options(), 

187 ) 

188 

189 async def get_available_seats_count( 1a

190 self, 

191 session: AsyncReadSession, 

192 container: SeatContainer, 

193 ) -> int: 

194 await self.check_seat_feature_enabled( 

195 session, self._get_organization_id(container) 

196 ) 

197 repository = CustomerSeatRepository.from_session(session) 

198 return await repository.get_available_seats_count_for_container(container) 

199 

200 async def count_assigned_seats_for_subscription( 1a

201 self, 

202 session: AsyncReadSession, 

203 subscription: Subscription, 

204 ) -> int: 

205 repository = CustomerSeatRepository.from_session(session) 

206 return await repository.count_assigned_seats_for_subscription(subscription.id) 

207 

208 async def assign_seat( 1a

209 self, 

210 session: AsyncSession, 

211 container: SeatContainer, 

212 *, 

213 email: str | None = None, 

214 external_customer_id: str | None = None, 

215 customer_id: uuid.UUID | None = None, 

216 metadata: dict[str, Any] | None = None, 

217 immediate_claim: bool = False, 

218 ) -> CustomerSeat: 

219 product = self._get_product(container) 

220 source_id = self._get_container_id(container) 

221 

222 if product is None: 

223 raise SeatNotAvailable( 

224 source_id, 

225 "Container has no associated product", 

226 ) 

227 

228 organization_id = self._get_organization_id(container) 

229 billing_manager_customer = container.customer 

230 is_subscription = self._is_subscription(container) 

231 

232 await self.check_seat_feature_enabled(session, organization_id) 

233 

234 # Validate order payment status 

235 if isinstance(container, Order): 

236 if container.status == OrderStatus.pending: 

237 raise SeatNotAvailable( 

238 source_id, "Order must be paid before assigning seats" 

239 ) 

240 

241 repository = CustomerSeatRepository.from_session(session) 

242 

243 available_seats = await repository.get_available_seats_count_for_container( 

244 container 

245 ) 

246 

247 if available_seats <= 0: 

248 raise SeatNotAvailable(source_id) 

249 

250 customer = await self._find__or_create_customer( 

251 session, 

252 organization_id, 

253 email, 

254 external_customer_id, 

255 customer_id, 

256 ) 

257 

258 existing_seat = await repository.get_by_container_and_customer( 

259 container, customer.id 

260 ) 

261 

262 if existing_seat and not existing_seat.is_revoked(): 

263 identifier = email or external_customer_id or str(customer_id) 

264 raise SeatAlreadyAssigned(identifier) 

265 

266 # Only generate invitation token for standard (non-immediate) claims 

267 if immediate_claim: 

268 invitation_token = None 

269 token_expires_at = None 

270 else: 

271 invitation_token = secrets.token_urlsafe(32) 

272 token_expires_at = datetime.now(UTC) + timedelta(days=1) 

273 

274 revoked_seat = await repository.get_revoked_seat_by_container(container) 

275 

276 if revoked_seat: 

277 seat = revoked_seat 

278 seat.status = SeatStatus.claimed if immediate_claim else SeatStatus.pending 

279 seat.invitation_token = invitation_token 

280 seat.invitation_token_expires_at = token_expires_at 

281 seat.customer_id = customer.id 

282 seat.seat_metadata = metadata or {} 

283 seat.revoked_at = None 

284 seat.claimed_at = datetime.now(UTC) if immediate_claim else None 

285 else: 

286 seat_data = { 

287 "status": SeatStatus.claimed if immediate_claim else SeatStatus.pending, 

288 "invitation_token": invitation_token, 

289 "invitation_token_expires_at": token_expires_at, 

290 "customer_id": customer.id, 

291 "seat_metadata": metadata or {}, 

292 "claimed_at": datetime.now(UTC) if immediate_claim else None, 

293 } 

294 if is_subscription: 

295 seat_data["subscription_id"] = source_id 

296 else: 

297 seat_data["order_id"] = source_id 

298 

299 seat = CustomerSeat(**seat_data) 

300 session.add(seat) 

301 

302 await session.flush() 

303 

304 if immediate_claim: 

305 # Immediate claim flow: grant benefits and trigger claimed webhook 

306 log.info( 

307 "Seat immediately claimed", 

308 subscription_id=seat.subscription_id, 

309 order_id=seat.order_id, 

310 email=email, 

311 customer_id=customer.id, 

312 ) 

313 

314 await self._publish_seat_claimed_event(seat, product.id) 

315 await self._enqueue_benefit_grant(seat, product.id) 

316 await self._send_seat_claimed_webhook(session, organization_id, seat) 

317 else: 

318 # Standard flow: send invitation email and trigger assigned webhook 

319 log.info( 

320 "Seat assigned", 

321 subscription_id=seat.subscription_id, 

322 order_id=seat.order_id, 

323 email=email, 

324 customer_id=customer.id, 

325 invitation_token=invitation_token or "none", 

326 ) 

327 

328 organization_repository = OrganizationRepository.from_session(session) 

329 organization = await organization_repository.get_by_id(organization_id) 

330 if organization: 

331 send_seat_invitation_email( 

332 customer_email=customer.email, 

333 seat=seat, 

334 organization=organization, 

335 product_name=product.name, 

336 billing_manager_email=billing_manager_customer.email, 

337 ) 

338 

339 await webhook_service.send( 

340 session, 

341 organization, 

342 WebhookEventType.customer_seat_assigned, 

343 seat, 

344 ) 

345 

346 return seat 

347 

348 async def get_seat_by_token( 1a

349 self, 

350 session: AsyncReadSession, 

351 invitation_token: str, 

352 ) -> CustomerSeat | None: 

353 repository = CustomerSeatRepository.from_session(session) 

354 seat = await repository.get_by_invitation_token( 

355 invitation_token, 

356 options=repository.get_eager_options(), 

357 ) 

358 

359 if not seat or seat.is_revoked() or seat.is_claimed(): 

360 return None 

361 

362 if ( 

363 seat.invitation_token_expires_at 

364 and seat.invitation_token_expires_at < datetime.now(UTC) 

365 ): 

366 return None 

367 

368 return seat 

369 

370 async def claim_seat( 1a

371 self, 

372 session: AsyncSession, 

373 invitation_token: str, 

374 request_metadata: dict[str, Any] | None = None, 

375 ) -> tuple[CustomerSeat, str]: 

376 repository = CustomerSeatRepository.from_session(session) 

377 

378 seat = await repository.get_by_invitation_token( 

379 invitation_token, 

380 options=repository.get_eager_options(), 

381 ) 

382 

383 if not seat or seat.is_revoked(): 

384 raise InvalidInvitationToken(invitation_token) 

385 

386 if ( 

387 seat.invitation_token_expires_at 

388 and seat.invitation_token_expires_at < datetime.now(UTC) 

389 ): 

390 raise InvalidInvitationToken(invitation_token) 

391 

392 # Reject already-claimed tokens for security 

393 if seat.is_claimed(): 

394 raise InvalidInvitationToken(invitation_token) 

395 

396 # Get product and organization_id from either subscription or order 

397 if seat.subscription_id and seat.subscription: 

398 product = seat.subscription.product 

399 organization_id = product.organization_id 

400 product_id = product.id 

401 elif seat.order_id and seat.order: 

402 assert seat.order.product is not None 

403 product = seat.order.product 

404 organization_id = product.organization_id 

405 product_id = product.id 

406 else: 

407 raise InvalidInvitationToken(invitation_token) 

408 

409 await self.check_seat_feature_enabled(session, organization_id) 

410 

411 if not seat.customer_id or not seat.customer: 

412 raise InvalidInvitationToken(invitation_token) 

413 

414 seat.status = SeatStatus.claimed 

415 seat.claimed_at = datetime.now(UTC) 

416 seat.invitation_token = None # Single-use token 

417 

418 await session.flush() 

419 

420 await self._publish_seat_claimed_event(seat, product_id) 

421 await self._enqueue_benefit_grant(seat, product_id) 

422 session_token, _ = await customer_session_service.create_customer_session( 

423 session, seat.customer 

424 ) 

425 

426 log.info( 

427 "Seat claimed", 

428 seat_id=seat.id, 

429 customer_id=seat.customer_id, 

430 subscription_id=seat.subscription_id, 

431 **(request_metadata or {}), 

432 ) 

433 

434 await self._send_seat_claimed_webhook(session, organization_id, seat) 

435 

436 return seat, session_token 

437 

438 async def revoke_seat( 1a

439 self, 

440 session: AsyncSession, 

441 seat: CustomerSeat, 

442 ) -> CustomerSeat: 

443 # Get product from either subscription or order 

444 if seat.subscription_id and seat.subscription: 

445 organization_id = seat.subscription.product.organization_id 

446 product_id = seat.subscription.product_id 

447 elif seat.order_id and seat.order and seat.order.product_id: 

448 organization_id = seat.order.organization.id 

449 product_id = seat.order.product_id 

450 else: 

451 raise ValueError("Seat must have either subscription or order") 

452 

453 await self.check_seat_feature_enabled(session, organization_id) 

454 

455 # Capture customer_id before clearing to avoid race condition 

456 original_customer_id = seat.customer_id 

457 

458 # Revoke benefits from the customer before clearing the customer_id 

459 if original_customer_id: 

460 if seat.subscription_id: 

461 enqueue_job( 

462 "benefit.enqueue_benefits_grants", 

463 task="revoke", 

464 customer_id=original_customer_id, 

465 product_id=product_id, 

466 subscription_id=seat.subscription_id, 

467 ) 

468 else: 

469 enqueue_job( 

470 "benefit.enqueue_benefits_grants", 

471 task="revoke", 

472 customer_id=original_customer_id, 

473 product_id=product_id, 

474 order_id=seat.order_id, 

475 ) 

476 

477 seat.status = SeatStatus.revoked 

478 seat.revoked_at = datetime.now(UTC) 

479 seat.customer_id = None 

480 seat.invitation_token = None 

481 

482 await session.flush() 

483 

484 log.info( 

485 "Seat revoked", 

486 seat_id=seat.id, 

487 subscription_id=seat.subscription_id, 

488 order_id=seat.order_id, 

489 ) 

490 

491 organization_repository = OrganizationRepository.from_session(session) 

492 organization = await organization_repository.get_by_id(organization_id) 

493 if organization: 

494 await webhook_service.send( 

495 session, 

496 organization, 

497 WebhookEventType.customer_seat_revoked, 

498 seat, 

499 ) 

500 

501 return seat 

502 

503 async def get_seat( 1a

504 self, 

505 session: AsyncReadSession, 

506 auth_subject: AuthSubject[User | Organization], 

507 seat_id: uuid.UUID, 

508 ) -> CustomerSeat | None: 

509 repository = CustomerSeatRepository.from_session(session) 

510 

511 seat = await repository.get_by_id( 

512 seat_id, 

513 options=repository.get_eager_options(), 

514 ) 

515 

516 if not seat: 

517 return None 

518 

519 # Get organization_id from either subscription or order 

520 if seat.subscription_id and seat.subscription: 

521 organization_id = seat.subscription.product.organization_id 

522 elif seat.order_id and seat.order: 

523 organization_id = seat.order.organization.id 

524 else: 

525 return None 

526 

527 if isinstance(auth_subject.subject, Organization): 

528 if organization_id != auth_subject.subject.id: 

529 return None 

530 elif isinstance(auth_subject.subject, User): 

531 pass 

532 

533 await self.check_seat_feature_enabled(session, organization_id) 

534 return seat 

535 

536 async def resend_invitation( 1a

537 self, 

538 session: AsyncSession, 

539 seat: CustomerSeat, 

540 ) -> CustomerSeat: 

541 # Get product info from either subscription or order 

542 if seat.subscription_id and seat.subscription and seat.subscription.product: 

543 organization_id = seat.subscription.product.organization_id 

544 product_name = seat.subscription.product.name 

545 billing_manager_email = seat.subscription.customer.email 

546 elif seat.order_id and seat.order and seat.order.product: 

547 organization_id = seat.order.product.organization_id 

548 product_name = seat.order.product.name 

549 billing_manager_email = seat.order.customer.email 

550 else: 

551 raise ValueError("Seat must have either subscription or order") 

552 

553 await self.check_seat_feature_enabled(session, organization_id) 

554 

555 if not seat.is_pending(): 

556 raise SeatNotPending() 

557 

558 if not seat.customer or not seat.invitation_token: 

559 raise InvalidInvitationToken(seat.invitation_token or "") 

560 

561 log.info( 

562 "Resending seat invitation", 

563 seat_id=seat.id, 

564 customer_id=seat.customer_id, 

565 subscription_id=seat.subscription_id, 

566 order_id=seat.order_id, 

567 ) 

568 

569 organization_repository = OrganizationRepository.from_session(session) 

570 organization = await organization_repository.get_by_id(organization_id) 

571 if organization: 

572 send_seat_invitation_email( 

573 customer_email=seat.customer.email, 

574 seat=seat, 

575 organization=organization, 

576 product_name=product_name, 

577 billing_manager_email=billing_manager_email, 

578 ) 

579 

580 return seat 

581 

582 async def get_seat_for_customer( 1a

583 self, 

584 session: AsyncReadSession, 

585 customer: Customer, 

586 seat_id: uuid.UUID, 

587 ) -> CustomerSeat | None: 

588 """Get a seat and verify it belongs to a subscription or order owned by the customer.""" 

589 repository = CustomerSeatRepository.from_session(session) 

590 

591 seat = await repository.get_by_id_for_customer( 

592 seat_id, 

593 customer.id, 

594 options=repository.get_eager_options(), 

595 ) 

596 

597 return seat 

598 

599 async def revoke_all_seats_for_subscription( 1a

600 self, 

601 session: AsyncSession, 

602 subscription: Subscription, 

603 ) -> int: 

604 """ 

605 Revoke all non-revoked seats for a subscription. 

606 

607 This is typically called when a subscription is cancelled to ensure 

608 all seat holders lose access to their benefits. 

609 

610 Returns the number of seats revoked. 

611 """ 

612 repository = CustomerSeatRepository.from_session(session) 

613 

614 all_seats = await repository.list_by_subscription_id( 

615 subscription.id, 

616 options=repository.get_eager_options(), 

617 ) 

618 

619 active_seats = [seat for seat in all_seats if not seat.is_revoked()] 

620 

621 revoked_count = 0 

622 for seat in active_seats: 

623 await self.revoke_seat(session, seat) 

624 revoked_count += 1 

625 

626 if revoked_count > 0: 

627 await session.flush() 

628 log.info( 

629 "Revoked all seats for subscription", 

630 subscription_id=subscription.id, 

631 seats_revoked=revoked_count, 

632 ) 

633 

634 return revoked_count 

635 

636 async def _find__or_create_customer( 1a

637 self, 

638 session: AsyncSession, 

639 organization_id: uuid.UUID, 

640 email: str | None, 

641 external_customer_id: str | None, 

642 customer_id: uuid.UUID | None, 

643 ) -> Customer: 

644 # Validate that exactly one identifier is provided 

645 provided_identifiers = [email, external_customer_id, customer_id] 

646 non_null_count = sum( 

647 1 for identifier in provided_identifiers if identifier is not None 

648 ) 

649 

650 if non_null_count != 1: 

651 raise InvalidSeatAssignmentRequest() 

652 

653 customer_repository = CustomerRepository.from_session(session) 

654 customer = None 

655 

656 # Find customer based on provided identifier 

657 if email: 

658 customer = await customer_repository.get_by_email_and_organization( 

659 email, organization_id 

660 ) 

661 elif external_customer_id: 

662 customer = await customer_repository.get_by_external_id_and_organization( 

663 external_customer_id, organization_id 

664 ) 

665 elif customer_id: 

666 customer = await customer_repository.get_by_id_and_organization( 

667 customer_id, organization_id 

668 ) 

669 

670 if not customer and not email: 

671 raise CustomerNotFound(external_customer_id or str(customer_id)) 

672 elif not customer: 

673 customer = Customer( 

674 organization_id=organization_id, 

675 email=email, 

676 ) 

677 session.add(customer) 

678 await session.flush() 

679 return customer 

680 

681 

682seat_service = SeatService() 1a