Coverage for polar/kit/jwt.py: 39%

27 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 15:52 +0000

1from datetime import datetime, timedelta 1ab

2from typing import Any, Literal 1ab

3 

4import jwt 1ab

5 

6from .utils import utc_now 1ab

7 

8DEFAULT_EXPIRATION = 60 * 15 # 15 minutes 1ab

9ALGORITHM = "HS256" 1ab

10 

11DecodeError = jwt.DecodeError 1ab

12ExpiredSignatureError = jwt.ExpiredSignatureError 1ab

13 

14 

15def create_expiration_dt(seconds: int) -> datetime: 1ab

16 return utc_now() + timedelta(seconds=seconds) 

17 

18 

19TYPE = Literal[ 1ab

20 "github_oauth", 

21 "discord_oauth", 

22 "google_oauth", 

23 "apple_oauth", 

24 "discord_guild_token", 

25 "auth", 

26 "github_repository_benefit_oauth", 

27 "customer_oauth", 

28] 

29 

30 

31def encode( 1ab

32 *, 

33 data: dict[str, Any], 

34 secret: str, 

35 expires_at: datetime | None = None, 

36 expires_in: int | None = DEFAULT_EXPIRATION, 

37 type: TYPE, 

38) -> str: 

39 if type: 

40 data["type"] = type 

41 

42 to_encode = data.copy() 

43 if not expires_at: 

44 expires_in = expires_in or DEFAULT_EXPIRATION 

45 expires_at = create_expiration_dt(seconds=expires_in) 

46 

47 to_encode["exp"] = expires_at 

48 return jwt.encode(to_encode, secret, algorithm=ALGORITHM) 

49 

50 

51def decode_unsafe(*, token: str, secret: str) -> dict[str, Any]: 1ab

52 return jwt.decode(token, secret, algorithms=[ALGORITHM]) 

53 

54 

55def decode( 1ab

56 *, 

57 token: str, 

58 secret: str, 

59 type: TYPE, 

60) -> dict[str, Any]: 

61 res = decode_unsafe(token=token, secret=secret) 

62 

63 if res.get("type", "") != type: 

64 raise Exception( 

65 "JWT of unexpected type, expected '%s' got '%s'", type, res.get("type", "") 

66 ) 

67 

68 return res