Coverage for polar/email_update/service.py: 46%

61 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 16:17 +0000

1from math import ceil 1a

2from urllib.parse import urlencode 1a

3 

4from sqlalchemy import delete 1a

5from sqlalchemy.orm import joinedload 1a

6 

7from polar.auth.models import AuthSubject 1a

8from polar.config import settings 1a

9from polar.email.react import render_email_template 1a

10from polar.email.schemas import EmailUpdateEmail, EmailUpdateProps 1a

11from polar.email.sender import enqueue_email 1a

12from polar.exceptions import PolarError, PolarRequestValidationError 1a

13from polar.kit.crypto import generate_token_hash_pair, get_token_hash 1a

14from polar.kit.extensions.sqlalchemy import sql 1a

15from polar.kit.services import ResourceServiceReader 1a

16from polar.kit.utils import utc_now 1a

17from polar.models import EmailVerification 1a

18from polar.models.user import User 1a

19from polar.postgres import AsyncSession 1a

20from polar.user.repository import UserRepository 1a

21 

22TOKEN_PREFIX = "polar_ev_" 1a

23 

24 

25class EmailUpdateError(PolarError): ... 1a

26 

27 

28class InvalidEmailUpdate(EmailUpdateError): 1a

29 def __init__(self) -> None: 1a

30 super().__init__( 

31 "This email update request is invalid or has expired.", status_code=401 

32 ) 

33 

34 

35class EmailUpdateService(ResourceServiceReader[EmailVerification]): 1a

36 async def request_email_update( 1a

37 self, 

38 email: str, 

39 session: AsyncSession, 

40 auth_subject: AuthSubject[User], 

41 ) -> tuple[EmailVerification, str]: 

42 user = auth_subject.subject 

43 

44 user_repository = UserRepository.from_session(session) 

45 existing_user = await user_repository.get_by_email(email) 

46 if existing_user is not None and existing_user.id != user.id: 

47 raise PolarRequestValidationError( 

48 [ 

49 { 

50 "type": "value_error", 

51 "loc": ("body", "email"), 

52 "msg": "Another user is already using this email.", 

53 "input": email, 

54 } 

55 ] 

56 ) 

57 

58 token, token_hash = generate_token_hash_pair( 

59 secret=settings.SECRET, prefix=TOKEN_PREFIX 

60 ) 

61 email_update_record = EmailVerification( 

62 email=email, token_hash=token_hash, user=user 

63 ) 

64 

65 session.add(email_update_record) 

66 await session.flush() 

67 

68 return email_update_record, token 

69 

70 async def send_email( 1a

71 self, 

72 email_update_record: EmailVerification, 

73 token: str, 

74 base_url: str, 

75 *, 

76 extra_url_params: dict[str, str] = {}, 

77 ) -> None: 

78 delta = email_update_record.expires_at - utc_now() 

79 token_lifetime_minutes = int(ceil(delta.seconds / 60)) 

80 

81 email = email_update_record.email 

82 url_params = {"token": token, **extra_url_params} 

83 body = render_email_template( 

84 EmailUpdateEmail( 

85 props=EmailUpdateProps( 

86 email=email, 

87 token_lifetime_minutes=token_lifetime_minutes, 

88 url=f"{base_url}?{urlencode(url_params)}", 

89 ) 

90 ) 

91 ) 

92 

93 enqueue_email( 

94 to_email_addr=email, subject="Update your email", html_content=body 

95 ) 

96 

97 async def verify(self, session: AsyncSession, token: str) -> User: 1a

98 token_hash = get_token_hash(token, secret=settings.SECRET) 

99 email_update_record = await self._get_email_update_record_by_token_hash( 

100 session, token_hash 

101 ) 

102 

103 if email_update_record is None: 

104 raise InvalidEmailUpdate() 

105 

106 user = email_update_record.user 

107 user.email = email_update_record.email 

108 session.add(user) 

109 

110 await session.delete(email_update_record) 

111 

112 return user 

113 

114 async def _get_email_update_record_by_token_hash( 1a

115 self, session: AsyncSession, token_hash: str 

116 ) -> EmailVerification | None: 

117 statement = ( 

118 sql.select(EmailVerification) 

119 .where( 

120 EmailVerification.token_hash == token_hash, 

121 EmailVerification.expires_at > utc_now(), 

122 ) 

123 .options(joinedload(EmailVerification.user)) 

124 ) 

125 

126 res = await session.execute(statement) 

127 return res.scalars().unique().one_or_none() 

128 

129 async def delete_expired_record(self, session: AsyncSession) -> None: 1a

130 statement = delete(EmailVerification).where( 

131 EmailVerification.expires_at < utc_now() 

132 ) 

133 await session.execute(statement) 

134 await session.flush() 

135 

136 

137email_update = EmailUpdateService(EmailVerification) 1a