Coverage for /usr/local/lib/python3.12/site-packages/prefect/locking/memory.py: 0%
142 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 13:38 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 13:38 +0000
1from __future__ import annotations
3import asyncio
4import logging
5import threading
6from typing import Any, TypedDict
8from typing_extensions import Self
10from prefect.logging.loggers import get_logger
12from .protocol import LockManager
14logger: logging.Logger = get_logger("locking.memory")
17class _LockInfo(TypedDict):
18 """
19 A dictionary containing information about a lock.
21 Attributes:
22 holder: The holder of the lock.
23 lock: The lock object.
24 expiration_timer: The timer for the lock expiration
25 """
27 holder: str
28 lock: threading.Lock
29 expiration_timer: threading.Timer | None
32class MemoryLockManager(LockManager):
33 """
34 A lock manager that stores lock information in memory.
36 Note: because this lock manager stores data in memory, it is not suitable for
37 use in a distributed environment or across different processes.
38 """
40 _instance = None
41 _initialized = False
43 def __new__(cls, *args: Any, **kwargs: Any) -> Self:
44 if cls._instance is None:
45 cls._instance = super().__new__(cls)
46 return cls._instance
48 def __init__(self):
49 if self.__class__._initialized:
50 return
51 self._locks_dict_lock = threading.Lock()
52 self._locks: dict[str, _LockInfo] = {}
53 self.__class__._initialized = True
55 def _expire_lock(self, key: str):
56 """
57 Expire the lock for the given key.
59 Used as a callback for the expiration timer of a lock.
61 Args:
62 key: The key of the lock to expire.
63 """
64 with self._locks_dict_lock:
65 if key in self._locks:
66 lock_info = self._locks[key]
67 if lock_info["lock"].locked():
68 lock_info["lock"].release()
69 if lock_info["expiration_timer"]:
70 lock_info["expiration_timer"].cancel()
71 del self._locks[key]
73 def acquire_lock(
74 self,
75 key: str,
76 holder: str,
77 acquire_timeout: float | None = None,
78 hold_timeout: float | None = None,
79 ) -> bool:
80 with self._locks_dict_lock:
81 if key not in self._locks:
82 lock = threading.Lock()
83 lock.acquire()
84 expiration_timer = None
85 if hold_timeout is not None:
86 expiration_timer = threading.Timer(
87 hold_timeout, self._expire_lock, args=(key,)
88 )
89 expiration_timer.start()
90 self._locks[key] = _LockInfo(
91 holder=holder, lock=lock, expiration_timer=expiration_timer
92 )
93 return True
94 elif self._locks[key]["holder"] == holder:
95 return True
96 else:
97 existing_lock_info = self._locks[key]
99 if acquire_timeout is not None:
100 existing_lock_acquired = existing_lock_info["lock"].acquire(
101 timeout=acquire_timeout
102 )
103 else:
104 existing_lock_acquired = existing_lock_info["lock"].acquire()
106 if existing_lock_acquired:
107 with self._locks_dict_lock:
108 if (
109 expiration_timer := existing_lock_info["expiration_timer"]
110 ) is not None:
111 expiration_timer.cancel()
112 expiration_timer = None
113 if hold_timeout is not None:
114 expiration_timer = threading.Timer(
115 hold_timeout, self._expire_lock, args=(key,)
116 )
117 expiration_timer.start()
118 self._locks[key] = _LockInfo(
119 holder=holder,
120 lock=existing_lock_info["lock"],
121 expiration_timer=expiration_timer,
122 )
123 return True
124 return False
126 async def aacquire_lock(
127 self,
128 key: str,
129 holder: str,
130 acquire_timeout: float | None = None,
131 hold_timeout: float | None = None,
132 ) -> bool:
133 with self._locks_dict_lock:
134 if key not in self._locks:
135 lock = threading.Lock()
136 lock.acquire()
137 expiration_timer = None
138 if hold_timeout is not None:
139 expiration_timer = threading.Timer(
140 hold_timeout, self._expire_lock, args=(key,)
141 )
142 expiration_timer.start()
143 self._locks[key] = _LockInfo(
144 holder=holder, lock=lock, expiration_timer=expiration_timer
145 )
146 return True
147 elif self._locks[key]["holder"] == holder:
148 return True
149 else:
150 existing_lock_info = self._locks[key]
152 if acquire_timeout is not None:
153 existing_lock_acquired = await asyncio.to_thread(
154 existing_lock_info["lock"].acquire, timeout=acquire_timeout
155 )
156 else:
157 existing_lock_acquired = await asyncio.to_thread(
158 existing_lock_info["lock"].acquire
159 )
161 if existing_lock_acquired:
162 with self._locks_dict_lock:
163 if (
164 expiration_timer := existing_lock_info["expiration_timer"]
165 ) is not None:
166 expiration_timer.cancel()
167 expiration_timer = None
168 if hold_timeout is not None:
169 expiration_timer = threading.Timer(
170 hold_timeout, self._expire_lock, args=(key,)
171 )
172 expiration_timer.start()
173 self._locks[key] = _LockInfo(
174 holder=holder,
175 lock=existing_lock_info["lock"],
176 expiration_timer=expiration_timer,
177 )
178 return True
179 return False
181 def release_lock(self, key: str, holder: str) -> None:
182 with self._locks_dict_lock:
183 if key in self._locks and self._locks[key]["holder"] == holder:
184 if (
185 expiration_timer := self._locks[key]["expiration_timer"]
186 ) is not None:
187 expiration_timer.cancel()
188 self._locks[key]["lock"].release()
189 del self._locks[key]
190 else:
191 raise ValueError(
192 f"No lock held by {holder} for transaction with key {key}"
193 )
195 def is_locked(self, key: str) -> bool:
196 return key in self._locks and self._locks[key]["lock"].locked()
198 def is_lock_holder(self, key: str, holder: str) -> bool:
199 lock_info = self._locks.get(key)
200 return (
201 lock_info is not None
202 and lock_info["lock"].locked()
203 and lock_info["holder"] == holder
204 )
206 def wait_for_lock(self, key: str, timeout: float | None = None) -> bool:
207 lock_info: _LockInfo | None = self._locks.get(key)
208 if lock_info is None:
209 return True
210 if lock_info["lock"].locked():
211 if timeout is not None:
212 lock_acquired = lock_info["lock"].acquire(timeout=timeout)
213 else:
214 lock_acquired = lock_info["lock"].acquire()
215 if lock_acquired:
216 lock_info["lock"].release()
217 return lock_acquired
218 return True
220 async def await_for_lock(self, key: str, timeout: float | None = None) -> bool:
221 lock_info: _LockInfo | None = self._locks.get(key, None)
222 if lock_info is None:
223 return True
224 if lock_info["lock"].locked():
225 if timeout is not None:
226 lock_acquired = await asyncio.to_thread(
227 lock_info["lock"].acquire, timeout=timeout
228 )
229 else:
230 lock_acquired = await asyncio.to_thread(lock_info["lock"].acquire)
231 if lock_acquired:
232 lock_info["lock"].release()
233 return lock_acquired
234 return True
236 def __getstate__(self) -> dict[str, Any]:
237 """
238 Prepare the lock manager for serialization.
240 If there are any locks held, log a warning that lock information will not
241 be available after deserialization.
242 """
243 state = self.__dict__.copy()
245 # Check if there are any locks held
246 if self._locks:
247 logger.warning(
248 "Serializing MemoryLockManager with %d active lock(s). "
249 "Lock information will not be available after deserialization.",
250 len(self._locks),
251 )
253 # Remove unpicklable objects
254 # The _locks_dict_lock will be recreated in __setstate__
255 state.pop("_locks_dict_lock", None)
256 # Clear all locks since threading.Lock objects cannot be pickled
257 state.pop("_locks", None)
259 return state
261 def __setstate__(self, state: dict[str, Any]) -> None:
262 """
263 Restore the lock manager after deserialization.
265 Reinitializes the lock manager with empty locks and a new lock for the
266 locks dictionary.
267 """
268 self.__dict__.update(state)
269 # Handle case where the lock manager is being deserialized in the same process as the original instance
270 if self.__class__._initialized:
271 return
272 self._locks_dict_lock = threading.Lock()
273 self._locks: dict[str, _LockInfo] = {}
274 self.__class__._initialized = True