Coverage for polar/subscription/tasks.py: 34%

84 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1import uuid 1a

2 

3import structlog 1a

4from sqlalchemy import func, select 1a

5from sqlalchemy.orm import selectinload 1a

6 

7from polar.exceptions import PolarTaskError 1a

8from polar.kit.utils import utc_now 1a

9from polar.locker import Locker 1a

10from polar.logging import Logger 1a

11from polar.models import ( 1a

12 Customer, 

13 Organization, 

14 Product, 

15 Subscription, 

16 SubscriptionMeter, 

17) 

18from polar.product.repository import ProductRepository 1a

19from polar.subscription.repository import SubscriptionRepository 1a

20from polar.worker import ( 1a

21 AsyncSessionMaker, 

22 RedisMiddleware, 

23 TaskPriority, 

24 actor, 

25 enqueue_job, 

26) 

27 

28from .service import SubscriptionNotReadyForMigration 1a

29from .service import subscription as subscription_service 1a

30 

31log: Logger = structlog.get_logger() 1a

32 

33 

34class SubscriptionTaskError(PolarTaskError): ... 1a

35 

36 

37class SubscriptionDoesNotExist(SubscriptionTaskError): 1a

38 def __init__(self, subscription_id: uuid.UUID) -> None: 1a

39 self.subscription_id = subscription_id 

40 message = f"The subscription with id {subscription_id} does not exist." 

41 super().__init__(message) 

42 

43 

44class SubscriptionTierDoesNotExist(SubscriptionTaskError): 1a

45 def __init__(self, subscription_tier_id: uuid.UUID) -> None: 1a

46 self.subscription_tier_id = subscription_tier_id 

47 message = ( 

48 f"The subscription tier with id {subscription_tier_id} does not exist." 

49 ) 

50 super().__init__(message) 

51 

52 

53@actor(actor_name="subscription.cycle", priority=TaskPriority.LOW) 1a

54async def subscription_cycle(subscription_id: uuid.UUID, force: bool = False) -> None: 1a

55 redis = RedisMiddleware.get() 

56 locker = Locker(redis) 

57 lock_name = f"subscription:cycle:{subscription_id}" 

58 

59 if await locker.is_locked(lock_name): 

60 log.info( 

61 "Subscription is already being cycled by another worker", 

62 subscription_id=subscription_id, 

63 ) 

64 return 

65 

66 async with locker.lock(lock_name, timeout=1.0, blocking_timeout=0.1): 

67 async with AsyncSessionMaker() as session: 

68 repository = SubscriptionRepository.from_session(session) 

69 subscription = await repository.get_by_id( 

70 subscription_id, options=repository.get_eager_options() 

71 ) 

72 if subscription is None: 

73 raise SubscriptionDoesNotExist(subscription_id) 

74 

75 if ( 

76 not force 

77 and subscription.current_period_end 

78 and subscription.current_period_end > utc_now() 

79 ): 

80 log.info( 

81 "Subscription has already been cycled", 

82 subscription_id=subscription_id, 

83 ) 

84 subscription = await repository.update( 

85 subscription, update_dict={"scheduler_locked_at": None} 

86 ) 

87 return 

88 

89 await subscription_service.cycle(session, subscription) 

90 

91 

92@actor( 1a

93 actor_name="subscription.subscription.update_product_benefits_grants", 

94 priority=TaskPriority.MEDIUM, 

95) 

96async def subscription_update_product_benefits_grants( 1a

97 subscription_tier_id: uuid.UUID, 

98) -> None: 

99 async with AsyncSessionMaker() as session: 

100 product_repository = ProductRepository.from_session(session) 

101 product = await product_repository.get_by_id(subscription_tier_id) 

102 if product is None: 

103 raise SubscriptionTierDoesNotExist(subscription_tier_id) 

104 

105 await subscription_service.update_product_benefits_grants(session, product) 

106 

107 

108@actor(actor_name="subscription.update_meters", priority=TaskPriority.LOW) 1a

109async def subscription_update_meters(subscription_id: uuid.UUID) -> None: 1a

110 async with AsyncSessionMaker() as session: 

111 repository = SubscriptionRepository.from_session(session) 

112 subscription = await repository.get_by_id( 

113 subscription_id, 

114 options=( 

115 selectinload(Subscription.meters).joinedload(SubscriptionMeter.meter), 

116 ), 

117 ) 

118 if subscription is None: 

119 raise SubscriptionDoesNotExist(subscription_id) 

120 await subscription_service.update_meters(session, subscription) 

121 

122 

123@actor(actor_name="subscription.cancel_customer", priority=TaskPriority.HIGH) 1a

124async def subscription_cancel_customer(customer_id: uuid.UUID) -> None: 1a

125 async with AsyncSessionMaker() as session: 

126 await subscription_service.cancel_customer(session, customer_id) 

127 

128 

129@actor( 1a

130 actor_name="subscription.enqueue_stripe_subscription_migrate", 

131 priority=TaskPriority.LOW, 

132) 

133async def enqueue_stripe_subscription_migrate( 1a

134 max_subscriptions_count: int, limit: int 

135) -> None: 

136 async with AsyncSessionMaker() as session: 

137 subscription_repository = SubscriptionRepository.from_session(session) 

138 

139 organizations_statement = ( 

140 select(Organization.id) 

141 .join(Product, Product.organization_id == Organization.id, isouter=True) 

142 .join( 

143 Subscription, 

144 Subscription.product_id == Product.id, 

145 isouter=True, 

146 ) 

147 .where( 

148 Subscription.stripe_subscription_id.is_not(None), 

149 ) 

150 .group_by(Organization.id) 

151 .having(func.count(Subscription.id) < max_subscriptions_count) 

152 .order_by(func.count(Subscription.id).asc()) 

153 .limit(limit) 

154 ) 

155 

156 subscriptions = subscription_repository.stream( 

157 subscription_repository.get_base_statement() 

158 .join(Customer, Customer.id == Subscription.customer_id) 

159 .where( 

160 Subscription.stripe_subscription_id.is_not(None), 

161 Customer.organization_id.in_(organizations_statement), 

162 ) 

163 ) 

164 async for subscription in subscriptions: 

165 enqueue_job("subscription.migrate_stripe_subscription", subscription.id) 

166 

167 

168@actor(actor_name="subscription.migrate_stripe_subscription", priority=TaskPriority.LOW) 1a

169async def migrate_stripe_subscription(subscription_id: uuid.UUID) -> None: 1a

170 async with AsyncSessionMaker() as session: 

171 repository = SubscriptionRepository.from_session(session) 

172 subscription = await repository.get_by_id(subscription_id) 

173 if subscription is None: 

174 raise SubscriptionDoesNotExist(subscription_id) 

175 try: 

176 await subscription_service.migrate_stripe_subscription( 

177 session, subscription 

178 ) 

179 except SubscriptionNotReadyForMigration: 

180 # Retry another time 

181 pass