Coverage for /usr/local/lib/python3.12/site-packages/prefect/_internal/concurrency/threads.py: 26%

193 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 11:21 +0000

1""" 

2Utilities for managing worker threads. 

3""" 

4 

5from __future__ import annotations 1a

6 

7import asyncio 1a

8import atexit 1a

9import concurrent.futures 1a

10import itertools 1a

11import os 1a

12import queue 1a

13import threading 1a

14import weakref 1a

15from typing import Any, Optional 1a

16 

17from typing_extensions import TypeVar 1a

18 

19from prefect._internal.concurrency import logger 1a

20from prefect._internal.concurrency.calls import Call, Portal 1a

21from prefect._internal.concurrency.cancellation import CancelledError 1a

22from prefect._internal.concurrency.event_loop import get_running_loop 1a

23from prefect._internal.concurrency.primitives import Event 1a

24 

25T = TypeVar("T", infer_variance=True) 1a

26 

27# Track all active instances for fork handling 

28_active_instances: weakref.WeakSet[WorkerThread | EventLoopThread] = weakref.WeakSet() 1a

29 

30 

31def _reset_after_fork_in_child(): 1a

32 """ 

33 Reset thread state after fork() to prevent multiprocessing deadlocks on Linux. 

34 

35 When fork() is called, the child process inherits all thread state and locks 

36 from the parent, but only the calling thread continues. This leaves other threads' 

37 locks in inconsistent states causing deadlocks. 

38 

39 This handler is called by os.register_at_fork() in the child process after fork(). 

40 """ 

41 for instance in list(_active_instances): 

42 instance.reset_for_fork() 

43 

44 

45# Register fork handler if supported (POSIX systems) 

46if hasattr(os, "register_at_fork"): 46 ↛ 54line 46 didn't jump to line 54 because the condition on line 46 was always true1a

47 try: 1a

48 os.register_at_fork(after_in_child=_reset_after_fork_in_child) 1a

49 except RuntimeError: 

50 # Might fail in certain contexts (e.g., if already in a child process) 

51 pass 

52 

53 

54class WorkerThread(Portal): 1a

55 """ 

56 A portal to a worker running on a thread. 

57 """ 

58 

59 # Used for unique thread names by default 

60 _counter = itertools.count().__next__ 1a

61 

62 def __init__( 1a

63 self, name: Optional[str] = None, daemon: bool = False, run_once: bool = False 

64 ): 

65 name = name or f"WorkerThread-{self._counter()}" 

66 

67 self.thread = threading.Thread( 

68 name=name, daemon=daemon, target=self._entrypoint 

69 ) 

70 self._queue: queue.Queue[Optional[Call[Any]]] = queue.Queue() 

71 self._run_once: bool = run_once 

72 self._started: bool = False 

73 self._submitted_count: int = 0 

74 self._lock = threading.Lock() 

75 

76 # Track this instance for fork handling 

77 _active_instances.add(self) 

78 

79 if not daemon: 

80 atexit.register(self.shutdown) 

81 

82 def reset_for_fork(self) -> None: 1a

83 """Reset state after fork() to prevent deadlocks in child process.""" 

84 self._started = False 

85 self._queue = queue.Queue() 

86 self._lock = threading.Lock() 

87 self._submitted_count = 0 

88 

89 def start(self) -> None: 1a

90 """ 

91 Start the worker thread. 

92 """ 

93 with self._lock: 

94 if not self._started: 

95 self._started = True 

96 self.thread.start() 

97 

98 def submit(self, call: Call[T]) -> Call[T]: 1a

99 if self._submitted_count > 0 and self._run_once: 

100 raise RuntimeError( 

101 "Worker configured to only run once. A call has already been submitted." 

102 ) 

103 

104 # Start on first submission if not started 

105 if not self._started: 

106 self.start() 

107 

108 # Track the portal running the call 

109 call.set_runner(self) 

110 

111 # Put the call in the queue 

112 self._queue.put_nowait(call) 

113 

114 self._submitted_count += 1 

115 if self._run_once: 

116 call.future.add_done_callback(lambda _: self.shutdown()) 

117 

118 return call 

