Coverage for polar/benefit/strategies/discord/service.py: 18%

100 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1from typing import Any, cast 1a

2 

3import httpx 1a

4import structlog 1a

5from httpx_oauth.clients.discord import DiscordOAuth2 1a

6from httpx_oauth.oauth2 import RefreshTokenError 1a

7 

8from polar.auth.models import AuthSubject 1a

9from polar.config import settings 1a

10from polar.customer.repository import CustomerRepository 1a

11from polar.integrations.discord.service import discord_bot as discord_bot_service 1a

12from polar.logging import Logger 1a

13from polar.models import Benefit, Customer, Organization, User 1a

14from polar.models.customer import CustomerOAuthAccount, CustomerOAuthPlatform 1a

15 

16from ..base.service import ( 1a

17 BenefitActionRequiredError, 

18 BenefitPropertiesValidationError, 

19 BenefitRetriableError, 

20 BenefitServiceProtocol, 

21) 

22from .properties import BenefitDiscordProperties, BenefitGrantDiscordProperties 1a

23 

24log: Logger = structlog.get_logger() 1a

25 

26 

27class BenefitDiscordService( 1a

28 BenefitServiceProtocol[BenefitDiscordProperties, BenefitGrantDiscordProperties] 

29): 

30 async def grant( 1a

31 self, 

32 benefit: Benefit, 

33 customer: Customer, 

34 grant_properties: BenefitGrantDiscordProperties, 

35 *, 

36 update: bool = False, 

37 attempt: int = 1, 

38 ) -> BenefitGrantDiscordProperties: 

39 bound_logger = log.bind( 

40 benefit_id=str(benefit.id), 

41 customer_id=str(customer.id), 

42 ) 

43 bound_logger.debug("Grant benefit") 

44 

45 properties = self._get_properties(benefit) 

46 guild_id = properties["guild_id"] 

47 role_id = properties["role_id"] 

48 

49 # If we already granted this benefit, make sure we revoke the previous config 

50 if update and grant_properties: 

51 bound_logger.debug("Grant benefit update") 

52 previous_guild_id = grant_properties.get("guild_id") 

53 previous_role_id = grant_properties.get("role_id") 

54 account_id = grant_properties.get("account_id") 

55 granted_account_id = grant_properties.get("granted_account_id") 

56 if ( 

57 (previous_guild_id is not None and previous_guild_id != guild_id) 

58 or (previous_role_id is not None and previous_role_id != role_id) 

59 or (granted_account_id is not None and granted_account_id != account_id) 

60 ): 

61 bound_logger.debug( 

62 "Revoke before granting because guild, role or account have changed" 

63 ) 

64 await self.revoke(benefit, customer, grant_properties, attempt=attempt) 

65 

66 if (account_id := grant_properties.get("account_id")) is None: 

67 raise BenefitActionRequiredError( 

68 "The customer needs to connect their Discord account" 

69 ) 

70 

71 oauth_account = await self._get_customer_oauth_account(customer, account_id) 

72 

73 try: 

74 await discord_bot_service.add_member( 

75 guild_id, role_id, oauth_account.account_id, oauth_account.access_token 

76 ) 

77 except httpx.HTTPError as e: 

78 error_bound_logger = bound_logger.bind(error=str(e)) 

79 if isinstance(e, httpx.HTTPStatusError): 

80 error_bound_logger = error_bound_logger.bind( 

81 status_code=e.response.status_code, body=e.response.text 

82 ) 

83 error_bound_logger.warning("HTTP error while adding member") 

84 raise BenefitRetriableError() from e 

85 

86 bound_logger.debug("Benefit granted") 

87 

88 # Store guild, and role as it may change if the benefit is updated 

89 return { 

90 **grant_properties, 

91 "guild_id": guild_id, 

92 "role_id": role_id, 

93 "granted_account_id": account_id, 

94 } 

95 

96 async def cycle( 1a

97 self, 

98 benefit: Benefit, 

99 customer: Customer, 

100 grant_properties: BenefitGrantDiscordProperties, 

101 *, 

102 attempt: int = 1, 

103 ) -> BenefitGrantDiscordProperties: 

104 return grant_properties 

105 

106 async def revoke( 1a

107 self, 

108 benefit: Benefit, 

109 customer: Customer, 

110 grant_properties: BenefitGrantDiscordProperties, 

111 *, 

112 attempt: int = 1, 

113 ) -> BenefitGrantDiscordProperties: 

114 bound_logger = log.bind( 

115 benefit_id=str(benefit.id), 

116 customer_id=str(customer.id), 

117 ) 

118 

119 guild_id = grant_properties.get("guild_id") 

120 role_id = grant_properties.get("role_id") 

