Coverage for /usr/local/lib/python3.12/site-packages/prefect/task_runs.py: 24%

113 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 13:38 +0000

1from __future__ import annotations 1a

2 

3import asyncio 1a

4import atexit 1a

5import threading 1a

6import uuid 1a

7from typing import TYPE_CHECKING, Any, Callable, Dict, Optional 1a

8 

9import anyio 1a

10from cachetools import TTLCache 1a

11from typing_extensions import Self 1a

12 

13from prefect._internal.concurrency.api import create_call, from_async, from_sync 1a

14from prefect._internal.concurrency.threads import get_global_loop 1a

15from prefect.client.schemas.objects import TERMINAL_STATES 1a

16from prefect.events.clients import get_events_subscriber 1a

17from prefect.events.filters import EventFilter, EventNameFilter 1a

18from prefect.logging.loggers import get_logger 1a

19from prefect.states import State 1a

20 

21if TYPE_CHECKING: 21 ↛ 22line 21 didn't jump to line 22 because the condition on line 21 was never true1a

22 import logging 

23 

24 

25class TaskRunWaiter: 1a

26 """ 

27 A service used for waiting for a task run to finish. 

28 

29 This service listens for task run events and provides a way to wait for a specific 

30 task run to finish. This is useful for waiting for a task run to finish before 

31 continuing execution. 

32 

33 The service is a singleton and must be started before use. The service will 

34 automatically start when the first instance is created. A single websocket 

35 connection is used to listen for task run events. 

36 

37 The service can be used to wait for a task run to finish by calling 

38 `TaskRunWaiter.wait_for_task_run` with the task run ID to wait for. The method 

39 will return when the task run has finished or the timeout has elapsed. 

40 

41 The service will automatically stop when the Python process exits or when the 

42 global loop thread is stopped. 

43 

44 Example: 

45 ```python 

46 import asyncio 

47 from uuid import uuid4 

48 

49 from prefect import task 

50 from prefect.task_engine import run_task_async 

51 from prefect.task_runs import TaskRunWaiter 

52 

53 

54 @task 

55 async def test_task(): 

56 await asyncio.sleep(5) 

57 print("Done!") 

58 

59 

60 async def main(): 

61 task_run_id = uuid4() 

62 asyncio.create_task(run_task_async(task=test_task, task_run_id=task_run_id)) 

63 

64 await TaskRunWaiter.wait_for_task_run(task_run_id) 

65 print("Task run finished") 

66 

67 

68 if __name__ == "__main__": 

69 asyncio.run(main()) 

70 ``` 

71 """ 

72 

73 _instance: Optional[Self] = None 1a

74 _instance_lock = threading.Lock() 1a

75 

76 def __init__(self): 1a

77 self.logger: "logging.Logger" = get_logger("TaskRunWaiter") 

78 self._consumer_task: "asyncio.Task[None] | None" = None 

79 self._observed_completed_task_runs: TTLCache[ 

80 uuid.UUID, Optional[State[Any]] 

81 ] = TTLCache(maxsize=10000, ttl=600) 

82 self._completion_events: Dict[uuid.UUID, asyncio.Event] = {} 

83 self._completion_callbacks: Dict[uuid.UUID, Callable[[], None]] = {} 

84 self._loop: Optional[asyncio.AbstractEventLoop] = None 

85 self._observed_completed_task_runs_lock = threading.Lock() 

86 self._completion_events_lock = threading.Lock() 

87 self._started = False 

88 

89 def start(self) -> None: 1a

90 """ 

91 Start the TaskRunWaiter service. 

92 """ 

93 if self._started: 

94 return 

95 self.logger.debug("Starting TaskRunWaiter") 

96 loop_thread = get_global_loop() 

97 

98 if not asyncio.get_running_loop() == loop_thread.loop: 

99 raise RuntimeError("TaskRunWaiter must run on the global loop thread.") 

100 

101 self._loop = loop_thread.loop 

102 if TYPE_CHECKING: 

103 assert self._loop is not None 

104 

105 consumer_started = asyncio.Event() 

106 self._consumer_task = self._loop.create_task( 

107 self._consume_events(consumer_started) 

108 ) 

109 asyncio.run_coroutine_threadsafe(consumer_started.wait(), self._loop) 

110 

111 loop_thread.add_shutdown_call(create_call(self.stop)) 

112 atexit.register(self.stop) 

113 self._started = True 

114 

115 async def _consume_events(self, consumer_started: asyncio.Event): 1a

116 async with get_events_subscriber( 

117 filter=EventFilter( 

118 event=EventNameFilter( 

119 name=[ 

120 f"prefect.task-run.{state.name.title()}" 

121 for state in TERMINAL_STATES 

122 ], 

123 ) 

124 ) 

125 ) as subscriber: 

126 consumer_started.set() 

127 async for event in subscriber: 

128 try: 

129 self.logger.debug( 

130 f"Received event: {event.resource['prefect.resource.id']}" 

131 ) 

