Coverage for polar/user/oauth_service.py: 42%
40 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
1import structlog 1a
2from sqlalchemy import func, select 1a
4from polar.exceptions import PolarError 1a
5from polar.kit.services import ResourceServiceReader 1a
6from polar.logging import Logger 1a
7from polar.models import OAuthAccount, User 1a
8from polar.models.user import OAuthPlatform 1a
9from polar.postgres import AsyncSession 1a
11log: Logger = structlog.get_logger() 1a
14class OAuthError(PolarError): ... 1a
17class OAuthAccountNotFound(OAuthError): 1a
18 def __init__(self, platform: OAuthPlatform) -> None: 1a
19 self.platform = platform
20 message = f"No {platform} OAuth account found for this user."
21 super().__init__(message, 404)
24class CannotDisconnectLastAuthMethod(OAuthError): 1a
25 def __init__(self) -> None: 1a
26 message = (
27 "Cannot disconnect this OAuth account as it's your only authentication method. "
28 "Please verify your email or connect another OAuth provider before disconnecting."
29 )
30 super().__init__(message, 400)
33class OAuthAccountService(ResourceServiceReader[OAuthAccount]): 1a
34 async def get_by_platform_and_account_id( 1a
35 self, session: AsyncSession, platform: OAuthPlatform, account_id: str
36 ) -> OAuthAccount | None:
37 stmt = select(OAuthAccount).where(
38 OAuthAccount.platform == platform,
39 OAuthAccount.account_id == account_id,
40 )
41 result = await session.execute(stmt)
42 return result.scalars().one_or_none()
44 async def disconnect_platform( 1a
45 self, session: AsyncSession, user: User, platform: OAuthPlatform
46 ) -> None:
47 oauth_accounts_statement = select(OAuthAccount).where(
48 OAuthAccount.platform == platform,
49 OAuthAccount.user_id == user.id,
50 )
51 oauth_account_result = await session.execute(oauth_accounts_statement)
52 # Some users have a buggy state with multiple OAuth accounts for the same platform
53 oauth_accounts = oauth_account_result.scalars().all()
55 if len(oauth_accounts) == 0:
56 raise OAuthAccountNotFound(platform)
58 other_accounts_count_statement = select(func.count(OAuthAccount.id)).where(
59 OAuthAccount.user_id == user.id,
60 OAuthAccount.id.not_in([oa.id for oa in oauth_accounts]),
61 )
62 other_accounts_count_result = await session.execute(
63 other_accounts_count_statement
64 )
65 other_accounts_count = other_accounts_count_result.scalar_one()
67 if other_accounts_count == 0 and not user.email_verified:
68 raise CannotDisconnectLastAuthMethod()
70 for oauth_account in oauth_accounts:
71 await session.delete(oauth_account)
72 log.info(
73 "oauth_account.disconnect",
74 oauth_account_id=oauth_account.id,
75 platform=platform,
76 )
78 await session.flush()
81oauth_account_service = OAuthAccountService(OAuthAccount) 1a