Coverage for /usr/local/lib/python3.12/site-packages/prefect/task_worker.py: 25%
268 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
1from __future__ import annotations 1a
3import asyncio 1a
4import inspect 1a
5import os 1a
6import signal 1a
7import socket 1a
8import sys 1a
9from concurrent.futures import ThreadPoolExecutor 1a
10from contextlib import AsyncExitStack 1a
11from contextvars import copy_context 1a
12from typing import TYPE_CHECKING, Any, Optional 1a
13from uuid import UUID 1a
15import anyio 1a
16import anyio.abc 1a
17import uvicorn 1a
18from exceptiongroup import BaseExceptionGroup # novermin 1a
19from fastapi import FastAPI 1a
20from typing_extensions import ParamSpec, Self, TypeVar 1a
21from websockets.exceptions import InvalidStatus 1a
23import prefect.types._datetime 1a
24from prefect import Task 1a
25from prefect._internal.compatibility.blocks import call_explicitly_async_block_method 1a
26from prefect._internal.concurrency.api import create_call, from_sync 1a
27from prefect.cache_policies import DEFAULT, NO_CACHE 1a
28from prefect.client.orchestration import get_client 1a
29from prefect.client.schemas.objects import TaskRun 1a
30from prefect.client.subscriptions import Subscription 1a
31from prefect.logging.loggers import get_logger 1a
32from prefect.results import ( 1a
33 ResultRecord,
34 ResultRecordMetadata,
35 ResultStore,
36 get_or_create_default_task_scheduling_storage,
37)
38from prefect.settings import get_current_settings 1a
39from prefect.states import Pending 1a
40from prefect.task_engine import run_task_async, run_task_sync 1a
41from prefect.types import DateTime 1a
42from prefect.utilities.annotations import NotSet 1a
43from prefect.utilities.asyncutils import asyncnullcontext, sync_compatible 1a
44from prefect.utilities.engine import emit_task_run_state_change_event 1a
45from prefect.utilities.processutils import ( 1a
46 _register_signal, # pyright: ignore[reportPrivateUsage]
47)
48from prefect.utilities.services import start_client_metrics_server 1a
49from prefect.utilities.timeout import timeout_async 1a
50from prefect.utilities.urls import url_for 1a
52if TYPE_CHECKING: 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true1a
53 import logging
55logger: "logging.Logger" = get_logger("task_worker") 1a
57P = ParamSpec("P") 1a
58R = TypeVar("R", infer_variance=True) 1a
61class StopTaskWorker(Exception): 1a
62 """Raised when the task worker is stopped."""
64 pass 1a
67def should_try_to_read_parameters(task: Task[P, R], task_run: TaskRun) -> bool: 1a
68 """Determines whether a task run should read parameters from the result store."""
69 if TYPE_CHECKING:
70 assert task_run.state is not None
71 new_enough_state_details = hasattr(
72 task_run.state.state_details, "task_parameters_id"
73 )
74 task_accepts_parameters = bool(inspect.signature(task.fn).parameters)
76 return new_enough_state_details and task_accepts_parameters
79class TaskWorker: 1a
80 """This class is responsible for serving tasks that may be executed in the background
81 by a task runner via the traditional engine machinery.
83 When `start()` is called, the task worker will open a websocket connection to a
84 server-side queue of scheduled task runs. When a scheduled task run is found, the
85 scheduled task run is submitted to the engine for execution with a minimal `EngineContext`
86 so that the task run can be governed by orchestration rules.
88 Args:
89 - tasks: A list of tasks to serve. These tasks will be submitted to the engine
90 when a scheduled task run is found.
91 - limit: The maximum number of tasks that can be run concurrently. Defaults to 10.
92 Pass `None` to remove the limit.
93 """
95 def __init__( 1a
96 self,
97 *tasks: Task[P, R],
98 limit: int | None = 10,
99 ):
100 self.tasks: list["Task[..., Any]"] = []
101 for t in tasks:
102 if not TYPE_CHECKING:
103 if not isinstance(t, Task):
104 continue
106 if t.cache_policy in [None, NO_CACHE, NotSet]:
107 self.tasks.append(
108 t.with_options(persist_result=True, cache_policy=DEFAULT)
109 )
110 else:
111 self.tasks.append(t.with_options(persist_result=True))
113 self.task_keys: set[str] = set(t.task_key for t in tasks if isinstance(t, Task)) # pyright: ignore[reportUnnecessaryIsInstance]
115 self._started_at: Optional[DateTime] = None
116 self.stopping: bool = False
118 self._client = get_client()
119 self._exit_stack = AsyncExitStack()
121 try:
122 asyncio.get_running_loop()
123 except RuntimeError:
124 raise RuntimeError(
125 "TaskWorker must be initialized within an async context."
126 )
128 self._runs_task_group: Optional[anyio.abc.TaskGroup] = None
129 self._executor = ThreadPoolExecutor(max_workers=limit if limit else None)
130 self._limiter = anyio.CapacityLimiter(limit) if limit else None
132 self.in_flight_task_runs: dict[str, dict[UUID, DateTime]] = {
133 task_key: {} for task_key in self.task_keys
134 }
135 self.finished_task_runs: dict[str, int] = {
136 task_key: 0 for task_key in self.task_keys
137 }
139 @property 1a
140 def client_id(self) -> str: 1a
141 return f"{socket.gethostname()}-{os.getpid()}"
143 @property 1a
144 def started_at(self) -> Optional[DateTime]: 1a
145 return self._started_at
147 @property 1a
148 def started(self) -> bool: 1a
149 return self._started_at is not None
151 @property 1a
152 def limit(self) -> Optional[int]: 1a
153 return int(self._limiter.total_tokens) if self._limiter else None
155 @property 1a
156 def current_tasks(self) -> Optional[int]: 1a
157 return (
158 int(self._limiter.borrowed_tokens)
159 if self._limiter
160 else sum(len(runs) for runs in self.in_flight_task_runs.values())
161 )
163 @property 1a
164 def available_tasks(self) -> Optional[int]: 1a
165 return int(self._limiter.available_tokens) if self._limiter else None
167 def handle_sigterm(self, signum: int, frame: object) -> None: 1a
168 """
169 Shuts down the task worker when a SIGTERM is received.
170 """
171 logger.info("SIGTERM received, initiating graceful shutdown...")
172 from_sync.call_in_loop_thread(create_call(self.stop))
174 sys.exit(0)
176 @sync_compatible 1a
177 async def start(self, timeout: Optional[float] = None) -> None: 1a
178 """
179 Starts a task worker, which runs the tasks provided in the constructor.
181 Args:
182 timeout: If provided, the task worker will exit after the given number of
183 seconds. Defaults to None, meaning the task worker will run indefinitely.
184 """
185 _register_signal(signal.SIGTERM, self.handle_sigterm)
187 start_client_metrics_server()
189 async with asyncnullcontext() if self.started else self:
190 logger.info("Starting task worker...")
191 try:
192 with timeout_async(timeout):
193 await self._subscribe_to_task_scheduling()
194 except InvalidStatus as exc:
195 if exc.response.status_code == 403:
196 logger.error(
197 "403: Could not establish a connection to the `/task_runs/subscriptions/scheduled`"
198 f" endpoint found at:\n\n {get_current_settings().api.url}"
199 "\n\nPlease double-check the values of"
200 " `PREFECT_API_AUTH_STRING` and `PREFECT_SERVER_API_AUTH_STRING` if running a Prefect server "
201 "or `PREFECT_API_URL` and `PREFECT_API_KEY` environment variables if using Prefect Cloud."
202 )
203 else:
204 raise
206 @sync_compatible 1a
207 async def stop(self): 1a
208 """Stops the task worker's polling cycle."""
209 if not self.started:
210 raise RuntimeError(
211 "Task worker has not yet started. Please start the task worker by"
212 " calling .start()"
213 )
215 self._started_at = None
216 self.stopping = True
218 raise StopTaskWorker
220 async def _acquire_token(self, task_run_id: UUID) -> bool: 1a
221 try:
222 if self._limiter:
223 await self._limiter.acquire_on_behalf_of(task_run_id)
224 except RuntimeError:
225 logger.debug(f"Token already acquired for task run: {task_run_id!r}")
226 return False
228 return True
230 def _release_token(self, task_run_id: UUID) -> bool: 1a
231 try:
232 if self._limiter:
233 self._limiter.release_on_behalf_of(task_run_id)
234 except RuntimeError:
235 logger.debug(f"No token to release for task run: {task_run_id!r}")
236 return False
238 return True
240 async def _subscribe_to_task_scheduling(self): 1a
241 base_url = get_current_settings().api.url
242 if base_url is None:
243 raise ValueError(
244 "`PREFECT_API_URL` must be set to use the task worker. "
245 "Task workers are not compatible with the ephemeral API."
246 )
247 task_keys_repr = " | ".join(
248 task_key.split(".")[-1].split("-")[0] for task_key in sorted(self.task_keys)
249 )
250 logger.info(f"Subscribing to runs of task(s): {task_keys_repr}")
251 async for task_run in Subscription(
252 model=TaskRun,
253 path="/task_runs/subscriptions/scheduled",
254 keys=self.task_keys,
255 client_id=self.client_id,
256 base_url=base_url,
257 ):
258 logger.info(f"Received task run: {task_run.id} - {task_run.name}")
260 token_acquired = await self._acquire_token(task_run.id)
261 if token_acquired:
262 assert self._runs_task_group is not None, (
263 "Task group was not initialized"
264 )
265 self._runs_task_group.start_soon(
266 self._safe_submit_scheduled_task_run, task_run
267 )
269 async def _safe_submit_scheduled_task_run(self, task_run: TaskRun): 1a
270 self.in_flight_task_runs[task_run.task_key][task_run.id] = (
271 prefect.types._datetime.now("UTC")
272 )
273 try:
274 await self._submit_scheduled_task_run(task_run)
275 except BaseException as exc:
276 logger.exception(
277 f"Failed to submit task run {task_run.id!r}",
278 exc_info=exc,
279 )
280 finally:
281 self.in_flight_task_runs[task_run.task_key].pop(task_run.id, None)
282 self.finished_task_runs[task_run.task_key] += 1
283 self._release_token(task_run.id)
285 async def _submit_scheduled_task_run(self, task_run: TaskRun): 1a
286 if TYPE_CHECKING:
287 assert task_run.state is not None
288 logger.debug(
289 f"Found task run: {task_run.name!r} in state: {task_run.state.name!r}"
290 )
292 task = next((t for t in self.tasks if t.task_key == task_run.task_key), None)
294 if not task:
295 if get_current_settings().tasks.scheduling.delete_failed_submissions:
296 logger.warning(
297 f"Task {task_run.name!r} not found in task worker registry."
298 )
299 await self._client._client.delete(f"/task_runs/{task_run.id}") # type: ignore
301 return
303 # The ID of the parameters for this run are stored in the Scheduled state's
304 # state_details. If there is no parameters_id, then the task was created
305 # without parameters.
306 parameters = {}
307 wait_for = []
308 run_context = None
309 if should_try_to_read_parameters(task, task_run):
310 parameters_id = task_run.state.state_details.task_parameters_id
311 if parameters_id is None:
312 logger.warning(
313 f"Task run {task_run.id!r} has no parameters ID. Skipping parameter retrieval."
314 )
315 return
317 task.persist_result = True
318 store = await ResultStore(
319 result_storage=await get_or_create_default_task_scheduling_storage()
320 ).update_for_task(task)
321 try:
322 run_data: dict[str, Any] = await read_parameters(store, parameters_id)
323 parameters = run_data.get("parameters", {})
324 wait_for = run_data.get("wait_for", [])
325 run_context = run_data.get("context", None)
326 except Exception as exc:
327 logger.exception(
328 f"Failed to read parameters for task run {task_run.id!r}",
329 exc_info=exc,
330 )
331 if get_current_settings().tasks.scheduling.delete_failed_submissions:
332 logger.info(
333 f"Deleting task run {task_run.id!r} because it failed to submit"
334 )
335 await self._client._client.delete(f"/task_runs/{task_run.id}")
336 return
338 initial_state = task_run.state
339 new_state = Pending()
340 new_state.state_details.deferred = True
341 new_state.state_details.task_run_id = task_run.id
342 new_state.state_details.flow_run_id = task_run.flow_run_id
343 state = new_state
344 task_run.state = state
346 emit_task_run_state_change_event(
347 task_run=task_run,
348 initial_state=initial_state,
349 validated_state=state,
350 )
352 if task_run_url := url_for(task_run):
353 logger.info(
354 f"Submitting task run {task_run.name!r} to engine. View in the UI: {task_run_url}"
355 )
357 if task.isasync:
358 await run_task_async(
359 task=task,
360 task_run_id=task_run.id,
361 task_run=task_run,
362 parameters=parameters,
363 wait_for=wait_for,
364 return_type="state",
365 context=run_context,
366 )
367 else:
368 context = copy_context()
369 future = self._executor.submit(
370 context.run,
371 run_task_sync,
372 task=task,
373 task_run_id=task_run.id,
374 task_run=task_run,
375 parameters=parameters,
376 wait_for=wait_for,
377 return_type="state",
378 context=run_context,
379 )
380 await asyncio.wrap_future(future)
382 async def execute_task_run(self, task_run: TaskRun) -> None: 1a
383 """Execute a task run in the task worker."""
384 async with self if not self.started else asyncnullcontext():
385 token_acquired = await self._acquire_token(task_run.id)
386 if token_acquired:
387 await self._safe_submit_scheduled_task_run(task_run)
389 async def __aenter__(self) -> Self: 1a
390 logger.debug("Starting task worker...")
392 if self._client._closed: # pyright: ignore[reportPrivateUsage]
393 self._client = get_client()
394 self._runs_task_group = anyio.create_task_group()
396 await self._exit_stack.__aenter__()
397 await self._exit_stack.enter_async_context(self._client)
398 await self._exit_stack.enter_async_context(self._runs_task_group)
399 self._exit_stack.enter_context(self._executor)
401 self._started_at = prefect.types._datetime.now("UTC")
402 return self
404 async def __aexit__(self, *exc_info: Any) -> None: 1a
405 logger.debug("Stopping task worker...")
406 self._started_at = None
407 await self._exit_stack.__aexit__(*exc_info)
410def create_status_server(task_worker: TaskWorker) -> FastAPI: 1a
411 status_app = FastAPI()
413 @status_app.get("/status")
414 def status(): # pyright: ignore[reportUnusedFunction]
415 if TYPE_CHECKING:
416 assert task_worker.started_at is not None
417 return {
418 "client_id": task_worker.client_id,
419 "started_at": task_worker.started_at.isoformat(),
420 "stopping": task_worker.stopping,
421 "limit": task_worker.limit,
422 "current": task_worker.current_tasks,
423 "available": task_worker.available_tasks,
424 "tasks": sorted(task_worker.task_keys),
425 "finished": task_worker.finished_task_runs,
426 "in_flight": {
427 key: {str(run): start.isoformat() for run, start in tasks.items()}
428 for key, tasks in task_worker.in_flight_task_runs.items()
429 },
430 }
432 return status_app
435@sync_compatible 1a
436async def serve( 1a
437 *tasks: Task[P, R],
438 limit: Optional[int] = 10,
439 status_server_port: Optional[int] = None,
440 timeout: Optional[float] = None,
441):
442 """Serve the provided tasks so that their runs may be submitted to
443 and executed in the engine. Tasks do not need to be within a flow run context to be
444 submitted. You must `.submit` the same task object that you pass to `serve`.
446 Args:
447 - tasks: A list of tasks to serve. When a scheduled task run is found for a
448 given task, the task run will be submitted to the engine for execution.
449 - limit: The maximum number of tasks that can be run concurrently. Defaults to 10.
450 Pass `None` to remove the limit.
451 - status_server_port: An optional port on which to start an HTTP server
452 exposing status information about the task worker. If not provided, no
453 status server will run.
454 - timeout: If provided, the task worker will exit after the given number of
455 seconds. Defaults to None, meaning the task worker will run indefinitely.
457 Example:
458 ```python
459 from prefect import task
460 from prefect.task_worker import serve
462 @task(log_prints=True)
463 def say(message: str):
464 print(message)
466 @task(log_prints=True)
467 def yell(message: str):
468 print(message.upper())
470 # starts a long-lived process that listens for scheduled runs of these tasks
471 serve(say, yell)
472 ```
473 """
474 task_worker = TaskWorker(*tasks, limit=limit)
476 status_server_task = None
477 if status_server_port is not None:
478 server = uvicorn.Server(
479 uvicorn.Config(
480 app=create_status_server(task_worker),
481 host="127.0.0.1",
482 port=status_server_port,
483 access_log=False,
484 log_level="warning",
485 )
486 )
487 loop = asyncio.get_event_loop()
488 status_server_task = loop.create_task(server.serve())
490 try:
491 await task_worker.start(timeout=timeout)
493 except TimeoutError:
494 if timeout is not None:
495 logger.info(f"Task worker timed out after {timeout} seconds. Exiting...")
496 else:
497 raise
499 except BaseExceptionGroup as exc: # novermin
500 exceptions = exc.exceptions
501 n_exceptions = len(exceptions)
502 logger.error(
503 f"Task worker stopped with {n_exceptions} exception{'s' if n_exceptions != 1 else ''}:"
504 f"\n" + "\n".join(str(e) for e in exceptions)
505 )
507 except StopTaskWorker:
508 logger.info("Task worker stopped.")
510 except (asyncio.CancelledError, KeyboardInterrupt):
511 logger.info("Task worker interrupted, stopping...")
513 finally:
514 if status_server_task:
515 status_server_task.cancel()
516 try:
517 await status_server_task
518 except asyncio.CancelledError:
519 pass
522async def store_parameters( 1a
523 result_store: ResultStore, identifier: UUID, parameters: dict[str, Any]
524) -> None:
525 """Store parameters for a task run in the result store.
527 Args:
528 result_store: The result store to store the parameters in.
529 identifier: The identifier of the task run.
530 parameters: The parameters to store.
531 """
532 if result_store.result_storage is None:
533 raise ValueError(
534 "Result store is not configured - must have a result storage block to store parameters"
535 )
536 record = ResultRecord(
537 result=parameters,
538 metadata=ResultRecordMetadata(
539 serializer=result_store.serializer, storage_key=str(identifier)
540 ),
541 )
543 await call_explicitly_async_block_method(
544 result_store.result_storage,
545 "write_path",
546 (f"parameters/{identifier}",),
547 {"content": record.serialize()},
548 )
551async def read_parameters( 1a
552 result_store: ResultStore, identifier: UUID
553) -> dict[str, Any]:
554 """Read parameters for a task run from the result store.
556 Args:
557 result_store: The result store to read the parameters from.
558 identifier: The identifier of the task run.
560 Returns:
561 The parameters for the task run.
562 """
563 if result_store.result_storage is None:
564 raise ValueError(
565 "Result store is not configured - must have a result storage block to read parameters"
566 )
567 record: ResultRecord[Any] = ResultRecord[Any].deserialize(
568 await call_explicitly_async_block_method(
569 result_store.result_storage,
570 "read_path",
571 (f"parameters/{identifier}",),
572 {},
573 )
574 )
575 return record.result