Coverage for /usr/local/lib/python3.12/site-packages/prefect/utilities/asyncutils.py: 52%

225 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1""" 

2Utilities for interoperability with async functions and workers from various contexts. 

3""" 

4 

5import asyncio 1a

6import inspect 1a

7import threading 1a

8import warnings 1a

9from collections.abc import AsyncGenerator, Awaitable, Coroutine 1a

10from contextlib import AbstractAsyncContextManager, asynccontextmanager 1a

11from contextvars import ContextVar 1a

12from functools import partial, wraps 1a

13from logging import Logger 1a

14from typing import TYPE_CHECKING, Any, Callable, NoReturn, Optional, Union, overload 1a

15from uuid import UUID, uuid4 1a

16 

17import anyio 1a

18import anyio.abc 1a

19import anyio.from_thread 1a

20import anyio.to_thread 1a

21import sniffio 1a

22from typing_extensions import ( 1a

23 Literal, 

24 ParamSpec, 

25 Self, 

26 TypeAlias, 

27 TypeGuard, 

28 TypeVar, 

29 TypeVarTuple, 

30 Unpack, 

31) 

32 

33from prefect._internal.concurrency.api import cast_to_call, from_sync 1a

34from prefect._internal.concurrency.threads import ( 1a

35 get_run_sync_loop, 

36 in_run_sync_loop, 

37) 

38from prefect.logging import get_logger 1a

39 

40T = TypeVar("T") 1a

41P = ParamSpec("P") 1a

42R = TypeVar("R", infer_variance=True) 1a

43F = TypeVar("F", bound=Callable[..., Any]) 1a

44Async = Literal[True] 1a

45Sync = Literal[False] 1a

46A = TypeVar("A", Async, Sync, covariant=True) 1a

47PosArgsT = TypeVarTuple("PosArgsT") 1a

48 

49_SyncOrAsyncCallable: TypeAlias = Callable[P, Union[R, Awaitable[R]]] 1a

50 

51# Global references to prevent garbage collection for `add_event_loop_shutdown_callback` 

52EVENT_LOOP_GC_REFS: dict[int, AsyncGenerator[None, Any]] = {} 1a

53 

54 

55RUNNING_IN_RUN_SYNC_LOOP_FLAG = ContextVar("running_in_run_sync_loop", default=False) 1a

56RUNNING_ASYNC_FLAG = ContextVar("run_async", default=False) 1a

57BACKGROUND_TASKS: set[asyncio.Task[Any]] = set() 1a

58background_task_lock: threading.Lock = threading.Lock() 1a

59 

60# Thread-local storage to keep track of worker thread state 

61_thread_local = threading.local() 1a

62 

63logger: Logger = get_logger() 1a

64 

65 

66_prefect_thread_limiter: Optional[anyio.CapacityLimiter] = None 1a

67 

68 

69def get_thread_limiter() -> anyio.CapacityLimiter: 1a

70 global _prefect_thread_limiter 

71 

72 if _prefect_thread_limiter is None: 72 ↛ 75line 72 didn't jump to line 75 because the condition on line 72 was always true1a

73 _prefect_thread_limiter = anyio.CapacityLimiter(250) 1a

74 

75 return _prefect_thread_limiter 1a

76 

77 

78def is_async_fn( 1a

79 func: _SyncOrAsyncCallable[P, R], 

80) -> TypeGuard[Callable[P, Coroutine[Any, Any, Any]]]: 

81 """ 

82 Returns `True` if a function returns a coroutine. 

83 

84 See https://github.com/microsoft/pyright/issues/2142 for an example use 

85 """ 

86 func = inspect.unwrap(func) 1a

87 return inspect.iscoroutinefunction(func) 1a

88 

89 

90def is_async_gen_fn( 1a

91 func: Callable[P, Any], 

92) -> TypeGuard[Callable[P, AsyncGenerator[Any, Any]]]: 

