Coverage for /usr/local/lib/python3.12/site-packages/prefect/_internal/concurrency/services.py: 30%

258 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1from __future__ import annotations 1a

2 

3import abc 1a

4import asyncio 1a

5import concurrent.futures 1a

6import contextlib 1a

7import logging 1a

8import os 1a

9import queue 1a

10import threading 1a

11import weakref 1a

12from collections.abc import AsyncGenerator, Awaitable, Coroutine, Generator, Hashable 1a

13from typing import TYPE_CHECKING, Any, Generic, NoReturn, Optional, Union, cast 1a

14 

15from typing_extensions import Self, TypeVar, TypeVarTuple, Unpack 1a

16 

17from prefect._internal.concurrency import logger 1a

18from prefect._internal.concurrency.api import create_call, from_sync 1a

19from prefect._internal.concurrency.cancellation import get_deadline, get_timeout 1a

20from prefect._internal.concurrency.event_loop import get_running_loop 1a

21from prefect._internal.concurrency.threads import WorkerThread, get_global_loop 1a

22 

23T = TypeVar("T") 1a

24Ts = TypeVarTuple("Ts") 1a

25R = TypeVar("R", infer_variance=True) 1a

26 

27# Track all active services for fork handling 

28_active_services: weakref.WeakSet[_QueueServiceBase[Any]] = weakref.WeakSet() 1a

29 

30 

31def _reset_services_after_fork(): 1a

32 """ 

33 Reset service state after fork() to prevent multiprocessing deadlocks on Linux. 

34 

35 Called by os.register_at_fork() in the child process after fork(). 

36 """ 

37 for service in list(_active_services): 

38 service.reset_for_fork() 

39 

40 # Reset the class-level instance tracking 

41 _QueueServiceBase.reset_instances_for_fork() 

42 

43 

44# Register fork handler if supported (POSIX systems) 

45if hasattr(os, "register_at_fork"): 45 ↛ 57line 45 didn't jump to line 57 because the condition on line 45 was always true1a

46 try: 1a

47 os.register_at_fork(after_in_child=_reset_services_after_fork) 1a

48 except RuntimeError as e: 

49 # Might fail in certain contexts (e.g., if already in a child process) 

50 logger.debug( 

51 "failed to register fork handler: %s (this may occur in child processes)", 

52 e, 

53 ) 

54 pass 

55 

56 

57class _QueueServiceBase(abc.ABC, Generic[T]): 1a

58 _instances: dict[int, Self] = {} 1a

59 _instance_lock = threading.Lock() 1a

60 

61 def __init__(self, *args: Hashable) -> None: 1a

62 self._queue: queue.Queue[Optional[T]] = queue.Queue() 

63 self._loop: Optional[asyncio.AbstractEventLoop] = None 

64 self._done_event: Optional[asyncio.Event] = None 

65 self._task: Optional[asyncio.Task[None]] = None 

66 self._stopped: bool = False 

67 self._started: bool = False 

68 self._key = hash((self.__class__, *args)) 

69 self._lock = threading.Lock() 

70 self._queue_get_thread = WorkerThread( 

71 # TODO: This thread should not need to be a daemon but when it is not, it 

72 # can prevent the interpreter from exiting. 

73 daemon=True, 

74 name=f"{type(self).__name__}Thread", 

75 ) 

76 self._logger = logging.getLogger(f"{type(self).__name__}") 

77 

78 # Track this instance for fork handling 

79 _active_services.add(self) 

80 

81 def reset_for_fork(self) -> None: 1a

82 """Reset instance state after fork() to prevent deadlocks in child process.""" 

83 self._stopped = True 

84 self._started = False 

85 self._loop = None 

86 self._done_event = None 

87 self._task = None 

88 self._queue = queue.Queue() 

89 self._lock = threading.Lock() 

90 

91 @classmethod 1a

92 def reset_instances_for_fork(cls) -> None: 1a

93 """Reset class-level state after fork() to prevent deadlocks in child process.""" 

94 cls._instances.clear() 

95 cls._instance_lock = threading.Lock() 

96 

97 def start(self) -> None: 1a

98 logger.debug("Starting service %r", self) 

99 loop_thread = get_global_loop() 

100 

101 if not asyncio.get_running_loop() == getattr(loop_thread, "_loop"): 

102 raise RuntimeError("Services must run on the global loop thread.") 

103 

104 self._loop = asyncio.get_running_loop() 

105 self._done_event = asyncio.Event() 

106 self._task = self._loop.create_task(self._run()) 

107 self._queue_get_thread.start() 

108 self._started = True 

109 

110 # Ensure that we wait for worker completion before loop thread shutdown 

