Coverage for /usr/local/lib/python3.12/site-packages/prefect/_internal/concurrency/calls.py: 48%
241 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
1"""
2Implementation of the `Call` data structure for transport of deferred function calls
3and low-level management of call execution.
4"""
6import abc 1d
7import asyncio 1d
8import concurrent.futures 1d
9import contextlib 1d
10import contextvars 1d
11import dataclasses 1d
12import inspect 1d
13import threading 1d
14import weakref 1d
15from collections.abc import Awaitable, Generator 1d
16from concurrent.futures._base import ( 1d
17 CANCELLED,
18 CANCELLED_AND_NOTIFIED,
19 FINISHED,
20 RUNNING,
21)
22from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, Union 1d
24from typing_extensions import ParamSpec, Self, TypeAlias, TypeVar, TypeVarTuple 1d
26from prefect._internal.concurrency import logger 1d
27from prefect._internal.concurrency.cancellation import ( 1d
28 AsyncCancelScope,
29 CancelledError,
30 cancel_async_at,
31 cancel_sync_at,
32 get_deadline,
33)
34from prefect._internal.concurrency.event_loop import get_running_loop 1d
36T = TypeVar("T", infer_variance=True) 1d
37Ts = TypeVarTuple("Ts") 1d
38P = ParamSpec("P") 1d
40_SyncOrAsyncCallable: TypeAlias = Callable[P, Union[T, Awaitable[T]]] 1d
43# Tracks the current call being executed. Note that storing the `Call`
44# object for an async call directly in the contextvar appears to create a
45# memory leak, despite the fact that we `reset` when leaving the context
46# that sets this contextvar. A weakref avoids the leak and works because a)
47# we already have strong references to the `Call` objects in other places
48# and b) this is used for performance optimizations where we have fallback
49# behavior if this weakref is garbage collected. A fix for issue #10952.
50current_call: contextvars.ContextVar["weakref.ref[Call[Any]]"] = ( # novm 1d
51 contextvars.ContextVar("current_call")
52)
54# Create a strong reference to tasks to prevent destruction during execution errors
55_ASYNC_TASK_REFS: set[asyncio.Task[None]] = set() 1d
58@contextlib.contextmanager 1d
59def set_current_call(call: "Call[Any]") -> Generator[None, Any, None]: 1d
60 token = current_call.set(weakref.ref(call)) 1eabc
61 try: 1eabc
62 yield 1eabc
63 finally:
64 current_call.reset(token) 1eabc
67class Future(concurrent.futures.Future[T]): 1d
68 """
69 Extension of `concurrent.futures.Future` with support for cancellation of running
70 futures.
72 Used by `Call`.
73 """
75 def __init__(self, name: Optional[str] = None) -> None: 1d
76 super().__init__() 1abc
77 self._cancel_scope = None 1abc
78 self._deadline = None 1abc
79 self._cancel_callbacks: list[Callable[[], None]] = [] 1abc
80 self._name = name 1abc
81 self._timed_out = False 1abc
83 def set_running_or_notify_cancel(self, timeout: Optional[float] = None): 1d
84 self._deadline = get_deadline(timeout) 1eabc
85 return super().set_running_or_notify_cancel() 1eabc
87 @contextlib.contextmanager 1d
88 def enforce_async_deadline(self) -> Generator[AsyncCancelScope]: 1d
89 with cancel_async_at(self._deadline, name=self._name) as self._cancel_scope: 1eabc
90 for callback in self._cancel_callbacks: 90 ↛ 91line 90 didn't jump to line 91 because the loop on line 90 never started1eabc
91 self._cancel_scope.add_cancel_callback(callback)
92 yield self._cancel_scope 1eabc
94 @contextlib.contextmanager 1d
95 def enforce_sync_deadline(self): 1d
96 with cancel_sync_at(self._deadline, name=self._name) as self._cancel_scope: 1eabc
97 for callback in self._cancel_callbacks: 97 ↛ 98line 97 didn't jump to line 98 because the loop on line 97 never started1eabc
98 self._cancel_scope.add_cancel_callback(callback)
99 yield self._cancel_scope 1eabc
101 def add_cancel_callback(self, callback: Callable[[], Any]) -> None: 1d
102 """
103 Add a callback to be enforced on cancellation.
105 Unlike "done" callbacks, this callback will be invoked _before_ the future is
106 cancelled. If added after the future is cancelled, nothing will happen.
107 """
108 # If we were to invoke cancel callbacks the same as "done" callbacks, we
109 # would not propagate chained cancellation in waiters in time to actually
110 # interrupt calls.
111 if self._cancel_scope:
112 # Add callback to current cancel scope if it exists
113 self._cancel_scope.add_cancel_callback(callback)
115 # Also add callbacks to tracking list
116 self._cancel_callbacks.append(callback)
118 def timedout(self) -> bool: 1d
119 with self._condition:
120 return self._timed_out
122 def cancel(self) -> bool: 1d
123 """Cancel the future if possible.
125 Returns True if the future was cancelled, False otherwise. A future cannot be
126 cancelled if it has already completed.
127 """
128 with self._condition:
129 # Unlike the stdlib, we allow attempted cancellation of RUNNING futures
130 if self._state in [RUNNING]:
131 if self._cancel_scope is None:
132 return False
133 elif not self._cancel_scope.cancelled():
134 # Perform cancellation
135 if not self._cancel_scope.cancel():
136 return False
138 if self._state in [FINISHED]:
139 return False
141 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:
142 return True
144 # Normally cancel callbacks are handled by the cancel scope but if there
145 # is not one let's respect them still
146 if not self._cancel_scope:
147 for callback in self._cancel_callbacks:
148 callback()
150 self._state = CANCELLED
151 self._condition.notify_all()
153 self._invoke_callbacks()
154 return True
156 if TYPE_CHECKING: 156 ↛ 158line 156 didn't jump to line 158 because the condition on line 156 was never true1d
158 def __get_result(self) -> T: ...
160 def result(self, timeout: Optional[float] = None) -> T: 1d
161 """Return the result of the call that the future represents.
163 Args:
164 timeout: The number of seconds to wait for the result if the future
165 isn't done. If None, then there is no limit on the wait time.
167 Returns:
168 The result of the call that the future represents.
170 Raises:
171 CancelledError: If the future was cancelled.
172 TimeoutError: If the future didn't finish executing before the given
173 timeout.
174 Exception: If the call raised then that exception will be raised.
175 """
176 try: 1eabc
177 with self._condition: 1eabc
178 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]: 178 ↛ 181line 178 didn't jump to line 181 because the condition on line 178 was never true1eabc
179 # Raise Prefect cancelled error instead of
180 # `concurrent.futures._base.CancelledError`
181 raise CancelledError()
182 elif self._state == FINISHED: 182 ↛ 183line 182 didn't jump to line 183 because the condition on line 182 was never true1eabc
183 return self.__get_result()
185 self._condition.wait(timeout) 1eabc
187 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]: 187 ↛ 190line 187 didn't jump to line 190 because the condition on line 187 was never true1abc
188 # Raise Prefect cancelled error instead of
189 # `concurrent.futures._base.CancelledError`
190 raise CancelledError()
191 elif self._state == FINISHED: 191 ↛ 194line 191 didn't jump to line 194 because the condition on line 191 was always true1abc
192 return self.__get_result() 1abc
193 else:
194 raise TimeoutError()
195 finally:
196 # Break a reference cycle with the exception in self._exception
197 self = None 1abc
199 _done_callbacks: list[Callable[[Self], object]] 1d
201 def _invoke_callbacks(self) -> None: 1d
202 """
203 Invoke our done callbacks and clean up cancel scopes and cancel
204 callbacks. Fixes a memory leak that hung on to Call objects,
205 preventing garbage collection of Futures.
207 A fix for #10952.
208 """
209 if self._done_callbacks: 209 ↛ 210line 209 didn't jump to line 210 because the condition on line 209 was never true1abc
210 done_callbacks = self._done_callbacks[:]
211 self._done_callbacks[:] = []
213 for callback in done_callbacks:
214 try:
215 callback(self)
216 except Exception:
217 logger.exception("exception calling callback for %r", self)
219 self._cancel_callbacks = [] 1abc
220 if self._cancel_scope: 220 ↛ exitline 220 didn't return from function '_invoke_callbacks' because the condition on line 220 was always true1abc
221 setattr(self._cancel_scope, "_callbacks", []) 1abc
222 self._cancel_scope = None 1abc
225@dataclasses.dataclass(eq=False) 1d
226class Call(Generic[T]): 1d
227 """
228 A deferred function call.
229 """
231 future: Future[T] 1d
232 fn: "_SyncOrAsyncCallable[..., T]" 1d
233 args: tuple[Any, ...] 1d
234 kwargs: dict[str, Any] 1d
235 context: contextvars.Context 1d
236 timeout: Optional[float] 1d
237 runner: Optional["Portal"] = None 1d
239 def __eq__(self, other: object) -> bool: 1d
240 """this is to avoid attempts at invalid access of args/kwargs in <3.13 stemming from the
241 auto-generated __eq__ method on the dataclass.
243 this will no longer be required in 3.13+, see https://github.com/python/cpython/issues/128294
244 """
245 if self is other:
246 return True
247 if not isinstance(other, Call):
248 return NotImplemented
250 try:
251 # Attempt to access args/kwargs. If any are missing on self or other,
252 # an AttributeError will be raised by the access attempt on one of them.
253 s_args, s_kwargs = self.args, self.kwargs
254 o_args, o_kwargs = other.args, other.kwargs
255 except AttributeError:
256 # If args/kwargs are missing on self or other (and self is not other),
257 # they are considered not equal. This ensures that a Call with deleted
258 # args/kwargs compares as different from one that still has them
259 return False
261 # If all args/kwargs were accessible on both, proceed with full field comparison.
262 # Note: self.future == other.future will use Future's __eq__ (default is identity).
263 return (
264 (self.future == other.future)
265 and (self.fn == other.fn)
266 and (s_args == o_args)
267 and (s_kwargs == o_kwargs)
268 and (self.context == other.context)
269 and (self.timeout == other.timeout)
270 and (self.runner == other.runner)
271 )
273 __hash__ = None # type: ignore 1d
275 @classmethod 1d
276 def new( 1d
277 cls,
278 __fn: _SyncOrAsyncCallable[P, T],
279 *args: P.args,
280 **kwargs: P.kwargs,
281 ) -> Self:
282 return cls( 1abc
283 future=Future(name=getattr(__fn, "__name__", str(__fn))),
284 fn=__fn,
285 args=args,
286 kwargs=kwargs,
287 context=contextvars.copy_context(),
288 timeout=None,
289 )
291 def set_timeout(self, timeout: Optional[float] = None) -> None: 1d
292 """
293 Set the timeout for the call.
295 The timeout begins when the call starts.
296 """
297 if self.future.done() or self.future.running():
298 raise RuntimeError("Timeouts cannot be added when the call has started.")
300 self.timeout = timeout
302 def set_runner(self, portal: "Portal") -> None: 1d
303 """
304 Update the portal used to run this call.
305 """
306 if self.runner is not None: 306 ↛ 307line 306 didn't jump to line 307 because the condition on line 306 was never true1eabc
307 raise RuntimeError("The portal is already set for this call.")
309 self.runner = portal 1eabc
311 def run(self) -> Optional[Awaitable[None]]: 1d
312 """
313 Execute the call and place the result on the future.
315 All exceptions during execution of the call are captured.
316 """
317 # Do not execute if the future is cancelled
318 if not self.future.set_running_or_notify_cancel(self.timeout): 318 ↛ 319line 318 didn't jump to line 319 because the condition on line 318 was never true1eabc
319 logger.debug("Skipping execution of cancelled call %r", self)
320 return None
322 logger.debug( 1eabc
323 "Running call %r in thread %r%s",
324 self,
325 threading.current_thread().name,
326 f" with timeout of {self.timeout}s" if self.timeout is not None else "",
327 )
329 coro = self.context.run(self._run_sync) 1eabc
331 if coro is not None: 331 ↛ 357line 331 didn't jump to line 357 because the condition on line 331 was always true1eabc
332 loop = get_running_loop() 1eabc
333 if loop: 333 ↛ 354line 333 didn't jump to line 354 because the condition on line 333 was always true1eabc
334 # If an event loop is available, return a task to be awaited
335 # Note we must create a task for context variables to propagate
336 logger.debug( 1eabc
337 "Scheduling coroutine for call %r in running loop %r",
338 self,
339 loop,
340 )
341 task = self.context.run(loop.create_task, self._run_async(coro)) 1eabc
343 # Prevent tasks from being garbage collected before completion
344 # See https://docs.python.org/3.10/library/asyncio-task.html#asyncio.create_task
345 _ASYNC_TASK_REFS.add(task) 1eabc
346 asyncio.ensure_future(task).add_done_callback( 1eabc
347 lambda _: _ASYNC_TASK_REFS.remove(task)
348 )
350 return task 1eabc
352 else:
353 # Otherwise, execute the function here
354 logger.debug("Executing coroutine for call %r in new loop", self)
355 return self.context.run(asyncio.run, self._run_async(coro))
357 return None
359 def result(self, timeout: Optional[float] = None) -> T: 1d
360 """
361 Wait for the result of the call.
363 Not safe for use from asynchronous contexts.
364 """
365 return self.future.result(timeout=timeout) 1eabc
367 async def aresult(self): 1d
368 """
369 Wait for the result of the call.
371 For use from asynchronous contexts.
372 """
373 try:
374 return await asyncio.wrap_future(self.future)
375 except asyncio.CancelledError as exc:
376 raise CancelledError() from exc
378 def cancelled(self) -> bool: 1d
379 """
380 Check if the call was cancelled.
381 """
382 return self.future.cancelled()
384 def timedout(self) -> bool: 1d
385 """
386 Check if the call timed out.
387 """
388 return self.future.timedout()
390 def cancel(self) -> bool: 1d
391 return self.future.cancel()
393 def _run_sync(self) -> Optional[Awaitable[T]]: 1d
394 cancel_scope = None 1eabc
395 try: 1eabc
396 with set_current_call(self): 1eabc
397 with self.future.enforce_sync_deadline() as cancel_scope: 1eabc
398 try: 1eabc
399 result = self.fn(*self.args, **self.kwargs) 1eabc
400 finally:
401 # Forget this call's arguments in order to free up any memory
402 # that may be referenced by them; after a call has happened,
403 # there's no need to keep a reference to them
404 with contextlib.suppress(AttributeError): 1eabc
405 del self.args, self.kwargs 1eabc
407 # Return the coroutine for async execution
408 if inspect.isawaitable(result): 408 ↛ 432line 408 didn't jump to line 432 because the condition on line 408 was always true1eabc
409 return result 1eabc
411 except CancelledError:
412 # Report cancellation
413 # in rare cases, enforce_sync_deadline raises CancelledError
414 # prior to yielding
415 if cancel_scope is None:
416 self.future.cancel()
417 return None
418 if cancel_scope.timedout():
419 setattr(self.future, "_timed_out", True)
420 self.future.cancel()
421 elif cancel_scope.cancelled():
422 self.future.cancel()
423 else:
424 raise
425 except BaseException as exc:
426 logger.debug("Encountered exception in call %r", self, exc_info=True)
427 self.future.set_exception(exc)
429 # Prevent reference cycle in `exc`
430 del self
431 else:
432 self.future.set_result(result) # noqa: F821
433 logger.debug("Finished call %r", self) # noqa: F821
435 async def _run_async(self, coro: Awaitable[T]) -> None: 1d
436 cancel_scope = result = None 1eabc
437 try: 1eabc
438 with set_current_call(self): 1eabc
439 with self.future.enforce_async_deadline() as cancel_scope: 1eabc
440 try: 1eabc
441 result = await coro 1eabc
442 finally:
443 # Forget this call's arguments in order to free up any memory
444 # that may be referenced by them; after a call has happened,
445 # there's no need to keep a reference to them
446 with contextlib.suppress(AttributeError): 1abc
447 del self.args, self.kwargs 1abc
448 except CancelledError:
449 # Report cancellation
450 if TYPE_CHECKING:
451 assert cancel_scope is not None
452 if cancel_scope.timedout():
453 setattr(self.future, "_timed_out", True)
454 self.future.cancel()
455 elif cancel_scope.cancelled():
456 self.future.cancel()
457 else:
458 raise
459 except BaseException as exc:
460 logger.debug("Encountered exception in async call %r", self, exc_info=True)
462 self.future.set_exception(exc)
463 # Prevent reference cycle in `exc`
464 del self
465 else:
466 # F821 ignored because Ruff gets confused about the del self above.
467 self.future.set_result(result) # noqa: F821 1abc
468 logger.debug("Finished async call %r", self) # noqa: F821 1abc
470 def __call__(self) -> Union[T, Awaitable[T]]: 1d
471 """
472 Execute the call and return its result.
474 All executions during execution of the call are re-raised.
475 """
476 coro = self.run()
478 # Return an awaitable if in an async context
479 if coro is not None:
481 async def run_and_return_result() -> T:
482 await coro
483 return self.result()
485 return run_and_return_result()
486 else:
487 return self.result()
489 def __repr__(self) -> str: 1d
490 name = getattr(self.fn, "__name__", str(self.fn))
492 try:
493 args, kwargs = self.args, self.kwargs
494 except AttributeError:
495 call_args = "<dropped>"
496 else:
497 call_args = ", ".join(
498 [repr(arg) for arg in args]
499 + [f"{key}={repr(val)}" for key, val in kwargs.items()]
500 )
502 # Enforce a maximum length
503 if len(call_args) > 100:
504 call_args = call_args[:100] + "..."
506 return f"{name}({call_args})"
509class Portal(abc.ABC): 1d
510 """
511 Allows submission of calls to execute elsewhere.
512 """
514 @abc.abstractmethod 1d
515 def submit(self, call: "Call[T]") -> "Call[T]": 1d
516 """
517 Submit a call to execute elsewhere.
519 The call's result can be retrieved with `call.result()`.
521 Returns the call for convenience.
522 """