Coverage for /usr/local/lib/python3.12/site-packages/prefect/_internal/concurrency/threads.py: 26%
193 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
1"""
2Utilities for managing worker threads.
3"""
5from __future__ import annotations 1a
7import asyncio 1a
8import atexit 1a
9import concurrent.futures 1a
10import itertools 1a
11import os 1a
12import queue 1a
13import threading 1a
14import weakref 1a
15from typing import Any, Optional 1a
17from typing_extensions import TypeVar 1a
19from prefect._internal.concurrency import logger 1a
20from prefect._internal.concurrency.calls import Call, Portal 1a
21from prefect._internal.concurrency.cancellation import CancelledError 1a
22from prefect._internal.concurrency.event_loop import get_running_loop 1a
23from prefect._internal.concurrency.primitives import Event 1a
25T = TypeVar("T", infer_variance=True) 1a
27# Track all active instances for fork handling
28_active_instances: weakref.WeakSet[WorkerThread | EventLoopThread] = weakref.WeakSet() 1a
31def _reset_after_fork_in_child(): 1a
32 """
33 Reset thread state after fork() to prevent multiprocessing deadlocks on Linux.
35 When fork() is called, the child process inherits all thread state and locks
36 from the parent, but only the calling thread continues. This leaves other threads'
37 locks in inconsistent states causing deadlocks.
39 This handler is called by os.register_at_fork() in the child process after fork().
40 """
41 for instance in list(_active_instances):
42 instance.reset_for_fork()
45# Register fork handler if supported (POSIX systems)
46if hasattr(os, "register_at_fork"): 46 ↛ 54line 46 didn't jump to line 54 because the condition on line 46 was always true1a
47 try: 1a
48 os.register_at_fork(after_in_child=_reset_after_fork_in_child) 1a
49 except RuntimeError:
50 # Might fail in certain contexts (e.g., if already in a child process)
51 pass
54class WorkerThread(Portal): 1a
55 """
56 A portal to a worker running on a thread.
57 """
59 # Used for unique thread names by default
60 _counter = itertools.count().__next__ 1a
62 def __init__( 1a
63 self, name: Optional[str] = None, daemon: bool = False, run_once: bool = False
64 ):
65 name = name or f"WorkerThread-{self._counter()}"
67 self.thread = threading.Thread(
68 name=name, daemon=daemon, target=self._entrypoint
69 )
70 self._queue: queue.Queue[Optional[Call[Any]]] = queue.Queue()
71 self._run_once: bool = run_once
72 self._started: bool = False
73 self._submitted_count: int = 0
74 self._lock = threading.Lock()
76 # Track this instance for fork handling
77 _active_instances.add(self)
79 if not daemon:
80 atexit.register(self.shutdown)
82 def reset_for_fork(self) -> None: 1a
83 """Reset state after fork() to prevent deadlocks in child process."""
84 self._started = False
85 self._queue = queue.Queue()
86 self._lock = threading.Lock()
87 self._submitted_count = 0
89 def start(self) -> None: 1a
90 """
91 Start the worker thread.
92 """
93 with self._lock:
94 if not self._started:
95 self._started = True
96 self.thread.start()
98 def submit(self, call: Call[T]) -> Call[T]: 1a
99 if self._submitted_count > 0 and self._run_once:
100 raise RuntimeError(
101 "Worker configured to only run once. A call has already been submitted."
102 )
104 # Start on first submission if not started
105 if not self._started:
106 self.start()
108 # Track the portal running the call
109 call.set_runner(self)
111 # Put the call in the queue
112 self._queue.put_nowait(call)
114 self._submitted_count += 1
115 if self._run_once:
116 call.future.add_done_callback(lambda _: self.shutdown())
118 return call
120 def shutdown(self) -> None: 1a
121 """
122 Shutdown the worker thread. Does not wait for the thread to stop.
123 """
124 self._queue.put_nowait(None)
126 @property 1a
127 def name(self) -> str: 1a
128 return self.thread.name
130 def _entrypoint(self) -> None: 1a
131 """
132 Entrypoint for the thread.
133 """
134 try:
135 self._run_until_shutdown()
136 except CancelledError:
137 logger.exception("%s was cancelled", self.name)
138 except BaseException:
139 # Log exceptions that crash the thread
140 logger.exception("%s encountered exception", self.name)
141 raise
143 def _run_until_shutdown(self): 1a
144 while True:
145 call = self._queue.get()
146 if call is None:
147 logger.info("Exiting worker thread %r", self.name)
148 break # shutdown requested
150 task = call.run()
151 assert task is None # calls should never return a coroutine in this worker
152 del call
154 def __enter__(self): 1a
155 self.start()
156 return self
158 def __exit__(self, *_): 1a
159 self.shutdown()
162class EventLoopThread(Portal): 1a
163 """
164 A portal to a worker running on a thread with an event loop.
165 """
167 def __init__( 1a
168 self,
169 name: str = "EventLoopThread",
170 daemon: bool = False,
171 run_once: bool = False,
172 ):
173 self.thread = threading.Thread(
174 name=name, daemon=daemon, target=self._entrypoint
175 )
176 self._ready_future: concurrent.futures.Future[bool] = (
177 concurrent.futures.Future()
178 )
179 self._loop: Optional[asyncio.AbstractEventLoop] = None
180 self._shutdown_event: Event = Event()
181 self._run_once: bool = run_once
182 self._submitted_count: int = 0
183 self._on_shutdown: list[Call[Any]] = []
184 self._lock = threading.Lock()
186 # Track this instance for fork handling
187 _active_instances.add(self)
189 if not daemon:
190 atexit.register(self.shutdown)
192 def reset_for_fork(self) -> None: 1a
193 """Reset state after fork() to prevent deadlocks in child process."""
194 self._loop = None
195 self._ready_future = concurrent.futures.Future()
196 self._shutdown_event = Event()
197 self._lock = threading.Lock()
198 self._submitted_count = 0
199 self._on_shutdown = []
201 def start(self): 1a
202 """
203 Start the worker thread; raises any exceptions encountered during startup.
204 """
205 with self._lock:
206 if self._loop is None:
207 self.thread.start()
208 self._ready_future.result()
210 def submit(self, call: Call[T]) -> Call[T]: 1a
211 if self._loop is None:
212 self.start()
214 with self._lock:
215 if self._submitted_count > 0 and self._run_once:
216 raise RuntimeError(
217 "Worker configured to only run once. A call has already been"
218 " submitted."
219 )
221 if self._shutdown_event.is_set():
222 raise RuntimeError("Worker is shutdown.")
224 # Track the portal running the call
225 call.set_runner(self)
227 if self._run_once:
228 call.future.add_done_callback(lambda _: self.shutdown())
230 # Submit the call to the event loop
231 assert self._loop is not None
232 asyncio.run_coroutine_threadsafe(self._run_call(call), self._loop)
233 self._submitted_count += 1
235 return call
237 def shutdown(self) -> None: 1a
238 """
239 Shutdown the worker thread. Does not wait for the thread to stop.
240 """
241 with self._lock:
242 self._shutdown_event.set()
244 @property 1a
245 def name(self) -> str: 1a
246 return self.thread.name
248 @property 1a
249 def running(self) -> bool: 1a
250 return not self._shutdown_event.is_set()
252 @property 1a
253 def loop(self) -> asyncio.AbstractEventLoop | None: 1a
254 return self._loop
256 def _entrypoint(self): 1a
257 """
258 Entrypoint for the thread.
260 Immediately create a new event loop and pass control to `run_until_shutdown`.
261 """
262 try:
263 asyncio.run(self._run_until_shutdown())
264 except BaseException:
265 # Log exceptions that crash the thread
266 logger.exception("%s encountered exception", self.name)
267 raise
269 async def _run_until_shutdown(self): 1a
270 try:
271 self._loop = asyncio.get_running_loop()
272 self._ready_future.set_result(True)
273 except Exception as exc:
274 self._ready_future.set_exception(exc)
275 return
277 await self._shutdown_event.wait()
279 for call in self._on_shutdown:
280 await self._run_call(call)
282 # Empty the list to allow calls to be garbage collected. Issue #10338.
283 self._on_shutdown = []
285 async def _run_call(self, call: Call[Any]) -> None: 1a
286 task = call.run()
287 if task is not None:
288 await task
290 def add_shutdown_call(self, call: Call[Any]) -> None: 1a
291 self._on_shutdown.append(call)
293 def __enter__(self): 1a
294 self.start()
295 return self
297 def __exit__(self, *_): 1a
298 self.shutdown()
301# the GLOBAL LOOP is used for background services, like logs
302_global_loop: Optional[EventLoopThread] = None 1a
303# the RUN SYNC LOOP is used exclusively for running async functions in a sync context via asyncutils.run_sync
304_run_sync_loop: Optional[EventLoopThread] = None 1a
307def get_global_loop() -> EventLoopThread: 1a
308 """
309 Get the global loop thread.
311 Creates a new one if there is not one available.
312 """
313 global _global_loop
315 # Create a new worker on first call or if the existing worker is dead
316 if (
317 _global_loop is None
318 or not _global_loop.thread.is_alive()
319 or not _global_loop.running
320 ):
321 _global_loop = EventLoopThread(daemon=True, name="GlobalEventLoopThread")
322 _global_loop.start()
324 return _global_loop
327def in_global_loop() -> bool: 1a
328 """
329 Check if called from the global loop.
330 """
331 if _global_loop is None:
332 # Avoid creating a global loop if there isn't one
333 return False
335 return getattr(get_global_loop(), "_loop") == get_running_loop()
338def get_run_sync_loop() -> EventLoopThread: 1a
339 """
340 Get the run_sync loop thread.
342 Creates a new one if there is not one available.
343 """
344 global _run_sync_loop
346 # Create a new worker on first call or if the existing worker is dead
347 if (
348 _run_sync_loop is None
349 or not _run_sync_loop.thread.is_alive()
350 or not _run_sync_loop.running
351 ):
352 _run_sync_loop = EventLoopThread(daemon=True, name="RunSyncEventLoopThread")
353 _run_sync_loop.start()
355 return _run_sync_loop
358def in_run_sync_loop() -> bool: 1a
359 """
360 Check if called from the global loop.
361 """
362 if _run_sync_loop is None:
363 # Avoid creating a global loop if there isn't one
364 return False
366 return getattr(get_run_sync_loop(), "_loop") == get_running_loop()
369def wait_for_global_loop_exit(timeout: Optional[float] = None) -> None: 1a
370 """
371 Shutdown the global loop and wait for it to exit.
372 """
373 loop_thread = get_global_loop()
374 loop_thread.shutdown()
376 if threading.get_ident() == loop_thread.thread.ident:
377 raise RuntimeError("Cannot wait for the loop thread from inside itself.")
379 loop_thread.thread.join(timeout)