Coverage for polar/customer/service.py: 23%

123 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 16:17 +0000

1import uuid 1a

2from collections.abc import Sequence 1a

3from typing import Any 1a

4 

5from sqlalchemy import UnaryExpression, asc, desc, func, or_ 1a

6from sqlalchemy.orm import joinedload 1a

7from sqlalchemy_utils.types.range import timedelta 1a

8 

9from polar.auth.models import AuthSubject 1a

10from polar.benefit.grant.repository import BenefitGrantRepository 1a

11from polar.customer_meter.repository import CustomerMeterRepository 1a

12from polar.exceptions import PolarRequestValidationError, ValidationError 1a

13from polar.kit.metadata import MetadataQuery, apply_metadata_clause 1a

14from polar.kit.pagination import PaginationParams 1a

15from polar.kit.sorting import Sorting 1a

16from polar.member import member_service 1a

17from polar.member.schemas import Member as MemberSchema 1a

18from polar.models import BenefitGrant, Customer, Organization, User 1a

19from polar.models.webhook_endpoint import CustomerWebhookEventType, WebhookEventType 1a

20from polar.organization.resolver import get_payload_organization 1a

21from polar.postgres import AsyncReadSession, AsyncSession 1a

22from polar.redis import Redis 1a

23from polar.subscription.repository import SubscriptionRepository 1a

24from polar.webhook.service import webhook as webhook_service 1a

25from polar.worker import enqueue_job 1a

26 

27from .repository import CustomerRepository 1a

28from .schemas.customer import ( 1a

29 CustomerCreate, 

30 CustomerUpdate, 

31 CustomerUpdateExternalID, 

32) 

33from .schemas.state import CustomerState 1a

34from .sorting import CustomerSortProperty 1a

35 

36 

37class CustomerService: 1a

38 async def list( 1a

39 self, 

40 session: AsyncReadSession, 

41 auth_subject: AuthSubject[User | Organization], 

42 *, 

43 organization_id: Sequence[uuid.UUID] | None = None, 

44 email: str | None = None, 

45 metadata: MetadataQuery | None = None, 

46 query: str | None = None, 

47 pagination: PaginationParams, 

48 sorting: list[Sorting[CustomerSortProperty]] = [ 

49 (CustomerSortProperty.created_at, True) 

50 ], 

51 ) -> tuple[Sequence[Customer], int]: 

52 repository = CustomerRepository.from_session(session) 

53 statement = repository.get_readable_statement(auth_subject) 

54 

55 if organization_id is not None: 

56 statement = statement.where(Customer.organization_id.in_(organization_id)) 

57 

58 if email is not None: 

59 statement = statement.where(func.lower(Customer.email) == email.lower()) 

60 

61 if metadata is not None: 

62 statement = apply_metadata_clause(Customer, statement, metadata) 

63 

64 if query is not None: 

65 statement = statement.where( 

66 or_( 

67 Customer.email.ilike(f"%{query}%"), 

68 Customer.name.ilike(f"%{query}%"), 

69 Customer.external_id.ilike(f"{query}%"), 

70 ) 

71 ) 

72 

73 order_by_clauses: list[UnaryExpression[Any]] = [] 

74 for criterion, is_desc in sorting: 

75 clause_function = desc if is_desc else asc 

76 if criterion == CustomerSortProperty.created_at: 

77 order_by_clauses.append(clause_function(Customer.created_at)) 

78 elif criterion == CustomerSortProperty.email: 

79 order_by_clauses.append(clause_function(Customer.email)) 

80 elif criterion == CustomerSortProperty.customer_name: 

81 order_by_clauses.append(clause_function(Customer.name)) 

82 statement = statement.order_by(*order_by_clauses) 

83 

84 return await repository.paginate( 

85 statement, limit=pagination.limit, page=pagination.page 

86 ) 

87 

88 async def get( 1a

89 self, 

90 session: AsyncReadSession, 

91 auth_subject: AuthSubject[User | Organization], 

92 id: uuid.UUID, 

93 ) -> Customer | None: 

