Coverage for polar/transaction/service/dispute.py: 26%

96 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1import math 1a

2 

3import stripe as stripe_lib 1a

4import structlog 1a

5from sqlalchemy import select 1a

6 

7from polar.integrations.stripe.utils import get_expandable_id 1a

8from polar.logging import Logger 1a

9from polar.models import Transaction 1a

10from polar.models.transaction import Processor, TransactionType 1a

11from polar.postgres import AsyncSession 1a

12 

13from .balance import balance_transaction as balance_transaction_service 1a

14from .base import BaseTransactionService, BaseTransactionServiceError 1a

15from .platform_fee import platform_fee_transaction as platform_fee_transaction_service 1a

16from .processor_fee import ( 1a

17 processor_fee_transaction as processor_fee_transaction_service, 

18) 

19 

20log: Logger = structlog.get_logger() 1a

21 

22 

23class DisputeTransactionError(BaseTransactionServiceError): ... 1a

24 

25 

26class DisputeClosed(DisputeTransactionError): 1a

27 def __init__(self, dispute_id: str) -> None: 1a

28 self.dispute_id = dispute_id 

29 message = f"Dispute {dispute_id} is closed." 

30 super().__init__(message) 

31 

32 

33class DisputeNotResolved(DisputeTransactionError): 1a

34 def __init__(self, dispute_id: str) -> None: 1a

35 self.dispute_id = dispute_id 

36 message = f"Dispute {dispute_id} is not resolved." 

37 super().__init__(message) 

38 

39 

40class DisputeUnknownPaymentTransaction(DisputeTransactionError): 1a

41 def __init__(self, dispute_id: str, charge_id: str) -> None: 1a

42 self.dispute_id = dispute_id 

43 self.charge_id = charge_id 

44 message = ( 

45 f"Dispute {dispute_id} created for charge {charge_id}, " 

46 "but the payment transaction is unknown." 

47 ) 

48 super().__init__(message) 

49 

50 

51class NotBalancedPaymentTransaction(DisputeTransactionError): 1a

52 def __init__(self, payment_transaction: Transaction) -> None: 1a

53 self.payment_transaction = payment_transaction 

54 message = ( 

55 f"Payment transaction {payment_transaction.id} is not balanced, " 

56 "cannot create dispute fees balances." 

57 ) 

58 super().__init__(message) 

59 

60 

61class DisputeTransactionService(BaseTransactionService): 1a

62 async def create_dispute( 1a

63 self, session: AsyncSession, *, dispute: stripe_lib.Dispute 

64 ) -> tuple[Transaction, Transaction | None]: 

65 if dispute.status in {"warning_closed"}: 

66 raise DisputeClosed(dispute.id) 

67 

68 if dispute.status not in {"won", "lost"}: 

69 raise DisputeNotResolved(dispute.id) 

70 

71 charge_id: str = get_expandable_id(dispute.charge) 

72 payment_transaction = await self.get_by( 

73 session, type=TransactionType.payment, charge_id=charge_id 

74 ) 

75 if payment_transaction is None: 

76 raise DisputeUnknownPaymentTransaction(dispute.id, charge_id) 

77 

78 dispute_amount = dispute.amount 

79 total_amount = payment_transaction.amount + payment_transaction.tax_amount 

80 tax_refund_amount = abs( 

81 int( 

82 math.floor(payment_transaction.tax_amount * dispute_amount) 

83 / total_amount 

84 ) 

85 ) 

86 

87 # Create the dispute, i.e. the transaction withdrawing the amount 

88 dispute_transaction = Transaction( 

89 type=TransactionType.dispute, 

90 processor=Processor.stripe, 

91 currency=dispute.currency, 

92 amount=-dispute.amount + tax_refund_amount, 

93 account_currency=dispute.currency, 

94 account_amount=-dispute.amount + tax_refund_amount, 

95 tax_amount=-tax_refund_amount, 

96 tax_country=payment_transaction.tax_country, 

97 tax_state=payment_transaction.tax_state, 

98 presentment_currency=dispute.currency, 

99 presentment_amount=-dispute.amount + tax_refund_amount, 

100 presentment_tax_amount=-tax_refund_amount, 

101 customer_id=payment_transaction.customer_id, 

102 charge_id=charge_id, 

103 dispute_id=dispute.id, 

104 payment_customer_id=payment_transaction.payment_customer_id, 

105 payment_organization_id=payment_transaction.payment_organization_id, 

106 payment_user_id=payment_transaction.payment_user_id, 

107 pledge_id=payment_transaction.pledge_id, 

108 issue_reward_id=payment_transaction.issue_reward_id, 

109 order_id=payment_transaction.order_id, 

110 incurred_transactions=[], 

111 ) 

112 session.add(dispute_transaction) 

113 dispute_fees = await processor_fee_transaction_service.create_dispute_fees( 

114 session, dispute_transaction=dispute_transaction, category="dispute" 

115 ) 

116 dispute_transaction.incurred_transactions = dispute_fees 

117 

118 # We won 😃 Create the dispute reversal, i.e. the transaction reinstating the amount 

119 dispute_reversal_transaction: Transaction | None = None 

120 if dispute.status == "won": 