93 """ 

94 Returns `True` if a function is an async generator. 

95 """ 

96 func = inspect.unwrap(func) 

97 return inspect.isasyncgenfunction(func) 

98 

99 

100def create_task(coroutine: Coroutine[Any, Any, R]) -> asyncio.Task[R]: 1a

101 """ 

102 Replacement for asyncio.create_task that will ensure that tasks aren't 

103 garbage collected before they complete. Allows for "fire and forget" 

104 behavior in which tasks can be created and the application can move on. 

105 Tasks can also be awaited normally. 

106 

107 See https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task 

108 for details (and essentially this implementation) 

109 """ 

110 

111 task = asyncio.create_task(coroutine) 1bcde

112 

113 # Add task to the set. This creates a strong reference. 

114 # Take a lock because this might be done from multiple threads. 

115 with background_task_lock: 1bcde

116 BACKGROUND_TASKS.add(task) 1bcde

117 

118 # To prevent keeping references to finished tasks forever, 

119 # make each task remove its own reference from the set after 

120 # completion: 

121 task.add_done_callback(BACKGROUND_TASKS.discard) 1bcde

122 

123 return task 1bcde

124 

125 

126@overload 1a

127def run_coro_as_sync( 127 ↛ exitline 127 didn't return from function 'run_coro_as_sync' because 1a

128 coroutine: Coroutine[Any, Any, R], 

129 *, 

130 force_new_thread: bool = ..., 

131 wait_for_result: Literal[True] = ..., 

132) -> R: ... 

133 

134 

135@overload 1a

136def run_coro_as_sync( 136 ↛ exitline 136 didn't return from function 'run_coro_as_sync' because 1a

137 coroutine: Coroutine[Any, Any, R], 

138 *, 

139 force_new_thread: bool = ..., 

140 wait_for_result: Literal[False] = False, 

141) -> R: ... 

142 

143 

144def run_coro_as_sync( 1a

145 coroutine: Coroutine[Any, Any, R], 

146 *, 

147 force_new_thread: bool = False, 

148 wait_for_result: bool = True, 

149) -> Optional[R]: 

150 """ 

151 Runs a coroutine from a synchronous context, as if it were a synchronous 

152 function. 

153 

154 The coroutine is scheduled to run in the "run sync" event loop, which is 

155 running in its own thread and is started the first time it is needed. This 

156 allows us to share objects like async httpx clients among all coroutines 

157 running in the loop. 

158 

159 If run_sync is called from within the run_sync loop, it will run the 

160 coroutine in a new thread, because otherwise a deadlock would occur. Note 

161 that this behavior should not appear anywhere in the Prefect codebase or in 

162 user code. 

163 

164 Args: 

165 coroutine (Awaitable): The coroutine to be run as a synchronous function. 

166 force_new_thread (bool, optional): If True, the coroutine will always be run in a new thread. 

167 Defaults to False. 

168 wait_for_result (bool, optional): If True, the function will wait for the coroutine to complete 

169 and return the result. If False, the function will submit the coroutine to the "run sync" 

170 event loop and return immediately, where it will eventually be run. Defaults to True. 

171 

172 Returns: 

173 The result of the coroutine if wait_for_result is True, otherwise None. 

174 """ 

175 

176 async def coroutine_wrapper() -> Optional[R]: 1cde

177 """ 

178 Set flags so that children (and grandchildren...) of this task know they are running in a new 

179 thread and do not try to run on the run_sync thread, which would cause a 

180 deadlock. 

181 """ 

182 token1 = RUNNING_IN_RUN_SYNC_LOOP_FLAG.set(True) 1bcde

183 token2 = RUNNING_ASYNC_FLAG.set(True) 1bcde

184 try: 1bcde

185 # use `asyncio.create_task` because it copies context variables automatically 

186 task = create_task(coroutine) 1bcde

