Coverage for /usr/local/lib/python3.12/site-packages/prefect/task_engine.py: 14%

795 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 13:38 +0000

1from __future__ import annotations 1a

2 

3import asyncio 1a

4import datetime 1a

5import inspect 1a

6import logging 1a

7import threading 1a

8import time 1a

9from asyncio import CancelledError 1a

10from contextlib import ExitStack, asynccontextmanager, contextmanager, nullcontext 1a

11from dataclasses import dataclass, field 1a

12from datetime import timedelta 1a

13from functools import partial 1a

14from textwrap import dedent 1a

15from typing import ( 1a

16 TYPE_CHECKING, 

17 Any, 

18 AsyncGenerator, 

19 Callable, 

20 Coroutine, 

21 Generator, 

22 Generic, 

23 Literal, 

24 Optional, 

25 Sequence, 

26 Type, 

27 TypeVar, 

28 Union, 

29 overload, 

30) 

31from uuid import UUID 1a

32 

33import anyio 1a

34from opentelemetry import trace 1a

35from typing_extensions import ParamSpec, Self 1a

36 

37import prefect.types._datetime 1a

38from prefect._internal.compatibility import deprecated 1a

39from prefect.cache_policies import CachePolicy 1a

40from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client 1a

41from prefect.client.schemas import TaskRun 1a

42from prefect.client.schemas.objects import ConcurrencyLeaseHolder, RunInput, State 1a

43from prefect.concurrency._asyncio import concurrency as _aconcurrency 1a

44from prefect.concurrency._sync import concurrency as _concurrency 1a

45from prefect.concurrency.context import ConcurrencyContext 1a

46from prefect.context import ( 1a

47 AssetContext, 

48 AsyncClientContext, 

49 FlowRunContext, 

50 SyncClientContext, 

51 TaskRunContext, 

52 hydrated_context, 

53) 

54from prefect.events.schemas.events import Event as PrefectEvent 1a

55from prefect.exceptions import ( 1a

56 Abort, 

57 Pause, 

58 PrefectException, 

59 TerminationSignal, 

60 UpstreamTaskError, 

61) 

62from prefect.logging.loggers import get_logger, patch_print, task_run_logger 1a

63from prefect.results import ( 1a

64 ResultRecord, 

65 _format_user_supplied_storage_key, # type: ignore[reportPrivateUsage] 

66 get_result_store, 

67 should_persist_result, 

68) 

69from prefect.settings import ( 1a

70 PREFECT_DEBUG_MODE, 

71 PREFECT_TASKS_REFRESH_CACHE, 

72) 

73from prefect.settings.context import get_current_settings 1a

74from prefect.states import ( 1a

75 AwaitingRetry, 

76 Completed, 

77 Failed, 

78 Pending, 

79 Retrying, 

80 Running, 

81 exception_to_crashed_state, 

82 exception_to_failed_state, 

83 return_value_to_state, 

84) 

85from prefect.telemetry.run_telemetry import RunTelemetry 1a

86from prefect.transactions import ( 1a

87 AsyncTransaction, 

88 IsolationLevel, 

89 Transaction, 

90 atransaction, 

91 transaction, 

92) 

93from prefect.utilities._engine import get_hook_name 1a

94from prefect.utilities.annotations import NotSet 1a

95from prefect.utilities.asyncutils import run_coro_as_sync 1a

96from prefect.utilities.callables import call_with_parameters, parameters_to_args_kwargs 1a

97from prefect.utilities.collections import visit_collection 1a

98from prefect.utilities.engine import ( 1a

99 emit_task_run_state_change_event, 

100 link_state_to_task_run_result, 

101 resolve_to_final_result, 

102) 

103from prefect.utilities.math import clamped_poisson_interval 1a

104from prefect.utilities.timeout import timeout, timeout_async 1a

105 

106if TYPE_CHECKING: 106 ↛ 107line 106 didn't jump to line 107 because the condition on line 106 was never true1a

107 from prefect.tasks import OneOrManyFutureOrResult, Task 

108 

109P = ParamSpec("P") 1a

110R = TypeVar("R") 1a

111 

112BACKOFF_MAX = 10 1a

113 

114 

115class TaskRunTimeoutError(TimeoutError): 1a

116 """Raised when a task run exceeds its timeout.""" 

117 

118 

119@dataclass 1a

120class BaseTaskRunEngine(Generic[P, R]): 1a

121 task: Union["Task[P, R]", "Task[P, Coroutine[Any, Any, R]]"] 1a

122 logger: logging.Logger = field(default_factory=lambda: get_logger("engine")) 1a

123 parameters: Optional[dict[str, Any]] = None 1a

124 task_run: Optional[TaskRun] = None 1a

125 retries: int = 0 1a

126 wait_for: Optional["OneOrManyFutureOrResult[Any]"] = None 1a

127 context: Optional[dict[str, Any]] = None 1a

128 # holds the return value from the user code 

129 _return_value: Union[R, Type[NotSet]] = NotSet 1a

130 # holds the exception raised by the user code, if any 

131 _raised: Union[Exception, BaseException, Type[NotSet]] = NotSet 1a

132 _initial_run_context: Optional[TaskRunContext] = None 1a

133 _is_started: bool = False 1a

134 _task_name_set: bool = False 1a

135 _last_event: Optional[PrefectEvent] = None 1a

136 _telemetry: RunTelemetry = field(default_factory=RunTelemetry) 1a

137 

138 def __post_init__(self) -> None: 1a

139 if self.parameters is None: 

140 self.parameters = {} 

141 

142 @property 1a

143 def state(self) -> State: 1a

144 if not self.task_run or not self.task_run.state: 

145 raise ValueError("Task run is not set") 

146 return self.task_run.state 

147 

148 def is_cancelled(self) -> bool: 1a

149 if ( 

150 self.context 

151 and "cancel_event" in self.context 

152 and isinstance(self.context["cancel_event"], threading.Event) 

153 ): 

154 return self.context["cancel_event"].is_set() 

155 return False 

156 

157 def compute_transaction_key(self) -> Optional[str]: 1a

158 key: Optional[str] = None 

159 if self.task.cache_policy and isinstance(self.task.cache_policy, CachePolicy): 

160 flow_run_context = FlowRunContext.get() 

161 task_run_context = TaskRunContext.get() 

162 

163 if flow_run_context: 

164 parameters = flow_run_context.parameters 

165 else: 

166 parameters = None 

167 

168 try: 

169 if not task_run_context: 

170 raise ValueError("Task run context is not set") 

171 key = self.task.cache_policy.compute_key( 

172 task_ctx=task_run_context, 

173 inputs=self.parameters or {}, 

174 flow_parameters=parameters or {}, 

175 ) 

176 except Exception: 

177 self.logger.exception( 

178 "Error encountered when computing cache key - result will not be persisted.", 

179 ) 

180 key = None 

181 elif self.task.result_storage_key is not None: 

182 key = _format_user_supplied_storage_key(self.task.result_storage_key) 

183 return key 

184 

185 def _resolve_parameters(self): 1a

186 if not self.parameters: 

187 return None 

188 

189 resolved_parameters = {} 

190 for parameter, value in self.parameters.items(): 

191 try: 

192 resolved_parameters[parameter] = visit_collection( 

193 value, 

194 visit_fn=resolve_to_final_result, 

195 return_data=True, 

196 max_depth=-1, 

197 remove_annotations=True, 

198 context={"parameter_name": parameter}, 

199 ) 

200 except UpstreamTaskError: 

201 raise 

202 except Exception as exc: 

203 raise PrefectException( 

204 f"Failed to resolve inputs in parameter {parameter!r}. If your" 

205 " parameter type is not supported, consider using the `quote`" 

206 " annotation to skip resolution of inputs." 

207 ) from exc 

208 

209 self.parameters = resolved_parameters 

210 

211 def _set_custom_task_run_name(self): 1a

212 from prefect.utilities._engine import resolve_custom_task_run_name 

213 

214 # update the task run name if necessary 

215 if not self._task_name_set and self.task.task_run_name: 

216 task_run_name = resolve_custom_task_run_name( 

217 task=self.task, parameters=self.parameters or {} 

218 ) 

219 

220 self.logger.extra["task_run_name"] = task_run_name 

