Coverage for polar/customer_seat/repository.py: 31%

103 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 16:17 +0000

1from collections.abc import Sequence 1a

2from uuid import UUID 1a

3 

4from sqlalchemy import Select, func, select 1a

5from sqlalchemy.orm import joinedload 1a

6 

7from polar.auth.models import AuthSubject, Organization, User, is_organization, is_user 1a

8from polar.kit.repository import RepositoryBase 1a

9from polar.kit.repository.base import Options 1a

10from polar.models import ( 1a

11 Customer, 

12 CustomerSeat, 

13 Order, 

14 Product, 

15 Subscription, 

16 UserOrganization, 

17) 

18from polar.models.customer_seat import SeatStatus 1a

19from polar.order.repository import OrderRepository 1a

20from polar.subscription.repository import SubscriptionRepository 1a

21 

22SeatContainer = Subscription | Order 1a

23 

24 

25class CustomerSeatRepository(RepositoryBase[CustomerSeat]): 1a

26 model = CustomerSeat 1a

27 

28 async def list_by_container( 1a

29 self, container: SeatContainer, *, options: Options = () 

30 ) -> Sequence[CustomerSeat]: 

31 """List seats for a subscription or order.""" 

32 if isinstance(container, Subscription): 

33 return await self.list_by_subscription_id(container.id, options=options) 

34 else: 

35 return await self.list_by_order_id(container.id, options=options) 

36 

37 async def get_available_seats_count_for_container( 1a

38 self, container: SeatContainer 

39 ) -> int: 

40 """Get available seats count for a subscription or order.""" 

41 if isinstance(container, Subscription): 

42 return await self.get_available_seats_count(container.id) 

43 else: 

44 return await self.get_available_seats_count_for_order(container.id) 

45 

46 async def get_by_container_and_customer( 1a

47 self, 

48 container: SeatContainer, 

49 customer_id: UUID, 

50 *, 

51 options: Options = (), 

52 ) -> CustomerSeat | None: 

53 if isinstance(container, Subscription): 

54 return await self.get_by_subscription_and_customer( 

55 container.id, customer_id, options=options 

56 ) 

57 else: 

58 return await self.get_by_order_and_customer( 

59 container.id, customer_id, options=options 

60 ) 

61 

62 async def get_revoked_seat_by_container( 1a

63 self, 

64 container: SeatContainer, 

65 *, 

66 options: Options = (), 

67 ) -> CustomerSeat | None: 

68 if isinstance(container, Subscription): 

69 return await self.get_revoked_seat_by_subscription( 

70 container.id, options=options 

71 ) 

72 else: 

73 return await self.get_revoked_seat_by_order(container.id, options=options) 

74 

75 async def list_by_subscription_id( 1a

76 self, subscription_id: UUID, *, options: Options = () 

77 ) -> Sequence[CustomerSeat]: 

78 statement = ( 

79 select(CustomerSeat) 

80 .where(CustomerSeat.subscription_id == subscription_id) 

81 .options(*options) 

82 ) 

83 return await self.get_all(statement) 

84 

85 async def list_by_order_id( 1a

86 self, order_id: UUID, *, options: Options = () 

87 ) -> Sequence[CustomerSeat]: 

88 statement = ( 

89 select(CustomerSeat) 

90 .where(CustomerSeat.order_id == order_id) 

91 .options(*options) 

92 ) 

93 return await self.get_all(statement) 

94 

95 async def get_by_invitation_token( 1a

96 self, token: str, *, options: Options = () 

97 ) -> CustomerSeat | None: 

98 statement = ( 

99 select(CustomerSeat) 

100 .where(CustomerSeat.invitation_token == token) 

101 .options(*options) 

102 ) 

103 return await self.get_one_or_none(statement) 

104 

105 async def count_assigned_seats_for_subscription(self, subscription_id: UUID) -> int: 1a

106 statement = select(func.count(CustomerSeat.id)).where( 

107 CustomerSeat.subscription_id == subscription_id, 

108 CustomerSeat.status.in_([SeatStatus.pending, SeatStatus.claimed]), 

109 ) 

110 result = await self.session.execute(statement) 

111 return result.scalar_one() 

112 

113 async def get_available_seats_count(self, subscription_id: UUID) -> int: 1a

114 subscription_statement = select(Subscription).where( 

115 Subscription.id == subscription_id 

116 ) 

117 subscription_repository = SubscriptionRepository.from_session(self.session) 

