Coverage for /usr/local/lib/python3.12/site-packages/prefect/_internal/concurrency/cancellation.py: 43%

272 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1""" 

2Utilities for cancellation in synchronous and asynchronous contexts. 

3""" 

4 

5import abc 1a

6import asyncio 1a

7import contextlib 1a

8import ctypes 1a

9import math 1a

10import os 1a

11import signal 1a

12import sys 1a

13import threading 1a

14import time 1a

15from types import TracebackType 1a

16from typing import TYPE_CHECKING, Any, Callable, Optional, overload 1a

17 

18import anyio 1a

19 

20from prefect._internal.concurrency import logger 1a

21from prefect._internal.concurrency.event_loop import get_running_loop 1a

22 

23_THREAD_SHIELDS: dict[threading.Thread, "ThreadShield"] = {} 1a

24_THREAD_SHIELDS_LOCK = threading.Lock() 1a

25 

26 

27class ThreadShield: 1a

28 """ 

29 A wrapper around a reentrant lock for shielding a thread from remote exceptions. 

30 This can be used in two ways: 

31 

32 1. As a context manager from _another_ thread to wait until the shield is released 

33 by a target before sending an exception. 

34 

35 2. From the current thread, using `set_exception` to throw the exception when the 

36 shield is released. 

37 

38 A reentrant lock means that shields can be nested and the exception will only be 

39 raised when the last context is exited. 

40 """ 

41 

42 def __init__(self, owner: threading.Thread): 1a

43 # Uses the Python implementation of the RLock instead of the C implementation 

44 # because we need to inspect `_count` directly to check if the lock is active 

45 # which is needed for delayed exception raising during alarms 

46 self._lock = threading._RLock() # type: ignore # yes, we want the private version 

47 self._exception = None 

48 self._owner = owner 

49 

50 def __enter__(self) -> None: 1a

51 self._lock.__enter__() 

52 

53 def __exit__(self, *exc_info: Any): 1a

54 retval = self._lock.__exit__(*exc_info) 

55 

56 # Raise the exception if this is the last shield to exit in the owner thread 

57 if ( 

58 not self.active() 

59 and self._exception 

60 and self._owner.ident == threading.current_thread().ident 

61 ): 

62 # Clear the exception to prevent it from being raised again 

63 exc = self._exception 

64 self._exception = None 

65 raise exc from None 

66 

67 return retval 

68 

69 def set_exception(self, exc: BaseException): 1a

70 self._exception = exc 

71 

72 def active(self) -> bool: 1a

73 """ 

74 Returns true if the shield is active. 

75 """ 

76 return getattr(self._lock, "_count") > 0 

77 

78 

79class CancelledError(asyncio.CancelledError): 1a

80 # We want our `CancelledError` to be treated as a `BaseException` and defining it 

81 # here simplifies downstream logic that needs to know "which" cancelled error to 

82 # handle. 

83 pass 1a

84 

85 

86def _get_thread_shield(thread: threading.Thread) -> ThreadShield: 1a

87 with _THREAD_SHIELDS_LOCK: 

88 if thread not in _THREAD_SHIELDS: 

89 _THREAD_SHIELDS[thread] = ThreadShield(thread) 

90 

91 # Perform garbage collection for old threads 

92 for thread_ in tuple(_THREAD_SHIELDS.keys()): 

93 if not thread_.is_alive(): 

94 _THREAD_SHIELDS.pop(thread_) 

95 

96 return _THREAD_SHIELDS[thread] 

97 

98 

99@contextlib.contextmanager 1a

100def shield(): 1a

101 """ 

102 Prevent code from within the scope from being cancelled. 

103 

104 This guards against cancellation from alarm signals and injected exceptions as used 

105 in this module. 

106 

107 If an event loop is running in the thread where this is called, it will be shielded 

108 from asynchronous cancellation as well. 

109 """ 

110 with ( 

111 anyio.CancelScope(shield=True) 

112 if get_running_loop() 

113 else contextlib.nullcontext() 

114 ): 

115 with _get_thread_shield(threading.current_thread()): 

116 yield 

117 

118 

119class CancelScope(abc.ABC): 1a

120 """ 

121 Defines a context where cancellation can be requested. 

122 

123 If cancelled, any code within the context should be interrupted. The cancellation 

124 implementation varies depending on the environment and may not interrupt some system 

125 calls. 

126 

127 A timeout can be defined to automatically cancel the scope after a given duration if 

128 it has not exited. 

129 """ 

