Coverage for /usr/local/lib/python3.12/site-packages/prefect/_internal/concurrency/waiters.py: 22%

158 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1""" 

2Implementations of `Waiter`s, which allow work to be sent back to a thread while it 

3waits for the result of the call. 

4""" 

5 

6import abc 1a

7import asyncio 1a

8import contextlib 1a

9import inspect 1a

10import queue 1a

11import threading 1a

12from collections import deque 1a

13from collections.abc import AsyncGenerator, Awaitable, Generator 1a

14from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar 1a

15from weakref import WeakKeyDictionary 1a

16 

17import anyio 1a

18 

19from prefect._internal.concurrency import logger 1a

20from prefect._internal.concurrency.calls import Call, Portal 1a

21from prefect._internal.concurrency.event_loop import call_soon_in_loop 1a

22from prefect._internal.concurrency.primitives import Event 1a

23 

24T = TypeVar("T") 1a

25 

26 

27# Waiters are stored in a queue for each thread 

28_WAITERS_BY_THREAD: "WeakKeyDictionary[threading.Thread, deque[Waiter[Any]]]" = ( 1a

29 WeakKeyDictionary() 

30) 

31 

32 

33def add_waiter_for_thread(waiter: "Waiter[Any]", thread: threading.Thread) -> None: 1a

34 """ 

35 Add a waiter for a thread. 

36 """ 

37 if thread not in _WAITERS_BY_THREAD: 

38 _WAITERS_BY_THREAD[thread] = deque() 

39 

40 _WAITERS_BY_THREAD[thread].append(waiter) 

41 

42 

43class Waiter(Portal, abc.ABC, Generic[T]): 1a

44 """ 

45 A waiter allows waiting for a call while routing callbacks to the 

46 the current thread. 

47 

48 Calls sent back to the waiter will be executed when waiting for the result. 

49 """ 

50 

51 def __init__(self, call: Call[T]) -> None: 1a

52 if not TYPE_CHECKING: 

53 if not isinstance(call, Call): # Guard against common mistake 

54 raise TypeError(f"Expected call of type `Call`; got {call!r}.") 

55 

56 self._call = call 

57 self._owner_thread = threading.current_thread() 

58 

59 # Set the waiter for the current thread 

60 add_waiter_for_thread(self, self._owner_thread) 

61 super().__init__() 

62 

63 def call_is_done(self) -> bool: 1a

64 return self._call.future.done() 

65 

66 @abc.abstractmethod 1a

67 def wait(self) -> T: 1a

68 """ 

69 Wait for the call to finish. 

70 

71 Watch for and execute any waiting callbacks. 

72 """ 

73 raise NotImplementedError() 

74 

75 @abc.abstractmethod 1a

76 def add_done_callback(self, callback: Call[Any]) -> None: 1a

77 """ 

78 Schedule a call to run when the waiter is done waiting. 

79 

80 If the waiter is already done, a `RuntimeError` error will be thrown. 

81 """ 

82 raise NotImplementedError() 

83 

84 def __repr__(self) -> str: 1a

85 return ( 

86 f"<{self.__class__.__name__} call={self._call}," 

87 f" owner={self._owner_thread.name!r}>" 

88 ) 

89 

90 

91class SyncWaiter(Waiter[T]): 1a

92 # Implementation of `Waiter` for use in synchronous contexts 

93 

94 def __init__(self, call: Call[T]) -> None: 1a

95 super().__init__(call=call) 

96 self._queue: queue.Queue[Optional[Call[T]]] = queue.Queue() 

97 self._done_callbacks: list[Call[Any]] = [] 

98 self._done_event = threading.Event() 

99 

100 def submit(self, call: Call[T]) -> Call[T]: 1a

101 """ 

102 Submit a callback to execute while waiting. 

103 """ 

104 if self.call_is_done(): 

105 raise RuntimeError(f"The call {self._call} is already done.") 

106 

107 self._queue.put_nowait(call) 

108 call.set_runner(self) 

109 return call 

110 

111 def _handle_waiting_callbacks(self) -> None: 1a

112 logger.debug("Waiter %r watching for callbacks", self) 

113 while True: 

114 callback = self._queue.get() 

115 if callback is None: 

116 break 

117 

118 # Ensure that callbacks are cancelled if the parent call is cancelled so 

119 # waiting never runs longer than the call 

120 self._call.future.add_cancel_callback(callback.future.cancel) 

121 callback.run() 

122 del callback 

123 

124 @contextlib.contextmanager 1a

125 def _handle_done_callbacks(self) -> Generator[None, Any, None]: 1a

126 try: 

127 yield 

128 finally: 

129 # Call done callbacks 

130 while self._done_callbacks: 

131 callback = self._done_callbacks.pop() 

132 if callback: 

133 callback.run() 

134 

135 def add_done_callback(self, callback: Call[Any]) -> None: 1a

136 if self._done_event.is_set(): 

137 raise RuntimeError("Cannot add done callbacks to done waiters.") 

138 else: 

139 self._done_callbacks.append(callback) 

140 

141 def wait(self) -> Call[T]: 1a

142 # Stop watching for work once the future is done 

