Coverage for polar/member/repository.py: 38%

34 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 16:17 +0000

1from collections.abc import Sequence 1a

2from uuid import UUID 1a

3 

4from sqlalchemy import Select, select 1a

5 

6from polar.auth.models import AuthSubject, Organization, User, is_organization, is_user 1a

7from polar.kit.repository import ( 1a

8 RepositoryBase, 

9 RepositorySoftDeletionIDMixin, 

10 RepositorySoftDeletionMixin, 

11) 

12from polar.models.customer import Customer 1a

13from polar.models.member import Member 1a

14from polar.models.user_organization import UserOrganization 1a

15from polar.postgres import AsyncReadSession, AsyncSession 1a

16 

17 

18class MemberRepository( 1a

19 RepositorySoftDeletionIDMixin[Member, UUID], 

20 RepositorySoftDeletionMixin[Member], 

21 RepositoryBase[Member], 

22): 

23 model = Member 1a

24 

25 async def get_by_customer_and_email( 1a

26 self, 

27 session: AsyncSession, 

28 customer: Customer, 

29 email: str | None = None, 

30 ) -> Member | None: 

31 """ 

32 Get a member by customer and email. 

33 

34 Returns: 

35 Member if found, None otherwise 

36 """ 

37 email = email or customer.email 

38 statement = select(Member).where( 

39 Member.customer_id == customer.id, 

40 Member.email == email, 

41 Member.deleted_at.is_(None), 

42 ) 

43 result = await session.execute(statement) 

44 return result.scalar_one_or_none() 

45 

46 async def list_by_customer( 1a

47 self, 

48 session: AsyncReadSession, 

49 customer_id: UUID, 

50 ) -> Sequence[Member]: 

51 statement = select(Member).where( 

52 Member.customer_id == customer_id, 

53 Member.deleted_at.is_(None), 

54 ) 

55 result = await session.execute(statement) 

56 return result.scalars().all() 

57 

58 async def list_by_customers( 1a

59 self, 

60 session: AsyncReadSession, 

61 customer_ids: Sequence[UUID], 

62 ) -> Sequence[Member]: 

63 """ 

64 Get all members for multiple customers (batch loading to avoid N+1 queries). 

65 """ 

66 if not customer_ids: 

67 return [] 

68 

69 statement = select(Member).where( 

70 Member.customer_id.in_(customer_ids), 

71 Member.deleted_at.is_(None), 

72 ) 

73 result = await session.execute(statement) 

74 return result.scalars().all() 

75 

76 def get_readable_statement( 1a

77 self, auth_subject: AuthSubject[User | Organization] 

78 ) -> Select[tuple[Member]]: 

79 """Get a statement filtered by the auth subject's access to organizations.""" 

80 statement = self.get_base_statement() 

81 

82 if is_user(auth_subject): 

83 user = auth_subject.subject 

84 statement = statement.where( 

85 Member.organization_id.in_( 

86 select(UserOrganization.organization_id).where( 

87 UserOrganization.user_id == user.id, 

88 UserOrganization.deleted_at.is_(None), 

89 ) 

90 ) 

91 ) 

92 elif is_organization(auth_subject): 

93 statement = statement.where( 

94 Member.organization_id == auth_subject.subject.id, 

95 ) 

96 

97 return statement