130 

131 def __init__( 1a

132 self, name: Optional[str] = None, timeout: Optional[float] = None 

133 ) -> None: 

134 self.name = name 1ebcd

135 self._deadline = None 1ebcd

136 self._cancelled = False 1ebcd

137 self._completed = False 1ebcd

138 self._started = False 1ebcd

139 self._start_time = None 1ebcd

140 self._end_time = None 1ebcd

141 self._timeout = timeout 1ebcd

142 self._lock = threading.Lock() 1ebcd

143 self._callbacks: list[Callable[[], None]] = [] 1ebcd

144 super().__init__() 1ebcd

145 

146 def __enter__(self): 1a

147 with self._lock: 1ebcd

148 self._deadline = get_deadline(self._timeout) 1ebcd

149 self._started = True 1ebcd

150 self._start_time = time.monotonic() 1ebcd

151 

152 logger.debug("%r entered", self) 1ebcd

153 return self 1ebcd

154 

155 def __exit__( 1a

156 self, exc_type: type[BaseException], exc_val: Exception, exc_tb: TracebackType 

157 ) -> Optional[bool]: 

158 with self._lock: 1ebcd

159 if not self._cancelled: 159 ↛ 161line 159 didn't jump to line 161 because the condition on line 159 was always true1ebcd

160 self._completed = True 1ebcd

161 self._end_time = time.monotonic() 1ebcd

162 

163 logger.debug("%r exited", self) 1ebcd

164 

165 @property 1a

166 def timeout(self): 1a

167 return self._timeout 1ebcd

168 

169 def started(self) -> bool: 1a

170 with self._lock: 

171 return self._started 

172 

173 def cancelled(self) -> bool: 1a

174 with self._lock: 1bcd

175 return self._cancelled 1bcd

176 

177 def timedout(self) -> bool: 1a

178 with self._lock: 

179 if not self._end_time or not self._deadline: 

180 return False 

181 return self._cancelled and self._end_time > self._deadline 

182 

183 def set_timeout(self, timeout: float): 1a

184 with self._lock: 

185 if self._started: 

186 raise RuntimeError("Cannot set timeout after scope has started.") 

187 self._timeout = timeout 

188 

189 def completed(self): 1a

190 with self._lock: 

191 return self._completed 

192 

193 def cancel(self, throw: bool = True) -> bool: 1a

194 """ 

195 Cancel this scope. 

196 

197 If `throw` is not set, this will only mark the scope as cancelled and will not 

198 throw the cancelled error. 

199 """ 

200 with self._lock: 

201 if not self._started: 

202 raise RuntimeError("Scope has not been entered.") 

203 

204 if self._completed: 

205 return False 

206 

207 if self._cancelled: 

208 return True 

209 

210 self._cancelled = True 

211 

212 logger.info("%r cancelling", self) 

213 

214 for callback in self._callbacks: 

215 callback() 

216 

217 return True 

218 

219 def add_cancel_callback(self, callback: Callable[[], None]): 1a

220 """ 

221 Add a callback to execute on cancellation. 

222 """ 

223 self._callbacks.append(callback) 

224 

225 def __repr__(self) -> str: 1a

226 with self._lock: 

227 state = ( 

228 "completed" 

229 if self._completed 

230 else ( 

231 "cancelled" 

232 if self._cancelled 

233 else "running" 

234 if self._started 

235 else "pending" 

236 ) 

237 ).upper() 

238 timeout = f", timeout={self._timeout:.2f}" if self._timeout else "" 

239 runtime = ( 

240 f", runtime={(self._end_time or time.monotonic()) - self._start_time:.2f}" 

241 if self._start_time 

242 else "" 

243 ) 

244 name = f", name={self.name!r}" if self.name else f"at {hex(id(self))}" 

245 return f"<{type(self).__name__}{name} {state}{timeout}{runtime}>" 

246 

247 

248class AsyncCancelScope(CancelScope): 1a

249 def __init__( 1a

250 self, name: Optional[str] = None, timeout: Optional[float] = None 

251 ) -> None: 

252 super().__init__(name=name, timeout=timeout) 1ebcd

253 

254 def __enter__(self): 1a

255 self.loop = asyncio.get_running_loop() 1ebcd

256 

257 super().__enter__() 1ebcd

258 

