Coverage for polar/auth/middlewares.py: 67%
88 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
1import logfire 1a
2import structlog 1a
3from fastapi import Request 1a
4from fastapi.security.utils import get_authorization_scheme_param 1a
5from starlette.types import ASGIApp, Receive, Send 1a
6from starlette.types import Scope as ASGIScope 1a
8from polar.customer_session.service import customer_session as customer_session_service 1a
9from polar.kit.utils import utc_now 1a
10from polar.logging import Logger 1a
11from polar.models import ( 1a
12 CustomerSession,
13 OAuth2Token,
14 OrganizationAccessToken,
15 PersonalAccessToken,
16 UserSession,
17)
18from polar.oauth2.constants import is_registration_token_prefix 1a
19from polar.oauth2.exception_handlers import OAuth2Error, oauth2_error_exception_handler 1a
20from polar.oauth2.exceptions import InvalidTokenError 1a
21from polar.oauth2.service.oauth2_token import oauth2_token as oauth2_token_service 1a
22from polar.organization_access_token.service import ( 1a
23 organization_access_token as organization_access_token_service,
24)
25from polar.personal_access_token.service import ( 1a
26 personal_access_token as personal_access_token_service,
27)
28from polar.postgres import AsyncSession 1a
29from polar.sentry import set_sentry_user 1a
30from polar.worker._enqueue import enqueue_job 1a
32from .models import Anonymous, AuthSubject, Subject 1a
33from .scope import Scope 1a
34from .service import auth as auth_service 1a
36log: Logger = structlog.get_logger(__name__) 1a
39async def get_user_session( 1a
40 request: Request, session: AsyncSession
41) -> UserSession | None:
42 return await auth_service.authenticate(session, request) 1bc
45def get_bearer_token(request: Request) -> str | None: 1a
46 authorization = request.headers.get("Authorization") 1bc
47 scheme, value = get_authorization_scheme_param(authorization) 1bc
48 if not scheme or not value or scheme.lower() != "bearer": 1bc
49 return None 1bc
50 if not value.isascii(): 50 ↛ 51line 50 didn't jump to line 51 because the condition on line 50 was never true1b
51 return None
52 return value 1b
55async def get_oauth2_token(session: AsyncSession, value: str) -> OAuth2Token | None: 1ab
56 return await oauth2_token_service.get_by_access_token(session, value) 1b
59async def get_personal_access_token( 1ab
60 session: AsyncSession, value: str
61) -> PersonalAccessToken | None:
62 token = await personal_access_token_service.get_by_token(session, value) 1b
64 if token is not None:
65 enqueue_job(
66 "personal_access_token.record_usage",
67 personal_access_token_id=token.id,
68 last_used_at=utc_now().timestamp(),
69 )
71 return token
74async def get_organization_access_token( 1a
75 session: AsyncSession, value: str
76) -> OrganizationAccessToken | None:
77 token = await organization_access_token_service.get_by_token(session, value) 1b
79 if token is not None:
80 enqueue_job(
81 "organization_access_token.record_usage",
82 organization_access_token_id=token.id,
83 last_used_at=utc_now().timestamp(),
84 )
86 return token
89async def get_customer_session( 1a
90 session: AsyncSession, value: str
91) -> CustomerSession | None:
92 return await customer_session_service.get_by_token(session, value) 1b
95async def get_auth_subject( 1a
96 request: Request, session: AsyncSession
97) -> AuthSubject[Subject]:
98 token = get_bearer_token(request) 1bc
99 if token is not None: 1bc
100 if is_registration_token_prefix(token): 100 ↛ 101line 100 didn't jump to line 101 because the condition on line 100 was never true1b
101 return AuthSubject(Anonymous(), set(), None)
103 customer_session = await get_customer_session(session, token) 1b
104 if customer_session:
105 return AuthSubject(
106 customer_session.customer,
107 {Scope.customer_portal_write},
108 customer_session,
109 )
111 organization_access_token = await get_organization_access_token(session, token) 1b
112 if organization_access_token:
113 return AuthSubject(
114 organization_access_token.organization,
115 organization_access_token.scopes,
116 organization_access_token,
117 )
119 oauth2_token = await get_oauth2_token(session, token) 1b
120 if oauth2_token:
121 return AuthSubject(oauth2_token.sub, oauth2_token.scopes, oauth2_token)
123 personal_access_token = await get_personal_access_token(session, token) 1b
124 if personal_access_token:
125 return AuthSubject(
126 personal_access_token.user,
127 personal_access_token.scopes,
128 personal_access_token,
129 )
131 raise InvalidTokenError()
133 user_session = await get_user_session(request, session) 1bc
134 if user_session is not None: 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true1bc
135 return AuthSubject(user_session.user, set(user_session.scopes), user_session)
137 return AuthSubject(Anonymous(), set(), None) 1bc
140class AuthSubjectMiddleware: 1a
141 def __init__(self, app: ASGIApp) -> None: 1a
142 self.app = app
144 async def __call__(self, scope: ASGIScope, receive: Receive, send: Send) -> None: 1a
145 if scope["type"] != "http": 1bc
146 await self.app(scope, receive, send)
147 return
149 session: AsyncSession = scope["state"]["async_session"] 1bc
150 request = Request(scope) 1bc
152 try: 1bc
153 auth_subject = await get_auth_subject(request, session) 1bc
154 except OAuth2Error as e:
155 response = await oauth2_error_exception_handler(request, e)
156 return await response(scope, receive, send)
158 scope["state"]["auth_subject"] = auth_subject 1bc
160 with logfire.set_baggage(**auth_subject.log_context): 1bc
161 log.info("Authenticated subject", **auth_subject.log_context) 1bc
162 set_sentry_user(auth_subject) 1bc
163 await self.app(scope, receive, send) 1bc