Coverage for /usr/local/lib/python3.12/site-packages/prefect/_internal/concurrency/calls.py: 48%

241 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1""" 

2Implementation of the `Call` data structure for transport of deferred function calls 

3and low-level management of call execution. 

4""" 

5 

6import abc 1d

7import asyncio 1d

8import concurrent.futures 1d

9import contextlib 1d

10import contextvars 1d

11import dataclasses 1d

12import inspect 1d

13import threading 1d

14import weakref 1d

15from collections.abc import Awaitable, Generator 1d

16from concurrent.futures._base import ( 1d

17 CANCELLED, 

18 CANCELLED_AND_NOTIFIED, 

19 FINISHED, 

20 RUNNING, 

21) 

22from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, Union 1d

23 

24from typing_extensions import ParamSpec, Self, TypeAlias, TypeVar, TypeVarTuple 1d

25 

26from prefect._internal.concurrency import logger 1d

27from prefect._internal.concurrency.cancellation import ( 1d

28 AsyncCancelScope, 

29 CancelledError, 

30 cancel_async_at, 

31 cancel_sync_at, 

32 get_deadline, 

33) 

34from prefect._internal.concurrency.event_loop import get_running_loop 1d

35 

36T = TypeVar("T", infer_variance=True) 1d

37Ts = TypeVarTuple("Ts") 1d

38P = ParamSpec("P") 1d

39 

40_SyncOrAsyncCallable: TypeAlias = Callable[P, Union[T, Awaitable[T]]] 1d

41 

42 

43# Tracks the current call being executed. Note that storing the `Call` 

44# object for an async call directly in the contextvar appears to create a 

45# memory leak, despite the fact that we `reset` when leaving the context 

46# that sets this contextvar. A weakref avoids the leak and works because a) 

47# we already have strong references to the `Call` objects in other places 

48# and b) this is used for performance optimizations where we have fallback 

49# behavior if this weakref is garbage collected. A fix for issue #10952. 

50current_call: contextvars.ContextVar["weakref.ref[Call[Any]]"] = ( # novm 1d

51 contextvars.ContextVar("current_call") 

52) 

53 

54# Create a strong reference to tasks to prevent destruction during execution errors 

55_ASYNC_TASK_REFS: set[asyncio.Task[None]] = set() 1d

56 

57 

58@contextlib.contextmanager 1d

59def set_current_call(call: "Call[Any]") -> Generator[None, Any, None]: 1d

60 token = current_call.set(weakref.ref(call)) 1eabc

61 try: 1eabc

62 yield 1eabc

63 finally: 

64 current_call.reset(token) 1eabc

65 

66 

67class Future(concurrent.futures.Future[T]): 1d

68 """ 

69 Extension of `concurrent.futures.Future` with support for cancellation of running 

70 futures. 

71 

72 Used by `Call`. 

73 """ 

74 

75 def __init__(self, name: Optional[str] = None) -> None: 1d

76 super().__init__() 1abc

77 self._cancel_scope = None 1abc

78 self._deadline = None 1abc

79 self._cancel_callbacks: list[Callable[[], None]] = [] 1abc

80 self._name = name 1abc

81 self._timed_out = False 1abc

82 

83 def set_running_or_notify_cancel(self, timeout: Optional[float] = None): 1d

84 self._deadline = get_deadline(timeout) 1eabc

85 return super().set_running_or_notify_cancel() 1eabc

86 

87 @contextlib.contextmanager 1d

88 def enforce_async_deadline(self) -> Generator[AsyncCancelScope]: 1d

89 with cancel_async_at(self._deadline, name=self._name) as self._cancel_scope: 1eabc

90 for callback in self._cancel_callbacks: 90 ↛ 91line 90 didn't jump to line 91 because the loop on line 90 never started1eabc

91 self._cancel_scope.add_cancel_callback(callback) 

92 yield self._cancel_scope 1eabc

93 

94 @contextlib.contextmanager 1d

95 def enforce_sync_deadline(self): 1d

