Coverage for polar/oauth2/authorization_server.py: 27%

261 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1import json 1a

2import secrets 1a

3import time 1a

4import typing 1a

5 

6import structlog 1a

7from authlib.oauth2 import AuthorizationServer as _AuthorizationServer 1a

8from authlib.oauth2 import OAuth2Error 1a

9from authlib.oauth2.rfc6749.errors import ( 1a

10 UnsupportedResponseTypeError, 

11) 

12from authlib.oauth2.rfc6750 import BearerTokenGenerator 1a

13from authlib.oauth2.rfc7009 import RevocationEndpoint as _RevocationEndpoint 1a

14from authlib.oauth2.rfc7591 import ( 1a

15 ClientRegistrationEndpoint as _ClientRegistrationEndpoint, 

16) 

17from authlib.oauth2.rfc7592 import ( 1a

18 ClientConfigurationEndpoint as _ClientConfigurationEndpoint, 

19) 

20from authlib.oauth2.rfc7662 import IntrospectionEndpoint as _IntrospectionEndpoint 1a

21from sqlalchemy import or_, select 1a

22from sqlalchemy.orm import Session 1a

23from starlette.requests import Request 1a

24from starlette.responses import Response 1a

25 

26from polar.config import settings 1a

27from polar.kit.crypto import generate_token, get_token_hash 1a

28from polar.logging import Logger 1a

29from polar.models import OAuth2Client, OAuth2Token, User 1a

30from polar.oauth2.sub_type import SubTypeValue 1a

31 

32from .constants import ( 1a

33 ACCESS_TOKEN_PREFIX, 

34 CLIENT_ID_PREFIX, 

35 CLIENT_REGISTRATION_TOKEN_PREFIX, 

36 CLIENT_SECRET_PREFIX, 

37 ISSUER, 

38 REFRESH_TOKEN_PREFIX, 

39) 

40from .grants import AuthorizationCodeGrant, CodeChallenge, register_grants 1a

41from .metadata import get_server_metadata 1a

42from .requests import StarletteJsonRequest, StarletteOAuth2Request 1a

43from .service.oauth2_grant import oauth2_grant as oauth2_grant_service 1a

44 

45logger: Logger = structlog.get_logger(__name__) 1a

46 

47 

48def _get_server_metadata(server: "AuthorizationServer") -> dict[str, typing.Any]: 1a

49 def _dummy_url_for(name: str) -> str: 

50 return name 

51 

52 return get_server_metadata(server, _dummy_url_for).model_dump(exclude_unset=True) 

53 

54 

55class ClientRegistrationEndpoint(_ClientRegistrationEndpoint): 1a

56 server: "AuthorizationServer" 

57 

58 def generate_client_registration_info( 1a

59 self, client: OAuth2Client, request: StarletteJsonRequest 

60 ) -> dict[str, str]: 

61 assert client.registration_access_token is not None 

62 return { 

63 "registration_client_uri": str( 

64 request.url_for("oauth2:get_client", client_id=client.client_id) 

65 ), 

66 "registration_access_token": client.registration_access_token, 

67 } 

68 

69 def generate_client_id(self, request: StarletteJsonRequest) -> str: 1a

70 return generate_token(prefix=CLIENT_ID_PREFIX) 

71 

72 def generate_client_secret(self, request: StarletteJsonRequest) -> str: 1a

73 return generate_token(prefix=CLIENT_SECRET_PREFIX) 

74 

75 def create_registration_response( 1a

76 self, request: StarletteJsonRequest 

77 ) -> tuple[int, dict[str, typing.Any], list[tuple[str, str]]]: 

78 """ 

79 Create client registration response. 

80 

81 Temporary workaround: Exclude client_secret and client_secret_expires_at 

82 from the response when token_endpoint_auth_method is 'none', as this 

83 helps clients that haven't yet updated to properly handle public clients. 

84 """ 

85 status, body, headers = super().create_registration_response(request) 

86 

87 # Check if this is a public client (token_endpoint_auth_method = none) 

88 if isinstance(body, dict): 