187 if wait_for_result: 187 ↛ 190line 187 didn't jump to line 190 because the condition on line 187 was always true1bcde

188 return await task 1bcde

189 finally: 

190 RUNNING_IN_RUN_SYNC_LOOP_FLAG.reset(token1) 1cde

191 RUNNING_ASYNC_FLAG.reset(token2) 1cde

192 

193 # if we are already in the run_sync loop, or a descendent of a coroutine 

194 # that is running in the run_sync loop, we need to run this coroutine in a 

195 # new thread 

196 if in_run_sync_loop() or RUNNING_IN_RUN_SYNC_LOOP_FLAG.get() or force_new_thread: 196 ↛ 197line 196 didn't jump to line 197 because the condition on line 196 was never true1cde

197 result = from_sync.call_in_new_thread(coroutine_wrapper) 

198 return result 

199 

200 # otherwise, we can run the coroutine in the run_sync loop 

201 # and wait for the result 

202 else: 

203 call = cast_to_call(coroutine_wrapper) 1cde

204 runner = get_run_sync_loop() 1bcde

205 runner.submit(call) 1bcde

206 try: 1bcde

207 return call.result() 1bcde

208 except KeyboardInterrupt: 

209 call.cancel() 

210 

211 logger.debug("Coroutine cancelled due to KeyboardInterrupt.") 

212 raise 

213 

214 

215async def run_sync_in_worker_thread( 1a

216 __fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs 

217) -> R: 

218 """ 

219 Runs a sync function in a new worker thread so that the main thread's event loop 

220 is not blocked. 

221 

222 Unlike the anyio function, this defaults to a cancellable thread and does not allow 

223 passing arguments to the anyio function so users can pass kwargs to their function. 

224 

225 Note that cancellation of threads will not result in interrupted computation, the 

226 thread may continue running — the outcome will just be ignored. 

227 """ 

228 # When running a sync function in a worker thread, we set this flag so that 

229 # any root sync compatible functions will run as sync functions 

230 token = RUNNING_ASYNC_FLAG.set(False) 1a

231 try: 1a

232 call = partial(__fn, *args, **kwargs) 1a

233 result = await anyio.to_thread.run_sync( 1ag

234 call_with_mark, call, abandon_on_cancel=True, limiter=get_thread_limiter() 

235 ) 

236 return result 1g

237 finally: 

238 RUNNING_ASYNC_FLAG.reset(token) 1g

239 

240 

241def call_with_mark(call: Callable[..., R]) -> R: 1a

242 mark_as_worker_thread() 1f

243 return call() 1fg

244 

245 

246def run_async_from_worker_thread( 1a

247 __fn: Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs 

248) -> R: 

249 """ 

250 Runs an async function in the main thread's event loop, blocking the worker 

251 thread until completion 

252 """ 

253 call = partial(__fn, *args, **kwargs) 1f

254 return anyio.from_thread.run(call) 1fg

255 

256 

257def run_async_in_new_loop( 1a

258 __fn: Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs 

259) -> R: 

260 return anyio.run(partial(__fn, *args, **kwargs)) 

261 

262 

263def mark_as_worker_thread() -> None: 1a

264 _thread_local.is_worker_thread = True 1f

265 

266 

267def in_async_worker_thread() -> bool: 1a

268 return getattr(_thread_local, "is_worker_thread", False) 

269 

270 

271def in_async_main_thread() -> bool: 1a

272 try: 

273 sniffio.current_async_library() 

274 except sniffio.AsyncLibraryNotFoundError: 

275 return False 

276 else: 

277 # We could be in a worker thread, not the main thread 

278 return not in_async_worker_thread() 

279 

280 

281def sync_compatible( 1a

282 async_fn: Callable[P, Coroutine[Any, Any, R]], 

283) -> Callable[P, Union[R, Coroutine[Any, Any, R]]]: 