96 with cancel_sync_at(self._deadline, name=self._name) as self._cancel_scope: 1eabc

97 for callback in self._cancel_callbacks: 97 ↛ 98line 97 didn't jump to line 98 because the loop on line 97 never started1eabc

98 self._cancel_scope.add_cancel_callback(callback) 

99 yield self._cancel_scope 1eabc

100 

101 def add_cancel_callback(self, callback: Callable[[], Any]) -> None: 1d

102 """ 

103 Add a callback to be enforced on cancellation. 

104 

105 Unlike "done" callbacks, this callback will be invoked _before_ the future is 

106 cancelled. If added after the future is cancelled, nothing will happen. 

107 """ 

108 # If we were to invoke cancel callbacks the same as "done" callbacks, we 

109 # would not propagate chained cancellation in waiters in time to actually 

110 # interrupt calls. 

111 if self._cancel_scope: 

112 # Add callback to current cancel scope if it exists 

113 self._cancel_scope.add_cancel_callback(callback) 

114 

115 # Also add callbacks to tracking list 

116 self._cancel_callbacks.append(callback) 

117 

118 def timedout(self) -> bool: 1d

119 with self._condition: 

120 return self._timed_out 

121 

122 def cancel(self) -> bool: 1d

123 """Cancel the future if possible. 

124 

125 Returns True if the future was cancelled, False otherwise. A future cannot be 

126 cancelled if it has already completed. 

127 """ 

128 with self._condition: 

129 # Unlike the stdlib, we allow attempted cancellation of RUNNING futures 

130 if self._state in [RUNNING]: 

131 if self._cancel_scope is None: 

132 return False 

133 elif not self._cancel_scope.cancelled(): 

134 # Perform cancellation 

135 if not self._cancel_scope.cancel(): 

136 return False 

137 

138 if self._state in [FINISHED]: 

139 return False 

140 

141 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]: 

142 return True 

143 

144 # Normally cancel callbacks are handled by the cancel scope but if there 

145 # is not one let's respect them still 

146 if not self._cancel_scope: 

147 for callback in self._cancel_callbacks: 

148 callback() 

149 

150 self._state = CANCELLED 

151 self._condition.notify_all() 

152 

153 self._invoke_callbacks() 

154 return True 

155 

156 if TYPE_CHECKING: 156 ↛ 158line 156 didn't jump to line 158 because the condition on line 156 was never true1d

157 

158 def __get_result(self) -> T: ... 

159 

160 def result(self, timeout: Optional[float] = None) -> T: 1d

161 """Return the result of the call that the future represents. 

162 

163 Args: 

164 timeout: The number of seconds to wait for the result if the future 

165 isn't done. If None, then there is no limit on the wait time. 

166 

167 Returns: 

168 The result of the call that the future represents. 

169 

170 Raises: 

171 CancelledError: If the future was cancelled. 

172 TimeoutError: If the future didn't finish executing before the given 

173 timeout. 

174 Exception: If the call raised then that exception will be raised. 

175 """ 

176 try: 1eabc

177 with self._condition: 1eabc

178 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]: 178 ↛ 181line 178 didn't jump to line 181 because the condition on line 178 was never true1eabc

179 # Raise Prefect cancelled error instead of 

180 # `concurrent.futures._base.CancelledError` 

181 raise CancelledError() 

182 elif self._state == FINISHED: 182 ↛ 183line 182 didn't jump to line 183 because the condition on line 182 was never true1eabc

183 return self.__get_result() 

184 

185 self._condition.wait(timeout) 1eabc

186 

187 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]: 187 ↛ 190line 187 didn't jump to line 190 because the condition on line 187 was never true1abc

188 # Raise Prefect cancelled error instead of 

189 # `concurrent.futures._base.CancelledError` 

190 raise CancelledError() 

191 elif self._state == FINISHED: 191 ↛ 194line 191 didn't jump to line 194 because the condition on line 191 was always true1abc

192 return self.__get_result() 1abc

193 else: 

194 raise TimeoutError() 

195 finally: 

196 # Break a reference cycle with the exception in self._exception 