89 token_endpoint_auth_method = body.get("token_endpoint_auth_method") 

90 if token_endpoint_auth_method == "none": 

91 # Remove client_secret fields for public clients as a temporary workaround 

92 body.pop("client_secret", None) 

93 body.pop("client_secret_expires_at", None) 

94 

95 return status, body, headers 

96 

97 def get_server_metadata(self) -> dict[str, typing.Any]: 1a

98 return _get_server_metadata(self.server) 

99 

100 def authenticate_token(self, request: StarletteJsonRequest) -> User | str: 1a

101 return request.user if request.user is not None else "dynamic_client" 

102 

103 def save_client( 1a

104 self, 

105 client_info: dict[str, typing.Any], 

106 client_metadata: dict[str, typing.Any], 

107 request: StarletteJsonRequest, 

108 ) -> OAuth2Client: 

109 oauth2_client = OAuth2Client(**client_info) 

110 oauth2_client.set_client_metadata(client_metadata) 

111 

112 if request.user is not None: 

113 oauth2_client.user_id = request.user.id 

114 oauth2_client.registration_access_token = generate_token( 

115 prefix=CLIENT_REGISTRATION_TOKEN_PREFIX 

116 ) 

117 

118 self.server.session.add(oauth2_client) 

119 self.server.session.flush() 

120 return oauth2_client 

121 

122 

123class ClientConfigurationEndpoint(_ClientConfigurationEndpoint): 1a

124 server: "AuthorizationServer" 

125 

126 def generate_client_registration_info( 1a

127 self, client: OAuth2Client, request: StarletteJsonRequest 

128 ) -> dict[str, str]: 

129 return { 

130 "registration_client_uri": str( 

131 request.url_for("oauth2:get_client", client_id=client.client_id) 

132 ), 

133 "registration_access_token": client.registration_access_token, 

134 } 

135 

136 def create_read_client_response( 1a

137 self, client: OAuth2Client, request: StarletteJsonRequest 

138 ) -> tuple[int, dict[str, typing.Any], list[tuple[str, str]]]: 

139 """ 

140 Create client read response (GET endpoint). 

141 

142 Temporary workaround: Exclude client_secret and client_secret_expires_at 

143 from the response when token_endpoint_auth_method is 'none', as this 

144 helps clients that haven't yet updated to properly handle public clients. 

145 """ 

146 status, body, headers = super().create_read_client_response(client, request) 

147 

148 # Check if this is a public client (token_endpoint_auth_method = none) 

149 if isinstance(body, dict): 

150 token_endpoint_auth_method = body.get("token_endpoint_auth_method") 

151 if token_endpoint_auth_method == "none": 

152 # Remove client_secret fields for public clients as a temporary workaround 

153 body.pop("client_secret", None) 

154 body.pop("client_secret_expires_at", None) 

155 

156 return status, body, headers 

157 

158 def authenticate_token(self, request: StarletteJsonRequest) -> User | str | None: 1a

159 if request.user is not None: 

160 return request.user 

161 

162 authorization = request.headers.get("Authorization") 

163 if authorization is None: 

164 return None 

165 

166 scheme, _, token = authorization.partition(" ") 

167 if scheme.lower() == "bearer" and token != "": 

168 return token 

169 

170 return None 

171 

172 def authenticate_client(self, request: StarletteJsonRequest) -> OAuth2Client | None: 1a

173 client_id = request.path_params.get("client_id") 

174 if client_id is None: 

175 return None 

176 

177 statement = select(OAuth2Client).where( 

178 OAuth2Client.deleted_at.is_(None), OAuth2Client.client_id == client_id 

179 ) 

180 result = self.server.session.execute(statement) 

181 client = result.unique().scalar_one_or_none() 

182 

183 if client is None: 

184 return None 

185 

186 credential = request.credential 

187 if ( 

188 credential is None 

189 or ( 

190 isinstance(credential, str) 

191 and not secrets.compare_digest( 

192 client.registration_access_token, credential 

193 ) 

194 ) 

195 or (isinstance(credential, User) and client.user_id != credential.id) 

196 ): 

