Coverage for polar/customer_portal/service/subscription.py: 32%

82 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 15:52 +0000

1import uuid 1a

2from collections.abc import Sequence 1a

3from enum import StrEnum 1a

4from typing import Any 1a

5 

6from sqlalchemy import Select, UnaryExpression, asc, desc, select 1a

7from sqlalchemy.orm import contains_eager, joinedload, selectinload 1a

8 

9from polar.auth.models import AuthSubject 1a

10from polar.exceptions import PolarError 1a

11from polar.kit.db.postgres import AsyncSession 1a

12from polar.kit.pagination import PaginationParams, paginate 1a

13from polar.kit.services import ResourceServiceReader 1a

14from polar.kit.sorting import Sorting 1a

15from polar.models import ( 1a

16 Customer, 

17 Organization, 

18 Product, 

19 Subscription, 

20 SubscriptionMeter, 

21) 

22from polar.models.subscription import CustomerCancellationReason 1a

23from polar.subscription.service import subscription as subscription_service 1a

24 

25from ..schemas.subscription import ( 1a

26 CustomerSubscriptionUpdate, 

27 CustomerSubscriptionUpdateProduct, 

28 CustomerSubscriptionUpdateSeats, 

29) 

30 

31 

32class CustomerSubscriptionError(PolarError): ... 1a

33 

34 

35class UpdateSubscriptionNotAllowed(CustomerSubscriptionError): 1a

36 def __init__(self) -> None: 1a

37 super().__init__("Updating subscription is not allowed.", 403) 

38 

39 

40class CustomerSubscriptionSortProperty(StrEnum): 1a

41 started_at = "started_at" 1a

42 amount = "amount" 1a

43 status = "status" 1a

44 organization = "organization" 1a

45 product = "product" 1a

46 

47 

48class CustomerSubscriptionService(ResourceServiceReader[Subscription]): 1a

49 async def list( 1a

50 self, 

51 session: AsyncSession, 

52 auth_subject: AuthSubject[Customer], 

53 *, 

54 product_id: Sequence[uuid.UUID] | None = None, 

55 active: bool | None = None, 

56 query: str | None = None, 

57 pagination: PaginationParams, 

58 sorting: list[Sorting[CustomerSubscriptionSortProperty]] = [ 

59 (CustomerSubscriptionSortProperty.started_at, True) 

60 ], 

61 ) -> tuple[Sequence[Subscription], int]: 

62 statement = self._get_readable_subscription_statement(auth_subject).where( 

63 Subscription.started_at.is_not(None) 

64 ) 

65 

66 statement = ( 

67 statement.join(Product, onclause=Subscription.product_id == Product.id) 

68 .join(Organization, onclause=Product.organization_id == Organization.id) 

69 .options( 

70 joinedload(Subscription.customer), 

71 contains_eager(Subscription.product).options( 

72 selectinload(Product.product_medias), 

73 contains_eager(Product.organization), 

74 ), 

75 selectinload(Subscription.meters).joinedload(SubscriptionMeter.meter), 

76 ) 

77 ) 

78 

79 if product_id is not None: 

80 statement = statement.where(Subscription.product_id.in_(product_id)) 

81 

82 if active is not None: 

83 if active: 

84 statement = statement.where(Subscription.active.is_(True)) 

85 else: 

86 statement = statement.where(Subscription.revoked.is_(True)) 

87 

88 if query is not None: 

89 statement = statement.where(Product.name.ilike(f"%{query}%")) 

90 

91 order_by_clauses: list[UnaryExpression[Any]] = [] 

92 for criterion, is_desc in sorting: 

93 clause_function = desc if is_desc else asc 

94 if criterion == CustomerSubscriptionSortProperty.started_at: 

95 order_by_clauses.append(clause_function(Subscription.started_at)) 

96 elif criterion == CustomerSubscriptionSortProperty.amount: 

97 order_by_clauses.append(clause_function(Subscription.amount)) 