111 loop_thread.add_shutdown_call(create_call(self.drain)) 

112 

113 # Stop at interpreter exit by default 

114 # Handling items may require spawning a thread and in 3.9 new threads 

115 # cannot be spawned after the interpreter finalizes threads which 

116 # happens _before_ the normal `atexit` hook is called resulting in 

117 # failure to process items. This is particularly relevant for services 

118 # which use an httpx client. See related issue at 

119 # https://github.com/python/cpython/issues/86813 

120 threading._register_atexit(self._at_exit) # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] 

121 

122 def _at_exit(self) -> None: 1a

123 self.drain(at_exit=True) 

124 

125 def _stop(self, at_exit: bool = False) -> None: 1a

126 """ 

127 Stop running this instance. 

128 

129 Does not wait for the instance to finish. See `drain`. 

130 """ 

131 

132 if self._stopped: 

133 return 

134 

135 with self._lock: 

136 if not at_exit: # The logger may not be available during interpreter exit 

137 logger.debug("Stopping service %r", self) 

138 

139 # Stop sending work to this instance 

140 self._remove_instance() 

141 self._stopped = True 

142 

143 # Allow asyncio task to be garbage-collected. Its context may contain 

144 # references to all Prefect Task calls made during a flow run, through 

145 # EngineContext. Issue #10338. 

146 self._task = None 

147 

148 # Signal completion to the loop 

149 self._queue.put_nowait(None) 

150 

151 @abc.abstractmethod 1a

152 def send(self, item: Any) -> Any: 1a

153 raise NotImplementedError 

154 

155 async def _run(self) -> None: 1a

156 try: 

157 async with self._lifespan(): 

158 await self._main_loop() 

159 except BaseException: 

160 self._remove_instance() 

161 # The logging call yields to another thread, so we must remove the instance 

162 # before reporting the failure to prevent retrieval of a dead instance 

163 log_traceback = logger.isEnabledFor(logging.DEBUG) 

164 logger.error( 

165 "Service %r failed with %s pending items.", 

166 type(self).__name__, 

167 self._queue.qsize(), 

168 exc_info=log_traceback, 

169 ) 

170 finally: 

171 self._remove_instance() 

172 

173 # Shutdown the worker thread 

174 self._queue_get_thread.shutdown() 

175 

176 self._stopped = True 

177 assert self._done_event is not None 

178 self._done_event.set() 

179 

180 async def _main_loop(self) -> None: 1a

181 last_log_time = 0 

182 log_interval = 4 # log every 4 seconds 

183 

184 while True: 

185 item: Optional[T] = await self._queue_get_thread.submit( 

186 create_call(self._queue.get) 

187 ).aresult() 

188 

189 if self._stopped: 

190 current_time = asyncio.get_event_loop().time() 

191 queue_size = self._queue.qsize() 

192 

193 if current_time - last_log_time >= log_interval and queue_size > 0: 

194 self._logger.warning( 

195 f"Still processing items: {queue_size} items remaining..." 

196 ) 

197 last_log_time = current_time 

198 

199 if item is None: 

200 logger.debug("Exiting service %r", self) 

201 self._queue.task_done() 

202 break 

203 

204 try: 

205 logger.debug("Service %r handling item %r", self, item) 

206 await self._handle(item) 

207 except Exception: 

208 log_traceback = logger.isEnabledFor(logging.DEBUG) 

209 logger.error( 

210 "Service %r failed to process item %r", 

211 type(self).__name__, 

212 item, 

213 exc_info=log_traceback, 

214 ) 

215 finally: 

216 self._queue.task_done() 

217 

218 @abc.abstractmethod 1a

219 async def _handle(self, item: Any) -> Any: 1a

220 raise NotImplementedError 

221 

222 @contextlib.asynccontextmanager 1a

223 async def _lifespan(self) -> AsyncGenerator[None, Any]: 1a

224 """ 

225 Perform any setup and teardown for the service. 

226 """ 

227 yield 

228 

229 def _drain(self, at_exit: bool = False) -> concurrent.futures.Future[bool]: 1a

230 """ 

231 Internal implementation for `drain`. Returns a future for sync/async interfaces. 

232 """ 

233 if not at_exit: # The logger may not be available during interpreter exit 

234 logger.debug("Draining service %r", self) 

235 

236 self._stop(at_exit=at_exit) 

237 

238 assert self._done_event is not None 

239 if self._done_event.is_set(): 

240 future: concurrent.futures.Future[bool] = concurrent.futures.Future() 

241 future.set_result(False) 

242 return future 

243 

244 assert self._loop is not None 

245 task = cast(Coroutine[Any, Any, bool], self._done_event.wait()) 