121 dispute_reversal_transaction = Transaction( 

122 type=TransactionType.dispute_reversal, 

123 processor=Processor.stripe, 

124 currency=dispute.currency, 

125 amount=dispute.amount - tax_refund_amount, 

126 account_currency=dispute.currency, 

127 account_amount=dispute.amount - tax_refund_amount, 

128 tax_amount=tax_refund_amount, 

129 tax_country=payment_transaction.tax_country, 

130 tax_state=payment_transaction.tax_state, 

131 presentment_currency=dispute.currency, 

132 presentment_amount=dispute.amount - tax_refund_amount, 

133 presentment_tax_amount=tax_refund_amount, 

134 customer_id=payment_transaction.customer_id, 

135 charge_id=charge_id, 

136 dispute_id=dispute.id, 

137 payment_customer_id=payment_transaction.payment_customer_id, 

138 payment_organization_id=payment_transaction.payment_organization_id, 

139 payment_user_id=payment_transaction.payment_user_id, 

140 pledge_id=payment_transaction.pledge_id, 

141 issue_reward_id=payment_transaction.issue_reward_id, 

142 order_id=payment_transaction.order_id, 

143 incurred_transactions=[], 

144 ) 

145 session.add(dispute_reversal_transaction) 

146 dispute_reversal_fees = ( 

147 await processor_fee_transaction_service.create_dispute_fees( 

148 session, 

149 dispute_transaction=dispute_reversal_transaction, 

150 category="dispute_reversal", 

151 ) 

152 ) 

153 dispute_reversal_transaction.incurred_transactions = dispute_reversal_fees 

154 # We lost 😢 Reverse the balances on the organization's account if it was already balanced 

155 elif dispute.status == "lost": 

156 await self._create_reversal_balances( 

157 session, 

158 payment_transaction=payment_transaction, 

159 dispute_amount=dispute_amount, 

160 ) 

161 

162 # Balance the (crazy high) dispute fees on the organization's account 

163 all_fees = dispute_fees 

164 if dispute_reversal_transaction is not None: 

165 all_fees += dispute_reversal_fees 

166 

167 try: 

168 await self._create_dispute_fees_balances( 

169 session, payment_transaction=payment_transaction, dispute_fees=all_fees 

170 ) 

171 except NotBalancedPaymentTransaction: 

172 log.warning( 

173 "Dispute fees balances could not be created for payment transaction", 

174 payment_transaction_id=payment_transaction.id, 

175 dispute_id=dispute.id, 

176 ) 

177 

178 await session.flush() 

179 

180 return dispute_transaction, dispute_reversal_transaction 

181 

182 async def create_reversal_balances_for_payment( 1a

183 self, session: AsyncSession, *, payment_transaction: Transaction 

184 ) -> list[tuple[Transaction, Transaction]]: 

185 """ 

186 Create reversal balances for a disputed payment transaction. 

187 

188 Mostly useful when releasing held balances: if a payment transaction has 

189 been disputed before the Account creation, we need to create the reversal 

190 balances so the refund is correctly accounted for. 

191 """ 

192 statement = select(Transaction).where( 

193 Transaction.type == TransactionType.dispute, 

194 Transaction.charge_id == payment_transaction.charge_id, 

195 ) 

196 result = await session.execute(statement) 

197 disputes = result.scalars().all() 

198 

199 reversal_balances: list[tuple[Transaction, Transaction]] = [] 

200 for dispute in disputes: 

201 # Skip if there is a dispute reversal: the operations are neutral 

202 dispute_reversal = await self.get_by( 

203 session, 

204 type=TransactionType.dispute_reversal, 

205 dispute_id=dispute.dispute_id, 

206 ) 

207 if dispute_reversal is not None: 

208 continue 

209 

210 reversal_balances += await self._create_reversal_balances( 

211 session, 

212 payment_transaction=payment_transaction, 

213 dispute_amount=dispute.amount, 

214 ) 

215 

216 return reversal_balances 

217 

218 async def _create_reversal_balances( 1a

219 self, 

220 session: AsyncSession, 

221 *, 

222 payment_transaction: Transaction, 

223 dispute_amount: int, 

224 ) -> list[tuple[Transaction, Transaction]]: 

225 total_amount = payment_transaction.amount + payment_transaction.tax_amount 

226 

227 reversal_balances: list[tuple[Transaction, Transaction]] = [] 

228 balance_transactions_couples = await self._get_balance_transactions_for_payment( 

229 session, payment_transaction=payment_transaction 

230 ) 

231 for balance_transactions_couple in balance_transactions_couples: 

232 outgoing, _ = balance_transactions_couple 

233 # Refund each balance proportionally 

234 balance_refund_amount = abs( 

235 int(math.floor(outgoing.amount * dispute_amount) / total_amount) 

236 ) 

237 reversal_balances.append( 

238 await balance_transaction_service.create_reversal_balance( 

239 session, 

240 balance_transactions=balance_transactions_couple, 

241 amount=balance_refund_amount, 

242 ) 

243 ) 

244 

245 return reversal_balances 

246 

247 async def _create_dispute_fees_balances( 1a

248 self, 

249 session: AsyncSession, 

250 *, 

251 payment_transaction: Transaction, 

252 dispute_fees: list[Transaction], 

253 ) -> list[tuple[Transaction, Transaction]]: 

254 balance_transactions_couples = await self._get_balance_transactions_for_payment( 

255 session, payment_transaction=payment_transaction 

256 ) 

257 if len(balance_transactions_couples) == 0: 

258 raise NotBalancedPaymentTransaction(payment_transaction) 

259 return await platform_fee_transaction_service.create_dispute_fees_balances( 

260 session, 

261 dispute_fees=dispute_fees, 

262 balance_transactions=balance_transactions_couples[0], 

263 ) 

264 

265 

266dispute_transaction = DisputeTransactionService(Transaction) 1a