Coverage for polar/oauth2/grants/authorization_code.py: 30%

154 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 15:52 +0000

1import typing 1a

2import uuid 1a

3 

4from authlib.oauth2.rfc6749.errors import ( 1a

5 AccessDeniedError, 

6 InvalidRequestError, 

7 OAuth2Error, 

8) 

9from authlib.oauth2.rfc6749.grants import ( 1a

10 AuthorizationCodeGrant as _AuthorizationCodeGrant, 

11) 

12from authlib.oauth2.rfc6749.requests import OAuth2Request 1a

13from authlib.oauth2.rfc7636 import CodeChallenge as _CodeChallenge 1a

14from authlib.oidc.core.errors import ConsentRequiredError, LoginRequiredError 1a

15from authlib.oidc.core.grants import OpenIDCode as _OpenIDCode 1a

16from authlib.oidc.core.grants import OpenIDToken as _OpenIDToken 1a

17from sqlalchemy import and_, select 1a

18from sqlalchemy.orm import Session 1a

19 

20from polar.config import settings 1a

21from polar.kit.crypto import generate_token, get_token_hash 1a

22from polar.models import ( 1a

23 OAuth2AuthorizationCode, 

24 OAuth2Client, 

25 Organization, 

26 User, 

27 UserOrganization, 

28) 

29 

30from ..constants import AUTHORIZATION_CODE_PREFIX, JWT_CONFIG 1a

31from ..requests import StarletteOAuth2Request 1a

32from ..service.oauth2_grant import oauth2_grant as oauth2_grant_service 1a

33from ..sub_type import SubType, SubTypeValue 1a

34from ..userinfo import UserInfo, generate_user_info 1a

35 

36if typing.TYPE_CHECKING: 36 ↛ 37line 36 didn't jump to line 37 because the condition on line 36 was never true1a

37 from ..authorization_server import AuthorizationServer 

38 

39 

40def _exists_nonce( 1a

41 session: Session, nonce: str, request: StarletteOAuth2Request 

42) -> bool: 

43 statement = select(OAuth2AuthorizationCode).where( 

44 OAuth2AuthorizationCode.client_id == request.client_id, 

45 OAuth2AuthorizationCode.nonce == nonce, 

46 ) 

47 result = session.execute(statement) 

48 return result.unique().scalar_one_or_none() is not None 

49 

50 

51class SubTypeGrantMixin: 1a

52 sub_type: SubType | None = None 1a

53 sub: User | Organization | None = None 1a

54 

55 

56class AuthorizationCodeGrant(SubTypeGrantMixin, _AuthorizationCodeGrant): 1a

57 server: "AuthorizationServer" 

58 TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"] 1a

59 

60 def __init__(self, request: OAuth2Request, server: "AuthorizationServer") -> None: 1a

61 super().__init__(request, server) 

62 self._hooks["before_create_authorization_response"] = set() 

63 self._hooks["before_validate_authorization_request_payload"] = { 

64 self.before_validate_authorization_request_payload 

65 } 

66 

67 def before_validate_authorization_request_payload( 1a

68 self, grant: "typing.Self", redirect_uri: str 

69 ) -> None: 

70 """ 

71 If no scope is provided in the authorization request, 

72 default to the client's scope. 

73 """ 

74 payload = self.request.payload 

75 scope: str | None = payload.data.get("scope") 

76 if scope is None: 

77 self.request.payload.data["scope"] = self.request.client.scope 

78 

79 def create_authorization_response( 1a

80 self, redirect_uri: str, grant_user: User | None 

81 ) -> tuple[int, str | dict[str, typing.Any], list[tuple[str, str]]]: 

82 payload = self.request.payload 

83 assert payload is not None 

84 

85 if not grant_user: 

86 raise AccessDeniedError(state=payload.state, redirect_uri=redirect_uri) 

87 

88 self.request.user = grant_user # pyright: ignore 

89 

90 self.execute_hook( 

91 "before_create_authorization_response", redirect_uri, grant_user 

92 ) 

93 return super().create_authorization_response(redirect_uri, grant_user) # pyright: ignore 