221 self.logger.debug( 

222 f"Renamed task run {self.task_run.name!r} to {task_run_name!r}" 

223 ) 

224 self.task_run.name = task_run_name 

225 self._task_name_set = True 

226 self._telemetry.update_run_name(name=task_run_name) 

227 

228 def _wait_for_dependencies(self): 1a

229 if not self.wait_for: 

230 return 

231 

232 visit_collection( 

233 self.wait_for, 

234 visit_fn=resolve_to_final_result, 

235 return_data=False, 

236 max_depth=-1, 

237 remove_annotations=True, 

238 context={"current_task_run": self.task_run, "current_task": self.task}, 

239 ) 

240 

241 @deprecated.deprecated_callable( 1a

242 start_date=datetime.datetime(2025, 8, 21), 

243 end_date=datetime.datetime(2025, 11, 21), 

244 help="This method is no longer used and will be removed in a future version.", 

245 ) 

246 def record_terminal_state_timing(self, state: State) -> None: 1a

247 if self.task_run and self.task_run.start_time and not self.task_run.end_time: 

248 self.task_run.end_time = state.timestamp 

249 

250 if self.state.is_running(): 

251 self.task_run.total_run_time += state.timestamp - self.state.timestamp 

252 

253 def is_running(self) -> bool: 1a

254 """Whether or not the engine is currently running a task.""" 

255 if (task_run := getattr(self, "task_run", None)) is None: 

256 return False 

257 return task_run.state.is_running() or task_run.state.is_scheduled() 

258 

259 def log_finished_message(self) -> None: 1a

260 if not self.task_run: 

261 return 

262 

263 # If debugging, use the more complete `repr` than the usual `str` description 

264 display_state = repr(self.state) if PREFECT_DEBUG_MODE else str(self.state) 

265 level = logging.INFO if self.state.is_completed() else logging.ERROR 

266 msg = f"Finished in state {display_state}" 

267 if self.state.is_pending() and self.state.name != "NotReady": 

268 msg += ( 

269 "\nPlease wait for all submitted tasks to complete" 

270 " before exiting your flow by calling `.wait()` on the " 

271 "`PrefectFuture` returned from your `.submit()` calls." 

272 ) 

273 msg += dedent( 

274 """ 

275 

276 Example: 

277 

278 from prefect import flow, task 

279 

280 @task 

281 def say_hello(name): 

282 print(f"Hello, {name}!") 

283 

284 @flow 

285 def example_flow(): 

286 future = say_hello.submit(name="Marvin") 

287 future.wait() 

288 

289 example_flow() 

290 """ 

291 ) 

292 self.logger.log( 

293 level=level, 

294 msg=msg, 

295 ) 

296 

297 def handle_rollback(self, txn: Transaction) -> None: 1a

298 assert self.task_run is not None 

299 

300 rolled_back_state = Completed( 

301 name="RolledBack", 

302 message="Task rolled back as part of transaction", 

303 ) 

304 

305 self._last_event = emit_task_run_state_change_event( 

306 task_run=self.task_run, 

307 initial_state=self.state, 

308 validated_state=rolled_back_state, 

309 follows=self._last_event, 

310 ) 

311 

312 

313@dataclass 1a

314class SyncTaskRunEngine(BaseTaskRunEngine[P, R]): 1a

315 task_run: Optional[TaskRun] = None 1a

316 _client: Optional[SyncPrefectClient] = None 1a

317 

318 @property 1a

319 def client(self) -> SyncPrefectClient: 1a

320 if not self._is_started or self._client is None: 

321 raise RuntimeError("Engine has not started.") 

322 return self._client 

323 

324 def can_retry(self, exc_or_state: Exception | State[R]) -> bool: 1a

325 retry_condition: Optional[ 

326 Callable[["Task[P, Coroutine[Any, Any, R]]", TaskRun, State[R]], bool] 

327 ] = self.task.retry_condition_fn 

328 

329 failure_type = "exception" if isinstance(exc_or_state, Exception) else "state" 

330 

331 if not self.task_run: 

332 raise ValueError("Task run is not set") 

333 try: 

334 self.logger.debug( 

335 f"Running `retry_condition_fn` check {retry_condition!r} for task" 

336 f" {self.task.name!r}" 

337 ) 

338 state = Failed( 

339 data=exc_or_state, 

340 message=f"Task run encountered unexpected {failure_type}: {repr(exc_or_state)}", 

341 ) 

342 if inspect.iscoroutinefunction(retry_condition): 

343 should_retry = run_coro_as_sync( 

344 retry_condition(self.task, self.task_run, state) 

345 ) 

346 elif inspect.isfunction(retry_condition): 

347 should_retry = retry_condition(self.task, self.task_run, state) 

348 else: 

349 should_retry = not retry_condition 

350 return should_retry 

351 except Exception: 

352 self.logger.error( 

353 ( 

354 "An error was encountered while running `retry_condition_fn` check" 

355 f" '{retry_condition!r}' for task {self.task.name!r}" 

356 ), 

357 exc_info=True, 

358 ) 

359 return False 

360 

361 def call_hooks(self, state: Optional[State] = None) -> None: 1a

362 if state is None: 

363 state = self.state 

364 task = self.task 

365 task_run = self.task_run 

366 

367 if not task_run: 

368 raise ValueError("Task run is not set") 

369 

370 if state.is_failed() and task.on_failure_hooks: 

371 hooks = task.on_failure_hooks 

372 elif state.is_completed() and task.on_completion_hooks: 

373 hooks = task.on_completion_hooks 

374 elif state.is_running() and task.on_running_hooks: 

375 hooks = task.on_running_hooks 

376 else: 

377 hooks = None 

378 

379 for hook in hooks or []: 

380 hook_name = get_hook_name(hook) 

381 

382 try: 

383 self.logger.info( 

384 f"Running hook {hook_name!r} in response to entering state" 

385 f" {state.name!r}" 

386 ) 

387 result = hook(task, task_run, state) 

388 if asyncio.iscoroutine(result): 

389 run_coro_as_sync(result) 

390 except Exception: 

391 self.logger.error( 

392 f"An error was encountered while running hook {hook_name!r}", 

393 exc_info=True, 

394 ) 

395 else: 

396 self.logger.info(f"Hook {hook_name!r} finished running successfully") 

397 

398 def begin_run(self) -> None: 1a

399 new_state = Running() 

400 

401 assert self.task_run is not None, "Task run is not set" 

402 self.task_run.start_time = new_state.timestamp 

403 

404 flow_run_context = FlowRunContext.get() 

405 if flow_run_context and flow_run_context.flow_run: 

406 # Carry forward any task run information from the flow run 

407 flow_run = flow_run_context.flow_run 

408 self.task_run.flow_run_run_count = flow_run.run_count 

409 

410 state = self.set_state(new_state) 

411 

412 # TODO: this is temporary until the API stops rejecting state transitions 

413 # and the client / transaction store becomes the source of truth 

414 # this is a bandaid caused by the API storing a Completed state with a bad 

415 # result reference that no longer exists 

416 if state.is_completed(): 

417 try: 

418 state.result(retry_result_failure=False, _sync=True) # type: ignore[reportCallIssue] 

419 except Exception: 

420 state = self.set_state(new_state, force=True) 

421 

422 backoff_count = 0 

423 

424 # TODO: Could this listen for state change events instead of polling? 

425 while state.is_pending() or state.is_paused(): 

426 if backoff_count < BACKOFF_MAX: 

427 backoff_count += 1 

428 interval = clamped_poisson_interval( 

429 average_interval=backoff_count, clamping_factor=0.3 

430 ) 

431 time.sleep(interval) 

432 state = self.set_state(new_state) 

433 

434 # Call on_running hooks after the task has entered the Running state 

435 if state.is_running(): 

436 self.call_hooks(state) 

437 

438 def set_state(self, state: State[R], force: bool = False) -> State[R]: 1a

439 last_state = self.state 

440 if not self.task_run: 

441 raise ValueError("Task run is not set") 

442 

443 self.task_run.state = new_state = state 

444 

445 if last_state.timestamp == new_state.timestamp: 

446 # Ensure that the state timestamp is unique, or at least not equal to the last state. 

447 # This might occur especially on Windows where the timestamp resolution is limited. 

