Coverage for polar/integrations/discord/endpoints.py: 40%
58 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
1from typing import Any 1a
2from uuid import UUID 1a
4import structlog 1a
5from fastapi import Request 1a
6from fastapi.responses import RedirectResponse 1a
7from httpx_oauth.oauth2 import GetAccessTokenError 1a
9from polar.auth.dependencies import WebUserWrite 1a
10from polar.config import settings 1a
11from polar.exceptions import Unauthorized 1a
12from polar.kit import jwt 1a
13from polar.kit.http import ReturnTo, add_query_parameters, get_safe_return_url 1a
14from polar.openapi import APITag 1a
15from polar.routing import APIRouter 1a
17from . import oauth 1a
18from .schemas import DiscordGuild 1a
19from .service import discord_bot as discord_bot_service 1a
21log = structlog.get_logger() 1a
23router = APIRouter( 1a
24 prefix="/integrations/discord",
25 tags=["integrations_discord", APITag.private],
26)
29###############################################################################
30# OAUTH2
31###############################################################################
34def get_decoded_token_state(state: str) -> dict[str, Any]: 1a
35 try:
36 state_data = jwt.decode(
37 token=state,
38 secret=settings.SECRET,
39 type="discord_oauth",
40 )
41 except jwt.DecodeError as e:
42 raise Unauthorized("Invalid state") from e
44 return state_data
47# -------------------------------------------------------------------------------
48# BOT
49# -------------------------------------------------------------------------------
52@router.get( 1a
53 "/bot/authorize",
54 name="integrations.discord.bot_authorize",
55)
56async def discord_bot_authorize( 1a
57 return_to: ReturnTo, request: Request, auth_subject: WebUserWrite
58) -> RedirectResponse:
59 state = {
60 "auth_type": "bot",
61 "user_id": str(auth_subject.subject.id),
62 "return_to": return_to,
63 }
65 encoded_state = jwt.encode(data=state, secret=settings.SECRET, type="discord_oauth")
67 authorization_url = await oauth.bot_client.get_authorization_url(
68 redirect_uri=str(request.url_for("integrations.discord.bot_callback")),
69 state=encoded_state,
70 extras_params=dict(
71 permissions=settings.DISCORD_BOT_PERMISSIONS,
72 ),
73 )
74 return RedirectResponse(authorization_url, 303)
77@router.get("/bot/callback", name="integrations.discord.bot_callback") 1a
78async def discord_bot_callback( 1a
79 auth_subject: WebUserWrite,
80 request: Request,
81 state: str,
82 code: str | None = None,
83 code_verifier: str | None = None,
84 error: str | None = None,
85) -> RedirectResponse:
86 decoded_state = get_decoded_token_state(state)
87 return_to = decoded_state["return_to"]
88 if code is None or error is not None:
89 redirect_url = get_safe_return_url(
90 add_query_parameters(
91 return_to, error=error or "Failed to authorize Discord bot."
92 )
93 )
94 return RedirectResponse(redirect_url, 303)
96 try:
97 access_token = await oauth.bot_client.get_access_token(
98 code, str(request.url_for("integrations.discord.bot_callback"))
99 )
100 except GetAccessTokenError as e:
101 redirect_url = get_safe_return_url(
102 add_query_parameters(
103 return_to, error="Failed to get access token. Please try again later."
104 )
105 )
106 log.error("Failed to get Discord bot access token", error=str(e))
107 return RedirectResponse(redirect_url, 303)
109 user_id = UUID(decoded_state["user_id"])
110 if user_id != auth_subject.subject.id or decoded_state["auth_type"] != "bot":
111 raise Unauthorized()
113 guild_id = access_token["guild"]["id"]
115 # We need to set this ID on a subsequent API call (e.g. create Discord benefit).
116 # To make sure a malicious user won't arbitrarily set guild IDs, we pass it as
117 # a signed JWT token.
118 guild_token = jwt.encode(
119 data={"guild_id": guild_id},
120 secret=settings.SECRET,
121 type="discord_guild_token",
122 )
124 redirect_url = get_safe_return_url(
125 add_query_parameters(return_to, guild_token=guild_token, guild_id=guild_id)
126 )
128 return RedirectResponse(redirect_url, 303)
131###############################################################################
132# API
133###############################################################################
136@router.get("/guild/lookup", response_model=DiscordGuild) 1a
137async def discord_guild_lookup( 1a
138 guild_token: str, auth_subject: WebUserWrite
139) -> DiscordGuild:
140 try:
141 guild_token_data = jwt.decode(
142 token=guild_token,
143 secret=settings.SECRET,
144 type="discord_guild_token",
145 )
146 guild_id = guild_token_data["guild_id"]
147 except (KeyError, jwt.DecodeError, jwt.ExpiredSignatureError) as e:
148 raise Unauthorized() from e
150 return await discord_bot_service.get_guild(guild_id)