Coverage for polar/customer/repository.py: 23%

112 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1import contextlib 1a

2from collections.abc import AsyncGenerator, Iterable, Sequence 1a

3from typing import Any 1a

4from uuid import UUID 1a

5 

6from sqlalchemy import Select, func, select, update 1a

7from sqlalchemy import inspect as orm_inspect 1a

8from sqlalchemy.orm import InstanceState 1a

9 

10from polar.auth.models import AuthSubject, Organization, User, is_organization, is_user 1a

11from polar.event.system import CustomerUpdatedFields, SystemEvent 1a

12from polar.kit.repository import ( 1a

13 Options, 

14 RepositoryBase, 

15 RepositorySoftDeletionIDMixin, 

16 RepositorySoftDeletionMixin, 

17) 

18from polar.kit.utils import utc_now 1a

19from polar.models import Customer, UserOrganization 1a

20from polar.models.webhook_endpoint import WebhookEventType 1a

21from polar.worker import enqueue_job 1a

22 

23 

24def _get_changed_value( 1a

25 inspection: InstanceState[Customer], attr_name: str 

26) -> tuple[bool, Any]: 

27 """ 

28 Check if attribute changed and return (has_changed, new_value). 

29 Returns (False, None) if value didn't actually change. 

30 """ 

31 attr = inspection.attrs[attr_name] 

32 history = attr.history 

33 

34 if not history.has_changes(): 

35 return (False, None) 

36 

37 deleted = history.deleted[0] if history.deleted else None 

38 added = history.added[0] if history.added else None 

39 

40 if deleted == added: 

41 return (False, None) 

42 

43 return (True, added) 

44 

45 

46class CustomerRepository( 1a

47 RepositorySoftDeletionIDMixin[Customer, UUID], 

48 RepositorySoftDeletionMixin[Customer], 

49 RepositoryBase[Customer], 

50): 

51 model = Customer 1a

52 

53 async def create(self, object: Customer, *, flush: bool = False) -> Customer: 1a

54 customer = await super().create(object, flush=flush) 

55 

56 # We need the id to enqueue the job 

57 if customer.id is None: 

58 customer_id = Customer.__table__.c.id.default.arg(None) 

59 customer.id = customer_id 

60 

61 return customer 

62 

63 @contextlib.asynccontextmanager 1a

64 async def create_context( 1a

65 self, object: Customer, *, flush: bool = False 

66 ) -> AsyncGenerator[Customer]: 

67 customer = await self.create(object, flush=flush) 

68 yield customer 

69 assert customer.id is not None, "Customer.id is None" 

70 

71 # If the customer has an external_id, enqueue a meter update job 

72 # to create meters for any pre-existing events with that external_id. 

73 if customer.external_id is not None: 

74 enqueue_job("customer_meter.update_customer", customer.id) 

75 

76 enqueue_job("customer.webhook", WebhookEventType.customer_created, customer.id) 

77 enqueue_job("customer.event", customer.id, SystemEvent.customer_created) 

78 

79 async def update( 1a

80 self, 

81 object: Customer, 

82 *, 

83 update_dict: dict[str, Any] | None = None, 

84 flush: bool = False, 

85 ) -> Customer: 

86 inspection = orm_inspect(object) 

87 

88 customer = await super().update(object, update_dict=update_dict, flush=flush) 

89 enqueue_job("customer.webhook", WebhookEventType.customer_updated, customer.id) 

90 

91 # Only create an event if the customer is not being deleted 

92 if not customer.deleted_at: 

93 updated_fields: CustomerUpdatedFields = {} 

94 

95 changed, value = _get_changed_value(inspection, "name") 

96 if changed: 

97 updated_fields["name"] = value 

98 

99 changed, value = _get_changed_value(inspection, "email") 

100 if changed: 

101 updated_fields["email"] = value 

102 

103 changed, value = _get_changed_value(inspection, "billing_address") 

104 if changed: 

105 updated_fields["billing_address"] = value.to_dict() if value else None 

106 

107 changed, value = _get_changed_value(inspection, "tax_id") 

108 if changed: 

109 updated_fields["tax_id"] = value[0] if value else None 

110 

111 changed, value = _get_changed_value(inspection, "user_metadata") 

112 if changed: 

113 updated_fields["metadata"] = value 

114 

115 enqueue_job( 

116 "customer.event", 

117 customer.id, 

118 SystemEvent.customer_updated, 

119 updated_fields, 

120 ) 

121 

122 return customer 

123 

124 async def soft_delete(self, object: Customer, *, flush: bool = False) -> Customer: 1a

125 customer = await super().soft_delete(object, flush=flush) 