448 new_state.timestamp += timedelta(microseconds=1) 

449 

450 # Ensure that the state_details are populated with the current run IDs 

451 new_state.state_details.task_run_id = self.task_run.id 

452 new_state.state_details.flow_run_id = self.task_run.flow_run_id 

453 

454 # Predictively update the de-normalized task_run.state_* attributes 

455 self.task_run.state_id = new_state.id 

456 self.task_run.state_type = new_state.type 

457 self.task_run.state_name = new_state.name 

458 if last_state.is_running(): 

459 self.task_run.total_run_time += new_state.timestamp - last_state.timestamp 

460 

461 if new_state.is_running(): 

462 self.task_run.run_count += 1 

463 

464 if new_state.is_final(): 

465 if ( 

466 self.task_run 

467 and self.task_run.start_time 

468 and not self.task_run.end_time 

469 ): 

470 self.task_run.end_time = new_state.timestamp 

471 

472 if isinstance(state.data, ResultRecord): 

473 result = state.data.result 

474 else: 

475 result = state.data 

476 

477 link_state_to_task_run_result(new_state, result) 

478 

479 # emit a state change event 

480 self._last_event = emit_task_run_state_change_event( 

481 task_run=self.task_run, 

482 initial_state=last_state, 

483 validated_state=self.task_run.state, 

484 follows=self._last_event, 

485 ) 

486 self._telemetry.update_state(new_state) 

487 return new_state 

488 

489 def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": 1a

490 if self._return_value is not NotSet: 

491 if isinstance(self._return_value, ResultRecord): 

492 return self._return_value.result 

493 # otherwise, return the value as is 

494 return self._return_value 

495 

496 if self._raised is not NotSet: 

497 # if the task raised an exception, raise it 

498 if raise_on_failure: 

499 raise self._raised 

500 

501 # otherwise, return the exception 

502 return self._raised 

503 

504 def handle_success( 1a

505 self, result: R, transaction: Transaction 

506 ) -> Union[ResultRecord[R], None, Coroutine[Any, Any, R], R]: 

507 # Handle the case where the task explicitly returns a failed state, in 

508 # which case we should retry the task if it has retries left. 

509 if isinstance(result, State) and result.is_failed(): 

510 if self.handle_retry(result): 

511 return None 

512 

513 if self.task.cache_expiration is not None: 

514 expiration = prefect.types._datetime.now("UTC") + self.task.cache_expiration 

515 else: 

516 expiration = None 

517 

518 terminal_state = run_coro_as_sync( 

519 return_value_to_state( 

520 result, 

521 result_store=get_result_store(), 

522 key=transaction.key, 

523 expiration=expiration, 

524 ) 

525 ) 

526 

527 # Avoid logging when running this rollback hook since it is not user-defined 

528 handle_rollback = partial(self.handle_rollback) 

529 handle_rollback.log_on_run = False 

530 

531 transaction.stage( 

532 terminal_state.data, 

533 on_rollback_hooks=[handle_rollback] + self.task.on_rollback_hooks, 

534 on_commit_hooks=self.task.on_commit_hooks, 

535 ) 

536 if transaction.is_committed(): 

537 terminal_state.name = "Cached" 

538 

539 self.set_state(terminal_state) 

540 self._return_value = result 

541 

542 self._telemetry.end_span_on_success() 

543 

544 def handle_retry(self, exc_or_state: Exception | State[R]) -> bool: 1a

545 """Handle any task run retries. 

546 

547 - If the task has retries left, and the retry condition is met, set the task to retrying and return True. 

548 - If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time. 

549 - If the task has no retries left, or the retry condition is not met, return False. 

550 """ 

551 failure_type = "exception" if isinstance(exc_or_state, Exception) else "state" 

552 if self.retries < self.task.retries and self.can_retry(exc_or_state): 

553 if self.task.retry_delay_seconds: 

554 delay = ( 

555 self.task.retry_delay_seconds[ 

556 min(self.retries, len(self.task.retry_delay_seconds) - 1) 

557 ] # repeat final delay value if attempts exceed specified delays 

558 if isinstance(self.task.retry_delay_seconds, Sequence) 

559 else self.task.retry_delay_seconds 

560 ) 

561 new_state = AwaitingRetry( 

562 scheduled_time=prefect.types._datetime.now("UTC") 

563 + timedelta(seconds=delay) 

564 ) 

565 else: 

566 delay = None 

567 new_state = Retrying() 

568 

569 self.logger.info( 

570 "Task run failed with %s: %r - Retry %s/%s will start %s", 

571 failure_type, 

572 exc_or_state, 

573 self.retries + 1, 

574 self.task.retries, 

575 str(delay) + " second(s) from now" if delay else "immediately", 

576 ) 

577 

578 self.set_state(new_state, force=True) 

579 # Call on_running hooks if we transitioned to a Running state (immediate retry) 

580 if new_state.is_running(): 

581 self.call_hooks(new_state) 

582 self.retries: int = self.retries + 1 

583 return True 

584 elif self.retries >= self.task.retries: 

585 if self.task.retries > 0: 

586 self.logger.error( 

587 f"Task run failed with {failure_type}: {exc_or_state!r} - Retries are exhausted", 

588 exc_info=True, 

589 ) 

590 else: 

591 self.logger.error( 

592 f"Task run failed with {failure_type}: {exc_or_state!r}", 

593 exc_info=True, 

594 ) 

595 return False 

596 

597 return False 

598 

599 def handle_exception(self, exc: Exception) -> None: 1a

600 # If the task fails, and we have retries left, set the task to retrying. 

601 self._telemetry.record_exception(exc) 

602 if not self.handle_retry(exc): 

603 # If the task has no retries left, or the retry condition is not met, set the task to failed. 

604 state = run_coro_as_sync( 

605 exception_to_failed_state( 

606 exc, 

607 message="Task run encountered an exception", 

608 result_store=get_result_store(), 

609 write_result=True, 

610 ) 

611 ) 

612 self.set_state(state) 

613 self._raised = exc 

614 self._telemetry.end_span_on_failure(state.message if state else None) 

615 

616 def handle_timeout(self, exc: TimeoutError) -> None: 1a

617 if not self.handle_retry(exc): 

618 if isinstance(exc, TaskRunTimeoutError): 

619 message = f"Task run exceeded timeout of {self.task.timeout_seconds} second(s)" 

620 else: 

621 message = f"Task run failed due to timeout: {exc!r}" 

622 self.logger.error(message) 

623 state = Failed( 

624 data=exc, 

625 message=message, 

626 name="TimedOut", 

627 ) 

628 self.set_state(state) 

629 self._raised = exc 

630 

631 def handle_crash(self, exc: BaseException) -> None: 1a

632 state = run_coro_as_sync(exception_to_crashed_state(exc)) 

633 self.logger.error(f"Crash detected! {state.message}") 

634 self.logger.debug("Crash details:", exc_info=exc) 

635 self.set_state(state, force=True) 

636 self._raised = exc 

637 self._telemetry.record_exception(exc) 

638 self._telemetry.end_span_on_failure(state.message if state else None) 

639 

640 @contextmanager 1a

641 def setup_run_context(self, client: Optional[SyncPrefectClient] = None): 1a

642 from prefect.utilities.engine import ( 

643 should_log_prints, 

644 ) 

645 

646 settings = get_current_settings() 

647 

648 if client is None: 

649 client = self.client 

650 if not self.task_run: 

651 raise ValueError("Task run is not set") 

652 

653 with ExitStack() as stack: 

654 if log_prints := should_log_prints(self.task): 

655 stack.enter_context(patch_print()) 

656 if self.task.persist_result is not None: 

657 persist_result = self.task.persist_result 

658 elif settings.tasks.default_persist_result is not None: 

659 persist_result = settings.tasks.default_persist_result 

660 else: 

661 persist_result = should_persist_result() 

662 

663 stack.enter_context( 

664 TaskRunContext( 

665 task=self.task, 

666 log_prints=log_prints, 

667 task_run=self.task_run, 

668 parameters=self.parameters, 

669 result_store=get_result_store().update_for_task( 

670 self.task, _sync=True 

671 ), 

672 client=client, 

673 persist_result=persist_result, 

674 ) 

675 ) 

676 stack.enter_context(ConcurrencyContext()) 