284 """ 

285 Converts an async function into a dual async and sync function. 

286 

287 When the returned function is called, we will attempt to determine the best way 

288 to enter the async function. 

289 

290 - If in a thread with a running event loop, we will return the coroutine for the 

291 caller to await. This is normal async behavior. 

292 - If in a blocking worker thread with access to an event loop in another thread, we 

293 will submit the async method to the event loop. 

294 - If we cannot find an event loop, we will create a new one and run the async method 

295 then tear down the loop. 

296 

297 Note: Type checkers will infer functions decorated with `@sync_compatible` are synchronous. If 

298 you want to use the decorated function in an async context, you will need to ignore the types 

299 and "cast" the return type to a coroutine. For example: 

300 ``` 

301 python result: Coroutine = sync_compatible(my_async_function)(arg1, arg2) # type: ignore 

302 ``` 

303 """ 

304 

305 @wraps(async_fn) 1a

306 def coroutine_wrapper( 1a

307 *args: Any, _sync: Optional[bool] = None, **kwargs: Any 

308 ) -> Union[R, Coroutine[Any, Any, R]]: 

309 from prefect.context import MissingContextError, get_run_context 

310 

311 if _sync is False: 

312 return async_fn(*args, **kwargs) 

313 

314 is_async = True 

315 

316 # if _sync is set, we do as we're told 

317 # otherwise, we make some determinations 

318 if _sync is None: 

319 try: 

320 run_ctx = get_run_context() 

321 parent_obj = getattr(run_ctx, "task", None) 

322 if not parent_obj: 

323 parent_obj = getattr(run_ctx, "flow", None) 

324 is_async = getattr(parent_obj, "isasync", True) 

325 except MissingContextError: 

326 # not in an execution context, make best effort to 

327 # decide whether to syncify 

328 try: 

329 asyncio.get_running_loop() 

330 is_async = True 

331 except RuntimeError: 

332 is_async = False 

333 

334 async def ctx_call(): 

335 """ 

336 Wrapper that is submitted using copy_context().run to ensure 

337 mutations of RUNNING_ASYNC_FLAG are tightly scoped to this coroutine's frame. 

338 """ 

339 token = RUNNING_ASYNC_FLAG.set(True) 

340 try: 

341 result = await async_fn(*args, **kwargs) 

342 finally: 

343 RUNNING_ASYNC_FLAG.reset(token) 

344 return result 

345 

346 if _sync is True: 

347 return run_coro_as_sync(ctx_call()) 

348 elif RUNNING_ASYNC_FLAG.get() or is_async: 

349 return ctx_call() 

350 else: 

351 return run_coro_as_sync(ctx_call()) 

352 

353 if is_async_fn(async_fn): 353 ↛ 355line 353 didn't jump to line 355 because the condition on line 353 was always true1a

354 wrapper = coroutine_wrapper 1a

355 elif is_async_gen_fn(async_fn): 

356 raise ValueError("Async generators cannot yet be marked as `sync_compatible`") 

357 else: 

358 raise TypeError("The decorated function must be async.") 

359 

360 wrapper.aio = async_fn # type: ignore 1a

361 return wrapper 1a

362 

363 

364@overload 1a

365def asyncnullcontext( 365 ↛ exitline 365 didn't return from function 'asyncnullcontext' because 1a

366 value: None = None, *args: Any, **kwargs: Any 

367) -> AbstractAsyncContextManager[None, None]: ... 

368 

369 

370@overload 1a

371def asyncnullcontext( 371 ↛ exitline 371 didn't return from function 'asyncnullcontext' because 1a

372 value: R, *args: Any, **kwargs: Any 

373) -> AbstractAsyncContextManager[R, None]: ... 

374 

375 

376@asynccontextmanager 1a

377async def asyncnullcontext( 1a

378 value: Optional[R] = None, *args: Any, **kwargs: Any 

379) -> AsyncGenerator[Any, Optional[R]]: 

380 yield value 

381 

382 

