Coverage for polar/checkout_link/service.py: 18%

118 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1import uuid 1a

2from collections.abc import Sequence 1a

3from typing import Any, cast 1a

4 

5from sqlalchemy import UnaryExpression, asc, desc 1a

6from sqlalchemy.orm import contains_eager 1a

7 

8from polar.auth.models import AuthSubject 1a

9from polar.checkout_link.repository import CheckoutLinkRepository 1a

10from polar.discount.service import discount as discount_service 1a

11from polar.exceptions import PolarRequestValidationError, ValidationError 1a

12from polar.kit.crypto import generate_token 1a

13from polar.kit.pagination import PaginationParams 1a

14from polar.kit.services import ResourceServiceReader 1a

15from polar.kit.sorting import Sorting 1a

16from polar.models import ( 1a

17 CheckoutLink, 

18 CheckoutLinkProduct, 

19 Discount, 

20 Organization, 

21 Product, 

22 ProductPrice, 

23 User, 

24) 

25from polar.postgres import AsyncSession 1a

26from polar.product.repository import ProductPriceRepository, ProductRepository 1a

27from polar.product.service import product as product_service 1a

28 

29from .schemas import ( 1a

30 CheckoutLinkCreate, 

31 CheckoutLinkCreateProduct, 

32 CheckoutLinkCreateProductPrice, 

33 CheckoutLinkCreateProducts, 

34 CheckoutLinkUpdate, 

35) 

36from .sorting import CheckoutLinkSortProperty 1a

37 

38CHECKOUT_LINK_CLIENT_SECRET_PREFIX = "polar_cl_" 1a

39 

40 

41class CheckoutLinkService(ResourceServiceReader[CheckoutLink]): 1a

42 async def list( 1a

43 self, 

44 session: AsyncSession, 

45 auth_subject: AuthSubject[User | Organization], 

46 *, 

47 organization_id: Sequence[uuid.UUID] | None = None, 

48 product_id: Sequence[uuid.UUID] | None = None, 

49 pagination: PaginationParams, 

50 sorting: list[Sorting[CheckoutLinkSortProperty]] = [ 

51 (CheckoutLinkSortProperty.created_at, False) 

52 ], 

53 ) -> tuple[Sequence[CheckoutLink], int]: 

54 repository = CheckoutLinkRepository.from_session(session) 

55 statement = repository.get_readable_statement(auth_subject) 

56 checkout_link_product_load = None 

57 

58 if organization_id is not None: 

59 statement = statement.where( 

60 CheckoutLink.organization_id.in_(organization_id) 

61 ) 

62 

63 if product_id is not None: 

64 statement = statement.join( 

65 CheckoutLinkProduct, 

66 onclause=CheckoutLinkProduct.checkout_link_id == CheckoutLink.id, 

67 ).where(CheckoutLinkProduct.product_id.in_(product_id)) 

68 checkout_link_product_load = contains_eager( 

69 CheckoutLink.checkout_link_products 

70 ) 

71 

72 order_by_clauses: list[UnaryExpression[Any]] = [] 

73 for criterion, is_desc in sorting: 

74 clause_function = desc if is_desc else asc 

75 if criterion == CheckoutLinkSortProperty.created_at: 

76 order_by_clauses.append(clause_function(CheckoutLink.created_at)) 

77 elif criterion == CheckoutLinkSortProperty.label: 

78 order_by_clauses.append(clause_function(CheckoutLink.label)) 

79 elif criterion == CheckoutLinkSortProperty.success_url: 

80 order_by_clauses.append(clause_function(CheckoutLink._success_url)) 

81 elif criterion == CheckoutLinkSortProperty.allow_discount_codes: 

82 order_by_clauses.append( 

83 clause_function(CheckoutLink.allow_discount_codes) 

84 ) 

85 statement = statement.order_by(*order_by_clauses) 

86 

87 statement = statement.options( 

88 *repository.get_eager_options( 

89 checkout_link_product_load=checkout_link_product_load 

90 ) 

91 ) 

92 

93 return await repository.paginate( 

94 statement, limit=pagination.limit, page=pagination.page 

95 ) 

96 

