Coverage for polar/locker.py: 48%

44 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 15:52 +0000

1import contextlib 1a

2from collections.abc import AsyncGenerator 1a

3 

4import logfire 1a

5import structlog 1a

6from fastapi import Depends 1a

7from redis.asyncio.lock import Lock 1a

8from redis.exceptions import LockNotOwnedError 1a

9 

10from polar.exceptions import PolarError 1a

11from polar.logging import Logger 1a

12from polar.redis import Redis, get_redis 1a

13 

14log: Logger = structlog.get_logger() 1a

15 

16 

17class LockerError(PolarError): 1a

18 def __init__( 1a

19 self, 

20 message: str = "A concurrency error occured. Try again later.", 

21 status_code: int = 500, 

22 ) -> None: 

23 super().__init__(message, status_code) 

24 

25 

26class TimeoutLockError(LockerError): 1a

27 pass 1a

28 

29 

30class Locker: 1a

31 """ 

32 Helper class to acquire distributed locks. 

33 """ 

34 

35 def __init__(self, redis: Redis) -> None: 1a

36 self.redis = redis 

37 

38 @contextlib.asynccontextmanager 1a

39 async def lock( 1a

40 self, 

41 name: str, 

42 *, 

43 timeout: float, 

44 blocking_timeout: float, 

45 sleep: float = 0.1, 

46 thread_local: bool = True, 

47 ) -> AsyncGenerator[Lock, None]: 

48 """ 

49 Acquire a distributed lock on the Redis server. 

50 

51 Args: 

52 name: Name of the lock. Automatically prefixed by `polarlock:`. 

53 timeout: The lifetime of the lock in seconds. 

54 blocking_timeout: The maximum amount of time in seconds to spend trying 

55 to acquire the lock. 

56 sleep: Amount of time in seconds to sleep between each iteration. 

57 Defaults to 0.1 seconds. 

58 

59 Raises: 

60 TimeoutLockError: The lock could not be acquired within `blocking_timeout` 

61 limit. 

62 """ 

63 lock = Lock( 

64 self.redis, 

65 self._get_key(name), 

66 timeout=timeout, 

67 sleep=sleep, 

68 blocking=True, 

69 blocking_timeout=blocking_timeout, 

70 thread_local=thread_local, 

71 ) 

72 

73 with logfire.span( 

74 "Acquire distributed lock {name}", 

75 name=name, 

76 timeout=timeout, 

77 blocking_timeout=blocking_timeout, 

78 ): 

79 log.debug("try to acquire lock", name=name) 

80 acquired = await lock.acquire() 

81 

82 if not acquired: 

83 log.error( 

84 "could not acquire lock before set limit", 

85 name=name, 

86 blocking_timeout=blocking_timeout, 

87 ) 

88 raise TimeoutLockError() 

89 else: 

90 log.debug("acquired lock", name=name) 

91 

92 with logfire.span( 

93 "Distributed lock {name} acquired", 

94 name=name, 

95 timeout=timeout, 

96 blocking_timeout=blocking_timeout, 

97 ): 

98 try: 

99 yield lock 

100 finally: 

101 try: 

102 await lock.release() 

103 except LockNotOwnedError: 

104 log.warning( 

105 "Already expired lock cannot be released", 

106 name=name, 

107 timeout=timeout, 

108 ) 

109 else: 

110 log.debug("released lock", name=name) 

111 

112 async def is_locked(self, name: str) -> bool: 1a

113 """ 

114 Check if a lock is currently held. 

115 

116 Args: 

117 name: Name of the lock. Automatically prefixed by `polarlock:`. 

118 

119 Returns: 

120 bool: True if the lock is currently held, False otherwise. 

121 """ 

122 lock = Lock(self.redis, self._get_key(name)) 

123 return await lock.locked() 

124 

125 def _get_key(self, name: str) -> str: 1a

126 return f"polarlock:{name}" 

127 

128 

129async def get_locker(redis: Redis = Depends(get_redis)) -> Locker: 1a

130 return Locker(redis)