Coverage for /usr/local/lib/python3.12/site-packages/prefect/_internal/concurrency/calls.py: 21%

241 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 13:38 +0000

1""" 

2Implementation of the `Call` data structure for transport of deferred function calls 

3and low-level management of call execution. 

4""" 

5 

6import abc 1a

7import asyncio 1a

8import concurrent.futures 1a

9import contextlib 1a

10import contextvars 1a

11import dataclasses 1a

12import inspect 1a

13import threading 1a

14import weakref 1a

15from collections.abc import Awaitable, Generator 1a

16from concurrent.futures._base import ( 1a

17 CANCELLED, 

18 CANCELLED_AND_NOTIFIED, 

19 FINISHED, 

20 RUNNING, 

21) 

22from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, Union 1a

23 

24from typing_extensions import ParamSpec, Self, TypeAlias, TypeVar, TypeVarTuple 1a

25 

26from prefect._internal.concurrency import logger 1a

27from prefect._internal.concurrency.cancellation import ( 1a

28 AsyncCancelScope, 

29 CancelledError, 

30 cancel_async_at, 

31 cancel_sync_at, 

32 get_deadline, 

33) 

34from prefect._internal.concurrency.event_loop import get_running_loop 1a

35 

36T = TypeVar("T", infer_variance=True) 1a

37Ts = TypeVarTuple("Ts") 1a

38P = ParamSpec("P") 1a

39 

40_SyncOrAsyncCallable: TypeAlias = Callable[P, Union[T, Awaitable[T]]] 1a

41 

42 

43# Tracks the current call being executed. Note that storing the `Call` 

44# object for an async call directly in the contextvar appears to create a 

45# memory leak, despite the fact that we `reset` when leaving the context 

46# that sets this contextvar. A weakref avoids the leak and works because a) 

47# we already have strong references to the `Call` objects in other places 

48# and b) this is used for performance optimizations where we have fallback 

49# behavior if this weakref is garbage collected. A fix for issue #10952. 

50current_call: contextvars.ContextVar["weakref.ref[Call[Any]]"] = ( # novm 1a

51 contextvars.ContextVar("current_call") 

52) 

53 

54# Create a strong reference to tasks to prevent destruction during execution errors 

55_ASYNC_TASK_REFS: set[asyncio.Task[None]] = set() 1a

56 

57 

58@contextlib.contextmanager 1a

59def set_current_call(call: "Call[Any]") -> Generator[None, Any, None]: 1a

60 token = current_call.set(weakref.ref(call)) 

61 try: 

62 yield 

63 finally: 

64 current_call.reset(token) 

65 

66 

67class Future(concurrent.futures.Future[T]): 1a

68 """ 

69 Extension of `concurrent.futures.Future` with support for cancellation of running 

70 futures. 

71 

72 Used by `Call`. 

73 """ 

74 

75 def __init__(self, name: Optional[str] = None) -> None: 1a

76 super().__init__() 

77 self._cancel_scope = None 

78 self._deadline = None 

79 self._cancel_callbacks: list[Callable[[], None]] = [] 

80 self._name = name 

81 self._timed_out = False 

82 

83 def set_running_or_notify_cancel(self, timeout: Optional[float] = None): 1a

84 self._deadline = get_deadline(timeout) 

85 return super().set_running_or_notify_cancel() 

86 

87 @contextlib.contextmanager 1a

88 def enforce_async_deadline(self) -> Generator[AsyncCancelScope]: 1a

89 with cancel_async_at(self._deadline, name=self._name) as self._cancel_scope: 

90 for callback in self._cancel_callbacks: 

91 self._cancel_scope.add_cancel_callback(callback) 

92 yield self._cancel_scope 

93 

94 @contextlib.contextmanager 1a

95 def enforce_sync_deadline(self): 1a

96 with cancel_sync_at(self._deadline, name=self._name) as self._cancel_scope: 

97 for callback in self._cancel_callbacks: 

98 self._cancel_scope.add_cancel_callback(callback) 

99 yield self._cancel_scope 

100 

101 def add_cancel_callback(self, callback: Callable[[], Any]) -> None: 1a

102 """ 

103 Add a callback to be enforced on cancellation. 

104 

105 Unlike "done" callbacks, this callback will be invoked _before_ the future is 

106 cancelled. If added after the future is cancelled, nothing will happen. 

107 """ 

108 # If we were to invoke cancel callbacks the same as "done" callbacks, we 

109 # would not propagate chained cancellation in waiters in time to actually 

110 # interrupt calls. 

111 if self._cancel_scope: 

112 # Add callback to current cancel scope if it exists 

