Coverage for polar/kit/tax.py: 50%
297 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
1import hashlib 1ab
2import json 1ab
3import uuid 1ab
4from collections.abc import Sequence 1ab
5from enum import StrEnum 1ab
6from typing import Annotated, Any, Literal, LiteralString, Protocol, TypedDict 1ab
8import stdnum.ca.bn 1ab
9import stdnum.cl.rut 1ab
10import stdnum.exceptions 1ab
11import stdnum.in_.gstin 1ab
12import stdnum.tr.vkn 1ab
13import stripe as stripe_lib 1ab
14import structlog 1ab
15from pydantic import Field 1ab
16from sqlalchemy.dialects.postgresql import JSONB 1ab
17from sqlalchemy.engine.interfaces import Dialect 1ab
18from sqlalchemy.types import TypeDecorator 1ab
19from stdnum import get_cc_module 1ab
21from polar.config import settings 1ab
22from polar.exceptions import PolarError 1ab
23from polar.integrations.stripe.service import stripe as stripe_service 1ab
24from polar.kit.address import Address 1ab
25from polar.logging import Logger 1ab
27log: Logger = structlog.get_logger() 1ab
30class TaxIDFormat(StrEnum): 1ab
31 """
32 List of supported tax ID formats.
34 Ref: https://docs.stripe.com/billing/customer/tax-ids#supported-tax-id
35 """
37 ad_nrt = "ad_nrt" 1ab
38 ae_trn = "ae_trn" 1ab
39 ar_cuit = "ar_cuit" 1ab
40 au_abn = "au_abn" 1ab
41 au_arn = "au_arn" 1ab
42 bg_uic = "bg_uic" 1ab
43 bh_vat = "bh_vat" 1ab
44 bo_tin = "bo_tin" 1ab
45 br_cnpj = "br_cnpj" 1ab
46 br_cpf = "br_cpf" 1ab
47 ca_bn = "ca_bn" 1ab
48 ca_gst_hst = "ca_gst_hst" 1ab
49 ca_pst_bc = "ca_pst_bc" 1ab
50 ca_pst_mb = "ca_pst_mb" 1ab
51 ca_pst_sk = "ca_pst_sk" 1ab
52 ca_qst = "ca_qst" 1ab
53 ch_uid = "ch_uid" 1ab
54 ch_vat = "ch_vat" 1ab
55 cl_tin = "cl_tin" 1ab
56 cn_tin = "cn_tin" 1ab
57 co_nit = "co_nit" 1ab
58 cr_tin = "cr_tin" 1ab
59 de_stn = "de_stn" 1ab
60 do_rcn = "do_rcn" 1ab
61 ec_ruc = "ec_ruc" 1ab
62 eg_tin = "eg_tin" 1ab
63 es_cif = "es_cif" 1ab
64 eu_oss_vat = "eu_oss_vat" 1ab
65 eu_vat = "eu_vat" 1ab
66 gb_vat = "gb_vat" 1ab
67 ge_vat = "ge_vat" 1ab
68 hk_br = "hk_br" 1ab
69 hr_oib = "hr_oib" 1ab
70 hu_tin = "hu_tin" 1ab
71 id_npwp = "id_npwp" 1ab
72 il_vat = "il_vat" 1ab
73 in_gst = "in_gst" 1ab
74 is_vat = "is_vat" 1ab
75 jp_cn = "jp_cn" 1ab
76 jp_rn = "jp_rn" 1ab
77 jp_trn = "jp_trn" 1ab
78 ke_pin = "ke_pin" 1ab
79 kr_brn = "kr_brn" 1ab
80 kz_bin = "kz_bin" 1ab
81 li_uid = "li_uid" 1ab
82 mx_rfc = "mx_rfc" 1ab
83 my_frp = "my_frp" 1ab
84 my_itn = "my_itn" 1ab
85 my_sst = "my_sst" 1ab
86 ng_tin = "ng_tin" 1ab
87 no_vat = "no_vat" 1ab
88 no_voec = "no_voec" 1ab
89 nz_gst = "nz_gst" 1ab
90 om_vat = "om_vat" 1ab
91 pe_ruc = "pe_ruc" 1ab
92 ph_tin = "ph_tin" 1ab
93 ro_tin = "ro_tin" 1ab
94 rs_pib = "rs_pib" 1ab
95 ru_inn = "ru_inn" 1ab
96 ru_kpp = "ru_kpp" 1ab
97 sa_vat = "sa_vat" 1ab
98 sg_gst = "sg_gst" 1ab
99 sg_uen = "sg_uen" 1ab
100 si_tin = "si_tin" 1ab
101 sv_nit = "sv_nit" 1ab
102 th_vat = "th_vat" 1ab
103 tr_tin = "tr_tin" 1ab
104 tw_vat = "tw_vat" 1ab
105 ua_vat = "ua_vat" 1ab
106 us_ein = "us_ein" 1ab
107 uy_ruc = "uy_ruc" 1ab
108 ve_rif = "ve_rif" 1ab
109 vn_tin = "vn_tin" 1ab
110 za_vat = "za_vat" 1ab
113COUNTRY_TAX_ID_MAP: dict[str, Sequence[TaxIDFormat]] = { 1ab
114 "AD": (TaxIDFormat.ad_nrt,),
115 "AE": (TaxIDFormat.ae_trn,),
116 "AR": (TaxIDFormat.ar_cuit,),
117 "AT": (TaxIDFormat.eu_vat,),
118 "AU": (TaxIDFormat.au_abn, TaxIDFormat.au_arn),
119 "BE": (TaxIDFormat.eu_vat,),
120 "BG": (TaxIDFormat.bg_uic, TaxIDFormat.eu_vat),
121 "BH": (TaxIDFormat.bh_vat,),
122 "BO": (TaxIDFormat.bo_tin,),
123 "BR": (TaxIDFormat.br_cnpj, TaxIDFormat.br_cpf),
124 "CA": (
125 TaxIDFormat.ca_gst_hst,
126 TaxIDFormat.ca_pst_bc,
127 TaxIDFormat.ca_pst_mb,
128 TaxIDFormat.ca_pst_sk,
129 TaxIDFormat.ca_qst,
130 TaxIDFormat.ca_bn,
131 ),
132 "CH": (TaxIDFormat.ch_uid, TaxIDFormat.ch_vat),
133 "CL": (TaxIDFormat.cl_tin,),
134 "CN": (TaxIDFormat.cn_tin,),
135 "CO": (TaxIDFormat.co_nit,),
136 "CR": (TaxIDFormat.cr_tin,),
137 "CY": (TaxIDFormat.eu_vat,),
138 "CZ": (TaxIDFormat.eu_vat,),
139 "DE": (TaxIDFormat.de_stn, TaxIDFormat.eu_vat),
140 "DK": (TaxIDFormat.eu_vat,),
141 "DO": (TaxIDFormat.do_rcn,),
142 "EC": (TaxIDFormat.ec_ruc,),
143 "EE": (TaxIDFormat.eu_vat,),
144 "EG": (TaxIDFormat.eg_tin,),
145 "ES": (TaxIDFormat.es_cif, TaxIDFormat.eu_vat),
146 "FI": (TaxIDFormat.eu_vat,),
147 "FR": (TaxIDFormat.eu_vat,),
148 "GB": (TaxIDFormat.gb_vat,),
149 "GE": (TaxIDFormat.ge_vat,),
150 "GR": (TaxIDFormat.eu_vat,),
151 "HK": (TaxIDFormat.hk_br,),
152 "HR": (TaxIDFormat.hr_oib, TaxIDFormat.eu_vat),
153 "HU": (TaxIDFormat.hu_tin, TaxIDFormat.eu_vat),
154 "ID": (TaxIDFormat.id_npwp,),
155 "IE": (TaxIDFormat.eu_vat,),
156 "IL": (TaxIDFormat.il_vat,),
157 "IN": (TaxIDFormat.in_gst,),
158 "IS": (TaxIDFormat.is_vat,),
159 "IT": (TaxIDFormat.eu_vat,),
160 "JP": (TaxIDFormat.jp_cn, TaxIDFormat.jp_rn, TaxIDFormat.jp_trn),
161 "KE": (TaxIDFormat.ke_pin,),
162 "KR": (TaxIDFormat.kr_brn,),
163 "KZ": (TaxIDFormat.kz_bin,),
164 "LI": (TaxIDFormat.li_uid,),
165 "LT": (TaxIDFormat.eu_vat,),
166 "LU": (TaxIDFormat.eu_vat,),
167 "LV": (TaxIDFormat.eu_vat,),
168 "MT": (TaxIDFormat.eu_vat,),
169 "MX": (TaxIDFormat.mx_rfc,),
170 "MY": (TaxIDFormat.my_frp, TaxIDFormat.my_itn, TaxIDFormat.my_sst),
171 "NG": (TaxIDFormat.ng_tin,),
172 "NL": (TaxIDFormat.eu_vat,),
173 "NO": (TaxIDFormat.no_vat, TaxIDFormat.no_voec),
174 "NZ": (TaxIDFormat.nz_gst,),
175 "OM": (TaxIDFormat.om_vat,),
176 "PE": (TaxIDFormat.pe_ruc,),
177 "PH": (TaxIDFormat.ph_tin,),
178 "PL": (TaxIDFormat.eu_vat,),
179 "PT": (TaxIDFormat.eu_vat,),
180 "RO": (TaxIDFormat.ro_tin, TaxIDFormat.eu_vat),
181 "RS": (TaxIDFormat.rs_pib,),
182 "RU": (TaxIDFormat.ru_inn, TaxIDFormat.ru_kpp),
183 "SA": (TaxIDFormat.sa_vat,),
184 "SE": (TaxIDFormat.eu_vat,),
185 "SG": (TaxIDFormat.sg_gst, TaxIDFormat.sg_uen),
186 "SI": (TaxIDFormat.si_tin, TaxIDFormat.eu_vat),
187 "SK": (TaxIDFormat.eu_vat,),
188 "SV": (TaxIDFormat.sv_nit,),
189 "TH": (TaxIDFormat.th_vat,),
190 "TR": (TaxIDFormat.tr_tin,),
191 "TW": (TaxIDFormat.tw_vat,),
192 "UA": (TaxIDFormat.ua_vat,),
193 "US": (TaxIDFormat.us_ein,),
194 "UY": (TaxIDFormat.uy_ruc,),
195 "VE": (TaxIDFormat.ve_rif,),
196 "VN": (TaxIDFormat.vn_tin,),
197 "ZA": (TaxIDFormat.za_vat,),
198}
200TaxID = Annotated[ 1ab
201 tuple[str, TaxIDFormat],
202 Field(examples=[("911144442", "us_ein"), ("FR61954506077", "eu_vat")]),
203]
206class TaxError(PolarError): ... 1ab
209class UnsupportedTaxIDFormat(TaxError): 1ab
210 def __init__(self, tax_id_type: TaxIDFormat) -> None: 1ab
211 self.tax_id_type = tax_id_type
212 super().__init__(f"Tax ID format {tax_id_type} is not supported.")
215class InvalidTaxID(TaxError): 1ab
216 def __init__(self, tax_id: str, country: str) -> None: 1ab
217 self.tax_id = tax_id
218 self.country = country
219 super().__init__("Invalid tax ID.")
222class ValidatorProtocol(Protocol): 1ab
223 def validate(self, number: str, country: str) -> str: ... 223 ↛ exitline 223 didn't return from function 'validate' because 1ab
226class StdNumValidator(ValidatorProtocol): 1ab
227 def __init__(self, tax_id_type: TaxIDFormat): 1ab
228 tax_id_country, tax_id_format = tax_id_type.split("_", 1)
229 module = get_cc_module(tax_id_country, tax_id_format)
230 if module is None:
231 raise UnsupportedTaxIDFormat(tax_id_type)
232 self.module = module
234 def validate(self, number: str, country: str) -> str: 1ab
235 try:
236 return self.module.validate(number)
237 except stdnum.exceptions.ValidationError as e:
238 raise InvalidTaxID(number, country) from e
241class CAGSTHSTValidator(ValidatorProtocol): 1ab
242 def validate(self, number: str, country: str) -> str: 1ab
243 number = stdnum.ca.bn.compact(number)
244 if len(number) != 15:
245 raise InvalidTaxID(number, country)
246 try:
247 return stdnum.ca.bn.validate(number)
248 except stdnum.exceptions.ValidationError as e:
249 raise InvalidTaxID(number, country) from e
252class CLTINValidator(ValidatorProtocol): 1ab
253 def validate(self, number: str, country: str) -> str: 1ab
254 number = stdnum.cl.rut.compact(number)
255 try:
256 return stdnum.cl.rut.validate(number)
257 except stdnum.exceptions.ValidationError as e:
258 raise InvalidTaxID(number, country) from e
261class TRTINValidator(ValidatorProtocol): 1ab
262 def validate(self, number: str, country: str) -> str: 1ab
263 number = stdnum.tr.vkn.compact(number)
264 try:
265 return stdnum.tr.vkn.validate(number)
266 except stdnum.exceptions.ValidationError as e:
267 raise InvalidTaxID(number, country) from e
270class INGSTValidator(ValidatorProtocol): 1ab
271 def validate(self, number: str, country: str) -> str: 1ab
272 number = stdnum.in_.gstin.compact(number)
273 try:
274 return stdnum.in_.gstin.validate(number)
275 except stdnum.exceptions.ValidationError as e:
276 raise InvalidTaxID(number, country) from e
279def _get_validator(tax_id_type: TaxIDFormat) -> ValidatorProtocol: 1ab
280 match tax_id_type:
281 case TaxIDFormat.ca_gst_hst:
282 return CAGSTHSTValidator()
283 case TaxIDFormat.cl_tin:
284 return CLTINValidator()
285 case TaxIDFormat.tr_tin:
286 return TRTINValidator()
287 case TaxIDFormat.in_gst:
288 return INGSTValidator()
289 case _:
290 return StdNumValidator(tax_id_type)
293def validate_tax_id(number: str, country: str) -> TaxID: 1ab
294 """
295 Validate a tax ID for a given country.
297 Args:
298 number: The tax ID to validate.
299 country: The country of the tax ID.
301 Returns:
302 The validated tax ID and the tax ID format as tuple
304 Raises:
305 InvalidTaxID: The tax ID is invalid or unsupported.
306 """
307 try:
308 tax_id_types = COUNTRY_TAX_ID_MAP[country]
309 except KeyError as e:
310 raise InvalidTaxID(number, country) from e
311 else:
312 for tax_id_type in tax_id_types:
313 try:
314 validator = _get_validator(tax_id_type)
315 return validator.validate(number, country), tax_id_type
316 except (UnsupportedTaxIDFormat, InvalidTaxID):
317 continue
318 raise InvalidTaxID(number, country)
321def to_stripe_tax_id(value: TaxID) -> stripe_lib.Customer.CreateParamsTaxIdDatum: 1ab
322 """
323 Convert a tax ID to the format expected by Stripe.
325 Args:
326 value: A tuple containing the tax ID and the tax ID type.
328 Returns:
329 A dictionary containing the tax ID in the format expected by Stripe.
330 """
331 tax_id, tax_id_type = value
332 return {
333 "type": str(tax_id_type), # type: ignore
334 "value": tax_id,
335 }
338class TaxIDType(TypeDecorator[Any]): 1ab
339 impl = JSONB 1ab
340 cache_ok = True 1ab
342 def process_bind_param(self, value: Any, dialect: Dialect) -> Any: 1ab
343 if value is not None:
344 if not isinstance(value, tuple | list) or len(value) != 2:
345 raise TypeError("Invalid tax ID value.")
346 return json.dumps(value)
347 return value
349 def process_result_value(self, value: str | None, dialect: Dialect) -> Any: 1ab
350 if value is not None:
351 return json.loads(value)
352 return value
355class TaxCalculationError(PolarError): 1ab
356 message: LiteralString
358 def __init__( 1ab
359 self,
360 stripe_error: stripe_lib.StripeError,
361 message: LiteralString = "An error occurred while calculating tax.",
362 ) -> None:
363 self.stripe_error = stripe_error
364 self.message = message
365 super().__init__(message)
368class IncompleteTaxLocation(TaxCalculationError): 1ab
369 def __init__(self, stripe_error: stripe_lib.InvalidRequestError) -> None: 1ab
370 super().__init__(stripe_error, "Required tax location information is missing.")
373class InvalidTaxLocation(TaxCalculationError): 1ab
374 def __init__(self, stripe_error: stripe_lib.StripeError) -> None: 1ab
375 super().__init__(
376 stripe_error,
377 (
378 "We could not determine the customer's tax location "
379 "based on the provided customer address."
380 ),
381 )
384class TaxabilityReason(StrEnum): 1ab
385 standard_rated = "standard_rated" 1ab
386 """Purchases that are subject to the standard rate of tax.""" 1ab
388 not_collecting = "not_collecting" 1ab
389 """Purchases for countries where we don't collect tax.""" 1ab
391 product_exempt = "product_exempt" 1ab
392 """Purchases for products that are exempt from tax.""" 1ab
394 reverse_charge = "reverse_charge" 1ab
395 """Purchases where the customer is responsible for paying tax, e.g. B2B transactions with provided tax ID.""" 1ab
397 not_subject_to_tax = "not_subject_to_tax" 1ab
398 """Purchases where the customer provided a tax ID, but on countries where we don't collect tax.""" 1ab
400 not_supported = "not_supported" 1ab
401 """Purchases from countries where we don't support tax.""" 1ab
403 customer_exempt = "customer_exempt" 1ab
404 """Purchases where the customer is exempt from tax, e.g. if the subscription was created before our tax registration.""" 1ab
406 @classmethod 1ab
407 def from_stripe( 1ab
408 cls, stripe_reason: str | None, tax_amount: int
409 ) -> "TaxabilityReason | None":
410 if stripe_reason is None or stripe_reason == "not_available":
411 # Stripe sometimes returns `None` or `not_available` even if taxes are collected.
412 if tax_amount != 0:
413 return TaxabilityReason.standard_rated
414 return None
416 return cls(stripe_reason)
419class TaxRate(TypedDict): 1ab
420 rate_type: Literal["percentage"] | Literal["fixed"] 1ab
421 basis_points: int | None 1ab
422 amount: int | None 1ab
423 amount_currency: str | None 1ab
424 display_name: str 1ab
425 country: str | None 1ab
426 state: str | None 1ab
429def from_stripe_tax_rate(tax_rate: stripe_lib.TaxRate) -> TaxRate | None: 1ab
430 rate_type = tax_rate.rate_type
431 if rate_type is None:
432 return None
434 return {
435 "rate_type": "fixed" if rate_type == "flat_amount" else "percentage",
436 "basis_points": int(tax_rate.percentage * 100)
437 if tax_rate.percentage is not None
438 else None,
439 "amount": tax_rate.flat_amount.amount if tax_rate.flat_amount else None,
440 "amount_currency": tax_rate.flat_amount.currency
441 if tax_rate.flat_amount
442 else None,
443 "display_name": tax_rate.display_name,
444 "country": tax_rate.country,
445 "state": tax_rate.state,
446 }
449def from_stripe_tax_rate_details( 1ab
450 tax_rate_details: stripe_lib.tax.Calculation.TaxBreakdown.TaxRateDetails,
451) -> TaxRate | None:
452 rate_type = tax_rate_details.rate_type
453 if rate_type is None:
454 return None
456 basis_points = None
457 amount = None
458 amount_currency = None
460 if tax_rate_details.percentage_decimal is not None:
461 basis_points = int(float(tax_rate_details.percentage_decimal) * 100)
462 elif tax_rate_details.flat_amount is not None:
463 amount = tax_rate_details.flat_amount.amount
464 amount_currency = tax_rate_details.flat_amount.currency
466 tax_type = tax_rate_details.tax_type
467 display_name = "Tax"
468 if tax_type is not None:
469 if tax_type in {"gst", "hst", "igst", "jct", "pst", "qct", "rst", "vat"}:
470 display_name = tax_type.upper()
471 else:
472 display_name = tax_type.replace("_", " ").title()
474 return {
475 "rate_type": "fixed" if rate_type == "flat_amount" else "percentage",
476 "basis_points": basis_points,
477 "amount": amount,
478 "amount_currency": amount_currency,
479 "display_name": display_name,
480 "country": tax_rate_details.country,
481 "state": tax_rate_details.state,
482 }
485class TaxCode(StrEnum): 1ab
486 general_electronically_supplied_services = ( 1ab
487 "general_electronically_supplied_services"
488 )
490 def to_stripe(self) -> str: 1ab
491 match self:
492 case TaxCode.general_electronically_supplied_services:
493 return "txcd_10000000"
496class TaxCalculation(TypedDict): 1ab
497 processor_id: str 1ab
498 amount: int 1ab
499 taxability_reason: TaxabilityReason | None 1ab
500 tax_rate: TaxRate | None 1ab
503async def calculate_tax( 1ab
504 identifier: uuid.UUID | str,
505 currency: str,
506 amount: int,
507 tax_code: TaxCode,
508 address: Address,
509 tax_ids: list[TaxID],
510 customer_exempt: bool,
511) -> TaxCalculation:
512 # Compute an idempotency key based on the input parameters to work as a sort of cache
513 address_str = address.model_dump_json()
514 tax_ids_str = ",".join(f"{tax_id[0]}:{tax_id[1]}" for tax_id in tax_ids)
515 taxability_override: Literal["customer_exempt", "none"] = (
516 "customer_exempt" if customer_exempt else "none"
517 )
518 idempotency_key_str = f"{identifier}:{currency}:{amount}:{tax_code}:{address_str}:{tax_ids_str}:{taxability_override}"
519 idempotency_key = hashlib.sha256(idempotency_key_str.encode()).hexdigest()
521 try:
522 calculation = await stripe_service.create_tax_calculation(
523 currency=currency,
524 line_items=[
525 {
526 "amount": amount,
527 "tax_code": tax_code.to_stripe(),
528 "quantity": 1,
529 "reference": str(identifier),
530 }
531 ],
532 customer_details={
533 "address": address.to_dict(),
534 "address_source": "billing",
535 "tax_ids": [to_stripe_tax_id(tax_id) for tax_id in tax_ids],
536 "taxability_override": taxability_override,
537 },
538 idempotency_key=idempotency_key,
539 )
540 except stripe_lib.RateLimitError as e:
541 if settings.is_sandbox():
542 log.warning(
543 "Stripe Tax API rate limit exceeded in sandbox mode, returning zero tax",
544 identifier=str(identifier),
545 currency=currency,
546 amount=amount,
547 )
548 return {
549 "processor_id": f"taxcalc_sandbox_{uuid.uuid4().hex}",
550 "amount": 0,
551 "taxability_reason": None,
552 "tax_rate": None,
553 }
554 raise
555 except stripe_lib.InvalidRequestError as e:
556 if (
557 e.error is not None
558 and e.error.param is not None
559 and e.error.param.startswith("customer_details[address]")
560 ):
561 raise IncompleteTaxLocation(e) from e
562 raise
563 except stripe_lib.StripeError as e:
564 if e.error is None or e.error.code != "customer_tax_location_invalid":
565 raise
566 raise InvalidTaxLocation(e) from e
567 else:
568 assert calculation.id is not None
569 amount = calculation.tax_amount_exclusive
570 breakdown = calculation.tax_breakdown[0]
571 return {
572 "processor_id": calculation.id,
573 "amount": amount,
574 "taxability_reason": TaxabilityReason.from_stripe(
575 breakdown.taxability_reason, amount
576 ),
577 "tax_rate": from_stripe_tax_rate_details(breakdown.tax_rate_details),
578 }