Coverage for polar/oauth2/endpoints/oauth2.py: 51%
83 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
1from collections.abc import Sequence 1a
2from typing import Literal, cast 1a
4from fastapi import Depends, Form, HTTPException, Request, Response 1a
5from fastapi.openapi.constants import REF_TEMPLATE 1a
7from polar.auth.dependencies import WebUserOrAnonymous, WebUserRead, WebUserWrite 1a
8from polar.auth.models import is_user 1a
9from polar.kit.pagination import ListResource, PaginationParamsQuery 1a
10from polar.models import OAuth2Token, Organization 1a
11from polar.openapi import APITag 1a
12from polar.organization.repository import OrganizationRepository 1a
13from polar.postgres import AsyncSession, get_db_session 1a
14from polar.routing import APIRouter 1a
16from ..authorization_server import ( 1a
17 AuthorizationServer,
18 ClientConfigurationEndpoint,
19 ClientRegistrationEndpoint,
20 IntrospectionEndpoint,
21 RevocationEndpoint,
22)
23from ..dependencies import get_authorization_server, get_token 1a
24from ..grants import AuthorizationCodeGrant 1a
25from ..schemas import ( 1a
26 AuthorizeResponse,
27 IntrospectTokenResponse,
28 OAuth2Client,
29 OAuth2ClientConfiguration,
30 OAuth2ClientConfigurationUpdate,
31 RevokeTokenResponse,
32 TokenResponse,
33 authorize_response_adapter,
34)
35from ..schemas import ( 1a
36 UserInfo as UserInfoSchema,
37)
38from ..service.oauth2_client import oauth2_client as oauth2_client_service 1a
39from ..sub_type import SubType 1a
40from ..userinfo import UserInfo, generate_user_info 1a
42router = APIRouter(prefix="/oauth2", tags=["oauth2"]) 1a
45@router.get( 1a
46 "/",
47 summary="List Clients",
48 tags=["clients", APITag.private],
49 response_model=ListResource[OAuth2Client],
50)
51async def list( 1a
52 auth_subject: WebUserRead,
53 pagination: PaginationParamsQuery,
54 session: AsyncSession = Depends(get_db_session),
55) -> ListResource[OAuth2Client]:
56 """List OAuth2 clients."""
57 results, count = await oauth2_client_service.list(
58 session, auth_subject, pagination=pagination
59 )
60 return ListResource.from_paginated_results(
61 [OAuth2Client.model_validate(result) for result in results], count, pagination
62 )
65@router.post( 1a
66 "/register",
67 summary="Create Client",
68 tags=["clients", APITag.public],
69 name="oauth2:create_client",
70)
71async def create( 1a
72 client_configuration: OAuth2ClientConfiguration,
73 request: Request,
74 auth_subject: WebUserOrAnonymous,
75 authorization_server: AuthorizationServer = Depends(get_authorization_server),
76) -> Response:
77 """Create an OAuth2 client."""
78 request.state.user = auth_subject.subject if is_user(auth_subject) else None
79 request.state.parsed_data = client_configuration.model_dump(
80 mode="json", exclude_none=True
81 )
82 return authorization_server.create_endpoint_response(
83 ClientRegistrationEndpoint.ENDPOINT_NAME, request
84 )
87@router.get( 1a
88 "/register/{client_id}",
89 tags=["clients", APITag.public],
90 summary="Get Client",
91 name="oauth2:get_client",
92)
93async def get( 1a
94 client_id: str,
95 request: Request,
96 auth_subject: WebUserOrAnonymous,
97 authorization_server: AuthorizationServer = Depends(get_authorization_server),
98) -> Response:
99 """Get an OAuth2 client by Client ID."""
100 request.state.user = auth_subject.subject if is_user(auth_subject) else None
101 return authorization_server.create_endpoint_response(
102 ClientConfigurationEndpoint.ENDPOINT_NAME, request
103 )
106@router.put( 1a
107 "/register/{client_id}",
108 tags=["clients", APITag.public],
109 summary="Update Client",
110 name="oauth2:update_client",
111)
112async def update( 1a
113 client_id: str,
114 client_configuration: OAuth2ClientConfigurationUpdate,
115 request: Request,
116 auth_subject: WebUserOrAnonymous,
117 authorization_server: AuthorizationServer = Depends(get_authorization_server),
118) -> Response:
119 """Update an OAuth2 client."""
120 request.state.user = auth_subject.subject if is_user(auth_subject) else None
121 request.state.parsed_data = client_configuration.model_dump(
122 mode="json", exclude_none=True
123 )
124 return authorization_server.create_endpoint_response(
125 ClientConfigurationEndpoint.ENDPOINT_NAME, request
126 )
129@router.delete( 1a
130 "/register/{client_id}",
131 tags=["clients", APITag.public],
132 summary="Delete Client",
133 name="oauth2:delete_client",
134)
135async def delete( 1a
136 client_id: str,
137 request: Request,
138 auth_subject: WebUserOrAnonymous,
139 authorization_server: AuthorizationServer = Depends(get_authorization_server),
140) -> Response:
141 """Delete an OAuth2 client."""
142 request.state.user = auth_subject.subject if is_user(auth_subject) else None
143 return authorization_server.create_endpoint_response(
144 ClientConfigurationEndpoint.ENDPOINT_NAME, request
145 )
148@router.get("/authorize", tags=[APITag.public]) 1a
149async def authorize( 1a
150 request: Request,
151 auth_subject: WebUserOrAnonymous,
152 authorization_server: AuthorizationServer = Depends(get_authorization_server),
153 session: AsyncSession = Depends(get_db_session),
154) -> AuthorizeResponse:
155 user = auth_subject.subject if is_user(auth_subject) else None
156 await request.form()
157 grant: AuthorizationCodeGrant = authorization_server.get_consent_grant(
158 request=request, end_user=user
159 )
161 if grant.prompt == "login":
162 raise HTTPException(status_code=401)
163 elif grant.prompt == "none":
164 return authorization_server.create_authorization_response(
165 request=request, grant_user=user, save_consent=False
166 )
168 organizations: Sequence[Organization] | None = None
169 if grant.sub_type == SubType.organization:
170 assert is_user(auth_subject)
171 organization_repository = OrganizationRepository.from_session(session)
172 organizations = await organization_repository.get_all_by_user(
173 auth_subject.subject.id
174 )
176 payload = grant.request.payload
177 assert payload is not None
179 return authorize_response_adapter.validate_python(
180 {
181 "client": grant.client,
182 "scopes": payload.scope,
183 "sub_type": grant.sub_type,
184 "sub": grant.sub,
185 "organizations": organizations,
186 }
187 )
190@router.post("/consent", tags=[APITag.private]) 1a
191async def consent( 1a
192 request: Request,
193 auth_subject: WebUserWrite,
194 action: Literal["allow", "deny"] = Form(...),
195 authorization_server: AuthorizationServer = Depends(get_authorization_server),
196) -> Response:
197 await request.form()
198 grant_user = auth_subject.subject if action == "allow" else None
199 return authorization_server.create_authorization_response(
200 request=request, grant_user=grant_user, save_consent=True
201 )
204@router.post( 1a
205 "/token",
206 summary="Request Token",
207 name="oauth2:request_token",
208 operation_id="oauth2:request_token",
209 tags=[APITag.public],
210 openapi_extra={
211 "requestBody": {
212 "required": True,
213 "content": {
214 "application/x-www-form-urlencoded": {
215 "schema": {
216 "oneOf": [
217 {
218 "$ref": REF_TEMPLATE.format(
219 model="AuthorizationCodeTokenRequest"
220 )
221 },
222 {"$ref": REF_TEMPLATE.format(model="RefreshTokenRequest")},
223 {"$ref": REF_TEMPLATE.format(model="WebTokenRequest")},
224 ]
225 }
226 }
227 },
228 },
229 },
230 response_model=TokenResponse,
231)
232async def token( 1ab
233 request: Request,
234 authorization_server: AuthorizationServer = Depends(get_authorization_server),
235) -> Response:
236 """Request an access token using a valid grant."""
237 await request.form()
238 return authorization_server.create_token_response(request)
241@router.post( 1a
242 "/revoke",
243 summary="Revoke Token",
244 name="oauth2:revoke_token",
245 operation_id="oauth2:revoke_token",
246 tags=[APITag.public],
247 openapi_extra={
248 "requestBody": {
249 "required": True,
250 "content": {
251 "application/x-www-form-urlencoded": {
252 "schema": {"$ref": REF_TEMPLATE.format(model="RevokeTokenRequest")}
253 }
254 },
255 },
256 },
257 response_model=RevokeTokenResponse,
258)
259async def revoke( 1ab
260 request: Request,
261 authorization_server: AuthorizationServer = Depends(get_authorization_server),
262) -> Response:
263 """Revoke an access token or a refresh token."""
264 await request.form()
265 return authorization_server.create_endpoint_response(
266 RevocationEndpoint.ENDPOINT_NAME, request
267 )
270@router.post( 1a
271 "/introspect",
272 summary="Introspect Token",
273 name="oauth2:introspect_token",
274 operation_id="oauth2:introspect_token",
275 tags=[APITag.public],
276 openapi_extra={
277 "requestBody": {
278 "required": True,
279 "content": {
280 "application/x-www-form-urlencoded": {
281 "schema": {
282 "$ref": REF_TEMPLATE.format(model="IntrospectTokenRequest")
283 }
284 }
285 },
286 },
287 },
288 response_model=IntrospectTokenResponse,
289)
290async def introspect( 1ab
291 request: Request,
292 authorization_server: AuthorizationServer = Depends(get_authorization_server),
293) -> Response:
294 """Get information about an access token."""
295 await request.form()
296 return authorization_server.create_endpoint_response(
297 IntrospectionEndpoint.ENDPOINT_NAME, request
298 )
301@router.get( 1a
302 "/userinfo",
303 summary="Get User Info",
304 name="oauth2:userinfo",
305 operation_id="oauth2:userinfo",
306 response_model=UserInfoSchema,
307 response_model_exclude_unset=True,
308 tags=[APITag.public],
309 openapi_extra={"x-speakeasy-name-override": "userinfo"},
310)
311async def userinfo_get(token: OAuth2Token = Depends(get_token)) -> UserInfo: 1ab
312 """Get information about the authenticated user."""
313 return generate_user_info(token.get_sub_type_value(), cast(str, token.scope))
316# Repeat the /userinfo endpoint to support POST requests
317# But don't include it in the OpenAPI schema
318@router.post( 1a
319 "/userinfo",
320 summary="Get User Info",
321 response_model=UserInfoSchema,
322 response_model_exclude_unset=True,
323 include_in_schema=False,
324)
325async def userinfo_post(token: OAuth2Token = Depends(get_token)) -> UserInfo: 1ab
326 """Get information about the authenticated user."""
327 return generate_user_info(token.get_sub_type_value(), cast(str, token.scope))