197 return None 

198 

199 return client 

200 

201 def revoke_access_token( 1a

202 self, token: typing.Any, request: StarletteJsonRequest 

203 ) -> None: 

204 return None 

205 

206 def check_permission( 1a

207 self, client: OAuth2Client, request: StarletteJsonRequest 

208 ) -> bool: 

209 return True 

210 

211 def delete_client( 1a

212 self, client: OAuth2Client, request: StarletteJsonRequest 

213 ) -> None: 

214 client.set_deleted_at() 

215 self.server.session.flush() 

216 

217 def update_client( 1a

218 self, 

219 client: OAuth2Client, 

220 client_metadata: dict[str, typing.Any], 

221 request: StarletteJsonRequest, 

222 ) -> OAuth2Client: 

223 client.set_client_metadata({**client.client_metadata, **client_metadata}) 

224 self.server.session.add(client) 

225 self.server.session.flush() 

226 return client 

227 

228 def get_server_metadata(self) -> dict[str, typing.Any]: 1a

229 return _get_server_metadata(self.server) 

230 

231 

232class _QueryTokenMixin: 1a

233 server: "AuthorizationServer" 

234 

235 def query_token( 1a

236 self, 

237 token_string: str, 

238 token_type_hint: typing.Literal["access_token", "refresh_token"] | None, 

239 ) -> OAuth2Token | None: 

240 token_hash = get_token_hash(token_string, secret=settings.SECRET) 

241 statement = select(OAuth2Token) 

242 if token_type_hint == "access_token": 

243 statement = statement.where(OAuth2Token.access_token == token_hash) 

244 elif token_type_hint == "refresh_token": 

245 statement = statement.where(OAuth2Token.refresh_token == token_hash) 

246 else: 

247 statement = statement.where( 

248 or_( 

249 OAuth2Token.access_token == token_hash, 

250 OAuth2Token.refresh_token == token_hash, 

251 ) 

252 ) 

253 

254 result = self.server.session.execute(statement) 

255 return result.unique().scalar_one_or_none() 

256 

257 

258class RevocationEndpoint(_QueryTokenMixin, _RevocationEndpoint): 1a

259 CLIENT_AUTH_METHODS = ["client_secret_basic", "client_secret_post"] 1a

260 

261 def revoke_token(self, token: OAuth2Token, request: StarletteOAuth2Request) -> None: 1a

262 now = int(time.time()) 

263 hint = request.form.get("token_type_hint") 

264 token.access_token_revoked_at = now # pyright: ignore 

265 if hint != "access_token": 

266 token.refresh_token_revoked_at = now # pyright: ignore 

267 self.server.session.add(token) 

268 self.server.session.flush() 

269 

270 

271class IntrospectionEndpoint(_QueryTokenMixin, _IntrospectionEndpoint): 1a

272 CLIENT_AUTH_METHODS = ["client_secret_basic", "client_secret_post"] 1a

273 

274 def check_permission( 1a

275 self, token: OAuth2Token, client: OAuth2Client, request: StarletteOAuth2Request 

276 ) -> bool: 

277 return token.check_client(client) # pyright: ignore 

278 

279 def introspect_token(self, token: OAuth2Token) -> dict[str, typing.Any]: 1a

280 return token.get_introspection_data(ISSUER) 

281 

282 

283class AuthorizationServer(_AuthorizationServer): 1a

284 if typing.TYPE_CHECKING: 284 ↛ 286line 284 didn't jump to line 286 because the condition on line 284 was never true1a

285 

286 def create_endpoint_response( 

287 self, name: str, request: Request | None = None 

288 ) -> Response: ... 

289 

290 def __init__( 1a

291 self, 

292 session: Session, 

293 *, 

294 scopes_supported: list[str] | None = None, 

295 error_uris: list[tuple[str, str]] | None = None, 

296 ) -> None: 

297 super().__init__(scopes_supported) 

298 self.session = session 

299 self._error_uris = dict(error_uris) if error_uris is not None else None 

