Coverage for polar/transaction/service/refund.py: 29%
92 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
1import itertools 1a
2import math 1a
4from sqlalchemy import select 1a
5from sqlalchemy.orm import joinedload 1a
7from polar.models import Refund, Transaction 1a
8from polar.models.refund import RefundStatus 1a
9from polar.models.transaction import TransactionType 1a
10from polar.postgres import AsyncSession 1a
11from polar.transaction.repository import RefundTransactionRepository 1a
13from .balance import balance_transaction as balance_transaction_service 1a
14from .base import BaseTransactionService, BaseTransactionServiceError 1a
15from .processor_fee import ( 1a
16 processor_fee_transaction as processor_fee_transaction_service,
17)
20class RefundTransactionError(BaseTransactionServiceError): ... 1a
23class NotSucceededRefundError(RefundTransactionError): 1a
24 def __init__(self, refund: Refund) -> None: 1a
25 self.refund = refund
26 super().__init__(f"Refund {refund.id} is not succeeded")
29class RefundTransactionAlreadyExistsError(RefundTransactionError): 1a
30 def __init__(self, refund: Refund) -> None: 1a
31 self.refund = refund
32 super().__init__(f"Refund transaction already exists for {refund.id}")
35class NotCanceledRefundError(RefundTransactionError): 1a
36 def __init__(self, refund: Refund) -> None: 1a
37 self.refund = refund
38 super().__init__(f"Refund {refund.id} is not canceled or failed")
41class RefundTransactionDoesNotExistError(RefundTransactionError): 1a
42 def __init__(self, refund: Refund) -> None: 1a
43 self.refund = refund
44 super().__init__(f"Refund transaction does not exist for {refund.id}")
47class RefundTransactionService(BaseTransactionService): 1a
48 async def create( 1a
49 self,
50 session: AsyncSession,
51 *,
52 charge_id: str,
53 payment_transaction: Transaction,
54 refund: Refund,
55 ) -> Transaction:
56 if not refund.succeeded:
57 raise NotSucceededRefundError(refund)
59 repository = RefundTransactionRepository.from_session(session)
60 if await repository.get_by_refund_id(refund.processor_id) is not None:
61 raise RefundTransactionAlreadyExistsError(refund)
63 refund_transaction = Transaction(
64 type=TransactionType.refund,
65 processor=refund.processor,
66 currency=refund.currency,
67 amount=-refund.amount,
68 account_currency=refund.currency,
69 account_amount=-refund.amount,
70 tax_amount=-refund.tax_amount,
71 tax_country=payment_transaction.tax_country,
72 tax_state=payment_transaction.tax_state,
73 presentment_currency=refund.currency,
74 presentment_amount=-refund.amount,
75 presentment_tax_amount=-refund.tax_amount,
76 customer_id=payment_transaction.customer_id,
77 charge_id=charge_id,
78 refund_id=refund.processor_id,
79 polar_refund_id=refund.id,
80 payment_customer_id=payment_transaction.payment_customer_id,
81 payment_organization_id=payment_transaction.payment_organization_id,
82 payment_user_id=payment_transaction.payment_user_id,
83 pledge_id=payment_transaction.pledge_id,
84 issue_reward_id=payment_transaction.issue_reward_id,
85 order_id=payment_transaction.order_id,
86 )
88 # Compute and link fees
89 transaction_fees = await processor_fee_transaction_service.create_refund_fees(
90 session, refund=refund, refund_transaction=refund_transaction
91 )
92 refund_transaction.incurred_transactions = transaction_fees
93 session.add(refund_transaction)
95 # Create reversal balances if it was already balanced
96 await self._create_reversal_balances(
97 session,
98 payment_transaction=payment_transaction,
99 refund_amount=refund.amount,
100 )
101 return refund_transaction
103 async def revert( 1a
104 self,
105 session: AsyncSession,
106 *,
107 charge_id: str,
108 payment_transaction: Transaction,
109 refund: Refund,
110 ) -> Transaction:
111 if refund.status not in {RefundStatus.canceled, RefundStatus.failed}:
112 raise NotCanceledRefundError(refund)
114 repository = RefundTransactionRepository.from_session(session)
115 refund_transaction = await repository.get_by_refund_id(refund.processor_id)
116 if refund_transaction is None:
117 raise RefundTransactionDoesNotExistError(refund)
119 refund_reversal_transaction = Transaction(
120 type=TransactionType.refund_reversal,
121 processor=refund.processor,
122 currency=refund.currency,
123 amount=refund.amount,
124 account_currency=refund.currency,
125 account_amount=refund.amount,
126 tax_amount=refund.tax_amount,
127 tax_country=payment_transaction.tax_country,
128 tax_state=payment_transaction.tax_state,
129 presentment_currency=refund.currency,
130 presentment_amount=refund.amount,
131 presentment_tax_amount=refund.tax_amount,
132 customer_id=payment_transaction.customer_id,
133 charge_id=charge_id,
134 refund_id=refund.processor_id,
135 polar_refund_id=refund.id,
136 payment_customer_id=payment_transaction.payment_customer_id,
137 payment_organization_id=payment_transaction.payment_organization_id,
138 payment_user_id=payment_transaction.payment_user_id,
139 pledge_id=payment_transaction.pledge_id,
140 issue_reward_id=payment_transaction.issue_reward_id,
141 order_id=payment_transaction.order_id,
142 )
143 session.add(refund_reversal_transaction)
145 # Create reversal balances if it was already balanced
146 await self._create_revert_reversal_balances(
147 session,
148 payment_transaction=payment_transaction,
149 refund_amount=refund.amount,
150 )
151 return refund_reversal_transaction
153 async def create_reversal_balances_for_payment( 1a
154 self, session: AsyncSession, *, payment_transaction: Transaction
155 ) -> list[tuple[Transaction, Transaction]]:
156 """
157 Create reversal balances for a refunded payment transaction.
159 Mostly useful when releasing held balances: if a payment transaction has
160 been refunded before the Account creation, we need to create the reversal
161 balances so the refund is correctly accounted for.
162 """
163 statement = select(Transaction).where(
164 Transaction.type == TransactionType.refund,
165 Transaction.charge_id == payment_transaction.charge_id,
166 )
168 result = await session.execute(statement)
169 refunds = result.scalars().all()
171 reversal_balances: list[tuple[Transaction, Transaction]] = []
172 for refund in refunds:
173 reversal_balances += await self._create_reversal_balances(
174 session,
175 payment_transaction=payment_transaction,
176 refund_amount=refund.amount,
177 )
179 return reversal_balances
181 async def _create_reversal_balances( 1a
182 self,
183 session: AsyncSession,
184 *,
185 payment_transaction: Transaction,
186 refund_amount: int,
187 ) -> list[tuple[Transaction, Transaction]]:
188 total_amount = payment_transaction.amount
190 reversal_balances: list[tuple[Transaction, Transaction]] = []
191 balance_transactions_couples = await self._get_balance_transactions_for_payment(
192 session, payment_transaction=payment_transaction
193 )
194 for balance_transactions_couple in balance_transactions_couples:
195 outgoing, _ = balance_transactions_couple
196 # Refund each balance proportionally
197 balance_refund_amount = abs(
198 int(math.floor(outgoing.amount * refund_amount) / total_amount)
199 )
200 reversal_balances.append(
201 await balance_transaction_service.create_reversal_balance(
202 session,
203 balance_transactions=balance_transactions_couple,
204 amount=balance_refund_amount,
205 )
206 )
207 return reversal_balances
209 async def _create_revert_reversal_balances( 1a
210 self,
211 session: AsyncSession,
212 *,
213 payment_transaction: Transaction,
214 refund_amount: int,
215 ) -> list[tuple[Transaction, Transaction]]:
216 total_amount = payment_transaction.amount
218 revert_reversal_balances: list[tuple[Transaction, Transaction]] = []
219 reverse_balance_transactions_couples = (
220 await self._get_reverse_balance_transactions_for_payment(
221 session, payment_transaction=payment_transaction
222 )
223 )
224 for reverse_balance_transactions_couple in reverse_balance_transactions_couples:
225 outgoing, incoming = reverse_balance_transactions_couple
226 assert outgoing.account is not None
227 # Reverse each balance proportionally
228 balance_reversal_amount = abs(
229 int(math.floor(outgoing.amount * refund_amount) / total_amount)
230 )
231 (
232 outgoing_reversal,
233 incoming_reversal,
234 ) = await balance_transaction_service.create_balance(
235 session,
236 source_account=None,
237 destination_account=outgoing.account,
238 amount=balance_reversal_amount,
239 pledge=outgoing.pledge,
240 order=outgoing.order,
241 issue_reward=outgoing.issue_reward,
242 )
244 # Tie the reversal to the original transactions, not the refunds
245 # This way, it'll get picked up when transferring the payment
246 # Basically, it'll do (+100 - 100 + 100)
247 outgoing_reversal.balance_reversal_transaction = (
248 incoming.balance_reversal_transaction
249 )
250 incoming_reversal.balance_reversal_transaction = (
251 outgoing.balance_reversal_transaction
252 )
253 session.add(outgoing_reversal)
254 session.add(incoming_reversal)
255 revert_reversal_balances.append((outgoing_reversal, incoming_reversal))
257 return revert_reversal_balances
259 async def _get_reverse_balance_transactions_for_payment( 1a
260 self, session: AsyncSession, *, payment_transaction: Transaction
261 ) -> list[tuple[Transaction, Transaction]]:
262 """
263 Get the balance transactions that have been reversed by the refund.
264 """
265 balance_transactions_statement = select(Transaction.id).where(
266 Transaction.type == TransactionType.balance,
267 Transaction.payment_transaction_id == payment_transaction.id,
268 )
269 statement = (
270 select(Transaction)
271 .where(
272 Transaction.type == TransactionType.balance,
273 Transaction.balance_reversal_transaction_id.in_(
274 balance_transactions_statement
275 ),
276 # WARNING: not a bulletproof solution
277 # In most cases, reversal balances should either be platform fees or refunds,
278 # but other situations may appear in the future.
279 Transaction.platform_fee_type.is_(None),
280 )
281 .order_by(
282 Transaction.balance_correlation_key,
283 Transaction.account_id.nulls_last(),
284 )
285 .options(
286 joinedload(Transaction.account),
287 joinedload(Transaction.pledge),
288 joinedload(Transaction.order),
289 joinedload(Transaction.issue_reward),
290 joinedload(Transaction.balance_reversal_transaction),
291 )
292 )
294 result = await session.execute(statement)
295 transactions = list(result.scalars().all())
296 return [
297 (t1, t2)
298 for _, (t1, t2) in itertools.groupby(
299 transactions, key=lambda t: t.balance_correlation_key
300 )
301 ]
304refund_transaction = RefundTransactionService(Transaction) 1a