Coverage for polar/license_key/service.py: 16%

165 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 15:52 +0000

1from collections.abc import Sequence 1a

2from typing import cast 1a

3from uuid import UUID 1a

4 

5import structlog 1a

6from sqlalchemy import Select, func, select 1a

7from sqlalchemy.orm import joinedload 1a

8 

9from polar.auth.models import AuthSubject 1a

10from polar.benefit.strategies.license_keys.properties import ( 1a

11 BenefitLicenseKeysProperties, 

12) 

13from polar.exceptions import BadRequest, NotPermitted, ResourceNotFound 1a

14from polar.kit.pagination import PaginationParams, paginate 1a

15from polar.kit.utils import utc_now 1a

16from polar.models import ( 1a

17 Benefit, 

18 Customer, 

19 LicenseKey, 

20 LicenseKeyActivation, 

21 Organization, 

22 User, 

23) 

24from polar.postgres import AsyncReadSession, AsyncSession 1a

25 

26from .repository import LicenseKeyRepository 1a

27from .schemas import ( 1a

28 LicenseKeyActivate, 

29 LicenseKeyCreate, 

30 LicenseKeyDeactivate, 

31 LicenseKeyUpdate, 

32 LicenseKeyValidate, 

33) 

34 

35log = structlog.get_logger() 1a

36 

37 

38class LicenseKeyService: 1a

39 async def list( 1a

40 self, 

41 session: AsyncReadSession, 

42 auth_subject: AuthSubject[User | Organization], 

43 *, 

44 pagination: PaginationParams, 

45 organization_id: Sequence[UUID] | None = None, 

46 benefit_id: Sequence[UUID] | None = None, 

47 ) -> tuple[Sequence[LicenseKey], int]: 

48 repository = LicenseKeyRepository.from_session(session) 

49 statement = ( 

50 repository.get_readable_statement(auth_subject) 

51 .order_by(LicenseKey.created_at.asc()) 

52 .options(*repository.get_eager_options()) 

53 ) 

54 

55 if organization_id is not None: 

56 statement = statement.where(LicenseKey.organization_id.in_(organization_id)) 

57 

58 if benefit_id is not None: 

59 statement = statement.where(LicenseKey.benefit_id.in_(benefit_id)) 

60 

61 return await repository.paginate( 

62 statement, limit=pagination.limit, page=pagination.page 

63 ) 

64 

65 async def get( 1a

66 self, 

67 session: AsyncReadSession, 

68 auth_subject: AuthSubject[User | Organization], 

69 id: UUID, 

70 ) -> LicenseKey | None: 

71 repository = LicenseKeyRepository.from_session(session) 

72 statement = ( 

73 repository.get_readable_statement(auth_subject) 

74 .where(LicenseKey.id == id) 

75 .options(*repository.get_eager_options()) 

76 ) 

77 return await repository.get_one_or_none(statement) 

78 

79 async def get_or_raise_by_key( 1a

80 self, 

81 session: AsyncSession, 

82 *, 

83 organization_id: UUID, 

84 key: str, 

85 ) -> LicenseKey: 

86 repository = LicenseKeyRepository.from_session(session) 

87 lk = await repository.get_by_organization_and_key( 

88 organization_id, key, options=repository.get_eager_options() 

89 ) 

90 if lk is None: 

91 raise ResourceNotFound() 

92 return lk 

93 

94 async def get_by_grant_or_raise( 1a

95 self, 

96 session: AsyncSession, 

97 *, 

98 id: UUID, 

99 organization_id: UUID, 

100 customer_id: UUID, 

101 benefit_id: UUID, 

102 ) -> LicenseKey: 

103 repository = LicenseKeyRepository.from_session(session) 

104 lk = await repository.get_by_id_organization_customer_and_benefit( 

105 id, 

106 organization_id, 

107 customer_id, 

108 benefit_id, 

109 options=repository.get_eager_options(), 

110 ) 

111 if lk is None: 

112 raise ResourceNotFound() 

113 return lk 

114 

