Coverage for polar/auth/middlewares.py: 67%

88 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 15:52 +0000

1import logfire 1a

2import structlog 1a

3from fastapi import Request 1a

4from fastapi.security.utils import get_authorization_scheme_param 1a

5from starlette.types import ASGIApp, Receive, Send 1a

6from starlette.types import Scope as ASGIScope 1a

7 

8from polar.customer_session.service import customer_session as customer_session_service 1a

9from polar.kit.utils import utc_now 1a

10from polar.logging import Logger 1a

11from polar.models import ( 1a

12 CustomerSession, 

13 OAuth2Token, 

14 OrganizationAccessToken, 

15 PersonalAccessToken, 

16 UserSession, 

17) 

18from polar.oauth2.constants import is_registration_token_prefix 1a

19from polar.oauth2.exception_handlers import OAuth2Error, oauth2_error_exception_handler 1a

20from polar.oauth2.exceptions import InvalidTokenError 1a

21from polar.oauth2.service.oauth2_token import oauth2_token as oauth2_token_service 1a

22from polar.organization_access_token.service import ( 1a

23 organization_access_token as organization_access_token_service, 

24) 

25from polar.personal_access_token.service import ( 1a

26 personal_access_token as personal_access_token_service, 

27) 

28from polar.postgres import AsyncSession 1a

29from polar.sentry import set_sentry_user 1a

30from polar.worker._enqueue import enqueue_job 1a

31 

32from .models import Anonymous, AuthSubject, Subject 1a

33from .scope import Scope 1a

34from .service import auth as auth_service 1a

35 

36log: Logger = structlog.get_logger(__name__) 1a

37 

38 

39async def get_user_session( 1a

40 request: Request, session: AsyncSession 

41) -> UserSession | None: 

42 return await auth_service.authenticate(session, request) 1bc

43 

44 

45def get_bearer_token(request: Request) -> str | None: 1a

46 authorization = request.headers.get("Authorization") 1bc

47 scheme, value = get_authorization_scheme_param(authorization) 1bc

48 if not scheme or not value or scheme.lower() != "bearer": 1bc

49 return None 1bc

50 if not value.isascii(): 50 ↛ 51line 50 didn't jump to line 51 because the condition on line 50 was never true1b

51 return None 

52 return value 1b

53 

54 

55async def get_oauth2_token(session: AsyncSession, value: str) -> OAuth2Token | None: 1ab

56 return await oauth2_token_service.get_by_access_token(session, value) 1b

57 

58 

59async def get_personal_access_token( 1ab

60 session: AsyncSession, value: str 

61) -> PersonalAccessToken | None: 

62 token = await personal_access_token_service.get_by_token(session, value) 1b

63 

64 if token is not None: 

65 enqueue_job( 

66 "personal_access_token.record_usage", 

67 personal_access_token_id=token.id, 

68 last_used_at=utc_now().timestamp(), 

69 ) 

70 

71 return token 

72 

73 

74async def get_organization_access_token( 1a

75 session: AsyncSession, value: str 

76) -> OrganizationAccessToken | None: 

77 token = await organization_access_token_service.get_by_token(session, value) 1b

78 

79 if token is not None: 

80 enqueue_job( 

81 "organization_access_token.record_usage", 

82 organization_access_token_id=token.id, 

83 last_used_at=utc_now().timestamp(), 

84 ) 

85 

86 return token 

87 

88 

89async def get_customer_session( 1a

90 session: AsyncSession, value: str 

91) -> CustomerSession | None: 

92 return await customer_session_service.get_by_token(session, value) 1b

93 

94 

95async def get_auth_subject( 1a

96 request: Request, session: AsyncSession 

97) -> AuthSubject[Subject]: 

98 token = get_bearer_token(request) 1bc

99 if token is not None: 1bc

100 if is_registration_token_prefix(token): 100 ↛ 101line 100 didn't jump to line 101 because the condition on line 100 was never true1b

101 return AuthSubject(Anonymous(), set(), None) 

102 

103 customer_session = await get_customer_session(session, token) 1b

104 if customer_session: 

105 return AuthSubject( 

106 customer_session.customer, 

107 {Scope.customer_portal_write}, 

108 customer_session, 

109 ) 

110 

111 organization_access_token = await get_organization_access_token(session, token) 1b

112 if organization_access_token: 

113 return AuthSubject( 

114 organization_access_token.organization, 

115 organization_access_token.scopes, 

116 organization_access_token, 

117 ) 

118 

119 oauth2_token = await get_oauth2_token(session, token) 1b

120 if oauth2_token: 

121 return AuthSubject(oauth2_token.sub, oauth2_token.scopes, oauth2_token) 

122 

123 personal_access_token = await get_personal_access_token(session, token) 1b

124 if personal_access_token: 

125 return AuthSubject( 

126 personal_access_token.user, 

127 personal_access_token.scopes, 

128 personal_access_token, 

129 ) 

130 

131 raise InvalidTokenError() 

132 

133 user_session = await get_user_session(request, session) 1bc

134 if user_session is not None: 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true1bc

135 return AuthSubject(user_session.user, set(user_session.scopes), user_session) 

136 

137 return AuthSubject(Anonymous(), set(), None) 1bc

138 

139 

140class AuthSubjectMiddleware: 1a

141 def __init__(self, app: ASGIApp) -> None: 1a

142 self.app = app 

143 

144 async def __call__(self, scope: ASGIScope, receive: Receive, send: Send) -> None: 1a

145 if scope["type"] != "http": 1bc

146 await self.app(scope, receive, send) 

147 return 

148 

149 session: AsyncSession = scope["state"]["async_session"] 1bc

150 request = Request(scope) 1bc

151 

152 try: 1bc

153 auth_subject = await get_auth_subject(request, session) 1bc

154 except OAuth2Error as e: 

155 response = await oauth2_error_exception_handler(request, e) 

156 return await response(scope, receive, send) 

157 

158 scope["state"]["auth_subject"] = auth_subject 1bc

159 

160 with logfire.set_baggage(**auth_subject.log_context): 1bc

161 log.info("Authenticated subject", **auth_subject.log_context) 1bc

162 set_sentry_user(auth_subject) 1bc

163 await self.app(scope, receive, send) 1bc