Coverage for /usr/local/lib/python3.12/site-packages/prefect/_internal/concurrency/services.py: 30%
258 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
1from __future__ import annotations 1a
3import abc 1a
4import asyncio 1a
5import concurrent.futures 1a
6import contextlib 1a
7import logging 1a
8import os 1a
9import queue 1a
10import threading 1a
11import weakref 1a
12from collections.abc import AsyncGenerator, Awaitable, Coroutine, Generator, Hashable 1a
13from typing import TYPE_CHECKING, Any, Generic, NoReturn, Optional, Union, cast 1a
15from typing_extensions import Self, TypeVar, TypeVarTuple, Unpack 1a
17from prefect._internal.concurrency import logger 1a
18from prefect._internal.concurrency.api import create_call, from_sync 1a
19from prefect._internal.concurrency.cancellation import get_deadline, get_timeout 1a
20from prefect._internal.concurrency.event_loop import get_running_loop 1a
21from prefect._internal.concurrency.threads import WorkerThread, get_global_loop 1a
23T = TypeVar("T") 1a
24Ts = TypeVarTuple("Ts") 1a
25R = TypeVar("R", infer_variance=True) 1a
27# Track all active services for fork handling
28_active_services: weakref.WeakSet[_QueueServiceBase[Any]] = weakref.WeakSet() 1a
31def _reset_services_after_fork(): 1a
32 """
33 Reset service state after fork() to prevent multiprocessing deadlocks on Linux.
35 Called by os.register_at_fork() in the child process after fork().
36 """
37 for service in list(_active_services):
38 service.reset_for_fork()
40 # Reset the class-level instance tracking
41 _QueueServiceBase.reset_instances_for_fork()
44# Register fork handler if supported (POSIX systems)
45if hasattr(os, "register_at_fork"): 45 ↛ 57line 45 didn't jump to line 57 because the condition on line 45 was always true1a
46 try: 1a
47 os.register_at_fork(after_in_child=_reset_services_after_fork) 1a
48 except RuntimeError as e:
49 # Might fail in certain contexts (e.g., if already in a child process)
50 logger.debug(
51 "failed to register fork handler: %s (this may occur in child processes)",
52 e,
53 )
54 pass
57class _QueueServiceBase(abc.ABC, Generic[T]): 1a
58 _instances: dict[int, Self] = {} 1a
59 _instance_lock = threading.Lock() 1a
61 def __init__(self, *args: Hashable) -> None: 1a
62 self._queue: queue.Queue[Optional[T]] = queue.Queue()
63 self._loop: Optional[asyncio.AbstractEventLoop] = None
64 self._done_event: Optional[asyncio.Event] = None
65 self._task: Optional[asyncio.Task[None]] = None
66 self._stopped: bool = False
67 self._started: bool = False
68 self._key = hash((self.__class__, *args))
69 self._lock = threading.Lock()
70 self._queue_get_thread = WorkerThread(
71 # TODO: This thread should not need to be a daemon but when it is not, it
72 # can prevent the interpreter from exiting.
73 daemon=True,
74 name=f"{type(self).__name__}Thread",
75 )
76 self._logger = logging.getLogger(f"{type(self).__name__}")
78 # Track this instance for fork handling
79 _active_services.add(self)
81 def reset_for_fork(self) -> None: 1a
82 """Reset instance state after fork() to prevent deadlocks in child process."""
83 self._stopped = True
84 self._started = False
85 self._loop = None
86 self._done_event = None
87 self._task = None
88 self._queue = queue.Queue()
89 self._lock = threading.Lock()
91 @classmethod 1a
92 def reset_instances_for_fork(cls) -> None: 1a
93 """Reset class-level state after fork() to prevent deadlocks in child process."""
94 cls._instances.clear()
95 cls._instance_lock = threading.Lock()
97 def start(self) -> None: 1a
98 logger.debug("Starting service %r", self)
99 loop_thread = get_global_loop()
101 if not asyncio.get_running_loop() == getattr(loop_thread, "_loop"):
102 raise RuntimeError("Services must run on the global loop thread.")
104 self._loop = asyncio.get_running_loop()
105 self._done_event = asyncio.Event()
106 self._task = self._loop.create_task(self._run())
107 self._queue_get_thread.start()
108 self._started = True
110 # Ensure that we wait for worker completion before loop thread shutdown
111 loop_thread.add_shutdown_call(create_call(self.drain))
113 # Stop at interpreter exit by default
114 # Handling items may require spawning a thread and in 3.9 new threads
115 # cannot be spawned after the interpreter finalizes threads which
116 # happens _before_ the normal `atexit` hook is called resulting in
117 # failure to process items. This is particularly relevant for services
118 # which use an httpx client. See related issue at
119 # https://github.com/python/cpython/issues/86813
120 threading._register_atexit(self._at_exit) # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
122 def _at_exit(self) -> None: 1a
123 self.drain(at_exit=True)
125 def _stop(self, at_exit: bool = False) -> None: 1a
126 """
127 Stop running this instance.
129 Does not wait for the instance to finish. See `drain`.
130 """
132 if self._stopped:
133 return
135 with self._lock:
136 if not at_exit: # The logger may not be available during interpreter exit
137 logger.debug("Stopping service %r", self)
139 # Stop sending work to this instance
140 self._remove_instance()
141 self._stopped = True
143 # Allow asyncio task to be garbage-collected. Its context may contain
144 # references to all Prefect Task calls made during a flow run, through
145 # EngineContext. Issue #10338.
146 self._task = None
148 # Signal completion to the loop
149 self._queue.put_nowait(None)
151 @abc.abstractmethod 1a
152 def send(self, item: Any) -> Any: 1a
153 raise NotImplementedError
155 async def _run(self) -> None: 1a
156 try:
157 async with self._lifespan():
158 await self._main_loop()
159 except BaseException:
160 self._remove_instance()
161 # The logging call yields to another thread, so we must remove the instance
162 # before reporting the failure to prevent retrieval of a dead instance
163 log_traceback = logger.isEnabledFor(logging.DEBUG)
164 logger.error(
165 "Service %r failed with %s pending items.",
166 type(self).__name__,
167 self._queue.qsize(),
168 exc_info=log_traceback,
169 )
170 finally:
171 self._remove_instance()
173 # Shutdown the worker thread
174 self._queue_get_thread.shutdown()
176 self._stopped = True
177 assert self._done_event is not None
178 self._done_event.set()
180 async def _main_loop(self) -> None: 1a
181 last_log_time = 0
182 log_interval = 4 # log every 4 seconds
184 while True:
185 item: Optional[T] = await self._queue_get_thread.submit(
186 create_call(self._queue.get)
187 ).aresult()
189 if self._stopped:
190 current_time = asyncio.get_event_loop().time()
191 queue_size = self._queue.qsize()
193 if current_time - last_log_time >= log_interval and queue_size > 0:
194 self._logger.warning(
195 f"Still processing items: {queue_size} items remaining..."
196 )
197 last_log_time = current_time
199 if item is None:
200 logger.debug("Exiting service %r", self)
201 self._queue.task_done()
202 break
204 try:
205 logger.debug("Service %r handling item %r", self, item)
206 await self._handle(item)
207 except Exception:
208 log_traceback = logger.isEnabledFor(logging.DEBUG)
209 logger.error(
210 "Service %r failed to process item %r",
211 type(self).__name__,
212 item,
213 exc_info=log_traceback,
214 )
215 finally:
216 self._queue.task_done()
218 @abc.abstractmethod 1a
219 async def _handle(self, item: Any) -> Any: 1a
220 raise NotImplementedError
222 @contextlib.asynccontextmanager 1a
223 async def _lifespan(self) -> AsyncGenerator[None, Any]: 1a
224 """
225 Perform any setup and teardown for the service.
226 """
227 yield
229 def _drain(self, at_exit: bool = False) -> concurrent.futures.Future[bool]: 1a
230 """
231 Internal implementation for `drain`. Returns a future for sync/async interfaces.
232 """
233 if not at_exit: # The logger may not be available during interpreter exit
234 logger.debug("Draining service %r", self)
236 self._stop(at_exit=at_exit)
238 assert self._done_event is not None
239 if self._done_event.is_set():
240 future: concurrent.futures.Future[bool] = concurrent.futures.Future()
241 future.set_result(False)
242 return future
244 assert self._loop is not None
245 task = cast(Coroutine[Any, Any, bool], self._done_event.wait())
246 return asyncio.run_coroutine_threadsafe(task, self._loop)
248 def drain(self, at_exit: bool = False) -> Union[bool, Awaitable[bool]]: 1a
249 """
250 Stop this instance of the service and wait for remaining work to be completed.
252 Returns an awaitable if called from an async context.
253 """
254 future = self._drain(at_exit=at_exit)
255 if get_running_loop() is not None:
256 return asyncio.wrap_future(future)
257 else:
258 return future.result()
260 @classmethod 1a
261 def drain_all( 1a
262 cls, timeout: Optional[float] = None, at_exit: bool = True
263 ) -> Union[
264 tuple[
265 set[concurrent.futures.Future[bool]], set[concurrent.futures.Future[bool]]
266 ],
267 Coroutine[
268 Any,
269 Any,
270 Optional[tuple[set[asyncio.Future[bool]], set[asyncio.Future[bool]]]],
271 ],
272 ]:
273 """
274 Stop all instances of the service and wait for all remaining work to be
275 completed.
277 Returns an awaitable if called from an async context.
278 """
279 futures: list[concurrent.futures.Future[bool]] = [] 1a
280 with cls._instance_lock: 1a
281 instances = tuple(cls._instances.values()) 1a
283 for instance in instances: 283 ↛ 284line 283 didn't jump to line 284 because the loop on line 283 never started1a
284 futures.append(instance._drain(at_exit=at_exit))
286 if get_running_loop() is not None: 286 ↛ 287line 286 didn't jump to line 287 because the condition on line 286 was never true1a
287 if futures:
288 return asyncio.wait(
289 [asyncio.wrap_future(fut) for fut in futures], timeout=timeout
290 )
291 # `wait` errors if it receives an empty list but we need to return a
292 # coroutine still
293 return asyncio.sleep(0)
294 else:
295 return concurrent.futures.wait(futures, timeout=timeout) 1a
297 def wait_until_empty(self) -> None: 1a
298 """
299 Wait until the queue is empty and all items have been processed.
300 """
301 self._queue.join()
303 @classmethod 1a
304 def instance(cls, *args: Hashable) -> Self: 1a
305 """
306 Get an instance of the service.
308 If an instance already exists with the given arguments, it will be returned.
309 """
310 with cls._instance_lock:
311 key = hash((cls, *args))
312 if key not in cls._instances:
313 cls._instances[key] = cls._new_instance(*args)
315 return cls._instances[key]
317 def _remove_instance(self): 1a
318 self._instances.pop(self._key, None)
320 @classmethod 1a
321 def _new_instance(cls, *args: Hashable) -> Self: 1a
322 """
323 Create and start a new instance of the service.
324 """
325 instance = cls(*args)
327 # If already on the global loop, just start it here to avoid deadlock
328 if threading.get_ident() == get_global_loop().thread.ident:
329 instance.start()
331 # Otherwise, bind the service to the global loop
332 else:
333 from_sync.call_soon_in_loop_thread(create_call(instance.start)).result()
335 return instance
338class QueueService(_QueueServiceBase[T]): 1a
339 def send(self, item: T) -> None: 1a
340 """
341 Send an item to this instance of the service.
342 """
343 with self._lock:
344 if self._stopped:
345 raise RuntimeError("Cannot put items in a stopped service instance.")
347 logger.debug("Service %r enqueuing item %r", self, item)
348 self._queue.put_nowait(self._prepare_item(item))
350 def _prepare_item(self, item: T) -> T: 1a
351 """
352 Prepare an item for submission to the service. This is called before
353 the item is sent to the service.
355 The default implementation returns the item unchanged.
356 """
357 return item
359 @abc.abstractmethod 1a
360 async def _handle(self, item: T) -> None: 1a
361 """
362 Process an item sent to the service.
363 """
366class FutureQueueService( 1a
367 _QueueServiceBase[tuple[Unpack[Ts], concurrent.futures.Future[R]]]
368):
369 """Queued service that provides a future that is signalled with the acquired result for each item
371 If there was a failure acquiring, the future result is set to the exception.
373 Type Parameters:
374 Ts: the tuple of types that make up sent arguments
375 R: the type returned for each item once acquired
377 """
379 async def _handle( 1a
380 self, item: tuple[Unpack[Ts], concurrent.futures.Future[R]]
381 ) -> None:
382 send_item, future = item[:-1], item[-1]
383 try:
384 response = await self.acquire(*send_item)
385 except Exception as exc:
386 # If the request to the increment endpoint fails in a non-standard
387 # way, we need to set the future's result so it'll be re-raised in
388 # the context of the caller.
389 future.set_exception(exc)
390 raise exc
391 else:
392 future.set_result(response)
394 @abc.abstractmethod 1a
395 async def acquire(self, *args: Unpack[Ts]) -> R: 1a
396 raise NotImplementedError
398 def send(self, item: tuple[Unpack[Ts]]) -> concurrent.futures.Future[R]: 1a
399 with self._lock:
400 if self._stopped:
401 raise RuntimeError("Cannot put items in a stopped service instance.")
403 logger.debug("Service %r enqueuing item %r", self, item)
404 future: concurrent.futures.Future[R] = concurrent.futures.Future()
405 self._queue.put_nowait((*self._prepare_item(item), future))
407 return future
409 def _prepare_item(self, item: tuple[Unpack[Ts]]) -> tuple[Unpack[Ts]]: 1a
410 """
411 Prepare an item for submission to the service. This is called before
412 the item is sent to the service.
414 The default implementation returns the item unchanged.
415 """
416 return item
419class BatchedQueueService(QueueService[T]): 1a
420 """
421 A queue service that handles a batch of items instead of a single item at a time.
423 Items will be processed when the batch reaches the configured `_max_batch_size`
424 or after an interval of `_min_interval` seconds (if set).
425 """
427 _max_batch_size: int 1a
428 _min_interval: Optional[float] = None 1a
430 @property 1a
431 def min_interval(self) -> float | None: 1a
432 return self.__class__._min_interval
434 @property 1a
435 def max_batch_size(self) -> int: 1a
436 return self.__class__._max_batch_size
438 async def _main_loop(self): 1a
439 done = False
441 while not done:
442 batch: list[T] = []
443 batch_size = 0
445 # Pull items from the queue until we reach the batch size
446 deadline = get_deadline(self.min_interval)
447 while batch_size < self.max_batch_size:
448 try:
449 item = await self._queue_get_thread.submit(
450 create_call(self._queue.get, timeout=get_timeout(deadline))
451 ).aresult()
453 if item is None:
454 done = True
455 break
457 batch.append(item)
458 batch_size += self._get_size(item)
459 logger.debug(
460 "Service %r added item %r to batch (size %s/%s)",
461 self,
462 item,
463 batch_size,
464 self.max_batch_size,
465 )
466 except queue.Empty:
467 # Process the batch after `min_interval` even if it is smaller than
468 # the batch size
469 break
471 if not batch:
472 continue
474 logger.debug(
475 "Service %r processing batch of size %s",
476 self,
477 batch_size,
478 )
479 try:
480 await self._handle_batch(batch)
481 except Exception:
482 log_traceback = logger.isEnabledFor(logging.DEBUG)
483 logger.error(
484 "Service %r failed to process batch of size %s",
485 self,
486 batch_size,
487 exc_info=log_traceback,
488 )
490 @abc.abstractmethod 1a
491 async def _handle_batch(self, items: list[T]) -> None: 1a
492 """
493 Process a batch of items sent to the service.
494 """
496 async def _handle(self, item: T) -> NoReturn: 1a
497 raise AssertionError(
498 "`_handle` should never be called for batched queue services"
499 )
501 def _get_size(self, item: T) -> int: 1a
502 """
503 Calculate the size of a single item.
504 """
505 # By default, batch size is just the number of items
506 return 1
509@contextlib.contextmanager 1a
510def drain_on_exit(service: QueueService[Any]) -> Generator[None, Any, None]: 1a
511 yield
512 service.drain_all(at_exit=True)
515@contextlib.asynccontextmanager 1a
516async def drain_on_exit_async(service: QueueService[Any]) -> AsyncGenerator[None, Any]: 1a
517 yield
518 drain_all = service.drain_all(at_exit=True)
519 if TYPE_CHECKING:
520 assert not isinstance(drain_all, tuple)
521 await drain_all