Coverage for polar/customer_meter/service.py: 25%

97 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 16:17 +0000

1import uuid 1a

2from collections.abc import Sequence 1a

3from decimal import Decimal 1a

4 

5from sqlalchemy import Select, or_ 1a

6from sqlalchemy.orm import joinedload 1a

7from sqlalchemy.orm.strategy_options import contains_eager 1a

8 

9from polar.auth.models import AuthSubject, Organization, User 1a

10from polar.customer.repository import CustomerRepository 1a

11from polar.event.repository import EventRepository 1a

12from polar.kit.math import non_negative_running_sum 1a

13from polar.kit.pagination import PaginationParams 1a

14from polar.kit.sorting import Sorting 1a

15from polar.locker import Locker 1a

16from polar.meter.repository import MeterRepository 1a

17from polar.meter.service import meter as meter_service 1a

18from polar.models import Customer, CustomerMeter, Event, Meter 1a

19from polar.models.event import EventSource 1a

20from polar.models.webhook_endpoint import WebhookEventType 1a

21from polar.postgres import AsyncSession 1a

22from polar.worker import enqueue_job 1a

23 

24from .repository import CustomerMeterRepository 1a

25from .sorting import CustomerMeterSortProperty 1a

26 

27 

28class CustomerMeterService: 1a

29 async def list( 1a

30 self, 

31 session: AsyncSession, 

32 auth_subject: AuthSubject[User | Organization], 

33 *, 

34 organization_id: Sequence[uuid.UUID] | None = None, 

35 customer_id: Sequence[uuid.UUID] | None = None, 

36 external_customer_id: Sequence[str] | None = None, 

37 meter_id: Sequence[uuid.UUID] | None = None, 

38 pagination: PaginationParams, 

39 sorting: list[Sorting[CustomerMeterSortProperty]] = [ 

40 (CustomerMeterSortProperty.modified_at, True) 

41 ], 

42 ) -> tuple[Sequence[CustomerMeter], int]: 

43 repository = CustomerMeterRepository.from_session(session) 

44 statement = ( 

45 repository.get_readable_statement(auth_subject) 

46 .join(CustomerMeter.meter) 

47 .options(contains_eager(CustomerMeter.meter)) 

48 ) 

49 

50 if organization_id is not None: 

51 statement = statement.where(Customer.organization_id.in_(organization_id)) 

52 

53 if customer_id is not None: 

54 statement = statement.where(Customer.id.in_(customer_id)) 

55 

56 if external_customer_id is not None: 

57 statement = statement.where(Customer.external_id.in_(external_customer_id)) 

58 

59 if meter_id is not None: 

60 statement = statement.where(Meter.id.in_(meter_id)) 

61 else: 

62 # Only filter archived meters when not querying for specific meter IDs 

63 statement = statement.where(Meter.archived_at.is_(None)) 

64 

65 statement = repository.apply_sorting(statement, sorting) 

66 

67 return await repository.paginate( 

68 statement, limit=pagination.limit, page=pagination.page 

69 ) 

70 

71 async def get( 1a

72 self, 

73 session: AsyncSession, 

74 auth_subject: AuthSubject[User | Organization], 

75 id: uuid.UUID, 

76 ) -> CustomerMeter | None: 

77 repository = CustomerMeterRepository.from_session(session) 

78 statement = ( 

79 repository.get_readable_statement(auth_subject) 

80 .where(CustomerMeter.id == id) 

81 .options(joinedload(CustomerMeter.meter)) 

82 ) 

83 return await repository.get_one_or_none(statement) 

84 

85 async def update_customer( 1a

86 self, session: AsyncSession, locker: Locker, customer: Customer 

87 ) -> None: 

88 repository = MeterRepository.from_session(session) 

89 statement = ( 

90 repository.get_base_statement() 

91 .where(Meter.organization_id == customer.organization_id) 

92 .order_by(Meter.created_at.asc()) 

93 ) 

94 

95 updated = False 

96 async for meter in repository.stream(statement): 

97 _, meter_updated = await self.update_customer_meter( 

98 session, locker, customer, meter 

99 ) 

100 updated = updated or meter_updated 

101 

102 if updated: 

103 enqueue_job( 

104 "customer.webhook", WebhookEventType.customer_state_changed, customer.id 

105 ) 

106 

107 customer_repository = CustomerRepository.from_session(session) 

108 await customer_repository.set_meters_updated_at((customer,)) 

109 