115 async def get_activation_or_raise( 1a

116 self, session: AsyncReadSession, *, license_key: LicenseKey, activation_id: UUID 

117 ) -> LicenseKeyActivation: 

118 query = select(LicenseKeyActivation).where( 

119 LicenseKeyActivation.id == activation_id, 

120 LicenseKeyActivation.license_key_id == license_key.id, 

121 LicenseKeyActivation.deleted_at.is_(None), 

122 ) 

123 result = await session.execute(query) 

124 record = result.scalar_one_or_none() 

125 if not record: 

126 raise ResourceNotFound() 

127 

128 record.license_key = license_key 

129 return record 

130 

131 async def update( 1a

132 self, 

133 session: AsyncSession, 

134 *, 

135 license_key: LicenseKey, 

136 updates: LicenseKeyUpdate, 

137 ) -> LicenseKey: 

138 update_dict = updates.model_dump(exclude_unset=True) 

139 for key, value in update_dict.items(): 

140 setattr(license_key, key, value) 

141 

142 session.add(license_key) 

143 await session.flush() 

144 return license_key 

145 

146 async def validate( 1a

147 self, 

148 session: AsyncSession, 

149 *, 

150 license_key: LicenseKey, 

151 validate: LicenseKeyValidate, 

152 ) -> LicenseKey: 

153 bound_logger = log.bind( 

154 license_key_id=license_key.id, 

155 organization_id=license_key.organization_id, 

156 customer_id=license_key.customer_id, 

157 benefit_id=license_key.benefit_id, 

158 ) 

159 if not license_key.is_active(): 

160 bound_logger.info("license_key.validate.invalid_status") 

161 raise ResourceNotFound("License key is no longer active.") 

162 

163 if license_key.expires_at: 

164 if utc_now() >= license_key.expires_at: 

165 bound_logger.info("license_key.validate.invalid_ttl") 

166 raise ResourceNotFound("License key has expired.") 

167 

168 if validate.activation_id: 

169 activation = await self.get_activation_or_raise( 

170 session, 

171 license_key=license_key, 

172 activation_id=validate.activation_id, 

173 ) 

174 if activation.conditions and validate.conditions != activation.conditions: 

175 # Skip logging UGC conditions 

176 bound_logger.info("license_key.validate.invalid_conditions") 

177 raise ResourceNotFound("License key does not match required conditions") 

178 license_key.activation = activation 

179 

180 if validate.benefit_id and validate.benefit_id != license_key.benefit_id: 

181 bound_logger.info("license_key.validate.invalid_benefit") 

182 raise ResourceNotFound("License key does not match given benefit.") 

183 

184 if validate.customer_id and validate.customer_id != license_key.customer_id: 

185 bound_logger.warning( 

186 "license_key.validate.invalid_owner", 

187 validate_customer_id=validate.customer_id, 

188 ) 

189 raise ResourceNotFound("License key does not match given user.") 

190 

191 if validate.increment_usage and license_key.limit_usage: 

192 remaining = license_key.limit_usage - license_key.usage 

193 if validate.increment_usage > remaining: 

194 bound_logger.info( 

195 "license_key.validate.insufficient_usage", 

196 usage_remaining=remaining, 

197 usage_requested=validate.increment_usage, 

198 ) 

199 raise BadRequest(f"License key only has {remaining} more usages.") 

200 

201 license_key.mark_validated(increment_usage=validate.increment_usage) 

202 session.add(license_key) 

203 bound_logger.info("license_key.validate") 

204 return license_key 

205 

206 async def get_activation_count( 1a

207 self, 

208 session: AsyncSession, 

209 license_key: LicenseKey, 

210 ) -> int: 

211 query = select(func.count(LicenseKeyActivation.id)).where( 

212 LicenseKeyActivation.license_key_id == license_key.id, 

213 LicenseKeyActivation.deleted_at.is_(None), 

214 ) 

215 res = await session.execute(query) 

216 count = res.scalar() 

217 if count: 

218 return count 

219 return 0 

220 