97 async def get_by_id( 1a

98 self, 

99 session: AsyncSession, 

100 auth_subject: AuthSubject[User | Organization], 

101 id: uuid.UUID, 

102 ) -> CheckoutLink | None: 

103 repository = CheckoutLinkRepository.from_session(session) 

104 statement = ( 

105 repository.get_readable_statement(auth_subject) 

106 .where(CheckoutLink.id == id) 

107 .options(*repository.get_eager_options()) 

108 ) 

109 return await repository.get_one_or_none(statement) 

110 

111 async def create( 1a

112 self, 

113 session: AsyncSession, 

114 checkout_link_create: CheckoutLinkCreate, 

115 auth_subject: AuthSubject[User | Organization], 

116 ) -> CheckoutLink: 

117 if isinstance(checkout_link_create, CheckoutLinkCreateProducts): 

118 products = await self._get_validated_products( 

119 session, checkout_link_create.products, auth_subject 

120 ) 

121 elif isinstance(checkout_link_create, CheckoutLinkCreateProduct): 

122 products = await self._get_validated_products( 

123 session, [checkout_link_create.product_id], auth_subject 

124 ) 

125 elif isinstance(checkout_link_create, CheckoutLinkCreateProductPrice): 

126 product, _ = await self._get_validated_price( 

127 session, checkout_link_create.product_price_id, auth_subject 

128 ) 

129 products = [product] 

130 organization = products[0].organization 

131 

132 discount: Discount | None = None 

133 if checkout_link_create.discount_id is not None: 

134 discount = await self._get_validated_discount( 

135 session, checkout_link_create.discount_id, organization, products 

136 ) 

137 

138 checkout_link = CheckoutLink( 

139 client_secret=generate_token(prefix=CHECKOUT_LINK_CLIENT_SECRET_PREFIX), 

140 organization=organization, 

141 discount=discount, 

142 checkout_link_products=[ 

143 CheckoutLinkProduct(product=product, order=i) 

144 for i, product in enumerate(products) 

145 ], 

146 **checkout_link_create.model_dump( 

147 exclude={ 

148 "products", 

149 "product_id", 

150 "product_price_id", 

151 "discount_id", 

152 }, 

153 by_alias=True, 

154 ), 

155 ) 

156 

157 repository = CheckoutLinkRepository.from_session(session) 

158 return await repository.create(checkout_link) 

159 

160 async def update( 1a

161 self, 

162 session: AsyncSession, 

163 checkout_link: CheckoutLink, 

164 checkout_link_update: CheckoutLinkUpdate, 

165 auth_subject: AuthSubject[User | Organization], 

166 ) -> CheckoutLink: 

167 if checkout_link_update.products is not None: 

168 products = await self._get_validated_products( 

169 session, checkout_link_update.products, auth_subject 

170 ) 

171 if checkout_link.organization_id != products[0].organization_id: 

172 raise PolarRequestValidationError( 

173 [ 

174 { 

175 "type": "value_error", 

176 "loc": ("body", "products"), 

177 "msg": ( 

178 "Products don't belong to " 

179 "the checkout link's organization." 

180 ), 

181 "input": checkout_link_update.products, 

182 } 

183 ] 

184 ) 

185 checkout_link.checkout_link_products = [] 

186 await session.flush() 

187 checkout_link.checkout_link_products = [ 

188 CheckoutLinkProduct(product=product, order=i) 

189 for i, product in enumerate(products) 

190 ] 

191 

192 if "discount_id" in checkout_link_update.model_fields_set: 

193 if checkout_link_update.discount_id is None: 

194 checkout_link.discount = None 

195 else: 

196 discount = await self._get_validated_discount( 

197 session, 

198 checkout_link_update.discount_id, 

199 checkout_link.organization, 

200 checkout_link.products, 

201 ) 

202 checkout_link.discount = discount 

203 

204 repository = CheckoutLinkRepository.from_session(session) 

205 return await repository.update( 

206 checkout_link, 

207 update_dict=checkout_link_update.model_dump( 

208 exclude_unset=True, 

209 exclude={"products", "discount_id"}, 

210 by_alias=True, 

211 ), 

212 ) 

213 

