Coverage for polar/payment/repository.py: 29%
44 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
1from collections.abc import Sequence 1a
2from uuid import UUID 1a
4from sqlalchemy import Select, func, select 1a
6from polar.auth.models import AuthSubject, Organization, User, is_organization, is_user 1a
7from polar.enums import PaymentProcessor 1a
8from polar.kit.repository import ( 1a
9 RepositoryBase,
10 RepositorySoftDeletionIDMixin,
11 RepositorySoftDeletionMixin,
12 RepositorySortingMixin,
13 SortingClause,
14)
15from polar.models import Order, Payment, UserOrganization 1a
16from polar.models.payment import PaymentStatus 1a
18from .sorting import PaymentSortProperty 1a
21class PaymentRepository( 1a
22 RepositorySortingMixin[Payment, PaymentSortProperty],
23 RepositorySoftDeletionIDMixin[Payment, UUID],
24 RepositorySoftDeletionMixin[Payment],
25 RepositoryBase[Payment],
26):
27 model = Payment 1a
29 async def get_all_by_customer( 1a
30 self, customer_id: UUID, *, status: PaymentStatus | None = None
31 ) -> Sequence[Payment]:
32 statement = (
33 self.get_base_statement()
34 .join(Order, Payment.order_id == Order.id)
35 .where(Order.deleted_at.is_(None), Order.customer_id == customer_id)
36 )
37 if status is not None:
38 statement = statement.where(Payment.status == status)
39 return await self.get_all(statement)
41 async def get_by_processor_id( 1a
42 self, processor: PaymentProcessor, processor_id: str
43 ) -> Payment | None:
44 statement = self.get_base_statement().where(
45 Payment.processor == processor, Payment.processor_id == processor_id
46 )
47 return await self.get_one_or_none(statement)
49 def get_readable_statement( 1a
50 self, auth_subject: AuthSubject[User | Organization]
51 ) -> Select[tuple[Payment]]:
52 statement = self.get_base_statement()
54 if is_user(auth_subject):
55 user = auth_subject.subject
56 statement = statement.where(
57 Payment.organization_id.in_(
58 select(UserOrganization.organization_id).where(
59 UserOrganization.user_id == user.id,
60 UserOrganization.deleted_at.is_(None),
61 )
62 )
63 )
64 elif is_organization(auth_subject):
65 statement = statement.where(
66 Payment.organization_id == auth_subject.subject.id,
67 )
69 return statement
71 def get_sorting_clause(self, property: PaymentSortProperty) -> SortingClause: 1a
72 match property:
73 case PaymentSortProperty.created_at:
74 return Payment.created_at
75 case PaymentSortProperty.status:
76 return Payment.status
77 case PaymentSortProperty.amount:
78 return Payment.amount
79 case PaymentSortProperty.method:
80 return Payment.method
82 async def count_failed_payments_for_order(self, order_id: UUID) -> int: 1a
83 """Count the number of failed payments for a specific order."""
84 statement = select(func.count(Payment.id)).where(
85 Payment.order_id == order_id,
86 Payment.status == PaymentStatus.failed,
87 )
88 result = await self.session.execute(statement)
89 return result.scalar() or 0
91 async def get_latest_for_order(self, order_id: UUID) -> Payment | None: 1a
92 """Get the latest payment for a specific order."""
93 statement = (
94 select(Payment)
95 .where(Payment.order_id == order_id)
96 .order_by(Payment.created_at.desc())
97 .limit(1)
98 )
99 return await self.get_one_or_none(statement)