Coverage for polar/oauth2/service/oauth2_token.py: 35%

60 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 16:17 +0000

1import time 1a

2from typing import cast 1a

3 

4import structlog 1a

5from sqlalchemy import select 1a

6from sqlalchemy.orm import joinedload 1a

7 

8from polar.config import settings 1a

9from polar.email.react import render_email_template 1a

10from polar.email.schemas import OAuth2LeakedTokenEmail, OAuth2LeakedTokenProps 1a

11from polar.email.sender import enqueue_email 1a

12from polar.enums import TokenType 1a

13from polar.exceptions import PolarError 1a

14from polar.kit.crypto import get_token_hash 1a

15from polar.kit.services import ResourceServiceReader 1a

16from polar.logging import Logger 1a

17from polar.models import OAuth2Token, User 1a

18from polar.postgres import AsyncSession 1a

19from polar.user_organization.service import ( 1a

20 user_organization as user_organization_service, 

21) 

22 

23log: Logger = structlog.get_logger() 1a

24 

25 

26class OAuth2TokenError(PolarError): ... 1a

27 

28 

29class OAuth2TokenService(ResourceServiceReader[OAuth2Token]): 1a

30 async def get_by_access_token( 1a

31 self, session: AsyncSession, access_token: str 

32 ) -> OAuth2Token | None: 

33 access_token_hash = get_token_hash(access_token, secret=settings.SECRET) 1b

34 statement = ( 1b

35 select(OAuth2Token) 

36 .where(OAuth2Token.access_token == access_token_hash) 

37 .options(joinedload(OAuth2Token.client)) 

38 ) 

39 result = await session.execute(statement) 1b

40 token = result.unique().scalar_one_or_none() 

41 

42 if token is None: 

43 return None 

44 

45 if cast(bool, token.is_revoked()): 

46 return None 

47 

48 if not token.sub.can_authenticate: 

49 return None 

50 

51 return token 

52 

53 async def revoke_leaked( 1a

54 self, 

55 session: AsyncSession, 

56 token: str, 

57 token_type: TokenType, 

58 *, 

59 notifier: str, 

60 url: str | None = None, 

61 ) -> bool: 

62 statement = select(OAuth2Token).options( 

63 joinedload(OAuth2Token.user), 

64 joinedload(OAuth2Token.organization), 

65 joinedload(OAuth2Token.client), 

66 ) 

67 

68 if token_type == TokenType.access_token: 

69 statement = statement.where( 

70 OAuth2Token.access_token 

71 == get_token_hash(token, secret=settings.SECRET) 

72 ) 

73 elif token_type == TokenType.refresh_token: 

74 statement = statement.where( 

75 OAuth2Token.refresh_token 

76 == get_token_hash(token, secret=settings.SECRET) 

77 ) 

78 else: 

79 raise ValueError(f"Unsupported token type: {token_type}") 

80 

81 result = await session.execute(statement) 

82 oauth2_token = result.unique().scalar_one_or_none() 

83 

84 if oauth2_token is None: 

85 return False 

86 

87 if cast(bool, oauth2_token.is_revoked()): 

88 return True 

89 

90 # Revoke 

91 oauth2_token.access_token_revoked_at = int(time.time()) # pyright: ignore 

92 oauth2_token.refresh_token_revoked_at = int(time.time()) # pyright: ignore 

93 session.add(oauth2_token) 

94 

95 # Notify 

96 recipients: list[str] 

97 sub = oauth2_token.sub 

98 if isinstance(sub, User): 

99 recipients = [sub.email] 

100 else: 

101 members = await user_organization_service.list_by_org(session, sub.id) 

102 recipients = [member.user.email for member in members] 

103 

104 oauth2_client = oauth2_token.client 

105 

106 for recipient in recipients: 

107 body = render_email_template( 

108 OAuth2LeakedTokenEmail( 

109 props=OAuth2LeakedTokenProps( 

110 email=recipient, 

111 client_name=cast(str, oauth2_client.client_name), 

112 notifier=notifier, 

113 url=url or "", 

114 ) 

115 ) 

116 ) 

117 enqueue_email( 

118 to_email_addr=recipient, 

119 subject="Security Notice - Your Polar Access Token has been leaked", 

120 html_content=body, 

121 ) 

122 

123 log.info( 

124 "Revoke leaked access token and refresh token", 

125 id=oauth2_token.id, 

126 token_type=token_type, 

127 notifier=notifier, 

128 url=url, 

129 ) 

130 

131 return True 

132 

133 

134oauth2_token = OAuth2TokenService(OAuth2Token) 1a