Coverage for /usr/local/lib/python3.12/site-packages/prefect/locking/memory.py: 0%

142 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 11:21 +0000

1from __future__ import annotations 

2 

3import asyncio 

4import logging 

5import threading 

6from typing import Any, TypedDict 

7 

8from typing_extensions import Self 

9 

10from prefect.logging.loggers import get_logger 

11 

12from .protocol import LockManager 

13 

14logger: logging.Logger = get_logger("locking.memory") 

15 

16 

17class _LockInfo(TypedDict): 

18 """ 

19 A dictionary containing information about a lock. 

20 

21 Attributes: 

22 holder: The holder of the lock. 

23 lock: The lock object. 

24 expiration_timer: The timer for the lock expiration 

25 """ 

26 

27 holder: str 

28 lock: threading.Lock 

29 expiration_timer: threading.Timer | None 

30 

31 

32class MemoryLockManager(LockManager): 

33 """ 

34 A lock manager that stores lock information in memory. 

35 

36 Note: because this lock manager stores data in memory, it is not suitable for 

37 use in a distributed environment or across different processes. 

38 """ 

39 

40 _instance = None 

41 _initialized = False 

42 

43 def __new__(cls, *args: Any, **kwargs: Any) -> Self: 

44 if cls._instance is None: 

45 cls._instance = super().__new__(cls) 

46 return cls._instance 

47 

48 def __init__(self): 

49 if self.__class__._initialized: 

50 return 

51 self._locks_dict_lock = threading.Lock() 

52 self._locks: dict[str, _LockInfo] = {} 

53 self.__class__._initialized = True 

54 

55 def _expire_lock(self, key: str): 

56 """ 

57 Expire the lock for the given key. 

58 

59 Used as a callback for the expiration timer of a lock. 

60 

61 Args: 

62 key: The key of the lock to expire. 

63 """ 

64 with self._locks_dict_lock: 

65 if key in self._locks: 

66 lock_info = self._locks[key] 

67 if lock_info["lock"].locked(): 

68 lock_info["lock"].release() 

69 if lock_info["expiration_timer"]: 

70 lock_info["expiration_timer"].cancel() 

71 del self._locks[key] 

72 

73 def acquire_lock( 

74 self, 

75 key: str, 

76 holder: str, 

77 acquire_timeout: float | None = None, 

78 hold_timeout: float | None = None, 

79 ) -> bool: 

80 with self._locks_dict_lock: 

81 if key not in self._locks: 

82 lock = threading.Lock() 

83 lock.acquire() 

84 expiration_timer = None 

85 if hold_timeout is not None: 

86 expiration_timer = threading.Timer( 

87 hold_timeout, self._expire_lock, args=(key,) 

88 ) 

89 expiration_timer.start() 

90 self._locks[key] = _LockInfo( 

91 holder=holder, lock=lock, expiration_timer=expiration_timer 

92 ) 

93 return True 

94 elif self._locks[key]["holder"] == holder: 

95 return True 

96 else: 

97 existing_lock_info = self._locks[key] 

98 

99 if acquire_timeout is not None: 

100 existing_lock_acquired = existing_lock_info["lock"].acquire( 

101 timeout=acquire_timeout 

102 ) 

103 else: 

104 existing_lock_acquired = existing_lock_info["lock"].acquire() 

105 

106 if existing_lock_acquired: 

107 with self._locks_dict_lock: 

108 if ( 

109 expiration_timer := existing_lock_info["expiration_timer"] 

110 ) is not None: 

111 expiration_timer.cancel() 

112 expiration_timer = None 

113 if hold_timeout is not None: 

114 expiration_timer = threading.Timer( 

115 hold_timeout, self._expire_lock, args=(key,) 

116 ) 

117 expiration_timer.start() 

118 self._locks[key] = _LockInfo( 

119 holder=holder, 

120 lock=existing_lock_info["lock"], 

121 expiration_timer=expiration_timer, 

122 ) 

123 return True 

124 return False 

125 

126 async def aacquire_lock( 

127 self, 

128 key: str, 

129 holder: str, 

130 acquire_timeout: float | None = None, 

131 hold_timeout: float | None = None, 

132 ) -> bool: 

133 with self._locks_dict_lock: 

134 if key not in self._locks: 

135 lock = threading.Lock() 

136 lock.acquire() 

137 expiration_timer = None 

138 if hold_timeout is not None: 

139 expiration_timer = threading.Timer( 

140 hold_timeout, self._expire_lock, args=(key,) 

141 ) 

142 expiration_timer.start() 

143 self._locks[key] = _LockInfo( 

144 holder=holder, lock=lock, expiration_timer=expiration_timer 

145 ) 