126 # Clear external_id for future recycling 

127 if customer.external_id: 

128 user_metadata = customer.user_metadata 

129 user_metadata["__external_id"] = customer.external_id 

130 # Store external_id in `user_metadata` for support debugging 

131 customer.user_metadata = user_metadata 

132 customer.external_id = None 

133 

134 enqueue_job("customer.webhook", WebhookEventType.customer_deleted, customer.id) 

135 enqueue_job("customer.event", customer.id, SystemEvent.customer_deleted) 

136 

137 return customer 

138 

139 async def touch_meters(self, customers: Iterable[Customer]) -> None: 1a

140 statement = ( 

141 update(Customer) 

142 .where(Customer.id.in_([c.id for c in customers])) 

143 .values(meters_dirtied_at=utc_now()) 

144 ) 

145 await self.session.execute(statement) 

146 

147 async def set_meters_updated_at(self, customers: Iterable[Customer]) -> None: 1a

148 statement = ( 

149 update(Customer) 

150 .where(Customer.id.in_([c.id for c in customers])) 

151 .values(meters_updated_at=utc_now()) 

152 ) 

153 await self.session.execute(statement) 

154 

155 async def get_by_id_and_organization( 1a

156 self, id: UUID, organization_id: UUID 

157 ) -> Customer | None: 

158 statement = self.get_base_statement().where( 

159 Customer.id == id, Customer.organization_id == organization_id 

160 ) 

161 return await self.get_one_or_none(statement) 

162 

163 async def get_by_email_and_organization( 1a

164 self, email: str, organization_id: UUID 

165 ) -> Customer | None: 

166 statement = self.get_base_statement().where( 

167 func.lower(Customer.email) == email.lower(), 

168 Customer.organization_id == organization_id, 

169 ) 

170 return await self.get_one_or_none(statement) 

171 

172 async def get_by_external_id_and_organization( 1a

173 self, external_id: str, organization_id: UUID 

174 ) -> Customer | None: 

175 statement = self.get_base_statement().where( 

176 Customer.external_id == external_id, 

177 Customer.organization_id == organization_id, 

178 ) 

179 return await self.get_one_or_none(statement) 

180 

181 async def get_by_stripe_customer_id_and_organization( 1a

182 self, stripe_customer_id: str, organization_id: UUID 

183 ) -> Customer | None: 

184 statement = self.get_base_statement().where( 

185 Customer.stripe_customer_id == stripe_customer_id, 

186 Customer.organization_id == organization_id, 

187 ) 

188 return await self.get_one_or_none(statement) 

189 

190 async def stream_by_organization( 1a

191 self, 

192 auth_subject: AuthSubject[User | Organization], 

193 organization_id: Sequence[UUID] | None, 

194 ) -> AsyncGenerator[Customer]: 

195 statement = self.get_readable_statement(auth_subject) 

196 

197 if organization_id is not None: 

198 statement = statement.where( 

199 Customer.organization_id.in_(organization_id), 

200 ) 

201 

202 async for customer in self.stream(statement): 

203 yield customer 

204 

205 async def get_readable_by_id( 1a

206 self, 

207 auth_subject: AuthSubject[User | Organization], 

208 id: UUID, 

209 *, 

210 options: Options = (), 

211 ) -> Customer | None: 

212 statement = ( 

213 self.get_readable_statement(auth_subject) 

214 .where(Customer.id == id) 

215 .options(*options) 

216 ) 

217 return await self.get_one_or_none(statement) 

218 

219 async def get_readable_by_external_id( 1a

220 self, 

221 auth_subject: AuthSubject[User | Organization], 

222 external_id: str, 

223 *, 

224 options: Options = (), 

225 ) -> Customer | None: 

226 statement = ( 

227 self.get_readable_statement(auth_subject) 

228 .where(Customer.external_id == external_id) 

229 .options(*options) 

230 ) 

231 return await self.get_one_or_none(statement) 

232 

233 def get_readable_statement( 1a

234 self, auth_subject: AuthSubject[User | Organization] 

235 ) -> Select[tuple[Customer]]: 

236 statement = self.get_base_statement() 

237 

238 if is_user(auth_subject): 

239 user = auth_subject.subject 

240 statement = statement.where( 

241 Customer.organization_id.in_( 

242 select(UserOrganization.organization_id).where( 

243 UserOrganization.user_id == user.id, 

244 UserOrganization.deleted_at.is_(None), 

245 ) 

246 ) 

247 ) 

248 elif is_organization(auth_subject): 

249 statement = statement.where( 

250 Customer.organization_id == auth_subject.subject.id, 

251 ) 

252 

253 return statement