Coverage for /usr/local/lib/python3.12/site-packages/prefect/_internal/concurrency/calls.py: 21%
241 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 13:38 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 13:38 +0000
1"""
2Implementation of the `Call` data structure for transport of deferred function calls
3and low-level management of call execution.
4"""
6import abc 1a
7import asyncio 1a
8import concurrent.futures 1a
9import contextlib 1a
10import contextvars 1a
11import dataclasses 1a
12import inspect 1a
13import threading 1a
14import weakref 1a
15from collections.abc import Awaitable, Generator 1a
16from concurrent.futures._base import ( 1a
17 CANCELLED,
18 CANCELLED_AND_NOTIFIED,
19 FINISHED,
20 RUNNING,
21)
22from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, Union 1a
24from typing_extensions import ParamSpec, Self, TypeAlias, TypeVar, TypeVarTuple 1a
26from prefect._internal.concurrency import logger 1a
27from prefect._internal.concurrency.cancellation import ( 1a
28 AsyncCancelScope,
29 CancelledError,
30 cancel_async_at,
31 cancel_sync_at,
32 get_deadline,
33)
34from prefect._internal.concurrency.event_loop import get_running_loop 1a
36T = TypeVar("T", infer_variance=True) 1a
37Ts = TypeVarTuple("Ts") 1a
38P = ParamSpec("P") 1a
40_SyncOrAsyncCallable: TypeAlias = Callable[P, Union[T, Awaitable[T]]] 1a
43# Tracks the current call being executed. Note that storing the `Call`
44# object for an async call directly in the contextvar appears to create a
45# memory leak, despite the fact that we `reset` when leaving the context
46# that sets this contextvar. A weakref avoids the leak and works because a)
47# we already have strong references to the `Call` objects in other places
48# and b) this is used for performance optimizations where we have fallback
49# behavior if this weakref is garbage collected. A fix for issue #10952.
50current_call: contextvars.ContextVar["weakref.ref[Call[Any]]"] = ( # novm 1a
51 contextvars.ContextVar("current_call")
52)
54# Create a strong reference to tasks to prevent destruction during execution errors
55_ASYNC_TASK_REFS: set[asyncio.Task[None]] = set() 1a
58@contextlib.contextmanager 1a
59def set_current_call(call: "Call[Any]") -> Generator[None, Any, None]: 1a
60 token = current_call.set(weakref.ref(call))
61 try:
62 yield
63 finally:
64 current_call.reset(token)
67class Future(concurrent.futures.Future[T]): 1a
68 """
69 Extension of `concurrent.futures.Future` with support for cancellation of running
70 futures.
72 Used by `Call`.
73 """
75 def __init__(self, name: Optional[str] = None) -> None: 1a
76 super().__init__()
77 self._cancel_scope = None
78 self._deadline = None
79 self._cancel_callbacks: list[Callable[[], None]] = []
80 self._name = name
81 self._timed_out = False
83 def set_running_or_notify_cancel(self, timeout: Optional[float] = None): 1a
84 self._deadline = get_deadline(timeout)
85 return super().set_running_or_notify_cancel()
87 @contextlib.contextmanager 1a
88 def enforce_async_deadline(self) -> Generator[AsyncCancelScope]: 1a
89 with cancel_async_at(self._deadline, name=self._name) as self._cancel_scope:
90 for callback in self._cancel_callbacks:
91 self._cancel_scope.add_cancel_callback(callback)
92 yield self._cancel_scope
94 @contextlib.contextmanager 1a
95 def enforce_sync_deadline(self): 1a
96 with cancel_sync_at(self._deadline, name=self._name) as self._cancel_scope:
97 for callback in self._cancel_callbacks:
98 self._cancel_scope.add_cancel_callback(callback)
99 yield self._cancel_scope
101 def add_cancel_callback(self, callback: Callable[[], Any]) -> None: 1a
102 """
103 Add a callback to be enforced on cancellation.
105 Unlike "done" callbacks, this callback will be invoked _before_ the future is
106 cancelled. If added after the future is cancelled, nothing will happen.
107 """
108 # If we were to invoke cancel callbacks the same as "done" callbacks, we
109 # would not propagate chained cancellation in waiters in time to actually
110 # interrupt calls.
111 if self._cancel_scope:
112 # Add callback to current cancel scope if it exists
113 self._cancel_scope.add_cancel_callback(callback)
115 # Also add callbacks to tracking list
116 self._cancel_callbacks.append(callback)
118 def timedout(self) -> bool: 1a
119 with self._condition:
120 return self._timed_out
122 def cancel(self) -> bool: 1a
123 """Cancel the future if possible.
125 Returns True if the future was cancelled, False otherwise. A future cannot be
126 cancelled if it has already completed.
127 """
128 with self._condition:
129 # Unlike the stdlib, we allow attempted cancellation of RUNNING futures
130 if self._state in [RUNNING]:
131 if self._cancel_scope is None:
132 return False
133 elif not self._cancel_scope.cancelled():
134 # Perform cancellation
135 if not self._cancel_scope.cancel():
136 return False
138 if self._state in [FINISHED]:
139 return False
141 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:
142 return True
144 # Normally cancel callbacks are handled by the cancel scope but if there
145 # is not one let's respect them still
146 if not self._cancel_scope:
147 for callback in self._cancel_callbacks:
148 callback()
150 self._state = CANCELLED
151 self._condition.notify_all()
153 self._invoke_callbacks()
154 return True
156 if TYPE_CHECKING: 156 ↛ 158line 156 didn't jump to line 158 because the condition on line 156 was never true1a
158 def __get_result(self) -> T: ...
160 def result(self, timeout: Optional[float] = None) -> T: 1a
161 """Return the result of the call that the future represents.
163 Args:
164 timeout: The number of seconds to wait for the result if the future
165 isn't done. If None, then there is no limit on the wait time.
167 Returns:
168 The result of the call that the future represents.
170 Raises:
171 CancelledError: If the future was cancelled.
172 TimeoutError: If the future didn't finish executing before the given
173 timeout.
174 Exception: If the call raised then that exception will be raised.
175 """
176 try:
177 with self._condition:
178 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:
179 # Raise Prefect cancelled error instead of
180 # `concurrent.futures._base.CancelledError`
181 raise CancelledError()
182 elif self._state == FINISHED:
183 return self.__get_result()
185 self._condition.wait(timeout)
187 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:
188 # Raise Prefect cancelled error instead of
189 # `concurrent.futures._base.CancelledError`
190 raise CancelledError()
191 elif self._state == FINISHED:
192 return self.__get_result()
193 else:
194 raise TimeoutError()
195 finally:
196 # Break a reference cycle with the exception in self._exception
197 self = None
199 _done_callbacks: list[Callable[[Self], object]] 1a
201 def _invoke_callbacks(self) -> None: 1a
202 """
203 Invoke our done callbacks and clean up cancel scopes and cancel
204 callbacks. Fixes a memory leak that hung on to Call objects,
205 preventing garbage collection of Futures.
207 A fix for #10952.
208 """
209 if self._done_callbacks:
210 done_callbacks = self._done_callbacks[:]
211 self._done_callbacks[:] = []
213 for callback in done_callbacks:
214 try:
215 callback(self)
216 except Exception:
217 logger.exception("exception calling callback for %r", self)
219 self._cancel_callbacks = []
220 if self._cancel_scope:
221 setattr(self._cancel_scope, "_callbacks", [])
222 self._cancel_scope = None
225@dataclasses.dataclass(eq=False) 1a
226class Call(Generic[T]): 1a
227 """
228 A deferred function call.
229 """
231 future: Future[T] 1a
232 fn: "_SyncOrAsyncCallable[..., T]" 1a
233 args: tuple[Any, ...] 1a
234 kwargs: dict[str, Any] 1a
235 context: contextvars.Context 1a
236 timeout: Optional[float] 1a
237 runner: Optional["Portal"] = None 1a
239 def __eq__(self, other: object) -> bool: 1a
240 """this is to avoid attempts at invalid access of args/kwargs in <3.13 stemming from the
241 auto-generated __eq__ method on the dataclass.
243 this will no longer be required in 3.13+, see https://github.com/python/cpython/issues/128294
244 """
245 if self is other:
246 return True
247 if not isinstance(other, Call):
248 return NotImplemented
250 try:
251 # Attempt to access args/kwargs. If any are missing on self or other,
252 # an AttributeError will be raised by the access attempt on one of them.
253 s_args, s_kwargs = self.args, self.kwargs
254 o_args, o_kwargs = other.args, other.kwargs
255 except AttributeError:
256 # If args/kwargs are missing on self or other (and self is not other),
257 # they are considered not equal. This ensures that a Call with deleted
258 # args/kwargs compares as different from one that still has them
259 return False
261 # If all args/kwargs were accessible on both, proceed with full field comparison.
262 # Note: self.future == other.future will use Future's __eq__ (default is identity).
263 return (
264 (self.future == other.future)
265 and (self.fn == other.fn)
266 and (s_args == o_args)
267 and (s_kwargs == o_kwargs)
268 and (self.context == other.context)
269 and (self.timeout == other.timeout)
270 and (self.runner == other.runner)
271 )
273 __hash__ = None # type: ignore 1a
275 @classmethod 1a
276 def new( 1a
277 cls,
278 __fn: _SyncOrAsyncCallable[P, T],
279 *args: P.args,
280 **kwargs: P.kwargs,
281 ) -> Self:
282 return cls(
283 future=Future(name=getattr(__fn, "__name__", str(__fn))),
284 fn=__fn,
285 args=args,
286 kwargs=kwargs,
287 context=contextvars.copy_context(),
288 timeout=None,
289 )
291 def set_timeout(self, timeout: Optional[float] = None) -> None: 1a
292 """
293 Set the timeout for the call.
295 The timeout begins when the call starts.
296 """
297 if self.future.done() or self.future.running():
298 raise RuntimeError("Timeouts cannot be added when the call has started.")
300 self.timeout = timeout
302 def set_runner(self, portal: "Portal") -> None: 1a
303 """
304 Update the portal used to run this call.
305 """
306 if self.runner is not None:
307 raise RuntimeError("The portal is already set for this call.")
309 self.runner = portal
311 def run(self) -> Optional[Awaitable[None]]: 1a
312 """
313 Execute the call and place the result on the future.
315 All exceptions during execution of the call are captured.
316 """
317 # Do not execute if the future is cancelled
318 if not self.future.set_running_or_notify_cancel(self.timeout):
319 logger.debug("Skipping execution of cancelled call %r", self)
320 return None
322 logger.debug(
323 "Running call %r in thread %r%s",
324 self,
325 threading.current_thread().name,
326 f" with timeout of {self.timeout}s" if self.timeout is not None else "",
327 )
329 coro = self.context.run(self._run_sync)
331 if coro is not None:
332 loop = get_running_loop()
333 if loop:
334 # If an event loop is available, return a task to be awaited
335 # Note we must create a task for context variables to propagate
336 logger.debug(
337 "Scheduling coroutine for call %r in running loop %r",
338 self,
339 loop,
340 )
341 task = self.context.run(loop.create_task, self._run_async(coro))
343 # Prevent tasks from being garbage collected before completion
344 # See https://docs.python.org/3.10/library/asyncio-task.html#asyncio.create_task
345 _ASYNC_TASK_REFS.add(task)
346 asyncio.ensure_future(task).add_done_callback(
347 lambda _: _ASYNC_TASK_REFS.remove(task)
348 )
350 return task
352 else:
353 # Otherwise, execute the function here
354 logger.debug("Executing coroutine for call %r in new loop", self)
355 return self.context.run(asyncio.run, self._run_async(coro))
357 return None
359 def result(self, timeout: Optional[float] = None) -> T: 1a
360 """
361 Wait for the result of the call.
363 Not safe for use from asynchronous contexts.
364 """
365 return self.future.result(timeout=timeout)
367 async def aresult(self): 1a
368 """
369 Wait for the result of the call.
371 For use from asynchronous contexts.
372 """
373 try:
374 return await asyncio.wrap_future(self.future)
375 except asyncio.CancelledError as exc:
376 raise CancelledError() from exc
378 def cancelled(self) -> bool: 1a
379 """
380 Check if the call was cancelled.
381 """
382 return self.future.cancelled()
384 def timedout(self) -> bool: 1a
385 """
386 Check if the call timed out.
387 """
388 return self.future.timedout()
390 def cancel(self) -> bool: 1a
391 return self.future.cancel()
393 def _run_sync(self) -> Optional[Awaitable[T]]: 1a
394 cancel_scope = None
395 try:
396 with set_current_call(self):
397 with self.future.enforce_sync_deadline() as cancel_scope:
398 try:
399 result = self.fn(*self.args, **self.kwargs)
400 finally:
401 # Forget this call's arguments in order to free up any memory
402 # that may be referenced by them; after a call has happened,
403 # there's no need to keep a reference to them
404 with contextlib.suppress(AttributeError):
405 del self.args, self.kwargs
407 # Return the coroutine for async execution
408 if inspect.isawaitable(result):
409 return result
411 except CancelledError:
412 # Report cancellation
413 # in rare cases, enforce_sync_deadline raises CancelledError
414 # prior to yielding
415 if cancel_scope is None:
416 self.future.cancel()
417 return None
418 if cancel_scope.timedout():
419 setattr(self.future, "_timed_out", True)
420 self.future.cancel()
421 elif cancel_scope.cancelled():
422 self.future.cancel()
423 else:
424 raise
425 except BaseException as exc:
426 logger.debug("Encountered exception in call %r", self, exc_info=True)
427 self.future.set_exception(exc)
429 # Prevent reference cycle in `exc`
430 del self
431 else:
432 self.future.set_result(result) # noqa: F821
433 logger.debug("Finished call %r", self) # noqa: F821
435 async def _run_async(self, coro: Awaitable[T]) -> None: 1a
436 cancel_scope = result = None
437 try:
438 with set_current_call(self):
439 with self.future.enforce_async_deadline() as cancel_scope:
440 try:
441 result = await coro
442 finally:
443 # Forget this call's arguments in order to free up any memory
444 # that may be referenced by them; after a call has happened,
445 # there's no need to keep a reference to them
446 with contextlib.suppress(AttributeError):
447 del self.args, self.kwargs
448 except CancelledError:
449 # Report cancellation
450 if TYPE_CHECKING:
451 assert cancel_scope is not None
452 if cancel_scope.timedout():
453 setattr(self.future, "_timed_out", True)
454 self.future.cancel()
455 elif cancel_scope.cancelled():
456 self.future.cancel()
457 else:
458 raise
459 except BaseException as exc:
460 logger.debug("Encountered exception in async call %r", self, exc_info=True)
462 self.future.set_exception(exc)
463 # Prevent reference cycle in `exc`
464 del self
465 else:
466 # F821 ignored because Ruff gets confused about the del self above.
467 self.future.set_result(result) # noqa: F821
468 logger.debug("Finished async call %r", self) # noqa: F821
470 def __call__(self) -> Union[T, Awaitable[T]]: 1a
471 """
472 Execute the call and return its result.
474 All executions during execution of the call are re-raised.
475 """
476 coro = self.run()
478 # Return an awaitable if in an async context
479 if coro is not None:
481 async def run_and_return_result() -> T:
482 await coro
483 return self.result()
485 return run_and_return_result()
486 else:
487 return self.result()
489 def __repr__(self) -> str: 1a
490 name = getattr(self.fn, "__name__", str(self.fn))
492 try:
493 args, kwargs = self.args, self.kwargs
494 except AttributeError:
495 call_args = "<dropped>"
496 else:
497 call_args = ", ".join(
498 [repr(arg) for arg in args]
499 + [f"{key}={repr(val)}" for key, val in kwargs.items()]
500 )
502 # Enforce a maximum length
503 if len(call_args) > 100:
504 call_args = call_args[:100] + "..."
506 return f"{name}({call_args})"
509class Portal(abc.ABC): 1a
510 """
511 Allows submission of calls to execute elsewhere.
512 """
514 @abc.abstractmethod 1a
515 def submit(self, call: "Call[T]") -> "Call[T]": 1a
516 """
517 Submit a call to execute elsewhere.
519 The call's result can be retrieved with `call.result()`.
521 Returns the call for convenience.
522 """