Coverage for polar/oauth2/authorization_server.py: 27%
261 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
1import json 1a
2import secrets 1a
3import time 1a
4import typing 1a
6import structlog 1a
7from authlib.oauth2 import AuthorizationServer as _AuthorizationServer 1a
8from authlib.oauth2 import OAuth2Error 1a
9from authlib.oauth2.rfc6749.errors import ( 1a
10 UnsupportedResponseTypeError,
11)
12from authlib.oauth2.rfc6750 import BearerTokenGenerator 1a
13from authlib.oauth2.rfc7009 import RevocationEndpoint as _RevocationEndpoint 1a
14from authlib.oauth2.rfc7591 import ( 1a
15 ClientRegistrationEndpoint as _ClientRegistrationEndpoint,
16)
17from authlib.oauth2.rfc7592 import ( 1a
18 ClientConfigurationEndpoint as _ClientConfigurationEndpoint,
19)
20from authlib.oauth2.rfc7662 import IntrospectionEndpoint as _IntrospectionEndpoint 1a
21from sqlalchemy import or_, select 1a
22from sqlalchemy.orm import Session 1a
23from starlette.requests import Request 1a
24from starlette.responses import Response 1a
26from polar.config import settings 1a
27from polar.kit.crypto import generate_token, get_token_hash 1a
28from polar.logging import Logger 1a
29from polar.models import OAuth2Client, OAuth2Token, User 1a
30from polar.oauth2.sub_type import SubTypeValue 1a
32from .constants import ( 1a
33 ACCESS_TOKEN_PREFIX,
34 CLIENT_ID_PREFIX,
35 CLIENT_REGISTRATION_TOKEN_PREFIX,
36 CLIENT_SECRET_PREFIX,
37 ISSUER,
38 REFRESH_TOKEN_PREFIX,
39)
40from .grants import AuthorizationCodeGrant, CodeChallenge, register_grants 1a
41from .metadata import get_server_metadata 1a
42from .requests import StarletteJsonRequest, StarletteOAuth2Request 1a
43from .service.oauth2_grant import oauth2_grant as oauth2_grant_service 1a
45logger: Logger = structlog.get_logger(__name__) 1a
48def _get_server_metadata(server: "AuthorizationServer") -> dict[str, typing.Any]: 1a
49 def _dummy_url_for(name: str) -> str:
50 return name
52 return get_server_metadata(server, _dummy_url_for).model_dump(exclude_unset=True)
55class ClientRegistrationEndpoint(_ClientRegistrationEndpoint): 1a
56 server: "AuthorizationServer"
58 def generate_client_registration_info( 1a
59 self, client: OAuth2Client, request: StarletteJsonRequest
60 ) -> dict[str, str]:
61 assert client.registration_access_token is not None
62 return {
63 "registration_client_uri": str(
64 request.url_for("oauth2:get_client", client_id=client.client_id)
65 ),
66 "registration_access_token": client.registration_access_token,
67 }
69 def generate_client_id(self, request: StarletteJsonRequest) -> str: 1a
70 return generate_token(prefix=CLIENT_ID_PREFIX)
72 def generate_client_secret(self, request: StarletteJsonRequest) -> str: 1a
73 return generate_token(prefix=CLIENT_SECRET_PREFIX)
75 def create_registration_response( 1a
76 self, request: StarletteJsonRequest
77 ) -> tuple[int, dict[str, typing.Any], list[tuple[str, str]]]:
78 """
79 Create client registration response.
81 Temporary workaround: Exclude client_secret and client_secret_expires_at
82 from the response when token_endpoint_auth_method is 'none', as this
83 helps clients that haven't yet updated to properly handle public clients.
84 """
85 status, body, headers = super().create_registration_response(request)
87 # Check if this is a public client (token_endpoint_auth_method = none)
88 if isinstance(body, dict):
89 token_endpoint_auth_method = body.get("token_endpoint_auth_method")
90 if token_endpoint_auth_method == "none":
91 # Remove client_secret fields for public clients as a temporary workaround
92 body.pop("client_secret", None)
93 body.pop("client_secret_expires_at", None)
95 return status, body, headers
97 def get_server_metadata(self) -> dict[str, typing.Any]: 1a
98 return _get_server_metadata(self.server)
100 def authenticate_token(self, request: StarletteJsonRequest) -> User | str: 1a
101 return request.user if request.user is not None else "dynamic_client"
103 def save_client( 1a
104 self,
105 client_info: dict[str, typing.Any],
106 client_metadata: dict[str, typing.Any],
107 request: StarletteJsonRequest,
108 ) -> OAuth2Client:
109 oauth2_client = OAuth2Client(**client_info)
110 oauth2_client.set_client_metadata(client_metadata)
112 if request.user is not None:
113 oauth2_client.user_id = request.user.id
114 oauth2_client.registration_access_token = generate_token(
115 prefix=CLIENT_REGISTRATION_TOKEN_PREFIX
116 )
118 self.server.session.add(oauth2_client)
119 self.server.session.flush()
120 return oauth2_client
123class ClientConfigurationEndpoint(_ClientConfigurationEndpoint): 1a
124 server: "AuthorizationServer"
126 def generate_client_registration_info( 1a
127 self, client: OAuth2Client, request: StarletteJsonRequest
128 ) -> dict[str, str]:
129 return {
130 "registration_client_uri": str(
131 request.url_for("oauth2:get_client", client_id=client.client_id)
132 ),
133 "registration_access_token": client.registration_access_token,
134 }
136 def create_read_client_response( 1a
137 self, client: OAuth2Client, request: StarletteJsonRequest
138 ) -> tuple[int, dict[str, typing.Any], list[tuple[str, str]]]:
139 """
140 Create client read response (GET endpoint).
142 Temporary workaround: Exclude client_secret and client_secret_expires_at
143 from the response when token_endpoint_auth_method is 'none', as this
144 helps clients that haven't yet updated to properly handle public clients.
145 """
146 status, body, headers = super().create_read_client_response(client, request)
148 # Check if this is a public client (token_endpoint_auth_method = none)
149 if isinstance(body, dict):
150 token_endpoint_auth_method = body.get("token_endpoint_auth_method")
151 if token_endpoint_auth_method == "none":
152 # Remove client_secret fields for public clients as a temporary workaround
153 body.pop("client_secret", None)
154 body.pop("client_secret_expires_at", None)
156 return status, body, headers
158 def authenticate_token(self, request: StarletteJsonRequest) -> User | str | None: 1a
159 if request.user is not None:
160 return request.user
162 authorization = request.headers.get("Authorization")
163 if authorization is None:
164 return None
166 scheme, _, token = authorization.partition(" ")
167 if scheme.lower() == "bearer" and token != "":
168 return token
170 return None
172 def authenticate_client(self, request: StarletteJsonRequest) -> OAuth2Client | None: 1a
173 client_id = request.path_params.get("client_id")
174 if client_id is None:
175 return None
177 statement = select(OAuth2Client).where(
178 OAuth2Client.deleted_at.is_(None), OAuth2Client.client_id == client_id
179 )
180 result = self.server.session.execute(statement)
181 client = result.unique().scalar_one_or_none()
183 if client is None:
184 return None
186 credential = request.credential
187 if (
188 credential is None
189 or (
190 isinstance(credential, str)
191 and not secrets.compare_digest(
192 client.registration_access_token, credential
193 )
194 )
195 or (isinstance(credential, User) and client.user_id != credential.id)
196 ):
197 return None
199 return client
201 def revoke_access_token( 1a
202 self, token: typing.Any, request: StarletteJsonRequest
203 ) -> None:
204 return None
206 def check_permission( 1a
207 self, client: OAuth2Client, request: StarletteJsonRequest
208 ) -> bool:
209 return True
211 def delete_client( 1a
212 self, client: OAuth2Client, request: StarletteJsonRequest
213 ) -> None:
214 client.set_deleted_at()
215 self.server.session.flush()
217 def update_client( 1a
218 self,
219 client: OAuth2Client,
220 client_metadata: dict[str, typing.Any],
221 request: StarletteJsonRequest,
222 ) -> OAuth2Client:
223 client.set_client_metadata({**client.client_metadata, **client_metadata})
224 self.server.session.add(client)
225 self.server.session.flush()
226 return client
228 def get_server_metadata(self) -> dict[str, typing.Any]: 1a
229 return _get_server_metadata(self.server)
232class _QueryTokenMixin: 1a
233 server: "AuthorizationServer"
235 def query_token( 1a
236 self,
237 token_string: str,
238 token_type_hint: typing.Literal["access_token", "refresh_token"] | None,
239 ) -> OAuth2Token | None:
240 token_hash = get_token_hash(token_string, secret=settings.SECRET)
241 statement = select(OAuth2Token)
242 if token_type_hint == "access_token":
243 statement = statement.where(OAuth2Token.access_token == token_hash)
244 elif token_type_hint == "refresh_token":
245 statement = statement.where(OAuth2Token.refresh_token == token_hash)
246 else:
247 statement = statement.where(
248 or_(
249 OAuth2Token.access_token == token_hash,
250 OAuth2Token.refresh_token == token_hash,
251 )
252 )
254 result = self.server.session.execute(statement)
255 return result.unique().scalar_one_or_none()
258class RevocationEndpoint(_QueryTokenMixin, _RevocationEndpoint): 1a
259 CLIENT_AUTH_METHODS = ["client_secret_basic", "client_secret_post"] 1a
261 def revoke_token(self, token: OAuth2Token, request: StarletteOAuth2Request) -> None: 1a
262 now = int(time.time())
263 hint = request.form.get("token_type_hint")
264 token.access_token_revoked_at = now # pyright: ignore
265 if hint != "access_token":
266 token.refresh_token_revoked_at = now # pyright: ignore
267 self.server.session.add(token)
268 self.server.session.flush()
271class IntrospectionEndpoint(_QueryTokenMixin, _IntrospectionEndpoint): 1a
272 CLIENT_AUTH_METHODS = ["client_secret_basic", "client_secret_post"] 1a
274 def check_permission( 1a
275 self, token: OAuth2Token, client: OAuth2Client, request: StarletteOAuth2Request
276 ) -> bool:
277 return token.check_client(client) # pyright: ignore
279 def introspect_token(self, token: OAuth2Token) -> dict[str, typing.Any]: 1a
280 return token.get_introspection_data(ISSUER)
283class AuthorizationServer(_AuthorizationServer): 1a
284 if typing.TYPE_CHECKING: 284 ↛ 286line 284 didn't jump to line 286 because the condition on line 284 was never true1a
286 def create_endpoint_response(
287 self, name: str, request: Request | None = None
288 ) -> Response: ...
290 def __init__( 1a
291 self,
292 session: Session,
293 *,
294 scopes_supported: list[str] | None = None,
295 error_uris: list[tuple[str, str]] | None = None,
296 ) -> None:
297 super().__init__(scopes_supported)
298 self.session = session
299 self._error_uris = dict(error_uris) if error_uris is not None else None
301 self.register_token_generator("default", self.create_bearer_token_generator())
303 @classmethod 1a
304 def build( 1a
305 cls,
306 session: Session,
307 *,
308 scopes_supported: list[str] | None = None,
309 error_uris: list[tuple[str, str]] | None = None,
310 ) -> typing.Self:
311 authorization_server = cls(
312 session, scopes_supported=scopes_supported, error_uris=error_uris
313 )
314 authorization_server.register_endpoint(RevocationEndpoint)
315 authorization_server.register_endpoint(IntrospectionEndpoint)
316 authorization_server.register_endpoint(ClientRegistrationEndpoint)
317 authorization_server.register_endpoint(ClientConfigurationEndpoint)
318 register_grants(authorization_server)
319 return authorization_server
321 def query_client(self, client_id: str) -> OAuth2Client | None: 1a
322 statement = select(OAuth2Client).where(
323 OAuth2Client.deleted_at.is_(None), OAuth2Client.client_id == client_id
324 )
325 result = self.session.execute(statement)
326 return result.unique().scalar_one_or_none()
328 def save_token( 1a
329 self, token: dict[str, typing.Any], request: StarletteOAuth2Request
330 ) -> None:
331 access_token = token.get("access_token", None)
332 access_token_hash = (
333 get_token_hash(access_token, secret=settings.SECRET)
334 if access_token is not None
335 else None
336 )
338 refresh_token = token.get("refresh_token", None)
339 refresh_token_hash = (
340 get_token_hash(refresh_token, secret=settings.SECRET)
341 if refresh_token is not None
342 else None
343 )
345 token_data = {
346 **token,
347 "access_token": access_token_hash,
348 "refresh_token": refresh_token_hash,
349 }
350 sub_type, sub = typing.cast(SubTypeValue, request.user)
351 client = typing.cast(OAuth2Client, request.client)
352 oauth2_token = OAuth2Token(
353 **token_data, client_id=client.client_id, sub_type=sub_type
354 )
355 oauth2_token.sub = sub
356 self.session.add(oauth2_token)
357 self.session.flush()
359 def get_error_uri(self, request: Request, error: OAuth2Error) -> str | None: 1a
360 if self._error_uris is None or error.error is None:
361 return None
362 return self._error_uris.get(error.error)
364 def create_oauth2_request(self, request: Request) -> StarletteOAuth2Request: 1a
365 return StarletteOAuth2Request(request)
367 def create_json_request(self, request: Request) -> StarletteJsonRequest: 1a
368 return StarletteJsonRequest(request)
370 def send_signal( 1a
371 self, name: str, *args: tuple[typing.Any], **kwargs: dict[str, typing.Any]
372 ) -> None:
373 logger.debug(f"Authlib signal: {name}", *args, **kwargs)
375 def handle_response( 1a
376 self,
377 status_code: int,
378 payload: dict[str, typing.Any] | str,
379 headers: list[tuple[str, str]],
380 ) -> Response:
381 if isinstance(payload, dict):
382 payload = json.dumps(payload)
383 return Response(payload, status_code, {k: v for k, v in headers})
385 def create_bearer_token_generator(self) -> BearerTokenGenerator: 1a
386 def _access_token_generator(
387 client: OAuth2Client, grant_type: str, user: SubTypeValue, scope: str
388 ) -> str:
389 sub_type, _ = user
390 return generate_token(prefix=ACCESS_TOKEN_PREFIX[sub_type])
392 def _refresh_token_generator(
393 client: OAuth2Client, grant_type: str, user: SubTypeValue, scope: str
394 ) -> str:
395 sub_type, _ = user
396 return generate_token(prefix=REFRESH_TOKEN_PREFIX[sub_type])
398 return BearerTokenGenerator(_access_token_generator, _refresh_token_generator)
400 def create_authorization_response( 1a
401 self,
402 request: Request,
403 grant_user: User | None = None,
404 save_consent: bool = False,
405 ) -> typing.Any:
406 if not isinstance(request, StarletteOAuth2Request):
407 oauth2_request = self.create_oauth2_request(request)
408 else:
409 oauth2_request = request
411 try:
412 grant: AuthorizationCodeGrant = self.get_authorization_grant(oauth2_request)
413 except UnsupportedResponseTypeError as error:
414 return self.handle_error_response(oauth2_request, error)
416 try:
417 redirect_uri = grant.validate_authorization_request()
418 status_code, body, headers = grant.create_authorization_response(
419 redirect_uri, grant_user
420 )
421 except OAuth2Error as error:
422 return self.handle_error_response(oauth2_request, error)
424 if save_consent:
425 self._save_consent(oauth2_request, grant)
427 return self.handle_response(status_code, body, headers)
429 def _save_consent( 1a
430 self, request: StarletteOAuth2Request, grant: AuthorizationCodeGrant
431 ) -> None:
432 assert grant.sub_type is not None
433 assert grant.sub is not None
434 assert grant.client is not None
435 payload = request.payload
436 assert payload is not None
437 oauth2_grant_service.create_or_update_grant(
438 self.session,
439 sub_type=grant.sub_type,
440 sub_id=grant.sub.id,
441 client_id=grant.client.client_id,
442 scope=payload.scope,
443 )
445 @property 1a
446 def response_types_supported(self) -> list[str]: 1a
447 response_types: list[str] = []
448 for grant, _ in self._authorization_grants:
449 try:
450 response_types.extend(getattr(grant, "RESPONSE_TYPES"))
451 except AttributeError:
452 pass
453 return response_types
455 @property 1a
456 def response_modes_supported(self) -> list[str]: 1a
457 return ["query"]
459 @property 1a
460 def grant_types_supported(self) -> list[str]: 1a
461 grant_types: set[str] = set()
462 for grant, _ in [*self._authorization_grants, *self._token_grants]:
463 try:
464 grant_types.add(getattr(grant, "GRANT_TYPE"))
465 except AttributeError:
466 pass
467 return list(grant_types)
469 @property 1a
470 def token_endpoint_auth_methods_supported(self) -> list[str]: 1a
471 return ["client_secret_basic", "client_secret_post", "none"]
473 @property 1a
474 def revocation_endpoint_auth_methods_supported(self) -> list[str]: 1a
475 auth_methods: set[str] = set()
476 for endpoint in self._endpoints.get(RevocationEndpoint.ENDPOINT_NAME, []):
477 auth_methods = auth_methods.union(
478 getattr(endpoint, "CLIENT_AUTH_METHODS", [])
479 )
480 return list(auth_methods)
482 @property 1a
483 def introspection_endpoint_auth_methods_supported(self) -> list[str]: 1a
484 auth_methods: set[str] = set()
485 for endpoint in self._endpoints.get(IntrospectionEndpoint.ENDPOINT_NAME, []):
486 auth_methods = auth_methods.union(
487 getattr(endpoint, "CLIENT_AUTH_METHODS", [])
488 )
489 return list(auth_methods)
491 @property 1a
492 def code_challenge_methods_supported(self) -> list[str]: 1a
493 code_challenge_methods: set[str] = set()
494 for _, extensions in self._authorization_grants:
495 for extension in extensions:
496 if isinstance(extension, CodeChallenge):
497 code_challenge_methods = code_challenge_methods.union(
498 extension.SUPPORTED_CODE_CHALLENGE_METHOD
499 )
500 return list(code_challenge_methods)