118 subscription = await subscription_repository.get_one_or_none( 

119 subscription_statement 

120 ) 

121 

122 if not subscription or subscription.seats is None: 

123 return 0 

124 claimed_statement = select(CustomerSeat).where( 

125 CustomerSeat.subscription_id == subscription_id, 

126 CustomerSeat.status.in_([SeatStatus.claimed, SeatStatus.pending]), 

127 ) 

128 claimed_seats = await self.get_all(claimed_statement) 

129 

130 return max(0, subscription.seats - len(claimed_seats)) 

131 

132 async def get_available_seats_count_for_order(self, order_id: UUID) -> int: 1a

133 order_repository = OrderRepository.from_session(self.session) 

134 order_statement = select(Order).where(Order.id == order_id) 

135 order = await order_repository.get_one_or_none(order_statement) 

136 

137 if not order or order.seats is None: 

138 return 0 

139 

140 claimed_statement = select(CustomerSeat).where( 

141 CustomerSeat.order_id == order_id, 

142 CustomerSeat.status.in_([SeatStatus.claimed, SeatStatus.pending]), 

143 ) 

144 claimed_seats = await self.get_all(claimed_statement) 

145 

146 return max(0, order.seats - len(claimed_seats)) 

147 

148 async def list_by_customer_id( 1a

149 self, customer_id: UUID, *, options: Options = () 

150 ) -> Sequence[CustomerSeat]: 

151 statement = ( 

152 select(CustomerSeat) 

153 .where(CustomerSeat.customer_id == customer_id) 

154 .options(*options) 

155 ) 

156 return await self.get_all(statement) 

157 

158 async def get_by_subscription_and_customer( 1a

159 self, 

160 subscription_id: UUID, 

161 customer_id: UUID, 

162 *, 

163 options: Options = (), 

164 ) -> CustomerSeat | None: 

165 statement = ( 

166 select(CustomerSeat) 

167 .where( 

168 CustomerSeat.subscription_id == subscription_id, 

169 CustomerSeat.customer_id == customer_id, 

170 ) 

171 .options(*options) 

172 ) 

173 return await self.get_one_or_none(statement) 

174 

175 async def get_by_order_and_customer( 1a

176 self, 

177 order_id: UUID, 

178 customer_id: UUID, 

179 *, 

180 options: Options = (), 

181 ) -> CustomerSeat | None: 

182 statement = ( 

183 select(CustomerSeat) 

184 .where( 

185 CustomerSeat.order_id == order_id, 

186 CustomerSeat.customer_id == customer_id, 

187 ) 

188 .options(*options) 

189 ) 

190 return await self.get_one_or_none(statement) 

191 

192 async def get_revoked_seat_by_subscription( 1a

193 self, 

194 subscription_id: UUID, 

195 *, 

196 options: Options = (), 

197 ) -> CustomerSeat | None: 

198 """Get a revoked seat for a subscription that can be reused.""" 

199 statement = ( 

200 select(CustomerSeat) 

201 .where( 

202 CustomerSeat.subscription_id == subscription_id, 

203 CustomerSeat.status == SeatStatus.revoked, 

204 ) 

205 .options(*options) 

206 .limit(1) 

207 ) 

208 return await self.get_one_or_none(statement) 

209 

210 async def get_revoked_seat_by_order( 1a

211 self, 

212 order_id: UUID, 

213 *, 

214 options: Options = (), 

215 ) -> CustomerSeat | None: 

216 """Get a revoked seat for an order that can be reused.""" 

217 statement = ( 

218 select(CustomerSeat) 

219 .where( 

220 CustomerSeat.order_id == order_id, 

221 CustomerSeat.status == SeatStatus.revoked, 

222 ) 

223 .options(*options) 

224 .limit(1) 

225 ) 

226 return await self.get_one_or_none(statement) 

227 

228 async def get_by_id( 1a

229 self, 

230 seat_id: UUID, 

231 *, 

232 options: Options = (), 

233 ) -> CustomerSeat | None: 

234 """Get a seat by ID.""" 

235 statement = ( 

236 select(CustomerSeat).where(CustomerSeat.id == seat_id).options(*options) 

237 ) 

238 return await self.get_one_or_none(statement) 

239 

240 async def get_by_id_for_customer( 1a

241 self, 

242 seat_id: UUID, 

243 customer_id: UUID, 

244 *, 

245 options: Options = (), 

246 ) -> CustomerSeat | None: 

247 """Get a seat by ID and verify it belongs to a subscription or order owned by the customer.""" 

