Coverage for polar/oauth2/grants/web.py: 22%
81 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
1import uuid 1a
2from collections.abc import Iterable 1a
3from typing import Any 1a
5from authlib.oauth2.rfc6749 import ClientMixin 1a
6from authlib.oauth2.rfc6749.errors import ( 1a
7 InvalidGrantError,
8 InvalidRequestError,
9 UnauthorizedClientError,
10)
11from authlib.oauth2.rfc6749.grants import BaseGrant, TokenEndpointMixin 1a
12from authlib.oauth2.rfc6749.hooks import hooked 1a
13from sqlalchemy import and_, select 1a
15from polar.config import settings 1a
16from polar.kit.crypto import get_token_hash 1a
17from polar.kit.utils import utc_now 1a
18from polar.models import Organization, User, UserOrganization, UserSession 1a
20from ..sub_type import SubType, SubTypeValue 1a
23class WebGrant(BaseGrant, TokenEndpointMixin): 1a
24 GRANT_TYPE = "web" 1a
25 TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post"] 1a
27 def validate_token_request(self) -> None: 1a
28 client = self._validate_request_client()
29 self.request.client = client
30 sub_type_value = self._validate_request_token(client)
31 self.request.user = sub_type_value
33 @hooked 1a
34 def create_token_response(self) -> tuple[int, Any, Iterable[tuple[str, str]]]: 1a
35 client = self.request.client
36 sub_type_value = self.request.user
37 scope = self.request.payload.scope or client.scope
39 token = self.generate_token(
40 user=sub_type_value, scope=scope, include_refresh_token=False
41 )
42 self.save_token(token)
44 return 200, token, self.TOKEN_RESPONSE_HEADER
46 def _validate_request_client(self) -> ClientMixin: 1a
47 client = self.authenticate_token_endpoint_client()
49 if not client.check_grant_type(self.GRANT_TYPE):
50 raise UnauthorizedClientError(
51 f"The client is not authorized to use 'grant_type={self.GRANT_TYPE}'"
52 )
54 return client
56 def _validate_request_token(self, client: ClientMixin) -> SubTypeValue: 1a
57 payload = self.request.payload
58 if payload is None:
59 raise InvalidRequestError("Missing request payload.")
61 data = payload.data
62 token = data.get("session_token")
63 if token is None:
64 raise InvalidRequestError("Missing 'session_token' in request.")
66 sub_type: str | None = data.get("sub_type")
67 try:
68 sub_type = SubType(sub_type) if sub_type else SubType.user
69 except ValueError as e:
70 raise InvalidRequestError("Invalid sub_type") from e
72 sub: str | None = data.get("sub")
73 if sub_type == SubType.organization and sub is None:
74 raise InvalidRequestError("Missing 'sub' for organization sub_type")
75 elif sub_type == SubType.user and sub is not None:
76 raise InvalidRequestError("Can't specify 'sub' for user sub_type")
78 scope = data.get("scope", "")
79 if scope:
80 self.server.validate_requested_scope(scope)
82 token = get_token_hash(token, secret=settings.SECRET)
83 statement = select(UserSession).where(
84 UserSession.token == token, UserSession.expires_at > utc_now()
85 )
86 result = self.server.session.execute(statement)
87 user_session: UserSession | None = result.unique().scalar_one_or_none()
88 if user_session is None:
89 raise InvalidGrantError()
91 user = user_session.user
92 sub_value: User | Organization | None = None
93 if sub_type == SubType.user:
94 sub_value = user
95 elif sub_type == SubType.organization:
96 assert sub is not None
97 try:
98 sub_uuid = uuid.UUID(sub)
99 except ValueError as e:
100 raise InvalidRequestError("Invalid 'sub' UUID") from e
101 organization = self._get_organization_admin(sub_uuid, user)
102 if organization is None:
103 raise InvalidGrantError()
104 sub_value = organization
106 assert sub_value is not None
107 return sub_type, sub_value
109 def _get_organization_admin( 1a
110 self, organization_id: uuid.UUID, user: User
111 ) -> Organization | None:
112 statement = (
113 select(Organization)
114 .join(
115 UserOrganization,
116 onclause=and_(
117 UserOrganization.user_id == user.id,
118 UserOrganization.deleted_at.is_(None),
119 ),
120 )
121 .where(Organization.id == organization_id)
122 )
123 result = self.server.session.execute(statement)
124 return result.unique().scalar_one_or_none()