119 

120 def shutdown(self) -> None: 1a

121 """ 

122 Shutdown the worker thread. Does not wait for the thread to stop. 

123 """ 

124 self._queue.put_nowait(None) 

125 

126 @property 1a

127 def name(self) -> str: 1a

128 return self.thread.name 

129 

130 def _entrypoint(self) -> None: 1a

131 """ 

132 Entrypoint for the thread. 

133 """ 

134 try: 

135 self._run_until_shutdown() 

136 except CancelledError: 

137 logger.exception("%s was cancelled", self.name) 

138 except BaseException: 

139 # Log exceptions that crash the thread 

140 logger.exception("%s encountered exception", self.name) 

141 raise 

142 

143 def _run_until_shutdown(self): 1a

144 while True: 

145 call = self._queue.get() 

146 if call is None: 

147 logger.info("Exiting worker thread %r", self.name) 

148 break # shutdown requested 

149 

150 task = call.run() 

151 assert task is None # calls should never return a coroutine in this worker 

152 del call 

153 

154 def __enter__(self): 1a

155 self.start() 

156 return self 

157 

158 def __exit__(self, *_): 1a

159 self.shutdown() 

160 

161 

162class EventLoopThread(Portal): 1a

163 """ 

164 A portal to a worker running on a thread with an event loop. 

165 """ 

166 

167 def __init__( 1a

168 self, 

169 name: str = "EventLoopThread", 

170 daemon: bool = False, 

171 run_once: bool = False, 

172 ): 

173 self.thread = threading.Thread( 

174 name=name, daemon=daemon, target=self._entrypoint 

175 ) 

176 self._ready_future: concurrent.futures.Future[bool] = ( 

177 concurrent.futures.Future() 

178 ) 

179 self._loop: Optional[asyncio.AbstractEventLoop] = None 

180 self._shutdown_event: Event = Event() 

181 self._run_once: bool = run_once 

182 self._submitted_count: int = 0 

183 self._on_shutdown: list[Call[Any]] = [] 

184 self._lock = threading.Lock() 

185 

186 # Track this instance for fork handling 

187 _active_instances.add(self) 

188 

189 if not daemon: 

190 atexit.register(self.shutdown) 

191 

192 def reset_for_fork(self) -> None: 1a

193 """Reset state after fork() to prevent deadlocks in child process.""" 

194 self._loop = None 

195 self._ready_future = concurrent.futures.Future() 

196 self._shutdown_event = Event() 

197 self._lock = threading.Lock() 

198 self._submitted_count = 0 

199 self._on_shutdown = [] 

200 

201 def start(self): 1a

202 """ 

203 Start the worker thread; raises any exceptions encountered during startup. 

204 """ 

205 with self._lock: 

206 if self._loop is None: 

207 self.thread.start() 

208 self._ready_future.result() 

209 

210 def submit(self, call: Call[T]) -> Call[T]: 1a

211 if self._loop is None: 

212 self.start() 

213 

214 with self._lock: 

215 if self._submitted_count > 0 and self._run_once: 

216 raise RuntimeError( 

217 "Worker configured to only run once. A call has already been" 

218 " submitted." 

219 ) 

220 

221 if self._shutdown_event.is_set(): 

222 raise RuntimeError("Worker is shutdown.") 

223 

224 # Track the portal running the call 

225 call.set_runner(self) 

226 

227 if self._run_once: 

228 call.future.add_done_callback(lambda _: self.shutdown()) 

229 

230 # Submit the call to the event loop 

231 assert self._loop is not None 

232 asyncio.run_coroutine_threadsafe(self._run_call(call), self._loop) 

233 self._submitted_count += 1 

234 

235 return call 

236 

237 def shutdown(self) -> None: 1a

238 """ 

239 Shutdown the worker thread. Does not wait for the thread to stop. 

240 """ 

241 with self._lock: 

242 self._shutdown_event.set() 

243 

244 @property 1a

245 def name(self) -> str: 1a

246 return self.thread.name 

247 

248 @property 1a

249 def running(self) -> bool: 1a

250 return not self._shutdown_event.is_set() 

251 

252 @property 1a