246 return asyncio.run_coroutine_threadsafe(task, self._loop) 

247 

248 def drain(self, at_exit: bool = False) -> Union[bool, Awaitable[bool]]: 1a

249 """ 

250 Stop this instance of the service and wait for remaining work to be completed. 

251 

252 Returns an awaitable if called from an async context. 

253 """ 

254 future = self._drain(at_exit=at_exit) 

255 if get_running_loop() is not None: 

256 return asyncio.wrap_future(future) 

257 else: 

258 return future.result() 

259 

260 @classmethod 1a

261 def drain_all( 1a

262 cls, timeout: Optional[float] = None, at_exit: bool = True 

263 ) -> Union[ 

264 tuple[ 

265 set[concurrent.futures.Future[bool]], set[concurrent.futures.Future[bool]] 

266 ], 

267 Coroutine[ 

268 Any, 

269 Any, 

270 Optional[tuple[set[asyncio.Future[bool]], set[asyncio.Future[bool]]]], 

271 ], 

272 ]: 

273 """ 

274 Stop all instances of the service and wait for all remaining work to be 

275 completed. 

276 

277 Returns an awaitable if called from an async context. 

278 """ 

279 futures: list[concurrent.futures.Future[bool]] = [] 1a

280 with cls._instance_lock: 1a

281 instances = tuple(cls._instances.values()) 1a

282 

283 for instance in instances: 283 ↛ 284line 283 didn't jump to line 284 because the loop on line 283 never started1a

284 futures.append(instance._drain(at_exit=at_exit)) 

285 

286 if get_running_loop() is not None: 286 ↛ 287line 286 didn't jump to line 287 because the condition on line 286 was never true1a

287 if futures: 

288 return asyncio.wait( 

289 [asyncio.wrap_future(fut) for fut in futures], timeout=timeout 

290 ) 

291 # `wait` errors if it receives an empty list but we need to return a 

292 # coroutine still 

293 return asyncio.sleep(0) 

294 else: 

295 return concurrent.futures.wait(futures, timeout=timeout) 1a

296 

297 def wait_until_empty(self) -> None: 1a

298 """ 

299 Wait until the queue is empty and all items have been processed. 

300 """ 

301 self._queue.join() 

302 

303 @classmethod 1a

304 def instance(cls, *args: Hashable) -> Self: 1a

305 """ 

306 Get an instance of the service. 

307 

308 If an instance already exists with the given arguments, it will be returned. 

309 """ 

310 with cls._instance_lock: 

311 key = hash((cls, *args)) 

312 if key not in cls._instances: 

313 cls._instances[key] = cls._new_instance(*args) 

314 

315 return cls._instances[key] 

316 

317 def _remove_instance(self): 1a

318 self._instances.pop(self._key, None) 

319 

320 @classmethod 1a

321 def _new_instance(cls, *args: Hashable) -> Self: 1a

322 """ 

323 Create and start a new instance of the service. 

324 """ 

325 instance = cls(*args) 

326 

327 # If already on the global loop, just start it here to avoid deadlock 

328 if threading.get_ident() == get_global_loop().thread.ident: 

329 instance.start() 

330 

331 # Otherwise, bind the service to the global loop 

332 else: 

333 from_sync.call_soon_in_loop_thread(create_call(instance.start)).result() 

334 

335 return instance 

336 

337 

338class QueueService(_QueueServiceBase[T]): 1a

339 def send(self, item: T) -> None: 1a

340 """ 

341 Send an item to this instance of the service. 

342 """ 

343 with self._lock: 

344 if self._stopped: 

345 raise RuntimeError("Cannot put items in a stopped service instance.") 

346 

347 logger.debug("Service %r enqueuing item %r", self, item) 

348 self._queue.put_nowait(self._prepare_item(item)) 

349 

350 def _prepare_item(self, item: T) -> T: 1a

351 """ 

352 Prepare an item for submission to the service. This is called before 

353 the item is sent to the service. 

354 

355 The default implementation returns the item unchanged. 

356 """ 

357 return item 

358 

359 @abc.abstractmethod 1a

360 async def _handle(self, item: T) -> None: 1a

361 """ 

362 Process an item sent to the service. 

363 """ 

364 

365 

366class FutureQueueService( 1a

367 _QueueServiceBase[tuple[Unpack[Ts], concurrent.futures.Future[R]]] 

368): 

369 """Queued service that provides a future that is signalled with the acquired result for each item 

370 

371 If there was a failure acquiring, the future result is set to the exception. 

372 

373 Type Parameters: 

374 Ts: the tuple of types that make up sent arguments 

375 R: the type returned for each item once acquired 

376 

377 """ 

378 