677 

678 self.logger: "logging.Logger" = task_run_logger( 

679 task_run=self.task_run, task=self.task 

680 ) # type: ignore 

681 

682 yield 

683 

684 @contextmanager 1a

685 def asset_context(self): 1a

686 parent_asset_ctx = AssetContext.get() 

687 

688 if parent_asset_ctx and parent_asset_ctx.copy_to_child_ctx: 

689 asset_ctx = parent_asset_ctx.model_copy() 

690 asset_ctx.copy_to_child_ctx = False 

691 else: 

692 asset_ctx = AssetContext.from_task_and_inputs( 

693 self.task, self.task_run.id, self.task_run.task_inputs 

694 ) 

695 

696 with asset_ctx as ctx: 

697 try: 

698 yield 

699 finally: 

700 ctx.emit_events(self.state) 

701 

702 @contextmanager 1a

703 def initialize_run( 1a

704 self, 

705 task_run_id: Optional[UUID] = None, 

706 dependencies: Optional[dict[str, set[RunInput]]] = None, 

707 ) -> Generator[Self, Any, Any]: 

708 """ 

709 Enters a client context and creates a task run if needed. 

710 """ 

711 

712 with hydrated_context(self.context): 

713 with SyncClientContext.get_or_create() as client_ctx: 

714 self._client = client_ctx.client 

715 self._is_started = True 

716 parent_flow_run_context = FlowRunContext.get() 

717 parent_task_run_context = TaskRunContext.get() 

718 

719 try: 

720 if not self.task_run: 

721 self.task_run = run_coro_as_sync( 

722 self.task.create_local_run( 

723 id=task_run_id, 

724 parameters=self.parameters, 

725 flow_run_context=parent_flow_run_context, 

726 parent_task_run_context=parent_task_run_context, 

727 wait_for=self.wait_for, 

728 extra_task_inputs=dependencies, 

729 ) 

730 ) 

731 # Emit an event to capture that the task run was in the `PENDING` state. 

732 self._last_event = emit_task_run_state_change_event( 

733 task_run=self.task_run, 

734 initial_state=None, 

735 validated_state=self.task_run.state, 

736 ) 

737 

738 with self.setup_run_context(): 

739 # setup_run_context might update the task run name, so log creation here 

740 self.logger.debug( 

741 f"Created task run {self.task_run.name!r} for task {self.task.name!r}" 

742 ) 

743 

744 self._telemetry.start_span( 

745 run=self.task_run, 

746 client=self.client, 

747 parameters=self.parameters, 

748 ) 

749 

750 yield self 

751 

752 except TerminationSignal as exc: 

753 # TerminationSignals are caught and handled as crashes 

754 self.handle_crash(exc) 

755 raise exc 

756 

757 except Exception: 

758 # regular exceptions are caught and re-raised to the user 

759 raise 

760 except (Pause, Abort) as exc: 

761 # Do not capture internal signals as crashes 

762 if isinstance(exc, Abort): 

763 self.logger.error("Task run was aborted: %s", exc) 

764 raise 

765 except GeneratorExit: 

766 # Do not capture generator exits as crashes 

767 raise 

768 except BaseException as exc: 

769 # BaseExceptions are caught and handled as crashes 

770 self.handle_crash(exc) 

771 raise 

772 finally: 

773 self.log_finished_message() 

774 self._is_started = False 

775 self._client = None 

776 

777 async def wait_until_ready(self) -> None: 1a

778 """Waits until the scheduled time (if its the future), then enters Running.""" 

779 if scheduled_time := self.state.state_details.scheduled_time: 

780 sleep_time = ( 

781 scheduled_time - prefect.types._datetime.now("UTC") 

782 ).total_seconds() 

783 await anyio.sleep(sleep_time if sleep_time > 0 else 0) 

784 new_state = Retrying() if self.state.name == "AwaitingRetry" else Running() 

785 self.set_state( 

786 new_state, 

787 force=True, 

788 ) 

789 # Call on_running hooks if we transitioned to a Running state 

790 if self.state.is_running(): 

791 self.call_hooks() 

792 

793 # -------------------------- 

794 # 

795 # The following methods compose the main task run loop 

796 # 

797 # -------------------------- 

798 

799 @contextmanager 1a

800 def start( 1a

801 self, 

802 task_run_id: Optional[UUID] = None, 

803 dependencies: Optional[dict[str, set[RunInput]]] = None, 

804 ) -> Generator[None, None, None]: 

805 with self.initialize_run(task_run_id=task_run_id, dependencies=dependencies): 

806 with ( 

807 trace.use_span(self._telemetry.span) 

808 if self._telemetry.span 

809 else nullcontext() 

810 ): 

811 try: 

812 self._resolve_parameters() 

813 self._set_custom_task_run_name() 

814 self._wait_for_dependencies() 

815 except UpstreamTaskError as upstream_exc: 

816 self.set_state( 

817 Pending( 

818 name="NotReady", 

819 message=str(upstream_exc), 

820 ), 

821 # if orchestrating a run already in a pending state, force orchestration to 

822 # update the state name 

823 force=self.state.is_pending(), 

824 ) 

825 yield 

826 self.call_hooks() 

827 return 

828 

829 with _concurrency( 

830 names=[f"tag:{tag}" for tag in self.task_run.tags], 

831 occupy=1, 

832 holder=ConcurrencyLeaseHolder(type="task_run", id=self.task_run.id), 

833 lease_duration=60, 

834 suppress_warnings=True, 

835 ): 

836 self.begin_run() 

837 try: 

838 yield 

839 finally: 

840 self.call_hooks() 

841 

842 @contextmanager 1a

843 def transaction_context(self) -> Generator[Transaction, None, None]: 1a

844 # refresh cache setting is now repurposes as overwrite transaction record 

845 overwrite = ( 

846 self.task.refresh_cache 

847 if self.task.refresh_cache is not None 

848 else PREFECT_TASKS_REFRESH_CACHE.value() 

849 ) 

850 

851 isolation_level = ( 

852 IsolationLevel(self.task.cache_policy.isolation_level) 

853 if self.task.cache_policy 

854 and self.task.cache_policy is not NotSet 

855 and self.task.cache_policy.isolation_level is not None 

856 else None 

857 ) 

858 

859 with transaction( 

860 key=self.compute_transaction_key(), 

861 store=get_result_store(), 

862 overwrite=overwrite, 

863 logger=self.logger, 

864 write_on_commit=should_persist_result(), 

865 isolation_level=isolation_level, 

866 ) as txn: 

867 yield txn 

868 

869 @contextmanager 1a

870 def run_context(self): 1a

871 # reenter the run context to ensure it is up to date for every run 

872 with self.setup_run_context(): 

873 try: 

874 with timeout( 

875 seconds=self.task.timeout_seconds, 

876 timeout_exc_type=TaskRunTimeoutError, 

877 ): 

878 self.logger.debug( 

879 f"Executing task {self.task.name!r} for task run {self.task_run.name!r}..." 

880 ) 

881 if self.is_cancelled(): 

882 raise CancelledError("Task run cancelled by the task runner") 

883 

884 yield self 

885 except TimeoutError as exc: 

886 self.handle_timeout(exc) 

887 except Exception as exc: 

888 self.handle_exception(exc) 

889 

890 def call_task_fn( 1a

891 self, transaction: Transaction 

892 ) -> Union[ResultRecord[Any], None, Coroutine[Any, Any, R], R]: 

893 """ 

894 Convenience method to call the task function. Returns a coroutine if the 

895 task is async. 

896 """ 

897 parameters = self.parameters or {} 

898 if transaction.is_committed(): 

899 result = transaction.read() 

900 else: 

901 result = call_with_parameters(self.task.fn, parameters) 

902 self.handle_success(result, transaction=transaction) 

903 return result 

904 

905 

906@dataclass 1a

907class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]): 1a

908 task_run: TaskRun | None = None 1a

909 _client: Optional[PrefectClient] = None 1a

910 

911 @property 1a

912 def client(self) -> PrefectClient: 1a

913 if not self._is_started or self._client is None: 

914 raise RuntimeError("Engine has not started.") 

915 return self._client 

916 

917 async def can_retry(self, exc_or_state: Exception | State[R]) -> bool: 1a