197 self = None 1abc

198 

199 _done_callbacks: list[Callable[[Self], object]] 1d

200 

201 def _invoke_callbacks(self) -> None: 1d

202 """ 

203 Invoke our done callbacks and clean up cancel scopes and cancel 

204 callbacks. Fixes a memory leak that hung on to Call objects, 

205 preventing garbage collection of Futures. 

206 

207 A fix for #10952. 

208 """ 

209 if self._done_callbacks: 209 ↛ 210line 209 didn't jump to line 210 because the condition on line 209 was never true1abc

210 done_callbacks = self._done_callbacks[:] 

211 self._done_callbacks[:] = [] 

212 

213 for callback in done_callbacks: 

214 try: 

215 callback(self) 

216 except Exception: 

217 logger.exception("exception calling callback for %r", self) 

218 

219 self._cancel_callbacks = [] 1abc

220 if self._cancel_scope: 220 ↛ exitline 220 didn't return from function '_invoke_callbacks' because the condition on line 220 was always true1abc

221 setattr(self._cancel_scope, "_callbacks", []) 1abc

222 self._cancel_scope = None 1abc

223 

224 

225@dataclasses.dataclass(eq=False) 1d

226class Call(Generic[T]): 1d

227 """ 

228 A deferred function call. 

229 """ 

230 

231 future: Future[T] 1d

232 fn: "_SyncOrAsyncCallable[..., T]" 1d

233 args: tuple[Any, ...] 1d

234 kwargs: dict[str, Any] 1d

235 context: contextvars.Context 1d

236 timeout: Optional[float] 1d

237 runner: Optional["Portal"] = None 1d

238 

239 def __eq__(self, other: object) -> bool: 1d

240 """this is to avoid attempts at invalid access of args/kwargs in <3.13 stemming from the 

241 auto-generated __eq__ method on the dataclass. 

242 

243 this will no longer be required in 3.13+, see https://github.com/python/cpython/issues/128294 

244 """ 

245 if self is other: 

246 return True 

247 if not isinstance(other, Call): 

248 return NotImplemented 

249 

250 try: 

251 # Attempt to access args/kwargs. If any are missing on self or other, 

252 # an AttributeError will be raised by the access attempt on one of them. 

253 s_args, s_kwargs = self.args, self.kwargs 

254 o_args, o_kwargs = other.args, other.kwargs 

255 except AttributeError: 

256 # If args/kwargs are missing on self or other (and self is not other), 

257 # they are considered not equal. This ensures that a Call with deleted 

258 # args/kwargs compares as different from one that still has them 

259 return False 

260 

261 # If all args/kwargs were accessible on both, proceed with full field comparison. 

262 # Note: self.future == other.future will use Future's __eq__ (default is identity). 

263 return ( 

264 (self.future == other.future) 

265 and (self.fn == other.fn) 

266 and (s_args == o_args) 

267 and (s_kwargs == o_kwargs) 

268 and (self.context == other.context) 

269 and (self.timeout == other.timeout) 

270 and (self.runner == other.runner) 

271 ) 

272 

273 __hash__ = None # type: ignore 1d

274 

275 @classmethod 1d

276 def new( 1d

277 cls, 

278 __fn: _SyncOrAsyncCallable[P, T], 

279 *args: P.args, 

280 **kwargs: P.kwargs, 

281 ) -> Self: 

282 return cls( 1abc

283 future=Future(name=getattr(__fn, "__name__", str(__fn))), 

284 fn=__fn, 

285 args=args, 

286 kwargs=kwargs, 

287 context=contextvars.copy_context(), 

288 timeout=None, 

289 ) 

290 

291 def set_timeout(self, timeout: Optional[float] = None) -> None: 1d

292 """ 

293 Set the timeout for the call. 

294 

295 The timeout begins when the call starts. 

296 """ 

297 if self.future.done() or self.future.running(): 

298 raise RuntimeError("Timeouts cannot be added when the call has started.") 

299 

300 self.timeout = timeout 

301 

