Coverage for /usr/local/lib/python3.12/site-packages/prefect/task_worker.py: 25%

268 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1from __future__ import annotations 1a

2 

3import asyncio 1a

4import inspect 1a

5import os 1a

6import signal 1a

7import socket 1a

8import sys 1a

9from concurrent.futures import ThreadPoolExecutor 1a

10from contextlib import AsyncExitStack 1a

11from contextvars import copy_context 1a

12from typing import TYPE_CHECKING, Any, Optional 1a

13from uuid import UUID 1a

14 

15import anyio 1a

16import anyio.abc 1a

17import uvicorn 1a

18from exceptiongroup import BaseExceptionGroup # novermin 1a

19from fastapi import FastAPI 1a

20from typing_extensions import ParamSpec, Self, TypeVar 1a

21from websockets.exceptions import InvalidStatus 1a

22 

23import prefect.types._datetime 1a

24from prefect import Task 1a

25from prefect._internal.compatibility.blocks import call_explicitly_async_block_method 1a

26from prefect._internal.concurrency.api import create_call, from_sync 1a

27from prefect.cache_policies import DEFAULT, NO_CACHE 1a

28from prefect.client.orchestration import get_client 1a

29from prefect.client.schemas.objects import TaskRun 1a

30from prefect.client.subscriptions import Subscription 1a

31from prefect.logging.loggers import get_logger 1a

32from prefect.results import ( 1a

33 ResultRecord, 

34 ResultRecordMetadata, 

35 ResultStore, 

36 get_or_create_default_task_scheduling_storage, 

37) 

38from prefect.settings import get_current_settings 1a

39from prefect.states import Pending 1a

40from prefect.task_engine import run_task_async, run_task_sync 1a

41from prefect.types import DateTime 1a

42from prefect.utilities.annotations import NotSet 1a

43from prefect.utilities.asyncutils import asyncnullcontext, sync_compatible 1a

44from prefect.utilities.engine import emit_task_run_state_change_event 1a

45from prefect.utilities.processutils import ( 1a

46 _register_signal, # pyright: ignore[reportPrivateUsage] 

47) 

48from prefect.utilities.services import start_client_metrics_server 1a

49from prefect.utilities.timeout import timeout_async 1a

50from prefect.utilities.urls import url_for 1a

51 

52if TYPE_CHECKING: 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true1a

53 import logging 

54 

55logger: "logging.Logger" = get_logger("task_worker") 1a

56 

57P = ParamSpec("P") 1a

58R = TypeVar("R", infer_variance=True) 1a

59 

60 

61class StopTaskWorker(Exception): 1a

62 """Raised when the task worker is stopped.""" 

63 

64 pass 1a

65 

66 

67def should_try_to_read_parameters(task: Task[P, R], task_run: TaskRun) -> bool: 1a

68 """Determines whether a task run should read parameters from the result store.""" 

69 if TYPE_CHECKING: 

70 assert task_run.state is not None 

71 new_enough_state_details = hasattr( 

72 task_run.state.state_details, "task_parameters_id" 

73 ) 

74 task_accepts_parameters = bool(inspect.signature(task.fn).parameters) 

75 

76 return new_enough_state_details and task_accepts_parameters 

77 

78 

79class TaskWorker: 1a

80 """This class is responsible for serving tasks that may be executed in the background 

81 by a task runner via the traditional engine machinery. 

82 

83 When `start()` is called, the task worker will open a websocket connection to a 

84 server-side queue of scheduled task runs. When a scheduled task run is found, the 

85 scheduled task run is submitted to the engine for execution with a minimal `EngineContext` 

86 so that the task run can be governed by orchestration rules. 

87 

88 Args: 

89 - tasks: A list of tasks to serve. These tasks will be submitted to the engine 

90 when a scheduled task run is found. 

91 - limit: The maximum number of tasks that can be run concurrently. Defaults to 10. 

92 Pass `None` to remove the limit. 

93 """ 

94 

95 def __init__( 1a

96 self, 

97 *tasks: Task[P, R], 

98 limit: int | None = 10, 

99 ): 

100 self.tasks: list["Task[..., Any]"] = [] 

101 for t in tasks: 

102 if not TYPE_CHECKING: 

103 if not isinstance(t, Task): 

104 continue 

105 

106 if t.cache_policy in [None, NO_CACHE, NotSet]: 