259 # Use anyio as the cancellation enforcer because it's very complicated and they 

260 # have done a good job 

261 self._anyio_scope = anyio.CancelScope( 1ebcd

262 deadline=self._deadline if self._deadline is not None else math.inf 

263 ).__enter__() 

264 

265 return self 1ebcd

266 

267 def __exit__( 1a

268 self, exc_type: type[BaseException], exc_val: Exception, exc_tb: TracebackType 

269 ) -> bool: 

270 if self._anyio_scope.cancel_called: 270 ↛ 272line 270 didn't jump to line 272 because the condition on line 270 was never true1bcd

271 # Mark as cancelled 

272 self.cancel(throw=False) 

273 

274 # TODO: Can we also delete the scope? 

275 # We have to exit this scope to prevent leaking memory. A fix for 

276 # issue #10952. 

277 self._anyio_scope.__exit__(exc_type, exc_val, exc_tb) 1bcd

278 

279 super().__exit__(exc_type, exc_val, exc_tb) 1bcd

280 

281 if self.cancelled() and exc_type is not CancelledError: 281 ↛ 283line 281 didn't jump to line 283 because the condition on line 281 was never true1bcd

282 # Ensure cancellation error is propagated on exit 

283 raise CancelledError() from exc_val 

284 

285 return False 1bcd

286 

287 def cancel(self, throw: bool = True): 1a

288 if not super().cancel(): 

289 return False 

290 

291 if throw: 

292 if self.loop is get_running_loop(): 

293 self._anyio_scope.cancel() 

294 else: 

295 # `Task.cancel` is not thread safe 

296 self.loop.call_soon_threadsafe(self._anyio_scope.cancel) 

297 

298 return True 

299 

300 

301class NullCancelScope(CancelScope): 1a

302 """ 

303 A cancel scope that does nothing. 

304 

305 This is used for environments where cancellation is not supported. 

306 """ 

307 

308 def __init__( 1a

309 self, 

310 name: Optional[str] = None, 

311 timeout: Optional[float] = None, 

312 reason: Optional[str] = None, 

313 ) -> None: 

314 super().__init__(name, timeout) 

315 self.reason = reason or "null cancel scope" 

316 

317 def cancel(self, throw: bool = True) -> bool: 1a

318 logger.warning("%r cannot cancel %s.", self, self.reason) 

319 return False 

320 

321 

322class AlarmCancelScope(CancelScope): 1a

323 """ 

324 A cancel scope that uses an alarm signal which can interrupt long-running system 

325 calls. 

326 

327 Only the main thread can be cancelled with an alarm signal, so this scope is only 

328 available in the main thread. 

329 """ 

330 

331 def __enter__(self): 1a

332 super().__enter__() 

333 

334 current_thread = threading.current_thread() 

335 self._previous_timer = None 

336 

337 if current_thread is not threading.main_thread(): 

338 raise ValueError( 

339 "Alarm based timeouts can only be used in the main thread." 

340 ) 

341 

342 self._previous_alarm_handler = signal.getsignal(signal.SIGALRM) 

343 

344 if self._previous_alarm_handler != signal.SIG_DFL: 

345 logger.warning( 

346 "%r overriding existing alarm handler %s", 

347 self, 

348 self._previous_alarm_handler, 

349 ) 

350 

351 # Capture alarm signals and raise a timeout 

352 signal.signal(signal.SIGALRM, self._sigalarm_to_error) 

353 

354 # Set a timer to raise an alarm signal 

355 if self.timeout is not None: 

356 # Use `setitimer` instead of `signal.alarm` for float support; raises a SIGALRM 

357 logger.debug("%r set alarm timer for %f seconds", self, self.timeout) 

358 self._previous_timer = signal.setitimer(signal.ITIMER_REAL, self.timeout) 

359 

360 return self 

361 

362 def _sigalarm_to_error(self, *args: object) -> None: 1a

363 logger.debug("%r captured alarm raising as cancelled error", self) 

364 if self.cancel(throw=False): 

365 shield = _get_thread_shield(threading.main_thread()) 

366 if shield.active(): 

367 logger.debug("%r thread shield active; delaying exception", self) 

368 shield.set_exception(CancelledError()) 

369 else: 

370 raise CancelledError() 

371 

372 def __exit__(self, *_: Any) -> Optional[bool]: 1a

373 retval = super().__exit__(*_) 

374 

375 if self.timeout is not None: 