383def sync(__async_fn: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: 1a

384 """ 

385 Call an async function from a synchronous context. Block until completion. 

386 

387 If in an asynchronous context, we will run the code in a separate loop instead of 

388 failing but a warning will be displayed since this is not recommended. 

389 """ 

390 if in_async_main_thread(): 

391 warnings.warn( 

392 "`sync` called from an asynchronous context; " 

393 "you should `await` the async function directly instead." 

394 ) 

395 with anyio.from_thread.start_blocking_portal() as portal: 

396 return portal.call(partial(__async_fn, *args, **kwargs)) 

397 elif in_async_worker_thread(): 

398 # In a sync context but we can access the event loop thread; send the async 

399 # call to the parent 

400 return run_async_from_worker_thread(__async_fn, *args, **kwargs) 

401 else: 

402 # In a sync context and there is no event loop; just create an event loop 

403 # to run the async code then tear it down 

404 return run_async_in_new_loop(__async_fn, *args, **kwargs) 

405 

406 

407async def add_event_loop_shutdown_callback( 1a

408 coroutine_fn: Callable[[], Awaitable[Any]], 

409) -> None: 

410 """ 

411 Adds a callback to the given callable on event loop closure. The callable must be 

412 a coroutine function. It will be awaited when the current event loop is shutting 

413 down. 

414 

415 Requires use of `asyncio.run()` which waits for async generator shutdown by 

416 default or explicit call of `asyncio.shutdown_asyncgens()`. If the application 

417 is entered with `asyncio.run_until_complete()` and the user calls 

418 `asyncio.close()` without the generator shutdown call, this will not trigger 

419 callbacks. 

420 

421 asyncio does not provided _any_ other way to clean up a resource when the event 

422 loop is about to close. 

423 """ 

424 

425 async def on_shutdown(key: int) -> AsyncGenerator[None, Any]: 1fb

426 # It appears that EVENT_LOOP_GC_REFS is somehow being garbage collected early. 

427 # We hold a reference to it so as to preserve it, at least for the lifetime of 

428 # this coroutine. See the issue below for the initial report/discussion: 

429 # https://github.com/PrefectHQ/prefect/issues/7709#issuecomment-1560021109 

430 _ = EVENT_LOOP_GC_REFS 1fb

431 try: 1fb

432 yield 1fb

433 except GeneratorExit: 

434 await coroutine_fn() 

435 # Remove self from the garbage collection set 

436 EVENT_LOOP_GC_REFS.pop(key) 

437 

438 # Create the iterator and store it in a global variable so it is not garbage 

439 # collected. If the iterator is garbage collected before the event loop closes, the 

440 # callback will not run. Since this function does not know the scope of the event 

441 # loop that is calling it, a reference with global scope is necessary to ensure 

442 # garbage collection does not occur until after event loop closure. 

443 key = id(on_shutdown) 1fb

444 EVENT_LOOP_GC_REFS[key] = on_shutdown(key) 1fb

445 

446 # Begin iterating so it will be cleaned up as an incomplete generator 

447 try: 1fb

448 await EVENT_LOOP_GC_REFS[key].__anext__() 1fb

449 # There is a poorly understood edge case we've seen in CI where the key is 

450 # removed from the dict before we begin generator iteration. 

451 except KeyError: 

452 logger.warning("The event loop shutdown callback was not properly registered. ") 

453 pass 

454 

455 

456class GatherIncomplete(RuntimeError): 1a

457 """Used to indicate retrieving gather results before completion""" 

458 

459 

460class GatherTaskGroup(anyio.abc.TaskGroup): 1a

461 """ 

462 A task group that gathers results. 

463 

464 AnyIO does not include `gather` support. This class extends the `TaskGroup` 

465 interface to allow simple gathering. 

466 

467 See https://github.com/agronholm/anyio/issues/100 

468 

469 This class should be instantiated with `create_gather_task_group`. 

470 """ 

471 

472 def __init__(self, task_group: anyio.abc.TaskGroup): 1a

473 self._results: dict[UUID, Any] = {} 

474 # The concrete task group implementation to use 

475 self._task_group: anyio.abc.TaskGroup = task_group 

476 

477 async def _run_and_store( 1a

478 self, 

479 key: UUID, 

480 fn: Callable[[Unpack[PosArgsT]], Awaitable[Any]], 

481 *args: Unpack[PosArgsT], 

482 ) -> None: 

483 self._results[key] = await fn(*args) 

484 

485 def start_soon( # pyright: ignore[reportIncompatibleMethodOverride] 1a

486 self, 

487 func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], 

488 *args: Unpack[PosArgsT], 

489 name: object = None, 

490 ) -> UUID: 

