Coverage for polar/models/customer.py: 60%
155 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
1import dataclasses 1ab
2import string 1ab
3import time 1ab
4from collections.abc import Sequence 1ab
5from datetime import datetime 1ab
6from enum import StrEnum 1ab
7from typing import TYPE_CHECKING, Any 1ab
8from uuid import UUID 1ab
10import sqlalchemy as sa 1ab
11from sqlalchemy import ( 1ab
12 TIMESTAMP,
13 Boolean,
14 Column,
15 ColumnElement,
16 ForeignKey,
17 Index,
18 Integer,
19 String,
20 UniqueConstraint,
21 Uuid,
22 func,
23)
24from sqlalchemy.dialects.postgresql import JSONB 1ab
25from sqlalchemy.ext.hybrid import hybrid_property 1ab
26from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship 1ab
28from polar.kit.address import Address, AddressType 1ab
29from polar.kit.db.models import RecordModel 1ab
30from polar.kit.metadata import MetadataMixin 1ab
31from polar.kit.tax import TaxID, TaxIDType 1ab
32from polar.kit.utils import utc_now 1ab
34if TYPE_CHECKING: 34 ↛ 35line 34 didn't jump to line 35 because the condition on line 34 was never true1ab
35 from .benefit_grant import BenefitGrant
36 from .customer_meter import CustomerMeter
37 from .member import Member
38 from .organization import Organization
39 from .payment_method import PaymentMethod
40 from .subscription import Subscription
43def short_id_to_base26(short_id: int) -> str: 1ab
44 """Convert a numeric short_id to an 8-character base-26 string (A-Z)."""
45 chars = string.ascii_uppercase
46 result = ""
47 num = short_id
49 # Convert to base-26
50 while num > 0:
51 result = chars[num % 26] + result
52 num = num // 26
54 # Pad with 'A' to ensure 8 characters
55 return result.rjust(8, "A")
58class CustomerOAuthPlatform(StrEnum): 1ab
59 github = "github" 1ab
60 discord = "discord" 1ab
62 def get_account_key(self, account_id: str) -> str: 1ab
63 return f"{self.value}:{account_id}"
65 def get_account_id(self, data: dict[str, Any]) -> str: 1ab
66 if self == CustomerOAuthPlatform.github:
67 return str(data["id"])
68 if self == CustomerOAuthPlatform.discord:
69 return str(data["id"])
70 raise NotImplementedError()
72 def get_account_username(self, data: dict[str, Any]) -> str: 1ab
73 if self == CustomerOAuthPlatform.github:
74 return data["login"]
75 if self == CustomerOAuthPlatform.discord:
76 return data["username"]
77 raise NotImplementedError()
80@dataclasses.dataclass 1ab
81class CustomerOAuthAccount: 1ab
82 access_token: str 1ab
83 account_id: str 1ab
84 account_username: str | None = None 1ab
85 expires_at: int | None = None 1ab
86 refresh_token: str | None = None 1ab
87 refresh_token_expires_at: int | None = None 1ab
89 def is_expired(self) -> bool: 1ab
90 if self.expires_at is None:
91 return False
92 return time.time() > self.expires_at
95class Customer(MetadataMixin, RecordModel): 1ab
96 __tablename__ = "customers" 1ab
97 __table_args__ = ( 1ab
98 Index(
99 "ix_customers_email_case_insensitive",
100 func.lower(Column("email")),
101 "deleted_at",
102 postgresql_nulls_not_distinct=True,
103 ),
104 Index(
105 "ix_customers_organization_id_email_case_insensitive",
106 "organization_id",
107 func.lower(Column("email")),
108 "deleted_at",
109 unique=True,
110 postgresql_nulls_not_distinct=True,
111 ),
112 UniqueConstraint("organization_id", "external_id"),
113 UniqueConstraint("organization_id", "short_id"),
114 )
116 external_id: Mapped[str | None] = mapped_column(String, nullable=True, default=None) 1ab
117 short_id: Mapped[int] = mapped_column( 1ab
118 sa.BigInteger,
119 nullable=False,
120 index=True,
121 server_default=sa.text("generate_customer_short_id()"),
122 )
123 email: Mapped[str] = mapped_column(String(320), nullable=False) 1ab
124 email_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) 1ab
125 stripe_customer_id: Mapped[str | None] = mapped_column( 1ab
126 String, nullable=True, default=None, unique=False
127 )
129 name: Mapped[str | None] = mapped_column(String, nullable=True, default=None) 1ab
130 _billing_name: Mapped[str | None] = mapped_column( 1ab
131 "billing_name", String, nullable=True, default=None
132 )
133 billing_address: Mapped[Address | None] = mapped_column( 1ab
134 AddressType, nullable=True, default=None
135 )
136 tax_id: Mapped[TaxID | None] = mapped_column(TaxIDType, nullable=True, default=None) 1ab
138 _oauth_accounts: Mapped[dict[str, dict[str, Any]]] = mapped_column( 1ab
139 "oauth_accounts", JSONB, nullable=False, default=dict
140 )
142 _legacy_user_id: Mapped[UUID | None] = mapped_column( 1ab
143 "legacy_user_id",
144 Uuid,
145 ForeignKey("users.id", ondelete="set null"),
146 nullable=True,
147 )
148 """ 1ab
149 Before implementing customers, every customer was a user. This field is used to
150 keep track of the user that originated this customer.
152 It helps us keep backwards compatibility with integrations that used the user ID as
153 reference to the customer.
155 For new customers, this field will be null.
156 """
158 meters_dirtied_at: Mapped[datetime | None] = mapped_column( 1ab
159 TIMESTAMP(timezone=True), nullable=True, default=None, index=True
160 )
161 meters_updated_at: Mapped[datetime] = mapped_column( 1ab
162 TIMESTAMP(timezone=True), nullable=True, default=None, index=True
163 )
165 invoice_next_number: Mapped[int] = mapped_column(Integer, nullable=False, default=1) 1ab
167 organization_id: Mapped[UUID] = mapped_column( 1ab
168 Uuid,
169 ForeignKey("organizations.id", ondelete="cascade"),
170 nullable=False,
171 index=True,
172 )
174 @declared_attr 1ab
175 def organization(cls) -> Mapped["Organization"]: 1ab
176 return relationship("Organization", lazy="raise") 1ab
178 @declared_attr 1ab
179 def payment_methods(cls) -> Mapped[Sequence["PaymentMethod"]]: 1ab
180 return relationship( 1ab
181 "PaymentMethod",
182 lazy="raise",
183 back_populates="customer",
184 cascade="all, delete-orphan",
185 foreign_keys="[PaymentMethod.customer_id]",
186 )
188 @declared_attr 1ab
189 def members(cls) -> Mapped[Sequence["Member"]]: 1ab
190 return relationship( 1ab
191 "Member",
192 lazy="raise",
193 back_populates="customer",
194 cascade="all, delete-orphan",
195 )
197 default_payment_method_id: Mapped[UUID | None] = mapped_column( 1ab
198 "default_payment_method_id",
199 Uuid,
200 ForeignKey("payment_methods.id", ondelete="set null"),
201 nullable=True,
202 )
204 @declared_attr 1ab
205 def default_payment_method(cls) -> Mapped["PaymentMethod | None"]: 1ab
206 return relationship( 1ab
207 "PaymentMethod",
208 lazy="raise",
209 uselist=False,
210 foreign_keys=[cls.default_payment_method_id], # type: ignore
211 )
213 @hybrid_property 1ab
214 def can_authenticate(self) -> bool: 1ab
215 return self.deleted_at is None
217 @can_authenticate.inplace.expression 1ab
218 @classmethod 1ab
219 def _can_authenticate_expression(cls) -> ColumnElement[bool]: 1ab
220 return cls.deleted_at.is_(None) 1c
222 def get_oauth_account( 1ab
223 self, account_id: str, platform: CustomerOAuthPlatform
224 ) -> CustomerOAuthAccount | None:
225 oauth_account_data = self._oauth_accounts.get(
226 platform.get_account_key(account_id)
227 )
228 if oauth_account_data is None:
229 return None
231 return CustomerOAuthAccount(**oauth_account_data)
233 def set_oauth_account( 1ab
234 self, oauth_account: CustomerOAuthAccount, platform: CustomerOAuthPlatform
235 ) -> None:
236 account_key = platform.get_account_key(oauth_account.account_id)
237 self._oauth_accounts = {
238 **self._oauth_accounts,
239 account_key: dataclasses.asdict(oauth_account),
240 }
242 def remove_oauth_account( 1ab
243 self, account_id: str, platform: CustomerOAuthPlatform
244 ) -> None:
245 account_key = platform.get_account_key(account_id)
246 self._oauth_accounts = {
247 k: v for k, v in self._oauth_accounts.items() if k != account_key
248 }
250 @property 1ab
251 def oauth_accounts(self) -> dict[str, Any]: 1ab
252 return self._oauth_accounts
254 @property 1ab
255 def short_id_str(self) -> str: 1ab
256 """Get the base-26 string representation of the short_id."""
257 return short_id_to_base26(self.short_id)
259 @property 1ab
260 def legacy_user_id(self) -> UUID: 1ab
261 return self._legacy_user_id or self.id
263 @property 1ab
264 def legacy_user_public_name(self) -> str: 1ab
265 if self.name:
266 return self.name[0]
267 return self.email[0]
269 @property 1ab
270 def active_subscriptions(self) -> Sequence["Subscription"] | None: 1ab
271 return getattr(self, "_active_subscriptions", None)
273 @active_subscriptions.setter 1ab
274 def active_subscriptions(self, value: Sequence["Subscription"]) -> None: 1ab
275 self._active_subscriptions = value
277 @property 1ab
278 def granted_benefits(self) -> Sequence["BenefitGrant"] | None: 1ab
279 return getattr(self, "_granted_benefits", None)
281 @granted_benefits.setter 1ab
282 def granted_benefits(self, value: Sequence["BenefitGrant"]) -> None: 1ab
283 self._granted_benefits = value
285 @property 1ab
286 def active_meters(self) -> Sequence["CustomerMeter"] | None: 1ab
287 return getattr(self, "_active_meters", None)
289 @active_meters.setter 1ab
290 def active_meters(self, value: Sequence["CustomerMeter"]) -> None: 1ab
291 self._active_meters = value
293 @property 1ab
294 def billing_name(self) -> str | None: 1ab
295 return self._billing_name or self.name
297 @billing_name.setter 1ab
298 def billing_name(self, value: str | None) -> None: 1ab
299 self._billing_name = value
301 @property 1ab
302 def actual_billing_name(self) -> str | None: 1ab
303 return self._billing_name
305 def touch_meters_dirtied_at(self) -> None: 1ab
306 self.meters_dirtied_at = utc_now()