Coverage for polar/subscription/repository.py: 28%

102 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1from collections.abc import Sequence 1a

2from dataclasses import dataclass 1a

3from typing import TYPE_CHECKING 1a

4from uuid import UUID 1a

5 

6from sqlalchemy import Select, case, or_, select 1a

7from sqlalchemy.orm import contains_eager 1a

8from sqlalchemy.orm.strategy_options import joinedload, selectinload 1a

9 

10from polar.auth.models import ( 1a

11 AuthSubject, 

12 Organization, 

13 User, 

14 is_customer, 

15 is_organization, 

16 is_user, 

17) 

18from polar.auth.models import ( 1a

19 Customer as AuthCustomer, 

20) 

21from polar.enums import SubscriptionRecurringInterval 1a

22from polar.kit.repository import ( 1a

23 Options, 

24 RepositoryBase, 

25 RepositorySoftDeletionIDMixin, 

26 RepositorySoftDeletionMixin, 

27 RepositorySortingMixin, 

28 SortingClause, 

29) 

30from polar.models import ( 1a

31 Customer, 

32 CustomerSeat, 

33 Discount, 

34 Product, 

35 ProductPrice, 

36 ProductPriceMeteredUnit, 

37 Subscription, 

38 SubscriptionMeter, 

39 SubscriptionProductPrice, 

40 UserOrganization, 

41) 

42from polar.models.customer_seat import SeatStatus 1a

43from polar.models.subscription import SubscriptionStatus 1a

44from polar.product.guard import is_metered_price 1a

45 

46from .sorting import SubscriptionSortProperty 1a

47 

48if TYPE_CHECKING: 48 ↛ 49line 48 didn't jump to line 49 because the condition on line 48 was never true1a

49 from sqlalchemy.orm.strategy_options import _AbstractLoad 

50 

51 

52@dataclass 1a

53class CustomerSubscriptionProductPrice: 1a

54 """ 

55 Result of looking up a customer's subscription product price for a meter. 

56 

57 Contains the paying customer ID (which may be different from the queried customer 

58 if they are a seat holder) and the associated subscription product price. 

59 

60 The full customer object can be accessed via subscription_product_price.subscription.customer. 

61 """ 

62 

63 customer_id: UUID 1a

64 subscription_product_price: SubscriptionProductPrice 1a

65 

66 

67class SubscriptionRepository( 1a

68 RepositorySortingMixin[Subscription, SubscriptionSortProperty], 

69 RepositorySoftDeletionIDMixin[Subscription, UUID], 

70 RepositorySoftDeletionMixin[Subscription], 

71 RepositoryBase[Subscription], 

72): 

73 model = Subscription 1a

74 

75 async def list_active_by_customer( 1a

76 self, customer_id: UUID, *, options: Options = () 

77 ) -> Sequence[Subscription]: 

78 statement = ( 

79 self.get_base_statement() 

80 .where( 

81 Subscription.customer_id == customer_id, 

82 Subscription.active.is_(True), 

83 ) 

84 .options(*options) 

85 ) 

86 return await self.get_all(statement) 

87 

88 async def get_by_id_and_organization( 1a

89 self, 

90 id: UUID, 

91 organization_id: UUID, 

92 *, 

93 options: Options = (), 

94 ) -> Subscription | None: 

95 statement = ( 

96 self.get_base_statement() 

97 .join(Product) 

98 .where( 

99 Subscription.id == id, 

100 Product.organization_id == organization_id, 

101 ) 

102 .options(contains_eager(Subscription.product), *options) 

103 ) 

104 return await self.get_one_or_none(statement) 

105 

106 async def get_by_checkout_id( 1a

107 self, checkout_id: UUID, *, options: Options = () 

108 ) -> Subscription | None: 

109 statement = ( 

110 self.get_base_statement() 

111 .where(Subscription.checkout_id == checkout_id) 

112 .options(*options) 

113 ) 

114 result = await self.session.execute(statement) 

115 return result.scalar_one_or_none() 

116 

117 async def get_by_stripe_subscription_id( 1a

118 self, stripe_subscription_id: str, *, options: Options = () 

119 ) -> Subscription | None: 

120 statement = ( 

121 self.get_base_statement() 

122 .where( 

123 or_( 

124 Subscription.stripe_subscription_id == stripe_subscription_id, 

125 Subscription.legacy_stripe_subscription_id 

126 == stripe_subscription_id, 

127 ) 

128 ) 

129 .options(*options) 

130 ) 

131 return await self.get_one_or_none(statement) 

132 

