Coverage for polar/oauth2/endpoints/oauth2.py: 51%

83 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 15:52 +0000

1from collections.abc import Sequence 1a

2from typing import Literal, cast 1a

3 

4from fastapi import Depends, Form, HTTPException, Request, Response 1a

5from fastapi.openapi.constants import REF_TEMPLATE 1a

6 

7from polar.auth.dependencies import WebUserOrAnonymous, WebUserRead, WebUserWrite 1a

8from polar.auth.models import is_user 1a

9from polar.kit.pagination import ListResource, PaginationParamsQuery 1a

10from polar.models import OAuth2Token, Organization 1a

11from polar.openapi import APITag 1a

12from polar.organization.repository import OrganizationRepository 1a

13from polar.postgres import AsyncSession, get_db_session 1a

14from polar.routing import APIRouter 1a

15 

16from ..authorization_server import ( 1a

17 AuthorizationServer, 

18 ClientConfigurationEndpoint, 

19 ClientRegistrationEndpoint, 

20 IntrospectionEndpoint, 

21 RevocationEndpoint, 

22) 

23from ..dependencies import get_authorization_server, get_token 1a

24from ..grants import AuthorizationCodeGrant 1a

25from ..schemas import ( 1a

26 AuthorizeResponse, 

27 IntrospectTokenResponse, 

28 OAuth2Client, 

29 OAuth2ClientConfiguration, 

30 OAuth2ClientConfigurationUpdate, 

31 RevokeTokenResponse, 

32 TokenResponse, 

33 authorize_response_adapter, 

34) 

35from ..schemas import ( 1a

36 UserInfo as UserInfoSchema, 

37) 

38from ..service.oauth2_client import oauth2_client as oauth2_client_service 1a

39from ..sub_type import SubType 1a

40from ..userinfo import UserInfo, generate_user_info 1a

41 

42router = APIRouter(prefix="/oauth2", tags=["oauth2"]) 1a

43 

44 

45@router.get( 1a

46 "/", 

47 summary="List Clients", 

48 tags=["clients", APITag.private], 

49 response_model=ListResource[OAuth2Client], 

50) 

51async def list( 1a

52 auth_subject: WebUserRead, 

53 pagination: PaginationParamsQuery, 

54 session: AsyncSession = Depends(get_db_session), 

55) -> ListResource[OAuth2Client]: 

56 """List OAuth2 clients.""" 

57 results, count = await oauth2_client_service.list( 

58 session, auth_subject, pagination=pagination 

59 ) 

60 return ListResource.from_paginated_results( 

61 [OAuth2Client.model_validate(result) for result in results], count, pagination 

62 ) 

63 

64 

65@router.post( 1a

66 "/register", 

67 summary="Create Client", 

68 tags=["clients", APITag.public], 

69 name="oauth2:create_client", 

70) 

71async def create( 1a

72 client_configuration: OAuth2ClientConfiguration, 

73 request: Request, 

74 auth_subject: WebUserOrAnonymous, 

75 authorization_server: AuthorizationServer = Depends(get_authorization_server), 

76) -> Response: 

77 """Create an OAuth2 client.""" 

78 request.state.user = auth_subject.subject if is_user(auth_subject) else None 

79 request.state.parsed_data = client_configuration.model_dump( 

80 mode="json", exclude_none=True 

81 ) 

82 return authorization_server.create_endpoint_response( 

83 ClientRegistrationEndpoint.ENDPOINT_NAME, request 

84 ) 

85 

86 

87@router.get( 1a

88 "/register/{client_id}", 

89 tags=["clients", APITag.public], 

90 summary="Get Client", 

91 name="oauth2:get_client", 

92) 

93async def get( 1a

94 client_id: str, 

95 request: Request, 

96 auth_subject: WebUserOrAnonymous, 

97 authorization_server: AuthorizationServer = Depends(get_authorization_server), 

98) -> Response: 

99 """Get an OAuth2 client by Client ID.""" 

100 request.state.user = auth_subject.subject if is_user(auth_subject) else None 

