Coverage for polar/oauth2/grants/authorization_code.py: 30%
154 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
1import typing 1a
2import uuid 1a
4from authlib.oauth2.rfc6749.errors import ( 1a
5 AccessDeniedError,
6 InvalidRequestError,
7 OAuth2Error,
8)
9from authlib.oauth2.rfc6749.grants import ( 1a
10 AuthorizationCodeGrant as _AuthorizationCodeGrant,
11)
12from authlib.oauth2.rfc6749.requests import OAuth2Request 1a
13from authlib.oauth2.rfc7636 import CodeChallenge as _CodeChallenge 1a
14from authlib.oidc.core.errors import ConsentRequiredError, LoginRequiredError 1a
15from authlib.oidc.core.grants import OpenIDCode as _OpenIDCode 1a
16from authlib.oidc.core.grants import OpenIDToken as _OpenIDToken 1a
17from sqlalchemy import and_, select 1a
18from sqlalchemy.orm import Session 1a
20from polar.config import settings 1a
21from polar.kit.crypto import generate_token, get_token_hash 1a
22from polar.models import ( 1a
23 OAuth2AuthorizationCode,
24 OAuth2Client,
25 Organization,
26 User,
27 UserOrganization,
28)
30from ..constants import AUTHORIZATION_CODE_PREFIX, JWT_CONFIG 1a
31from ..requests import StarletteOAuth2Request 1a
32from ..service.oauth2_grant import oauth2_grant as oauth2_grant_service 1a
33from ..sub_type import SubType, SubTypeValue 1a
34from ..userinfo import UserInfo, generate_user_info 1a
36if typing.TYPE_CHECKING: 36 ↛ 37line 36 didn't jump to line 37 because the condition on line 36 was never true1a
37 from ..authorization_server import AuthorizationServer
40def _exists_nonce( 1a
41 session: Session, nonce: str, request: StarletteOAuth2Request
42) -> bool:
43 statement = select(OAuth2AuthorizationCode).where(
44 OAuth2AuthorizationCode.client_id == request.client_id,
45 OAuth2AuthorizationCode.nonce == nonce,
46 )
47 result = session.execute(statement)
48 return result.unique().scalar_one_or_none() is not None
51class SubTypeGrantMixin: 1a
52 sub_type: SubType | None = None 1a
53 sub: User | Organization | None = None 1a
56class AuthorizationCodeGrant(SubTypeGrantMixin, _AuthorizationCodeGrant): 1a
57 server: "AuthorizationServer"
58 TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"] 1a
60 def __init__(self, request: OAuth2Request, server: "AuthorizationServer") -> None: 1a
61 super().__init__(request, server)
62 self._hooks["before_create_authorization_response"] = set()
63 self._hooks["before_validate_authorization_request_payload"] = {
64 self.before_validate_authorization_request_payload
65 }
67 def before_validate_authorization_request_payload( 1a
68 self, grant: "typing.Self", redirect_uri: str
69 ) -> None:
70 """
71 If no scope is provided in the authorization request,
72 default to the client's scope.
73 """
74 payload = self.request.payload
75 scope: str | None = payload.data.get("scope")
76 if scope is None:
77 self.request.payload.data["scope"] = self.request.client.scope
79 def create_authorization_response( 1a
80 self, redirect_uri: str, grant_user: User | None
81 ) -> tuple[int, str | dict[str, typing.Any], list[tuple[str, str]]]:
82 payload = self.request.payload
83 assert payload is not None
85 if not grant_user:
86 raise AccessDeniedError(state=payload.state, redirect_uri=redirect_uri)
88 self.request.user = grant_user # pyright: ignore
90 self.execute_hook(
91 "before_create_authorization_response", redirect_uri, grant_user
92 )
93 return super().create_authorization_response(redirect_uri, grant_user) # pyright: ignore
95 def generate_authorization_code(self) -> str: 1a
96 return generate_token(prefix=AUTHORIZATION_CODE_PREFIX)
98 def save_authorization_code( 1a
99 self, code: str, request: StarletteOAuth2Request
100 ) -> OAuth2AuthorizationCode:
101 payload = request.payload
102 assert payload is not None
104 nonce = payload.data.get("nonce")
105 code_challenge = payload.data.get("code_challenge")
106 code_challenge_method = payload.data.get("code_challenge_method")
108 assert self.sub_type is not None
109 assert self.sub is not None
111 authorization_code = OAuth2AuthorizationCode(
112 code=get_token_hash(code, secret=settings.SECRET),
113 client_id=payload.client_id,
114 sub_type=self.sub_type,
115 scope=payload.scope,
116 redirect_uri=payload.redirect_uri,
117 nonce=nonce,
118 code_challenge=code_challenge,
119 code_challenge_method=code_challenge_method,
120 )
121 authorization_code.sub = self.sub
123 self.server.session.add(authorization_code)
124 self.server.session.flush()
125 return authorization_code
127 def query_authorization_code( 1a
128 self, code: str, client: OAuth2Client
129 ) -> OAuth2AuthorizationCode | None:
130 code_hash = get_token_hash(code, secret=settings.SECRET)
131 statement = select(OAuth2AuthorizationCode).where(
132 OAuth2AuthorizationCode.code == code_hash,
133 OAuth2AuthorizationCode.client_id == client.client_id,
134 )
135 result = self.server.session.execute(statement)
136 authorization_code = result.unique().scalar_one_or_none()
137 if authorization_code is not None and not typing.cast(
138 bool, authorization_code.is_expired()
139 ):
140 return authorization_code
141 return None
143 def delete_authorization_code( 1a
144 self, authorization_code: OAuth2AuthorizationCode
145 ) -> None:
146 self.server.session.delete(authorization_code)
147 self.server.session.flush()
149 def authenticate_user( 1a
150 self, authorization_code: OAuth2AuthorizationCode
151 ) -> SubTypeValue | None:
152 return authorization_code.get_sub_type_value()
155class CodeChallenge(_CodeChallenge): 1a
156 pass 1a
159class OpenIDCode(_OpenIDCode): 1a
160 def __init__(self, session: Session, require_nonce: bool = False): 1a
161 super().__init__(require_nonce)
162 self._session = session
164 def exists_nonce(self, nonce: str, request: StarletteOAuth2Request) -> bool: 1a
165 return _exists_nonce(self._session, nonce, request)
167 def get_jwt_config(self, grant: AuthorizationCodeGrant) -> dict[str, typing.Any]: 1a
168 return JWT_CONFIG
170 def generate_user_info(self, user: SubTypeValue, scope: str) -> UserInfo: 1a
171 return generate_user_info(user, scope)
174class OpenIDToken(_OpenIDToken): 1a
175 def get_jwt_config(self, grant: AuthorizationCodeGrant) -> dict[str, typing.Any]: 1a
176 return JWT_CONFIG
178 def generate_user_info(self, user: SubTypeValue, scope: str) -> UserInfo: 1a
179 return generate_user_info(user, scope)
182class InvalidSubError(OAuth2Error): 1a
183 error = "invalid_sub" 1a
186class ValidateSubAndPrompt: 1a
187 def __init__(self, session: Session) -> None: 1a
188 self._session = session
190 def __call__(self, grant: AuthorizationCodeGrant) -> None: 1a
191 grant.register_hook("after_validate_consent_request", self._validate)
192 grant.register_hook(
193 "before_create_authorization_response", self._validate_resolved_sub
194 )
196 def _validate( 1a
197 self,
198 grant: AuthorizationCodeGrant,
199 redirect_uri: str,
200 redirect_fragment: bool = False,
201 ) -> None:
202 self._validate_sub(grant, redirect_uri, redirect_fragment)
203 self._validate_scope_consent(grant, redirect_uri, redirect_fragment)
205 def _validate_sub( 1a
206 self,
207 grant: AuthorizationCodeGrant,
208 redirect_uri: str,
209 redirect_fragment: bool = False,
210 ) -> None:
211 payload = grant.request.payload
212 assert payload is not None
214 sub_type: str | None = payload.data.get("sub_type")
215 if sub_type:
216 try:
217 grant.sub_type = SubType(sub_type)
218 except ValueError as e:
219 raise InvalidRequestError("Invalid sub_type") from e
220 else:
221 client: OAuth2Client = typing.cast(OAuth2Client, grant.client)
222 grant.sub_type = client.default_sub_type
224 sub: str | None = payload.data.get("sub")
225 user = grant.request.user
227 if grant.sub_type == SubType.user:
228 grant.sub = user
229 if sub is not None:
230 raise InvalidRequestError("Can't specify sub for user sub_type")
231 elif (
232 grant.sub_type == SubType.organization
233 and sub is not None
234 and user is not None
235 ):
236 try:
237 sub_uuid = uuid.UUID(sub)
238 except ValueError as e:
239 raise InvalidSubError() from e
240 organization = self._get_organization_admin(sub_uuid, user)
241 if organization is None:
242 raise InvalidSubError()
243 grant.sub = organization
245 def _validate_scope_consent( 1a
246 self,
247 grant: AuthorizationCodeGrant,
248 redirect_uri: str,
249 redirect_fragment: bool = False,
250 ) -> None:
251 payload = grant.request.payload
252 assert payload is not None
254 prompt = payload.data.get("prompt")
256 # Check if the sub has granted the requested scope or a subset of it
257 has_granted_scope = False
258 if grant.sub is not None:
259 assert grant.client is not None
260 assert grant.sub_type is not None
261 has_granted_scope = oauth2_grant_service.has_granted_scope(
262 self._session,
263 sub_type=grant.sub_type,
264 sub_id=grant.sub.id,
265 client_id=grant.client.client_id,
266 scope=payload.scope,
267 )
269 # If the prompt is "none", the sub must be authenticated and have granted the requested scope
270 if prompt == "none":
271 if grant.sub is None:
272 raise LoginRequiredError(
273 redirect_uri=redirect_uri, redirect_fragment=redirect_fragment
274 )
275 if not has_granted_scope:
276 raise ConsentRequiredError(
277 redirect_uri=redirect_uri, redirect_fragment=redirect_fragment
278 )
280 # Bypass everything if nothing is specified and conditions are met
281 if prompt is None and has_granted_scope:
282 grant.prompt = "none"
284 def _validate_resolved_sub( 1a
285 self,
286 grant: AuthorizationCodeGrant,
287 redirect_uri: str,
288 redirect_fragment: bool = False,
289 ) -> None:
290 self._validate_sub(grant, redirect_uri, redirect_fragment)
291 if grant.sub is None:
292 raise InvalidSubError()
294 def _get_organization_admin( 1a
295 self, organization_id: uuid.UUID, user: User
296 ) -> Organization | None:
297 statement = (
298 select(Organization)
299 .join(
300 UserOrganization,
301 onclause=and_(
302 UserOrganization.user_id == user.id,
303 UserOrganization.deleted_at.is_(None),
304 ),
305 )
306 .where(Organization.id == organization_id)
307 )
308 result = self._session.execute(statement)
309 return result.unique().scalar_one_or_none()