918 retry_condition: Optional[ 

919 Callable[["Task[P, Coroutine[Any, Any, R]]", TaskRun, State[R]], bool] 

920 ] = self.task.retry_condition_fn 

921 

922 failure_type = "exception" if isinstance(exc_or_state, Exception) else "state" 

923 

924 if not self.task_run: 

925 raise ValueError("Task run is not set") 

926 try: 

927 self.logger.debug( 

928 f"Running `retry_condition_fn` check {retry_condition!r} for task" 

929 f" {self.task.name!r}" 

930 ) 

931 state = Failed( 

932 data=exc_or_state, 

933 message=f"Task run encountered unexpected {failure_type}: {repr(exc_or_state)}", 

934 ) 

935 if inspect.iscoroutinefunction(retry_condition): 

936 should_retry = await retry_condition(self.task, self.task_run, state) 

937 elif inspect.isfunction(retry_condition): 

938 should_retry = retry_condition(self.task, self.task_run, state) 

939 else: 

940 should_retry = not retry_condition 

941 return should_retry 

942 

943 except Exception: 

944 self.logger.error( 

945 ( 

946 "An error was encountered while running `retry_condition_fn` check" 

947 f" '{retry_condition!r}' for task {self.task.name!r}" 

948 ), 

949 exc_info=True, 

950 ) 

951 return False 

952 

953 async def call_hooks(self, state: Optional[State] = None) -> None: 1a

954 if state is None: 

955 state = self.state 

956 task = self.task 

957 task_run = self.task_run 

958 

959 if not task_run: 

960 raise ValueError("Task run is not set") 

961 

962 if state.is_failed() and task.on_failure_hooks: 

963 hooks = task.on_failure_hooks 

964 elif state.is_completed() and task.on_completion_hooks: 

965 hooks = task.on_completion_hooks 

966 elif state.is_running() and task.on_running_hooks: 

967 hooks = task.on_running_hooks 

968 else: 

969 hooks = None 

970 

971 for hook in hooks or []: 

972 hook_name = get_hook_name(hook) 

973 

974 try: 

975 self.logger.info( 

976 f"Running hook {hook_name!r} in response to entering state" 

977 f" {state.name!r}" 

978 ) 

979 result = hook(task, task_run, state) 

980 if inspect.isawaitable(result): 

981 await result 

982 except Exception: 

983 self.logger.error( 

984 f"An error was encountered while running hook {hook_name!r}", 

985 exc_info=True, 

986 ) 

987 else: 

988 self.logger.info(f"Hook {hook_name!r} finished running successfully") 

989 

990 async def begin_run(self) -> None: 1a

991 try: 

992 self._resolve_parameters() 

993 self._set_custom_task_run_name() 

994 except UpstreamTaskError as upstream_exc: 

995 state = await self.set_state( 

996 Pending( 

997 name="NotReady", 

998 message=str(upstream_exc), 

999 ), 

1000 # if orchestrating a run already in a pending state, force orchestration to 

1001 # update the state name 

1002 force=self.state.is_pending(), 

1003 ) 

1004 return 

1005 

1006 new_state = Running() 

1007 

1008 self.task_run.start_time = new_state.timestamp 

1009 

1010 flow_run_context = FlowRunContext.get() 

1011 if flow_run_context: 

1012 # Carry forward any task run information from the flow run 

1013 flow_run = flow_run_context.flow_run 

1014 self.task_run.flow_run_run_count = flow_run.run_count 

1015 

1016 state = await self.set_state(new_state) 

1017 

1018 # TODO: this is temporary until the API stops rejecting state transitions 

1019 # and the client / transaction store becomes the source of truth 

1020 # this is a bandaid caused by the API storing a Completed state with a bad 

1021 # result reference that no longer exists 

1022 if state.is_completed(): 

1023 try: 

1024 await state.result(retry_result_failure=False) 

1025 except Exception: 

1026 state = await self.set_state(new_state, force=True) 

1027 

1028 backoff_count = 0 

1029 

1030 # TODO: Could this listen for state change events instead of polling? 

1031 while state.is_pending() or state.is_paused(): 

1032 if backoff_count < BACKOFF_MAX: 

1033 backoff_count += 1 

1034 interval = clamped_poisson_interval( 

1035 average_interval=backoff_count, clamping_factor=0.3 

1036 ) 

1037 await anyio.sleep(interval) 

1038 state = await self.set_state(new_state) 

1039 

1040 # Call on_running hooks after the task has entered the Running state 

1041 if state.is_running(): 

1042 await self.call_hooks(state) 

1043 

1044 async def set_state(self, state: State, force: bool = False) -> State: 1a

1045 last_state = self.state 

1046 if not self.task_run: 

1047 raise ValueError("Task run is not set") 

1048 

1049 self.task_run.state = new_state = state 

1050 

1051 if last_state.timestamp == new_state.timestamp: 

1052 # Ensure that the state timestamp is unique, or at least not equal to the last state. 

1053 # This might occur especially on Windows where the timestamp resolution is limited. 

1054 new_state.timestamp += timedelta(microseconds=1) 

1055 

1056 # Ensure that the state_details are populated with the current run IDs 

1057 new_state.state_details.task_run_id = self.task_run.id 

1058 new_state.state_details.flow_run_id = self.task_run.flow_run_id 

1059 

1060 # Predictively update the de-normalized task_run.state_* attributes 

1061 self.task_run.state_id = new_state.id 

1062 self.task_run.state_type = new_state.type 

1063 self.task_run.state_name = new_state.name 

1064 

1065 if last_state.is_running(): 

1066 self.task_run.total_run_time += new_state.timestamp - last_state.timestamp 

1067 

1068 if new_state.is_running(): 

1069 self.task_run.run_count += 1 

1070 

1071 if new_state.is_final(): 

1072 if ( 

1073 self.task_run 

1074 and self.task_run.start_time 

1075 and not self.task_run.end_time 

1076 ): 

1077 self.task_run.end_time = new_state.timestamp 

1078 

1079 if isinstance(new_state.data, ResultRecord): 

1080 result = new_state.data.result 

1081 else: 

1082 result = new_state.data 

1083 

1084 link_state_to_task_run_result(new_state, result) 

1085 

1086 # emit a state change event 

1087 self._last_event = emit_task_run_state_change_event( 

1088 task_run=self.task_run, 

1089 initial_state=last_state, 

1090 validated_state=self.task_run.state, 

1091 follows=self._last_event, 

1092 ) 

1093 

1094 self._telemetry.update_state(new_state) 

1095 return new_state 

1096 

1097 async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": 1a

1098 if self._return_value is not NotSet: 

1099 if isinstance(self._return_value, ResultRecord): 

1100 return self._return_value.result 

1101 # otherwise, return the value as is 

1102 return self._return_value 

1103 

1104 if self._raised is not NotSet: 

1105 # if the task raised an exception, raise it 

1106 if raise_on_failure: 

1107 raise self._raised 

1108 

1109 # otherwise, return the exception 

1110 return self._raised 

1111 

1112 async def handle_success( 1a

1113 self, result: R, transaction: AsyncTransaction 

1114 ) -> Union[ResultRecord[R], None, Coroutine[Any, Any, R], R]: 

1115 if isinstance(result, State) and result.is_failed(): 

1116 if await self.handle_retry(result): 

1117 return None 

1118 

1119 if self.task.cache_expiration is not None: 

1120 expiration = prefect.types._datetime.now("UTC") + self.task.cache_expiration 

1121 else: 

1122 expiration = None 

1123 

1124 terminal_state = await return_value_to_state( 

1125 result, 

1126 result_store=get_result_store(), 

1127 key=transaction.key, 

1128 expiration=expiration, 

1129 ) 

1130 

1131 # Avoid logging when running this rollback hook since it is not user-defined 

1132 handle_rollback = partial(self.handle_rollback) 

1133 handle_rollback.log_on_run = False 

1134 

1135 transaction.stage( 

1136 terminal_state.data, 

1137 on_rollback_hooks=[handle_rollback] + self.task.on_rollback_hooks, 

1138 on_commit_hooks=self.task.on_commit_hooks, 

1139 ) 

1140 if transaction.is_committed(): 

1141 terminal_state.name = "Cached" 

1142 

1143 await self.set_state(terminal_state) 

