Coverage for polar/customer_session/service.py: 48%

61 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1import uuid 1a

2 

3import structlog 1a

4from pydantic import HttpUrl 1a

5from sqlalchemy import delete, select 1a

6from sqlalchemy.orm import joinedload 1a

7from sqlalchemy.orm.strategy_options import contains_eager 1a

8 

9from polar.auth.models import AuthSubject, Organization, User 1a

10from polar.config import settings 1a

11from polar.customer.repository import CustomerRepository 1a

12from polar.enums import TokenType 1a

13from polar.exceptions import PolarRequestValidationError 1a

14from polar.kit.crypto import generate_token_hash_pair, get_token_hash 1a

15from polar.kit.services import ResourceServiceReader 1a

16from polar.kit.utils import utc_now 1a

17from polar.logging import Logger 1a

18from polar.models import Customer, CustomerSession 1a

19from polar.postgres import AsyncSession 1a

20 

21from .schemas import CustomerSessionCreate, CustomerSessionCustomerIDCreate 1a

22 

23log: Logger = structlog.get_logger() 1a

24 

25CUSTOMER_SESSION_TOKEN_PREFIX = "polar_cst_" 1a

26 

27 

28class CustomerSessionService(ResourceServiceReader[CustomerSession]): 1a

29 async def create( 1a

30 self, 

31 session: AsyncSession, 

32 auth_subject: AuthSubject[User | Organization], 

33 customer_create: CustomerSessionCreate, 

34 ) -> CustomerSession: 

35 repository = CustomerRepository.from_session(session) 

36 statement = repository.get_readable_statement(auth_subject).options( 

37 joinedload(Customer.organization), 

38 ) 

39 

40 id_field: str 

41 id_value: uuid.UUID | str 

42 if isinstance(customer_create, CustomerSessionCustomerIDCreate): 

43 statement = statement.where(Customer.id == customer_create.customer_id) 

44 id_field = "customer_id" 

45 id_value = customer_create.customer_id 

46 else: 

47 statement = statement.where( 

48 Customer.external_id == customer_create.external_customer_id 

49 ) 

50 id_field = "external_customer_id" 

51 id_value = customer_create.external_customer_id 

52 

53 customer = await repository.get_one_or_none(statement) 

54 

55 if customer is None: 

56 raise PolarRequestValidationError( 

57 [ 

58 { 

59 "loc": ("body", id_field), 

60 "msg": "Customer does not exist.", 

61 "type": "value_error", 

62 "input": id_value, 

63 } 

64 ] 

65 ) 

66 

67 token, customer_session = await self.create_customer_session( 

68 session, customer, customer_create.return_url 

69 ) 

70 customer_session.raw_token = token 

71 return customer_session 

72 

73 async def create_customer_session( 1a

74 self, 

75 session: AsyncSession, 

76 customer: Customer, 

77 return_url: HttpUrl | None = None, 

78 ) -> tuple[str, CustomerSession]: 

79 token, token_hash = generate_token_hash_pair( 

80 secret=settings.SECRET, prefix=CUSTOMER_SESSION_TOKEN_PREFIX 

81 ) 

82 customer_session = CustomerSession( 

83 token=token_hash, 

84 customer=customer, 

85 return_url=str(return_url) if return_url else None, 

86 ) 

87 session.add(customer_session) 

88 await session.flush() 

89 

90 return token, customer_session 

91 

92 async def get_by_token( 1a

93 self, session: AsyncSession, token: str, *, expired: bool = False 

94 ) -> CustomerSession | None: 

95 token_hash = get_token_hash(token, secret=settings.SECRET) 1b

96 statement = ( 1b

97 select(CustomerSession) 

98 .join(CustomerSession.customer) 

99 .where( 

100 CustomerSession.token == token_hash, 

101 CustomerSession.deleted_at.is_(None), 

102 Customer.can_authenticate.is_(True), 

103 ) 

104 .options(contains_eager(CustomerSession.customer)) 

105 ) 

106 if not expired: 106 ↛ 109line 106 didn't jump to line 109 because the condition on line 106 was always true1b

107 statement = statement.where(CustomerSession.expires_at > utc_now()) 1b

108 

109 result = await session.execute(statement) 1b

110 return result.unique().scalar_one_or_none() 

111 

112 async def delete_expired(self, session: AsyncSession) -> None: 1a

113 statement = delete(CustomerSession).where( 

114 CustomerSession.expires_at < utc_now() 

115 ) 

116 await session.execute(statement) 

117 

118 async def revoke_leaked( 1a

119 self, 

120 session: AsyncSession, 

121 token: str, 

122 token_type: TokenType, 

123 *, 

124 notifier: str, 

125 url: str | None, 

126 ) -> bool: 

127 customer_session = await self.get_by_token(session, token) 

128 

129 if customer_session is None: 

130 return False 

131 

132 await session.delete(customer_session) 

133 

134 log.info( 

135 "Revoke leaked customer session token", 

136 id=customer_session.id, 

137 notifier=notifier, 

138 url=url, 

139 ) 

140 

141 return True 

142 

143 

144customer_session = CustomerSessionService(CustomerSession) 1a