Coverage for /usr/local/lib/python3.12/site-packages/prefect/task_engine.py: 14%
795 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 13:38 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 13:38 +0000
1from __future__ import annotations 1a
3import asyncio 1a
4import datetime 1a
5import inspect 1a
6import logging 1a
7import threading 1a
8import time 1a
9from asyncio import CancelledError 1a
10from contextlib import ExitStack, asynccontextmanager, contextmanager, nullcontext 1a
11from dataclasses import dataclass, field 1a
12from datetime import timedelta 1a
13from functools import partial 1a
14from textwrap import dedent 1a
15from typing import ( 1a
16 TYPE_CHECKING,
17 Any,
18 AsyncGenerator,
19 Callable,
20 Coroutine,
21 Generator,
22 Generic,
23 Literal,
24 Optional,
25 Sequence,
26 Type,
27 TypeVar,
28 Union,
29 overload,
30)
31from uuid import UUID 1a
33import anyio 1a
34from opentelemetry import trace 1a
35from typing_extensions import ParamSpec, Self 1a
37import prefect.types._datetime 1a
38from prefect._internal.compatibility import deprecated 1a
39from prefect.cache_policies import CachePolicy 1a
40from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client 1a
41from prefect.client.schemas import TaskRun 1a
42from prefect.client.schemas.objects import ConcurrencyLeaseHolder, RunInput, State 1a
43from prefect.concurrency._asyncio import concurrency as _aconcurrency 1a
44from prefect.concurrency._sync import concurrency as _concurrency 1a
45from prefect.concurrency.context import ConcurrencyContext 1a
46from prefect.context import ( 1a
47 AssetContext,
48 AsyncClientContext,
49 FlowRunContext,
50 SyncClientContext,
51 TaskRunContext,
52 hydrated_context,
53)
54from prefect.events.schemas.events import Event as PrefectEvent 1a
55from prefect.exceptions import ( 1a
56 Abort,
57 Pause,
58 PrefectException,
59 TerminationSignal,
60 UpstreamTaskError,
61)
62from prefect.logging.loggers import get_logger, patch_print, task_run_logger 1a
63from prefect.results import ( 1a
64 ResultRecord,
65 _format_user_supplied_storage_key, # type: ignore[reportPrivateUsage]
66 get_result_store,
67 should_persist_result,
68)
69from prefect.settings import ( 1a
70 PREFECT_DEBUG_MODE,
71 PREFECT_TASKS_REFRESH_CACHE,
72)
73from prefect.settings.context import get_current_settings 1a
74from prefect.states import ( 1a
75 AwaitingRetry,
76 Completed,
77 Failed,
78 Pending,
79 Retrying,
80 Running,
81 exception_to_crashed_state,
82 exception_to_failed_state,
83 return_value_to_state,
84)
85from prefect.telemetry.run_telemetry import RunTelemetry 1a
86from prefect.transactions import ( 1a
87 AsyncTransaction,
88 IsolationLevel,
89 Transaction,
90 atransaction,
91 transaction,
92)
93from prefect.utilities._engine import get_hook_name 1a
94from prefect.utilities.annotations import NotSet 1a
95from prefect.utilities.asyncutils import run_coro_as_sync 1a
96from prefect.utilities.callables import call_with_parameters, parameters_to_args_kwargs 1a
97from prefect.utilities.collections import visit_collection 1a
98from prefect.utilities.engine import ( 1a
99 emit_task_run_state_change_event,
100 link_state_to_task_run_result,
101 resolve_to_final_result,
102)
103from prefect.utilities.math import clamped_poisson_interval 1a
104from prefect.utilities.timeout import timeout, timeout_async 1a
106if TYPE_CHECKING: 106 ↛ 107line 106 didn't jump to line 107 because the condition on line 106 was never true1a
107 from prefect.tasks import OneOrManyFutureOrResult, Task
109P = ParamSpec("P") 1a
110R = TypeVar("R") 1a
112BACKOFF_MAX = 10 1a
115class TaskRunTimeoutError(TimeoutError): 1a
116 """Raised when a task run exceeds its timeout."""
119@dataclass 1a
120class BaseTaskRunEngine(Generic[P, R]): 1a
121 task: Union["Task[P, R]", "Task[P, Coroutine[Any, Any, R]]"] 1a
122 logger: logging.Logger = field(default_factory=lambda: get_logger("engine")) 1a
123 parameters: Optional[dict[str, Any]] = None 1a
124 task_run: Optional[TaskRun] = None 1a
125 retries: int = 0 1a
126 wait_for: Optional["OneOrManyFutureOrResult[Any]"] = None 1a
127 context: Optional[dict[str, Any]] = None 1a
128 # holds the return value from the user code
129 _return_value: Union[R, Type[NotSet]] = NotSet 1a
130 # holds the exception raised by the user code, if any
131 _raised: Union[Exception, BaseException, Type[NotSet]] = NotSet 1a
132 _initial_run_context: Optional[TaskRunContext] = None 1a
133 _is_started: bool = False 1a
134 _task_name_set: bool = False 1a
135 _last_event: Optional[PrefectEvent] = None 1a
136 _telemetry: RunTelemetry = field(default_factory=RunTelemetry) 1a
138 def __post_init__(self) -> None: 1a
139 if self.parameters is None:
140 self.parameters = {}
142 @property 1a
143 def state(self) -> State: 1a
144 if not self.task_run or not self.task_run.state:
145 raise ValueError("Task run is not set")
146 return self.task_run.state
148 def is_cancelled(self) -> bool: 1a
149 if (
150 self.context
151 and "cancel_event" in self.context
152 and isinstance(self.context["cancel_event"], threading.Event)
153 ):
154 return self.context["cancel_event"].is_set()
155 return False
157 def compute_transaction_key(self) -> Optional[str]: 1a
158 key: Optional[str] = None
159 if self.task.cache_policy and isinstance(self.task.cache_policy, CachePolicy):
160 flow_run_context = FlowRunContext.get()
161 task_run_context = TaskRunContext.get()
163 if flow_run_context:
164 parameters = flow_run_context.parameters
165 else:
166 parameters = None
168 try:
169 if not task_run_context:
170 raise ValueError("Task run context is not set")
171 key = self.task.cache_policy.compute_key(
172 task_ctx=task_run_context,
173 inputs=self.parameters or {},
174 flow_parameters=parameters or {},
175 )
176 except Exception:
177 self.logger.exception(
178 "Error encountered when computing cache key - result will not be persisted.",
179 )
180 key = None
181 elif self.task.result_storage_key is not None:
182 key = _format_user_supplied_storage_key(self.task.result_storage_key)
183 return key
185 def _resolve_parameters(self): 1a
186 if not self.parameters:
187 return None
189 resolved_parameters = {}
190 for parameter, value in self.parameters.items():
191 try:
192 resolved_parameters[parameter] = visit_collection(
193 value,
194 visit_fn=resolve_to_final_result,
195 return_data=True,
196 max_depth=-1,
197 remove_annotations=True,
198 context={"parameter_name": parameter},
199 )
200 except UpstreamTaskError:
201 raise
202 except Exception as exc:
203 raise PrefectException(
204 f"Failed to resolve inputs in parameter {parameter!r}. If your"
205 " parameter type is not supported, consider using the `quote`"
206 " annotation to skip resolution of inputs."
207 ) from exc
209 self.parameters = resolved_parameters
211 def _set_custom_task_run_name(self): 1a
212 from prefect.utilities._engine import resolve_custom_task_run_name
214 # update the task run name if necessary
215 if not self._task_name_set and self.task.task_run_name:
216 task_run_name = resolve_custom_task_run_name(
217 task=self.task, parameters=self.parameters or {}
218 )
220 self.logger.extra["task_run_name"] = task_run_name
221 self.logger.debug(
222 f"Renamed task run {self.task_run.name!r} to {task_run_name!r}"
223 )
224 self.task_run.name = task_run_name
225 self._task_name_set = True
226 self._telemetry.update_run_name(name=task_run_name)
228 def _wait_for_dependencies(self): 1a
229 if not self.wait_for:
230 return
232 visit_collection(
233 self.wait_for,
234 visit_fn=resolve_to_final_result,
235 return_data=False,
236 max_depth=-1,
237 remove_annotations=True,
238 context={"current_task_run": self.task_run, "current_task": self.task},
239 )
241 @deprecated.deprecated_callable( 1a
242 start_date=datetime.datetime(2025, 8, 21),
243 end_date=datetime.datetime(2025, 11, 21),
244 help="This method is no longer used and will be removed in a future version.",
245 )
246 def record_terminal_state_timing(self, state: State) -> None: 1a
247 if self.task_run and self.task_run.start_time and not self.task_run.end_time:
248 self.task_run.end_time = state.timestamp
250 if self.state.is_running():
251 self.task_run.total_run_time += state.timestamp - self.state.timestamp
253 def is_running(self) -> bool: 1a
254 """Whether or not the engine is currently running a task."""
255 if (task_run := getattr(self, "task_run", None)) is None:
256 return False
257 return task_run.state.is_running() or task_run.state.is_scheduled()
259 def log_finished_message(self) -> None: 1a
260 if not self.task_run:
261 return
263 # If debugging, use the more complete `repr` than the usual `str` description
264 display_state = repr(self.state) if PREFECT_DEBUG_MODE else str(self.state)
265 level = logging.INFO if self.state.is_completed() else logging.ERROR
266 msg = f"Finished in state {display_state}"
267 if self.state.is_pending() and self.state.name != "NotReady":
268 msg += (
269 "\nPlease wait for all submitted tasks to complete"
270 " before exiting your flow by calling `.wait()` on the "
271 "`PrefectFuture` returned from your `.submit()` calls."
272 )
273 msg += dedent(
274 """
276 Example:
278 from prefect import flow, task
280 @task
281 def say_hello(name):
282 print(f"Hello, {name}!")
284 @flow
285 def example_flow():
286 future = say_hello.submit(name="Marvin")
287 future.wait()
289 example_flow()
290 """
291 )
292 self.logger.log(
293 level=level,
294 msg=msg,
295 )
297 def handle_rollback(self, txn: Transaction) -> None: 1a
298 assert self.task_run is not None
300 rolled_back_state = Completed(
301 name="RolledBack",
302 message="Task rolled back as part of transaction",
303 )
305 self._last_event = emit_task_run_state_change_event(
306 task_run=self.task_run,
307 initial_state=self.state,
308 validated_state=rolled_back_state,
309 follows=self._last_event,
310 )
313@dataclass 1a
314class SyncTaskRunEngine(BaseTaskRunEngine[P, R]): 1a
315 task_run: Optional[TaskRun] = None 1a
316 _client: Optional[SyncPrefectClient] = None 1a
318 @property 1a
319 def client(self) -> SyncPrefectClient: 1a
320 if not self._is_started or self._client is None:
321 raise RuntimeError("Engine has not started.")
322 return self._client
324 def can_retry(self, exc_or_state: Exception | State[R]) -> bool: 1a
325 retry_condition: Optional[
326 Callable[["Task[P, Coroutine[Any, Any, R]]", TaskRun, State[R]], bool]
327 ] = self.task.retry_condition_fn
329 failure_type = "exception" if isinstance(exc_or_state, Exception) else "state"
331 if not self.task_run:
332 raise ValueError("Task run is not set")
333 try:
334 self.logger.debug(
335 f"Running `retry_condition_fn` check {retry_condition!r} for task"
336 f" {self.task.name!r}"
337 )
338 state = Failed(
339 data=exc_or_state,
340 message=f"Task run encountered unexpected {failure_type}: {repr(exc_or_state)}",
341 )
342 if inspect.iscoroutinefunction(retry_condition):
343 should_retry = run_coro_as_sync(
344 retry_condition(self.task, self.task_run, state)
345 )
346 elif inspect.isfunction(retry_condition):
347 should_retry = retry_condition(self.task, self.task_run, state)
348 else:
349 should_retry = not retry_condition
350 return should_retry
351 except Exception:
352 self.logger.error(
353 (
354 "An error was encountered while running `retry_condition_fn` check"
355 f" '{retry_condition!r}' for task {self.task.name!r}"
356 ),
357 exc_info=True,
358 )
359 return False
361 def call_hooks(self, state: Optional[State] = None) -> None: 1a
362 if state is None:
363 state = self.state
364 task = self.task
365 task_run = self.task_run
367 if not task_run:
368 raise ValueError("Task run is not set")
370 if state.is_failed() and task.on_failure_hooks:
371 hooks = task.on_failure_hooks
372 elif state.is_completed() and task.on_completion_hooks:
373 hooks = task.on_completion_hooks
374 elif state.is_running() and task.on_running_hooks:
375 hooks = task.on_running_hooks
376 else:
377 hooks = None
379 for hook in hooks or []:
380 hook_name = get_hook_name(hook)
382 try:
383 self.logger.info(
384 f"Running hook {hook_name!r} in response to entering state"
385 f" {state.name!r}"
386 )
387 result = hook(task, task_run, state)
388 if asyncio.iscoroutine(result):
389 run_coro_as_sync(result)
390 except Exception:
391 self.logger.error(
392 f"An error was encountered while running hook {hook_name!r}",
393 exc_info=True,
394 )
395 else:
396 self.logger.info(f"Hook {hook_name!r} finished running successfully")
398 def begin_run(self) -> None: 1a
399 new_state = Running()
401 assert self.task_run is not None, "Task run is not set"
402 self.task_run.start_time = new_state.timestamp
404 flow_run_context = FlowRunContext.get()
405 if flow_run_context and flow_run_context.flow_run:
406 # Carry forward any task run information from the flow run
407 flow_run = flow_run_context.flow_run
408 self.task_run.flow_run_run_count = flow_run.run_count
410 state = self.set_state(new_state)
412 # TODO: this is temporary until the API stops rejecting state transitions
413 # and the client / transaction store becomes the source of truth
414 # this is a bandaid caused by the API storing a Completed state with a bad
415 # result reference that no longer exists
416 if state.is_completed():
417 try:
418 state.result(retry_result_failure=False, _sync=True) # type: ignore[reportCallIssue]
419 except Exception:
420 state = self.set_state(new_state, force=True)
422 backoff_count = 0
424 # TODO: Could this listen for state change events instead of polling?
425 while state.is_pending() or state.is_paused():
426 if backoff_count < BACKOFF_MAX:
427 backoff_count += 1
428 interval = clamped_poisson_interval(
429 average_interval=backoff_count, clamping_factor=0.3
430 )
431 time.sleep(interval)
432 state = self.set_state(new_state)
434 # Call on_running hooks after the task has entered the Running state
435 if state.is_running():
436 self.call_hooks(state)
438 def set_state(self, state: State[R], force: bool = False) -> State[R]: 1a
439 last_state = self.state
440 if not self.task_run:
441 raise ValueError("Task run is not set")
443 self.task_run.state = new_state = state
445 if last_state.timestamp == new_state.timestamp:
446 # Ensure that the state timestamp is unique, or at least not equal to the last state.
447 # This might occur especially on Windows where the timestamp resolution is limited.
448 new_state.timestamp += timedelta(microseconds=1)
450 # Ensure that the state_details are populated with the current run IDs
451 new_state.state_details.task_run_id = self.task_run.id
452 new_state.state_details.flow_run_id = self.task_run.flow_run_id
454 # Predictively update the de-normalized task_run.state_* attributes
455 self.task_run.state_id = new_state.id
456 self.task_run.state_type = new_state.type
457 self.task_run.state_name = new_state.name
458 if last_state.is_running():
459 self.task_run.total_run_time += new_state.timestamp - last_state.timestamp
461 if new_state.is_running():
462 self.task_run.run_count += 1
464 if new_state.is_final():
465 if (
466 self.task_run
467 and self.task_run.start_time
468 and not self.task_run.end_time
469 ):
470 self.task_run.end_time = new_state.timestamp
472 if isinstance(state.data, ResultRecord):
473 result = state.data.result
474 else:
475 result = state.data
477 link_state_to_task_run_result(new_state, result)
479 # emit a state change event
480 self._last_event = emit_task_run_state_change_event(
481 task_run=self.task_run,
482 initial_state=last_state,
483 validated_state=self.task_run.state,
484 follows=self._last_event,
485 )
486 self._telemetry.update_state(new_state)
487 return new_state
489 def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": 1a
490 if self._return_value is not NotSet:
491 if isinstance(self._return_value, ResultRecord):
492 return self._return_value.result
493 # otherwise, return the value as is
494 return self._return_value
496 if self._raised is not NotSet:
497 # if the task raised an exception, raise it
498 if raise_on_failure:
499 raise self._raised
501 # otherwise, return the exception
502 return self._raised
504 def handle_success( 1a
505 self, result: R, transaction: Transaction
506 ) -> Union[ResultRecord[R], None, Coroutine[Any, Any, R], R]:
507 # Handle the case where the task explicitly returns a failed state, in
508 # which case we should retry the task if it has retries left.
509 if isinstance(result, State) and result.is_failed():
510 if self.handle_retry(result):
511 return None
513 if self.task.cache_expiration is not None:
514 expiration = prefect.types._datetime.now("UTC") + self.task.cache_expiration
515 else:
516 expiration = None
518 terminal_state = run_coro_as_sync(
519 return_value_to_state(
520 result,
521 result_store=get_result_store(),
522 key=transaction.key,
523 expiration=expiration,
524 )
525 )
527 # Avoid logging when running this rollback hook since it is not user-defined
528 handle_rollback = partial(self.handle_rollback)
529 handle_rollback.log_on_run = False
531 transaction.stage(
532 terminal_state.data,
533 on_rollback_hooks=[handle_rollback] + self.task.on_rollback_hooks,
534 on_commit_hooks=self.task.on_commit_hooks,
535 )
536 if transaction.is_committed():
537 terminal_state.name = "Cached"
539 self.set_state(terminal_state)
540 self._return_value = result
542 self._telemetry.end_span_on_success()
544 def handle_retry(self, exc_or_state: Exception | State[R]) -> bool: 1a
545 """Handle any task run retries.
547 - If the task has retries left, and the retry condition is met, set the task to retrying and return True.
548 - If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
549 - If the task has no retries left, or the retry condition is not met, return False.
550 """
551 failure_type = "exception" if isinstance(exc_or_state, Exception) else "state"
552 if self.retries < self.task.retries and self.can_retry(exc_or_state):
553 if self.task.retry_delay_seconds:
554 delay = (
555 self.task.retry_delay_seconds[
556 min(self.retries, len(self.task.retry_delay_seconds) - 1)
557 ] # repeat final delay value if attempts exceed specified delays
558 if isinstance(self.task.retry_delay_seconds, Sequence)
559 else self.task.retry_delay_seconds
560 )
561 new_state = AwaitingRetry(
562 scheduled_time=prefect.types._datetime.now("UTC")
563 + timedelta(seconds=delay)
564 )
565 else:
566 delay = None
567 new_state = Retrying()
569 self.logger.info(
570 "Task run failed with %s: %r - Retry %s/%s will start %s",
571 failure_type,
572 exc_or_state,
573 self.retries + 1,
574 self.task.retries,
575 str(delay) + " second(s) from now" if delay else "immediately",
576 )
578 self.set_state(new_state, force=True)
579 # Call on_running hooks if we transitioned to a Running state (immediate retry)
580 if new_state.is_running():
581 self.call_hooks(new_state)
582 self.retries: int = self.retries + 1
583 return True
584 elif self.retries >= self.task.retries:
585 if self.task.retries > 0:
586 self.logger.error(
587 f"Task run failed with {failure_type}: {exc_or_state!r} - Retries are exhausted",
588 exc_info=True,
589 )
590 else:
591 self.logger.error(
592 f"Task run failed with {failure_type}: {exc_or_state!r}",
593 exc_info=True,
594 )
595 return False
597 return False
599 def handle_exception(self, exc: Exception) -> None: 1a
600 # If the task fails, and we have retries left, set the task to retrying.
601 self._telemetry.record_exception(exc)
602 if not self.handle_retry(exc):
603 # If the task has no retries left, or the retry condition is not met, set the task to failed.
604 state = run_coro_as_sync(
605 exception_to_failed_state(
606 exc,
607 message="Task run encountered an exception",
608 result_store=get_result_store(),
609 write_result=True,
610 )
611 )
612 self.set_state(state)
613 self._raised = exc
614 self._telemetry.end_span_on_failure(state.message if state else None)
616 def handle_timeout(self, exc: TimeoutError) -> None: 1a
617 if not self.handle_retry(exc):
618 if isinstance(exc, TaskRunTimeoutError):
619 message = f"Task run exceeded timeout of {self.task.timeout_seconds} second(s)"
620 else:
621 message = f"Task run failed due to timeout: {exc!r}"
622 self.logger.error(message)
623 state = Failed(
624 data=exc,
625 message=message,
626 name="TimedOut",
627 )
628 self.set_state(state)
629 self._raised = exc
631 def handle_crash(self, exc: BaseException) -> None: 1a
632 state = run_coro_as_sync(exception_to_crashed_state(exc))
633 self.logger.error(f"Crash detected! {state.message}")
634 self.logger.debug("Crash details:", exc_info=exc)
635 self.set_state(state, force=True)
636 self._raised = exc
637 self._telemetry.record_exception(exc)
638 self._telemetry.end_span_on_failure(state.message if state else None)
640 @contextmanager 1a
641 def setup_run_context(self, client: Optional[SyncPrefectClient] = None): 1a
642 from prefect.utilities.engine import (
643 should_log_prints,
644 )
646 settings = get_current_settings()
648 if client is None:
649 client = self.client
650 if not self.task_run:
651 raise ValueError("Task run is not set")
653 with ExitStack() as stack:
654 if log_prints := should_log_prints(self.task):
655 stack.enter_context(patch_print())
656 if self.task.persist_result is not None:
657 persist_result = self.task.persist_result
658 elif settings.tasks.default_persist_result is not None:
659 persist_result = settings.tasks.default_persist_result
660 else:
661 persist_result = should_persist_result()
663 stack.enter_context(
664 TaskRunContext(
665 task=self.task,
666 log_prints=log_prints,
667 task_run=self.task_run,
668 parameters=self.parameters,
669 result_store=get_result_store().update_for_task(
670 self.task, _sync=True
671 ),
672 client=client,
673 persist_result=persist_result,
674 )
675 )
676 stack.enter_context(ConcurrencyContext())
678 self.logger: "logging.Logger" = task_run_logger(
679 task_run=self.task_run, task=self.task
680 ) # type: ignore
682 yield
684 @contextmanager 1a
685 def asset_context(self): 1a
686 parent_asset_ctx = AssetContext.get()
688 if parent_asset_ctx and parent_asset_ctx.copy_to_child_ctx:
689 asset_ctx = parent_asset_ctx.model_copy()
690 asset_ctx.copy_to_child_ctx = False
691 else:
692 asset_ctx = AssetContext.from_task_and_inputs(
693 self.task, self.task_run.id, self.task_run.task_inputs
694 )
696 with asset_ctx as ctx:
697 try:
698 yield
699 finally:
700 ctx.emit_events(self.state)
702 @contextmanager 1a
703 def initialize_run( 1a
704 self,
705 task_run_id: Optional[UUID] = None,
706 dependencies: Optional[dict[str, set[RunInput]]] = None,
707 ) -> Generator[Self, Any, Any]:
708 """
709 Enters a client context and creates a task run if needed.
710 """
712 with hydrated_context(self.context):
713 with SyncClientContext.get_or_create() as client_ctx:
714 self._client = client_ctx.client
715 self._is_started = True
716 parent_flow_run_context = FlowRunContext.get()
717 parent_task_run_context = TaskRunContext.get()
719 try:
720 if not self.task_run:
721 self.task_run = run_coro_as_sync(
722 self.task.create_local_run(
723 id=task_run_id,
724 parameters=self.parameters,
725 flow_run_context=parent_flow_run_context,
726 parent_task_run_context=parent_task_run_context,
727 wait_for=self.wait_for,
728 extra_task_inputs=dependencies,
729 )
730 )
731 # Emit an event to capture that the task run was in the `PENDING` state.
732 self._last_event = emit_task_run_state_change_event(
733 task_run=self.task_run,
734 initial_state=None,
735 validated_state=self.task_run.state,
736 )
738 with self.setup_run_context():
739 # setup_run_context might update the task run name, so log creation here
740 self.logger.debug(
741 f"Created task run {self.task_run.name!r} for task {self.task.name!r}"
742 )
744 self._telemetry.start_span(
745 run=self.task_run,
746 client=self.client,
747 parameters=self.parameters,
748 )
750 yield self
752 except TerminationSignal as exc:
753 # TerminationSignals are caught and handled as crashes
754 self.handle_crash(exc)
755 raise exc
757 except Exception:
758 # regular exceptions are caught and re-raised to the user
759 raise
760 except (Pause, Abort) as exc:
761 # Do not capture internal signals as crashes
762 if isinstance(exc, Abort):
763 self.logger.error("Task run was aborted: %s", exc)
764 raise
765 except GeneratorExit:
766 # Do not capture generator exits as crashes
767 raise
768 except BaseException as exc:
769 # BaseExceptions are caught and handled as crashes
770 self.handle_crash(exc)
771 raise
772 finally:
773 self.log_finished_message()
774 self._is_started = False
775 self._client = None
777 async def wait_until_ready(self) -> None: 1a
778 """Waits until the scheduled time (if its the future), then enters Running."""
779 if scheduled_time := self.state.state_details.scheduled_time:
780 sleep_time = (
781 scheduled_time - prefect.types._datetime.now("UTC")
782 ).total_seconds()
783 await anyio.sleep(sleep_time if sleep_time > 0 else 0)
784 new_state = Retrying() if self.state.name == "AwaitingRetry" else Running()
785 self.set_state(
786 new_state,
787 force=True,
788 )
789 # Call on_running hooks if we transitioned to a Running state
790 if self.state.is_running():
791 self.call_hooks()
793 # --------------------------
794 #
795 # The following methods compose the main task run loop
796 #
797 # --------------------------
799 @contextmanager 1a
800 def start( 1a
801 self,
802 task_run_id: Optional[UUID] = None,
803 dependencies: Optional[dict[str, set[RunInput]]] = None,
804 ) -> Generator[None, None, None]:
805 with self.initialize_run(task_run_id=task_run_id, dependencies=dependencies):
806 with (
807 trace.use_span(self._telemetry.span)
808 if self._telemetry.span
809 else nullcontext()
810 ):
811 try:
812 self._resolve_parameters()
813 self._set_custom_task_run_name()
814 self._wait_for_dependencies()
815 except UpstreamTaskError as upstream_exc:
816 self.set_state(
817 Pending(
818 name="NotReady",
819 message=str(upstream_exc),
820 ),
821 # if orchestrating a run already in a pending state, force orchestration to
822 # update the state name
823 force=self.state.is_pending(),
824 )
825 yield
826 self.call_hooks()
827 return
829 with _concurrency(
830 names=[f"tag:{tag}" for tag in self.task_run.tags],
831 occupy=1,
832 holder=ConcurrencyLeaseHolder(type="task_run", id=self.task_run.id),
833 lease_duration=60,
834 suppress_warnings=True,
835 ):
836 self.begin_run()
837 try:
838 yield
839 finally:
840 self.call_hooks()
842 @contextmanager 1a
843 def transaction_context(self) -> Generator[Transaction, None, None]: 1a
844 # refresh cache setting is now repurposes as overwrite transaction record
845 overwrite = (
846 self.task.refresh_cache
847 if self.task.refresh_cache is not None
848 else PREFECT_TASKS_REFRESH_CACHE.value()
849 )
851 isolation_level = (
852 IsolationLevel(self.task.cache_policy.isolation_level)
853 if self.task.cache_policy
854 and self.task.cache_policy is not NotSet
855 and self.task.cache_policy.isolation_level is not None
856 else None
857 )
859 with transaction(
860 key=self.compute_transaction_key(),
861 store=get_result_store(),
862 overwrite=overwrite,
863 logger=self.logger,
864 write_on_commit=should_persist_result(),
865 isolation_level=isolation_level,
866 ) as txn:
867 yield txn
869 @contextmanager 1a
870 def run_context(self): 1a
871 # reenter the run context to ensure it is up to date for every run
872 with self.setup_run_context():
873 try:
874 with timeout(
875 seconds=self.task.timeout_seconds,
876 timeout_exc_type=TaskRunTimeoutError,
877 ):
878 self.logger.debug(
879 f"Executing task {self.task.name!r} for task run {self.task_run.name!r}..."
880 )
881 if self.is_cancelled():
882 raise CancelledError("Task run cancelled by the task runner")
884 yield self
885 except TimeoutError as exc:
886 self.handle_timeout(exc)
887 except Exception as exc:
888 self.handle_exception(exc)
890 def call_task_fn( 1a
891 self, transaction: Transaction
892 ) -> Union[ResultRecord[Any], None, Coroutine[Any, Any, R], R]:
893 """
894 Convenience method to call the task function. Returns a coroutine if the
895 task is async.
896 """
897 parameters = self.parameters or {}
898 if transaction.is_committed():
899 result = transaction.read()
900 else:
901 result = call_with_parameters(self.task.fn, parameters)
902 self.handle_success(result, transaction=transaction)
903 return result
906@dataclass 1a
907class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]): 1a
908 task_run: TaskRun | None = None 1a
909 _client: Optional[PrefectClient] = None 1a
911 @property 1a
912 def client(self) -> PrefectClient: 1a
913 if not self._is_started or self._client is None:
914 raise RuntimeError("Engine has not started.")
915 return self._client
917 async def can_retry(self, exc_or_state: Exception | State[R]) -> bool: 1a
918 retry_condition: Optional[
919 Callable[["Task[P, Coroutine[Any, Any, R]]", TaskRun, State[R]], bool]
920 ] = self.task.retry_condition_fn
922 failure_type = "exception" if isinstance(exc_or_state, Exception) else "state"
924 if not self.task_run:
925 raise ValueError("Task run is not set")
926 try:
927 self.logger.debug(
928 f"Running `retry_condition_fn` check {retry_condition!r} for task"
929 f" {self.task.name!r}"
930 )
931 state = Failed(
932 data=exc_or_state,
933 message=f"Task run encountered unexpected {failure_type}: {repr(exc_or_state)}",
934 )
935 if inspect.iscoroutinefunction(retry_condition):
936 should_retry = await retry_condition(self.task, self.task_run, state)
937 elif inspect.isfunction(retry_condition):
938 should_retry = retry_condition(self.task, self.task_run, state)
939 else:
940 should_retry = not retry_condition
941 return should_retry
943 except Exception:
944 self.logger.error(
945 (
946 "An error was encountered while running `retry_condition_fn` check"
947 f" '{retry_condition!r}' for task {self.task.name!r}"
948 ),
949 exc_info=True,
950 )
951 return False
953 async def call_hooks(self, state: Optional[State] = None) -> None: 1a
954 if state is None:
955 state = self.state
956 task = self.task
957 task_run = self.task_run
959 if not task_run:
960 raise ValueError("Task run is not set")
962 if state.is_failed() and task.on_failure_hooks:
963 hooks = task.on_failure_hooks
964 elif state.is_completed() and task.on_completion_hooks:
965 hooks = task.on_completion_hooks
966 elif state.is_running() and task.on_running_hooks:
967 hooks = task.on_running_hooks
968 else:
969 hooks = None
971 for hook in hooks or []:
972 hook_name = get_hook_name(hook)
974 try:
975 self.logger.info(
976 f"Running hook {hook_name!r} in response to entering state"
977 f" {state.name!r}"
978 )
979 result = hook(task, task_run, state)
980 if inspect.isawaitable(result):
981 await result
982 except Exception:
983 self.logger.error(
984 f"An error was encountered while running hook {hook_name!r}",
985 exc_info=True,
986 )
987 else:
988 self.logger.info(f"Hook {hook_name!r} finished running successfully")
990 async def begin_run(self) -> None: 1a
991 try:
992 self._resolve_parameters()
993 self._set_custom_task_run_name()
994 except UpstreamTaskError as upstream_exc:
995 state = await self.set_state(
996 Pending(
997 name="NotReady",
998 message=str(upstream_exc),
999 ),
1000 # if orchestrating a run already in a pending state, force orchestration to
1001 # update the state name
1002 force=self.state.is_pending(),
1003 )
1004 return
1006 new_state = Running()
1008 self.task_run.start_time = new_state.timestamp
1010 flow_run_context = FlowRunContext.get()
1011 if flow_run_context:
1012 # Carry forward any task run information from the flow run
1013 flow_run = flow_run_context.flow_run
1014 self.task_run.flow_run_run_count = flow_run.run_count
1016 state = await self.set_state(new_state)
1018 # TODO: this is temporary until the API stops rejecting state transitions
1019 # and the client / transaction store becomes the source of truth
1020 # this is a bandaid caused by the API storing a Completed state with a bad
1021 # result reference that no longer exists
1022 if state.is_completed():
1023 try:
1024 await state.result(retry_result_failure=False)
1025 except Exception:
1026 state = await self.set_state(new_state, force=True)
1028 backoff_count = 0
1030 # TODO: Could this listen for state change events instead of polling?
1031 while state.is_pending() or state.is_paused():
1032 if backoff_count < BACKOFF_MAX:
1033 backoff_count += 1
1034 interval = clamped_poisson_interval(
1035 average_interval=backoff_count, clamping_factor=0.3
1036 )
1037 await anyio.sleep(interval)
1038 state = await self.set_state(new_state)
1040 # Call on_running hooks after the task has entered the Running state
1041 if state.is_running():
1042 await self.call_hooks(state)
1044 async def set_state(self, state: State, force: bool = False) -> State: 1a
1045 last_state = self.state
1046 if not self.task_run:
1047 raise ValueError("Task run is not set")
1049 self.task_run.state = new_state = state
1051 if last_state.timestamp == new_state.timestamp:
1052 # Ensure that the state timestamp is unique, or at least not equal to the last state.
1053 # This might occur especially on Windows where the timestamp resolution is limited.
1054 new_state.timestamp += timedelta(microseconds=1)
1056 # Ensure that the state_details are populated with the current run IDs
1057 new_state.state_details.task_run_id = self.task_run.id
1058 new_state.state_details.flow_run_id = self.task_run.flow_run_id
1060 # Predictively update the de-normalized task_run.state_* attributes
1061 self.task_run.state_id = new_state.id
1062 self.task_run.state_type = new_state.type
1063 self.task_run.state_name = new_state.name
1065 if last_state.is_running():
1066 self.task_run.total_run_time += new_state.timestamp - last_state.timestamp
1068 if new_state.is_running():
1069 self.task_run.run_count += 1
1071 if new_state.is_final():
1072 if (
1073 self.task_run
1074 and self.task_run.start_time
1075 and not self.task_run.end_time
1076 ):
1077 self.task_run.end_time = new_state.timestamp
1079 if isinstance(new_state.data, ResultRecord):
1080 result = new_state.data.result
1081 else:
1082 result = new_state.data
1084 link_state_to_task_run_result(new_state, result)
1086 # emit a state change event
1087 self._last_event = emit_task_run_state_change_event(
1088 task_run=self.task_run,
1089 initial_state=last_state,
1090 validated_state=self.task_run.state,
1091 follows=self._last_event,
1092 )
1094 self._telemetry.update_state(new_state)
1095 return new_state
1097 async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": 1a
1098 if self._return_value is not NotSet:
1099 if isinstance(self._return_value, ResultRecord):
1100 return self._return_value.result
1101 # otherwise, return the value as is
1102 return self._return_value
1104 if self._raised is not NotSet:
1105 # if the task raised an exception, raise it
1106 if raise_on_failure:
1107 raise self._raised
1109 # otherwise, return the exception
1110 return self._raised
1112 async def handle_success( 1a
1113 self, result: R, transaction: AsyncTransaction
1114 ) -> Union[ResultRecord[R], None, Coroutine[Any, Any, R], R]:
1115 if isinstance(result, State) and result.is_failed():
1116 if await self.handle_retry(result):
1117 return None
1119 if self.task.cache_expiration is not None:
1120 expiration = prefect.types._datetime.now("UTC") + self.task.cache_expiration
1121 else:
1122 expiration = None
1124 terminal_state = await return_value_to_state(
1125 result,
1126 result_store=get_result_store(),
1127 key=transaction.key,
1128 expiration=expiration,
1129 )
1131 # Avoid logging when running this rollback hook since it is not user-defined
1132 handle_rollback = partial(self.handle_rollback)
1133 handle_rollback.log_on_run = False
1135 transaction.stage(
1136 terminal_state.data,
1137 on_rollback_hooks=[handle_rollback] + self.task.on_rollback_hooks,
1138 on_commit_hooks=self.task.on_commit_hooks,
1139 )
1140 if transaction.is_committed():
1141 terminal_state.name = "Cached"
1143 await self.set_state(terminal_state)
1144 self._return_value = result
1145 self._telemetry.end_span_on_success()
1147 return result
1149 async def handle_retry(self, exc_or_state: Exception | State[R]) -> bool: 1a
1150 """Handle any task run retries.
1152 - If the task has retries left, and the retry condition is met, set the task to retrying and return True.
1153 - If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
1154 - If the task has no retries left, or the retry condition is not met, return False.
1155 """
1156 failure_type = "exception" if isinstance(exc_or_state, Exception) else "state"
1158 if self.retries < self.task.retries and await self.can_retry(exc_or_state):
1159 if self.task.retry_delay_seconds:
1160 delay = (
1161 self.task.retry_delay_seconds[
1162 min(self.retries, len(self.task.retry_delay_seconds) - 1)
1163 ] # repeat final delay value if attempts exceed specified delays
1164 if isinstance(self.task.retry_delay_seconds, Sequence)
1165 else self.task.retry_delay_seconds
1166 )
1167 new_state = AwaitingRetry(
1168 scheduled_time=prefect.types._datetime.now("UTC")
1169 + timedelta(seconds=delay)
1170 )
1171 else:
1172 delay = None
1173 new_state = Retrying()
1175 self.logger.info(
1176 "Task run failed with %s: %r - Retry %s/%s will start %s",
1177 failure_type,
1178 exc_or_state,
1179 self.retries + 1,
1180 self.task.retries,
1181 str(delay) + " second(s) from now" if delay else "immediately",
1182 )
1184 await self.set_state(new_state, force=True)
1185 # Call on_running hooks if we transitioned to a Running state (immediate retry)
1186 if new_state.is_running():
1187 await self.call_hooks(new_state)
1188 self.retries: int = self.retries + 1
1189 return True
1190 elif self.retries >= self.task.retries:
1191 if self.task.retries > 0:
1192 self.logger.error(
1193 f"Task run failed with {failure_type}: {exc_or_state!r} - Retries are exhausted",
1194 exc_info=True,
1195 )
1196 else:
1197 self.logger.error(
1198 f"Task run failed with {failure_type}: {exc_or_state!r}",
1199 exc_info=True,
1200 )
1201 return False
1203 return False
1205 async def handle_exception(self, exc: Exception) -> None: 1a
1206 # If the task fails, and we have retries left, set the task to retrying.
1207 self._telemetry.record_exception(exc)
1208 if not await self.handle_retry(exc):
1209 # If the task has no retries left, or the retry condition is not met, set the task to failed.
1210 state = await exception_to_failed_state(
1211 exc,
1212 message="Task run encountered an exception",
1213 result_store=get_result_store(),
1214 )
1215 await self.set_state(state)
1216 self._raised = exc
1218 self._telemetry.end_span_on_failure(state.message)
1220 async def handle_timeout(self, exc: TimeoutError) -> None: 1a
1221 self._telemetry.record_exception(exc)
1222 if not await self.handle_retry(exc):
1223 if isinstance(exc, TaskRunTimeoutError):
1224 message = f"Task run exceeded timeout of {self.task.timeout_seconds} second(s)"
1225 else:
1226 message = f"Task run failed due to timeout: {exc!r}"
1227 self.logger.error(message)
1228 state = Failed(
1229 data=exc,
1230 message=message,
1231 name="TimedOut",
1232 )
1233 await self.set_state(state)
1234 self._raised = exc
1235 self._telemetry.end_span_on_failure(state.message)
1237 async def handle_crash(self, exc: BaseException) -> None: 1a
1238 state = await exception_to_crashed_state(exc)
1239 self.logger.error(f"Crash detected! {state.message}")
1240 self.logger.debug("Crash details:", exc_info=exc)
1241 await self.set_state(state, force=True)
1242 self._raised = exc
1244 self._telemetry.record_exception(exc)
1245 self._telemetry.end_span_on_failure(state.message)
1247 @asynccontextmanager 1a
1248 async def setup_run_context(self, client: Optional[PrefectClient] = None): 1a
1249 from prefect.utilities.engine import (
1250 should_log_prints,
1251 )
1253 settings = get_current_settings()
1255 if client is None:
1256 client = self.client
1257 if not self.task_run:
1258 raise ValueError("Task run is not set")
1260 with ExitStack() as stack:
1261 if log_prints := should_log_prints(self.task):
1262 stack.enter_context(patch_print())
1263 if self.task.persist_result is not None:
1264 persist_result = self.task.persist_result
1265 elif settings.tasks.default_persist_result is not None:
1266 persist_result = settings.tasks.default_persist_result
1267 else:
1268 persist_result = should_persist_result()
1270 stack.enter_context(
1271 TaskRunContext(
1272 task=self.task,
1273 log_prints=log_prints,
1274 task_run=self.task_run,
1275 parameters=self.parameters,
1276 result_store=await get_result_store().update_for_task(
1277 self.task, _sync=False
1278 ),
1279 client=client,
1280 persist_result=persist_result,
1281 )
1282 )
1283 stack.enter_context(ConcurrencyContext())
1285 self.logger: "logging.Logger" = task_run_logger(
1286 task_run=self.task_run, task=self.task
1287 ) # type: ignore
1289 yield
1291 @asynccontextmanager 1a
1292 async def asset_context(self): 1a
1293 parent_asset_ctx = AssetContext.get()
1295 if parent_asset_ctx and parent_asset_ctx.copy_to_child_ctx:
1296 asset_ctx = parent_asset_ctx.model_copy()
1297 asset_ctx.copy_to_child_ctx = False
1298 else:
1299 asset_ctx = AssetContext.from_task_and_inputs(
1300 self.task, self.task_run.id, self.task_run.task_inputs
1301 )
1303 with asset_ctx as ctx:
1304 try:
1305 yield
1306 finally:
1307 ctx.emit_events(self.state)
1309 @asynccontextmanager 1a
1310 async def initialize_run( 1a
1311 self,
1312 task_run_id: Optional[UUID] = None,
1313 dependencies: Optional[dict[str, set[RunInput]]] = None,
1314 ) -> AsyncGenerator[Self, Any]:
1315 """
1316 Enters a client context and creates a task run if needed.
1317 """
1319 with hydrated_context(self.context):
1320 async with AsyncClientContext.get_or_create():
1321 self._client = get_client()
1322 self._is_started = True
1323 parent_flow_run_context = FlowRunContext.get()
1324 parent_task_run_context = TaskRunContext.get()
1326 try:
1327 if not self.task_run:
1328 self.task_run = await self.task.create_local_run(
1329 id=task_run_id,
1330 parameters=self.parameters,
1331 flow_run_context=parent_flow_run_context,
1332 parent_task_run_context=parent_task_run_context,
1333 wait_for=self.wait_for,
1334 extra_task_inputs=dependencies,
1335 )
1336 # Emit an event to capture that the task run was in the `PENDING` state.
1337 self._last_event = emit_task_run_state_change_event(
1338 task_run=self.task_run,
1339 initial_state=None,
1340 validated_state=self.task_run.state,
1341 )
1343 async with self.setup_run_context():
1344 # setup_run_context might update the task run name, so log creation here
1345 self.logger.debug(
1346 f"Created task run {self.task_run.name!r} for task {self.task.name!r}"
1347 )
1349 await self._telemetry.async_start_span(
1350 run=self.task_run,
1351 client=self.client,
1352 parameters=self.parameters,
1353 )
1355 yield self
1357 except TerminationSignal as exc:
1358 # TerminationSignals are caught and handled as crashes
1359 await self.handle_crash(exc)
1360 raise exc
1362 except Exception:
1363 # regular exceptions are caught and re-raised to the user
1364 raise
1365 except (Pause, Abort) as exc:
1366 # Do not capture internal signals as crashes
1367 if isinstance(exc, Abort):
1368 self.logger.error("Task run was aborted: %s", exc)
1369 raise
1370 except GeneratorExit:
1371 # Do not capture generator exits as crashes
1372 raise
1373 except BaseException as exc:
1374 # BaseExceptions are caught and handled as crashes
1375 await self.handle_crash(exc)
1376 raise
1377 finally:
1378 self.log_finished_message()
1379 self._is_started = False
1380 self._client = None
1382 async def wait_until_ready(self) -> None: 1a
1383 """Waits until the scheduled time (if its the future), then enters Running."""
1384 if scheduled_time := self.state.state_details.scheduled_time:
1385 sleep_time = (
1386 scheduled_time - prefect.types._datetime.now("UTC")
1387 ).total_seconds()
1388 await anyio.sleep(sleep_time if sleep_time > 0 else 0)
1389 new_state = Retrying() if self.state.name == "AwaitingRetry" else Running()
1390 await self.set_state(
1391 new_state,
1392 force=True,
1393 )
1394 # Call on_running hooks if we transitioned to a Running state
1395 if self.state.is_running():
1396 await self.call_hooks()
1398 # --------------------------
1399 #
1400 # The following methods compose the main task run loop
1401 #
1402 # --------------------------
1404 @asynccontextmanager 1a
1405 async def start( 1a
1406 self,
1407 task_run_id: Optional[UUID] = None,
1408 dependencies: Optional[dict[str, set[RunInput]]] = None,
1409 ) -> AsyncGenerator[None, None]:
1410 async with self.initialize_run(
1411 task_run_id=task_run_id, dependencies=dependencies
1412 ):
1413 with (
1414 trace.use_span(self._telemetry.span)
1415 if self._telemetry.span
1416 else nullcontext()
1417 ):
1418 try:
1419 self._resolve_parameters()
1420 self._set_custom_task_run_name()
1421 self._wait_for_dependencies()
1422 except UpstreamTaskError as upstream_exc:
1423 await self.set_state(
1424 Pending(
1425 name="NotReady",
1426 message=str(upstream_exc),
1427 ),
1428 # if orchestrating a run already in a pending state, force orchestration to
1429 # update the state name
1430 force=self.state.is_pending(),
1431 )
1432 yield
1433 await self.call_hooks()
1434 return
1436 async with _aconcurrency(
1437 names=[f"tag:{tag}" for tag in self.task_run.tags],
1438 occupy=1,
1439 holder=ConcurrencyLeaseHolder(type="task_run", id=self.task_run.id),
1440 lease_duration=60,
1441 suppress_warnings=True,
1442 ):
1443 await self.begin_run()
1444 try:
1445 yield
1446 finally:
1447 await self.call_hooks()
1449 @asynccontextmanager 1a
1450 async def transaction_context(self) -> AsyncGenerator[AsyncTransaction, None]: 1a
1451 # refresh cache setting is now repurposes as overwrite transaction record
1452 overwrite = (
1453 self.task.refresh_cache
1454 if self.task.refresh_cache is not None
1455 else PREFECT_TASKS_REFRESH_CACHE.value()
1456 )
1457 isolation_level = (
1458 IsolationLevel(self.task.cache_policy.isolation_level)
1459 if self.task.cache_policy
1460 and self.task.cache_policy is not NotSet
1461 and self.task.cache_policy.isolation_level is not None
1462 else None
1463 )
1465 async with atransaction(
1466 key=self.compute_transaction_key(),
1467 store=get_result_store(),
1468 overwrite=overwrite,
1469 logger=self.logger,
1470 write_on_commit=should_persist_result(),
1471 isolation_level=isolation_level,
1472 ) as txn:
1473 yield txn
1475 @asynccontextmanager 1a
1476 async def run_context(self): 1a
1477 # reenter the run context to ensure it is up to date for every run
1478 async with self.setup_run_context():
1479 try:
1480 with timeout_async(
1481 seconds=self.task.timeout_seconds,
1482 timeout_exc_type=TaskRunTimeoutError,
1483 ):
1484 self.logger.debug(
1485 f"Executing task {self.task.name!r} for task run {self.task_run.name!r}..."
1486 )
1487 if self.is_cancelled():
1488 raise CancelledError("Task run cancelled by the task runner")
1490 yield self
1491 except TimeoutError as exc:
1492 await self.handle_timeout(exc)
1493 except Exception as exc:
1494 await self.handle_exception(exc)
1496 async def call_task_fn( 1a
1497 self, transaction: AsyncTransaction
1498 ) -> Union[ResultRecord[Any], None, Coroutine[Any, Any, R], R]:
1499 """
1500 Convenience method to call the task function. Returns a coroutine if the
1501 task is async.
1502 """
1503 parameters = self.parameters or {}
1504 if transaction.is_committed():
1505 result = await transaction.read()
1506 else:
1507 result = await call_with_parameters(self.task.fn, parameters)
1508 await self.handle_success(result, transaction=transaction)
1509 return result
1512def run_task_sync( 1a
1513 task: "Task[P, R]",
1514 task_run_id: Optional[UUID] = None,
1515 task_run: Optional[TaskRun] = None,
1516 parameters: Optional[dict[str, Any]] = None,
1517 wait_for: Optional["OneOrManyFutureOrResult[Any]"] = None,
1518 return_type: Literal["state", "result"] = "result",
1519 dependencies: Optional[dict[str, set[RunInput]]] = None,
1520 context: Optional[dict[str, Any]] = None,
1521) -> Union[R, State, None]:
1522 engine = SyncTaskRunEngine[P, R](
1523 task=task,
1524 parameters=parameters,
1525 task_run=task_run,
1526 wait_for=wait_for,
1527 context=context,
1528 )
1530 with engine.start(task_run_id=task_run_id, dependencies=dependencies):
1531 while engine.is_running():
1532 run_coro_as_sync(engine.wait_until_ready())
1533 with (
1534 engine.asset_context(),
1535 engine.run_context(),
1536 engine.transaction_context() as txn,
1537 ):
1538 engine.call_task_fn(txn)
1540 return engine.state if return_type == "state" else engine.result()
1543async def run_task_async( 1a
1544 task: "Task[P, R]",
1545 task_run_id: Optional[UUID] = None,
1546 task_run: Optional[TaskRun] = None,
1547 parameters: Optional[dict[str, Any]] = None,
1548 wait_for: Optional["OneOrManyFutureOrResult[Any]"] = None,
1549 return_type: Literal["state", "result"] = "result",
1550 dependencies: Optional[dict[str, set[RunInput]]] = None,
1551 context: Optional[dict[str, Any]] = None,
1552) -> Union[R, State, None]:
1553 engine = AsyncTaskRunEngine[P, R](
1554 task=task,
1555 parameters=parameters,
1556 task_run=task_run,
1557 wait_for=wait_for,
1558 context=context,
1559 )
1561 async with engine.start(task_run_id=task_run_id, dependencies=dependencies):
1562 while engine.is_running():
1563 await engine.wait_until_ready()
1564 async with (
1565 engine.asset_context(),
1566 engine.run_context(),
1567 engine.transaction_context() as txn,
1568 ):
1569 await engine.call_task_fn(txn)
1571 return engine.state if return_type == "state" else await engine.result()
1574def run_generator_task_sync( 1a
1575 task: "Task[P, R]",
1576 task_run_id: Optional[UUID] = None,
1577 task_run: Optional[TaskRun] = None,
1578 parameters: Optional[dict[str, Any]] = None,
1579 wait_for: Optional["OneOrManyFutureOrResult[Any]"] = None,
1580 return_type: Literal["state", "result"] = "result",
1581 dependencies: Optional[dict[str, set[RunInput]]] = None,
1582 context: Optional[dict[str, Any]] = None,
1583) -> Generator[R, None, None]:
1584 if return_type != "result":
1585 raise ValueError("The return_type for a generator task must be 'result'")
1587 engine = SyncTaskRunEngine[P, R](
1588 task=task,
1589 parameters=parameters,
1590 task_run=task_run,
1591 wait_for=wait_for,
1592 context=context,
1593 )
1595 with engine.start(task_run_id=task_run_id, dependencies=dependencies):
1596 while engine.is_running():
1597 run_coro_as_sync(engine.wait_until_ready())
1598 with (
1599 engine.asset_context(),
1600 engine.run_context(),
1601 engine.transaction_context() as txn,
1602 ):
1603 # TODO: generators should default to commit_mode=OFF
1604 # because they are dynamic by definition
1605 # for now we just prevent this branch explicitly
1606 if False and txn.is_committed():
1607 txn.read()
1608 else:
1609 call_args, call_kwargs = parameters_to_args_kwargs(
1610 task.fn, engine.parameters or {}
1611 )
1612 gen = task.fn(*call_args, **call_kwargs)
1613 try:
1614 while True:
1615 gen_result = next(gen)
1616 # link the current state to the result for dependency tracking
1617 #
1618 # TODO: this could grow the task_run_result
1619 # dictionary in an unbounded way, so finding a
1620 # way to periodically clean it up (using
1621 # weakrefs or similar) would be good
1622 link_state_to_task_run_result(engine.state, gen_result)
1623 yield gen_result
1624 except StopIteration as exc:
1625 engine.handle_success(exc.value, transaction=txn)
1626 except GeneratorExit as exc:
1627 engine.handle_success(None, transaction=txn)
1628 gen.throw(exc)
1630 return engine.result()
1633async def run_generator_task_async( 1a
1634 task: "Task[P, R]",
1635 task_run_id: Optional[UUID] = None,
1636 task_run: Optional[TaskRun] = None,
1637 parameters: Optional[dict[str, Any]] = None,
1638 wait_for: Optional["OneOrManyFutureOrResult[Any]"] = None,
1639 return_type: Literal["state", "result"] = "result",
1640 dependencies: Optional[dict[str, set[RunInput]]] = None,
1641 context: Optional[dict[str, Any]] = None,
1642) -> AsyncGenerator[R, None]:
1643 if return_type != "result":
1644 raise ValueError("The return_type for a generator task must be 'result'")
1645 engine = AsyncTaskRunEngine[P, R](
1646 task=task,
1647 parameters=parameters,
1648 task_run=task_run,
1649 wait_for=wait_for,
1650 context=context,
1651 )
1653 async with engine.start(task_run_id=task_run_id, dependencies=dependencies):
1654 while engine.is_running():
1655 await engine.wait_until_ready()
1656 async with (
1657 engine.asset_context(),
1658 engine.run_context(),
1659 engine.transaction_context() as txn,
1660 ):
1661 # TODO: generators should default to commit_mode=OFF
1662 # because they are dynamic by definition
1663 # for now we just prevent this branch explicitly
1664 if False and txn.is_committed():
1665 txn.read()
1666 else:
1667 call_args, call_kwargs = parameters_to_args_kwargs(
1668 task.fn, engine.parameters or {}
1669 )
1670 gen = task.fn(*call_args, **call_kwargs)
1671 try:
1672 while True:
1673 # can't use anext in Python < 3.10
1674 gen_result = await gen.__anext__()
1675 # link the current state to the result for dependency tracking
1676 #
1677 # TODO: this could grow the task_run_result
1678 # dictionary in an unbounded way, so finding a
1679 # way to periodically clean it up (using
1680 # weakrefs or similar) would be good
1681 link_state_to_task_run_result(engine.state, gen_result)
1682 yield gen_result
1683 except (StopAsyncIteration, GeneratorExit) as exc:
1684 await engine.handle_success(None, transaction=txn)
1685 if isinstance(exc, GeneratorExit):
1686 gen.throw(exc)
1688 # async generators can't return, but we can raise failures here
1689 if engine.state.is_failed():
1690 await engine.result()
1693@overload 1a
1694def run_task( 1694 ↛ exitline 1694 didn't return from function 'run_task' because 1a
1695 task: "Task[P, R]",
1696 task_run_id: Optional[UUID] = None,
1697 task_run: Optional[TaskRun] = None,
1698 parameters: Optional[dict[str, Any]] = None,
1699 wait_for: Optional["OneOrManyFutureOrResult[Any]"] = None,
1700 return_type: Literal["state"] = "state",
1701 dependencies: Optional[dict[str, set[RunInput]]] = None,
1702 context: Optional[dict[str, Any]] = None,
1703) -> State[R]: ...
1706@overload 1a
1707def run_task( 1707 ↛ exitline 1707 didn't return from function 'run_task' because 1a
1708 task: "Task[P, R]",
1709 task_run_id: Optional[UUID] = None,
1710 task_run: Optional[TaskRun] = None,
1711 parameters: Optional[dict[str, Any]] = None,
1712 wait_for: Optional["OneOrManyFutureOrResult[Any]"] = None,
1713 return_type: Literal["result"] = "result",
1714 dependencies: Optional[dict[str, set[RunInput]]] = None,
1715 context: Optional[dict[str, Any]] = None,
1716) -> R: ...
1719def run_task( 1a
1720 task: "Task[P, Union[R, Coroutine[Any, Any, R]]]",
1721 task_run_id: Optional[UUID] = None,
1722 task_run: Optional[TaskRun] = None,
1723 parameters: Optional[dict[str, Any]] = None,
1724 wait_for: Optional["OneOrManyFutureOrResult[Any]"] = None,
1725 return_type: Literal["state", "result"] = "result",
1726 dependencies: Optional[dict[str, set[RunInput]]] = None,
1727 context: Optional[dict[str, Any]] = None,
1728) -> Union[R, State, None, Coroutine[Any, Any, Union[R, State, None]]]:
1729 """
1730 Runs the provided task.
1732 Args:
1733 task: The task to run
1734 task_run_id: The ID of the task run; if not provided, a new task run
1735 will be created
1736 task_run: The task run object; if not provided, a new task run
1737 will be created
1738 parameters: The parameters to pass to the task
1739 wait_for: A list of futures to wait for before running the task
1740 return_type: The return type to return; either "state" or "result"
1741 dependencies: A dictionary of task run inputs to use for dependency tracking
1742 context: A dictionary containing the context to use for the task run; only
1743 required if the task is running on in a remote environment
1745 Returns:
1746 The result of the task run
1747 """
1748 kwargs: dict[str, Any] = dict(
1749 task=task,
1750 task_run_id=task_run_id,
1751 task_run=task_run,
1752 parameters=parameters,
1753 wait_for=wait_for,
1754 return_type=return_type,
1755 dependencies=dependencies,
1756 context=context,
1757 )
1758 if task.isasync and task.isgenerator:
1759 return run_generator_task_async(**kwargs)
1760 elif task.isgenerator:
1761 return run_generator_task_sync(**kwargs)
1762 elif task.isasync:
1763 return run_task_async(**kwargs)
1764 else:
1765 return run_task_sync(**kwargs)