214 async def delete( 1a

215 self, session: AsyncSession, checkout_link: CheckoutLink 

216 ) -> CheckoutLink: 

217 repository = CheckoutLinkRepository.from_session(session) 

218 return await repository.soft_delete(checkout_link) 

219 

220 async def _get_validated_products( 1a

221 self, 

222 session: AsyncSession, 

223 product_ids: Sequence[uuid.UUID], 

224 auth_subject: AuthSubject[User | Organization], 

225 ) -> Sequence[Product]: 

226 products: list[Product] = [] 

227 errors: list[ValidationError] = [] 

228 

229 for index, product_id in enumerate(product_ids): 

230 product = await product_service.get(session, auth_subject, product_id) 

231 

232 if product is None: 

233 errors.append( 

234 { 

235 "type": "value_error", 

236 "loc": ("body", "products", index), 

237 "msg": "Product does not exist.", 

238 "input": product_id, 

239 } 

240 ) 

241 continue 

242 

243 if product.is_archived: 

244 errors.append( 

245 { 

246 "type": "value_error", 

247 "loc": ("body", "products", index), 

248 "msg": "Product is archived.", 

249 "input": product_id, 

250 } 

251 ) 

252 continue 

253 

254 products.append(product) 

255 

256 organization_ids = {product.organization_id for product in products} 

257 if len(organization_ids) > 1: 

258 errors.append( 

259 { 

260 "type": "value_error", 

261 "loc": ("body", "products"), 

262 "msg": "Products must all belong to the same organization.", 

263 "input": products, 

264 } 

265 ) 

266 

267 if len(errors) > 0: 

268 raise PolarRequestValidationError(errors) 

269 

270 return products 

271 

272 async def _get_validated_price( 1a

273 self, 

274 session: AsyncSession, 

275 price_id: uuid.UUID, 

276 auth_subject: AuthSubject[User | Organization], 

277 ) -> tuple[Product, ProductPrice]: 

278 product_price_repository = ProductPriceRepository.from_session(session) 

279 price = await product_price_repository.get_readable_by_id( 

280 price_id, auth_subject 

281 ) 

282 

283 if price is None: 

284 raise PolarRequestValidationError( 

285 [ 

286 { 

287 "type": "value_error", 

288 "loc": ("body", "product_price_id"), 

289 "msg": "Price does not exist.", 

290 "input": price_id, 

291 } 

292 ] 

293 ) 

294 

295 if price.is_archived: 

296 raise PolarRequestValidationError( 

297 [ 

298 { 

299 "type": "value_error", 

300 "loc": ("body", "product_price_id"), 

301 "msg": "Price is archived.", 

302 "input": price_id, 

303 } 

304 ] 

305 ) 

306 

307 product = price.product 

308 if product.is_archived: 

309 raise PolarRequestValidationError( 

310 [ 

311 { 

312 "type": "value_error", 

313 "loc": ("body", "product_price_id"), 

314 "msg": "Product is archived.", 

315 "input": price_id, 

316 } 

317 ] 

318 ) 

319 

320 product_repository = ProductRepository.from_session(session) 

321 product = cast( 

322 Product, 

323 await product_repository.get_by_id( 

324 product.id, options=product_repository.get_eager_options() 

325 ), 

326 ) 

327 return (product, price) 

328 

329 async def _get_validated_discount( 1a

330 self, 

331 session: AsyncSession, 

332 discount_id: uuid.UUID, 

333 organization: Organization, 

334 products: Sequence[Product], 

335 ) -> Discount: 

336 discount = await discount_service.get_by_id_and_organization( 

337 session, 

338 discount_id, 

339 organization, 

340 products=products, 

341 redeemable=False, 

342 ) 

343 

344 if discount is None: 

345 raise PolarRequestValidationError( 

346 [ 

347 { 

348 "type": "value_error", 

349 "loc": ("body", "discount_id"), 

350 "msg": ( 

351 "Discount does not exist or " 

352 "is not applicable to this product." 

353 ), 

354 "input": discount_id, 

355 } 

356 ] 

357 ) 

358 

359 return discount 

360 

361 

362checkout_link = CheckoutLinkService(CheckoutLink) 1a