Coverage for polar/user/repository.py: 29%
45 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
1from collections.abc import Sequence 1a
2from uuid import UUID 1a
4from sqlalchemy import func, select 1a
6from polar.kit.repository import ( 1a
7 RepositoryBase,
8 RepositorySoftDeletionIDMixin,
9 RepositorySoftDeletionMixin,
10 RepositorySortingMixin,
11)
12from polar.kit.repository.base import SortingClause 1a
13from polar.models import OAuthAccount, User, UserOrganization 1a
14from polar.models.user import OAuthPlatform 1a
16from .sorting import UserSortProperty 1a
19class UserRepository( 1a
20 RepositorySortingMixin[User, UserSortProperty],
21 RepositorySoftDeletionIDMixin[User, UUID],
22 RepositorySoftDeletionMixin[User],
23 RepositoryBase[User],
24):
25 model = User 1a
27 async def get_by_email( 1a
28 self,
29 email: str,
30 *,
31 include_deleted: bool = False,
32 included_blocked: bool = False,
33 ) -> User | None:
34 statement = self.get_base_statement(include_deleted=include_deleted).where(
35 func.lower(User.email) == email.lower()
36 )
37 if not included_blocked:
38 statement = statement.where(User.blocked_at.is_(None))
39 return await self.get_one_or_none(statement)
41 async def get_by_stripe_customer_id( 1a
42 self,
43 stripe_customer_id: str,
44 *,
45 include_deleted: bool = False,
46 included_blocked: bool = False,
47 ) -> User | None:
48 statement = self.get_base_statement(include_deleted=include_deleted).where(
49 User.stripe_customer_id == stripe_customer_id
50 )
51 if not included_blocked:
52 statement = statement.where(User.blocked_at.is_(None))
53 return await self.get_one_or_none(statement)
55 async def get_by_oauth_account( 1a
56 self,
57 platform: OAuthPlatform,
58 account_id: str,
59 *,
60 include_deleted: bool = False,
61 included_blocked: bool = False,
62 ) -> User | None:
63 statement = (
64 self.get_base_statement(include_deleted=include_deleted)
65 .join(User.oauth_accounts)
66 .where(
67 OAuthAccount.platform == platform,
68 OAuthAccount.account_id == account_id,
69 )
70 )
71 if not included_blocked:
72 statement = statement.where(User.blocked_at.is_(None))
73 return await self.get_one_or_none(statement)
75 async def get_by_identity_verification_id( 1a
76 self,
77 identity_verification_id: str,
78 *,
79 include_deleted: bool = False,
80 included_blocked: bool = False,
81 ) -> User | None:
82 statement = self.get_base_statement(include_deleted=include_deleted).where(
83 User.identity_verification_id == identity_verification_id
84 )
85 if not included_blocked:
86 statement = statement.where(User.blocked_at.is_(None))
87 return await self.get_one_or_none(statement)
89 async def get_all_by_organization( 1a
90 self,
91 organization_id: UUID,
92 *,
93 include_deleted: bool = False,
94 included_blocked: bool = False,
95 ) -> Sequence[User]:
96 statement = (
97 self.get_base_statement(include_deleted=include_deleted)
98 .join(UserOrganization, UserOrganization.user_id == User.id)
99 .where(
100 UserOrganization.deleted_at.is_(None),
101 UserOrganization.organization_id == organization_id,
102 )
103 )
104 if not included_blocked:
105 statement = statement.where(User.blocked_at.is_(None))
106 return await self.get_all(statement)
108 async def is_organization_member( 1a
109 self,
110 user_id: UUID,
111 organization_id: UUID,
112 ) -> bool:
113 statement = select(UserOrganization).where(
114 UserOrganization.user_id == user_id,
115 UserOrganization.organization_id == organization_id,
116 UserOrganization.deleted_at.is_(None),
117 )
118 result = await self.session.execute(statement)
119 return result.scalar_one_or_none() is not None
121 def get_sorting_clause(self, property: UserSortProperty) -> SortingClause: 1a
122 match property:
123 case UserSortProperty.created_at:
124 return self.model.created_at
125 case UserSortProperty.email:
126 return self.model.email