Coverage for polar/models/customer.py: 60%

155 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1import dataclasses 1ab

2import string 1ab

3import time 1ab

4from collections.abc import Sequence 1ab

5from datetime import datetime 1ab

6from enum import StrEnum 1ab

7from typing import TYPE_CHECKING, Any 1ab

8from uuid import UUID 1ab

9 

10import sqlalchemy as sa 1ab

11from sqlalchemy import ( 1ab

12 TIMESTAMP, 

13 Boolean, 

14 Column, 

15 ColumnElement, 

16 ForeignKey, 

17 Index, 

18 Integer, 

19 String, 

20 UniqueConstraint, 

21 Uuid, 

22 func, 

23) 

24from sqlalchemy.dialects.postgresql import JSONB 1ab

25from sqlalchemy.ext.hybrid import hybrid_property 1ab

26from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship 1ab

27 

28from polar.kit.address import Address, AddressType 1ab

29from polar.kit.db.models import RecordModel 1ab

30from polar.kit.metadata import MetadataMixin 1ab

31from polar.kit.tax import TaxID, TaxIDType 1ab

32from polar.kit.utils import utc_now 1ab

33 

34if TYPE_CHECKING: 34 ↛ 35line 34 didn't jump to line 35 because the condition on line 34 was never true1ab

35 from .benefit_grant import BenefitGrant 

36 from .customer_meter import CustomerMeter 

37 from .member import Member 

38 from .organization import Organization 

39 from .payment_method import PaymentMethod 

40 from .subscription import Subscription 

41 

42 

43def short_id_to_base26(short_id: int) -> str: 1ab

44 """Convert a numeric short_id to an 8-character base-26 string (A-Z).""" 

45 chars = string.ascii_uppercase 

46 result = "" 

47 num = short_id 

48 

49 # Convert to base-26 

50 while num > 0: 

51 result = chars[num % 26] + result 

52 num = num // 26 

53 

54 # Pad with 'A' to ensure 8 characters 

55 return result.rjust(8, "A") 

56 

57 

58class CustomerOAuthPlatform(StrEnum): 1ab

59 github = "github" 1ab

60 discord = "discord" 1ab

61 

62 def get_account_key(self, account_id: str) -> str: 1ab

63 return f"{self.value}:{account_id}" 

64 

65 def get_account_id(self, data: dict[str, Any]) -> str: 1ab

66 if self == CustomerOAuthPlatform.github: 

67 return str(data["id"]) 

68 if self == CustomerOAuthPlatform.discord: 

69 return str(data["id"]) 

70 raise NotImplementedError() 

71 

72 def get_account_username(self, data: dict[str, Any]) -> str: 1ab

73 if self == CustomerOAuthPlatform.github: 

74 return data["login"] 

75 if self == CustomerOAuthPlatform.discord: 

76 return data["username"] 

77 raise NotImplementedError() 

78 

79 

80@dataclasses.dataclass 1ab

81class CustomerOAuthAccount: 1ab

82 access_token: str 1ab

83 account_id: str 1ab

84 account_username: str | None = None 1ab

85 expires_at: int | None = None 1ab

86 refresh_token: str | None = None 1ab

87 refresh_token_expires_at: int | None = None 1ab

88 

89 def is_expired(self) -> bool: 1ab

90 if self.expires_at is None: 

91 return False 

92 return time.time() > self.expires_at 

93 

94 

95class Customer(MetadataMixin, RecordModel): 1ab

96 __tablename__ = "customers" 1ab

97 __table_args__ = ( 1ab

98 Index( 

99 "ix_customers_email_case_insensitive", 

100 func.lower(Column("email")), 

101 "deleted_at", 

102 postgresql_nulls_not_distinct=True, 

103 ), 

104 Index( 

105 "ix_customers_organization_id_email_case_insensitive", 

106 "organization_id", 

107 func.lower(Column("email")), 

108 "deleted_at", 

109 unique=True, 

110 postgresql_nulls_not_distinct=True, 

111 ), 

112 UniqueConstraint("organization_id", "external_id"), 

113 UniqueConstraint("organization_id", "short_id"), 

114 ) 

