Coverage for /usr/local/lib/python3.12/site-packages/prefect/_internal/concurrency/cancellation.py: 43%
272 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
1"""
2Utilities for cancellation in synchronous and asynchronous contexts.
3"""
5import abc 1a
6import asyncio 1a
7import contextlib 1a
8import ctypes 1a
9import math 1a
10import os 1a
11import signal 1a
12import sys 1a
13import threading 1a
14import time 1a
15from types import TracebackType 1a
16from typing import TYPE_CHECKING, Any, Callable, Optional, overload 1a
18import anyio 1a
20from prefect._internal.concurrency import logger 1a
21from prefect._internal.concurrency.event_loop import get_running_loop 1a
23_THREAD_SHIELDS: dict[threading.Thread, "ThreadShield"] = {} 1a
24_THREAD_SHIELDS_LOCK = threading.Lock() 1a
27class ThreadShield: 1a
28 """
29 A wrapper around a reentrant lock for shielding a thread from remote exceptions.
30 This can be used in two ways:
32 1. As a context manager from _another_ thread to wait until the shield is released
33 by a target before sending an exception.
35 2. From the current thread, using `set_exception` to throw the exception when the
36 shield is released.
38 A reentrant lock means that shields can be nested and the exception will only be
39 raised when the last context is exited.
40 """
42 def __init__(self, owner: threading.Thread): 1a
43 # Uses the Python implementation of the RLock instead of the C implementation
44 # because we need to inspect `_count` directly to check if the lock is active
45 # which is needed for delayed exception raising during alarms
46 self._lock = threading._RLock() # type: ignore # yes, we want the private version
47 self._exception = None
48 self._owner = owner
50 def __enter__(self) -> None: 1a
51 self._lock.__enter__()
53 def __exit__(self, *exc_info: Any): 1a
54 retval = self._lock.__exit__(*exc_info)
56 # Raise the exception if this is the last shield to exit in the owner thread
57 if (
58 not self.active()
59 and self._exception
60 and self._owner.ident == threading.current_thread().ident
61 ):
62 # Clear the exception to prevent it from being raised again
63 exc = self._exception
64 self._exception = None
65 raise exc from None
67 return retval
69 def set_exception(self, exc: BaseException): 1a
70 self._exception = exc
72 def active(self) -> bool: 1a
73 """
74 Returns true if the shield is active.
75 """
76 return getattr(self._lock, "_count") > 0
79class CancelledError(asyncio.CancelledError): 1a
80 # We want our `CancelledError` to be treated as a `BaseException` and defining it
81 # here simplifies downstream logic that needs to know "which" cancelled error to
82 # handle.
83 pass 1a
86def _get_thread_shield(thread: threading.Thread) -> ThreadShield: 1a
87 with _THREAD_SHIELDS_LOCK:
88 if thread not in _THREAD_SHIELDS:
89 _THREAD_SHIELDS[thread] = ThreadShield(thread)
91 # Perform garbage collection for old threads
92 for thread_ in tuple(_THREAD_SHIELDS.keys()):
93 if not thread_.is_alive():
94 _THREAD_SHIELDS.pop(thread_)
96 return _THREAD_SHIELDS[thread]
99@contextlib.contextmanager 1a
100def shield(): 1a
101 """
102 Prevent code from within the scope from being cancelled.
104 This guards against cancellation from alarm signals and injected exceptions as used
105 in this module.
107 If an event loop is running in the thread where this is called, it will be shielded
108 from asynchronous cancellation as well.
109 """
110 with (
111 anyio.CancelScope(shield=True)
112 if get_running_loop()
113 else contextlib.nullcontext()
114 ):
115 with _get_thread_shield(threading.current_thread()):
116 yield
119class CancelScope(abc.ABC): 1a
120 """
121 Defines a context where cancellation can be requested.
123 If cancelled, any code within the context should be interrupted. The cancellation
124 implementation varies depending on the environment and may not interrupt some system
125 calls.
127 A timeout can be defined to automatically cancel the scope after a given duration if
128 it has not exited.
129 """
131 def __init__( 1a
132 self, name: Optional[str] = None, timeout: Optional[float] = None
133 ) -> None:
134 self.name = name 1ebcd
135 self._deadline = None 1ebcd
136 self._cancelled = False 1ebcd
137 self._completed = False 1ebcd
138 self._started = False 1ebcd
139 self._start_time = None 1ebcd
140 self._end_time = None 1ebcd
141 self._timeout = timeout 1ebcd
142 self._lock = threading.Lock() 1ebcd
143 self._callbacks: list[Callable[[], None]] = [] 1ebcd
144 super().__init__() 1ebcd
146 def __enter__(self): 1a
147 with self._lock: 1ebcd
148 self._deadline = get_deadline(self._timeout) 1ebcd
149 self._started = True 1ebcd
150 self._start_time = time.monotonic() 1ebcd
152 logger.debug("%r entered", self) 1ebcd
153 return self 1ebcd
155 def __exit__( 1a
156 self, exc_type: type[BaseException], exc_val: Exception, exc_tb: TracebackType
157 ) -> Optional[bool]:
158 with self._lock: 1ebcd
159 if not self._cancelled: 159 ↛ 161line 159 didn't jump to line 161 because the condition on line 159 was always true1ebcd
160 self._completed = True 1ebcd
161 self._end_time = time.monotonic() 1ebcd
163 logger.debug("%r exited", self) 1ebcd
165 @property 1a
166 def timeout(self): 1a
167 return self._timeout 1ebcd
169 def started(self) -> bool: 1a
170 with self._lock:
171 return self._started
173 def cancelled(self) -> bool: 1a
174 with self._lock: 1bcd
175 return self._cancelled 1bcd
177 def timedout(self) -> bool: 1a
178 with self._lock:
179 if not self._end_time or not self._deadline:
180 return False
181 return self._cancelled and self._end_time > self._deadline
183 def set_timeout(self, timeout: float): 1a
184 with self._lock:
185 if self._started:
186 raise RuntimeError("Cannot set timeout after scope has started.")
187 self._timeout = timeout
189 def completed(self): 1a
190 with self._lock:
191 return self._completed
193 def cancel(self, throw: bool = True) -> bool: 1a
194 """
195 Cancel this scope.
197 If `throw` is not set, this will only mark the scope as cancelled and will not
198 throw the cancelled error.
199 """
200 with self._lock:
201 if not self._started:
202 raise RuntimeError("Scope has not been entered.")
204 if self._completed:
205 return False
207 if self._cancelled:
208 return True
210 self._cancelled = True
212 logger.info("%r cancelling", self)
214 for callback in self._callbacks:
215 callback()
217 return True
219 def add_cancel_callback(self, callback: Callable[[], None]): 1a
220 """
221 Add a callback to execute on cancellation.
222 """
223 self._callbacks.append(callback)
225 def __repr__(self) -> str: 1a
226 with self._lock:
227 state = (
228 "completed"
229 if self._completed
230 else (
231 "cancelled"
232 if self._cancelled
233 else "running"
234 if self._started
235 else "pending"
236 )
237 ).upper()
238 timeout = f", timeout={self._timeout:.2f}" if self._timeout else ""
239 runtime = (
240 f", runtime={(self._end_time or time.monotonic()) - self._start_time:.2f}"
241 if self._start_time
242 else ""
243 )
244 name = f", name={self.name!r}" if self.name else f"at {hex(id(self))}"
245 return f"<{type(self).__name__}{name} {state}{timeout}{runtime}>"
248class AsyncCancelScope(CancelScope): 1a
249 def __init__( 1a
250 self, name: Optional[str] = None, timeout: Optional[float] = None
251 ) -> None:
252 super().__init__(name=name, timeout=timeout) 1ebcd
254 def __enter__(self): 1a
255 self.loop = asyncio.get_running_loop() 1ebcd
257 super().__enter__() 1ebcd
259 # Use anyio as the cancellation enforcer because it's very complicated and they
260 # have done a good job
261 self._anyio_scope = anyio.CancelScope( 1ebcd
262 deadline=self._deadline if self._deadline is not None else math.inf
263 ).__enter__()
265 return self 1ebcd
267 def __exit__( 1a
268 self, exc_type: type[BaseException], exc_val: Exception, exc_tb: TracebackType
269 ) -> bool:
270 if self._anyio_scope.cancel_called: 270 ↛ 272line 270 didn't jump to line 272 because the condition on line 270 was never true1bcd
271 # Mark as cancelled
272 self.cancel(throw=False)
274 # TODO: Can we also delete the scope?
275 # We have to exit this scope to prevent leaking memory. A fix for
276 # issue #10952.
277 self._anyio_scope.__exit__(exc_type, exc_val, exc_tb) 1bcd
279 super().__exit__(exc_type, exc_val, exc_tb) 1bcd
281 if self.cancelled() and exc_type is not CancelledError: 281 ↛ 283line 281 didn't jump to line 283 because the condition on line 281 was never true1bcd
282 # Ensure cancellation error is propagated on exit
283 raise CancelledError() from exc_val
285 return False 1bcd
287 def cancel(self, throw: bool = True): 1a
288 if not super().cancel():
289 return False
291 if throw:
292 if self.loop is get_running_loop():
293 self._anyio_scope.cancel()
294 else:
295 # `Task.cancel` is not thread safe
296 self.loop.call_soon_threadsafe(self._anyio_scope.cancel)
298 return True
301class NullCancelScope(CancelScope): 1a
302 """
303 A cancel scope that does nothing.
305 This is used for environments where cancellation is not supported.
306 """
308 def __init__( 1a
309 self,
310 name: Optional[str] = None,
311 timeout: Optional[float] = None,
312 reason: Optional[str] = None,
313 ) -> None:
314 super().__init__(name, timeout)
315 self.reason = reason or "null cancel scope"
317 def cancel(self, throw: bool = True) -> bool: 1a
318 logger.warning("%r cannot cancel %s.", self, self.reason)
319 return False
322class AlarmCancelScope(CancelScope): 1a
323 """
324 A cancel scope that uses an alarm signal which can interrupt long-running system
325 calls.
327 Only the main thread can be cancelled with an alarm signal, so this scope is only
328 available in the main thread.
329 """
331 def __enter__(self): 1a
332 super().__enter__()
334 current_thread = threading.current_thread()
335 self._previous_timer = None
337 if current_thread is not threading.main_thread():
338 raise ValueError(
339 "Alarm based timeouts can only be used in the main thread."
340 )
342 self._previous_alarm_handler = signal.getsignal(signal.SIGALRM)
344 if self._previous_alarm_handler != signal.SIG_DFL:
345 logger.warning(
346 "%r overriding existing alarm handler %s",
347 self,
348 self._previous_alarm_handler,
349 )
351 # Capture alarm signals and raise a timeout
352 signal.signal(signal.SIGALRM, self._sigalarm_to_error)
354 # Set a timer to raise an alarm signal
355 if self.timeout is not None:
356 # Use `setitimer` instead of `signal.alarm` for float support; raises a SIGALRM
357 logger.debug("%r set alarm timer for %f seconds", self, self.timeout)
358 self._previous_timer = signal.setitimer(signal.ITIMER_REAL, self.timeout)
360 return self
362 def _sigalarm_to_error(self, *args: object) -> None: 1a
363 logger.debug("%r captured alarm raising as cancelled error", self)
364 if self.cancel(throw=False):
365 shield = _get_thread_shield(threading.main_thread())
366 if shield.active():
367 logger.debug("%r thread shield active; delaying exception", self)
368 shield.set_exception(CancelledError())
369 else:
370 raise CancelledError()
372 def __exit__(self, *_: Any) -> Optional[bool]: 1a
373 retval = super().__exit__(*_)
375 if self.timeout is not None:
376 # Restore the previous timer
377 if TYPE_CHECKING:
378 assert self._previous_timer is not None
379 signal.setitimer(signal.ITIMER_REAL, *self._previous_timer)
381 # Restore the previous signal handler
382 signal.signal(signal.SIGALRM, self._previous_alarm_handler)
384 return retval
386 def cancel(self, throw: bool = True): 1a
387 if not super().cancel():
388 return False
390 if throw:
391 logger.debug("%r sending alarm signal to main thread", self)
392 os.kill(os.getpid(), signal.SIGALRM)
394 return True
397class WatcherThreadCancelScope(CancelScope): 1a
398 """
399 A cancel scope that uses a watcher thread and an injected exception to enforce
400 cancellation.
402 The injected exception cannot interrupt calls and will be raised on the ~next
403 instruction. This can raise exceptions in unexpected places. See `shield` for
404 guarding against interruption.
406 If a timeout is specified, a watcher thread is spawned that will run for `timeout`
407 seconds then send the exception to the supervised thread.
408 """
410 def __enter__(self): 1a
411 super().__enter__() 1ebcd
412 self._event = threading.Event() 1ebcd
413 self._enforcer_thread = None 1ebcd
414 self._supervised_thread = threading.current_thread() 1ebcd
416 if self.timeout is not None: 416 ↛ 417line 416 didn't jump to line 417 because the condition on line 416 was never true1ebcd
417 name = self.name or f"for scope {hex(id(self))}"
418 self._enforcer_thread = threading.Thread(
419 target=self._timeout_enforcer,
420 name=f"timeout-watcher {name} {self.timeout:.2f}",
421 )
422 self._enforcer_thread.start()
424 return self 1ebcd
426 def __exit__(self, *_: Any) -> Optional[bool]: 1a
427 retval = super().__exit__(*_) 1ebcd
428 self._event.set() 1ebcd
429 if self._enforcer_thread: 429 ↛ 430line 429 didn't jump to line 430 because the condition on line 429 was never true1ebcd
430 logger.debug("%r joining enforcer thread %r", self, self._enforcer_thread)
431 self._enforcer_thread.join()
432 return retval 1ebcd
434 def _send_cancelled_error(self): 1a
435 """
436 Send a cancelled error to the supervised thread.
437 """
438 if self._supervised_thread.is_alive():
439 logger.debug(
440 "%r sending exception to supervised thread %r",
441 self,
442 self._supervised_thread,
443 )
444 with _get_thread_shield(self._supervised_thread):
445 try:
446 _send_exception_to_thread(self._supervised_thread, CancelledError)
447 except ValueError:
448 # If the thread is gone; just move on without error
449 logger.debug("Thread missing!")
451 def _timeout_enforcer(self): 1a
452 """
453 Target for a thread that enforces a timeout.
454 """
455 if not self._event.wait(self.timeout):
456 logger.debug("%r enforcer detected timeout!", self)
457 if self.cancel(throw=False):
458 with _get_thread_shield(self._supervised_thread):
459 self._send_cancelled_error()
461 # Wait for the supervised thread to exit its context
462 logger.debug("%r waiting for supervised thread to exit", self)
463 self._event.wait()
465 def cancel(self, throw: bool = True): 1a
466 if not super().cancel():
467 return False
469 if throw:
470 self._send_cancelled_error()
472 return True
475@overload 1a
476def get_deadline(timeout: float) -> float: ... 476 ↛ exitline 476 didn't return from function 'get_deadline' because 1a
479@overload 1a
480def get_deadline(timeout: None) -> None: ... 480 ↛ exitline 480 didn't return from function 'get_deadline' because 1a
483def get_deadline(timeout: Optional[float]) -> Optional[float]: 1a
484 """
485 Compute an deadline given a timeout.
487 Uses a monotonic clock.
488 """
489 if timeout is None: 489 ↛ 492line 489 didn't jump to line 492 because the condition on line 489 was always true1ebcd
490 return None 1ebcd
492 return time.monotonic() + timeout
495def get_timeout(deadline: Optional[float]): 1a
496 """
497 Compute an timeout given a deadline.
499 Uses a monotonic clock.
500 """
501 if deadline is None: 501 ↛ 504line 501 didn't jump to line 504 because the condition on line 501 was always true1ebcd
502 return None 1ebcd
504 return max(0, deadline - time.monotonic())
507@contextlib.contextmanager 1a
508def cancel_async_at(deadline: Optional[float], name: Optional[str] = None): 1a
509 """
510 Cancel any async calls within the context if it does not exit by the given deadline.
512 Deadlines must be computed with the monotonic clock. See `get_deadline`.
514 A timeout error will be raised on the next `await` when the timeout expires.
516 Yields a `CancelContext`.
517 """
518 with cancel_async_after(get_timeout(deadline), name=name) as ctx: 1ebcd
519 yield ctx 1ebcd
522@contextlib.contextmanager 1a
523def cancel_async_after(timeout: Optional[float], name: Optional[str] = None): 1a
524 """
525 Cancel any async calls within the context if it does not exit after the given
526 timeout.
528 A timeout error will be raised on the next `await` when the timeout expires.
530 Yields a `CancelContext`.
531 """
532 with AsyncCancelScope(timeout=timeout, name=name) as ctx: 1ebcd
533 yield ctx 1ebcd
536@contextlib.contextmanager 1a
537def cancel_sync_at(deadline: Optional[float], name: Optional[str] = None): 1a
538 """
539 Cancel any sync calls within the context if it does not exit by the given deadline.
541 Deadlines must be computed with the monotonic clock. See `get_deadline`.
543 The cancel method varies depending on if this is called in the main thread or not.
544 See `cancel_sync_after` for details
546 Yields a `CancelContext`.
547 """
548 timeout = max(0, deadline - time.monotonic()) if deadline is not None else None 1ebcd
550 with cancel_sync_after(timeout, name=name) as ctx: 1ebcd
551 yield ctx 1ebcd
554@contextlib.contextmanager 1a
555def cancel_sync_after(timeout: Optional[float], name: Optional[str] = None): 1a
556 """
557 Cancel any sync calls within the context if it does not exit after the given
558 timeout.
560 The timeout method varies depending on if this is called in the main thread or not.
561 See `AlarmCancelScope` and `WatcherThreadCancelScope` for details.
563 Yields a `CancelContext`.
564 """
566 if sys.platform.startswith("win"): 566 ↛ 567line 566 didn't jump to line 567 because the condition on line 566 was never true1ebcd
567 yield NullCancelScope(reason="cancellation is not supported on Windows")
568 return
570 thread = threading.current_thread() 1ebcd
571 existing_alarm_handler = signal.getsignal(signal.SIGALRM) != signal.SIG_DFL 1ebcd
573 if ( 573 ↛ 581line 573 didn't jump to line 581 because the condition on line 573 was never true
574 thread is threading.main_thread()
575 # Avoid nested alarm handlers; it's hard to follow and they will interfere with
576 # each other
577 and not existing_alarm_handler
578 # Avoid using an alarm when there is no timeout; it's better saved for that case
579 and timeout is not None
580 ):
581 scope = AlarmCancelScope(name=name, timeout=timeout)
582 else:
583 scope = WatcherThreadCancelScope(name=name, timeout=timeout) 1ebcd
585 with scope: 1ebcd
586 yield scope 1ebcd
589def _send_exception_to_thread(thread: threading.Thread, exc_type: type[BaseException]): 1a
590 """
591 Raise an exception in a thread.
593 This will not interrupt long-running system calls like `sleep` or `wait`.
594 """
595 if not thread.ident:
596 raise ValueError("Thread is not started.")
597 ret = ctypes.pythonapi.PyThreadState_SetAsyncExc(
598 ctypes.c_long(thread.ident), ctypes.py_object(exc_type)
599 )
600 if ret == 0:
601 raise ValueError("Thread not found.")