248 statement = ( 

249 select(CustomerSeat) 

250 .outerjoin(Subscription, CustomerSeat.subscription_id == Subscription.id) 

251 .outerjoin(Order, CustomerSeat.order_id == Order.id) 

252 .where( 

253 CustomerSeat.id == seat_id, 

254 ( 

255 (Subscription.customer_id == customer_id) 

256 | (Order.customer_id == customer_id) 

257 ), 

258 ) 

259 .options(*options) 

260 ) 

261 return await self.get_one_or_none(statement) 

262 

263 def get_readable_statement( 1a

264 self, auth_subject: AuthSubject[User | Organization] 

265 ) -> Select[tuple[CustomerSeat]]: 

266 """ 

267 Get a statement filtered by authorization. 

268 

269 Seats are readable by users/organizations who have access to the product's organization. 

270 Handles both subscription-based and order-based seats. 

271 """ 

272 

273 statement = ( 

274 self.get_base_statement() 

275 .outerjoin(Subscription, CustomerSeat.subscription_id == Subscription.id) 

276 .outerjoin(Order, CustomerSeat.order_id == Order.id) 

277 .outerjoin( 

278 Product, 

279 (Subscription.product_id == Product.id) 

280 | (Order.product_id == Product.id), 

281 ) 

282 ) 

283 

284 if is_user(auth_subject): 

285 user_org_ids = select(UserOrganization.organization_id).where( 

286 UserOrganization.user_id == auth_subject.subject.id, 

287 UserOrganization.deleted_at.is_(None), 

288 ) 

289 statement = statement.where(Product.organization_id.in_(user_org_ids)) 

290 elif is_organization(auth_subject): 

291 statement = statement.where( 

292 Product.organization_id == auth_subject.subject.id 

293 ) 

294 

295 return statement 

296 

297 async def get_by_id_and_auth_subject( 1a

298 self, 

299 auth_subject: AuthSubject[User | Organization], 

300 seat_id: UUID, 

301 *, 

302 options: Options = (), 

303 ) -> CustomerSeat | None: 

304 """Get a seat by ID filtered by auth subject.""" 

305 statement = ( 

306 self.get_readable_statement(auth_subject) 

307 .where(CustomerSeat.id == seat_id) 

308 .options(*options) 

309 ) 

310 return await self.get_one_or_none(statement) 

311 

312 async def get_by_subscription_and_auth_subject( 1a

313 self, 

314 auth_subject: AuthSubject[User | Organization], 

315 seat_id: UUID, 

316 subscription_id: UUID, 

317 *, 

318 options: Options = (), 

319 ) -> CustomerSeat | None: 

320 """Get a seat by ID and subscription ID filtered by auth subject.""" 

321 statement = ( 

322 self.get_readable_statement(auth_subject) 

323 .where( 

324 CustomerSeat.id == seat_id, 

325 CustomerSeat.subscription_id == subscription_id, 

326 ) 

327 .options(*options) 

328 ) 

329 return await self.get_one_or_none(statement) 

330 

331 async def get_active_seat_for_customer( 1a

332 self, 

333 customer_id: UUID, 

334 *, 

335 options: Options = (), 

336 ) -> CustomerSeat | None: 

337 """ 

338 Get an active (claimed) seat for a customer. 

339 

340 Used to determine if a customer is a seat holder and should have 

341 their usage charges routed to the billing manager's subscription. 

342 """ 

343 statement = ( 

344 select(CustomerSeat) 

345 .where( 

346 CustomerSeat.customer_id == customer_id, 

347 CustomerSeat.status == SeatStatus.claimed, 

348 ) 

349 .options(*options) 

350 .limit(1) 

351 ) 

352 return await self.get_one_or_none(statement) 

353 

354 def get_eager_options(self) -> Options: 1a

355 return ( 

356 joinedload(CustomerSeat.subscription).options( 

357 joinedload(Subscription.product).joinedload(Product.organization), 

358 joinedload(Subscription.customer), 

359 ), 

360 joinedload(CustomerSeat.order).options( 

361 joinedload(Order.product), 

362 joinedload(Order.customer).joinedload(Customer.organization), 

363 ), 

364 joinedload(CustomerSeat.customer), 

365 ) 

366 

367 def get_eager_options_with_prices(self) -> Options: 1a

368 return ( 

369 *self.get_eager_options(), 

370 joinedload(CustomerSeat.subscription).joinedload( 

371 Subscription.subscription_product_prices 

372 ), 

373 )