379 async def _handle( 1a

380 self, item: tuple[Unpack[Ts], concurrent.futures.Future[R]] 

381 ) -> None: 

382 send_item, future = item[:-1], item[-1] 

383 try: 

384 response = await self.acquire(*send_item) 

385 except Exception as exc: 

386 # If the request to the increment endpoint fails in a non-standard 

387 # way, we need to set the future's result so it'll be re-raised in 

388 # the context of the caller. 

389 future.set_exception(exc) 

390 raise exc 

391 else: 

392 future.set_result(response) 

393 

394 @abc.abstractmethod 1a

395 async def acquire(self, *args: Unpack[Ts]) -> R: 1a

396 raise NotImplementedError 

397 

398 def send(self, item: tuple[Unpack[Ts]]) -> concurrent.futures.Future[R]: 1a

399 with self._lock: 

400 if self._stopped: 

401 raise RuntimeError("Cannot put items in a stopped service instance.") 

402 

403 logger.debug("Service %r enqueuing item %r", self, item) 

404 future: concurrent.futures.Future[R] = concurrent.futures.Future() 

405 self._queue.put_nowait((*self._prepare_item(item), future)) 

406 

407 return future 

408 

409 def _prepare_item(self, item: tuple[Unpack[Ts]]) -> tuple[Unpack[Ts]]: 1a

410 """ 

411 Prepare an item for submission to the service. This is called before 

412 the item is sent to the service. 

413 

414 The default implementation returns the item unchanged. 

415 """ 

416 return item 

417 

418 

419class BatchedQueueService(QueueService[T]): 1a

420 """ 

421 A queue service that handles a batch of items instead of a single item at a time. 

422 

423 Items will be processed when the batch reaches the configured `_max_batch_size` 

424 or after an interval of `_min_interval` seconds (if set). 

425 """ 

426 

427 _max_batch_size: int 1a

428 _min_interval: Optional[float] = None 1a

429 

430 @property 1a

431 def min_interval(self) -> float | None: 1a

432 return self.__class__._min_interval 

433 

434 @property 1a

435 def max_batch_size(self) -> int: 1a

436 return self.__class__._max_batch_size 

437 

438 async def _main_loop(self): 1a

439 done = False 

440 

441 while not done: 

442 batch: list[T] = [] 

443 batch_size = 0 

444 

445 # Pull items from the queue until we reach the batch size 

446 deadline = get_deadline(self.min_interval) 

447 while batch_size < self.max_batch_size: 

448 try: 

449 item = await self._queue_get_thread.submit( 

450 create_call(self._queue.get, timeout=get_timeout(deadline)) 

451 ).aresult() 

452 

453 if item is None: 

454 done = True 

455 break 

456 

457 batch.append(item) 

458 batch_size += self._get_size(item) 

459 logger.debug( 

460 "Service %r added item %r to batch (size %s/%s)", 

461 self, 

462 item, 

463 batch_size, 

464 self.max_batch_size, 

465 ) 

466 except queue.Empty: 

467 # Process the batch after `min_interval` even if it is smaller than 

468 # the batch size 

469 break 

470 

471 if not batch: 

472 continue 

473 

474 logger.debug( 

475 "Service %r processing batch of size %s", 

476 self, 

477 batch_size, 

478 ) 

479 try: 

480 await self._handle_batch(batch) 

481 except Exception: 

482 log_traceback = logger.isEnabledFor(logging.DEBUG) 

483 logger.error( 

484 "Service %r failed to process batch of size %s", 

485 self, 

486 batch_size, 

487 exc_info=log_traceback, 

488 ) 

489 

490 @abc.abstractmethod 1a

491 async def _handle_batch(self, items: list[T]) -> None: 1a

492 """ 

493 Process a batch of items sent to the service. 

494 """ 

495 

496 async def _handle(self, item: T) -> NoReturn: 1a

497 raise AssertionError( 

498 "`_handle` should never be called for batched queue services" 

499 ) 

500 

501 def _get_size(self, item: T) -> int: 1a

502 """ 

503 Calculate the size of a single item. 

504 """ 

505 # By default, batch size is just the number of items 

506 return 1 

507 

508 

509@contextlib.contextmanager 1a

510def drain_on_exit(service: QueueService[Any]) -> Generator[None, Any, None]: 1a

511 yield 

512 service.drain_all(at_exit=True) 

513 

514 

515@contextlib.asynccontextmanager 1a

516async def drain_on_exit_async(service: QueueService[Any]) -> AsyncGenerator[None, Any]: 1a

517 yield 

518 drain_all = service.drain_all(at_exit=True) 

519 if TYPE_CHECKING: 

520 assert not isinstance(drain_all, tuple) 

521 await drain_all