Coverage for polar/user/repository.py: 29%

45 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 15:52 +0000

1from collections.abc import Sequence 1a

2from uuid import UUID 1a

3 

4from sqlalchemy import func, select 1a

5 

6from polar.kit.repository import ( 1a

7 RepositoryBase, 

8 RepositorySoftDeletionIDMixin, 

9 RepositorySoftDeletionMixin, 

10 RepositorySortingMixin, 

11) 

12from polar.kit.repository.base import SortingClause 1a

13from polar.models import OAuthAccount, User, UserOrganization 1a

14from polar.models.user import OAuthPlatform 1a

15 

16from .sorting import UserSortProperty 1a

17 

18 

19class UserRepository( 1a

20 RepositorySortingMixin[User, UserSortProperty], 

21 RepositorySoftDeletionIDMixin[User, UUID], 

22 RepositorySoftDeletionMixin[User], 

23 RepositoryBase[User], 

24): 

25 model = User 1a

26 

27 async def get_by_email( 1a

28 self, 

29 email: str, 

30 *, 

31 include_deleted: bool = False, 

32 included_blocked: bool = False, 

33 ) -> User | None: 

34 statement = self.get_base_statement(include_deleted=include_deleted).where( 

35 func.lower(User.email) == email.lower() 

36 ) 

37 if not included_blocked: 

38 statement = statement.where(User.blocked_at.is_(None)) 

39 return await self.get_one_or_none(statement) 

40 

41 async def get_by_stripe_customer_id( 1a

42 self, 

43 stripe_customer_id: str, 

44 *, 

45 include_deleted: bool = False, 

46 included_blocked: bool = False, 

47 ) -> User | None: 

48 statement = self.get_base_statement(include_deleted=include_deleted).where( 

49 User.stripe_customer_id == stripe_customer_id 

50 ) 

51 if not included_blocked: 

52 statement = statement.where(User.blocked_at.is_(None)) 

53 return await self.get_one_or_none(statement) 

54 

55 async def get_by_oauth_account( 1a

56 self, 

57 platform: OAuthPlatform, 

58 account_id: str, 

59 *, 

60 include_deleted: bool = False, 

61 included_blocked: bool = False, 

62 ) -> User | None: 

63 statement = ( 

64 self.get_base_statement(include_deleted=include_deleted) 

65 .join(User.oauth_accounts) 

66 .where( 

67 OAuthAccount.platform == platform, 

68 OAuthAccount.account_id == account_id, 

69 ) 

70 ) 

71 if not included_blocked: 

72 statement = statement.where(User.blocked_at.is_(None)) 

73 return await self.get_one_or_none(statement) 

74 

75 async def get_by_identity_verification_id( 1a

76 self, 

77 identity_verification_id: str, 

78 *, 

79 include_deleted: bool = False, 

80 included_blocked: bool = False, 

81 ) -> User | None: 

82 statement = self.get_base_statement(include_deleted=include_deleted).where( 

83 User.identity_verification_id == identity_verification_id 

84 ) 

85 if not included_blocked: 

86 statement = statement.where(User.blocked_at.is_(None)) 

87 return await self.get_one_or_none(statement) 

88 

89 async def get_all_by_organization( 1a

90 self, 

91 organization_id: UUID, 

92 *, 

93 include_deleted: bool = False, 

94 included_blocked: bool = False, 

95 ) -> Sequence[User]: 

96 statement = ( 

97 self.get_base_statement(include_deleted=include_deleted) 

98 .join(UserOrganization, UserOrganization.user_id == User.id) 

99 .where( 

100 UserOrganization.deleted_at.is_(None), 

101 UserOrganization.organization_id == organization_id, 

102 ) 

103 ) 

104 if not included_blocked: 

105 statement = statement.where(User.blocked_at.is_(None)) 

106 return await self.get_all(statement) 

107 

108 async def is_organization_member( 1a

109 self, 

110 user_id: UUID, 

111 organization_id: UUID, 

112 ) -> bool: 

113 statement = select(UserOrganization).where( 

114 UserOrganization.user_id == user_id, 

115 UserOrganization.organization_id == organization_id, 

116 UserOrganization.deleted_at.is_(None), 

117 ) 

118 result = await self.session.execute(statement) 

119 return result.scalar_one_or_none() is not None 

120 

121 def get_sorting_clause(self, property: UserSortProperty) -> SortingClause: 1a

122 match property: 

123 case UserSortProperty.created_at: 

124 return self.model.created_at 

125 case UserSortProperty.email: 

126 return self.model.email