Coverage for polar/license_key/service.py: 16%
165 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
1from collections.abc import Sequence 1a
2from typing import cast 1a
3from uuid import UUID 1a
5import structlog 1a
6from sqlalchemy import Select, func, select 1a
7from sqlalchemy.orm import joinedload 1a
9from polar.auth.models import AuthSubject 1a
10from polar.benefit.strategies.license_keys.properties import ( 1a
11 BenefitLicenseKeysProperties,
12)
13from polar.exceptions import BadRequest, NotPermitted, ResourceNotFound 1a
14from polar.kit.pagination import PaginationParams, paginate 1a
15from polar.kit.utils import utc_now 1a
16from polar.models import ( 1a
17 Benefit,
18 Customer,
19 LicenseKey,
20 LicenseKeyActivation,
21 Organization,
22 User,
23)
24from polar.postgres import AsyncReadSession, AsyncSession 1a
26from .repository import LicenseKeyRepository 1a
27from .schemas import ( 1a
28 LicenseKeyActivate,
29 LicenseKeyCreate,
30 LicenseKeyDeactivate,
31 LicenseKeyUpdate,
32 LicenseKeyValidate,
33)
35log = structlog.get_logger() 1a
38class LicenseKeyService: 1a
39 async def list( 1a
40 self,
41 session: AsyncReadSession,
42 auth_subject: AuthSubject[User | Organization],
43 *,
44 pagination: PaginationParams,
45 organization_id: Sequence[UUID] | None = None,
46 benefit_id: Sequence[UUID] | None = None,
47 ) -> tuple[Sequence[LicenseKey], int]:
48 repository = LicenseKeyRepository.from_session(session)
49 statement = (
50 repository.get_readable_statement(auth_subject)
51 .order_by(LicenseKey.created_at.asc())
52 .options(*repository.get_eager_options())
53 )
55 if organization_id is not None:
56 statement = statement.where(LicenseKey.organization_id.in_(organization_id))
58 if benefit_id is not None:
59 statement = statement.where(LicenseKey.benefit_id.in_(benefit_id))
61 return await repository.paginate(
62 statement, limit=pagination.limit, page=pagination.page
63 )
65 async def get( 1a
66 self,
67 session: AsyncReadSession,
68 auth_subject: AuthSubject[User | Organization],
69 id: UUID,
70 ) -> LicenseKey | None:
71 repository = LicenseKeyRepository.from_session(session)
72 statement = (
73 repository.get_readable_statement(auth_subject)
74 .where(LicenseKey.id == id)
75 .options(*repository.get_eager_options())
76 )
77 return await repository.get_one_or_none(statement)
79 async def get_or_raise_by_key( 1a
80 self,
81 session: AsyncSession,
82 *,
83 organization_id: UUID,
84 key: str,
85 ) -> LicenseKey:
86 repository = LicenseKeyRepository.from_session(session)
87 lk = await repository.get_by_organization_and_key(
88 organization_id, key, options=repository.get_eager_options()
89 )
90 if lk is None:
91 raise ResourceNotFound()
92 return lk
94 async def get_by_grant_or_raise( 1a
95 self,
96 session: AsyncSession,
97 *,
98 id: UUID,
99 organization_id: UUID,
100 customer_id: UUID,
101 benefit_id: UUID,
102 ) -> LicenseKey:
103 repository = LicenseKeyRepository.from_session(session)
104 lk = await repository.get_by_id_organization_customer_and_benefit(
105 id,
106 organization_id,
107 customer_id,
108 benefit_id,
109 options=repository.get_eager_options(),
110 )
111 if lk is None:
112 raise ResourceNotFound()
113 return lk
115 async def get_activation_or_raise( 1a
116 self, session: AsyncReadSession, *, license_key: LicenseKey, activation_id: UUID
117 ) -> LicenseKeyActivation:
118 query = select(LicenseKeyActivation).where(
119 LicenseKeyActivation.id == activation_id,
120 LicenseKeyActivation.license_key_id == license_key.id,
121 LicenseKeyActivation.deleted_at.is_(None),
122 )
123 result = await session.execute(query)
124 record = result.scalar_one_or_none()
125 if not record:
126 raise ResourceNotFound()
128 record.license_key = license_key
129 return record
131 async def update( 1a
132 self,
133 session: AsyncSession,
134 *,
135 license_key: LicenseKey,
136 updates: LicenseKeyUpdate,
137 ) -> LicenseKey:
138 update_dict = updates.model_dump(exclude_unset=True)
139 for key, value in update_dict.items():
140 setattr(license_key, key, value)
142 session.add(license_key)
143 await session.flush()
144 return license_key
146 async def validate( 1a
147 self,
148 session: AsyncSession,
149 *,
150 license_key: LicenseKey,
151 validate: LicenseKeyValidate,
152 ) -> LicenseKey:
153 bound_logger = log.bind(
154 license_key_id=license_key.id,
155 organization_id=license_key.organization_id,
156 customer_id=license_key.customer_id,
157 benefit_id=license_key.benefit_id,
158 )
159 if not license_key.is_active():
160 bound_logger.info("license_key.validate.invalid_status")
161 raise ResourceNotFound("License key is no longer active.")
163 if license_key.expires_at:
164 if utc_now() >= license_key.expires_at:
165 bound_logger.info("license_key.validate.invalid_ttl")
166 raise ResourceNotFound("License key has expired.")
168 if validate.activation_id:
169 activation = await self.get_activation_or_raise(
170 session,
171 license_key=license_key,
172 activation_id=validate.activation_id,
173 )
174 if activation.conditions and validate.conditions != activation.conditions:
175 # Skip logging UGC conditions
176 bound_logger.info("license_key.validate.invalid_conditions")
177 raise ResourceNotFound("License key does not match required conditions")
178 license_key.activation = activation
180 if validate.benefit_id and validate.benefit_id != license_key.benefit_id:
181 bound_logger.info("license_key.validate.invalid_benefit")
182 raise ResourceNotFound("License key does not match given benefit.")
184 if validate.customer_id and validate.customer_id != license_key.customer_id:
185 bound_logger.warning(
186 "license_key.validate.invalid_owner",
187 validate_customer_id=validate.customer_id,
188 )
189 raise ResourceNotFound("License key does not match given user.")
191 if validate.increment_usage and license_key.limit_usage:
192 remaining = license_key.limit_usage - license_key.usage
193 if validate.increment_usage > remaining:
194 bound_logger.info(
195 "license_key.validate.insufficient_usage",
196 usage_remaining=remaining,
197 usage_requested=validate.increment_usage,
198 )
199 raise BadRequest(f"License key only has {remaining} more usages.")
201 license_key.mark_validated(increment_usage=validate.increment_usage)
202 session.add(license_key)
203 bound_logger.info("license_key.validate")
204 return license_key
206 async def get_activation_count( 1a
207 self,
208 session: AsyncSession,
209 license_key: LicenseKey,
210 ) -> int:
211 query = select(func.count(LicenseKeyActivation.id)).where(
212 LicenseKeyActivation.license_key_id == license_key.id,
213 LicenseKeyActivation.deleted_at.is_(None),
214 )
215 res = await session.execute(query)
216 count = res.scalar()
217 if count:
218 return count
219 return 0
221 async def activate( 1a
222 self,
223 session: AsyncSession,
224 license_key: LicenseKey,
225 activate: LicenseKeyActivate,
226 ) -> LicenseKeyActivation:
227 if not license_key.is_active():
228 raise NotPermitted(
229 "License key is no longer active. "
230 "This license key can not be activated."
231 )
233 if license_key.expires_at:
234 if utc_now() >= license_key.expires_at:
235 raise NotPermitted("License key has expired.")
237 if not license_key.limit_activations:
238 raise NotPermitted(
239 "This license key does not support activations. "
240 "Use the /validate endpoint instead to check license validity."
241 )
243 current_activation_count = await self.get_activation_count(
244 session,
245 license_key=license_key,
246 )
247 if current_activation_count >= license_key.limit_activations:
248 log.info(
249 "license_key.activate.limit_reached",
250 license_key_id=license_key.id,
251 organization_id=license_key.organization_id,
252 customer_id=license_key.customer_id,
253 benefit_id=license_key.benefit_id,
254 )
255 raise NotPermitted("License key activation limit already reached")
257 instance = LicenseKeyActivation(
258 license_key=license_key,
259 label=activate.label,
260 conditions=activate.conditions,
261 meta=activate.meta,
262 )
263 session.add(instance)
264 await session.flush()
265 assert instance.id
266 log.info(
267 "license_key.activate",
268 license_key_id=license_key.id,
269 organization_id=license_key.organization_id,
270 customer_id=license_key.customer_id,
271 benefit_id=license_key.benefit_id,
272 activation_id=instance.id,
273 )
274 return instance
276 async def deactivate( 1a
277 self,
278 session: AsyncSession,
279 license_key: LicenseKey,
280 deactivate: LicenseKeyDeactivate,
281 ) -> bool:
282 activation = await self.get_activation_or_raise(
283 session,
284 license_key=license_key,
285 activation_id=deactivate.activation_id,
286 )
287 activation.mark_deleted()
288 session.add(activation)
289 await session.flush()
290 assert activation.deleted_at is not None
291 log.info(
292 "license_key.deactivate",
293 license_key_id=license_key.id,
294 organization_id=license_key.organization_id,
295 customer_id=license_key.customer_id,
296 benefit_id=license_key.benefit_id,
297 activation_id=activation.id,
298 )
299 return True
301 async def customer_grant( 1a
302 self,
303 session: AsyncSession,
304 *,
305 customer: Customer,
306 benefit: Benefit,
307 license_key_id: UUID | None = None,
308 ) -> LicenseKey:
309 props = cast(BenefitLicenseKeysProperties, benefit.properties)
310 create_schema = LicenseKeyCreate.build(
311 organization_id=benefit.organization_id,
312 customer_id=customer.id,
313 benefit_id=benefit.id,
314 prefix=props.get("prefix", None),
315 limit_usage=props.get("limit_usage", None),
316 activations=props.get("activations", None),
317 expires=props.get("expires", None),
318 )
319 log.info(
320 "license_key.grant.request",
321 organization_id=benefit.organization_id,
322 customer_id=customer.id,
323 benefit_id=benefit.id,
324 )
325 if license_key_id:
326 return await self.customer_update_grant(
327 session,
328 create_schema=create_schema,
329 license_key_id=license_key_id,
330 )
332 return await self.customer_create_grant(
333 session,
334 create_schema=create_schema,
335 )
337 async def customer_update_grant( 1a
338 self,
339 session: AsyncSession,
340 *,
341 license_key_id: UUID,
342 create_schema: LicenseKeyCreate,
343 ) -> LicenseKey:
344 key = await self.get_by_grant_or_raise(
345 session,
346 id=license_key_id,
347 organization_id=create_schema.organization_id,
348 customer_id=create_schema.customer_id,
349 benefit_id=create_schema.benefit_id,
350 )
352 update_attrs = [
353 "status",
354 "expires_at",
355 "limit_activations",
356 "limit_usage",
357 ]
358 for attr in update_attrs:
359 current = getattr(key, attr)
360 updated = getattr(create_schema, attr)
361 if current != updated:
362 setattr(key, attr, updated)
364 session.add(key)
365 await session.flush()
366 assert key.id is not None
367 log.info(
368 "license_key.grant.update",
369 license_key_id=key.id,
370 organization_id=key.organization_id,
371 customer_id=key.customer_id,
372 benefit_id=key.benefit_id,
373 )
374 return key
376 async def customer_create_grant( 1a
377 self,
378 session: AsyncSession,
379 *,
380 create_schema: LicenseKeyCreate,
381 ) -> LicenseKey:
382 key = LicenseKey(**create_schema.model_dump())
383 session.add(key)
384 await session.flush()
385 assert key.id is not None
386 log.info(
387 "license_key.grant.create",
388 license_key_id=key.id,
389 organization_id=key.organization_id,
390 customer_id=key.customer_id,
391 benefit_id=key.benefit_id,
392 )
393 return key
395 async def customer_revoke( 1a
396 self,
397 session: AsyncSession,
398 customer: Customer,
399 benefit: Benefit,
400 license_key_id: UUID,
401 ) -> LicenseKey:
402 key = await self.get_by_grant_or_raise(
403 session,
404 id=license_key_id,
405 organization_id=benefit.organization_id,
406 customer_id=customer.id,
407 benefit_id=benefit.id,
408 )
409 key.mark_revoked()
410 session.add(key)
411 await session.flush()
412 log.info(
413 "license_key.revoke",
414 license_key_id=key.id,
415 organization_id=key.organization_id,
416 customer_id=key.customer_id,
417 benefit_id=key.benefit_id,
418 )
419 return key
421 async def get_customer_list( 1a
422 self,
423 session: AsyncSession,
424 auth_subject: AuthSubject[Customer],
425 *,
426 pagination: PaginationParams,
427 benefit_id: UUID | None = None,
428 ) -> tuple[Sequence[LicenseKey], int]:
429 query = (
430 self._get_select_customer_base(auth_subject)
431 .order_by(LicenseKey.created_at.asc())
432 .options(
433 joinedload(LicenseKey.benefit),
434 )
435 )
437 if benefit_id:
438 query = query.where(LicenseKey.benefit_id == benefit_id)
440 return await paginate(session, query, pagination=pagination)
442 async def get_customer_license_key( 1a
443 self,
444 session: AsyncSession,
445 auth_subject: AuthSubject[Customer],
446 license_key_id: UUID,
447 ) -> LicenseKey | None:
448 query = (
449 self._get_select_customer_base(auth_subject)
450 .where(LicenseKey.id == license_key_id)
451 .options(joinedload(LicenseKey.activations), joinedload(LicenseKey.benefit))
452 )
453 result = await session.execute(query)
454 return result.unique().scalar_one_or_none()
456 def _get_select_customer_base( 1a
457 self, auth_subject: AuthSubject[Customer]
458 ) -> Select[tuple[LicenseKey]]:
459 return (
460 select(LicenseKey)
461 .options(joinedload(LicenseKey.customer))
462 .where(
463 LicenseKey.deleted_at.is_(None),
464 LicenseKey.customer_id == auth_subject.subject.id,
465 )
466 )
469license_key = LicenseKeyService() 1a