98 elif criterion == CustomerSubscriptionSortProperty.status: 

99 order_by_clauses.append(clause_function(Subscription.status)) 

100 elif criterion == CustomerSubscriptionSortProperty.organization: 

101 order_by_clauses.append(clause_function(Organization.slug)) 

102 elif criterion == CustomerSubscriptionSortProperty.product: 

103 order_by_clauses.append(clause_function(Product.name)) 

104 statement = statement.order_by(*order_by_clauses) 

105 

106 return await paginate(session, statement, pagination=pagination) 

107 

108 async def get_by_id( 1a

109 self, 

110 session: AsyncSession, 

111 auth_subject: AuthSubject[Customer], 

112 id: uuid.UUID, 

113 ) -> Subscription | None: 

114 statement = ( 

115 self._get_readable_subscription_statement(auth_subject) 

116 .where(Subscription.id == id) 

117 .options( 

118 joinedload(Subscription.customer), 

119 joinedload(Subscription.product).options( 

120 selectinload(Product.product_medias), 

121 joinedload(Product.organization), 

122 ), 

123 selectinload(Subscription.meters).joinedload(SubscriptionMeter.meter), 

124 ) 

125 ) 

126 

127 result = await session.execute(statement) 

128 return result.scalar_one_or_none() 

129 

130 async def update( 1a

131 self, 

132 session: AsyncSession, 

133 subscription: Subscription, 

134 *, 

135 updates: CustomerSubscriptionUpdate, 

136 ) -> Subscription: 

137 if isinstance(updates, CustomerSubscriptionUpdateProduct): 

138 organization = subscription.product.organization 

139 if not organization.allow_customer_updates: 

140 raise UpdateSubscriptionNotAllowed() 

141 

142 return await self.update_product( 

143 session, 

144 subscription, 

145 product_id=updates.product_id, 

146 ) 

147 

148 if isinstance(updates, CustomerSubscriptionUpdateSeats): 

149 organization = subscription.product.organization 

150 

151 return await subscription_service.update_seats( 

152 session, 

153 subscription, 

154 seats=updates.seats, 

155 proration_behavior=updates.proration_behavior, 

156 ) 

157 

158 cancel = updates.cancel_at_period_end is True 

159 uncancel = updates.cancel_at_period_end is False 

160 if not (cancel or uncancel): 

161 return subscription 

162 

163 if cancel: 

164 return await self.cancel( 

165 session, 

166 subscription, 

167 reason=updates.cancellation_reason, 

168 comment=updates.cancellation_comment, 

169 ) 

170 

171 return await self.uncancel(session, subscription) 

172 

173 async def update_product( 1a

174 self, 

175 session: AsyncSession, 

176 subscription: Subscription, 

177 *, 

178 product_id: uuid.UUID, 

179 ) -> Subscription: 

180 return await subscription_service.update_product( 

181 session, subscription, product_id=product_id 

182 ) 

183 

184 async def uncancel( 1a

185 self, 

186 session: AsyncSession, 

187 subscription: Subscription, 

188 ) -> Subscription: 

189 return await subscription_service.uncancel( 

190 session, 

191 subscription, 

192 ) 

193 

194 async def cancel( 1a

195 self, 

196 session: AsyncSession, 

197 subscription: Subscription, 

198 *, 

199 reason: CustomerCancellationReason | None = None, 

200 comment: str | None = None, 

201 ) -> Subscription: 

202 return await subscription_service.cancel( 

203 session, 

204 subscription, 

205 customer_reason=reason, 

206 customer_comment=comment, 

207 ) 

208 

209 def _get_readable_subscription_statement( 1a

210 self, auth_subject: AuthSubject[Customer] 

211 ) -> Select[tuple[Subscription]]: 

212 return select(Subscription).where( 

213 Subscription.deleted_at.is_(None), 

214 Subscription.customer_id == auth_subject.subject.id, 

215 ) 

216 

217 

218customer_subscription = CustomerSubscriptionService(Subscription) 1a