110 async def update_customer_meter( 1a

111 self, session: AsyncSession, locker: Locker, customer: Customer, meter: Meter 

112 ) -> tuple[CustomerMeter | None, bool]: 

113 async with locker.lock( 

114 f"customer_meter:{customer.id}:{meter.id}", 

115 timeout=5.0, 

116 blocking_timeout=0.2, 

117 ): 

118 repository = CustomerMeterRepository.from_session(session) 

119 customer_meter = await repository.get_by_customer_and_meter( 

120 customer.id, meter.id 

121 ) 

122 

123 event_repository = EventRepository.from_session(session) 

124 events_statement = await self._get_current_window_events_statement( 

125 session, customer, meter 

126 ) 

127 last_event = await event_repository.get_one_or_none( 

128 events_statement.order_by(None) 

129 .order_by(Event.ingested_at.desc()) 

130 .limit(1) 

131 ) 

132 

133 if last_event is None: 

134 return customer_meter, False 

135 

136 if customer_meter is None: 

137 customer_meter = await repository.create( 

138 CustomerMeter(customer=customer, meter=meter) 

139 ) 

140 

141 if customer_meter.last_balanced_event_id == last_event.id: 

142 return customer_meter, False 

143 

144 usage_events_statement = events_statement.with_only_columns(Event.id).where( 

145 Event.source == EventSource.user 

146 ) 

147 usage_units = await meter_service.get_quantity( 

148 session, meter, usage_events_statement 

149 ) 

150 customer_meter.consumed_units = Decimal(usage_units) 

151 

152 credit_events_statement = events_statement.where( 

153 Event.is_meter_credit.is_(True) 

154 ) 

155 credit_events = await event_repository.get_all(credit_events_statement) 

156 credited_units = non_negative_running_sum( 

157 event.user_metadata["units"] for event in credit_events 

158 ) 

159 customer_meter.credited_units = credited_units 

160 customer_meter.balance = ( 

161 customer_meter.credited_units - customer_meter.consumed_units 

162 ) 

163 customer_meter.last_balanced_event = last_event 

164 

165 return await repository.update(customer_meter), True 

166 

167 async def get_rollover_units( 1a

168 self, session: AsyncSession, customer: Customer, meter: Meter 

169 ) -> int: 

170 event_repository = EventRepository.from_session(session) 

171 events_statement = await self._get_current_window_events_statement( 

172 session, customer, meter 

173 ) 

174 last_event = await event_repository.get_one_or_none( 

175 events_statement.order_by(None).order_by(Event.ingested_at.desc()).limit(1) 

176 ) 

177 

178 if last_event is None: 

179 return 0 

180 

181 usage_events_statement = events_statement.with_only_columns(Event.id).where( 

182 Event.source == EventSource.user 

183 ) 

184 usage_units = await meter_service.get_quantity( 

185 session, meter, usage_events_statement 

186 ) 

187 

188 credit_events_statement = events_statement.where( 

189 Event.is_meter_credit.is_(True) 

190 ) 

191 credit_events = await event_repository.get_all(credit_events_statement) 

192 non_rollover_units = non_negative_running_sum( 

193 event.user_metadata["units"] 

194 for event in credit_events 

195 if not event.user_metadata["rollover"] 

196 ) 

197 rollover_units = non_negative_running_sum( 

198 event.user_metadata["units"] 

199 for event in credit_events 

200 if event.user_metadata["rollover"] 

201 ) 

202 balance = non_rollover_units + rollover_units - usage_units 

203 

204 return max(0, min(int(balance), rollover_units)) 

205 

206 async def _get_current_window_events_statement( 1a

207 self, session: AsyncSession, customer: Customer, meter: Meter 

208 ) -> Select[tuple[Event]]: 

209 event_repository = EventRepository.from_session(session) 

210 meter_reset_event = await event_repository.get_latest_meter_reset( 

211 customer, meter.id 

212 ) 

213 statement = ( 

214 event_repository.get_base_statement() 

215 .where( 

216 Event.organization_id == meter.organization_id, 

217 Event.customer == customer, 

218 or_( 

219 # Events matching meter definitions 

220 event_repository.get_meter_clause(meter), 

221 # System events impacting the meter balance 

222 event_repository.get_meter_system_clause(meter), 

223 ), 

224 ) 

225 .order_by(Event.ingested_at.asc()) 

226 ) 

227 if meter_reset_event is not None: 

228 statement = statement.where( 

229 Event.ingested_at >= meter_reset_event.ingested_at 

230 ) 

231 

232 return statement 

233 

234 

235customer_meter = CustomerMeterService() 1a