253 def loop(self) -> asyncio.AbstractEventLoop | None: 1a

254 return self._loop 

255 

256 def _entrypoint(self): 1a

257 """ 

258 Entrypoint for the thread. 

259 

260 Immediately create a new event loop and pass control to `run_until_shutdown`. 

261 """ 

262 try: 

263 asyncio.run(self._run_until_shutdown()) 

264 except BaseException: 

265 # Log exceptions that crash the thread 

266 logger.exception("%s encountered exception", self.name) 

267 raise 

268 

269 async def _run_until_shutdown(self): 1a

270 try: 

271 self._loop = asyncio.get_running_loop() 

272 self._ready_future.set_result(True) 

273 except Exception as exc: 

274 self._ready_future.set_exception(exc) 

275 return 

276 

277 await self._shutdown_event.wait() 

278 

279 for call in self._on_shutdown: 

280 await self._run_call(call) 

281 

282 # Empty the list to allow calls to be garbage collected. Issue #10338. 

283 self._on_shutdown = [] 

284 

285 async def _run_call(self, call: Call[Any]) -> None: 1a

286 task = call.run() 

287 if task is not None: 

288 await task 

289 

290 def add_shutdown_call(self, call: Call[Any]) -> None: 1a

291 self._on_shutdown.append(call) 

292 

293 def __enter__(self): 1a

294 self.start() 

295 return self 

296 

297 def __exit__(self, *_): 1a

298 self.shutdown() 

299 

300 

301# the GLOBAL LOOP is used for background services, like logs 

302_global_loop: Optional[EventLoopThread] = None 1a

303# the RUN SYNC LOOP is used exclusively for running async functions in a sync context via asyncutils.run_sync 

304_run_sync_loop: Optional[EventLoopThread] = None 1a

305 

306 

307def get_global_loop() -> EventLoopThread: 1a

308 """ 

309 Get the global loop thread. 

310 

311 Creates a new one if there is not one available. 

312 """ 

313 global _global_loop 

314 

315 # Create a new worker on first call or if the existing worker is dead 

316 if ( 

317 _global_loop is None 

318 or not _global_loop.thread.is_alive() 

319 or not _global_loop.running 

320 ): 

321 _global_loop = EventLoopThread(daemon=True, name="GlobalEventLoopThread") 

322 _global_loop.start() 

323 

324 return _global_loop 

325 

326 

327def in_global_loop() -> bool: 1a

328 """ 

329 Check if called from the global loop. 

330 """ 

331 if _global_loop is None: 

332 # Avoid creating a global loop if there isn't one 

333 return False 

334 

335 return getattr(get_global_loop(), "_loop") == get_running_loop() 

336 

337 

338def get_run_sync_loop() -> EventLoopThread: 1a

339 """ 

340 Get the run_sync loop thread. 

341 

342 Creates a new one if there is not one available. 

343 """ 

344 global _run_sync_loop 

345 

346 # Create a new worker on first call or if the existing worker is dead 

347 if ( 

348 _run_sync_loop is None 

349 or not _run_sync_loop.thread.is_alive() 

350 or not _run_sync_loop.running 

351 ): 

352 _run_sync_loop = EventLoopThread(daemon=True, name="RunSyncEventLoopThread") 

353 _run_sync_loop.start() 

354 

355 return _run_sync_loop 

356 

357 

358def in_run_sync_loop() -> bool: 1a

359 """ 

360 Check if called from the global loop. 

361 """ 

362 if _run_sync_loop is None: 

363 # Avoid creating a global loop if there isn't one 

364 return False 

365 

366 return getattr(get_run_sync_loop(), "_loop") == get_running_loop() 

367 

368 

369def wait_for_global_loop_exit(timeout: Optional[float] = None) -> None: 1a

370 """ 

371 Shutdown the global loop and wait for it to exit. 

372 """ 

373 loop_thread = get_global_loop() 

374 loop_thread.shutdown() 

375 

376 if threading.get_ident() == loop_thread.thread.ident: 

377 raise RuntimeError("Cannot wait for the loop thread from inside itself.") 

378 

379 loop_thread.thread.join(timeout)