115 

116 external_id: Mapped[str | None] = mapped_column(String, nullable=True, default=None) 1ab

117 short_id: Mapped[int] = mapped_column( 1ab

118 sa.BigInteger, 

119 nullable=False, 

120 index=True, 

121 server_default=sa.text("generate_customer_short_id()"), 

122 ) 

123 email: Mapped[str] = mapped_column(String(320), nullable=False) 1ab

124 email_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) 1ab

125 stripe_customer_id: Mapped[str | None] = mapped_column( 1ab

126 String, nullable=True, default=None, unique=False 

127 ) 

128 

129 name: Mapped[str | None] = mapped_column(String, nullable=True, default=None) 1ab

130 _billing_name: Mapped[str | None] = mapped_column( 1ab

131 "billing_name", String, nullable=True, default=None 

132 ) 

133 billing_address: Mapped[Address | None] = mapped_column( 1ab

134 AddressType, nullable=True, default=None 

135 ) 

136 tax_id: Mapped[TaxID | None] = mapped_column(TaxIDType, nullable=True, default=None) 1ab

137 

138 _oauth_accounts: Mapped[dict[str, dict[str, Any]]] = mapped_column( 1ab

139 "oauth_accounts", JSONB, nullable=False, default=dict 

140 ) 

141 

142 _legacy_user_id: Mapped[UUID | None] = mapped_column( 1ab

143 "legacy_user_id", 

144 Uuid, 

145 ForeignKey("users.id", ondelete="set null"), 

146 nullable=True, 

147 ) 

148 """ 1ab

149 Before implementing customers, every customer was a user. This field is used to 

150 keep track of the user that originated this customer. 

151 

152 It helps us keep backwards compatibility with integrations that used the user ID as 

153 reference to the customer. 

154 

155 For new customers, this field will be null. 

156 """ 

157 

158 meters_dirtied_at: Mapped[datetime | None] = mapped_column( 1ab

159 TIMESTAMP(timezone=True), nullable=True, default=None, index=True 

160 ) 

161 meters_updated_at: Mapped[datetime] = mapped_column( 1ab

162 TIMESTAMP(timezone=True), nullable=True, default=None, index=True 

163 ) 

164 

165 invoice_next_number: Mapped[int] = mapped_column(Integer, nullable=False, default=1) 1ab

166 

167 organization_id: Mapped[UUID] = mapped_column( 1ab

168 Uuid, 

169 ForeignKey("organizations.id", ondelete="cascade"), 

170 nullable=False, 

171 index=True, 

172 ) 

173 

174 @declared_attr 1ab

175 def organization(cls) -> Mapped["Organization"]: 1ab

176 return relationship("Organization", lazy="raise") 1ab

177 

178 @declared_attr 1ab

179 def payment_methods(cls) -> Mapped[Sequence["PaymentMethod"]]: 1ab

180 return relationship( 1ab

181 "PaymentMethod", 

182 lazy="raise", 

183 back_populates="customer", 

184 cascade="all, delete-orphan", 

185 foreign_keys="[PaymentMethod.customer_id]", 

186 ) 

187 

188 @declared_attr 1ab

189 def members(cls) -> Mapped[Sequence["Member"]]: 1ab

190 return relationship( 1ab

191 "Member", 

192 lazy="raise", 

193 back_populates="customer", 

194 cascade="all, delete-orphan", 

195 ) 

196 

197 default_payment_method_id: Mapped[UUID | None] = mapped_column( 1ab

198 "default_payment_method_id", 

199 Uuid, 

200 ForeignKey("payment_methods.id", ondelete="set null"), 

201 nullable=True, 

202 ) 

203 

204 @declared_attr 1ab

205 def default_payment_method(cls) -> Mapped["PaymentMethod | None"]: 1ab

