Coverage for polar/login_code/service.py: 42%

64 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 16:17 +0000

1import datetime 1a

2import secrets 1a

3import string 1a

4from math import ceil 1a

5 

6import structlog 1a

7from sqlalchemy import select 1a

8from sqlalchemy.orm import joinedload 1a

9 

10from polar.config import settings 1a

11from polar.email.react import render_email_template 1a

12from polar.email.schemas import LoginCodeEmail, LoginCodeProps 1a

13from polar.email.sender import enqueue_email 1a

14from polar.exceptions import PolarError 1a

15from polar.kit.crypto import get_token_hash 1a

16from polar.kit.utils import utc_now 1a

17from polar.models import LoginCode, User 1a

18from polar.postgres import AsyncSession 1a

19from polar.user.repository import UserRepository 1a

20from polar.user.schemas import UserSignupAttribution 1a

21from polar.user.service import user as user_service 1a

22 

23log = structlog.get_logger() 1a

24 

25 

26class LoginCodeError(PolarError): ... 1a

27 

28 

29class LoginCodeInvalidOrExpired(LoginCodeError): 1a

30 def __init__(self) -> None: 1a

31 super().__init__("This login code is invalid or has expired.", status_code=401) 

32 

33 

34class LoginCodeService: 1a

35 async def request( 1a

36 self, 

37 session: AsyncSession, 

38 email: str, 

39 *, 

40 return_to: str | None = None, 

41 signup_attribution: UserSignupAttribution | None = None, 

42 ) -> tuple[LoginCode, str]: 

43 user_repository = UserRepository.from_session(session) 

44 user = await user_repository.get_by_email(email) 

45 

46 code, code_hash = self._generate_code_hash() 

47 

48 login_code = LoginCode( 

49 code_hash=code_hash, 

50 email=email, 

51 user_id=user.id if user is not None else None, 

52 expires_at=utc_now() 

53 + datetime.timedelta(seconds=settings.LOGIN_CODE_TTL_SECONDS), 

54 ) 

55 session.add(login_code) 

56 await session.flush() 

57 

58 return login_code, code 

59 

60 async def send( 1a

61 self, 

62 login_code: LoginCode, 

63 code: str, 

64 ) -> None: 

65 delta = login_code.expires_at - utc_now() 

66 code_lifetime_minutes = int(ceil(delta.seconds / 60)) 

67 

68 email = login_code.email 

69 subject = "Sign in to Polar" 

70 body = render_email_template( 

71 LoginCodeEmail( 

72 props=LoginCodeProps( 

73 email=email, 

74 code=code, 

75 code_lifetime_minutes=code_lifetime_minutes, 

76 ) 

77 ) 

78 ) 

79 

80 enqueue_email(to_email_addr=email, subject=subject, html_content=body) 

81 

82 if settings.is_development(): 

83 log.info( 

84 "\n" 

85 "╔══════════════════════════════════════════════════════════╗\n" 

86 "║ ║\n" 

87 f"║ 🔑 LOGIN CODE: {code} ║\n" 

88 "║ ║\n" 

89 "╚══════════════════════════════════════════════════════════╝" 

90 ) 

91 

92 async def authenticate( 1a

93 self, 

94 session: AsyncSession, 

95 code: str, 

96 email: str, 

97 *, 

98 signup_attribution: UserSignupAttribution | None = None, 

99 ) -> tuple[User, bool]: 

100 code_hash = get_token_hash(code, secret=settings.SECRET) 

101 

102 statement = ( 

103 select(LoginCode) 

104 .where( 

105 LoginCode.code_hash == code_hash, 

106 LoginCode.email == email, 

107 LoginCode.expires_at > utc_now(), 

108 ) 

109 .options(joinedload(LoginCode.user)) 

110 ) 

111 result = await session.execute(statement) 

112 login_code = result.unique().scalar_one_or_none() 

113 

114 if login_code is None: 

115 raise LoginCodeInvalidOrExpired() 

116 

117 is_signup = False 

118 user = login_code.user 

119 if user is None: 

120 user, is_signup = await user_service.get_by_email_or_create( 

121 session, 

122 login_code.email, 

123 signup_attribution=signup_attribution, 

124 ) 

125 

126 # Mark email as verified 

127 if not user.email_verified: 

128 is_signup = True 

129 user.email_verified = True 

130 session.add(user) 

131 

132 await session.delete(login_code) 

133 

134 return user, is_signup 

135 

136 def _generate_code_hash(self) -> tuple[str, str]: 1a

137 code = "".join( 

138 secrets.choice(string.ascii_uppercase + string.digits) 

139 for _ in range(settings.LOGIN_CODE_LENGTH) 

140 ) 

141 code_hash = get_token_hash(code, secret=settings.SECRET) 

142 return code, code_hash 

143 

144 

145login_code = LoginCodeService() 1a