491 key = uuid4() 

492 # Put a placeholder in-case the result is retrieved earlier 

493 self._results[key] = GatherIncomplete 

494 self._task_group.start_soon(self._run_and_store, key, func, *args, name=name) 

495 return key 

496 

497 async def start(self, func: object, *args: object, name: object = None) -> NoReturn: 1a

498 """ 

499 Since `start` returns the result of `task_status.started()` but here we must 

500 return the key instead, we just won't support this method for now. 

501 """ 

502 raise RuntimeError("`GatherTaskGroup` does not support `start`.") 

503 

504 def get_result(self, key: UUID) -> Any: 1a

505 result = self._results[key] 

506 if result is GatherIncomplete: 

507 raise GatherIncomplete( 

508 "Task is not complete. " 

509 "Results should not be retrieved until the task group exits." 

510 ) 

511 return result 

512 

513 async def __aenter__(self) -> Self: 1a

514 await self._task_group.__aenter__() 

515 return self 

516 

517 async def __aexit__(self, *tb: Any) -> Optional[bool]: # pyright: ignore[reportIncompatibleMethodOverride] 1a

518 try: 

519 retval = await self._task_group.__aexit__(*tb) 

520 return retval 

521 finally: 

522 del self._task_group 

523 

524 

525def create_gather_task_group() -> GatherTaskGroup: 1a

526 """Create a new task group that gathers results""" 

527 # This function matches the AnyIO API which uses callables since the concrete 

528 # task group class depends on the async library being used and cannot be 

529 # determined until runtime 

530 return GatherTaskGroup(anyio.create_task_group()) 

531 

532 

533async def gather(*calls: Callable[[], Coroutine[Any, Any, T]]) -> list[T]: 1a

534 """ 

535 Run calls concurrently and gather their results. 

536 

537 Unlike `asyncio.gather` this expects to receive _callables_ not _coroutines_. 

538 This matches `anyio` semantics. 

539 """ 

540 keys: list[UUID] = [] 

541 async with create_gather_task_group() as tg: 

542 for call in calls: 

543 keys.append(tg.start_soon(call)) 

544 return [tg.get_result(key) for key in keys] 

545 

546 

547class LazySemaphore: 1a

548 def __init__(self, initial_value_func: Callable[[], int]) -> None: 1a

549 self._semaphore: Optional[asyncio.Semaphore] = None 1a

550 self._initial_value_func = initial_value_func 1a

551 

552 async def __aenter__(self) -> asyncio.Semaphore: 1a

553 self._initialize_semaphore() 

554 if TYPE_CHECKING: 

555 assert self._semaphore is not None 

556 await self._semaphore.__aenter__() 

557 return self._semaphore 

558 

559 async def __aexit__(self, *args: Any) -> None: 1a

560 if TYPE_CHECKING: 

561 assert self._semaphore is not None 

562 await self._semaphore.__aexit__(*args) 

563 

564 def _initialize_semaphore(self) -> None: 1a

565 if self._semaphore is None: 

566 initial_value = self._initial_value_func() 

567 self._semaphore = asyncio.Semaphore(initial_value)