300 

301 self.register_token_generator("default", self.create_bearer_token_generator()) 

302 

303 @classmethod 1a

304 def build( 1a

305 cls, 

306 session: Session, 

307 *, 

308 scopes_supported: list[str] | None = None, 

309 error_uris: list[tuple[str, str]] | None = None, 

310 ) -> typing.Self: 

311 authorization_server = cls( 

312 session, scopes_supported=scopes_supported, error_uris=error_uris 

313 ) 

314 authorization_server.register_endpoint(RevocationEndpoint) 

315 authorization_server.register_endpoint(IntrospectionEndpoint) 

316 authorization_server.register_endpoint(ClientRegistrationEndpoint) 

317 authorization_server.register_endpoint(ClientConfigurationEndpoint) 

318 register_grants(authorization_server) 

319 return authorization_server 

320 

321 def query_client(self, client_id: str) -> OAuth2Client | None: 1a

322 statement = select(OAuth2Client).where( 

323 OAuth2Client.deleted_at.is_(None), OAuth2Client.client_id == client_id 

324 ) 

325 result = self.session.execute(statement) 

326 return result.unique().scalar_one_or_none() 

327 

328 def save_token( 1a

329 self, token: dict[str, typing.Any], request: StarletteOAuth2Request 

330 ) -> None: 

331 access_token = token.get("access_token", None) 

332 access_token_hash = ( 

333 get_token_hash(access_token, secret=settings.SECRET) 

334 if access_token is not None 

335 else None 

336 ) 

337 

338 refresh_token = token.get("refresh_token", None) 

339 refresh_token_hash = ( 

340 get_token_hash(refresh_token, secret=settings.SECRET) 

341 if refresh_token is not None 

342 else None 

343 ) 

344 

345 token_data = { 

346 **token, 

347 "access_token": access_token_hash, 

348 "refresh_token": refresh_token_hash, 

349 } 

350 sub_type, sub = typing.cast(SubTypeValue, request.user) 

351 client = typing.cast(OAuth2Client, request.client) 

352 oauth2_token = OAuth2Token( 

353 **token_data, client_id=client.client_id, sub_type=sub_type 

354 ) 

355 oauth2_token.sub = sub 

356 self.session.add(oauth2_token) 

357 self.session.flush() 

358 

359 def get_error_uri(self, request: Request, error: OAuth2Error) -> str | None: 1a

360 if self._error_uris is None or error.error is None: 

361 return None 

362 return self._error_uris.get(error.error) 

363 

364 def create_oauth2_request(self, request: Request) -> StarletteOAuth2Request: 1a

365 return StarletteOAuth2Request(request) 

366 

367 def create_json_request(self, request: Request) -> StarletteJsonRequest: 1a

368 return StarletteJsonRequest(request) 

369 

370 def send_signal( 1a

371 self, name: str, *args: tuple[typing.Any], **kwargs: dict[str, typing.Any] 

372 ) -> None: 

373 logger.debug(f"Authlib signal: {name}", *args, **kwargs) 

374 

375 def handle_response( 1a

376 self, 

377 status_code: int, 

378 payload: dict[str, typing.Any] | str, 

379 headers: list[tuple[str, str]], 

380 ) -> Response: 

381 if isinstance(payload, dict): 

382 payload = json.dumps(payload) 

383 return Response(payload, status_code, {k: v for k, v in headers}) 

384 

385 def create_bearer_token_generator(self) -> BearerTokenGenerator: 1a

386 def _access_token_generator( 

387 client: OAuth2Client, grant_type: str, user: SubTypeValue, scope: str 

388 ) -> str: 

389 sub_type, _ = user 

390 return generate_token(prefix=ACCESS_TOKEN_PREFIX[sub_type]) 

391 

392 def _refresh_token_generator( 

393 client: OAuth2Client, grant_type: str, user: SubTypeValue, scope: str 

394 ) -> str: 

395 sub_type, _ = user 

396 return generate_token(prefix=REFRESH_TOKEN_PREFIX[sub_type]) 

397 

