Coverage for polar/payment/repository.py: 29%

44 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 16:17 +0000

1from collections.abc import Sequence 1a

2from uuid import UUID 1a

3 

4from sqlalchemy import Select, func, select 1a

5 

6from polar.auth.models import AuthSubject, Organization, User, is_organization, is_user 1a

7from polar.enums import PaymentProcessor 1a

8from polar.kit.repository import ( 1a

9 RepositoryBase, 

10 RepositorySoftDeletionIDMixin, 

11 RepositorySoftDeletionMixin, 

12 RepositorySortingMixin, 

13 SortingClause, 

14) 

15from polar.models import Order, Payment, UserOrganization 1a

16from polar.models.payment import PaymentStatus 1a

17 

18from .sorting import PaymentSortProperty 1a

19 

20 

21class PaymentRepository( 1a

22 RepositorySortingMixin[Payment, PaymentSortProperty], 

23 RepositorySoftDeletionIDMixin[Payment, UUID], 

24 RepositorySoftDeletionMixin[Payment], 

25 RepositoryBase[Payment], 

26): 

27 model = Payment 1a

28 

29 async def get_all_by_customer( 1a

30 self, customer_id: UUID, *, status: PaymentStatus | None = None 

31 ) -> Sequence[Payment]: 

32 statement = ( 

33 self.get_base_statement() 

34 .join(Order, Payment.order_id == Order.id) 

35 .where(Order.deleted_at.is_(None), Order.customer_id == customer_id) 

36 ) 

37 if status is not None: 

38 statement = statement.where(Payment.status == status) 

39 return await self.get_all(statement) 

40 

41 async def get_by_processor_id( 1a

42 self, processor: PaymentProcessor, processor_id: str 

43 ) -> Payment | None: 

44 statement = self.get_base_statement().where( 

45 Payment.processor == processor, Payment.processor_id == processor_id 

46 ) 

47 return await self.get_one_or_none(statement) 

48 

49 def get_readable_statement( 1a

50 self, auth_subject: AuthSubject[User | Organization] 

51 ) -> Select[tuple[Payment]]: 

52 statement = self.get_base_statement() 

53 

54 if is_user(auth_subject): 

55 user = auth_subject.subject 

56 statement = statement.where( 

57 Payment.organization_id.in_( 

58 select(UserOrganization.organization_id).where( 

59 UserOrganization.user_id == user.id, 

60 UserOrganization.deleted_at.is_(None), 

61 ) 

62 ) 

63 ) 

64 elif is_organization(auth_subject): 

65 statement = statement.where( 

66 Payment.organization_id == auth_subject.subject.id, 

67 ) 

68 

69 return statement 

70 

71 def get_sorting_clause(self, property: PaymentSortProperty) -> SortingClause: 1a

72 match property: 

73 case PaymentSortProperty.created_at: 

74 return Payment.created_at 

75 case PaymentSortProperty.status: 

76 return Payment.status 

77 case PaymentSortProperty.amount: 

78 return Payment.amount 

79 case PaymentSortProperty.method: 

80 return Payment.method 

81 

82 async def count_failed_payments_for_order(self, order_id: UUID) -> int: 1a

83 """Count the number of failed payments for a specific order.""" 

84 statement = select(func.count(Payment.id)).where( 

85 Payment.order_id == order_id, 

86 Payment.status == PaymentStatus.failed, 

87 ) 

88 result = await self.session.execute(statement) 

89 return result.scalar() or 0 

90 

91 async def get_latest_for_order(self, order_id: UUID) -> Payment | None: 1a

92 """Get the latest payment for a specific order.""" 

93 statement = ( 

94 select(Payment) 

95 .where(Payment.order_id == order_id) 

96 .order_by(Payment.created_at.desc()) 

97 .limit(1) 

98 ) 

99 return await self.get_one_or_none(statement)