Coverage for /usr/local/lib/python3.12/site-packages/prefect/_internal/concurrency/cancellation.py: 22%
272 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
1"""
2Utilities for cancellation in synchronous and asynchronous contexts.
3"""
5import abc 1a
6import asyncio 1a
7import contextlib 1a
8import ctypes 1a
9import math 1a
10import os 1a
11import signal 1a
12import sys 1a
13import threading 1a
14import time 1a
15from types import TracebackType 1a
16from typing import TYPE_CHECKING, Any, Callable, Optional, overload 1a
18import anyio 1a
20from prefect._internal.concurrency import logger 1a
21from prefect._internal.concurrency.event_loop import get_running_loop 1a
23_THREAD_SHIELDS: dict[threading.Thread, "ThreadShield"] = {} 1a
24_THREAD_SHIELDS_LOCK = threading.Lock() 1a
27class ThreadShield: 1a
28 """
29 A wrapper around a reentrant lock for shielding a thread from remote exceptions.
30 This can be used in two ways:
32 1. As a context manager from _another_ thread to wait until the shield is released
33 by a target before sending an exception.
35 2. From the current thread, using `set_exception` to throw the exception when the
36 shield is released.
38 A reentrant lock means that shields can be nested and the exception will only be
39 raised when the last context is exited.
40 """
42 def __init__(self, owner: threading.Thread): 1a
43 # Uses the Python implementation of the RLock instead of the C implementation
44 # because we need to inspect `_count` directly to check if the lock is active
45 # which is needed for delayed exception raising during alarms
46 self._lock = threading._RLock() # type: ignore # yes, we want the private version
47 self._exception = None
48 self._owner = owner
50 def __enter__(self) -> None: 1a
51 self._lock.__enter__()
53 def __exit__(self, *exc_info: Any): 1a
54 retval = self._lock.__exit__(*exc_info)
56 # Raise the exception if this is the last shield to exit in the owner thread
57 if (
58 not self.active()
59 and self._exception
60 and self._owner.ident == threading.current_thread().ident
61 ):
62 # Clear the exception to prevent it from being raised again
63 exc = self._exception
64 self._exception = None
65 raise exc from None
67 return retval
69 def set_exception(self, exc: BaseException): 1a
70 self._exception = exc
72 def active(self) -> bool: 1a
73 """
74 Returns true if the shield is active.
75 """
76 return getattr(self._lock, "_count") > 0
79class CancelledError(asyncio.CancelledError): 1a
80 # We want our `CancelledError` to be treated as a `BaseException` and defining it
81 # here simplifies downstream logic that needs to know "which" cancelled error to
82 # handle.
83 pass 1a
86def _get_thread_shield(thread: threading.Thread) -> ThreadShield: 1a
87 with _THREAD_SHIELDS_LOCK:
88 if thread not in _THREAD_SHIELDS:
89 _THREAD_SHIELDS[thread] = ThreadShield(thread)
91 # Perform garbage collection for old threads
92 for thread_ in tuple(_THREAD_SHIELDS.keys()):
93 if not thread_.is_alive():
94 _THREAD_SHIELDS.pop(thread_)
96 return _THREAD_SHIELDS[thread]
99@contextlib.contextmanager 1a
100def shield(): 1a
101 """
102 Prevent code from within the scope from being cancelled.
104 This guards against cancellation from alarm signals and injected exceptions as used
105 in this module.
107 If an event loop is running in the thread where this is called, it will be shielded
108 from asynchronous cancellation as well.
109 """
110 with (
111 anyio.CancelScope(shield=True)
112 if get_running_loop()
113 else contextlib.nullcontext()
114 ):
115 with _get_thread_shield(threading.current_thread()):
116 yield
119class CancelScope(abc.ABC): 1a
120 """
121 Defines a context where cancellation can be requested.
123 If cancelled, any code within the context should be interrupted. The cancellation
124 implementation varies depending on the environment and may not interrupt some system
125 calls.
127 A timeout can be defined to automatically cancel the scope after a given duration if
128 it has not exited.
129 """
131 def __init__( 1a
132 self, name: Optional[str] = None, timeout: Optional[float] = None
133 ) -> None:
134 self.name = name
135 self._deadline = None
136 self._cancelled = False
137 self._completed = False
138 self._started = False
139 self._start_time = None
140 self._end_time = None
141 self._timeout = timeout
142 self._lock = threading.Lock()
143 self._callbacks: list[Callable[[], None]] = []
144 super().__init__()
146 def __enter__(self): 1a
147 with self._lock:
148 self._deadline = get_deadline(self._timeout)
149 self._started = True
150 self._start_time = time.monotonic()
152 logger.debug("%r entered", self)
153 return self
155 def __exit__( 1a
156 self, exc_type: type[BaseException], exc_val: Exception, exc_tb: TracebackType
157 ) -> Optional[bool]:
158 with self._lock:
159 if not self._cancelled:
160 self._completed = True
161 self._end_time = time.monotonic()
163 logger.debug("%r exited", self)
165 @property 1a
166 def timeout(self): 1a
167 return self._timeout
169 def started(self) -> bool: 1a
170 with self._lock:
171 return self._started
173 def cancelled(self) -> bool: 1a
174 with self._lock:
175 return self._cancelled
177 def timedout(self) -> bool: 1a
178 with self._lock:
179 if not self._end_time or not self._deadline:
180 return False
181 return self._cancelled and self._end_time > self._deadline
183 def set_timeout(self, timeout: float): 1a
184 with self._lock:
185 if self._started:
186 raise RuntimeError("Cannot set timeout after scope has started.")
187 self._timeout = timeout
189 def completed(self): 1a
190 with self._lock:
191 return self._completed
193 def cancel(self, throw: bool = True) -> bool: 1a
194 """
195 Cancel this scope.
197 If `throw` is not set, this will only mark the scope as cancelled and will not
198 throw the cancelled error.
199 """
200 with self._lock:
201 if not self._started:
202 raise RuntimeError("Scope has not been entered.")
204 if self._completed:
205 return False
207 if self._cancelled:
208 return True
210 self._cancelled = True
212 logger.info("%r cancelling", self)
214 for callback in self._callbacks:
215 callback()
217 return True
219 def add_cancel_callback(self, callback: Callable[[], None]): 1a
220 """
221 Add a callback to execute on cancellation.
222 """
223 self._callbacks.append(callback)
225 def __repr__(self) -> str: 1a
226 with self._lock:
227 state = (
228 "completed"
229 if self._completed
230 else (
231 "cancelled"
232 if self._cancelled
233 else "running"
234 if self._started
235 else "pending"
236 )
237 ).upper()
238 timeout = f", timeout={self._timeout:.2f}" if self._timeout else ""
239 runtime = (
240 f", runtime={(self._end_time or time.monotonic()) - self._start_time:.2f}"
241 if self._start_time
242 else ""
243 )
244 name = f", name={self.name!r}" if self.name else f"at {hex(id(self))}"
245 return f"<{type(self).__name__}{name} {state}{timeout}{runtime}>"
248class AsyncCancelScope(CancelScope): 1a
249 def __init__( 1a
250 self, name: Optional[str] = None, timeout: Optional[float] = None
251 ) -> None:
252 super().__init__(name=name, timeout=timeout)
254 def __enter__(self): 1a
255 self.loop = asyncio.get_running_loop()
257 super().__enter__()
259 # Use anyio as the cancellation enforcer because it's very complicated and they
260 # have done a good job
261 self._anyio_scope = anyio.CancelScope(
262 deadline=self._deadline if self._deadline is not None else math.inf
263 ).__enter__()
265 return self
267 def __exit__( 1a
268 self, exc_type: type[BaseException], exc_val: Exception, exc_tb: TracebackType
269 ) -> bool:
270 if self._anyio_scope.cancel_called:
271 # Mark as cancelled
272 self.cancel(throw=False)
274 # TODO: Can we also delete the scope?
275 # We have to exit this scope to prevent leaking memory. A fix for
276 # issue #10952.
277 self._anyio_scope.__exit__(exc_type, exc_val, exc_tb)
279 super().__exit__(exc_type, exc_val, exc_tb)
281 if self.cancelled() and exc_type is not CancelledError:
282 # Ensure cancellation error is propagated on exit
283 raise CancelledError() from exc_val
285 return False
287 def cancel(self, throw: bool = True): 1a
288 if not super().cancel():
289 return False
291 if throw:
292 if self.loop is get_running_loop():
293 self._anyio_scope.cancel()
294 else:
295 # `Task.cancel` is not thread safe
296 self.loop.call_soon_threadsafe(self._anyio_scope.cancel)
298 return True
301class NullCancelScope(CancelScope): 1a
302 """
303 A cancel scope that does nothing.
305 This is used for environments where cancellation is not supported.
306 """
308 def __init__( 1a
309 self,
310 name: Optional[str] = None,
311 timeout: Optional[float] = None,
312 reason: Optional[str] = None,
313 ) -> None:
314 super().__init__(name, timeout)
315 self.reason = reason or "null cancel scope"
317 def cancel(self, throw: bool = True) -> bool: 1a
318 logger.warning("%r cannot cancel %s.", self, self.reason)
319 return False
322class AlarmCancelScope(CancelScope): 1a
323 """
324 A cancel scope that uses an alarm signal which can interrupt long-running system
325 calls.
327 Only the main thread can be cancelled with an alarm signal, so this scope is only
328 available in the main thread.
329 """
331 def __enter__(self): 1a
332 super().__enter__()
334 current_thread = threading.current_thread()
335 self._previous_timer = None
337 if current_thread is not threading.main_thread():
338 raise ValueError(
339 "Alarm based timeouts can only be used in the main thread."
340 )
342 self._previous_alarm_handler = signal.getsignal(signal.SIGALRM)
344 if self._previous_alarm_handler != signal.SIG_DFL:
345 logger.warning(
346 "%r overriding existing alarm handler %s",
347 self,
348 self._previous_alarm_handler,
349 )
351 # Capture alarm signals and raise a timeout
352 signal.signal(signal.SIGALRM, self._sigalarm_to_error)
354 # Set a timer to raise an alarm signal
355 if self.timeout is not None:
356 # Use `setitimer` instead of `signal.alarm` for float support; raises a SIGALRM
357 logger.debug("%r set alarm timer for %f seconds", self, self.timeout)
358 self._previous_timer = signal.setitimer(signal.ITIMER_REAL, self.timeout)
360 return self
362 def _sigalarm_to_error(self, *args: object) -> None: 1a
363 logger.debug("%r captured alarm raising as cancelled error", self)
364 if self.cancel(throw=False):
365 shield = _get_thread_shield(threading.main_thread())
366 if shield.active():
367 logger.debug("%r thread shield active; delaying exception", self)
368 shield.set_exception(CancelledError())
369 else:
370 raise CancelledError()
372 def __exit__(self, *_: Any) -> Optional[bool]: 1a
373 retval = super().__exit__(*_)
375 if self.timeout is not None:
376 # Restore the previous timer
377 if TYPE_CHECKING:
378 assert self._previous_timer is not None
379 signal.setitimer(signal.ITIMER_REAL, *self._previous_timer)
381 # Restore the previous signal handler
382 signal.signal(signal.SIGALRM, self._previous_alarm_handler)
384 return retval
386 def cancel(self, throw: bool = True): 1a
387 if not super().cancel():
388 return False
390 if throw:
391 logger.debug("%r sending alarm signal to main thread", self)
392 os.kill(os.getpid(), signal.SIGALRM)
394 return True
397class WatcherThreadCancelScope(CancelScope): 1a
398 """
399 A cancel scope that uses a watcher thread and an injected exception to enforce
400 cancellation.
402 The injected exception cannot interrupt calls and will be raised on the ~next
403 instruction. This can raise exceptions in unexpected places. See `shield` for
404 guarding against interruption.
406 If a timeout is specified, a watcher thread is spawned that will run for `timeout`
407 seconds then send the exception to the supervised thread.
408 """
410 def __enter__(self): 1a
411 super().__enter__()
412 self._event = threading.Event()
413 self._enforcer_thread = None
414 self._supervised_thread = threading.current_thread()
416 if self.timeout is not None:
417 name = self.name or f"for scope {hex(id(self))}"
418 self._enforcer_thread = threading.Thread(
419 target=self._timeout_enforcer,
420 name=f"timeout-watcher {name} {self.timeout:.2f}",
421 )
422 self._enforcer_thread.start()
424 return self
426 def __exit__(self, *_: Any) -> Optional[bool]: 1a
427 retval = super().__exit__(*_)
428 self._event.set()
429 if self._enforcer_thread:
430 logger.debug("%r joining enforcer thread %r", self, self._enforcer_thread)
431 self._enforcer_thread.join()
432 return retval
434 def _send_cancelled_error(self): 1a
435 """
436 Send a cancelled error to the supervised thread.
437 """
438 if self._supervised_thread.is_alive():
439 logger.debug(
440 "%r sending exception to supervised thread %r",
441 self,
442 self._supervised_thread,
443 )
444 with _get_thread_shield(self._supervised_thread):
445 try:
446 _send_exception_to_thread(self._supervised_thread, CancelledError)
447 except ValueError:
448 # If the thread is gone; just move on without error
449 logger.debug("Thread missing!")
451 def _timeout_enforcer(self): 1a
452 """
453 Target for a thread that enforces a timeout.
454 """
455 if not self._event.wait(self.timeout):
456 logger.debug("%r enforcer detected timeout!", self)
457 if self.cancel(throw=False):
458 with _get_thread_shield(self._supervised_thread):
459 self._send_cancelled_error()
461 # Wait for the supervised thread to exit its context
462 logger.debug("%r waiting for supervised thread to exit", self)
463 self._event.wait()
465 def cancel(self, throw: bool = True): 1a
466 if not super().cancel():
467 return False
469 if throw:
470 self._send_cancelled_error()
472 return True
475@overload 1a
476def get_deadline(timeout: float) -> float: ... 476 ↛ exitline 476 didn't return from function 'get_deadline' because 1a
479@overload 1a
480def get_deadline(timeout: None) -> None: ... 480 ↛ exitline 480 didn't return from function 'get_deadline' because 1a
483def get_deadline(timeout: Optional[float]) -> Optional[float]: 1a
484 """
485 Compute an deadline given a timeout.
487 Uses a monotonic clock.
488 """
489 if timeout is None:
490 return None
492 return time.monotonic() + timeout
495def get_timeout(deadline: Optional[float]): 1a
496 """
497 Compute an timeout given a deadline.
499 Uses a monotonic clock.
500 """
501 if deadline is None:
502 return None
504 return max(0, deadline - time.monotonic())
507@contextlib.contextmanager 1a
508def cancel_async_at(deadline: Optional[float], name: Optional[str] = None): 1a
509 """
510 Cancel any async calls within the context if it does not exit by the given deadline.
512 Deadlines must be computed with the monotonic clock. See `get_deadline`.
514 A timeout error will be raised on the next `await` when the timeout expires.
516 Yields a `CancelContext`.
517 """
518 with cancel_async_after(get_timeout(deadline), name=name) as ctx:
519 yield ctx
522@contextlib.contextmanager 1a
523def cancel_async_after(timeout: Optional[float], name: Optional[str] = None): 1a
524 """
525 Cancel any async calls within the context if it does not exit after the given
526 timeout.
528 A timeout error will be raised on the next `await` when the timeout expires.
530 Yields a `CancelContext`.
531 """
532 with AsyncCancelScope(timeout=timeout, name=name) as ctx:
533 yield ctx
536@contextlib.contextmanager 1a
537def cancel_sync_at(deadline: Optional[float], name: Optional[str] = None): 1a
538 """
539 Cancel any sync calls within the context if it does not exit by the given deadline.
541 Deadlines must be computed with the monotonic clock. See `get_deadline`.
543 The cancel method varies depending on if this is called in the main thread or not.
544 See `cancel_sync_after` for details
546 Yields a `CancelContext`.
547 """
548 timeout = max(0, deadline - time.monotonic()) if deadline is not None else None
550 with cancel_sync_after(timeout, name=name) as ctx:
551 yield ctx
554@contextlib.contextmanager 1a
555def cancel_sync_after(timeout: Optional[float], name: Optional[str] = None): 1a
556 """
557 Cancel any sync calls within the context if it does not exit after the given
558 timeout.
560 The timeout method varies depending on if this is called in the main thread or not.
561 See `AlarmCancelScope` and `WatcherThreadCancelScope` for details.
563 Yields a `CancelContext`.
564 """
566 if sys.platform.startswith("win"):
567 yield NullCancelScope(reason="cancellation is not supported on Windows")
568 return
570 thread = threading.current_thread()
571 existing_alarm_handler = signal.getsignal(signal.SIGALRM) != signal.SIG_DFL
573 if (
574 thread is threading.main_thread()
575 # Avoid nested alarm handlers; it's hard to follow and they will interfere with
576 # each other
577 and not existing_alarm_handler
578 # Avoid using an alarm when there is no timeout; it's better saved for that case
579 and timeout is not None
580 ):
581 scope = AlarmCancelScope(name=name, timeout=timeout)
582 else:
583 scope = WatcherThreadCancelScope(name=name, timeout=timeout)
585 with scope:
586 yield scope
589def _send_exception_to_thread(thread: threading.Thread, exc_type: type[BaseException]): 1a
590 """
591 Raise an exception in a thread.
593 This will not interrupt long-running system calls like `sleep` or `wait`.
594 """
595 if not thread.ident:
596 raise ValueError("Thread is not started.")
597 ret = ctypes.pythonapi.PyThreadState_SetAsyncExc(
598 ctypes.c_long(thread.ident), ctypes.py_object(exc_type)
599 )
600 if ret == 0:
601 raise ValueError("Thread not found.")