302 def set_runner(self, portal: "Portal") -> None: 1d

303 """ 

304 Update the portal used to run this call. 

305 """ 

306 if self.runner is not None: 306 ↛ 307line 306 didn't jump to line 307 because the condition on line 306 was never true1eabc

307 raise RuntimeError("The portal is already set for this call.") 

308 

309 self.runner = portal 1eabc

310 

311 def run(self) -> Optional[Awaitable[None]]: 1d

312 """ 

313 Execute the call and place the result on the future. 

314 

315 All exceptions during execution of the call are captured. 

316 """ 

317 # Do not execute if the future is cancelled 

318 if not self.future.set_running_or_notify_cancel(self.timeout): 318 ↛ 319line 318 didn't jump to line 319 because the condition on line 318 was never true1eabc

319 logger.debug("Skipping execution of cancelled call %r", self) 

320 return None 

321 

322 logger.debug( 1eabc

323 "Running call %r in thread %r%s", 

324 self, 

325 threading.current_thread().name, 

326 f" with timeout of {self.timeout}s" if self.timeout is not None else "", 

327 ) 

328 

329 coro = self.context.run(self._run_sync) 1eabc

330 

331 if coro is not None: 331 ↛ 357line 331 didn't jump to line 357 because the condition on line 331 was always true1eabc

332 loop = get_running_loop() 1eabc

333 if loop: 333 ↛ 354line 333 didn't jump to line 354 because the condition on line 333 was always true1eabc

334 # If an event loop is available, return a task to be awaited 

335 # Note we must create a task for context variables to propagate 

336 logger.debug( 1eabc

337 "Scheduling coroutine for call %r in running loop %r", 

338 self, 

339 loop, 

340 ) 

341 task = self.context.run(loop.create_task, self._run_async(coro)) 1eabc

342 

343 # Prevent tasks from being garbage collected before completion 

344 # See https://docs.python.org/3.10/library/asyncio-task.html#asyncio.create_task 

345 _ASYNC_TASK_REFS.add(task) 1eabc

346 asyncio.ensure_future(task).add_done_callback( 1eabc

347 lambda _: _ASYNC_TASK_REFS.remove(task) 

348 ) 

349 

350 return task 1eabc

351 

352 else: 

353 # Otherwise, execute the function here 

354 logger.debug("Executing coroutine for call %r in new loop", self) 

355 return self.context.run(asyncio.run, self._run_async(coro)) 

356 

357 return None 

358 

359 def result(self, timeout: Optional[float] = None) -> T: 1d

360 """ 

361 Wait for the result of the call. 

362 

363 Not safe for use from asynchronous contexts. 

364 """ 

365 return self.future.result(timeout=timeout) 1eabc

366 

367 async def aresult(self): 1d

368 """ 

369 Wait for the result of the call. 

370 

371 For use from asynchronous contexts. 

372 """ 

373 try: 

374 return await asyncio.wrap_future(self.future) 

375 except asyncio.CancelledError as exc: 

376 raise CancelledError() from exc 

377 

378 def cancelled(self) -> bool: 1d

379 """ 

380 Check if the call was cancelled. 

381 """ 

382 return self.future.cancelled() 

383 

384 def timedout(self) -> bool: 1d

385 """ 

386 Check if the call timed out. 

387 """ 

388 return self.future.timedout() 

389 

390 def cancel(self) -> bool: 1d

391 return self.future.cancel() 

392 

393 def _run_sync(self) -> Optional[Awaitable[T]]: 1d

394 cancel_scope = None 1eabc

395 try: 1eabc

396 with set_current_call(self): 1eabc

397 with self.future.enforce_sync_deadline() as cancel_scope: 1eabc

398 try: 1eabc

399 result = self.fn(*self.args, **self.kwargs) 1eabc

400 finally: 

401 # Forget this call's arguments in order to free up any memory 

402 # that may be referenced by them; after a call has happened, 

403 # there's no need to keep a reference to them 

404 with contextlib.suppress(AttributeError): 1eabc

405 del self.args, self.kwargs 1eabc