121 account_id = grant_properties.get("granted_account_id") 

122 

123 if not (guild_id and role_id and account_id): 

124 return {} 

125 

126 properties = self._get_properties(benefit) 

127 

128 try: 

129 if properties["kick_member"]: 

130 await discord_bot_service.remove_member(guild_id, account_id) 

131 else: 

132 await discord_bot_service.remove_member_role( 

133 guild_id, role_id, account_id 

134 ) 

135 except httpx.HTTPError as e: 

136 error_bound_logger = bound_logger.bind(error=str(e)) 

137 if isinstance(e, httpx.HTTPStatusError): 

138 error_bound_logger = error_bound_logger.bind( 

139 status_code=e.response.status_code, body=e.response.text 

140 ) 

141 error_bound_logger.warning("HTTP error while removing member") 

142 raise BenefitRetriableError() from e 

143 

144 bound_logger.debug("Benefit revoked") 

145 

146 # Keep account_id in case we need to re-grant later 

147 return { 

148 "account_id": grant_properties.get("account_id"), 

149 } 

150 

151 async def requires_update( 1a

152 self, benefit: Benefit, previous_properties: BenefitDiscordProperties 

153 ) -> bool: 

154 new_properties = self._get_properties(benefit) 

155 return ( 

156 new_properties["guild_id"] != previous_properties["guild_id"] 

157 or new_properties["role_id"] != previous_properties["role_id"] 

158 ) 

159 

160 async def validate_properties( 1a

161 self, auth_subject: AuthSubject[User | Organization], properties: dict[str, Any] 

162 ) -> BenefitDiscordProperties: 

163 guild_id: str = properties["guild_id"] 

164 role_id: str = properties["role_id"] 

165 

166 guild = await discord_bot_service.get_guild(guild_id) 

167 guild_roles = [role.id for role in guild.roles] 

168 

169 if role_id not in guild_roles: 

170 raise BenefitPropertiesValidationError( 

171 [ 

172 { 

173 "type": "invalid_role", 

174 "msg": "This role does not exist on this server.", 

175 "loc": ("role_id",), 

176 "input": role_id, 

177 } 

178 ] 

179 ) 

180 

181 if not await discord_bot_service.is_bot_role_above_role(guild_id, role_id): 

182 raise BenefitPropertiesValidationError( 

183 [ 

184 { 

185 "type": "invalid_role_position", 

186 "msg": "This role is above the Polar bot role, so Discord won't let our bot grants it. Please reorder them so the Polar bot is above.", 

187 "loc": ("role_id",), 

188 "input": role_id, 

189 } 

190 ] 

191 ) 

192 

193 return cast(BenefitDiscordProperties, properties) 

194 

195 async def _get_customer_oauth_account( 1a

196 self, customer: Customer, account_id: str 

197 ) -> CustomerOAuthAccount: 

198 oauth_account = customer.get_oauth_account( 

199 account_id, CustomerOAuthPlatform.discord 

200 ) 

201 if oauth_account is None: 

202 raise BenefitActionRequiredError( 

203 "The customer needs to connect their Discord account" 

204 ) 

205 

206 if oauth_account.is_expired(): 

207 if oauth_account.refresh_token is None: 

208 raise BenefitActionRequiredError( 

209 "The customer needs to reconnect their Discord account" 

210 ) 

211 

212 log.debug( 

213 "Refresh Discord access token", 

214 oauth_account_id=oauth_account.account_id, 

215 customer_id=str(customer.id), 

216 ) 

217 client = DiscordOAuth2( 

218 settings.DISCORD_CLIENT_ID, 

219 settings.DISCORD_CLIENT_SECRET, 

220 scopes=["identify", "email", "guilds.join"], 

221 ) 

222 try: 

223 refreshed_token_data = await client.refresh_token( 

224 oauth_account.refresh_token 

225 ) 

226 except RefreshTokenError as e: 

227 log.warning( 

228 "Failed to refresh Discord access token", 

229 oauth_account_id=oauth_account.account_id, 

230 customer_id=str(customer.id), 

231 error=str(e), 

232 ) 

233 raise BenefitActionRequiredError( 

234 "The customer needs to reconnect their Discord account" 

235 ) from e 

236 oauth_account.access_token = refreshed_token_data["access_token"] 

237 oauth_account.expires_at = refreshed_token_data["expires_at"] 

238 oauth_account.refresh_token = refreshed_token_data["refresh_token"] 

239 customer.set_oauth_account(oauth_account, CustomerOAuthPlatform.discord) 

240 

241 customer_repository = CustomerRepository.from_session(self.session) 

242 await customer_repository.update(customer) 

243 

244 return oauth_account