Coverage for polar/oauth2/service/oauth2_client.py: 40%

53 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1from collections.abc import Sequence 1a

2from typing import cast 1a

3 

4import structlog 1a

5from sqlalchemy import select 1a

6from sqlalchemy.orm import joinedload 1a

7 

8from polar.auth.models import AuthSubject 1a

9from polar.email.react import render_email_template 1a

10from polar.email.schemas import OAuth2LeakedClientEmail, OAuth2LeakedClientProps 1a

11from polar.email.sender import enqueue_email 1a

12from polar.enums import TokenType 1a

13from polar.exceptions import PolarError 1a

14from polar.kit.crypto import generate_token 1a

15from polar.kit.pagination import PaginationParams, paginate 1a

16from polar.kit.services import ResourceServiceReader 1a

17from polar.logging import Logger 1a

18from polar.models import OAuth2Client, User 1a

19from polar.postgres import AsyncSession 1a

20 

21from ..constants import CLIENT_REGISTRATION_TOKEN_PREFIX, CLIENT_SECRET_PREFIX 1a

22 

23log: Logger = structlog.get_logger() 1a

24 

25 

26class OAuth2ClientError(PolarError): ... 1a

27 

28 

29class OAuth2ClientService(ResourceServiceReader[OAuth2Client]): 1a

30 async def list( 1a

31 self, 

32 session: AsyncSession, 

33 auth_subject: AuthSubject[User], 

34 *, 

35 pagination: PaginationParams, 

36 ) -> tuple[Sequence[OAuth2Client], int]: 

37 statement = ( 

38 select(OAuth2Client) 

39 .where( 

40 OAuth2Client.user_id == auth_subject.subject.id, 

41 OAuth2Client.deleted_at.is_(None), 

42 ) 

43 .order_by(OAuth2Client.created_at.desc()) 

44 ) 

45 return await paginate(session, statement, pagination=pagination) 

46 

47 async def get_by_client_id( 1a

48 self, session: AsyncSession, client_id: str 

49 ) -> OAuth2Client | None: 

50 statement = select(OAuth2Client).where( 

51 OAuth2Client.client_id == client_id, OAuth2Client.deleted_at.is_(None) 

52 ) 

53 result = await session.execute(statement) 

54 return result.scalar_one_or_none() 

55 

56 async def revoke_leaked( 1a

57 self, 

58 session: AsyncSession, 

59 token: str, 

60 token_type: TokenType, 

61 *, 

62 notifier: str, 

63 url: str | None = None, 

64 ) -> bool: 

65 statement = select(OAuth2Client).options(joinedload(OAuth2Client.user)) 

66 

67 if token_type == TokenType.client_secret: 

68 statement = statement.where(OAuth2Client.client_secret == token) 

69 elif token_type == TokenType.client_registration_token: 

70 statement = statement.where(OAuth2Client.registration_access_token == token) 

71 else: 

72 raise ValueError(f"Unsupported token type: {token_type}") 

73 

74 result = await session.execute(statement) 

75 client = result.unique().scalar_one_or_none() 

76 

77 if client is None: 

78 return False 

79 

80 subject: str 

81 if token_type == TokenType.client_secret: 

82 client.client_secret = generate_token(prefix=CLIENT_SECRET_PREFIX) # pyright: ignore 

83 subject = ( 

84 "Security Notice - Your Polar OAuth2 Client Secret has been leaked" 

85 ) 

86 elif token_type == TokenType.client_registration_token: 

87 client.registration_access_token = generate_token( 

88 prefix=CLIENT_REGISTRATION_TOKEN_PREFIX 

89 ) 

90 subject = ( 

91 "Security Notice - " 

92 "Your Polar OAuth2 Client Registration Token has been leaked" 

93 ) 

94 session.add(client) 

95 

96 if client.user is not None: 

97 email = client.user.email 

98 body = render_email_template( 

99 OAuth2LeakedClientEmail( 

100 props=OAuth2LeakedClientProps( 

101 email=email, 

102 token_type=token_type, 

103 client_name=cast(str, client.client_name), 

104 notifier=notifier, 

105 url=url or "", 

106 ) 

107 ) 

108 ) 

109 

110 enqueue_email(to_email_addr=email, subject=subject, html_content=body) 

111 

112 log.info( 

113 "Revoke leaked OAuth2 client", 

114 id=client.id, 

115 token_type=token_type, 

116 notifier=notifier, 

117 url=url, 

118 ) 

119 

120 return True 

121 

122 

123oauth2_client = OAuth2ClientService(OAuth2Client) 1a