Coverage for polar/wallet/repository.py: 40%

37 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1from uuid import UUID 1a

2 

3from sqlalchemy import Select, func, select 1a

4from sqlalchemy.orm import contains_eager, joinedload 1a

5 

6from polar.auth.models import AuthSubject, Organization, User, is_organization, is_user 1a

7from polar.kit.repository import ( 1a

8 Options, 

9 RepositoryBase, 

10 RepositoryIDMixin, 

11 RepositorySoftDeletionIDMixin, 

12 RepositorySoftDeletionMixin, 

13 RepositorySortingMixin, 

14 SortingClause, 

15) 

16from polar.models import Customer, UserOrganization, Wallet, WalletTransaction 1a

17from polar.models.wallet import WalletType 1a

18 

19from .sorting import WalletSortProperty 1a

20 

21 

22class WalletRepository( 1a

23 RepositorySortingMixin[Wallet, WalletSortProperty], 

24 RepositorySoftDeletionIDMixin[Wallet, UUID], 

25 RepositorySoftDeletionMixin[Wallet], 

26 RepositoryBase[Wallet], 

27): 

28 model = Wallet 1a

29 

30 async def get_by_type_currency_customer( 1a

31 self, type: WalletType, currency: str, customer_id: UUID 

32 ) -> Wallet | None: 

33 statement = self.get_base_statement().where( 

34 Wallet.type == type, 

35 Wallet.currency == currency, 

36 Wallet.customer_id == customer_id, 

37 ) 

38 return await self.get_one_or_none(statement) 

39 

40 def get_readable_statement( 1a

41 self, auth_subject: AuthSubject[User | Organization] 

42 ) -> Select[tuple[Wallet]]: 

43 statement = ( 

44 self.get_base_statement() 

45 .join(Customer, Wallet.customer_id == Customer.id) 

46 .options( 

47 contains_eager(Wallet.customer).joinedload(Customer.organization), 

48 ) 

49 ) 

50 

51 if is_user(auth_subject): 

52 user = auth_subject.subject 

53 statement = statement.where( 

54 Customer.organization_id.in_( 

55 select(UserOrganization.organization_id).where( 

56 UserOrganization.user_id == user.id, 

57 UserOrganization.deleted_at.is_(None), 

58 ) 

59 ) 

60 ) 

61 elif is_organization(auth_subject): 

62 statement = statement.where( 

63 Customer.organization_id == auth_subject.subject.id, 

64 ) 

65 

66 return statement 

67 

68 def get_eager_options(self) -> Options: 1a

69 return (joinedload(Wallet.customer).joinedload(Customer.organization),) 

70 

71 def get_sorting_clause(self, property: WalletSortProperty) -> SortingClause: 1a

72 match property: 

73 case WalletSortProperty.created_at: 

74 return Wallet.created_at 

75 case WalletSortProperty.balance: 

76 return Wallet.balance 

77 

78 

79class WalletTransactionRepository( 1a

80 RepositoryIDMixin[WalletTransaction, UUID], 

81 RepositoryBase[WalletTransaction], 

82): 

83 model = WalletTransaction 1a

84 

85 async def get_balance(self, wallet_id: UUID) -> int: 1a

86 statement = select(func.coalesce(func.sum(WalletTransaction.amount), 0)).where( 

87 WalletTransaction.wallet_id == wallet_id, 

88 ) 

89 result = await self.session.execute(statement) 

90 return result.scalar_one() 

91 

92 def get_eager_options(self) -> Options: 1a

93 return ( 

94 joinedload(WalletTransaction.wallet).options( 

95 joinedload(Wallet.customer).joinedload(Customer.organization) 

96 ), 

97 )