221 async def activate( 1a

222 self, 

223 session: AsyncSession, 

224 license_key: LicenseKey, 

225 activate: LicenseKeyActivate, 

226 ) -> LicenseKeyActivation: 

227 if not license_key.is_active(): 

228 raise NotPermitted( 

229 "License key is no longer active. " 

230 "This license key can not be activated." 

231 ) 

232 

233 if license_key.expires_at: 

234 if utc_now() >= license_key.expires_at: 

235 raise NotPermitted("License key has expired.") 

236 

237 if not license_key.limit_activations: 

238 raise NotPermitted( 

239 "This license key does not support activations. " 

240 "Use the /validate endpoint instead to check license validity." 

241 ) 

242 

243 current_activation_count = await self.get_activation_count( 

244 session, 

245 license_key=license_key, 

246 ) 

247 if current_activation_count >= license_key.limit_activations: 

248 log.info( 

249 "license_key.activate.limit_reached", 

250 license_key_id=license_key.id, 

251 organization_id=license_key.organization_id, 

252 customer_id=license_key.customer_id, 

253 benefit_id=license_key.benefit_id, 

254 ) 

255 raise NotPermitted("License key activation limit already reached") 

256 

257 instance = LicenseKeyActivation( 

258 license_key=license_key, 

259 label=activate.label, 

260 conditions=activate.conditions, 

261 meta=activate.meta, 

262 ) 

263 session.add(instance) 

264 await session.flush() 

265 assert instance.id 

266 log.info( 

267 "license_key.activate", 

268 license_key_id=license_key.id, 

269 organization_id=license_key.organization_id, 

270 customer_id=license_key.customer_id, 

271 benefit_id=license_key.benefit_id, 

272 activation_id=instance.id, 

273 ) 

274 return instance 

275 

276 async def deactivate( 1a

277 self, 

278 session: AsyncSession, 

279 license_key: LicenseKey, 

280 deactivate: LicenseKeyDeactivate, 

281 ) -> bool: 

282 activation = await self.get_activation_or_raise( 

283 session, 

284 license_key=license_key, 

285 activation_id=deactivate.activation_id, 

286 ) 

287 activation.mark_deleted() 

288 session.add(activation) 

289 await session.flush() 

290 assert activation.deleted_at is not None 

291 log.info( 

292 "license_key.deactivate", 

293 license_key_id=license_key.id, 

294 organization_id=license_key.organization_id, 

295 customer_id=license_key.customer_id, 

296 benefit_id=license_key.benefit_id, 

297 activation_id=activation.id, 

298 ) 

299 return True 

300 

301 async def customer_grant( 1a

302 self, 

303 session: AsyncSession, 

304 *, 

305 customer: Customer, 

306 benefit: Benefit, 

307 license_key_id: UUID | None = None, 

308 ) -> LicenseKey: 

309 props = cast(BenefitLicenseKeysProperties, benefit.properties) 

310 create_schema = LicenseKeyCreate.build( 

311 organization_id=benefit.organization_id, 

312 customer_id=customer.id, 

313 benefit_id=benefit.id, 

314 prefix=props.get("prefix", None), 

315 limit_usage=props.get("limit_usage", None), 

316 activations=props.get("activations", None), 

317 expires=props.get("expires", None), 

318 ) 

319 log.info( 

320 "license_key.grant.request", 

321 organization_id=benefit.organization_id, 

322 customer_id=customer.id, 

323 benefit_id=benefit.id, 

324 ) 

325 if license_key_id: 

326 return await self.customer_update_grant( 

327 session, 

328 create_schema=create_schema, 

329 license_key_id=license_key_id, 

330 ) 

331 

332 return await self.customer_create_grant( 

333 session, 

334 create_schema=create_schema, 

335 ) 

336 

337 async def customer_update_grant( 1a

338 self, 

339 session: AsyncSession, 

340 *, 

341 license_key_id: UUID, 

342 create_schema: LicenseKeyCreate, 

343 ) -> LicenseKey: 

