Coverage for polar/organization_access_token/repository.py: 69%
27 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
1from datetime import datetime 1a
2from uuid import UUID 1a
4from sqlalchemy import Select, or_, select, update 1a
5from sqlalchemy.orm import contains_eager 1a
7from polar.auth.models import AuthSubject, User 1a
8from polar.kit.repository import ( 1a
9 RepositoryBase,
10 RepositorySoftDeletionIDMixin,
11 RepositorySoftDeletionMixin,
12)
13from polar.kit.utils import utc_now 1a
14from polar.models import Organization, OrganizationAccessToken, UserOrganization 1a
15from polar.postgres import sql 1a
18class OrganizationAccessTokenRepository( 1a
19 RepositorySoftDeletionIDMixin[OrganizationAccessToken, UUID],
20 RepositorySoftDeletionMixin[OrganizationAccessToken],
21 RepositoryBase[OrganizationAccessToken],
22):
23 model = OrganizationAccessToken 1a
25 async def get_by_token_hash( 1a
26 self, token_hash: str, *, expired: bool = False
27 ) -> OrganizationAccessToken | None:
28 statement = ( 1b
29 self.get_base_statement()
30 .join(OrganizationAccessToken.organization)
31 .where(
32 OrganizationAccessToken.token == token_hash,
33 Organization.can_authenticate.is_(True),
34 )
35 .options(contains_eager(OrganizationAccessToken.organization))
36 )
37 if not expired: 37 ↛ 44line 37 didn't jump to line 44 because the condition on line 37 was always true1b
38 statement = statement.where( 1b
39 or_(
40 OrganizationAccessToken.expires_at.is_(None),
41 OrganizationAccessToken.expires_at > utc_now(),
42 )
43 )
44 return await self.get_one_or_none(statement) 1b
46 async def record_usage(self, id: UUID, last_used_at: datetime) -> None: 1a
47 statement = (
48 update(OrganizationAccessToken)
49 .where(OrganizationAccessToken.id == id)
50 .values(last_used_at=last_used_at)
51 )
52 await self.session.execute(statement)
54 def get_readable_statement( 1a
55 self, auth_subject: AuthSubject[User]
56 ) -> Select[tuple[OrganizationAccessToken]]:
57 statement = self.get_base_statement()
58 user = auth_subject.subject
59 statement = statement.where(
60 OrganizationAccessToken.organization_id.in_(
61 select(UserOrganization.organization_id).where(
62 UserOrganization.user_id == user.id,
63 UserOrganization.deleted_at.is_(None),
64 )
65 )
66 )
67 return statement
69 async def count_by_organization_id( 1a
70 self,
71 organization_id: UUID,
72 ) -> int:
73 """Count active organization access tokens for an organization."""
74 count = await self.session.scalar(
75 sql.select(sql.func.count(OrganizationAccessToken.id)).where(
76 OrganizationAccessToken.organization_id == organization_id,
77 OrganizationAccessToken.deleted_at.is_(None),
78 )
79 )
80 return count or 0