101 return authorization_server.create_endpoint_response( 

102 ClientConfigurationEndpoint.ENDPOINT_NAME, request 

103 ) 

104 

105 

106@router.put( 1a

107 "/register/{client_id}", 

108 tags=["clients", APITag.public], 

109 summary="Update Client", 

110 name="oauth2:update_client", 

111) 

112async def update( 1a

113 client_id: str, 

114 client_configuration: OAuth2ClientConfigurationUpdate, 

115 request: Request, 

116 auth_subject: WebUserOrAnonymous, 

117 authorization_server: AuthorizationServer = Depends(get_authorization_server), 

118) -> Response: 

119 """Update an OAuth2 client.""" 

120 request.state.user = auth_subject.subject if is_user(auth_subject) else None 

121 request.state.parsed_data = client_configuration.model_dump( 

122 mode="json", exclude_none=True 

123 ) 

124 return authorization_server.create_endpoint_response( 

125 ClientConfigurationEndpoint.ENDPOINT_NAME, request 

126 ) 

127 

128 

129@router.delete( 1a

130 "/register/{client_id}", 

131 tags=["clients", APITag.public], 

132 summary="Delete Client", 

133 name="oauth2:delete_client", 

134) 

135async def delete( 1a

136 client_id: str, 

137 request: Request, 

138 auth_subject: WebUserOrAnonymous, 

139 authorization_server: AuthorizationServer = Depends(get_authorization_server), 

140) -> Response: 

141 """Delete an OAuth2 client.""" 

142 request.state.user = auth_subject.subject if is_user(auth_subject) else None 

143 return authorization_server.create_endpoint_response( 

144 ClientConfigurationEndpoint.ENDPOINT_NAME, request 

145 ) 

146 

147 

148@router.get("/authorize", tags=[APITag.public]) 1a

149async def authorize( 1a

150 request: Request, 

151 auth_subject: WebUserOrAnonymous, 

152 authorization_server: AuthorizationServer = Depends(get_authorization_server), 

153 session: AsyncSession = Depends(get_db_session), 

154) -> AuthorizeResponse: 

155 user = auth_subject.subject if is_user(auth_subject) else None 

156 await request.form() 

157 grant: AuthorizationCodeGrant = authorization_server.get_consent_grant( 

158 request=request, end_user=user 

159 ) 

160 

161 if grant.prompt == "login": 

162 raise HTTPException(status_code=401) 

163 elif grant.prompt == "none": 

164 return authorization_server.create_authorization_response( 

165 request=request, grant_user=user, save_consent=False 

166 ) 

167 

168 organizations: Sequence[Organization] | None = None 

169 if grant.sub_type == SubType.organization: 

170 assert is_user(auth_subject) 

171 organization_repository = OrganizationRepository.from_session(session) 

172 organizations = await organization_repository.get_all_by_user( 

173 auth_subject.subject.id 

174 ) 

175 

176 payload = grant.request.payload 

177 assert payload is not None 

178 

179 return authorize_response_adapter.validate_python( 

180 { 

181 "client": grant.client, 

182 "scopes": payload.scope, 

183 "sub_type": grant.sub_type, 

184 "sub": grant.sub, 

185 "organizations": organizations, 

186 } 

187 ) 

188 

189 

190@router.post("/consent", tags=[APITag.private]) 1a

191async def consent( 1a

192 request: Request, 

193 auth_subject: WebUserWrite, 

194 action: Literal["allow", "deny"] = Form(...), 

195 authorization_server: AuthorizationServer = Depends(get_authorization_server), 

196) -> Response: 

197 await request.form() 

198 grant_user = auth_subject.subject if action == "allow" else None 

199 return authorization_server.create_authorization_response( 

200 request=request, grant_user=grant_user, save_consent=True 

201 ) 

202 

203 