94 

95 def generate_authorization_code(self) -> str: 1a

96 return generate_token(prefix=AUTHORIZATION_CODE_PREFIX) 

97 

98 def save_authorization_code( 1a

99 self, code: str, request: StarletteOAuth2Request 

100 ) -> OAuth2AuthorizationCode: 

101 payload = request.payload 

102 assert payload is not None 

103 

104 nonce = payload.data.get("nonce") 

105 code_challenge = payload.data.get("code_challenge") 

106 code_challenge_method = payload.data.get("code_challenge_method") 

107 

108 assert self.sub_type is not None 

109 assert self.sub is not None 

110 

111 authorization_code = OAuth2AuthorizationCode( 

112 code=get_token_hash(code, secret=settings.SECRET), 

113 client_id=payload.client_id, 

114 sub_type=self.sub_type, 

115 scope=payload.scope, 

116 redirect_uri=payload.redirect_uri, 

117 nonce=nonce, 

118 code_challenge=code_challenge, 

119 code_challenge_method=code_challenge_method, 

120 ) 

121 authorization_code.sub = self.sub 

122 

123 self.server.session.add(authorization_code) 

124 self.server.session.flush() 

125 return authorization_code 

126 

127 def query_authorization_code( 1a

128 self, code: str, client: OAuth2Client 

129 ) -> OAuth2AuthorizationCode | None: 

130 code_hash = get_token_hash(code, secret=settings.SECRET) 

131 statement = select(OAuth2AuthorizationCode).where( 

132 OAuth2AuthorizationCode.code == code_hash, 

133 OAuth2AuthorizationCode.client_id == client.client_id, 

134 ) 

135 result = self.server.session.execute(statement) 

136 authorization_code = result.unique().scalar_one_or_none() 

137 if authorization_code is not None and not typing.cast( 

138 bool, authorization_code.is_expired() 

139 ): 

140 return authorization_code 

141 return None 

142 

143 def delete_authorization_code( 1a

144 self, authorization_code: OAuth2AuthorizationCode 

145 ) -> None: 

146 self.server.session.delete(authorization_code) 

147 self.server.session.flush() 

148 

149 def authenticate_user( 1a

150 self, authorization_code: OAuth2AuthorizationCode 

151 ) -> SubTypeValue | None: 

152 return authorization_code.get_sub_type_value() 

153 

154 

155class CodeChallenge(_CodeChallenge): 1a

156 pass 1a

157 

158 

159class OpenIDCode(_OpenIDCode): 1a

160 def __init__(self, session: Session, require_nonce: bool = False): 1a

161 super().__init__(require_nonce) 

162 self._session = session 

163 

164 def exists_nonce(self, nonce: str, request: StarletteOAuth2Request) -> bool: 1a

165 return _exists_nonce(self._session, nonce, request) 

166 

167 def get_jwt_config(self, grant: AuthorizationCodeGrant) -> dict[str, typing.Any]: 1a

168 return JWT_CONFIG 

169 

170 def generate_user_info(self, user: SubTypeValue, scope: str) -> UserInfo: 1a

171 return generate_user_info(user, scope) 

172 

173 

174class OpenIDToken(_OpenIDToken): 1a

175 def get_jwt_config(self, grant: AuthorizationCodeGrant) -> dict[str, typing.Any]: 1a

176 return JWT_CONFIG 

177 

178 def generate_user_info(self, user: SubTypeValue, scope: str) -> UserInfo: 1a

179 return generate_user_info(user, scope) 

180 

181 

182class InvalidSubError(OAuth2Error): 1a

183 error = "invalid_sub" 1a

184 

185 

186class ValidateSubAndPrompt: 1a

187 def __init__(self, session: Session) -> None: 1a

188 self._session = session 

189 

190 def __call__(self, grant: AuthorizationCodeGrant) -> None: 1a

191 grant.register_hook("after_validate_consent_request", self._validate) 

192 grant.register_hook( 

193 "before_create_authorization_response", self._validate_resolved_sub 

194 ) 

195 