94 repository = CustomerRepository.from_session(session) 

95 statement = repository.get_readable_statement(auth_subject).where( 

96 Customer.id == id 

97 ) 

98 return await repository.get_one_or_none(statement) 

99 

100 async def get_external( 1a

101 self, 

102 session: AsyncReadSession, 

103 auth_subject: AuthSubject[User | Organization], 

104 external_id: str, 

105 ) -> Customer | None: 

106 repository = CustomerRepository.from_session(session) 

107 statement = repository.get_readable_statement(auth_subject).where( 

108 Customer.external_id == external_id 

109 ) 

110 return await repository.get_one_or_none(statement) 

111 

112 async def create( 1a

113 self, 

114 session: AsyncSession, 

115 customer_create: CustomerCreate, 

116 auth_subject: AuthSubject[User | Organization], 

117 ) -> Customer: 

118 organization = await get_payload_organization( 

119 session, auth_subject, customer_create 

120 ) 

121 repository = CustomerRepository.from_session(session) 

122 

123 errors: list[ValidationError] = [] 

124 

125 if await repository.get_by_email_and_organization( 

126 customer_create.email, organization.id 

127 ): 

128 errors.append( 

129 { 

130 "type": "value_error", 

131 "loc": ("body", "email"), 

132 "msg": "A customer with this email address already exists.", 

133 "input": customer_create.email, 

134 } 

135 ) 

136 

137 if customer_create.external_id is not None: 

138 if await repository.get_by_external_id_and_organization( 

139 customer_create.external_id, organization.id 

140 ): 

141 errors.append( 

142 { 

143 "type": "value_error", 

144 "loc": ("body", "external_id"), 

145 "msg": "A customer with this external ID already exists.", 

146 "input": customer_create.external_id, 

147 } 

148 ) 

149 

150 if errors: 

151 raise PolarRequestValidationError(errors) 

152 

153 async with repository.create_context( 

154 Customer( 

155 organization=organization, 

156 **customer_create.model_dump( 

157 exclude={"organization_id", "owner"}, by_alias=True 

158 ), 

159 ) 

160 ) as customer: 

161 owner_email = customer_create.owner.email if customer_create.owner else None 

162 owner_name = customer_create.owner.name if customer_create.owner else None 

163 owner_external_id = ( 

164 customer_create.owner.external_id if customer_create.owner else None 

165 ) 

166 

167 await member_service.create_owner_member( 

168 session, 

169 customer, 

170 organization, 

171 owner_email=owner_email, 

172 owner_name=owner_name, 

173 owner_external_id=owner_external_id, 

174 ) 

175 return customer 

176 

177 async def update( 1a

178 self, 

179 session: AsyncSession, 

180 customer: Customer, 

181 customer_update: CustomerUpdate | CustomerUpdateExternalID, 

182 ) -> Customer: 

183 repository = CustomerRepository.from_session(session) 

184 

185 errors: list[ValidationError] = [] 

186 if ( 

187 customer_update.email is not None 

188 and customer.email.lower() != customer_update.email.lower() 

189 ): 

190 already_exists = await repository.get_by_email_and_organization( 

191 customer_update.email, customer.organization_id 

192 ) 

193 if already_exists: 

194 errors.append( 

195 { 

196 "type": "value_error", 

197 "loc": ("body", "email"), 

198 "msg": "A customer with this email address already exists.", 

199 "input": customer_update.email, 

200 } 

201 ) 

202 

203 customer.email = customer_update.email 

204 customer.email_verified = False 

205 

206 if ( 

207 isinstance(customer_update, CustomerUpdate) 

208 and "external_id" in customer_update.model_fields_set 

209 and customer.external_id is not None 

210 and customer.external_id != customer_update.external_id 

211 ): 

212 errors.append( 

213 { 

214 "type": "value_error", 

215 "loc": ("body", "external_id"), 

216 "msg": "Customer external ID cannot be updated.", 

217 "input": customer_update.external_id, 

218 } 

219 ) 