132 task_run_id = uuid.UUID( 

133 event.resource["prefect.resource.id"].replace( 

134 "prefect.task-run.", "" 

135 ) 

136 ) 

137 

138 # Extract the state from the event 

139 # All events should have validated_state since we don't support 

140 # new clients with old servers 

141 state_data = State.model_validate( 

142 { 

143 "id": event.id, 

144 "timestamp": event.occurred, 

145 **event.payload["validated_state"], 

146 } 

147 ) 

148 state_data.state_details.task_run_id = task_run_id 

149 

150 with self._observed_completed_task_runs_lock: 

151 # Cache the state for a short period of time to avoid 

152 # unnecessary waits 

153 self._observed_completed_task_runs[task_run_id] = state_data 

154 with self._completion_events_lock: 

155 # Set the event for the task run ID if it is in the cache 

156 # so the waiter can wake up the waiting coroutine 

157 if task_run_id in self._completion_events: 

158 self._completion_events[task_run_id].set() 

159 if task_run_id in self._completion_callbacks: 

160 self._completion_callbacks[task_run_id]() 

161 except Exception as exc: 

162 self.logger.error(f"Error processing event: {exc}") 

163 

164 def stop(self) -> None: 1a

165 """ 

166 Stop the TaskRunWaiter service. 

167 """ 

168 self.logger.debug("Stopping TaskRunWaiter") 

169 if self._consumer_task: 

170 self._consumer_task.cancel() 

171 self._consumer_task = None 

172 self.__class__._instance = None 

173 self._started = False 

174 

175 @classmethod 1a

176 async def wait_for_task_run( 1a

177 cls, task_run_id: uuid.UUID, timeout: Optional[float] = None 

178 ) -> Optional[State[Any]]: 

179 """ 

180 Wait for a task run to finish and return its final state. 

181 

182 Note this relies on a websocket connection to receive events from the server 

183 and will not work with an ephemeral server. 

184 

185 Args: 

186 task_run_id: The ID of the task run to wait for. 

187 timeout: The maximum time to wait for the task run to 

188 finish. Defaults to None. 

189 

190 Returns: 

191 The final state of the task run if available, None otherwise. 

192 """ 

193 instance = cls.instance() 

194 with instance._observed_completed_task_runs_lock: 

195 if task_run_id in instance._observed_completed_task_runs: 

196 return instance._observed_completed_task_runs[task_run_id] 

197 

198 # Need to create event in loop thread to ensure it can be set 

199 # from the loop thread 

200 finished_event = await from_async.wait_for_call_in_loop_thread( 

201 create_call(asyncio.Event) 

202 ) 

203 with instance._completion_events_lock: 

204 # Cache the event for the task run ID so the consumer can set it 

205 # when the event is received 

206 instance._completion_events[task_run_id] = finished_event 

207 

208 try: 

209 # Now check one more time whether the task run arrived before we start to 

210 # wait on it, in case it came in while we were setting up the event above. 

211 with instance._observed_completed_task_runs_lock: 

212 if task_run_id in instance._observed_completed_task_runs: 

213 return instance._observed_completed_task_runs[task_run_id] 

214 

215 with anyio.move_on_after(delay=timeout): 

216 await from_async.wait_for_call_in_loop_thread( 

217 create_call(finished_event.wait) 

218 ) 

219 

220 # After waiting, retrieve the state from the cache 

221 with instance._observed_completed_task_runs_lock: 

222 return instance._observed_completed_task_runs.get(task_run_id) 

223 finally: 

224 with instance._completion_events_lock: 

225 # Remove the event from the cache after it has been waited on 

226 instance._completion_events.pop(task_run_id, None) 

227 

228 @classmethod 1a

229 def add_done_callback( 1a

230 cls, task_run_id: uuid.UUID, callback: Callable[[], None] 

231 ) -> None: 

232 """ 

233 Add a callback to be called when a task run finishes. 

234 

235 Args: 

236 task_run_id: The ID of the task run to wait for. 

237 callback: The callback to call when the task run finishes. 

238 """ 

239 instance = cls.instance() 

240 with instance._observed_completed_task_runs_lock: 

241 if task_run_id in instance._observed_completed_task_runs: 

242 callback() 

243 return 

244 

245 with instance._completion_events_lock: 

246 # Cache the event for the task run ID so the consumer can set it 

247 # when the event is received 

248 instance._completion_callbacks[task_run_id] = callback 

249 

250 @classmethod 1a

251 def instance(cls) -> Self: 1a

252 """ 

253 Get the singleton instance of TaskRunWaiter. 

254 """ 

255 with cls._instance_lock: 

256 if cls._instance is None: 

257 cls._instance = cls._new_instance() 

258 return cls._instance 

259 

260 @classmethod 1a

261 def _new_instance(cls): 1a

262 instance = cls() 

263 

264 if threading.get_ident() == get_global_loop().thread.ident: 

265 instance.start() 

266 else: 

267 from_sync.call_soon_in_loop_thread(create_call(instance.start)).result() 

268 

269 return instance