204@router.post( 1a

205 "/token", 

206 summary="Request Token", 

207 name="oauth2:request_token", 

208 operation_id="oauth2:request_token", 

209 tags=[APITag.public], 

210 openapi_extra={ 

211 "requestBody": { 

212 "required": True, 

213 "content": { 

214 "application/x-www-form-urlencoded": { 

215 "schema": { 

216 "oneOf": [ 

217 { 

218 "$ref": REF_TEMPLATE.format( 

219 model="AuthorizationCodeTokenRequest" 

220 ) 

221 }, 

222 {"$ref": REF_TEMPLATE.format(model="RefreshTokenRequest")}, 

223 {"$ref": REF_TEMPLATE.format(model="WebTokenRequest")}, 

224 ] 

225 } 

226 } 

227 }, 

228 }, 

229 }, 

230 response_model=TokenResponse, 

231) 

232async def token( 1ab

233 request: Request, 

234 authorization_server: AuthorizationServer = Depends(get_authorization_server), 

235) -> Response: 

236 """Request an access token using a valid grant.""" 

237 await request.form() 

238 return authorization_server.create_token_response(request) 

239 

240 

241@router.post( 1a

242 "/revoke", 

243 summary="Revoke Token", 

244 name="oauth2:revoke_token", 

245 operation_id="oauth2:revoke_token", 

246 tags=[APITag.public], 

247 openapi_extra={ 

248 "requestBody": { 

249 "required": True, 

250 "content": { 

251 "application/x-www-form-urlencoded": { 

252 "schema": {"$ref": REF_TEMPLATE.format(model="RevokeTokenRequest")} 

253 } 

254 }, 

255 }, 

256 }, 

257 response_model=RevokeTokenResponse, 

258) 

259async def revoke( 1ab

260 request: Request, 

261 authorization_server: AuthorizationServer = Depends(get_authorization_server), 

262) -> Response: 

263 """Revoke an access token or a refresh token.""" 

264 await request.form() 

265 return authorization_server.create_endpoint_response( 

266 RevocationEndpoint.ENDPOINT_NAME, request 

267 ) 

268 

269 

270@router.post( 1a

271 "/introspect", 

272 summary="Introspect Token", 

273 name="oauth2:introspect_token", 

274 operation_id="oauth2:introspect_token", 

275 tags=[APITag.public], 

276 openapi_extra={ 

277 "requestBody": { 

278 "required": True, 

279 "content": { 

280 "application/x-www-form-urlencoded": { 

281 "schema": { 

282 "$ref": REF_TEMPLATE.format(model="IntrospectTokenRequest") 

283 } 

284 } 

285 }, 

286 }, 

287 }, 

288 response_model=IntrospectTokenResponse, 

289) 

290async def introspect( 1ab

291 request: Request, 

292 authorization_server: AuthorizationServer = Depends(get_authorization_server), 

293) -> Response: 

294 """Get information about an access token.""" 

295 await request.form() 

296 return authorization_server.create_endpoint_response( 

297 IntrospectionEndpoint.ENDPOINT_NAME, request 

298 ) 

299 

300 

301@router.get( 1a

302 "/userinfo", 

303 summary="Get User Info", 

304 name="oauth2:userinfo", 

305 operation_id="oauth2:userinfo", 

306 response_model=UserInfoSchema, 

307 response_model_exclude_unset=True, 

308 tags=[APITag.public], 

309 openapi_extra={"x-speakeasy-name-override": "userinfo"}, 

310) 

311async def userinfo_get(token: OAuth2Token = Depends(get_token)) -> UserInfo: 1ab

312 """Get information about the authenticated user.""" 

313 return generate_user_info(token.get_sub_type_value(), cast(str, token.scope)) 

314 

315 

316# Repeat the /userinfo endpoint to support POST requests 

317# But don't include it in the OpenAPI schema 

318@router.post( 1a

319 "/userinfo", 

320 summary="Get User Info", 

321 response_model=UserInfoSchema, 

322 response_model_exclude_unset=True, 

323 include_in_schema=False, 

324) 

325async def userinfo_post(token: OAuth2Token = Depends(get_token)) -> UserInfo: 1ab

326 """Get information about the authenticated user.""" 

327 return generate_user_info(token.get_sub_type_value(), cast(str, token.scope))