Coverage for polar/oauth2/grants/web.py: 22%

81 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 15:52 +0000

1import uuid 1a

2from collections.abc import Iterable 1a

3from typing import Any 1a

4 

5from authlib.oauth2.rfc6749 import ClientMixin 1a

6from authlib.oauth2.rfc6749.errors import ( 1a

7 InvalidGrantError, 

8 InvalidRequestError, 

9 UnauthorizedClientError, 

10) 

11from authlib.oauth2.rfc6749.grants import BaseGrant, TokenEndpointMixin 1a

12from authlib.oauth2.rfc6749.hooks import hooked 1a

13from sqlalchemy import and_, select 1a

14 

15from polar.config import settings 1a

16from polar.kit.crypto import get_token_hash 1a

17from polar.kit.utils import utc_now 1a

18from polar.models import Organization, User, UserOrganization, UserSession 1a

19 

20from ..sub_type import SubType, SubTypeValue 1a

21 

22 

23class WebGrant(BaseGrant, TokenEndpointMixin): 1a

24 GRANT_TYPE = "web" 1a

25 TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post"] 1a

26 

27 def validate_token_request(self) -> None: 1a

28 client = self._validate_request_client() 

29 self.request.client = client 

30 sub_type_value = self._validate_request_token(client) 

31 self.request.user = sub_type_value 

32 

33 @hooked 1a

34 def create_token_response(self) -> tuple[int, Any, Iterable[tuple[str, str]]]: 1a

35 client = self.request.client 

36 sub_type_value = self.request.user 

37 scope = self.request.payload.scope or client.scope 

38 

39 token = self.generate_token( 

40 user=sub_type_value, scope=scope, include_refresh_token=False 

41 ) 

42 self.save_token(token) 

43 

44 return 200, token, self.TOKEN_RESPONSE_HEADER 

45 

46 def _validate_request_client(self) -> ClientMixin: 1a

47 client = self.authenticate_token_endpoint_client() 

48 

49 if not client.check_grant_type(self.GRANT_TYPE): 

50 raise UnauthorizedClientError( 

51 f"The client is not authorized to use 'grant_type={self.GRANT_TYPE}'" 

52 ) 

53 

54 return client 

55 

56 def _validate_request_token(self, client: ClientMixin) -> SubTypeValue: 1a

57 payload = self.request.payload 

58 if payload is None: 

59 raise InvalidRequestError("Missing request payload.") 

60 

61 data = payload.data 

62 token = data.get("session_token") 

63 if token is None: 

64 raise InvalidRequestError("Missing 'session_token' in request.") 

65 

66 sub_type: str | None = data.get("sub_type") 

67 try: 

68 sub_type = SubType(sub_type) if sub_type else SubType.user 

69 except ValueError as e: 

70 raise InvalidRequestError("Invalid sub_type") from e 

71 

72 sub: str | None = data.get("sub") 

73 if sub_type == SubType.organization and sub is None: 

74 raise InvalidRequestError("Missing 'sub' for organization sub_type") 

75 elif sub_type == SubType.user and sub is not None: 

76 raise InvalidRequestError("Can't specify 'sub' for user sub_type") 

77 

78 scope = data.get("scope", "") 

79 if scope: 

80 self.server.validate_requested_scope(scope) 

81 

82 token = get_token_hash(token, secret=settings.SECRET) 

83 statement = select(UserSession).where( 

84 UserSession.token == token, UserSession.expires_at > utc_now() 

85 ) 

86 result = self.server.session.execute(statement) 

87 user_session: UserSession | None = result.unique().scalar_one_or_none() 

88 if user_session is None: 

89 raise InvalidGrantError() 

90 

91 user = user_session.user 

92 sub_value: User | Organization | None = None 

93 if sub_type == SubType.user: 

94 sub_value = user 

95 elif sub_type == SubType.organization: 

96 assert sub is not None 

97 try: 

98 sub_uuid = uuid.UUID(sub) 

99 except ValueError as e: 

100 raise InvalidRequestError("Invalid 'sub' UUID") from e 

101 organization = self._get_organization_admin(sub_uuid, user) 

102 if organization is None: 

103 raise InvalidGrantError() 

104 sub_value = organization 

105 

106 assert sub_value is not None 

107 return sub_type, sub_value 

108 

109 def _get_organization_admin( 1a

110 self, organization_id: uuid.UUID, user: User 

111 ) -> Organization | None: 

112 statement = ( 

113 select(Organization) 

114 .join( 

115 UserOrganization, 

116 onclause=and_( 

117 UserOrganization.user_id == user.id, 

118 UserOrganization.deleted_at.is_(None), 

119 ), 

120 ) 

121 .where(Organization.id == organization_id) 

122 ) 

123 result = self.server.session.execute(statement) 

124 return result.unique().scalar_one_or_none()