Coverage for polar/customer_portal/service/customer_session.py: 39%

73 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 16:17 +0000

1import secrets 1a

2import string 1a

3import uuid 1a

4from math import ceil 1a

5 

6from sqlalchemy import select 1a

7 

8from polar.config import settings 1a

9from polar.customer.repository import CustomerRepository 1a

10from polar.customer_session.service import customer_session as customer_session_service 1a

11from polar.email.react import render_email_template 1a

12from polar.email.schemas import CustomerSessionCodeEmail, CustomerSessionCodeProps 1a

13from polar.email.sender import enqueue_email 1a

14from polar.exceptions import PolarError 1a

15from polar.kit.crypto import get_token_hash 1a

16from polar.kit.utils import utc_now 1a

17from polar.models import CustomerSession, CustomerSessionCode, Organization 1a

18from polar.organization.repository import OrganizationRepository 1a

19from polar.postgres import AsyncSession 1a

20 

21 

22class CustomerSessionError(PolarError): ... 1a

23 

24 

25class OrganizationDoesNotExist(CustomerSessionError): 1a

26 def __init__(self, organization_id: uuid.UUID) -> None: 1a

27 self.organization_id = organization_id 

28 message = f"Organization {organization_id} does not exist." 

29 super().__init__(message) 

30 

31 

32class CustomerDoesNotExist(CustomerSessionError): 1a

33 def __init__(self, email: str, organization: Organization) -> None: 1a

34 self.email = email 

35 self.organization = organization 

36 message = f"Customer does not exist for email {email} and organization {organization.id}." 

37 super().__init__(message) 

38 

39 

40class CustomerSessionCodeInvalidOrExpired(CustomerSessionError): 1a

41 def __init__(self) -> None: 1a

42 super().__init__( 

43 "This customer session code is invalid or has expired.", status_code=401 

44 ) 

45 

46 

47class CustomerSessionService: 1a

48 async def request( 1a

49 self, session: AsyncSession, email: str, organization_id: uuid.UUID 

50 ) -> tuple[CustomerSessionCode, str]: 

51 organization_repository = OrganizationRepository.from_session(session) 

52 organization = await organization_repository.get_by_id(organization_id) 

53 if organization is None: 

54 raise OrganizationDoesNotExist(organization_id) 

55 

56 repository = CustomerRepository.from_session(session) 

57 customer = await repository.get_by_email_and_organization( 

58 email, organization.id 

59 ) 

60 if customer is None: 

61 raise CustomerDoesNotExist(email, organization) 

62 

63 code, code_hash = self._generate_code_hash() 

64 

65 customer_session_code = CustomerSessionCode( 

66 code=code_hash, email=customer.email, customer=customer 

67 ) 

68 session.add(customer_session_code) 

69 

70 return customer_session_code, code 

71 

72 async def send( 1a

73 self, 

74 session: AsyncSession, 

75 customer_session_code: CustomerSessionCode, 

76 code: str, 

77 ) -> None: 

78 customer = customer_session_code.customer 

79 organization_repository = OrganizationRepository.from_session(session) 

80 organization = await organization_repository.get_by_id( 

81 customer_session_code.customer.organization_id 

82 ) 

83 assert organization is not None 

84 

85 delta = customer_session_code.expires_at - utc_now() 

86 code_lifetime_minutes = int(ceil(delta.seconds / 60)) 

87 

88 body = render_email_template( 

89 CustomerSessionCodeEmail( 

90 props=CustomerSessionCodeProps.model_validate( 

91 { 

92 "email": customer.email, 

93 "organization": organization, 

94 "code": code, 

95 "code_lifetime_minutes": code_lifetime_minutes, 

96 "url": settings.generate_frontend_url( 

97 f"/{organization.slug}/portal/authenticate" 

98 ), 

99 } 

100 ) 

101 ) 

102 ) 

103 

104 enqueue_email( 

105 **organization.email_from_reply, 

106 to_email_addr=customer.email, 

107 subject=f"Access your {organization.name} purchases", 

108 html_content=body, 

109 ) 

110 

111 async def authenticate( 1a

112 self, session: AsyncSession, code: str 

113 ) -> tuple[str, CustomerSession]: 

114 code_hash = get_token_hash(code, secret=settings.SECRET) 

115 

116 statement = select(CustomerSessionCode).where( 

117 CustomerSessionCode.expires_at > utc_now(), 

118 CustomerSessionCode.code == code_hash, 

119 ) 

120 result = await session.execute(statement) 

121 customer_session_code = result.scalar_one_or_none() 

122 

123 if customer_session_code is None: 

124 raise CustomerSessionCodeInvalidOrExpired() 

125 

126 customer = customer_session_code.customer 

127 if customer_session_code.email.lower() == customer.email.lower(): 

128 customer_repository = CustomerRepository.from_session(session) 

129 await customer_repository.update( 

130 customer, update_dict={"email_verified": True} 

131 ) 

132 

133 await session.delete(customer_session_code) 

134 

135 return await customer_session_service.create_customer_session( 

136 session, customer_session_code.customer 

137 ) 

138 

139 def _generate_code_hash(self) -> tuple[str, str]: 1a

140 code = "".join( 

141 secrets.choice(string.ascii_uppercase + string.digits) 

142 for _ in range(settings.CUSTOMER_SESSION_CODE_LENGTH) 

143 ) 

144 code_hash = get_token_hash(code, secret=settings.SECRET) 

145 return code, code_hash 

146 

147 

148customer_session = CustomerSessionService() 1a