146 return True 

147 elif self._locks[key]["holder"] == holder: 

148 return True 

149 else: 

150 existing_lock_info = self._locks[key] 

151 

152 if acquire_timeout is not None: 

153 existing_lock_acquired = await asyncio.to_thread( 

154 existing_lock_info["lock"].acquire, timeout=acquire_timeout 

155 ) 

156 else: 

157 existing_lock_acquired = await asyncio.to_thread( 

158 existing_lock_info["lock"].acquire 

159 ) 

160 

161 if existing_lock_acquired: 

162 with self._locks_dict_lock: 

163 if ( 

164 expiration_timer := existing_lock_info["expiration_timer"] 

165 ) is not None: 

166 expiration_timer.cancel() 

167 expiration_timer = None 

168 if hold_timeout is not None: 

169 expiration_timer = threading.Timer( 

170 hold_timeout, self._expire_lock, args=(key,) 

171 ) 

172 expiration_timer.start() 

173 self._locks[key] = _LockInfo( 

174 holder=holder, 

175 lock=existing_lock_info["lock"], 

176 expiration_timer=expiration_timer, 

177 ) 

178 return True 

179 return False 

180 

181 def release_lock(self, key: str, holder: str) -> None: 

182 with self._locks_dict_lock: 

183 if key in self._locks and self._locks[key]["holder"] == holder: 

184 if ( 

185 expiration_timer := self._locks[key]["expiration_timer"] 

186 ) is not None: 

187 expiration_timer.cancel() 

188 self._locks[key]["lock"].release() 

189 del self._locks[key] 

190 else: 

191 raise ValueError( 

192 f"No lock held by {holder} for transaction with key {key}" 

193 ) 

194 

195 def is_locked(self, key: str) -> bool: 

196 return key in self._locks and self._locks[key]["lock"].locked() 

197 

198 def is_lock_holder(self, key: str, holder: str) -> bool: 

199 lock_info = self._locks.get(key) 

200 return ( 

201 lock_info is not None 

202 and lock_info["lock"].locked() 

203 and lock_info["holder"] == holder 

204 ) 

205 

206 def wait_for_lock(self, key: str, timeout: float | None = None) -> bool: 

207 lock_info: _LockInfo | None = self._locks.get(key) 

208 if lock_info is None: 

209 return True 

210 if lock_info["lock"].locked(): 

211 if timeout is not None: 

212 lock_acquired = lock_info["lock"].acquire(timeout=timeout) 

213 else: 

214 lock_acquired = lock_info["lock"].acquire() 

215 if lock_acquired: 

216 lock_info["lock"].release() 

217 return lock_acquired 

218 return True 

219 

220 async def await_for_lock(self, key: str, timeout: float | None = None) -> bool: 

221 lock_info: _LockInfo | None = self._locks.get(key, None) 

222 if lock_info is None: 

223 return True 

224 if lock_info["lock"].locked(): 

225 if timeout is not None: 

226 lock_acquired = await asyncio.to_thread( 

227 lock_info["lock"].acquire, timeout=timeout 

228 ) 

229 else: 

230 lock_acquired = await asyncio.to_thread(lock_info["lock"].acquire) 

231 if lock_acquired: 

232 lock_info["lock"].release() 

233 return lock_acquired 

234 return True 

235 

236 def __getstate__(self) -> dict[str, Any]: 

237 """ 

238 Prepare the lock manager for serialization. 

239 

240 If there are any locks held, log a warning that lock information will not 

241 be available after deserialization. 

242 """ 

243 state = self.__dict__.copy() 

244 

245 # Check if there are any locks held 

246 if self._locks: 

247 logger.warning( 

248 "Serializing MemoryLockManager with %d active lock(s). " 

249 "Lock information will not be available after deserialization.", 

250 len(self._locks), 

251 ) 

252 

253 # Remove unpicklable objects 

254 # The _locks_dict_lock will be recreated in __setstate__ 

255 state.pop("_locks_dict_lock", None) 

256 # Clear all locks since threading.Lock objects cannot be pickled 

257 state.pop("_locks", None) 

258 

259 return state 

260 

261 def __setstate__(self, state: dict[str, Any]) -> None: 

262 """ 

263 Restore the lock manager after deserialization. 

264 

265 Reinitializes the lock manager with empty locks and a new lock for the 

266 locks dictionary. 

267 """ 

268 self.__dict__.update(state) 

269 # Handle case where the lock manager is being deserialized in the same process as the original instance 

270 if self.__class__._initialized: 

271 return 

272 self._locks_dict_lock = threading.Lock() 

273 self._locks: dict[str, _LockInfo] = {} 

274 self.__class__._initialized = True