Coverage for polar/integrations/discord/endpoints.py: 40%

58 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 16:17 +0000

1from typing import Any 1a

2from uuid import UUID 1a

3 

4import structlog 1a

5from fastapi import Request 1a

6from fastapi.responses import RedirectResponse 1a

7from httpx_oauth.oauth2 import GetAccessTokenError 1a

8 

9from polar.auth.dependencies import WebUserWrite 1a

10from polar.config import settings 1a

11from polar.exceptions import Unauthorized 1a

12from polar.kit import jwt 1a

13from polar.kit.http import ReturnTo, add_query_parameters, get_safe_return_url 1a

14from polar.openapi import APITag 1a

15from polar.routing import APIRouter 1a

16 

17from . import oauth 1a

18from .schemas import DiscordGuild 1a

19from .service import discord_bot as discord_bot_service 1a

20 

21log = structlog.get_logger() 1a

22 

23router = APIRouter( 1a

24 prefix="/integrations/discord", 

25 tags=["integrations_discord", APITag.private], 

26) 

27 

28 

29############################################################################### 

30# OAUTH2 

31############################################################################### 

32 

33 

34def get_decoded_token_state(state: str) -> dict[str, Any]: 1a

35 try: 

36 state_data = jwt.decode( 

37 token=state, 

38 secret=settings.SECRET, 

39 type="discord_oauth", 

40 ) 

41 except jwt.DecodeError as e: 

42 raise Unauthorized("Invalid state") from e 

43 

44 return state_data 

45 

46 

47# ------------------------------------------------------------------------------- 

48# BOT 

49# ------------------------------------------------------------------------------- 

50 

51 

52@router.get( 1a

53 "/bot/authorize", 

54 name="integrations.discord.bot_authorize", 

55) 

56async def discord_bot_authorize( 1a

57 return_to: ReturnTo, request: Request, auth_subject: WebUserWrite 

58) -> RedirectResponse: 

59 state = { 

60 "auth_type": "bot", 

61 "user_id": str(auth_subject.subject.id), 

62 "return_to": return_to, 

63 } 

64 

65 encoded_state = jwt.encode(data=state, secret=settings.SECRET, type="discord_oauth") 

66 

67 authorization_url = await oauth.bot_client.get_authorization_url( 

68 redirect_uri=str(request.url_for("integrations.discord.bot_callback")), 

69 state=encoded_state, 

70 extras_params=dict( 

71 permissions=settings.DISCORD_BOT_PERMISSIONS, 

72 ), 

73 ) 

74 return RedirectResponse(authorization_url, 303) 

75 

76 

77@router.get("/bot/callback", name="integrations.discord.bot_callback") 1a

78async def discord_bot_callback( 1a

79 auth_subject: WebUserWrite, 

80 request: Request, 

81 state: str, 

82 code: str | None = None, 

83 code_verifier: str | None = None, 

84 error: str | None = None, 

85) -> RedirectResponse: 

86 decoded_state = get_decoded_token_state(state) 

87 return_to = decoded_state["return_to"] 

88 if code is None or error is not None: 

89 redirect_url = get_safe_return_url( 

90 add_query_parameters( 

91 return_to, error=error or "Failed to authorize Discord bot." 

92 ) 

93 ) 

94 return RedirectResponse(redirect_url, 303) 

95 

96 try: 

97 access_token = await oauth.bot_client.get_access_token( 

98 code, str(request.url_for("integrations.discord.bot_callback")) 

99 ) 

100 except GetAccessTokenError as e: 

101 redirect_url = get_safe_return_url( 

102 add_query_parameters( 

103 return_to, error="Failed to get access token. Please try again later." 

104 ) 

105 ) 

106 log.error("Failed to get Discord bot access token", error=str(e)) 

107 return RedirectResponse(redirect_url, 303) 

108 

109 user_id = UUID(decoded_state["user_id"]) 

110 if user_id != auth_subject.subject.id or decoded_state["auth_type"] != "bot": 

111 raise Unauthorized() 

112 

113 guild_id = access_token["guild"]["id"] 

114 

115 # We need to set this ID on a subsequent API call (e.g. create Discord benefit). 

116 # To make sure a malicious user won't arbitrarily set guild IDs, we pass it as 

117 # a signed JWT token. 

118 guild_token = jwt.encode( 

119 data={"guild_id": guild_id}, 

120 secret=settings.SECRET, 

121 type="discord_guild_token", 

122 ) 

123 

124 redirect_url = get_safe_return_url( 

125 add_query_parameters(return_to, guild_token=guild_token, guild_id=guild_id) 

126 ) 

127 

128 return RedirectResponse(redirect_url, 303) 

129 

130 

131############################################################################### 

132# API 

133############################################################################### 

134 

135 

136@router.get("/guild/lookup", response_model=DiscordGuild) 1a

137async def discord_guild_lookup( 1a

138 guild_token: str, auth_subject: WebUserWrite 

139) -> DiscordGuild: 

140 try: 

141 guild_token_data = jwt.decode( 

142 token=guild_token, 

143 secret=settings.SECRET, 

144 type="discord_guild_token", 

145 ) 

146 guild_id = guild_token_data["guild_id"] 

147 except (KeyError, jwt.DecodeError, jwt.ExpiredSignatureError) as e: 

148 raise Unauthorized() from e 

149 

150 return await discord_bot_service.get_guild(guild_id)