107 self.tasks.append( 

108 t.with_options(persist_result=True, cache_policy=DEFAULT) 

109 ) 

110 else: 

111 self.tasks.append(t.with_options(persist_result=True)) 

112 

113 self.task_keys: set[str] = set(t.task_key for t in tasks if isinstance(t, Task)) # pyright: ignore[reportUnnecessaryIsInstance] 

114 

115 self._started_at: Optional[DateTime] = None 

116 self.stopping: bool = False 

117 

118 self._client = get_client() 

119 self._exit_stack = AsyncExitStack() 

120 

121 try: 

122 asyncio.get_running_loop() 

123 except RuntimeError: 

124 raise RuntimeError( 

125 "TaskWorker must be initialized within an async context." 

126 ) 

127 

128 self._runs_task_group: Optional[anyio.abc.TaskGroup] = None 

129 self._executor = ThreadPoolExecutor(max_workers=limit if limit else None) 

130 self._limiter = anyio.CapacityLimiter(limit) if limit else None 

131 

132 self.in_flight_task_runs: dict[str, dict[UUID, DateTime]] = { 

133 task_key: {} for task_key in self.task_keys 

134 } 

135 self.finished_task_runs: dict[str, int] = { 

136 task_key: 0 for task_key in self.task_keys 

137 } 

138 

139 @property 1a

140 def client_id(self) -> str: 1a

141 return f"{socket.gethostname()}-{os.getpid()}" 

142 

143 @property 1a

144 def started_at(self) -> Optional[DateTime]: 1a

145 return self._started_at 

146 

147 @property 1a

148 def started(self) -> bool: 1a

149 return self._started_at is not None 

150 

151 @property 1a

152 def limit(self) -> Optional[int]: 1a

153 return int(self._limiter.total_tokens) if self._limiter else None 

154 

155 @property 1a

156 def current_tasks(self) -> Optional[int]: 1a

157 return ( 

158 int(self._limiter.borrowed_tokens) 

159 if self._limiter 

160 else sum(len(runs) for runs in self.in_flight_task_runs.values()) 

161 ) 

162 

163 @property 1a

164 def available_tasks(self) -> Optional[int]: 1a

165 return int(self._limiter.available_tokens) if self._limiter else None 

166 

167 def handle_sigterm(self, signum: int, frame: object) -> None: 1a

168 """ 

169 Shuts down the task worker when a SIGTERM is received. 

170 """ 

171 logger.info("SIGTERM received, initiating graceful shutdown...") 

172 from_sync.call_in_loop_thread(create_call(self.stop)) 

173 

174 sys.exit(0) 

175 

176 @sync_compatible 1a

177 async def start(self, timeout: Optional[float] = None) -> None: 1a

178 """ 

179 Starts a task worker, which runs the tasks provided in the constructor. 

180 

181 Args: 

182 timeout: If provided, the task worker will exit after the given number of 

183 seconds. Defaults to None, meaning the task worker will run indefinitely. 

184 """ 

185 _register_signal(signal.SIGTERM, self.handle_sigterm) 

186 

187 start_client_metrics_server() 

188 

189 async with asyncnullcontext() if self.started else self: 

190 logger.info("Starting task worker...") 

191 try: 

192 with timeout_async(timeout): 

193 await self._subscribe_to_task_scheduling() 

194 except InvalidStatus as exc: 

195 if exc.response.status_code == 403: 

196 logger.error( 

197 "403: Could not establish a connection to the `/task_runs/subscriptions/scheduled`" 

198 f" endpoint found at:\n\n {get_current_settings().api.url}" 

199 "\n\nPlease double-check the values of" 

200 " `PREFECT_API_AUTH_STRING` and `PREFECT_SERVER_API_AUTH_STRING` if running a Prefect server " 

201 "or `PREFECT_API_URL` and `PREFECT_API_KEY` environment variables if using Prefect Cloud." 

202 ) 

203 else: 

204 raise 

205 

206 @sync_compatible 1a

207 async def stop(self): 1a

208 """Stops the task worker's polling cycle.""" 

209 if not self.started: 

210 raise RuntimeError( 

211 "Task worker has not yet started. Please start the task worker by" 

212 " calling .start()" 

213 ) 

214 

215 self._started_at = None 

216 self.stopping = True 

217 