133 def get_eager_options( 1a

134 self, *, product_load: "_AbstractLoad | None" = None 

135 ) -> Options: 

136 if product_load is None: 

137 product_load = joinedload(Subscription.product) 

138 return ( 

139 joinedload(Subscription.customer), 

140 product_load.options( 

141 joinedload(Product.organization), 

142 selectinload(Product.product_medias), 

143 selectinload(Product.attached_custom_fields), 

144 ), 

145 selectinload(Subscription.meters).joinedload(SubscriptionMeter.meter), 

146 ) 

147 

148 def get_readable_statement( 1a

149 self, auth_subject: AuthSubject[User | Organization | Customer] 

150 ) -> Select[tuple[Subscription]]: 

151 statement = self.get_base_statement().join(Product) 

152 

153 if is_user(auth_subject): 

154 user = auth_subject.subject 

155 statement = statement.where( 

156 Product.organization_id.in_( 

157 select(UserOrganization.organization_id).where( 

158 UserOrganization.user_id == user.id, 

159 UserOrganization.deleted_at.is_(None), 

160 ) 

161 ) 

162 ) 

163 elif is_organization(auth_subject): 

164 statement = statement.where( 

165 Product.organization_id == auth_subject.subject.id, 

166 ) 

167 elif is_customer(auth_subject): 

168 customer = auth_subject.subject 

169 statement = statement.where( 

170 Subscription.customer_id == customer.id, 

171 Subscription.deleted_at.is_(None), 

172 ) 

173 

174 return statement 

175 

176 def get_claimed_subscriptions_statement( 1a

177 self, auth_subject: AuthSubject[AuthCustomer] 

178 ) -> Select[tuple[Subscription]]: 

179 """Get subscriptions where the customer has a claimed seat.""" 

180 customer = auth_subject.subject 

181 

182 statement = ( 

183 self.get_base_statement() 

184 .join(CustomerSeat, CustomerSeat.subscription_id == Subscription.id) 

185 .where( 

186 CustomerSeat.customer_id == customer.id, 

187 CustomerSeat.status == SeatStatus.claimed, 

188 ) 

189 ) 

190 

191 return statement 

192 

193 def get_sorting_clause(self, property: SubscriptionSortProperty) -> SortingClause: 1a

194 match property: 

195 case SubscriptionSortProperty.customer: 

196 return Customer.email 

197 case SubscriptionSortProperty.status: 

198 return case( 

199 (Subscription.status == SubscriptionStatus.incomplete, 1), 

200 ( 

201 Subscription.status == SubscriptionStatus.incomplete_expired, 

202 2, 

203 ), 

204 (Subscription.status == SubscriptionStatus.trialing, 3), 

205 ( 

206 Subscription.status == SubscriptionStatus.active, 

207 case( 

208 (Subscription.cancel_at_period_end.is_(False), 4), 

209 (Subscription.cancel_at_period_end.is_(True), 5), 

210 ), 

211 ), 

212 (Subscription.status == SubscriptionStatus.past_due, 6), 

213 (Subscription.status == SubscriptionStatus.canceled, 7), 

214 (Subscription.status == SubscriptionStatus.unpaid, 8), 

215 ) 

216 case SubscriptionSortProperty.started_at: 

217 return Subscription.started_at 

218 case SubscriptionSortProperty.current_period_end: 

219 return Subscription.current_period_end 

220 case SubscriptionSortProperty.amount: 

221 return case( 

222 ( 

223 Subscription.recurring_interval 

224 == SubscriptionRecurringInterval.year, 

225 Subscription.amount / 12, 

226 ), 

227 ( 

228 Subscription.recurring_interval 

229 == SubscriptionRecurringInterval.month, 

230 Subscription.amount, 

231 ), 

232 ( 

233 Subscription.recurring_interval 

234 == SubscriptionRecurringInterval.week, 

235 Subscription.amount * 4, 

236 ), 

237 ( 

238 Subscription.recurring_interval 

239 == SubscriptionRecurringInterval.day, 

240 Subscription.amount * 30, 

241 ), 

242 ) 

243 case SubscriptionSortProperty.product: 

244 return Product.name 

245 case SubscriptionSortProperty.discount: 

246 return Discount.name 

247 

248 

249class SubscriptionProductPriceRepository( 1a

250 RepositorySoftDeletionIDMixin[SubscriptionProductPrice, UUID], 

251 RepositorySoftDeletionMixin[SubscriptionProductPrice], 

252 RepositoryBase[SubscriptionProductPrice], 

253): 

