Coverage for /usr/local/lib/python3.12/site-packages/prefect/utilities/asyncutils.py: 43%
225 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
1"""
2Utilities for interoperability with async functions and workers from various contexts.
3"""
5import asyncio 1a
6import inspect 1a
7import threading 1a
8import warnings 1a
9from collections.abc import AsyncGenerator, Awaitable, Coroutine 1a
10from contextlib import AbstractAsyncContextManager, asynccontextmanager 1a
11from contextvars import ContextVar 1a
12from functools import partial, wraps 1a
13from logging import Logger 1a
14from typing import TYPE_CHECKING, Any, Callable, NoReturn, Optional, Union, overload 1a
15from uuid import UUID, uuid4 1a
17import anyio 1a
18import anyio.abc 1a
19import anyio.from_thread 1a
20import anyio.to_thread 1a
21import sniffio 1a
22from typing_extensions import ( 1a
23 Literal,
24 ParamSpec,
25 Self,
26 TypeAlias,
27 TypeGuard,
28 TypeVar,
29 TypeVarTuple,
30 Unpack,
31)
33from prefect._internal.concurrency.api import cast_to_call, from_sync 1a
34from prefect._internal.concurrency.threads import ( 1a
35 get_run_sync_loop,
36 in_run_sync_loop,
37)
38from prefect.logging import get_logger 1a
40T = TypeVar("T") 1a
41P = ParamSpec("P") 1a
42R = TypeVar("R", infer_variance=True) 1a
43F = TypeVar("F", bound=Callable[..., Any]) 1a
44Async = Literal[True] 1a
45Sync = Literal[False] 1a
46A = TypeVar("A", Async, Sync, covariant=True) 1a
47PosArgsT = TypeVarTuple("PosArgsT") 1a
49_SyncOrAsyncCallable: TypeAlias = Callable[P, Union[R, Awaitable[R]]] 1a
51# Global references to prevent garbage collection for `add_event_loop_shutdown_callback`
52EVENT_LOOP_GC_REFS: dict[int, AsyncGenerator[None, Any]] = {} 1a
55RUNNING_IN_RUN_SYNC_LOOP_FLAG = ContextVar("running_in_run_sync_loop", default=False) 1a
56RUNNING_ASYNC_FLAG = ContextVar("run_async", default=False) 1a
57BACKGROUND_TASKS: set[asyncio.Task[Any]] = set() 1a
58background_task_lock: threading.Lock = threading.Lock() 1a
60# Thread-local storage to keep track of worker thread state
61_thread_local = threading.local() 1a
63logger: Logger = get_logger() 1a
66_prefect_thread_limiter: Optional[anyio.CapacityLimiter] = None 1a
69def get_thread_limiter() -> anyio.CapacityLimiter: 1a
70 global _prefect_thread_limiter
72 if _prefect_thread_limiter is None: 72 ↛ 75line 72 didn't jump to line 75 because the condition on line 72 was always true1a
73 _prefect_thread_limiter = anyio.CapacityLimiter(250) 1a
75 return _prefect_thread_limiter 1a
78def is_async_fn( 1a
79 func: _SyncOrAsyncCallable[P, R],
80) -> TypeGuard[Callable[P, Coroutine[Any, Any, Any]]]:
81 """
82 Returns `True` if a function returns a coroutine.
84 See https://github.com/microsoft/pyright/issues/2142 for an example use
85 """
86 func = inspect.unwrap(func) 1a
87 return inspect.iscoroutinefunction(func) 1a
90def is_async_gen_fn( 1a
91 func: Callable[P, Any],
92) -> TypeGuard[Callable[P, AsyncGenerator[Any, Any]]]:
93 """
94 Returns `True` if a function is an async generator.
95 """
96 func = inspect.unwrap(func)
97 return inspect.isasyncgenfunction(func)
100def create_task(coroutine: Coroutine[Any, Any, R]) -> asyncio.Task[R]: 1a
101 """
102 Replacement for asyncio.create_task that will ensure that tasks aren't
103 garbage collected before they complete. Allows for "fire and forget"
104 behavior in which tasks can be created and the application can move on.
105 Tasks can also be awaited normally.
107 See https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
108 for details (and essentially this implementation)
109 """
111 task = asyncio.create_task(coroutine)
113 # Add task to the set. This creates a strong reference.
114 # Take a lock because this might be done from multiple threads.
115 with background_task_lock:
116 BACKGROUND_TASKS.add(task)
118 # To prevent keeping references to finished tasks forever,
119 # make each task remove its own reference from the set after
120 # completion:
121 task.add_done_callback(BACKGROUND_TASKS.discard)
123 return task
126@overload 1a
127def run_coro_as_sync( 127 ↛ exitline 127 didn't return from function 'run_coro_as_sync' because 1a
128 coroutine: Coroutine[Any, Any, R],
129 *,
130 force_new_thread: bool = ...,
131 wait_for_result: Literal[True] = ...,
132) -> R: ...
135@overload 1a
136def run_coro_as_sync( 136 ↛ exitline 136 didn't return from function 'run_coro_as_sync' because 1a
137 coroutine: Coroutine[Any, Any, R],
138 *,
139 force_new_thread: bool = ...,
140 wait_for_result: Literal[False] = False,
141) -> R: ...
144def run_coro_as_sync( 1a
145 coroutine: Coroutine[Any, Any, R],
146 *,
147 force_new_thread: bool = False,
148 wait_for_result: bool = True,
149) -> Optional[R]:
150 """
151 Runs a coroutine from a synchronous context, as if it were a synchronous
152 function.
154 The coroutine is scheduled to run in the "run sync" event loop, which is
155 running in its own thread and is started the first time it is needed. This
156 allows us to share objects like async httpx clients among all coroutines
157 running in the loop.
159 If run_sync is called from within the run_sync loop, it will run the
160 coroutine in a new thread, because otherwise a deadlock would occur. Note
161 that this behavior should not appear anywhere in the Prefect codebase or in
162 user code.
164 Args:
165 coroutine (Awaitable): The coroutine to be run as a synchronous function.
166 force_new_thread (bool, optional): If True, the coroutine will always be run in a new thread.
167 Defaults to False.
168 wait_for_result (bool, optional): If True, the function will wait for the coroutine to complete
169 and return the result. If False, the function will submit the coroutine to the "run sync"
170 event loop and return immediately, where it will eventually be run. Defaults to True.
172 Returns:
173 The result of the coroutine if wait_for_result is True, otherwise None.
174 """
176 async def coroutine_wrapper() -> Optional[R]:
177 """
178 Set flags so that children (and grandchildren...) of this task know they are running in a new
179 thread and do not try to run on the run_sync thread, which would cause a
180 deadlock.
181 """
182 token1 = RUNNING_IN_RUN_SYNC_LOOP_FLAG.set(True)
183 token2 = RUNNING_ASYNC_FLAG.set(True)
184 try:
185 # use `asyncio.create_task` because it copies context variables automatically
186 task = create_task(coroutine)
187 if wait_for_result:
188 return await task
189 finally:
190 RUNNING_IN_RUN_SYNC_LOOP_FLAG.reset(token1)
191 RUNNING_ASYNC_FLAG.reset(token2)
193 # if we are already in the run_sync loop, or a descendent of a coroutine
194 # that is running in the run_sync loop, we need to run this coroutine in a
195 # new thread
196 if in_run_sync_loop() or RUNNING_IN_RUN_SYNC_LOOP_FLAG.get() or force_new_thread:
197 result = from_sync.call_in_new_thread(coroutine_wrapper)
198 return result
200 # otherwise, we can run the coroutine in the run_sync loop
201 # and wait for the result
202 else:
203 call = cast_to_call(coroutine_wrapper)
204 runner = get_run_sync_loop()
205 runner.submit(call)
206 try:
207 return call.result()
208 except KeyboardInterrupt:
209 call.cancel()
211 logger.debug("Coroutine cancelled due to KeyboardInterrupt.")
212 raise
215async def run_sync_in_worker_thread( 1a
216 __fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs
217) -> R:
218 """
219 Runs a sync function in a new worker thread so that the main thread's event loop
220 is not blocked.
222 Unlike the anyio function, this defaults to a cancellable thread and does not allow
223 passing arguments to the anyio function so users can pass kwargs to their function.
225 Note that cancellation of threads will not result in interrupted computation, the
226 thread may continue running — the outcome will just be ignored.
227 """
228 # When running a sync function in a worker thread, we set this flag so that
229 # any root sync compatible functions will run as sync functions
230 token = RUNNING_ASYNC_FLAG.set(False) 1a
231 try: 1a
232 call = partial(__fn, *args, **kwargs) 1a
233 result = await anyio.to_thread.run_sync( 1ac
234 call_with_mark, call, abandon_on_cancel=True, limiter=get_thread_limiter()
235 )
236 return result 1c
237 finally:
238 RUNNING_ASYNC_FLAG.reset(token) 1c
241def call_with_mark(call: Callable[..., R]) -> R: 1a
242 mark_as_worker_thread() 1b
243 return call() 1bc
246def run_async_from_worker_thread( 1a
247 __fn: Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs
248) -> R:
249 """
250 Runs an async function in the main thread's event loop, blocking the worker
251 thread until completion
252 """
253 call = partial(__fn, *args, **kwargs) 1b
254 return anyio.from_thread.run(call) 1bc
257def run_async_in_new_loop( 1a
258 __fn: Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs
259) -> R:
260 return anyio.run(partial(__fn, *args, **kwargs))
263def mark_as_worker_thread() -> None: 1a
264 _thread_local.is_worker_thread = True 1b
267def in_async_worker_thread() -> bool: 1a
268 return getattr(_thread_local, "is_worker_thread", False)
271def in_async_main_thread() -> bool: 1a
272 try:
273 sniffio.current_async_library()
274 except sniffio.AsyncLibraryNotFoundError:
275 return False
276 else:
277 # We could be in a worker thread, not the main thread
278 return not in_async_worker_thread()
281def sync_compatible( 1a
282 async_fn: Callable[P, Coroutine[Any, Any, R]],
283) -> Callable[P, Union[R, Coroutine[Any, Any, R]]]:
284 """
285 Converts an async function into a dual async and sync function.
287 When the returned function is called, we will attempt to determine the best way
288 to enter the async function.
290 - If in a thread with a running event loop, we will return the coroutine for the
291 caller to await. This is normal async behavior.
292 - If in a blocking worker thread with access to an event loop in another thread, we
293 will submit the async method to the event loop.
294 - If we cannot find an event loop, we will create a new one and run the async method
295 then tear down the loop.
297 Note: Type checkers will infer functions decorated with `@sync_compatible` are synchronous. If
298 you want to use the decorated function in an async context, you will need to ignore the types
299 and "cast" the return type to a coroutine. For example:
300 ```
301 python result: Coroutine = sync_compatible(my_async_function)(arg1, arg2) # type: ignore
302 ```
303 """
305 @wraps(async_fn) 1a
306 def coroutine_wrapper( 1a
307 *args: Any, _sync: Optional[bool] = None, **kwargs: Any
308 ) -> Union[R, Coroutine[Any, Any, R]]:
309 from prefect.context import MissingContextError, get_run_context
311 if _sync is False:
312 return async_fn(*args, **kwargs)
314 is_async = True
316 # if _sync is set, we do as we're told
317 # otherwise, we make some determinations
318 if _sync is None:
319 try:
320 run_ctx = get_run_context()
321 parent_obj = getattr(run_ctx, "task", None)
322 if not parent_obj:
323 parent_obj = getattr(run_ctx, "flow", None)
324 is_async = getattr(parent_obj, "isasync", True)
325 except MissingContextError:
326 # not in an execution context, make best effort to
327 # decide whether to syncify
328 try:
329 asyncio.get_running_loop()
330 is_async = True
331 except RuntimeError:
332 is_async = False
334 async def ctx_call():
335 """
336 Wrapper that is submitted using copy_context().run to ensure
337 mutations of RUNNING_ASYNC_FLAG are tightly scoped to this coroutine's frame.
338 """
339 token = RUNNING_ASYNC_FLAG.set(True)
340 try:
341 result = await async_fn(*args, **kwargs)
342 finally:
343 RUNNING_ASYNC_FLAG.reset(token)
344 return result
346 if _sync is True:
347 return run_coro_as_sync(ctx_call())
348 elif RUNNING_ASYNC_FLAG.get() or is_async:
349 return ctx_call()
350 else:
351 return run_coro_as_sync(ctx_call())
353 if is_async_fn(async_fn): 353 ↛ 355line 353 didn't jump to line 355 because the condition on line 353 was always true1a
354 wrapper = coroutine_wrapper 1a
355 elif is_async_gen_fn(async_fn):
356 raise ValueError("Async generators cannot yet be marked as `sync_compatible`")
357 else:
358 raise TypeError("The decorated function must be async.")
360 wrapper.aio = async_fn # type: ignore 1a
361 return wrapper 1a
364@overload 1a
365def asyncnullcontext( 365 ↛ exitline 365 didn't return from function 'asyncnullcontext' because 1a
366 value: None = None, *args: Any, **kwargs: Any
367) -> AbstractAsyncContextManager[None, None]: ...
370@overload 1a
371def asyncnullcontext( 371 ↛ exitline 371 didn't return from function 'asyncnullcontext' because 1a
372 value: R, *args: Any, **kwargs: Any
373) -> AbstractAsyncContextManager[R, None]: ...
376@asynccontextmanager 1a
377async def asyncnullcontext( 1a
378 value: Optional[R] = None, *args: Any, **kwargs: Any
379) -> AsyncGenerator[Any, Optional[R]]:
380 yield value
383def sync(__async_fn: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: 1a
384 """
385 Call an async function from a synchronous context. Block until completion.
387 If in an asynchronous context, we will run the code in a separate loop instead of
388 failing but a warning will be displayed since this is not recommended.
389 """
390 if in_async_main_thread():
391 warnings.warn(
392 "`sync` called from an asynchronous context; "
393 "you should `await` the async function directly instead."
394 )
395 with anyio.from_thread.start_blocking_portal() as portal:
396 return portal.call(partial(__async_fn, *args, **kwargs))
397 elif in_async_worker_thread():
398 # In a sync context but we can access the event loop thread; send the async
399 # call to the parent
400 return run_async_from_worker_thread(__async_fn, *args, **kwargs)
401 else:
402 # In a sync context and there is no event loop; just create an event loop
403 # to run the async code then tear it down
404 return run_async_in_new_loop(__async_fn, *args, **kwargs)
407async def add_event_loop_shutdown_callback( 1a
408 coroutine_fn: Callable[[], Awaitable[Any]],
409) -> None:
410 """
411 Adds a callback to the given callable on event loop closure. The callable must be
412 a coroutine function. It will be awaited when the current event loop is shutting
413 down.
415 Requires use of `asyncio.run()` which waits for async generator shutdown by
416 default or explicit call of `asyncio.shutdown_asyncgens()`. If the application
417 is entered with `asyncio.run_until_complete()` and the user calls
418 `asyncio.close()` without the generator shutdown call, this will not trigger
419 callbacks.
421 asyncio does not provided _any_ other way to clean up a resource when the event
422 loop is about to close.
423 """
425 async def on_shutdown(key: int) -> AsyncGenerator[None, Any]: 1b
426 # It appears that EVENT_LOOP_GC_REFS is somehow being garbage collected early.
427 # We hold a reference to it so as to preserve it, at least for the lifetime of
428 # this coroutine. See the issue below for the initial report/discussion:
429 # https://github.com/PrefectHQ/prefect/issues/7709#issuecomment-1560021109
430 _ = EVENT_LOOP_GC_REFS 1b
431 try: 1b
432 yield 1b
433 except GeneratorExit:
434 await coroutine_fn()
435 # Remove self from the garbage collection set
436 EVENT_LOOP_GC_REFS.pop(key)
438 # Create the iterator and store it in a global variable so it is not garbage
439 # collected. If the iterator is garbage collected before the event loop closes, the
440 # callback will not run. Since this function does not know the scope of the event
441 # loop that is calling it, a reference with global scope is necessary to ensure
442 # garbage collection does not occur until after event loop closure.
443 key = id(on_shutdown) 1b
444 EVENT_LOOP_GC_REFS[key] = on_shutdown(key) 1b
446 # Begin iterating so it will be cleaned up as an incomplete generator
447 try: 1b
448 await EVENT_LOOP_GC_REFS[key].__anext__() 1b
449 # There is a poorly understood edge case we've seen in CI where the key is
450 # removed from the dict before we begin generator iteration.
451 except KeyError:
452 logger.warning("The event loop shutdown callback was not properly registered. ")
453 pass
456class GatherIncomplete(RuntimeError): 1a
457 """Used to indicate retrieving gather results before completion"""
460class GatherTaskGroup(anyio.abc.TaskGroup): 1a
461 """
462 A task group that gathers results.
464 AnyIO does not include `gather` support. This class extends the `TaskGroup`
465 interface to allow simple gathering.
467 See https://github.com/agronholm/anyio/issues/100
469 This class should be instantiated with `create_gather_task_group`.
470 """
472 def __init__(self, task_group: anyio.abc.TaskGroup): 1a
473 self._results: dict[UUID, Any] = {}
474 # The concrete task group implementation to use
475 self._task_group: anyio.abc.TaskGroup = task_group
477 async def _run_and_store( 1a
478 self,
479 key: UUID,
480 fn: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
481 *args: Unpack[PosArgsT],
482 ) -> None:
483 self._results[key] = await fn(*args)
485 def start_soon( # pyright: ignore[reportIncompatibleMethodOverride] 1a
486 self,
487 func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
488 *args: Unpack[PosArgsT],
489 name: object = None,
490 ) -> UUID:
491 key = uuid4()
492 # Put a placeholder in-case the result is retrieved earlier
493 self._results[key] = GatherIncomplete
494 self._task_group.start_soon(self._run_and_store, key, func, *args, name=name)
495 return key
497 async def start(self, func: object, *args: object, name: object = None) -> NoReturn: 1a
498 """
499 Since `start` returns the result of `task_status.started()` but here we must
500 return the key instead, we just won't support this method for now.
501 """
502 raise RuntimeError("`GatherTaskGroup` does not support `start`.")
504 def get_result(self, key: UUID) -> Any: 1a
505 result = self._results[key]
506 if result is GatherIncomplete:
507 raise GatherIncomplete(
508 "Task is not complete. "
509 "Results should not be retrieved until the task group exits."
510 )
511 return result
513 async def __aenter__(self) -> Self: 1a
514 await self._task_group.__aenter__()
515 return self
517 async def __aexit__(self, *tb: Any) -> Optional[bool]: # pyright: ignore[reportIncompatibleMethodOverride] 1a
518 try:
519 retval = await self._task_group.__aexit__(*tb)
520 return retval
521 finally:
522 del self._task_group
525def create_gather_task_group() -> GatherTaskGroup: 1a
526 """Create a new task group that gathers results"""
527 # This function matches the AnyIO API which uses callables since the concrete
528 # task group class depends on the async library being used and cannot be
529 # determined until runtime
530 return GatherTaskGroup(anyio.create_task_group())
533async def gather(*calls: Callable[[], Coroutine[Any, Any, T]]) -> list[T]: 1a
534 """
535 Run calls concurrently and gather their results.
537 Unlike `asyncio.gather` this expects to receive _callables_ not _coroutines_.
538 This matches `anyio` semantics.
539 """
540 keys: list[UUID] = []
541 async with create_gather_task_group() as tg:
542 for call in calls:
543 keys.append(tg.start_soon(call))
544 return [tg.get_result(key) for key in keys]
547class LazySemaphore: 1a
548 def __init__(self, initial_value_func: Callable[[], int]) -> None: 1a
549 self._semaphore: Optional[asyncio.Semaphore] = None 1a
550 self._initial_value_func = initial_value_func 1a
552 async def __aenter__(self) -> asyncio.Semaphore: 1a
553 self._initialize_semaphore()
554 if TYPE_CHECKING:
555 assert self._semaphore is not None
556 await self._semaphore.__aenter__()
557 return self._semaphore
559 async def __aexit__(self, *args: Any) -> None: 1a
560 if TYPE_CHECKING:
561 assert self._semaphore is not None
562 await self._semaphore.__aexit__(*args)
564 def _initialize_semaphore(self) -> None: 1a
565 if self._semaphore is None:
566 initial_value = self._initial_value_func()
567 self._semaphore = asyncio.Semaphore(initial_value)