1144 self._return_value = result 

1145 self._telemetry.end_span_on_success() 

1146 

1147 return result 

1148 

1149 async def handle_retry(self, exc_or_state: Exception | State[R]) -> bool: 1a

1150 """Handle any task run retries. 

1151 

1152 - If the task has retries left, and the retry condition is met, set the task to retrying and return True. 

1153 - If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time. 

1154 - If the task has no retries left, or the retry condition is not met, return False. 

1155 """ 

1156 failure_type = "exception" if isinstance(exc_or_state, Exception) else "state" 

1157 

1158 if self.retries < self.task.retries and await self.can_retry(exc_or_state): 

1159 if self.task.retry_delay_seconds: 

1160 delay = ( 

1161 self.task.retry_delay_seconds[ 

1162 min(self.retries, len(self.task.retry_delay_seconds) - 1) 

1163 ] # repeat final delay value if attempts exceed specified delays 

1164 if isinstance(self.task.retry_delay_seconds, Sequence) 

1165 else self.task.retry_delay_seconds 

1166 ) 

1167 new_state = AwaitingRetry( 

1168 scheduled_time=prefect.types._datetime.now("UTC") 

1169 + timedelta(seconds=delay) 

1170 ) 

1171 else: 

1172 delay = None 

1173 new_state = Retrying() 

1174 

1175 self.logger.info( 

1176 "Task run failed with %s: %r - Retry %s/%s will start %s", 

1177 failure_type, 

1178 exc_or_state, 

1179 self.retries + 1, 

1180 self.task.retries, 

1181 str(delay) + " second(s) from now" if delay else "immediately", 

1182 ) 

1183 

1184 await self.set_state(new_state, force=True) 

1185 # Call on_running hooks if we transitioned to a Running state (immediate retry) 

1186 if new_state.is_running(): 

1187 await self.call_hooks(new_state) 

1188 self.retries: int = self.retries + 1 

1189 return True 

1190 elif self.retries >= self.task.retries: 

1191 if self.task.retries > 0: 

1192 self.logger.error( 

1193 f"Task run failed with {failure_type}: {exc_or_state!r} - Retries are exhausted", 

1194 exc_info=True, 

1195 ) 

1196 else: 

1197 self.logger.error( 

1198 f"Task run failed with {failure_type}: {exc_or_state!r}", 

1199 exc_info=True, 

1200 ) 

1201 return False 

1202 

1203 return False 

1204 

1205 async def handle_exception(self, exc: Exception) -> None: 1a

1206 # If the task fails, and we have retries left, set the task to retrying. 

1207 self._telemetry.record_exception(exc) 

1208 if not await self.handle_retry(exc): 

1209 # If the task has no retries left, or the retry condition is not met, set the task to failed. 

1210 state = await exception_to_failed_state( 

1211 exc, 

1212 message="Task run encountered an exception", 

1213 result_store=get_result_store(), 

1214 ) 

1215 await self.set_state(state) 

1216 self._raised = exc 

1217 

1218 self._telemetry.end_span_on_failure(state.message) 

1219 

1220 async def handle_timeout(self, exc: TimeoutError) -> None: 1a

1221 self._telemetry.record_exception(exc) 

1222 if not await self.handle_retry(exc): 

1223 if isinstance(exc, TaskRunTimeoutError): 

1224 message = f"Task run exceeded timeout of {self.task.timeout_seconds} second(s)" 

1225 else: 

1226 message = f"Task run failed due to timeout: {exc!r}" 

1227 self.logger.error(message) 

1228 state = Failed( 

1229 data=exc, 

1230 message=message, 

1231 name="TimedOut", 

1232 ) 

1233 await self.set_state(state) 

1234 self._raised = exc 

1235 self._telemetry.end_span_on_failure(state.message) 

1236 

1237 async def handle_crash(self, exc: BaseException) -> None: 1a

1238 state = await exception_to_crashed_state(exc) 

1239 self.logger.error(f"Crash detected! {state.message}") 

1240 self.logger.debug("Crash details:", exc_info=exc) 

1241 await self.set_state(state, force=True) 

1242 self._raised = exc 

1243 

1244 self._telemetry.record_exception(exc) 

1245 self._telemetry.end_span_on_failure(state.message) 

1246 

1247 @asynccontextmanager 1a

1248 async def setup_run_context(self, client: Optional[PrefectClient] = None): 1a

1249 from prefect.utilities.engine import ( 

1250 should_log_prints, 

1251 ) 

1252 

1253 settings = get_current_settings() 

1254 

1255 if client is None: 

1256 client = self.client 

1257 if not self.task_run: 

1258 raise ValueError("Task run is not set") 

1259 

1260 with ExitStack() as stack: 

1261 if log_prints := should_log_prints(self.task): 

1262 stack.enter_context(patch_print()) 

1263 if self.task.persist_result is not None: 

1264 persist_result = self.task.persist_result 

1265 elif settings.tasks.default_persist_result is not None: 

1266 persist_result = settings.tasks.default_persist_result 

1267 else: 

1268 persist_result = should_persist_result() 

1269 

1270 stack.enter_context( 

1271 TaskRunContext( 

1272 task=self.task, 

1273 log_prints=log_prints, 

1274 task_run=self.task_run, 

1275 parameters=self.parameters, 

1276 result_store=await get_result_store().update_for_task( 

1277 self.task, _sync=False 

1278 ), 

1279 client=client, 

1280 persist_result=persist_result, 

1281 ) 

1282 ) 

1283 stack.enter_context(ConcurrencyContext()) 

1284 

1285 self.logger: "logging.Logger" = task_run_logger( 

1286 task_run=self.task_run, task=self.task 

1287 ) # type: ignore 

1288 

1289 yield 

1290 

1291 @asynccontextmanager 1a

1292 async def asset_context(self): 1a

1293 parent_asset_ctx = AssetContext.get() 

1294 

1295 if parent_asset_ctx and parent_asset_ctx.copy_to_child_ctx: 

1296 asset_ctx = parent_asset_ctx.model_copy() 

1297 asset_ctx.copy_to_child_ctx = False 

1298 else: 

1299 asset_ctx = AssetContext.from_task_and_inputs( 

1300 self.task, self.task_run.id, self.task_run.task_inputs 

1301 ) 

1302 

1303 with asset_ctx as ctx: 

1304 try: 

1305 yield 

1306 finally: 

1307 ctx.emit_events(self.state) 

1308 

1309 @asynccontextmanager 1a

1310 async def initialize_run( 1a

1311 self, 

1312 task_run_id: Optional[UUID] = None, 

1313 dependencies: Optional[dict[str, set[RunInput]]] = None, 

1314 ) -> AsyncGenerator[Self, Any]: 

1315 """ 

1316 Enters a client context and creates a task run if needed. 

1317 """ 

1318 

1319 with hydrated_context(self.context): 

1320 async with AsyncClientContext.get_or_create(): 

1321 self._client = get_client() 

1322 self._is_started = True 

1323 parent_flow_run_context = FlowRunContext.get() 

1324 parent_task_run_context = TaskRunContext.get() 

1325 

1326 try: 

1327 if not self.task_run: 

1328 self.task_run = await self.task.create_local_run( 

1329 id=task_run_id, 

1330 parameters=self.parameters, 

1331 flow_run_context=parent_flow_run_context, 

1332 parent_task_run_context=parent_task_run_context, 

1333 wait_for=self.wait_for, 

1334 extra_task_inputs=dependencies, 

1335 ) 

1336 # Emit an event to capture that the task run was in the `PENDING` state. 

1337 self._last_event = emit_task_run_state_change_event( 

1338 task_run=self.task_run, 

1339 initial_state=None, 

1340 validated_state=self.task_run.state, 

1341 ) 

1342 

1343 async with self.setup_run_context(): 

1344 # setup_run_context might update the task run name, so log creation here 

1345 self.logger.debug( 

1346 f"Created task run {self.task_run.name!r} for task {self.task.name!r}" 

1347 ) 

1348 

1349 await self._telemetry.async_start_span( 

1350 run=self.task_run, 

1351 client=self.client, 

1352 parameters=self.parameters, 

1353 ) 

1354 

1355 yield self 

1356 

1357 except TerminationSignal as exc: 

1358 # TerminationSignals are caught and handled as crashes 