344 key = await self.get_by_grant_or_raise( 

345 session, 

346 id=license_key_id, 

347 organization_id=create_schema.organization_id, 

348 customer_id=create_schema.customer_id, 

349 benefit_id=create_schema.benefit_id, 

350 ) 

351 

352 update_attrs = [ 

353 "status", 

354 "expires_at", 

355 "limit_activations", 

356 "limit_usage", 

357 ] 

358 for attr in update_attrs: 

359 current = getattr(key, attr) 

360 updated = getattr(create_schema, attr) 

361 if current != updated: 

362 setattr(key, attr, updated) 

363 

364 session.add(key) 

365 await session.flush() 

366 assert key.id is not None 

367 log.info( 

368 "license_key.grant.update", 

369 license_key_id=key.id, 

370 organization_id=key.organization_id, 

371 customer_id=key.customer_id, 

372 benefit_id=key.benefit_id, 

373 ) 

374 return key 

375 

376 async def customer_create_grant( 1a

377 self, 

378 session: AsyncSession, 

379 *, 

380 create_schema: LicenseKeyCreate, 

381 ) -> LicenseKey: 

382 key = LicenseKey(**create_schema.model_dump()) 

383 session.add(key) 

384 await session.flush() 

385 assert key.id is not None 

386 log.info( 

387 "license_key.grant.create", 

388 license_key_id=key.id, 

389 organization_id=key.organization_id, 

390 customer_id=key.customer_id, 

391 benefit_id=key.benefit_id, 

392 ) 

393 return key 

394 

395 async def customer_revoke( 1a

396 self, 

397 session: AsyncSession, 

398 customer: Customer, 

399 benefit: Benefit, 

400 license_key_id: UUID, 

401 ) -> LicenseKey: 

402 key = await self.get_by_grant_or_raise( 

403 session, 

404 id=license_key_id, 

405 organization_id=benefit.organization_id, 

406 customer_id=customer.id, 

407 benefit_id=benefit.id, 

408 ) 

409 key.mark_revoked() 

410 session.add(key) 

411 await session.flush() 

412 log.info( 

413 "license_key.revoke", 

414 license_key_id=key.id, 

415 organization_id=key.organization_id, 

416 customer_id=key.customer_id, 

417 benefit_id=key.benefit_id, 

418 ) 

419 return key 

420 

421 async def get_customer_list( 1a

422 self, 

423 session: AsyncSession, 

424 auth_subject: AuthSubject[Customer], 

425 *, 

426 pagination: PaginationParams, 

427 benefit_id: UUID | None = None, 

428 ) -> tuple[Sequence[LicenseKey], int]: 

429 query = ( 

430 self._get_select_customer_base(auth_subject) 

431 .order_by(LicenseKey.created_at.asc()) 

432 .options( 

433 joinedload(LicenseKey.benefit), 

434 ) 

435 ) 

436 

437 if benefit_id: 

438 query = query.where(LicenseKey.benefit_id == benefit_id) 

439 

440 return await paginate(session, query, pagination=pagination) 

441 

442 async def get_customer_license_key( 1a

443 self, 

444 session: AsyncSession, 

445 auth_subject: AuthSubject[Customer], 

446 license_key_id: UUID, 

447 ) -> LicenseKey | None: 

448 query = ( 

449 self._get_select_customer_base(auth_subject) 

450 .where(LicenseKey.id == license_key_id) 

451 .options(joinedload(LicenseKey.activations), joinedload(LicenseKey.benefit)) 

452 ) 

453 result = await session.execute(query) 

454 return result.unique().scalar_one_or_none() 

455 

456 def _get_select_customer_base( 1a

457 self, auth_subject: AuthSubject[Customer] 

458 ) -> Select[tuple[LicenseKey]]: 

459 return ( 

460 select(LicenseKey) 

461 .options(joinedload(LicenseKey.customer)) 

462 .where( 

463 LicenseKey.deleted_at.is_(None), 

464 LicenseKey.customer_id == auth_subject.subject.id, 

465 ) 

466 ) 

467 

468 

469license_key = LicenseKeyService() 1a