Coverage for polar/customer_portal/endpoints/oauth_accounts.py: 38%

80 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 16:17 +0000

1import uuid 1a

2from typing import Any 1a

3 

4import structlog 1a

5from fastapi import Depends, Query, Request 1a

6from fastapi.responses import RedirectResponse 1a

7from httpx_oauth.clients.discord import DiscordOAuth2 1a

8from httpx_oauth.clients.github import GitHubOAuth2 1a

9from httpx_oauth.exceptions import GetProfileError 1a

10from httpx_oauth.oauth2 import BaseOAuth2, GetAccessTokenError 1a

11from pydantic import UUID4 1a

12 

13from polar.auth.models import Customer, is_anonymous, is_customer 1a

14from polar.config import settings 1a

15from polar.customer.repository import CustomerRepository 1a

16from polar.customer_session.service import customer_session as customer_session_service 1a

17from polar.exceptions import PolarError 1a

18from polar.integrations.github.client import Forbidden 1a

19from polar.kit import jwt 1a

20from polar.kit.http import ReturnTo, add_query_parameters, get_safe_return_url 1a

21from polar.logging import Logger 1a

22from polar.models.customer import CustomerOAuthAccount, CustomerOAuthPlatform 1a

23from polar.openapi import APITag 1a

24from polar.postgres import AsyncSession, get_db_session 1a

25from polar.routing import APIRouter 1a

26 

27from .. import auth 1a

28from ..schemas.oauth_accounts import AuthorizeResponse 1a

29 

30router = APIRouter(prefix="/oauth-accounts", tags=["oauth-accounts", APITag.private]) 1a

31 

32log: Logger = structlog.get_logger() 1a

33 

34 

35OAUTH_CLIENTS: dict[CustomerOAuthPlatform, BaseOAuth2[Any]] = { 1a

36 CustomerOAuthPlatform.github: GitHubOAuth2( 

37 settings.GITHUB_CLIENT_ID, settings.GITHUB_CLIENT_SECRET 

38 ), 

39 CustomerOAuthPlatform.discord: DiscordOAuth2( 

40 settings.DISCORD_CLIENT_ID, 

41 settings.DISCORD_CLIENT_SECRET, 

42 scopes=["identify", "email", "guilds.join"], 

43 ), 

44} 

45 

46 

47class OAuthCallbackError(PolarError): 1a

48 def __init__(self, message: str) -> None: 1a

49 super().__init__(message, 400) 

50 

51 

52@router.get("/authorize", name="customer_portal.oauth_accounts.authorize") 1a

53async def authorize( 1a

54 request: Request, 

55 return_to: ReturnTo, 

56 auth_subject: auth.CustomerPortalWrite, 

57 platform: CustomerOAuthPlatform = Query(...), 

58 customer_id: UUID4 = Query(...), 

59 session: AsyncSession = Depends(get_db_session), 

60) -> AuthorizeResponse: 

61 customer = auth_subject.subject 

62 state = { 

63 "customer_id": str(customer.id), 

64 "platform": platform, 

65 "return_to": return_to, 

66 } 

67 encoded_state = jwt.encode( 

68 data=state, secret=settings.SECRET, type="customer_oauth" 

69 ) 

70 client = OAUTH_CLIENTS[platform] 

71 authorization_url = await client.get_authorization_url( 

72 redirect_uri=str(request.url_for("customer_portal.oauth_accounts.callback")), 

73 state=encoded_state, 

74 ) 

75 

76 return AuthorizeResponse(url=authorization_url) 

77 

78 

79@router.get("/callback", name="customer_portal.oauth_accounts.callback") 1a

80async def callback( 1a

81 request: Request, 

82 auth_subject: auth.CustomerPortalOAuthAccount, 

83 state: str, 

84 code: str | None = None, 

85 error: str | None = None, 

86 session: AsyncSession = Depends(get_db_session), 

87) -> RedirectResponse: 

88 try: 

89 state_data = jwt.decode( 

90 token=state, 

91 secret=settings.SECRET, 

92 type="customer_oauth", 

93 ) 

94 except jwt.DecodeError as e: 

95 raise Forbidden("Invalid state") from e 

96 

97 customer_repository = CustomerRepository.from_session(session) 

98 customer_id = uuid.UUID(state_data.get("customer_id")) 

99 customer: Customer | None = None 

100 if is_customer(auth_subject): 

101 customer = auth_subject.subject 

102 elif is_anonymous(auth_subject): 

103 # Trust the customer ID in the state for anonymous users 

104 customer = await customer_repository.get_by_id(customer_id) 

105 

106 if customer is None: 

107 raise Forbidden("Invalid customer") 

108 

109 return_to = state_data["return_to"] 

110 platform = CustomerOAuthPlatform(state_data["platform"]) 

111 

112 redirect_url = get_safe_return_url(return_to) 

113 # If not authenticated, create a new customer session, we trust the customer ID in the state 

114 if is_anonymous(auth_subject): 

115 token, _ = await customer_session_service.create_customer_session( 

116 session, customer 

117 ) 

118 redirect_url = add_query_parameters(redirect_url, customer_session_token=token) 

119 

120 if code is None or error is not None: 

121 redirect_url = add_query_parameters( 

122 redirect_url, error=error or "Failed to authorize." 

123 ) 

124 return RedirectResponse(redirect_url, 303) 

125 

126 try: 

127 client = OAUTH_CLIENTS[platform] 

128 oauth2_token_data = await client.get_access_token( 

129 code, str(request.url_for("customer_portal.oauth_accounts.callback")) 

130 ) 

131 except GetAccessTokenError as e: 

132 redirect_url = add_query_parameters( 

133 redirect_url, error="Failed to get access token. Please try again later." 

134 ) 

135 log.error("Failed to get access token", error=str(e)) 

136 return RedirectResponse(redirect_url, 303) 

137 

138 try: 

139 profile = await client.get_profile(oauth2_token_data["access_token"]) 

140 except GetProfileError as e: 

141 redirect_url = add_query_parameters( 

142 redirect_url, 

143 error="Failed to get profile information. Please try again later.", 

144 ) 

145 log.error("Failed to get account ID", error=str(e)) 

146 return RedirectResponse(redirect_url, 303) 

147 

148 oauth_account = CustomerOAuthAccount( 

149 access_token=oauth2_token_data["access_token"], 

150 expires_at=oauth2_token_data["expires_at"], 

151 refresh_token=oauth2_token_data["refresh_token"], 

152 account_id=platform.get_account_id(profile), 

153 account_username=platform.get_account_username(profile), 

154 ) 

155 

156 customer.set_oauth_account(oauth_account, platform) 

157 await customer_repository.update(customer) 

158 

159 return RedirectResponse(redirect_url)