113 self._cancel_scope.add_cancel_callback(callback) 

114 

115 # Also add callbacks to tracking list 

116 self._cancel_callbacks.append(callback) 

117 

118 def timedout(self) -> bool: 1a

119 with self._condition: 

120 return self._timed_out 

121 

122 def cancel(self) -> bool: 1a

123 """Cancel the future if possible. 

124 

125 Returns True if the future was cancelled, False otherwise. A future cannot be 

126 cancelled if it has already completed. 

127 """ 

128 with self._condition: 

129 # Unlike the stdlib, we allow attempted cancellation of RUNNING futures 

130 if self._state in [RUNNING]: 

131 if self._cancel_scope is None: 

132 return False 

133 elif not self._cancel_scope.cancelled(): 

134 # Perform cancellation 

135 if not self._cancel_scope.cancel(): 

136 return False 

137 

138 if self._state in [FINISHED]: 

139 return False 

140 

141 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]: 

142 return True 

143 

144 # Normally cancel callbacks are handled by the cancel scope but if there 

145 # is not one let's respect them still 

146 if not self._cancel_scope: 

147 for callback in self._cancel_callbacks: 

148 callback() 

149 

150 self._state = CANCELLED 

151 self._condition.notify_all() 

152 

153 self._invoke_callbacks() 

154 return True 

155 

156 if TYPE_CHECKING: 156 ↛ 158line 156 didn't jump to line 158 because the condition on line 156 was never true1a

157 

158 def __get_result(self) -> T: ... 

159 

160 def result(self, timeout: Optional[float] = None) -> T: 1a

161 """Return the result of the call that the future represents. 

162 

163 Args: 

164 timeout: The number of seconds to wait for the result if the future 

165 isn't done. If None, then there is no limit on the wait time. 

166 

167 Returns: 

168 The result of the call that the future represents. 

169 

170 Raises: 

171 CancelledError: If the future was cancelled. 

172 TimeoutError: If the future didn't finish executing before the given 

173 timeout. 

174 Exception: If the call raised then that exception will be raised. 

175 """ 

176 try: 

177 with self._condition: 

178 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]: 

179 # Raise Prefect cancelled error instead of 

180 # `concurrent.futures._base.CancelledError` 

181 raise CancelledError() 

182 elif self._state == FINISHED: 

183 return self.__get_result() 

184 

185 self._condition.wait(timeout) 

186 

187 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]: 

188 # Raise Prefect cancelled error instead of 

189 # `concurrent.futures._base.CancelledError` 

190 raise CancelledError() 

191 elif self._state == FINISHED: 

192 return self.__get_result() 

193 else: 

194 raise TimeoutError() 

195 finally: 

196 # Break a reference cycle with the exception in self._exception 

197 self = None 

198 

199 _done_callbacks: list[Callable[[Self], object]] 1a

200 

201 def _invoke_callbacks(self) -> None: 1a

202 """ 

203 Invoke our done callbacks and clean up cancel scopes and cancel 

204 callbacks. Fixes a memory leak that hung on to Call objects, 

205 preventing garbage collection of Futures. 

206 

207 A fix for #10952. 

208 """ 

209 if self._done_callbacks: 

210 done_callbacks = self._done_callbacks[:] 

211 self._done_callbacks[:] = [] 

212 

213 for callback in done_callbacks: 

214 try: 

215 callback(self) 

216 except Exception: 

217 logger.exception("exception calling callback for %r", self) 

218 

219 self._cancel_callbacks = [] 

220 if self._cancel_scope: 

221 setattr(self._cancel_scope, "_callbacks", []) 

222 self._cancel_scope = None 

223 

224 

225@dataclasses.dataclass(eq=False) 1a

226class Call(Generic[T]): 1a

227 """ 

228 A deferred function call. 

229 """ 

230 

231 future: Future[T] 1a

232 fn: "_SyncOrAsyncCallable[..., T]" 1a

233 args: tuple[Any, ...] 1a

234 kwargs: dict[str, Any] 1a

235 context: contextvars.Context 1a

236 timeout: Optional[float] 1a

237 runner: Optional["Portal"] = None 1a

238 

239 def __eq__(self, other: object) -> bool: 1a

240 """this is to avoid attempts at invalid access of args/kwargs in <3.13 stemming from the 

241 auto-generated __eq__ method on the dataclass. 

242 

243 this will no longer be required in 3.13+, see https://github.com/python/cpython/issues/128294 

244 """ 

245 if self is other: 

246 return True 

247 if not isinstance(other, Call): 

248 return NotImplemented 

249 

250 try: 