218 raise StopTaskWorker 

219 

220 async def _acquire_token(self, task_run_id: UUID) -> bool: 1a

221 try: 

222 if self._limiter: 

223 await self._limiter.acquire_on_behalf_of(task_run_id) 

224 except RuntimeError: 

225 logger.debug(f"Token already acquired for task run: {task_run_id!r}") 

226 return False 

227 

228 return True 

229 

230 def _release_token(self, task_run_id: UUID) -> bool: 1a

231 try: 

232 if self._limiter: 

233 self._limiter.release_on_behalf_of(task_run_id) 

234 except RuntimeError: 

235 logger.debug(f"No token to release for task run: {task_run_id!r}") 

236 return False 

237 

238 return True 

239 

240 async def _subscribe_to_task_scheduling(self): 1a

241 base_url = get_current_settings().api.url 

242 if base_url is None: 

243 raise ValueError( 

244 "`PREFECT_API_URL` must be set to use the task worker. " 

245 "Task workers are not compatible with the ephemeral API." 

246 ) 

247 task_keys_repr = " | ".join( 

248 task_key.split(".")[-1].split("-")[0] for task_key in sorted(self.task_keys) 

249 ) 

250 logger.info(f"Subscribing to runs of task(s): {task_keys_repr}") 

251 async for task_run in Subscription( 

252 model=TaskRun, 

253 path="/task_runs/subscriptions/scheduled", 

254 keys=self.task_keys, 

255 client_id=self.client_id, 

256 base_url=base_url, 

257 ): 

258 logger.info(f"Received task run: {task_run.id} - {task_run.name}") 

259 

260 token_acquired = await self._acquire_token(task_run.id) 

261 if token_acquired: 

262 assert self._runs_task_group is not None, ( 

263 "Task group was not initialized" 

264 ) 

265 self._runs_task_group.start_soon( 

266 self._safe_submit_scheduled_task_run, task_run 

267 ) 

268 

269 async def _safe_submit_scheduled_task_run(self, task_run: TaskRun): 1a

270 self.in_flight_task_runs[task_run.task_key][task_run.id] = ( 

271 prefect.types._datetime.now("UTC") 

272 ) 

273 try: 

274 await self._submit_scheduled_task_run(task_run) 

275 except BaseException as exc: 

276 logger.exception( 

277 f"Failed to submit task run {task_run.id!r}", 

278 exc_info=exc, 

279 ) 

280 finally: 

281 self.in_flight_task_runs[task_run.task_key].pop(task_run.id, None) 

282 self.finished_task_runs[task_run.task_key] += 1 

283 self._release_token(task_run.id) 

284 

285 async def _submit_scheduled_task_run(self, task_run: TaskRun): 1a

286 if TYPE_CHECKING: 

287 assert task_run.state is not None 

288 logger.debug( 

289 f"Found task run: {task_run.name!r} in state: {task_run.state.name!r}" 

290 ) 

291 

292 task = next((t for t in self.tasks if t.task_key == task_run.task_key), None) 

293 

294 if not task: 

295 if get_current_settings().tasks.scheduling.delete_failed_submissions: 

296 logger.warning( 

297 f"Task {task_run.name!r} not found in task worker registry." 

298 ) 

299 await self._client._client.delete(f"/task_runs/{task_run.id}") # type: ignore 

300 

301 return 

302 

303 # The ID of the parameters for this run are stored in the Scheduled state's 

304 # state_details. If there is no parameters_id, then the task was created 

305 # without parameters. 

306 parameters = {} 

307 wait_for = [] 

308 run_context = None 

309 if should_try_to_read_parameters(task, task_run): 

310 parameters_id = task_run.state.state_details.task_parameters_id 

311 if parameters_id is None: 

312 logger.warning( 

313 f"Task run {task_run.id!r} has no parameters ID. Skipping parameter retrieval." 

314 ) 

315 return 

316 

317 task.persist_result = True 

318 store = await ResultStore( 

319 result_storage=await get_or_create_default_task_scheduling_storage() 

320 ).update_for_task(task) 

321 try: 

322 run_data: dict[str, Any] = await read_parameters(store, parameters_id) 

323 parameters = run_data.get("parameters", {}) 

324 wait_for = run_data.get("wait_for", []) 

325 run_context = run_data.get("context", None) 

326 except Exception as exc: 

