Coverage for polar/customer_session/service.py: 48%
61 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
1import uuid 1a
3import structlog 1a
4from pydantic import HttpUrl 1a
5from sqlalchemy import delete, select 1a
6from sqlalchemy.orm import joinedload 1a
7from sqlalchemy.orm.strategy_options import contains_eager 1a
9from polar.auth.models import AuthSubject, Organization, User 1a
10from polar.config import settings 1a
11from polar.customer.repository import CustomerRepository 1a
12from polar.enums import TokenType 1a
13from polar.exceptions import PolarRequestValidationError 1a
14from polar.kit.crypto import generate_token_hash_pair, get_token_hash 1a
15from polar.kit.services import ResourceServiceReader 1a
16from polar.kit.utils import utc_now 1a
17from polar.logging import Logger 1a
18from polar.models import Customer, CustomerSession 1a
19from polar.postgres import AsyncSession 1a
21from .schemas import CustomerSessionCreate, CustomerSessionCustomerIDCreate 1a
23log: Logger = structlog.get_logger() 1a
25CUSTOMER_SESSION_TOKEN_PREFIX = "polar_cst_" 1a
28class CustomerSessionService(ResourceServiceReader[CustomerSession]): 1a
29 async def create( 1a
30 self,
31 session: AsyncSession,
32 auth_subject: AuthSubject[User | Organization],
33 customer_create: CustomerSessionCreate,
34 ) -> CustomerSession:
35 repository = CustomerRepository.from_session(session)
36 statement = repository.get_readable_statement(auth_subject).options(
37 joinedload(Customer.organization),
38 )
40 id_field: str
41 id_value: uuid.UUID | str
42 if isinstance(customer_create, CustomerSessionCustomerIDCreate):
43 statement = statement.where(Customer.id == customer_create.customer_id)
44 id_field = "customer_id"
45 id_value = customer_create.customer_id
46 else:
47 statement = statement.where(
48 Customer.external_id == customer_create.external_customer_id
49 )
50 id_field = "external_customer_id"
51 id_value = customer_create.external_customer_id
53 customer = await repository.get_one_or_none(statement)
55 if customer is None:
56 raise PolarRequestValidationError(
57 [
58 {
59 "loc": ("body", id_field),
60 "msg": "Customer does not exist.",
61 "type": "value_error",
62 "input": id_value,
63 }
64 ]
65 )
67 token, customer_session = await self.create_customer_session(
68 session, customer, customer_create.return_url
69 )
70 customer_session.raw_token = token
71 return customer_session
73 async def create_customer_session( 1a
74 self,
75 session: AsyncSession,
76 customer: Customer,
77 return_url: HttpUrl | None = None,
78 ) -> tuple[str, CustomerSession]:
79 token, token_hash = generate_token_hash_pair(
80 secret=settings.SECRET, prefix=CUSTOMER_SESSION_TOKEN_PREFIX
81 )
82 customer_session = CustomerSession(
83 token=token_hash,
84 customer=customer,
85 return_url=str(return_url) if return_url else None,
86 )
87 session.add(customer_session)
88 await session.flush()
90 return token, customer_session
92 async def get_by_token( 1a
93 self, session: AsyncSession, token: str, *, expired: bool = False
94 ) -> CustomerSession | None:
95 token_hash = get_token_hash(token, secret=settings.SECRET) 1b
96 statement = ( 1b
97 select(CustomerSession)
98 .join(CustomerSession.customer)
99 .where(
100 CustomerSession.token == token_hash,
101 CustomerSession.deleted_at.is_(None),
102 Customer.can_authenticate.is_(True),
103 )
104 .options(contains_eager(CustomerSession.customer))
105 )
106 if not expired: 106 ↛ 109line 106 didn't jump to line 109 because the condition on line 106 was always true1b
107 statement = statement.where(CustomerSession.expires_at > utc_now()) 1b
109 result = await session.execute(statement) 1b
110 return result.unique().scalar_one_or_none()
112 async def delete_expired(self, session: AsyncSession) -> None: 1a
113 statement = delete(CustomerSession).where(
114 CustomerSession.expires_at < utc_now()
115 )
116 await session.execute(statement)
118 async def revoke_leaked( 1a
119 self,
120 session: AsyncSession,
121 token: str,
122 token_type: TokenType,
123 *,
124 notifier: str,
125 url: str | None,
126 ) -> bool:
127 customer_session = await self.get_by_token(session, token)
129 if customer_session is None:
130 return False
132 await session.delete(customer_session)
134 log.info(
135 "Revoke leaked customer session token",
136 id=customer_session.id,
137 notifier=notifier,
138 url=url,
139 )
141 return True
144customer_session = CustomerSessionService(CustomerSession) 1a