254 model = SubscriptionProductPrice 1a

255 

256 async def get_by_customer_and_meter( 1a

257 self, customer_id: UUID, meter_id: UUID 

258 ) -> CustomerSubscriptionProductPrice | None: 

259 """ 

260 Get the paying customer and subscription product price for a customer and meter. 

261 

262 If the customer has a direct subscription with the meter, returns that. 

263 If the customer is a seat holder, returns the billing manager's subscription. 

264 """ 

265 result = await self._get_direct_subscription_price(customer_id, meter_id) 

266 if result is not None: 

267 return result 

268 

269 return await self._get_seat_subscription_price(customer_id, meter_id) 

270 

271 async def _get_direct_subscription_price( 1a

272 self, customer_id: UUID, meter_id: UUID 

273 ) -> CustomerSubscriptionProductPrice | None: 

274 statement = ( 

275 self.get_base_statement() 

276 .join( 

277 ProductPrice, 

278 SubscriptionProductPrice.product_price_id == ProductPrice.id, 

279 ) 

280 .join( 

281 Subscription, 

282 Subscription.id == SubscriptionProductPrice.subscription_id, 

283 ) 

284 .where( 

285 ProductPrice.is_metered.is_(True), 

286 ProductPriceMeteredUnit.meter_id == meter_id, 

287 Subscription.billable.is_(True), 

288 Subscription.customer_id == customer_id, 

289 ) 

290 # In case customer has several subscriptions, take the earliest one 

291 .order_by(Subscription.started_at.asc()) 

292 .limit(1) 

293 .options( 

294 contains_eager(SubscriptionProductPrice.product_price), 

295 contains_eager(SubscriptionProductPrice.subscription).joinedload( 

296 Subscription.customer 

297 ), 

298 ) 

299 ) 

300 

301 subscription_product_price = await self.get_one_or_none(statement) 

302 if subscription_product_price is None: 

303 return None 

304 

305 return CustomerSubscriptionProductPrice( 

306 customer_id=subscription_product_price.subscription.customer_id, 

307 subscription_product_price=subscription_product_price, 

308 ) 

309 

310 async def _get_seat_subscription_price( 1a

311 self, customer_id: UUID, meter_id: UUID 

312 ) -> CustomerSubscriptionProductPrice | None: 

313 """ 

314 Get subscription product price for a customer who is a seat holder. 

315 

316 Returns the billing manager's subscription if the seat holder has access 

317 to a metered price for the specified meter. 

318 """ 

319 seat = await self._get_active_seat_for_customer(customer_id) 

320 if seat is None or seat.subscription is None: 

321 return None 

322 

323 # Find matching metered price in billing manager's subscription 

324 assert seat.subscription is not None 

325 metered_price = self._find_metered_price_in_subscription( 

326 seat.subscription, meter_id 

327 ) 

328 if metered_price is None: 

329 return None 

330 

331 return CustomerSubscriptionProductPrice( 

332 customer_id=seat.subscription.customer_id, 

333 subscription_product_price=metered_price, 

334 ) 

335 

336 async def _get_active_seat_for_customer( 1a

337 self, customer_id: UUID 

338 ) -> CustomerSeat | None: 

339 """Get the active seat for a customer, with subscription data eagerly loaded.""" 

340 statement = ( 

341 select(CustomerSeat) 

342 .where( 

343 CustomerSeat.customer_id == customer_id, 

344 CustomerSeat.status == SeatStatus.claimed, 

345 ) 

346 .options( 

347 joinedload(CustomerSeat.subscription).options( 

348 joinedload(Subscription.customer), 

349 joinedload(Subscription.subscription_product_prices).options( 

350 joinedload(SubscriptionProductPrice.product_price), 

351 # Load the back-reference to satisfy lazy='raise_on_sql' 

352 # This points to the same Subscription already being loaded above 

353 joinedload(SubscriptionProductPrice.subscription), 

354 ), 

355 ) 

356 ) 

357 .limit(1) 

358 ) 

359 return await self.session.scalar(statement) 

360 

361 def _find_metered_price_in_subscription( 1a

362 self, subscription: Subscription, meter_id: UUID 

363 ) -> SubscriptionProductPrice | None: 

364 """ 

365 Find a metered price for the given meter in a subscription. 

366 

367 Returns None if no matching metered price is found. 

368 """ 

369 for spp in subscription.subscription_product_prices: 

370 if ( 

371 is_metered_price(spp.product_price) 

372 and spp.product_price.meter_id == meter_id 

373 ): 

374 return spp 

375 return None