327 logger.exception( 

328 f"Failed to read parameters for task run {task_run.id!r}", 

329 exc_info=exc, 

330 ) 

331 if get_current_settings().tasks.scheduling.delete_failed_submissions: 

332 logger.info( 

333 f"Deleting task run {task_run.id!r} because it failed to submit" 

334 ) 

335 await self._client._client.delete(f"/task_runs/{task_run.id}") 

336 return 

337 

338 initial_state = task_run.state 

339 new_state = Pending() 

340 new_state.state_details.deferred = True 

341 new_state.state_details.task_run_id = task_run.id 

342 new_state.state_details.flow_run_id = task_run.flow_run_id 

343 state = new_state 

344 task_run.state = state 

345 

346 emit_task_run_state_change_event( 

347 task_run=task_run, 

348 initial_state=initial_state, 

349 validated_state=state, 

350 ) 

351 

352 if task_run_url := url_for(task_run): 

353 logger.info( 

354 f"Submitting task run {task_run.name!r} to engine. View in the UI: {task_run_url}" 

355 ) 

356 

357 if task.isasync: 

358 await run_task_async( 

359 task=task, 

360 task_run_id=task_run.id, 

361 task_run=task_run, 

362 parameters=parameters, 

363 wait_for=wait_for, 

364 return_type="state", 

365 context=run_context, 

366 ) 

367 else: 

368 context = copy_context() 

369 future = self._executor.submit( 

370 context.run, 

371 run_task_sync, 

372 task=task, 

373 task_run_id=task_run.id, 

374 task_run=task_run, 

375 parameters=parameters, 

376 wait_for=wait_for, 

377 return_type="state", 

378 context=run_context, 

379 ) 

380 await asyncio.wrap_future(future) 

381 

382 async def execute_task_run(self, task_run: TaskRun) -> None: 1a

383 """Execute a task run in the task worker.""" 

384 async with self if not self.started else asyncnullcontext(): 

385 token_acquired = await self._acquire_token(task_run.id) 

386 if token_acquired: 

387 await self._safe_submit_scheduled_task_run(task_run) 

388 

389 async def __aenter__(self) -> Self: 1a

390 logger.debug("Starting task worker...") 

391 

392 if self._client._closed: # pyright: ignore[reportPrivateUsage] 

393 self._client = get_client() 

394 self._runs_task_group = anyio.create_task_group() 

395 

396 await self._exit_stack.__aenter__() 

397 await self._exit_stack.enter_async_context(self._client) 

398 await self._exit_stack.enter_async_context(self._runs_task_group) 

399 self._exit_stack.enter_context(self._executor) 

400 

401 self._started_at = prefect.types._datetime.now("UTC") 

402 return self 

403 

404 async def __aexit__(self, *exc_info: Any) -> None: 1a

405 logger.debug("Stopping task worker...") 

406 self._started_at = None 

407 await self._exit_stack.__aexit__(*exc_info) 

408 

409 

410def create_status_server(task_worker: TaskWorker) -> FastAPI: 1a

411 status_app = FastAPI() 

412 

413 @status_app.get("/status") 

414 def status(): # pyright: ignore[reportUnusedFunction] 

415 if TYPE_CHECKING: 

416 assert task_worker.started_at is not None 

417 return { 

418 "client_id": task_worker.client_id, 

419 "started_at": task_worker.started_at.isoformat(), 

420 "stopping": task_worker.stopping, 

421 "limit": task_worker.limit, 

422 "current": task_worker.current_tasks, 

423 "available": task_worker.available_tasks, 

424 "tasks": sorted(task_worker.task_keys), 

425 "finished": task_worker.finished_task_runs, 

426 "in_flight": { 

427 key: {str(run): start.isoformat() for run, start in tasks.items()} 

428 for key, tasks in task_worker.in_flight_task_runs.items() 

429 }, 

430 } 

431 

432 return status_app 

433 

434 

435@sync_compatible 1a

436async def serve( 1a

437 *tasks: Task[P, R], 

438 limit: Optional[int] = 10, 

439 status_server_port: Optional[int] = None, 

440 timeout: Optional[float] = None, 

441): 