1359 await self.handle_crash(exc) 

1360 raise exc 

1361 

1362 except Exception: 

1363 # regular exceptions are caught and re-raised to the user 

1364 raise 

1365 except (Pause, Abort) as exc: 

1366 # Do not capture internal signals as crashes 

1367 if isinstance(exc, Abort): 

1368 self.logger.error("Task run was aborted: %s", exc) 

1369 raise 

1370 except GeneratorExit: 

1371 # Do not capture generator exits as crashes 

1372 raise 

1373 except BaseException as exc: 

1374 # BaseExceptions are caught and handled as crashes 

1375 await self.handle_crash(exc) 

1376 raise 

1377 finally: 

1378 self.log_finished_message() 

1379 self._is_started = False 

1380 self._client = None 

1381 

1382 async def wait_until_ready(self) -> None: 1a

1383 """Waits until the scheduled time (if its the future), then enters Running.""" 

1384 if scheduled_time := self.state.state_details.scheduled_time: 

1385 sleep_time = ( 

1386 scheduled_time - prefect.types._datetime.now("UTC") 

1387 ).total_seconds() 

1388 await anyio.sleep(sleep_time if sleep_time > 0 else 0) 

1389 new_state = Retrying() if self.state.name == "AwaitingRetry" else Running() 

1390 await self.set_state( 

1391 new_state, 

1392 force=True, 

1393 ) 

1394 # Call on_running hooks if we transitioned to a Running state 

1395 if self.state.is_running(): 

1396 await self.call_hooks() 

1397 

1398 # -------------------------- 

1399 # 

1400 # The following methods compose the main task run loop 

1401 # 

1402 # -------------------------- 

1403 

1404 @asynccontextmanager 1a

1405 async def start( 1a

1406 self, 

1407 task_run_id: Optional[UUID] = None, 

1408 dependencies: Optional[dict[str, set[RunInput]]] = None, 

1409 ) -> AsyncGenerator[None, None]: 

1410 async with self.initialize_run( 

1411 task_run_id=task_run_id, dependencies=dependencies 

1412 ): 

1413 with ( 

1414 trace.use_span(self._telemetry.span) 

1415 if self._telemetry.span 

1416 else nullcontext() 

1417 ): 

1418 try: 

1419 self._resolve_parameters() 

1420 self._set_custom_task_run_name() 

1421 self._wait_for_dependencies() 

1422 except UpstreamTaskError as upstream_exc: 

1423 await self.set_state( 

1424 Pending( 

1425 name="NotReady", 

1426 message=str(upstream_exc), 

1427 ), 

1428 # if orchestrating a run already in a pending state, force orchestration to 

1429 # update the state name 

1430 force=self.state.is_pending(), 

1431 ) 

1432 yield 

1433 await self.call_hooks() 

1434 return 

1435 

1436 async with _aconcurrency( 

1437 names=[f"tag:{tag}" for tag in self.task_run.tags], 

1438 occupy=1, 

1439 holder=ConcurrencyLeaseHolder(type="task_run", id=self.task_run.id), 

1440 lease_duration=60, 

1441 suppress_warnings=True, 

1442 ): 

1443 await self.begin_run() 

1444 try: 

1445 yield 

1446 finally: 

1447 await self.call_hooks() 

1448 

1449 @asynccontextmanager 1a

1450 async def transaction_context(self) -> AsyncGenerator[AsyncTransaction, None]: 1a

1451 # refresh cache setting is now repurposes as overwrite transaction record 

1452 overwrite = ( 

1453 self.task.refresh_cache 

1454 if self.task.refresh_cache is not None 

1455 else PREFECT_TASKS_REFRESH_CACHE.value() 

1456 ) 

1457 isolation_level = ( 

1458 IsolationLevel(self.task.cache_policy.isolation_level) 

1459 if self.task.cache_policy 

1460 and self.task.cache_policy is not NotSet 

1461 and self.task.cache_policy.isolation_level is not None 

1462 else None 

1463 ) 

1464 

1465 async with atransaction( 

1466 key=self.compute_transaction_key(), 

1467 store=get_result_store(), 

1468 overwrite=overwrite, 

1469 logger=self.logger, 

1470 write_on_commit=should_persist_result(), 

1471 isolation_level=isolation_level, 

1472 ) as txn: 

1473 yield txn 

1474 

1475 @asynccontextmanager 1a

1476 async def run_context(self): 1a

1477 # reenter the run context to ensure it is up to date for every run 

1478 async with self.setup_run_context(): 

1479 try: 

1480 with timeout_async( 

1481 seconds=self.task.timeout_seconds, 

1482 timeout_exc_type=TaskRunTimeoutError, 

1483 ): 

1484 self.logger.debug( 

1485 f"Executing task {self.task.name!r} for task run {self.task_run.name!r}..." 

1486 ) 

1487 if self.is_cancelled(): 

1488 raise CancelledError("Task run cancelled by the task runner") 

1489 

1490 yield self 

1491 except TimeoutError as exc: 

1492 await self.handle_timeout(exc) 

1493 except Exception as exc: 

1494 await self.handle_exception(exc) 

1495 

1496 async def call_task_fn( 1a

1497 self, transaction: AsyncTransaction 

1498 ) -> Union[ResultRecord[Any], None, Coroutine[Any, Any, R], R]: 

1499 """ 

1500 Convenience method to call the task function. Returns a coroutine if the 

1501 task is async. 

1502 """ 

1503 parameters = self.parameters or {} 

1504 if transaction.is_committed(): 

1505 result = await transaction.read() 

1506 else: 

1507 result = await call_with_parameters(self.task.fn, parameters) 

1508 await self.handle_success(result, transaction=transaction) 

1509 return result 

1510 

1511 

1512def run_task_sync( 1a

1513 task: "Task[P, R]", 

1514 task_run_id: Optional[UUID] = None, 

1515 task_run: Optional[TaskRun] = None, 

1516 parameters: Optional[dict[str, Any]] = None, 

1517 wait_for: Optional["OneOrManyFutureOrResult[Any]"] = None, 

1518 return_type: Literal["state", "result"] = "result", 

1519 dependencies: Optional[dict[str, set[RunInput]]] = None, 

1520 context: Optional[dict[str, Any]] = None, 

1521) -> Union[R, State, None]: 

1522 engine = SyncTaskRunEngine[P, R]( 

1523 task=task, 

1524 parameters=parameters, 

1525 task_run=task_run, 

1526 wait_for=wait_for, 

1527 context=context, 

1528 ) 

1529 

1530 with engine.start(task_run_id=task_run_id, dependencies=dependencies): 

1531 while engine.is_running(): 

1532 run_coro_as_sync(engine.wait_until_ready()) 

1533 with ( 

1534 engine.asset_context(), 

1535 engine.run_context(), 

1536 engine.transaction_context() as txn, 

1537 ): 

1538 engine.call_task_fn(txn) 

1539 

1540 return engine.state if return_type == "state" else engine.result() 

1541 

1542 

1543async def run_task_async( 1a

1544 task: "Task[P, R]", 

1545 task_run_id: Optional[UUID] = None, 

1546 task_run: Optional[TaskRun] = None, 

1547 parameters: Optional[dict[str, Any]] = None, 

1548 wait_for: Optional["OneOrManyFutureOrResult[Any]"] = None, 

1549 return_type: Literal["state", "result"] = "result", 

1550 dependencies: Optional[dict[str, set[RunInput]]] = None, 

1551 context: Optional[dict[str, Any]] = None, 

1552) -> Union[R, State, None]: 

1553 engine = AsyncTaskRunEngine[P, R]( 

1554 task=task, 

1555 parameters=parameters, 

1556 task_run=task_run, 

1557 wait_for=wait_for, 

1558 context=context, 

1559 ) 

1560 

1561 async with engine.start(task_run_id=task_run_id, dependencies=dependencies): 

1562 while engine.is_running(): 

1563 await engine.wait_until_ready() 

1564 async with ( 

1565 engine.asset_context(), 

1566 engine.run_context(), 

1567 engine.transaction_context() as txn, 

1568 ): 

1569 await engine.call_task_fn(txn) 

1570 

1571 return engine.state if return_type == "state" else await engine.result() 

1572 

1573 

