Coverage for polar/customer_meter/service.py: 25%
97 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
1import uuid 1a
2from collections.abc import Sequence 1a
3from decimal import Decimal 1a
5from sqlalchemy import Select, or_ 1a
6from sqlalchemy.orm import joinedload 1a
7from sqlalchemy.orm.strategy_options import contains_eager 1a
9from polar.auth.models import AuthSubject, Organization, User 1a
10from polar.customer.repository import CustomerRepository 1a
11from polar.event.repository import EventRepository 1a
12from polar.kit.math import non_negative_running_sum 1a
13from polar.kit.pagination import PaginationParams 1a
14from polar.kit.sorting import Sorting 1a
15from polar.locker import Locker 1a
16from polar.meter.repository import MeterRepository 1a
17from polar.meter.service import meter as meter_service 1a
18from polar.models import Customer, CustomerMeter, Event, Meter 1a
19from polar.models.event import EventSource 1a
20from polar.models.webhook_endpoint import WebhookEventType 1a
21from polar.postgres import AsyncSession 1a
22from polar.worker import enqueue_job 1a
24from .repository import CustomerMeterRepository 1a
25from .sorting import CustomerMeterSortProperty 1a
28class CustomerMeterService: 1a
29 async def list( 1a
30 self,
31 session: AsyncSession,
32 auth_subject: AuthSubject[User | Organization],
33 *,
34 organization_id: Sequence[uuid.UUID] | None = None,
35 customer_id: Sequence[uuid.UUID] | None = None,
36 external_customer_id: Sequence[str] | None = None,
37 meter_id: Sequence[uuid.UUID] | None = None,
38 pagination: PaginationParams,
39 sorting: list[Sorting[CustomerMeterSortProperty]] = [
40 (CustomerMeterSortProperty.modified_at, True)
41 ],
42 ) -> tuple[Sequence[CustomerMeter], int]:
43 repository = CustomerMeterRepository.from_session(session)
44 statement = (
45 repository.get_readable_statement(auth_subject)
46 .join(CustomerMeter.meter)
47 .options(contains_eager(CustomerMeter.meter))
48 )
50 if organization_id is not None:
51 statement = statement.where(Customer.organization_id.in_(organization_id))
53 if customer_id is not None:
54 statement = statement.where(Customer.id.in_(customer_id))
56 if external_customer_id is not None:
57 statement = statement.where(Customer.external_id.in_(external_customer_id))
59 if meter_id is not None:
60 statement = statement.where(Meter.id.in_(meter_id))
61 else:
62 # Only filter archived meters when not querying for specific meter IDs
63 statement = statement.where(Meter.archived_at.is_(None))
65 statement = repository.apply_sorting(statement, sorting)
67 return await repository.paginate(
68 statement, limit=pagination.limit, page=pagination.page
69 )
71 async def get( 1a
72 self,
73 session: AsyncSession,
74 auth_subject: AuthSubject[User | Organization],
75 id: uuid.UUID,
76 ) -> CustomerMeter | None:
77 repository = CustomerMeterRepository.from_session(session)
78 statement = (
79 repository.get_readable_statement(auth_subject)
80 .where(CustomerMeter.id == id)
81 .options(joinedload(CustomerMeter.meter))
82 )
83 return await repository.get_one_or_none(statement)
85 async def update_customer( 1a
86 self, session: AsyncSession, locker: Locker, customer: Customer
87 ) -> None:
88 repository = MeterRepository.from_session(session)
89 statement = (
90 repository.get_base_statement()
91 .where(Meter.organization_id == customer.organization_id)
92 .order_by(Meter.created_at.asc())
93 )
95 updated = False
96 async for meter in repository.stream(statement):
97 _, meter_updated = await self.update_customer_meter(
98 session, locker, customer, meter
99 )
100 updated = updated or meter_updated
102 if updated:
103 enqueue_job(
104 "customer.webhook", WebhookEventType.customer_state_changed, customer.id
105 )
107 customer_repository = CustomerRepository.from_session(session)
108 await customer_repository.set_meters_updated_at((customer,))
110 async def update_customer_meter( 1a
111 self, session: AsyncSession, locker: Locker, customer: Customer, meter: Meter
112 ) -> tuple[CustomerMeter | None, bool]:
113 async with locker.lock(
114 f"customer_meter:{customer.id}:{meter.id}",
115 timeout=5.0,
116 blocking_timeout=0.2,
117 ):
118 repository = CustomerMeterRepository.from_session(session)
119 customer_meter = await repository.get_by_customer_and_meter(
120 customer.id, meter.id
121 )
123 event_repository = EventRepository.from_session(session)
124 events_statement = await self._get_current_window_events_statement(
125 session, customer, meter
126 )
127 last_event = await event_repository.get_one_or_none(
128 events_statement.order_by(None)
129 .order_by(Event.ingested_at.desc())
130 .limit(1)
131 )
133 if last_event is None:
134 return customer_meter, False
136 if customer_meter is None:
137 customer_meter = await repository.create(
138 CustomerMeter(customer=customer, meter=meter)
139 )
141 if customer_meter.last_balanced_event_id == last_event.id:
142 return customer_meter, False
144 usage_events_statement = events_statement.with_only_columns(Event.id).where(
145 Event.source == EventSource.user
146 )
147 usage_units = await meter_service.get_quantity(
148 session, meter, usage_events_statement
149 )
150 customer_meter.consumed_units = Decimal(usage_units)
152 credit_events_statement = events_statement.where(
153 Event.is_meter_credit.is_(True)
154 )
155 credit_events = await event_repository.get_all(credit_events_statement)
156 credited_units = non_negative_running_sum(
157 event.user_metadata["units"] for event in credit_events
158 )
159 customer_meter.credited_units = credited_units
160 customer_meter.balance = (
161 customer_meter.credited_units - customer_meter.consumed_units
162 )
163 customer_meter.last_balanced_event = last_event
165 return await repository.update(customer_meter), True
167 async def get_rollover_units( 1a
168 self, session: AsyncSession, customer: Customer, meter: Meter
169 ) -> int:
170 event_repository = EventRepository.from_session(session)
171 events_statement = await self._get_current_window_events_statement(
172 session, customer, meter
173 )
174 last_event = await event_repository.get_one_or_none(
175 events_statement.order_by(None).order_by(Event.ingested_at.desc()).limit(1)
176 )
178 if last_event is None:
179 return 0
181 usage_events_statement = events_statement.with_only_columns(Event.id).where(
182 Event.source == EventSource.user
183 )
184 usage_units = await meter_service.get_quantity(
185 session, meter, usage_events_statement
186 )
188 credit_events_statement = events_statement.where(
189 Event.is_meter_credit.is_(True)
190 )
191 credit_events = await event_repository.get_all(credit_events_statement)
192 non_rollover_units = non_negative_running_sum(
193 event.user_metadata["units"]
194 for event in credit_events
195 if not event.user_metadata["rollover"]
196 )
197 rollover_units = non_negative_running_sum(
198 event.user_metadata["units"]
199 for event in credit_events
200 if event.user_metadata["rollover"]
201 )
202 balance = non_rollover_units + rollover_units - usage_units
204 return max(0, min(int(balance), rollover_units))
206 async def _get_current_window_events_statement( 1a
207 self, session: AsyncSession, customer: Customer, meter: Meter
208 ) -> Select[tuple[Event]]:
209 event_repository = EventRepository.from_session(session)
210 meter_reset_event = await event_repository.get_latest_meter_reset(
211 customer, meter.id
212 )
213 statement = (
214 event_repository.get_base_statement()
215 .where(
216 Event.organization_id == meter.organization_id,
217 Event.customer == customer,
218 or_(
219 # Events matching meter definitions
220 event_repository.get_meter_clause(meter),
221 # System events impacting the meter balance
222 event_repository.get_meter_system_clause(meter),
223 ),
224 )
225 .order_by(Event.ingested_at.asc())
226 )
227 if meter_reset_event is not None:
228 statement = statement.where(
229 Event.ingested_at >= meter_reset_event.ingested_at
230 )
232 return statement
235customer_meter = CustomerMeterService() 1a