442 """Serve the provided tasks so that their runs may be submitted to 

443 and executed in the engine. Tasks do not need to be within a flow run context to be 

444 submitted. You must `.submit` the same task object that you pass to `serve`. 

445 

446 Args: 

447 - tasks: A list of tasks to serve. When a scheduled task run is found for a 

448 given task, the task run will be submitted to the engine for execution. 

449 - limit: The maximum number of tasks that can be run concurrently. Defaults to 10. 

450 Pass `None` to remove the limit. 

451 - status_server_port: An optional port on which to start an HTTP server 

452 exposing status information about the task worker. If not provided, no 

453 status server will run. 

454 - timeout: If provided, the task worker will exit after the given number of 

455 seconds. Defaults to None, meaning the task worker will run indefinitely. 

456 

457 Example: 

458 ```python 

459 from prefect import task 

460 from prefect.task_worker import serve 

461 

462 @task(log_prints=True) 

463 def say(message: str): 

464 print(message) 

465 

466 @task(log_prints=True) 

467 def yell(message: str): 

468 print(message.upper()) 

469 

470 # starts a long-lived process that listens for scheduled runs of these tasks 

471 serve(say, yell) 

472 ``` 

473 """ 

474 task_worker = TaskWorker(*tasks, limit=limit) 

475 

476 status_server_task = None 

477 if status_server_port is not None: 

478 server = uvicorn.Server( 

479 uvicorn.Config( 

480 app=create_status_server(task_worker), 

481 host="127.0.0.1", 

482 port=status_server_port, 

483 access_log=False, 

484 log_level="warning", 

485 ) 

486 ) 

487 loop = asyncio.get_event_loop() 

488 status_server_task = loop.create_task(server.serve()) 

489 

490 try: 

491 await task_worker.start(timeout=timeout) 

492 

493 except TimeoutError: 

494 if timeout is not None: 

495 logger.info(f"Task worker timed out after {timeout} seconds. Exiting...") 

496 else: 

497 raise 

498 

499 except BaseExceptionGroup as exc: # novermin 

500 exceptions = exc.exceptions 

501 n_exceptions = len(exceptions) 

502 logger.error( 

503 f"Task worker stopped with {n_exceptions} exception{'s' if n_exceptions != 1 else ''}:" 

504 f"\n" + "\n".join(str(e) for e in exceptions) 

505 ) 

506 

507 except StopTaskWorker: 

508 logger.info("Task worker stopped.") 

509 

510 except (asyncio.CancelledError, KeyboardInterrupt): 

511 logger.info("Task worker interrupted, stopping...") 

512 

513 finally: 

514 if status_server_task: 

515 status_server_task.cancel() 

516 try: 

517 await status_server_task 

518 except asyncio.CancelledError: 

519 pass 

520 

521 

522async def store_parameters( 1a

523 result_store: ResultStore, identifier: UUID, parameters: dict[str, Any] 

524) -> None: 

525 """Store parameters for a task run in the result store. 

526 

527 Args: 

528 result_store: The result store to store the parameters in. 

529 identifier: The identifier of the task run. 

530 parameters: The parameters to store. 

531 """ 

532 if result_store.result_storage is None: 

533 raise ValueError( 

534 "Result store is not configured - must have a result storage block to store parameters" 

535 ) 

536 record = ResultRecord( 

537 result=parameters, 

538 metadata=ResultRecordMetadata( 

539 serializer=result_store.serializer, storage_key=str(identifier) 

540 ), 

541 ) 

542 

543 await call_explicitly_async_block_method( 

544 result_store.result_storage, 

545 "write_path", 

546 (f"parameters/{identifier}",), 

547 {"content": record.serialize()}, 

548 ) 

549 

550 

551async def read_parameters( 1a

552 result_store: ResultStore, identifier: UUID 

553) -> dict[str, Any]: 

554 """Read parameters for a task run from the result store. 

555 

556 Args: 

557 result_store: The result store to read the parameters from. 

558 identifier: The identifier of the task run. 

559 

560 Returns: 

561 The parameters for the task run. 

562 """ 

563 if result_store.result_storage is None: 

564 raise ValueError( 

565 "Result store is not configured - must have a result storage block to read parameters" 

566 ) 

567 record: ResultRecord[Any] = ResultRecord[Any].deserialize( 

568 await call_explicitly_async_block_method( 

569 result_store.result_storage, 

570 "read_path", 

571 (f"parameters/{identifier}",), 

572 {}, 

573 ) 

574 ) 

575 return record.result