1574def run_generator_task_sync( 1a

1575 task: "Task[P, R]", 

1576 task_run_id: Optional[UUID] = None, 

1577 task_run: Optional[TaskRun] = None, 

1578 parameters: Optional[dict[str, Any]] = None, 

1579 wait_for: Optional["OneOrManyFutureOrResult[Any]"] = None, 

1580 return_type: Literal["state", "result"] = "result", 

1581 dependencies: Optional[dict[str, set[RunInput]]] = None, 

1582 context: Optional[dict[str, Any]] = None, 

1583) -> Generator[R, None, None]: 

1584 if return_type != "result": 

1585 raise ValueError("The return_type for a generator task must be 'result'") 

1586 

1587 engine = SyncTaskRunEngine[P, R]( 

1588 task=task, 

1589 parameters=parameters, 

1590 task_run=task_run, 

1591 wait_for=wait_for, 

1592 context=context, 

1593 ) 

1594 

1595 with engine.start(task_run_id=task_run_id, dependencies=dependencies): 

1596 while engine.is_running(): 

1597 run_coro_as_sync(engine.wait_until_ready()) 

1598 with ( 

1599 engine.asset_context(), 

1600 engine.run_context(), 

1601 engine.transaction_context() as txn, 

1602 ): 

1603 # TODO: generators should default to commit_mode=OFF 

1604 # because they are dynamic by definition 

1605 # for now we just prevent this branch explicitly 

1606 if False and txn.is_committed(): 

1607 txn.read() 

1608 else: 

1609 call_args, call_kwargs = parameters_to_args_kwargs( 

1610 task.fn, engine.parameters or {} 

1611 ) 

1612 gen = task.fn(*call_args, **call_kwargs) 

1613 try: 

1614 while True: 

1615 gen_result = next(gen) 

1616 # link the current state to the result for dependency tracking 

1617 # 

1618 # TODO: this could grow the task_run_result 

1619 # dictionary in an unbounded way, so finding a 

1620 # way to periodically clean it up (using 

1621 # weakrefs or similar) would be good 

1622 link_state_to_task_run_result(engine.state, gen_result) 

1623 yield gen_result 

1624 except StopIteration as exc: 

1625 engine.handle_success(exc.value, transaction=txn) 

1626 except GeneratorExit as exc: 

1627 engine.handle_success(None, transaction=txn) 

1628 gen.throw(exc) 

1629 

1630 return engine.result() 

1631 

1632 

1633async def run_generator_task_async( 1a

1634 task: "Task[P, R]", 

1635 task_run_id: Optional[UUID] = None, 

1636 task_run: Optional[TaskRun] = None, 

1637 parameters: Optional[dict[str, Any]] = None, 

1638 wait_for: Optional["OneOrManyFutureOrResult[Any]"] = None, 

1639 return_type: Literal["state", "result"] = "result", 

1640 dependencies: Optional[dict[str, set[RunInput]]] = None, 

1641 context: Optional[dict[str, Any]] = None, 

1642) -> AsyncGenerator[R, None]: 

1643 if return_type != "result": 

1644 raise ValueError("The return_type for a generator task must be 'result'") 

1645 engine = AsyncTaskRunEngine[P, R]( 

1646 task=task, 

1647 parameters=parameters, 

1648 task_run=task_run, 

1649 wait_for=wait_for, 

1650 context=context, 

1651 ) 

1652 

1653 async with engine.start(task_run_id=task_run_id, dependencies=dependencies): 

1654 while engine.is_running(): 

1655 await engine.wait_until_ready() 

1656 async with ( 

1657 engine.asset_context(), 

1658 engine.run_context(), 

1659 engine.transaction_context() as txn, 

1660 ): 

1661 # TODO: generators should default to commit_mode=OFF 

1662 # because they are dynamic by definition 

1663 # for now we just prevent this branch explicitly 

1664 if False and txn.is_committed(): 

1665 txn.read() 

1666 else: 

1667 call_args, call_kwargs = parameters_to_args_kwargs( 

1668 task.fn, engine.parameters or {} 

1669 ) 

1670 gen = task.fn(*call_args, **call_kwargs) 

1671 try: 

1672 while True: 

1673 # can't use anext in Python < 3.10 

1674 gen_result = await gen.__anext__() 

1675 # link the current state to the result for dependency tracking 

1676 # 

1677 # TODO: this could grow the task_run_result 

1678 # dictionary in an unbounded way, so finding a 

1679 # way to periodically clean it up (using 

1680 # weakrefs or similar) would be good 

1681 link_state_to_task_run_result(engine.state, gen_result) 

1682 yield gen_result 

1683 except (StopAsyncIteration, GeneratorExit) as exc: 

1684 await engine.handle_success(None, transaction=txn) 

1685 if isinstance(exc, GeneratorExit): 

1686 gen.throw(exc) 

1687 

1688 # async generators can't return, but we can raise failures here 

1689 if engine.state.is_failed(): 

1690 await engine.result() 

1691 

1692 

1693@overload 1a

1694def run_task( 1694 ↛ exitline 1694 didn't return from function 'run_task' because 1a

1695 task: "Task[P, R]", 

1696 task_run_id: Optional[UUID] = None, 

1697 task_run: Optional[TaskRun] = None, 

1698 parameters: Optional[dict[str, Any]] = None, 

1699 wait_for: Optional["OneOrManyFutureOrResult[Any]"] = None, 

1700 return_type: Literal["state"] = "state", 

1701 dependencies: Optional[dict[str, set[RunInput]]] = None, 

1702 context: Optional[dict[str, Any]] = None, 

1703) -> State[R]: ... 

1704 

1705 

1706@overload 1a

1707def run_task( 1707 ↛ exitline 1707 didn't return from function 'run_task' because 1a

1708 task: "Task[P, R]", 

1709 task_run_id: Optional[UUID] = None, 

1710 task_run: Optional[TaskRun] = None, 

1711 parameters: Optional[dict[str, Any]] = None, 

1712 wait_for: Optional["OneOrManyFutureOrResult[Any]"] = None, 

1713 return_type: Literal["result"] = "result", 

1714 dependencies: Optional[dict[str, set[RunInput]]] = None, 

1715 context: Optional[dict[str, Any]] = None, 

1716) -> R: ... 

1717 

1718 

1719def run_task( 1a

1720 task: "Task[P, Union[R, Coroutine[Any, Any, R]]]", 

1721 task_run_id: Optional[UUID] = None, 

1722 task_run: Optional[TaskRun] = None, 

1723 parameters: Optional[dict[str, Any]] = None, 

1724 wait_for: Optional["OneOrManyFutureOrResult[Any]"] = None, 

1725 return_type: Literal["state", "result"] = "result", 

1726 dependencies: Optional[dict[str, set[RunInput]]] = None, 

1727 context: Optional[dict[str, Any]] = None, 

1728) -> Union[R, State, None, Coroutine[Any, Any, Union[R, State, None]]]: 

1729 """ 

1730 Runs the provided task. 

1731 

1732 Args: 

1733 task: The task to run 

1734 task_run_id: The ID of the task run; if not provided, a new task run 

1735 will be created 

1736 task_run: The task run object; if not provided, a new task run 

1737 will be created 

1738 parameters: The parameters to pass to the task 

1739 wait_for: A list of futures to wait for before running the task 

1740 return_type: The return type to return; either "state" or "result" 

1741 dependencies: A dictionary of task run inputs to use for dependency tracking 

1742 context: A dictionary containing the context to use for the task run; only 

1743 required if the task is running on in a remote environment 

1744 

1745 Returns: 

1746 The result of the task run 

1747 """ 

1748 kwargs: dict[str, Any] = dict( 

1749 task=task, 

1750 task_run_id=task_run_id, 

1751 task_run=task_run, 

1752 parameters=parameters, 

1753 wait_for=wait_for, 

1754 return_type=return_type, 

1755 dependencies=dependencies, 

1756 context=context, 

1757 ) 

1758 if task.isasync and task.isgenerator: 

1759 return run_generator_task_async(**kwargs) 

1760 elif task.isgenerator: 

1761 return run_generator_task_sync(**kwargs) 

1762 elif task.isasync: 

1763 return run_task_async(**kwargs) 

1764 else: 

1765 return run_task_sync(**kwargs)