220 

221 if ( 

222 isinstance(customer_update, CustomerUpdate) 

223 and customer_update.external_id is not None 

224 and customer.external_id != customer_update.external_id 

225 ): 

226 if await repository.get_by_external_id_and_organization( 

227 customer_update.external_id, customer.organization_id 

228 ): 

229 errors.append( 

230 { 

231 "type": "value_error", 

232 "loc": ("body", "external_id"), 

233 "msg": "A customer with this external ID already exists.", 

234 "input": customer_update.external_id, 

235 } 

236 ) 

237 

238 if errors: 

239 raise PolarRequestValidationError(errors) 

240 

241 return await repository.update( 

242 customer, 

243 update_dict=customer_update.model_dump( 

244 exclude={"email"}, exclude_unset=True, by_alias=True 

245 ), 

246 ) 

247 

248 async def delete(self, session: AsyncSession, customer: Customer) -> Customer: 1a

249 enqueue_job("subscription.cancel_customer", customer_id=customer.id) 

250 enqueue_job("benefit.revoke_customer", customer_id=customer.id) 

251 

252 repository = CustomerRepository.from_session(session) 

253 return await repository.soft_delete(customer) 

254 

255 async def get_state( 1a

256 self, 

257 session: AsyncReadSession, 

258 redis: Redis, 

259 customer: Customer, 

260 cache: bool = True, 

261 ) -> CustomerState: 

262 # 👋 Whenever you change the state schema, 

263 # please also update the cache key with a version number. 

264 cache_key = f"polar:customer_state:v3:{customer.id}" 

265 

266 if cache: 

267 raw_state = await redis.get(cache_key) 

268 if raw_state is not None: 

269 return CustomerState.model_validate_json(raw_state) 

270 

271 subscription_repository = SubscriptionRepository.from_session(session) 

272 customer.active_subscriptions = ( 

273 await subscription_repository.list_active_by_customer(customer.id) 

274 ) 

275 

276 benefit_grant_repository = BenefitGrantRepository.from_session(session) 

277 customer.granted_benefits = ( 

278 await benefit_grant_repository.list_granted_by_customer( 

279 customer.id, options=(joinedload(BenefitGrant.benefit),) 

280 ) 

281 ) 

282 

283 customer_meter_repository = CustomerMeterRepository.from_session(session) 

284 customer.active_meters = await customer_meter_repository.get_all_by_customer( 

285 customer.id 

286 ) 

287 

288 state = CustomerState.model_validate(customer) 

289 

290 await redis.set( 

291 cache_key, 

292 state.model_dump_json(), 

293 ex=int(timedelta(hours=1).total_seconds()), 

294 ) 

295 

296 return state 

297 

298 async def webhook( 1a

299 self, 

300 session: AsyncSession, 

301 redis: Redis, 

302 event_type: CustomerWebhookEventType, 

303 customer: Customer, 

304 ) -> None: 

305 if event_type == WebhookEventType.customer_state_changed: 

306 data = await self.get_state(session, redis, customer, cache=False) 

307 await webhook_service.send( 

308 session, 

309 customer.organization, 

310 WebhookEventType.customer_state_changed, 

311 data, 

312 ) 

313 else: 

314 await webhook_service.send( 

315 session, customer.organization, event_type, customer 

316 ) 

317 

318 # For created, updated and deleted events, also trigger a state changed event 

319 if event_type in ( 

320 WebhookEventType.customer_created, 

321 WebhookEventType.customer_updated, 

322 WebhookEventType.customer_deleted, 

323 ): 

324 await self.webhook( 

325 session, redis, WebhookEventType.customer_state_changed, customer 

326 ) 

327 

328 async def load_members( 1a

329 self, 

330 session: AsyncReadSession, 

331 customer_id: uuid.UUID, 

332 ) -> Sequence[MemberSchema]: 

333 members = await member_service.list_by_customer(session, customer_id) 

334 return [MemberSchema.model_validate(member) for member in members] 

335 

336 

337customer = CustomerService() 1a