206 return relationship( 1ab

207 "PaymentMethod", 

208 lazy="raise", 

209 uselist=False, 

210 foreign_keys=[cls.default_payment_method_id], # type: ignore 

211 ) 

212 

213 @hybrid_property 1ab

214 def can_authenticate(self) -> bool: 1ab

215 return self.deleted_at is None 

216 

217 @can_authenticate.inplace.expression 1ab

218 @classmethod 1ab

219 def _can_authenticate_expression(cls) -> ColumnElement[bool]: 1ab

220 return cls.deleted_at.is_(None) 1c

221 

222 def get_oauth_account( 1ab

223 self, account_id: str, platform: CustomerOAuthPlatform 

224 ) -> CustomerOAuthAccount | None: 

225 oauth_account_data = self._oauth_accounts.get( 

226 platform.get_account_key(account_id) 

227 ) 

228 if oauth_account_data is None: 

229 return None 

230 

231 return CustomerOAuthAccount(**oauth_account_data) 

232 

233 def set_oauth_account( 1ab

234 self, oauth_account: CustomerOAuthAccount, platform: CustomerOAuthPlatform 

235 ) -> None: 

236 account_key = platform.get_account_key(oauth_account.account_id) 

237 self._oauth_accounts = { 

238 **self._oauth_accounts, 

239 account_key: dataclasses.asdict(oauth_account), 

240 } 

241 

242 def remove_oauth_account( 1ab

243 self, account_id: str, platform: CustomerOAuthPlatform 

244 ) -> None: 

245 account_key = platform.get_account_key(account_id) 

246 self._oauth_accounts = { 

247 k: v for k, v in self._oauth_accounts.items() if k != account_key 

248 } 

249 

250 @property 1ab

251 def oauth_accounts(self) -> dict[str, Any]: 1ab

252 return self._oauth_accounts 

253 

254 @property 1ab

255 def short_id_str(self) -> str: 1ab

256 """Get the base-26 string representation of the short_id.""" 

257 return short_id_to_base26(self.short_id) 

258 

259 @property 1ab

260 def legacy_user_id(self) -> UUID: 1ab

261 return self._legacy_user_id or self.id 

262 

263 @property 1ab

264 def legacy_user_public_name(self) -> str: 1ab

265 if self.name: 

266 return self.name[0] 

267 return self.email[0] 

268 

269 @property 1ab

270 def active_subscriptions(self) -> Sequence["Subscription"] | None: 1ab

271 return getattr(self, "_active_subscriptions", None) 

272 

273 @active_subscriptions.setter 1ab

274 def active_subscriptions(self, value: Sequence["Subscription"]) -> None: 1ab

275 self._active_subscriptions = value 

276 

277 @property 1ab

278 def granted_benefits(self) -> Sequence["BenefitGrant"] | None: 1ab

279 return getattr(self, "_granted_benefits", None) 

280 

281 @granted_benefits.setter 1ab

282 def granted_benefits(self, value: Sequence["BenefitGrant"]) -> None: 1ab

283 self._granted_benefits = value 

284 

285 @property 1ab

286 def active_meters(self) -> Sequence["CustomerMeter"] | None: 1ab

287 return getattr(self, "_active_meters", None) 

288 

289 @active_meters.setter 1ab

290 def active_meters(self, value: Sequence["CustomerMeter"]) -> None: 1ab

291 self._active_meters = value 

292 

293 @property 1ab

294 def billing_name(self) -> str | None: 1ab

295 return self._billing_name or self.name 

296 

297 @billing_name.setter 1ab

298 def billing_name(self, value: str | None) -> None: 1ab

299 self._billing_name = value 

300 

301 @property 1ab

302 def actual_billing_name(self) -> str | None: 1ab

303 return self._billing_name 

304 

305 def touch_meters_dirtied_at(self) -> None: 1ab

306 self.meters_dirtied_at = utc_now()