251 # Attempt to access args/kwargs. If any are missing on self or other, 

252 # an AttributeError will be raised by the access attempt on one of them. 

253 s_args, s_kwargs = self.args, self.kwargs 

254 o_args, o_kwargs = other.args, other.kwargs 

255 except AttributeError: 

256 # If args/kwargs are missing on self or other (and self is not other), 

257 # they are considered not equal. This ensures that a Call with deleted 

258 # args/kwargs compares as different from one that still has them 

259 return False 

260 

261 # If all args/kwargs were accessible on both, proceed with full field comparison. 

262 # Note: self.future == other.future will use Future's __eq__ (default is identity). 

263 return ( 

264 (self.future == other.future) 

265 and (self.fn == other.fn) 

266 and (s_args == o_args) 

267 and (s_kwargs == o_kwargs) 

268 and (self.context == other.context) 

269 and (self.timeout == other.timeout) 

270 and (self.runner == other.runner) 

271 ) 

272 

273 __hash__ = None # type: ignore 1a

274 

275 @classmethod 1a

276 def new( 1a

277 cls, 

278 __fn: _SyncOrAsyncCallable[P, T], 

279 *args: P.args, 

280 **kwargs: P.kwargs, 

281 ) -> Self: 

282 return cls( 

283 future=Future(name=getattr(__fn, "__name__", str(__fn))), 

284 fn=__fn, 

285 args=args, 

286 kwargs=kwargs, 

287 context=contextvars.copy_context(), 

288 timeout=None, 

289 ) 

290 

291 def set_timeout(self, timeout: Optional[float] = None) -> None: 1a

292 """ 

293 Set the timeout for the call. 

294 

295 The timeout begins when the call starts. 

296 """ 

297 if self.future.done() or self.future.running(): 

298 raise RuntimeError("Timeouts cannot be added when the call has started.") 

299 

300 self.timeout = timeout 

301 

302 def set_runner(self, portal: "Portal") -> None: 1a

303 """ 

304 Update the portal used to run this call. 

305 """ 

306 if self.runner is not None: 

307 raise RuntimeError("The portal is already set for this call.") 

308 

309 self.runner = portal 

310 

311 def run(self) -> Optional[Awaitable[None]]: 1a

312 """ 

313 Execute the call and place the result on the future. 

314 

315 All exceptions during execution of the call are captured. 

316 """ 

317 # Do not execute if the future is cancelled 

318 if not self.future.set_running_or_notify_cancel(self.timeout): 

319 logger.debug("Skipping execution of cancelled call %r", self) 

320 return None 

321 

322 logger.debug( 

323 "Running call %r in thread %r%s", 

324 self, 

325 threading.current_thread().name, 

326 f" with timeout of {self.timeout}s" if self.timeout is not None else "", 

327 ) 

328 

329 coro = self.context.run(self._run_sync) 

330 

331 if coro is not None: 

332 loop = get_running_loop() 

333 if loop: 

334 # If an event loop is available, return a task to be awaited 

335 # Note we must create a task for context variables to propagate 

336 logger.debug( 

337 "Scheduling coroutine for call %r in running loop %r", 

338 self, 

339 loop, 

340 ) 

341 task = self.context.run(loop.create_task, self._run_async(coro)) 

342 

343 # Prevent tasks from being garbage collected before completion 

344 # See https://docs.python.org/3.10/library/asyncio-task.html#asyncio.create_task 

345 _ASYNC_TASK_REFS.add(task) 

346 asyncio.ensure_future(task).add_done_callback( 

347 lambda _: _ASYNC_TASK_REFS.remove(task) 

348 ) 

349 

350 return task 

351 

352 else: 

353 # Otherwise, execute the function here 

354 logger.debug("Executing coroutine for call %r in new loop", self) 

355 return self.context.run(asyncio.run, self._run_async(coro)) 

356 

357 return None 

358 

359 def result(self, timeout: Optional[float] = None) -> T: 1a

360 """ 

361 Wait for the result of the call. 

362 

363 Not safe for use from asynchronous contexts. 

364 """ 

365 return self.future.result(timeout=timeout) 

366 

367 async def aresult(self): 1a

368 """ 

369 Wait for the result of the call. 

370 

371 For use from asynchronous contexts. 

372 """ 

373 try: 

374 return await asyncio.wrap_future(self.future) 

375 except asyncio.CancelledError as exc: 

376 raise CancelledError() from exc 

377 

378 def cancelled(self) -> bool: 1a

379 """ 

380 Check if the call was cancelled. 

381 """ 

382 return self.future.cancelled() 

383 

