Coverage for polar/customer_portal/endpoints/oauth_accounts.py: 38%
80 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
1import uuid 1a
2from typing import Any 1a
4import structlog 1a
5from fastapi import Depends, Query, Request 1a
6from fastapi.responses import RedirectResponse 1a
7from httpx_oauth.clients.discord import DiscordOAuth2 1a
8from httpx_oauth.clients.github import GitHubOAuth2 1a
9from httpx_oauth.exceptions import GetProfileError 1a
10from httpx_oauth.oauth2 import BaseOAuth2, GetAccessTokenError 1a
11from pydantic import UUID4 1a
13from polar.auth.models import Customer, is_anonymous, is_customer 1a
14from polar.config import settings 1a
15from polar.customer.repository import CustomerRepository 1a
16from polar.customer_session.service import customer_session as customer_session_service 1a
17from polar.exceptions import PolarError 1a
18from polar.integrations.github.client import Forbidden 1a
19from polar.kit import jwt 1a
20from polar.kit.http import ReturnTo, add_query_parameters, get_safe_return_url 1a
21from polar.logging import Logger 1a
22from polar.models.customer import CustomerOAuthAccount, CustomerOAuthPlatform 1a
23from polar.openapi import APITag 1a
24from polar.postgres import AsyncSession, get_db_session 1a
25from polar.routing import APIRouter 1a
27from .. import auth 1a
28from ..schemas.oauth_accounts import AuthorizeResponse 1a
30router = APIRouter(prefix="/oauth-accounts", tags=["oauth-accounts", APITag.private]) 1a
32log: Logger = structlog.get_logger() 1a
35OAUTH_CLIENTS: dict[CustomerOAuthPlatform, BaseOAuth2[Any]] = { 1a
36 CustomerOAuthPlatform.github: GitHubOAuth2(
37 settings.GITHUB_CLIENT_ID, settings.GITHUB_CLIENT_SECRET
38 ),
39 CustomerOAuthPlatform.discord: DiscordOAuth2(
40 settings.DISCORD_CLIENT_ID,
41 settings.DISCORD_CLIENT_SECRET,
42 scopes=["identify", "email", "guilds.join"],
43 ),
44}
47class OAuthCallbackError(PolarError): 1a
48 def __init__(self, message: str) -> None: 1a
49 super().__init__(message, 400)
52@router.get("/authorize", name="customer_portal.oauth_accounts.authorize") 1a
53async def authorize( 1a
54 request: Request,
55 return_to: ReturnTo,
56 auth_subject: auth.CustomerPortalWrite,
57 platform: CustomerOAuthPlatform = Query(...),
58 customer_id: UUID4 = Query(...),
59 session: AsyncSession = Depends(get_db_session),
60) -> AuthorizeResponse:
61 customer = auth_subject.subject
62 state = {
63 "customer_id": str(customer.id),
64 "platform": platform,
65 "return_to": return_to,
66 }
67 encoded_state = jwt.encode(
68 data=state, secret=settings.SECRET, type="customer_oauth"
69 )
70 client = OAUTH_CLIENTS[platform]
71 authorization_url = await client.get_authorization_url(
72 redirect_uri=str(request.url_for("customer_portal.oauth_accounts.callback")),
73 state=encoded_state,
74 )
76 return AuthorizeResponse(url=authorization_url)
79@router.get("/callback", name="customer_portal.oauth_accounts.callback") 1a
80async def callback( 1a
81 request: Request,
82 auth_subject: auth.CustomerPortalOAuthAccount,
83 state: str,
84 code: str | None = None,
85 error: str | None = None,
86 session: AsyncSession = Depends(get_db_session),
87) -> RedirectResponse:
88 try:
89 state_data = jwt.decode(
90 token=state,
91 secret=settings.SECRET,
92 type="customer_oauth",
93 )
94 except jwt.DecodeError as e:
95 raise Forbidden("Invalid state") from e
97 customer_repository = CustomerRepository.from_session(session)
98 customer_id = uuid.UUID(state_data.get("customer_id"))
99 customer: Customer | None = None
100 if is_customer(auth_subject):
101 customer = auth_subject.subject
102 elif is_anonymous(auth_subject):
103 # Trust the customer ID in the state for anonymous users
104 customer = await customer_repository.get_by_id(customer_id)
106 if customer is None:
107 raise Forbidden("Invalid customer")
109 return_to = state_data["return_to"]
110 platform = CustomerOAuthPlatform(state_data["platform"])
112 redirect_url = get_safe_return_url(return_to)
113 # If not authenticated, create a new customer session, we trust the customer ID in the state
114 if is_anonymous(auth_subject):
115 token, _ = await customer_session_service.create_customer_session(
116 session, customer
117 )
118 redirect_url = add_query_parameters(redirect_url, customer_session_token=token)
120 if code is None or error is not None:
121 redirect_url = add_query_parameters(
122 redirect_url, error=error or "Failed to authorize."
123 )
124 return RedirectResponse(redirect_url, 303)
126 try:
127 client = OAUTH_CLIENTS[platform]
128 oauth2_token_data = await client.get_access_token(
129 code, str(request.url_for("customer_portal.oauth_accounts.callback"))
130 )
131 except GetAccessTokenError as e:
132 redirect_url = add_query_parameters(
133 redirect_url, error="Failed to get access token. Please try again later."
134 )
135 log.error("Failed to get access token", error=str(e))
136 return RedirectResponse(redirect_url, 303)
138 try:
139 profile = await client.get_profile(oauth2_token_data["access_token"])
140 except GetProfileError as e:
141 redirect_url = add_query_parameters(
142 redirect_url,
143 error="Failed to get profile information. Please try again later.",
144 )
145 log.error("Failed to get account ID", error=str(e))
146 return RedirectResponse(redirect_url, 303)
148 oauth_account = CustomerOAuthAccount(
149 access_token=oauth2_token_data["access_token"],
150 expires_at=oauth2_token_data["expires_at"],
151 refresh_token=oauth2_token_data["refresh_token"],
152 account_id=platform.get_account_id(profile),
153 account_username=platform.get_account_username(profile),
154 )
156 customer.set_oauth_account(oauth_account, platform)
157 await customer_repository.update(customer)
159 return RedirectResponse(redirect_url)