143 self._call.future.add_done_callback(lambda _: self._queue.put_nowait(None)) 

144 self._call.future.add_done_callback(lambda _: self._done_event.set()) 

145 

146 with self._handle_done_callbacks(): 

147 self._handle_waiting_callbacks() 

148 

149 # Wait for the future to be done 

150 self._done_event.wait() 

151 

152 _WAITERS_BY_THREAD[self._owner_thread].remove(self) 

153 return self._call 

154 

155 

156class AsyncWaiter(Waiter[T]): 1a

157 # Implementation of `Waiter` for use in asynchronous contexts 

158 

159 def __init__(self, call: Call[T]) -> None: 1a

160 super().__init__(call=call) 

161 

162 # Delay instantiating loop and queue as there may not be a loop present yet 

163 self._loop: Optional[asyncio.AbstractEventLoop] = None 

164 self._queue: Optional[asyncio.Queue[Optional[Call[T]]]] = None 

165 self._early_submissions: list[Call[T]] = [] 

166 self._done_callbacks: list[Call[Any]] = [] 

167 self._done_event = Event() 

168 self._done_waiting = False 

169 

170 def submit(self, call: Call[T]) -> Call[T]: 1a

171 """ 

172 Submit a callback to execute while waiting. 

173 """ 

174 if self.call_is_done(): 

175 raise RuntimeError(f"The call {self._call} is already done.") 

176 

177 call.set_runner(self) 

178 

179 if not self._queue: 

180 # If the loop is not yet available, just push the call to a stack 

181 self._early_submissions.append(call) 

182 return call 

183 

184 # We must put items in the queue from the event loop that owns it 

185 if TYPE_CHECKING: 

186 assert self._loop is not None 

187 call_soon_in_loop(self._loop, self._queue.put_nowait, call) 

188 return call 

189 

190 def _resubmit_early_submissions(self) -> None: 1a

191 if TYPE_CHECKING: 

192 assert self._queue is not None 

193 assert self._loop is not None 

194 for call in self._early_submissions: 

195 # We must put items in the queue from the event loop that owns it 

196 call_soon_in_loop(self._loop, self._queue.put_nowait, call) 

197 self._early_submissions = [] 

198 

199 async def _handle_waiting_callbacks(self) -> None: 1a

200 logger.debug("Waiter %r watching for callbacks", self) 

201 tasks: list[Awaitable[None]] = [] 

202 

203 if TYPE_CHECKING: 

204 assert self._queue is not None 

205 

206 try: 

207 while True: 

208 callback = await self._queue.get() 

209 if callback is None: 

210 break 

211 

212 # Ensure that callbacks are cancelled if the parent call is cancelled so 

213 # waiting never runs longer than the call 

214 self._call.future.add_cancel_callback(callback.future.cancel) 

215 retval = callback.run() 

216 if inspect.isawaitable(retval): 

217 tasks.append(retval) 

218 

219 del callback 

220 

221 # Tasks are collected and awaited as a group; if each task was awaited in 

222 # the above loop, async work would not be executed concurrently 

223 await asyncio.gather(*tasks) 

224 finally: 

225 self._done_waiting = True 

226 

227 @contextlib.asynccontextmanager 1a

228 async def _handle_done_callbacks(self) -> AsyncGenerator[None, Any]: 1a

229 try: 

230 yield 

231 finally: 

232 # Call done callbacks 

233 while self._done_callbacks: 

234 callback = self._done_callbacks.pop() 

235 if callback: 

236 # We shield against cancellation so we can run the callback 

237 with anyio.CancelScope(shield=True): 

238 await self._run_done_callback(callback) 

239 

240 async def _run_done_callback(self, callback: Call[Any]) -> None: 1a

241 coro = callback.run() 

242 if coro: 

243 await coro 

244 

245 def add_done_callback(self, callback: Call[Any]) -> None: 1a

246 if self._done_event.is_set(): 

247 raise RuntimeError("Cannot add done callbacks to done waiters.") 

248 else: 

249 self._done_callbacks.append(callback) 

250 

251 def _signal_stop_waiting(self) -> None: 1a

252 # Only send a `None` to the queue if the waiter is still blocked reading from 

253 # the queue. Otherwise, it's possible that the event loop is stopped. 

254 if not self._done_waiting: 

255 assert self._loop is not None 

256 assert self._queue is not None 

257 call_soon_in_loop(self._loop, self._queue.put_nowait, None) 

258 

259 async def wait(self) -> Call[T]: 1a

260 # Assign the loop 

261 self._loop = asyncio.get_running_loop() 

262 self._queue = asyncio.Queue() 

263 self._resubmit_early_submissions() 

264 

265 # Stop watching for work once the future is done 

266 self._call.future.add_done_callback(lambda _: self._signal_stop_waiting()) 

267 self._call.future.add_done_callback(lambda _: self._done_event.set()) 

268 

269 async with self._handle_done_callbacks(): 

270 await self._handle_waiting_callbacks() 

271 

272 # Wait for the future to be done 

273 await self._done_event.wait() 

274 

275 _WAITERS_BY_THREAD[self._owner_thread].remove(self) 

276 return self._call