398 return BearerTokenGenerator(_access_token_generator, _refresh_token_generator) 

399 

400 def create_authorization_response( 1a

401 self, 

402 request: Request, 

403 grant_user: User | None = None, 

404 save_consent: bool = False, 

405 ) -> typing.Any: 

406 if not isinstance(request, StarletteOAuth2Request): 

407 oauth2_request = self.create_oauth2_request(request) 

408 else: 

409 oauth2_request = request 

410 

411 try: 

412 grant: AuthorizationCodeGrant = self.get_authorization_grant(oauth2_request) 

413 except UnsupportedResponseTypeError as error: 

414 return self.handle_error_response(oauth2_request, error) 

415 

416 try: 

417 redirect_uri = grant.validate_authorization_request() 

418 status_code, body, headers = grant.create_authorization_response( 

419 redirect_uri, grant_user 

420 ) 

421 except OAuth2Error as error: 

422 return self.handle_error_response(oauth2_request, error) 

423 

424 if save_consent: 

425 self._save_consent(oauth2_request, grant) 

426 

427 return self.handle_response(status_code, body, headers) 

428 

429 def _save_consent( 1a

430 self, request: StarletteOAuth2Request, grant: AuthorizationCodeGrant 

431 ) -> None: 

432 assert grant.sub_type is not None 

433 assert grant.sub is not None 

434 assert grant.client is not None 

435 payload = request.payload 

436 assert payload is not None 

437 oauth2_grant_service.create_or_update_grant( 

438 self.session, 

439 sub_type=grant.sub_type, 

440 sub_id=grant.sub.id, 

441 client_id=grant.client.client_id, 

442 scope=payload.scope, 

443 ) 

444 

445 @property 1a

446 def response_types_supported(self) -> list[str]: 1a

447 response_types: list[str] = [] 

448 for grant, _ in self._authorization_grants: 

449 try: 

450 response_types.extend(getattr(grant, "RESPONSE_TYPES")) 

451 except AttributeError: 

452 pass 

453 return response_types 

454 

455 @property 1a

456 def response_modes_supported(self) -> list[str]: 1a

457 return ["query"] 

458 

459 @property 1a

460 def grant_types_supported(self) -> list[str]: 1a

461 grant_types: set[str] = set() 

462 for grant, _ in [*self._authorization_grants, *self._token_grants]: 

463 try: 

464 grant_types.add(getattr(grant, "GRANT_TYPE")) 

465 except AttributeError: 

466 pass 

467 return list(grant_types) 

468 

469 @property 1a

470 def token_endpoint_auth_methods_supported(self) -> list[str]: 1a

471 return ["client_secret_basic", "client_secret_post", "none"] 

472 

473 @property 1a

474 def revocation_endpoint_auth_methods_supported(self) -> list[str]: 1a

475 auth_methods: set[str] = set() 

476 for endpoint in self._endpoints.get(RevocationEndpoint.ENDPOINT_NAME, []): 

477 auth_methods = auth_methods.union( 

478 getattr(endpoint, "CLIENT_AUTH_METHODS", []) 

479 ) 

480 return list(auth_methods) 

481 

482 @property 1a

483 def introspection_endpoint_auth_methods_supported(self) -> list[str]: 1a

484 auth_methods: set[str] = set() 

485 for endpoint in self._endpoints.get(IntrospectionEndpoint.ENDPOINT_NAME, []): 

486 auth_methods = auth_methods.union( 

487 getattr(endpoint, "CLIENT_AUTH_METHODS", []) 

488 ) 

489 return list(auth_methods) 

490 

491 @property 1a

492 def code_challenge_methods_supported(self) -> list[str]: 1a

493 code_challenge_methods: set[str] = set() 

494 for _, extensions in self._authorization_grants: 

495 for extension in extensions: 

496 if isinstance(extension, CodeChallenge): 

497 code_challenge_methods = code_challenge_methods.union( 

498 extension.SUPPORTED_CODE_CHALLENGE_METHOD 

499 ) 

500 return list(code_challenge_methods)