Coverage for polar/customer/repository.py: 23%
112 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
1import contextlib 1a
2from collections.abc import AsyncGenerator, Iterable, Sequence 1a
3from typing import Any 1a
4from uuid import UUID 1a
6from sqlalchemy import Select, func, select, update 1a
7from sqlalchemy import inspect as orm_inspect 1a
8from sqlalchemy.orm import InstanceState 1a
10from polar.auth.models import AuthSubject, Organization, User, is_organization, is_user 1a
11from polar.event.system import CustomerUpdatedFields, SystemEvent 1a
12from polar.kit.repository import ( 1a
13 Options,
14 RepositoryBase,
15 RepositorySoftDeletionIDMixin,
16 RepositorySoftDeletionMixin,
17)
18from polar.kit.utils import utc_now 1a
19from polar.models import Customer, UserOrganization 1a
20from polar.models.webhook_endpoint import WebhookEventType 1a
21from polar.worker import enqueue_job 1a
24def _get_changed_value( 1a
25 inspection: InstanceState[Customer], attr_name: str
26) -> tuple[bool, Any]:
27 """
28 Check if attribute changed and return (has_changed, new_value).
29 Returns (False, None) if value didn't actually change.
30 """
31 attr = inspection.attrs[attr_name]
32 history = attr.history
34 if not history.has_changes():
35 return (False, None)
37 deleted = history.deleted[0] if history.deleted else None
38 added = history.added[0] if history.added else None
40 if deleted == added:
41 return (False, None)
43 return (True, added)
46class CustomerRepository( 1a
47 RepositorySoftDeletionIDMixin[Customer, UUID],
48 RepositorySoftDeletionMixin[Customer],
49 RepositoryBase[Customer],
50):
51 model = Customer 1a
53 async def create(self, object: Customer, *, flush: bool = False) -> Customer: 1a
54 customer = await super().create(object, flush=flush)
56 # We need the id to enqueue the job
57 if customer.id is None:
58 customer_id = Customer.__table__.c.id.default.arg(None)
59 customer.id = customer_id
61 return customer
63 @contextlib.asynccontextmanager 1a
64 async def create_context( 1a
65 self, object: Customer, *, flush: bool = False
66 ) -> AsyncGenerator[Customer]:
67 customer = await self.create(object, flush=flush)
68 yield customer
69 assert customer.id is not None, "Customer.id is None"
71 # If the customer has an external_id, enqueue a meter update job
72 # to create meters for any pre-existing events with that external_id.
73 if customer.external_id is not None:
74 enqueue_job("customer_meter.update_customer", customer.id)
76 enqueue_job("customer.webhook", WebhookEventType.customer_created, customer.id)
77 enqueue_job("customer.event", customer.id, SystemEvent.customer_created)
79 async def update( 1a
80 self,
81 object: Customer,
82 *,
83 update_dict: dict[str, Any] | None = None,
84 flush: bool = False,
85 ) -> Customer:
86 inspection = orm_inspect(object)
88 customer = await super().update(object, update_dict=update_dict, flush=flush)
89 enqueue_job("customer.webhook", WebhookEventType.customer_updated, customer.id)
91 # Only create an event if the customer is not being deleted
92 if not customer.deleted_at:
93 updated_fields: CustomerUpdatedFields = {}
95 changed, value = _get_changed_value(inspection, "name")
96 if changed:
97 updated_fields["name"] = value
99 changed, value = _get_changed_value(inspection, "email")
100 if changed:
101 updated_fields["email"] = value
103 changed, value = _get_changed_value(inspection, "billing_address")
104 if changed:
105 updated_fields["billing_address"] = value.to_dict() if value else None
107 changed, value = _get_changed_value(inspection, "tax_id")
108 if changed:
109 updated_fields["tax_id"] = value[0] if value else None
111 changed, value = _get_changed_value(inspection, "user_metadata")
112 if changed:
113 updated_fields["metadata"] = value
115 enqueue_job(
116 "customer.event",
117 customer.id,
118 SystemEvent.customer_updated,
119 updated_fields,
120 )
122 return customer
124 async def soft_delete(self, object: Customer, *, flush: bool = False) -> Customer: 1a
125 customer = await super().soft_delete(object, flush=flush)
126 # Clear external_id for future recycling
127 if customer.external_id:
128 user_metadata = customer.user_metadata
129 user_metadata["__external_id"] = customer.external_id
130 # Store external_id in `user_metadata` for support debugging
131 customer.user_metadata = user_metadata
132 customer.external_id = None
134 enqueue_job("customer.webhook", WebhookEventType.customer_deleted, customer.id)
135 enqueue_job("customer.event", customer.id, SystemEvent.customer_deleted)
137 return customer
139 async def touch_meters(self, customers: Iterable[Customer]) -> None: 1a
140 statement = (
141 update(Customer)
142 .where(Customer.id.in_([c.id for c in customers]))
143 .values(meters_dirtied_at=utc_now())
144 )
145 await self.session.execute(statement)
147 async def set_meters_updated_at(self, customers: Iterable[Customer]) -> None: 1a
148 statement = (
149 update(Customer)
150 .where(Customer.id.in_([c.id for c in customers]))
151 .values(meters_updated_at=utc_now())
152 )
153 await self.session.execute(statement)
155 async def get_by_id_and_organization( 1a
156 self, id: UUID, organization_id: UUID
157 ) -> Customer | None:
158 statement = self.get_base_statement().where(
159 Customer.id == id, Customer.organization_id == organization_id
160 )
161 return await self.get_one_or_none(statement)
163 async def get_by_email_and_organization( 1a
164 self, email: str, organization_id: UUID
165 ) -> Customer | None:
166 statement = self.get_base_statement().where(
167 func.lower(Customer.email) == email.lower(),
168 Customer.organization_id == organization_id,
169 )
170 return await self.get_one_or_none(statement)
172 async def get_by_external_id_and_organization( 1a
173 self, external_id: str, organization_id: UUID
174 ) -> Customer | None:
175 statement = self.get_base_statement().where(
176 Customer.external_id == external_id,
177 Customer.organization_id == organization_id,
178 )
179 return await self.get_one_or_none(statement)
181 async def get_by_stripe_customer_id_and_organization( 1a
182 self, stripe_customer_id: str, organization_id: UUID
183 ) -> Customer | None:
184 statement = self.get_base_statement().where(
185 Customer.stripe_customer_id == stripe_customer_id,
186 Customer.organization_id == organization_id,
187 )
188 return await self.get_one_or_none(statement)
190 async def stream_by_organization( 1a
191 self,
192 auth_subject: AuthSubject[User | Organization],
193 organization_id: Sequence[UUID] | None,
194 ) -> AsyncGenerator[Customer]:
195 statement = self.get_readable_statement(auth_subject)
197 if organization_id is not None:
198 statement = statement.where(
199 Customer.organization_id.in_(organization_id),
200 )
202 async for customer in self.stream(statement):
203 yield customer
205 async def get_readable_by_id( 1a
206 self,
207 auth_subject: AuthSubject[User | Organization],
208 id: UUID,
209 *,
210 options: Options = (),
211 ) -> Customer | None:
212 statement = (
213 self.get_readable_statement(auth_subject)
214 .where(Customer.id == id)
215 .options(*options)
216 )
217 return await self.get_one_or_none(statement)
219 async def get_readable_by_external_id( 1a
220 self,
221 auth_subject: AuthSubject[User | Organization],
222 external_id: str,
223 *,
224 options: Options = (),
225 ) -> Customer | None:
226 statement = (
227 self.get_readable_statement(auth_subject)
228 .where(Customer.external_id == external_id)
229 .options(*options)
230 )
231 return await self.get_one_or_none(statement)
233 def get_readable_statement( 1a
234 self, auth_subject: AuthSubject[User | Organization]
235 ) -> Select[tuple[Customer]]:
236 statement = self.get_base_statement()
238 if is_user(auth_subject):
239 user = auth_subject.subject
240 statement = statement.where(
241 Customer.organization_id.in_(
242 select(UserOrganization.organization_id).where(
243 UserOrganization.user_id == user.id,
244 UserOrganization.deleted_at.is_(None),
245 )
246 )
247 )
248 elif is_organization(auth_subject):
249 statement = statement.where(
250 Customer.organization_id == auth_subject.subject.id,
251 )
253 return statement