376 # Restore the previous timer 

377 if TYPE_CHECKING: 

378 assert self._previous_timer is not None 

379 signal.setitimer(signal.ITIMER_REAL, *self._previous_timer) 

380 

381 # Restore the previous signal handler 

382 signal.signal(signal.SIGALRM, self._previous_alarm_handler) 

383 

384 return retval 

385 

386 def cancel(self, throw: bool = True): 1a

387 if not super().cancel(): 

388 return False 

389 

390 if throw: 

391 logger.debug("%r sending alarm signal to main thread", self) 

392 os.kill(os.getpid(), signal.SIGALRM) 

393 

394 return True 

395 

396 

397class WatcherThreadCancelScope(CancelScope): 1a

398 """ 

399 A cancel scope that uses a watcher thread and an injected exception to enforce 

400 cancellation. 

401 

402 The injected exception cannot interrupt calls and will be raised on the ~next 

403 instruction. This can raise exceptions in unexpected places. See `shield` for 

404 guarding against interruption. 

405 

406 If a timeout is specified, a watcher thread is spawned that will run for `timeout` 

407 seconds then send the exception to the supervised thread. 

408 """ 

409 

410 def __enter__(self): 1a

411 super().__enter__() 1ebcd

412 self._event = threading.Event() 1ebcd

413 self._enforcer_thread = None 1ebcd

414 self._supervised_thread = threading.current_thread() 1ebcd

415 

416 if self.timeout is not None: 416 ↛ 417line 416 didn't jump to line 417 because the condition on line 416 was never true1ebcd

417 name = self.name or f"for scope {hex(id(self))}" 

418 self._enforcer_thread = threading.Thread( 

419 target=self._timeout_enforcer, 

420 name=f"timeout-watcher {name} {self.timeout:.2f}", 

421 ) 

422 self._enforcer_thread.start() 

423 

424 return self 1ebcd

425 

426 def __exit__(self, *_: Any) -> Optional[bool]: 1a

427 retval = super().__exit__(*_) 1ebcd

428 self._event.set() 1ebcd

429 if self._enforcer_thread: 429 ↛ 430line 429 didn't jump to line 430 because the condition on line 429 was never true1ebcd

430 logger.debug("%r joining enforcer thread %r", self, self._enforcer_thread) 

431 self._enforcer_thread.join() 

432 return retval 1ebcd

433 

434 def _send_cancelled_error(self): 1a

435 """ 

436 Send a cancelled error to the supervised thread. 

437 """ 

438 if self._supervised_thread.is_alive(): 

439 logger.debug( 

440 "%r sending exception to supervised thread %r", 

441 self, 

442 self._supervised_thread, 

443 ) 

444 with _get_thread_shield(self._supervised_thread): 

445 try: 

446 _send_exception_to_thread(self._supervised_thread, CancelledError) 

447 except ValueError: 

448 # If the thread is gone; just move on without error 

449 logger.debug("Thread missing!") 

450 

451 def _timeout_enforcer(self): 1a

452 """ 

453 Target for a thread that enforces a timeout. 

454 """ 

455 if not self._event.wait(self.timeout): 

456 logger.debug("%r enforcer detected timeout!", self) 

457 if self.cancel(throw=False): 

458 with _get_thread_shield(self._supervised_thread): 

459 self._send_cancelled_error() 

460 

461 # Wait for the supervised thread to exit its context 

462 logger.debug("%r waiting for supervised thread to exit", self) 

463 self._event.wait() 

464 

465 def cancel(self, throw: bool = True): 1a

466 if not super().cancel(): 

467 return False 

468 

469 if throw: 

470 self._send_cancelled_error() 

471 

472 return True 

473 

474 

475@overload 1a

476def get_deadline(timeout: float) -> float: ... 476 ↛ exitline 476 didn't return from function 'get_deadline' because 1a

477 

478 

479@overload 1a

480def get_deadline(timeout: None) -> None: ... 480 ↛ exitline 480 didn't return from function 'get_deadline' because 1a

481 

482 

483def get_deadline(timeout: Optional[float]) -> Optional[float]: 1a

484 """ 

485 Compute an deadline given a timeout. 

486 

487 Uses a monotonic clock. 

488 """ 

489 if timeout is None: 489 ↛ 492line 489 didn't jump to line 492 because the condition on line 489 was always true1ebcd

490 return None 1ebcd

491 

