Coverage for /usr/local/lib/python3.12/site-packages/prefect/task_runs.py: 24%
113 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
1from __future__ import annotations 1a
3import asyncio 1a
4import atexit 1a
5import threading 1a
6import uuid 1a
7from typing import TYPE_CHECKING, Any, Callable, Dict, Optional 1a
9import anyio 1a
10from cachetools import TTLCache 1a
11from typing_extensions import Self 1a
13from prefect._internal.concurrency.api import create_call, from_async, from_sync 1a
14from prefect._internal.concurrency.threads import get_global_loop 1a
15from prefect.client.schemas.objects import TERMINAL_STATES 1a
16from prefect.events.clients import get_events_subscriber 1a
17from prefect.events.filters import EventFilter, EventNameFilter 1a
18from prefect.logging.loggers import get_logger 1a
19from prefect.states import State 1a
21if TYPE_CHECKING: 21 ↛ 22line 21 didn't jump to line 22 because the condition on line 21 was never true1a
22 import logging
25class TaskRunWaiter: 1a
26 """
27 A service used for waiting for a task run to finish.
29 This service listens for task run events and provides a way to wait for a specific
30 task run to finish. This is useful for waiting for a task run to finish before
31 continuing execution.
33 The service is a singleton and must be started before use. The service will
34 automatically start when the first instance is created. A single websocket
35 connection is used to listen for task run events.
37 The service can be used to wait for a task run to finish by calling
38 `TaskRunWaiter.wait_for_task_run` with the task run ID to wait for. The method
39 will return when the task run has finished or the timeout has elapsed.
41 The service will automatically stop when the Python process exits or when the
42 global loop thread is stopped.
44 Example:
45 ```python
46 import asyncio
47 from uuid import uuid4
49 from prefect import task
50 from prefect.task_engine import run_task_async
51 from prefect.task_runs import TaskRunWaiter
54 @task
55 async def test_task():
56 await asyncio.sleep(5)
57 print("Done!")
60 async def main():
61 task_run_id = uuid4()
62 asyncio.create_task(run_task_async(task=test_task, task_run_id=task_run_id))
64 await TaskRunWaiter.wait_for_task_run(task_run_id)
65 print("Task run finished")
68 if __name__ == "__main__":
69 asyncio.run(main())
70 ```
71 """
73 _instance: Optional[Self] = None 1a
74 _instance_lock = threading.Lock() 1a
76 def __init__(self): 1a
77 self.logger: "logging.Logger" = get_logger("TaskRunWaiter")
78 self._consumer_task: "asyncio.Task[None] | None" = None
79 self._observed_completed_task_runs: TTLCache[
80 uuid.UUID, Optional[State[Any]]
81 ] = TTLCache(maxsize=10000, ttl=600)
82 self._completion_events: Dict[uuid.UUID, asyncio.Event] = {}
83 self._completion_callbacks: Dict[uuid.UUID, Callable[[], None]] = {}
84 self._loop: Optional[asyncio.AbstractEventLoop] = None
85 self._observed_completed_task_runs_lock = threading.Lock()
86 self._completion_events_lock = threading.Lock()
87 self._started = False
89 def start(self) -> None: 1a
90 """
91 Start the TaskRunWaiter service.
92 """
93 if self._started:
94 return
95 self.logger.debug("Starting TaskRunWaiter")
96 loop_thread = get_global_loop()
98 if not asyncio.get_running_loop() == loop_thread.loop:
99 raise RuntimeError("TaskRunWaiter must run on the global loop thread.")
101 self._loop = loop_thread.loop
102 if TYPE_CHECKING:
103 assert self._loop is not None
105 consumer_started = asyncio.Event()
106 self._consumer_task = self._loop.create_task(
107 self._consume_events(consumer_started)
108 )
109 asyncio.run_coroutine_threadsafe(consumer_started.wait(), self._loop)
111 loop_thread.add_shutdown_call(create_call(self.stop))
112 atexit.register(self.stop)
113 self._started = True
115 async def _consume_events(self, consumer_started: asyncio.Event): 1a
116 async with get_events_subscriber(
117 filter=EventFilter(
118 event=EventNameFilter(
119 name=[
120 f"prefect.task-run.{state.name.title()}"
121 for state in TERMINAL_STATES
122 ],
123 )
124 )
125 ) as subscriber:
126 consumer_started.set()
127 async for event in subscriber:
128 try:
129 self.logger.debug(
130 f"Received event: {event.resource['prefect.resource.id']}"
131 )
132 task_run_id = uuid.UUID(
133 event.resource["prefect.resource.id"].replace(
134 "prefect.task-run.", ""
135 )
136 )
138 # Extract the state from the event
139 # All events should have validated_state since we don't support
140 # new clients with old servers
141 state_data = State.model_validate(
142 {
143 "id": event.id,
144 "timestamp": event.occurred,
145 **event.payload["validated_state"],
146 }
147 )
148 state_data.state_details.task_run_id = task_run_id
150 with self._observed_completed_task_runs_lock:
151 # Cache the state for a short period of time to avoid
152 # unnecessary waits
153 self._observed_completed_task_runs[task_run_id] = state_data
154 with self._completion_events_lock:
155 # Set the event for the task run ID if it is in the cache
156 # so the waiter can wake up the waiting coroutine
157 if task_run_id in self._completion_events:
158 self._completion_events[task_run_id].set()
159 if task_run_id in self._completion_callbacks:
160 self._completion_callbacks[task_run_id]()
161 except Exception as exc:
162 self.logger.error(f"Error processing event: {exc}")
164 def stop(self) -> None: 1a
165 """
166 Stop the TaskRunWaiter service.
167 """
168 self.logger.debug("Stopping TaskRunWaiter")
169 if self._consumer_task:
170 self._consumer_task.cancel()
171 self._consumer_task = None
172 self.__class__._instance = None
173 self._started = False
175 @classmethod 1a
176 async def wait_for_task_run( 1a
177 cls, task_run_id: uuid.UUID, timeout: Optional[float] = None
178 ) -> Optional[State[Any]]:
179 """
180 Wait for a task run to finish and return its final state.
182 Note this relies on a websocket connection to receive events from the server
183 and will not work with an ephemeral server.
185 Args:
186 task_run_id: The ID of the task run to wait for.
187 timeout: The maximum time to wait for the task run to
188 finish. Defaults to None.
190 Returns:
191 The final state of the task run if available, None otherwise.
192 """
193 instance = cls.instance()
194 with instance._observed_completed_task_runs_lock:
195 if task_run_id in instance._observed_completed_task_runs:
196 return instance._observed_completed_task_runs[task_run_id]
198 # Need to create event in loop thread to ensure it can be set
199 # from the loop thread
200 finished_event = await from_async.wait_for_call_in_loop_thread(
201 create_call(asyncio.Event)
202 )
203 with instance._completion_events_lock:
204 # Cache the event for the task run ID so the consumer can set it
205 # when the event is received
206 instance._completion_events[task_run_id] = finished_event
208 try:
209 # Now check one more time whether the task run arrived before we start to
210 # wait on it, in case it came in while we were setting up the event above.
211 with instance._observed_completed_task_runs_lock:
212 if task_run_id in instance._observed_completed_task_runs:
213 return instance._observed_completed_task_runs[task_run_id]
215 with anyio.move_on_after(delay=timeout):
216 await from_async.wait_for_call_in_loop_thread(
217 create_call(finished_event.wait)
218 )
220 # After waiting, retrieve the state from the cache
221 with instance._observed_completed_task_runs_lock:
222 return instance._observed_completed_task_runs.get(task_run_id)
223 finally:
224 with instance._completion_events_lock:
225 # Remove the event from the cache after it has been waited on
226 instance._completion_events.pop(task_run_id, None)
228 @classmethod 1a
229 def add_done_callback( 1a
230 cls, task_run_id: uuid.UUID, callback: Callable[[], None]
231 ) -> None:
232 """
233 Add a callback to be called when a task run finishes.
235 Args:
236 task_run_id: The ID of the task run to wait for.
237 callback: The callback to call when the task run finishes.
238 """
239 instance = cls.instance()
240 with instance._observed_completed_task_runs_lock:
241 if task_run_id in instance._observed_completed_task_runs:
242 callback()
243 return
245 with instance._completion_events_lock:
246 # Cache the event for the task run ID so the consumer can set it
247 # when the event is received
248 instance._completion_callbacks[task_run_id] = callback
250 @classmethod 1a
251 def instance(cls) -> Self: 1a
252 """
253 Get the singleton instance of TaskRunWaiter.
254 """
255 with cls._instance_lock:
256 if cls._instance is None:
257 cls._instance = cls._new_instance()
258 return cls._instance
260 @classmethod 1a
261 def _new_instance(cls): 1a
262 instance = cls()
264 if threading.get_ident() == get_global_loop().thread.ident:
265 instance.start()
266 else:
267 from_sync.call_soon_in_loop_thread(create_call(instance.start)).result()
269 return instance