406 

407 # Return the coroutine for async execution 

408 if inspect.isawaitable(result): 408 ↛ 432line 408 didn't jump to line 432 because the condition on line 408 was always true1eabc

409 return result 1eabc

410 

411 except CancelledError: 

412 # Report cancellation 

413 # in rare cases, enforce_sync_deadline raises CancelledError 

414 # prior to yielding 

415 if cancel_scope is None: 

416 self.future.cancel() 

417 return None 

418 if cancel_scope.timedout(): 

419 setattr(self.future, "_timed_out", True) 

420 self.future.cancel() 

421 elif cancel_scope.cancelled(): 

422 self.future.cancel() 

423 else: 

424 raise 

425 except BaseException as exc: 

426 logger.debug("Encountered exception in call %r", self, exc_info=True) 

427 self.future.set_exception(exc) 

428 

429 # Prevent reference cycle in `exc` 

430 del self 

431 else: 

432 self.future.set_result(result) # noqa: F821 

433 logger.debug("Finished call %r", self) # noqa: F821 

434 

435 async def _run_async(self, coro: Awaitable[T]) -> None: 1d

436 cancel_scope = result = None 1eabc

437 try: 1eabc

438 with set_current_call(self): 1eabc

439 with self.future.enforce_async_deadline() as cancel_scope: 1eabc

440 try: 1eabc

441 result = await coro 1eabc

442 finally: 

443 # Forget this call's arguments in order to free up any memory 

444 # that may be referenced by them; after a call has happened, 

445 # there's no need to keep a reference to them 

446 with contextlib.suppress(AttributeError): 1abc

447 del self.args, self.kwargs 1abc

448 except CancelledError: 

449 # Report cancellation 

450 if TYPE_CHECKING: 

451 assert cancel_scope is not None 

452 if cancel_scope.timedout(): 

453 setattr(self.future, "_timed_out", True) 

454 self.future.cancel() 

455 elif cancel_scope.cancelled(): 

456 self.future.cancel() 

457 else: 

458 raise 

459 except BaseException as exc: 

460 logger.debug("Encountered exception in async call %r", self, exc_info=True) 

461 

462 self.future.set_exception(exc) 

463 # Prevent reference cycle in `exc` 

464 del self 

465 else: 

466 # F821 ignored because Ruff gets confused about the del self above. 

467 self.future.set_result(result) # noqa: F821 1abc

468 logger.debug("Finished async call %r", self) # noqa: F821 1abc

469 

470 def __call__(self) -> Union[T, Awaitable[T]]: 1d

471 """ 

472 Execute the call and return its result. 

473 

474 All executions during execution of the call are re-raised. 

475 """ 

476 coro = self.run() 

477 

478 # Return an awaitable if in an async context 

479 if coro is not None: 

480 

481 async def run_and_return_result() -> T: 

482 await coro 

483 return self.result() 

484 

485 return run_and_return_result() 

486 else: 

487 return self.result() 

488 

489 def __repr__(self) -> str: 1d

490 name = getattr(self.fn, "__name__", str(self.fn)) 

491 

492 try: 

493 args, kwargs = self.args, self.kwargs 

494 except AttributeError: 

495 call_args = "<dropped>" 

496 else: 

497 call_args = ", ".join( 

498 [repr(arg) for arg in args] 

499 + [f"{key}={repr(val)}" for key, val in kwargs.items()] 

500 ) 

501 

502 # Enforce a maximum length 

503 if len(call_args) > 100: 

504 call_args = call_args[:100] + "..." 

505 

506 return f"{name}({call_args})" 

507 

508 

509class Portal(abc.ABC): 1d

510 """ 

511 Allows submission of calls to execute elsewhere. 

512 """ 

513 

514 @abc.abstractmethod 1d

515 def submit(self, call: "Call[T]") -> "Call[T]": 1d

516 """ 

517 Submit a call to execute elsewhere. 

518 

519 The call's result can be retrieved with `call.result()`. 

520 

521 Returns the call for convenience. 

522 """