492 return time.monotonic() + timeout 

493 

494 

495def get_timeout(deadline: Optional[float]): 1a

496 """ 

497 Compute an timeout given a deadline. 

498 

499 Uses a monotonic clock. 

500 """ 

501 if deadline is None: 501 ↛ 504line 501 didn't jump to line 504 because the condition on line 501 was always true1ebcd

502 return None 1ebcd

503 

504 return max(0, deadline - time.monotonic()) 

505 

506 

507@contextlib.contextmanager 1a

508def cancel_async_at(deadline: Optional[float], name: Optional[str] = None): 1a

509 """ 

510 Cancel any async calls within the context if it does not exit by the given deadline. 

511 

512 Deadlines must be computed with the monotonic clock. See `get_deadline`. 

513 

514 A timeout error will be raised on the next `await` when the timeout expires. 

515 

516 Yields a `CancelContext`. 

517 """ 

518 with cancel_async_after(get_timeout(deadline), name=name) as ctx: 1ebcd

519 yield ctx 1ebcd

520 

521 

522@contextlib.contextmanager 1a

523def cancel_async_after(timeout: Optional[float], name: Optional[str] = None): 1a

524 """ 

525 Cancel any async calls within the context if it does not exit after the given 

526 timeout. 

527 

528 A timeout error will be raised on the next `await` when the timeout expires. 

529 

530 Yields a `CancelContext`. 

531 """ 

532 with AsyncCancelScope(timeout=timeout, name=name) as ctx: 1ebcd

533 yield ctx 1ebcd

534 

535 

536@contextlib.contextmanager 1a

537def cancel_sync_at(deadline: Optional[float], name: Optional[str] = None): 1a

538 """ 

539 Cancel any sync calls within the context if it does not exit by the given deadline. 

540 

541 Deadlines must be computed with the monotonic clock. See `get_deadline`. 

542 

543 The cancel method varies depending on if this is called in the main thread or not. 

544 See `cancel_sync_after` for details 

545 

546 Yields a `CancelContext`. 

547 """ 

548 timeout = max(0, deadline - time.monotonic()) if deadline is not None else None 1ebcd

549 

550 with cancel_sync_after(timeout, name=name) as ctx: 1ebcd

551 yield ctx 1ebcd

552 

553 

554@contextlib.contextmanager 1a

555def cancel_sync_after(timeout: Optional[float], name: Optional[str] = None): 1a

556 """ 

557 Cancel any sync calls within the context if it does not exit after the given 

558 timeout. 

559 

560 The timeout method varies depending on if this is called in the main thread or not. 

561 See `AlarmCancelScope` and `WatcherThreadCancelScope` for details. 

562 

563 Yields a `CancelContext`. 

564 """ 

565 

566 if sys.platform.startswith("win"): 566 ↛ 567line 566 didn't jump to line 567 because the condition on line 566 was never true1ebcd

567 yield NullCancelScope(reason="cancellation is not supported on Windows") 

568 return 

569 

570 thread = threading.current_thread() 1ebcd

571 existing_alarm_handler = signal.getsignal(signal.SIGALRM) != signal.SIG_DFL 1ebcd

572 

573 if ( 573 ↛ 581line 573 didn't jump to line 581 because the condition on line 573 was never true

574 thread is threading.main_thread() 

575 # Avoid nested alarm handlers; it's hard to follow and they will interfere with 

576 # each other 

577 and not existing_alarm_handler 

578 # Avoid using an alarm when there is no timeout; it's better saved for that case 

579 and timeout is not None 

580 ): 

581 scope = AlarmCancelScope(name=name, timeout=timeout) 

582 else: 

583 scope = WatcherThreadCancelScope(name=name, timeout=timeout) 1ebcd

584 

585 with scope: 1ebcd

586 yield scope 1ebcd

587 

588 

589def _send_exception_to_thread(thread: threading.Thread, exc_type: type[BaseException]): 1a

590 """ 

591 Raise an exception in a thread. 

592 

593 This will not interrupt long-running system calls like `sleep` or `wait`. 

594 """ 

595 if not thread.ident: 

596 raise ValueError("Thread is not started.") 

597 ret = ctypes.pythonapi.PyThreadState_SetAsyncExc( 

598 ctypes.c_long(thread.ident), ctypes.py_object(exc_type) 

599 ) 

600 if ret == 0: 

601 raise ValueError("Thread not found.")