384 def timedout(self) -> bool: 1a

385 """ 

386 Check if the call timed out. 

387 """ 

388 return self.future.timedout() 

389 

390 def cancel(self) -> bool: 1a

391 return self.future.cancel() 

392 

393 def _run_sync(self) -> Optional[Awaitable[T]]: 1a

394 cancel_scope = None 

395 try: 

396 with set_current_call(self): 

397 with self.future.enforce_sync_deadline() as cancel_scope: 

398 try: 

399 result = self.fn(*self.args, **self.kwargs) 

400 finally: 

401 # Forget this call's arguments in order to free up any memory 

402 # that may be referenced by them; after a call has happened, 

403 # there's no need to keep a reference to them 

404 with contextlib.suppress(AttributeError): 

405 del self.args, self.kwargs 

406 

407 # Return the coroutine for async execution 

408 if inspect.isawaitable(result): 

409 return result 

410 

411 except CancelledError: 

412 # Report cancellation 

413 # in rare cases, enforce_sync_deadline raises CancelledError 

414 # prior to yielding 

415 if cancel_scope is None: 

416 self.future.cancel() 

417 return None 

418 if cancel_scope.timedout(): 

419 setattr(self.future, "_timed_out", True) 

420 self.future.cancel() 

421 elif cancel_scope.cancelled(): 

422 self.future.cancel() 

423 else: 

424 raise 

425 except BaseException as exc: 

426 logger.debug("Encountered exception in call %r", self, exc_info=True) 

427 self.future.set_exception(exc) 

428 

429 # Prevent reference cycle in `exc` 

430 del self 

431 else: 

432 self.future.set_result(result) # noqa: F821 

433 logger.debug("Finished call %r", self) # noqa: F821 

434 

435 async def _run_async(self, coro: Awaitable[T]) -> None: 1a

436 cancel_scope = result = None 

437 try: 

438 with set_current_call(self): 

439 with self.future.enforce_async_deadline() as cancel_scope: 

440 try: 

441 result = await coro 

442 finally: 

443 # Forget this call's arguments in order to free up any memory 

444 # that may be referenced by them; after a call has happened, 

445 # there's no need to keep a reference to them 

446 with contextlib.suppress(AttributeError): 

447 del self.args, self.kwargs 

448 except CancelledError: 

449 # Report cancellation 

450 if TYPE_CHECKING: 

451 assert cancel_scope is not None 

452 if cancel_scope.timedout(): 

453 setattr(self.future, "_timed_out", True) 

454 self.future.cancel() 

455 elif cancel_scope.cancelled(): 

456 self.future.cancel() 

457 else: 

458 raise 

459 except BaseException as exc: 

460 logger.debug("Encountered exception in async call %r", self, exc_info=True) 

461 

462 self.future.set_exception(exc) 

463 # Prevent reference cycle in `exc` 

464 del self 

465 else: 

466 # F821 ignored because Ruff gets confused about the del self above. 

467 self.future.set_result(result) # noqa: F821 

468 logger.debug("Finished async call %r", self) # noqa: F821 

469 

470 def __call__(self) -> Union[T, Awaitable[T]]: 1a

471 """ 

472 Execute the call and return its result. 

473 

474 All executions during execution of the call are re-raised. 

475 """ 

476 coro = self.run() 

477 

478 # Return an awaitable if in an async context 

479 if coro is not None: 

480 

481 async def run_and_return_result() -> T: 

482 await coro 

483 return self.result() 

484 

485 return run_and_return_result() 

486 else: 

487 return self.result() 

488 

489 def __repr__(self) -> str: 1a

490 name = getattr(self.fn, "__name__", str(self.fn)) 

491 

492 try: 

493 args, kwargs = self.args, self.kwargs 

494 except AttributeError: 

495 call_args = "<dropped>" 

496 else: 

497 call_args = ", ".join( 

498 [repr(arg) for arg in args] 

499 + [f"{key}={repr(val)}" for key, val in kwargs.items()] 

500 ) 

501 

502 # Enforce a maximum length 

503 if len(call_args) > 100: 

504 call_args = call_args[:100] + "..." 

505 

506 return f"{name}({call_args})" 

507 

508 

509class Portal(abc.ABC): 1a

510 """ 

511 Allows submission of calls to execute elsewhere. 

512 """ 

513 

514 @abc.abstractmethod 1a

515 def submit(self, call: "Call[T]") -> "Call[T]": 1a

516 """ 

517 Submit a call to execute elsewhere. 

518 

519 The call's result can be retrieved with `call.result()`. 

520 

521 Returns the call for convenience. 

522 """