Coverage for polar/user/oauth_service.py: 42%

40 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 15:52 +0000

1import structlog 1a

2from sqlalchemy import func, select 1a

3 

4from polar.exceptions import PolarError 1a

5from polar.kit.services import ResourceServiceReader 1a

6from polar.logging import Logger 1a

7from polar.models import OAuthAccount, User 1a

8from polar.models.user import OAuthPlatform 1a

9from polar.postgres import AsyncSession 1a

10 

11log: Logger = structlog.get_logger() 1a

12 

13 

14class OAuthError(PolarError): ... 1a

15 

16 

17class OAuthAccountNotFound(OAuthError): 1a

18 def __init__(self, platform: OAuthPlatform) -> None: 1a

19 self.platform = platform 

20 message = f"No {platform} OAuth account found for this user." 

21 super().__init__(message, 404) 

22 

23 

24class CannotDisconnectLastAuthMethod(OAuthError): 1a

25 def __init__(self) -> None: 1a

26 message = ( 

27 "Cannot disconnect this OAuth account as it's your only authentication method. " 

28 "Please verify your email or connect another OAuth provider before disconnecting." 

29 ) 

30 super().__init__(message, 400) 

31 

32 

33class OAuthAccountService(ResourceServiceReader[OAuthAccount]): 1a

34 async def get_by_platform_and_account_id( 1a

35 self, session: AsyncSession, platform: OAuthPlatform, account_id: str 

36 ) -> OAuthAccount | None: 

37 stmt = select(OAuthAccount).where( 

38 OAuthAccount.platform == platform, 

39 OAuthAccount.account_id == account_id, 

40 ) 

41 result = await session.execute(stmt) 

42 return result.scalars().one_or_none() 

43 

44 async def disconnect_platform( 1a

45 self, session: AsyncSession, user: User, platform: OAuthPlatform 

46 ) -> None: 

47 oauth_accounts_statement = select(OAuthAccount).where( 

48 OAuthAccount.platform == platform, 

49 OAuthAccount.user_id == user.id, 

50 ) 

51 oauth_account_result = await session.execute(oauth_accounts_statement) 

52 # Some users have a buggy state with multiple OAuth accounts for the same platform 

53 oauth_accounts = oauth_account_result.scalars().all() 

54 

55 if len(oauth_accounts) == 0: 

56 raise OAuthAccountNotFound(platform) 

57 

58 other_accounts_count_statement = select(func.count(OAuthAccount.id)).where( 

59 OAuthAccount.user_id == user.id, 

60 OAuthAccount.id.not_in([oa.id for oa in oauth_accounts]), 

61 ) 

62 other_accounts_count_result = await session.execute( 

63 other_accounts_count_statement 

64 ) 

65 other_accounts_count = other_accounts_count_result.scalar_one() 

66 

67 if other_accounts_count == 0 and not user.email_verified: 

68 raise CannotDisconnectLastAuthMethod() 

69 

70 for oauth_account in oauth_accounts: 

71 await session.delete(oauth_account) 

72 log.info( 

73 "oauth_account.disconnect", 

74 oauth_account_id=oauth_account.id, 

75 platform=platform, 

76 ) 

77 

78 await session.flush() 

79 

80 

81oauth_account_service = OAuthAccountService(OAuthAccount) 1a