196 def _validate( 1a

197 self, 

198 grant: AuthorizationCodeGrant, 

199 redirect_uri: str, 

200 redirect_fragment: bool = False, 

201 ) -> None: 

202 self._validate_sub(grant, redirect_uri, redirect_fragment) 

203 self._validate_scope_consent(grant, redirect_uri, redirect_fragment) 

204 

205 def _validate_sub( 1a

206 self, 

207 grant: AuthorizationCodeGrant, 

208 redirect_uri: str, 

209 redirect_fragment: bool = False, 

210 ) -> None: 

211 payload = grant.request.payload 

212 assert payload is not None 

213 

214 sub_type: str | None = payload.data.get("sub_type") 

215 if sub_type: 

216 try: 

217 grant.sub_type = SubType(sub_type) 

218 except ValueError as e: 

219 raise InvalidRequestError("Invalid sub_type") from e 

220 else: 

221 client: OAuth2Client = typing.cast(OAuth2Client, grant.client) 

222 grant.sub_type = client.default_sub_type 

223 

224 sub: str | None = payload.data.get("sub") 

225 user = grant.request.user 

226 

227 if grant.sub_type == SubType.user: 

228 grant.sub = user 

229 if sub is not None: 

230 raise InvalidRequestError("Can't specify sub for user sub_type") 

231 elif ( 

232 grant.sub_type == SubType.organization 

233 and sub is not None 

234 and user is not None 

235 ): 

236 try: 

237 sub_uuid = uuid.UUID(sub) 

238 except ValueError as e: 

239 raise InvalidSubError() from e 

240 organization = self._get_organization_admin(sub_uuid, user) 

241 if organization is None: 

242 raise InvalidSubError() 

243 grant.sub = organization 

244 

245 def _validate_scope_consent( 1a

246 self, 

247 grant: AuthorizationCodeGrant, 

248 redirect_uri: str, 

249 redirect_fragment: bool = False, 

250 ) -> None: 

251 payload = grant.request.payload 

252 assert payload is not None 

253 

254 prompt = payload.data.get("prompt") 

255 

256 # Check if the sub has granted the requested scope or a subset of it 

257 has_granted_scope = False 

258 if grant.sub is not None: 

259 assert grant.client is not None 

260 assert grant.sub_type is not None 

261 has_granted_scope = oauth2_grant_service.has_granted_scope( 

262 self._session, 

263 sub_type=grant.sub_type, 

264 sub_id=grant.sub.id, 

265 client_id=grant.client.client_id, 

266 scope=payload.scope, 

267 ) 

268 

269 # If the prompt is "none", the sub must be authenticated and have granted the requested scope 

270 if prompt == "none": 

271 if grant.sub is None: 

272 raise LoginRequiredError( 

273 redirect_uri=redirect_uri, redirect_fragment=redirect_fragment 

274 ) 

275 if not has_granted_scope: 

276 raise ConsentRequiredError( 

277 redirect_uri=redirect_uri, redirect_fragment=redirect_fragment 

278 ) 

279 

280 # Bypass everything if nothing is specified and conditions are met 

281 if prompt is None and has_granted_scope: 

282 grant.prompt = "none" 

283 

284 def _validate_resolved_sub( 1a

285 self, 

286 grant: AuthorizationCodeGrant, 

287 redirect_uri: str, 

288 redirect_fragment: bool = False, 

289 ) -> None: 

290 self._validate_sub(grant, redirect_uri, redirect_fragment) 

291 if grant.sub is None: 

292 raise InvalidSubError() 

293 

294 def _get_organization_admin( 1a

295 self, organization_id: uuid.UUID, user: User 

296 ) -> Organization | None: 

297 statement = ( 

298 select(Organization) 

299 .join( 

300 UserOrganization, 

301 onclause=and_( 

302 UserOrganization.user_id == user.id, 

303 UserOrganization.deleted_at.is_(None), 

304 ), 

305 ) 

306 .where(Organization.id == organization_id) 

307 ) 

308 result = self._session.execute(statement) 

309 return result.unique().scalar_one_or_none()