Coverage for /usr/local/lib/python3.12/site-packages/prefect/utilities/engine.py: 13%

292 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 13:38 +0000

1import asyncio 1a

2import contextlib 1a

3import os 1a

4import signal 1a

5import time 1a

6from collections.abc import Awaitable, Callable, Generator 1a

7from functools import partial 1a

8from logging import Logger 1a

9from typing import ( 1a

10 TYPE_CHECKING, 

11 Any, 

12 NoReturn, 

13 Optional, 

14 TypeVar, 

15 Union, 

16 cast, 

17) 

18from uuid import UUID 1a

19 

20import anyio 1a

21from opentelemetry import propagate, trace 1a

22from typing_extensions import TypeIs 1a

23 

24import prefect 1a

25import prefect.exceptions 1a

26from prefect._internal.concurrency.cancellation import get_deadline 1a

27from prefect.client.schemas import FlowRunResult, OrchestrationResult, TaskRun 1a

28from prefect.client.schemas.objects import ( 1a

29 RunType, 

30 TaskRunResult, 

31) 

32from prefect.client.schemas.responses import ( 1a

33 SetStateStatus, 

34 StateAbortDetails, 

35 StateRejectDetails, 

36 StateWaitDetails, 

37) 

38from prefect.context import FlowRunContext 1a

39from prefect.events import Event, emit_event 1a

40from prefect.exceptions import ( 1a

41 Pause, 

42 PrefectException, 

43 TerminationSignal, 

44 UpstreamTaskError, 

45) 

46from prefect.flows import Flow 1a

47from prefect.futures import PrefectFuture 1a

48from prefect.logging.loggers import get_logger 1a

49from prefect.results import ResultRecord, should_persist_result 1a

50from prefect.settings import PREFECT_LOGGING_LOG_PRINTS 1a

51from prefect.states import State 1a

52from prefect.tasks import Task 1a

53from prefect.utilities.annotations import allow_failure, quote 1a

54from prefect.utilities.collections import StopVisiting, visit_collection 1a

55from prefect.utilities.text import truncated_to 1a

56 

57if TYPE_CHECKING: 57 ↛ 58line 57 didn't jump to line 58 because the condition on line 57 was never true1a

58 from prefect.client.orchestration import PrefectClient, SyncPrefectClient 

59 

60API_HEALTHCHECKS: dict[str, float] = {} 1a

61UNTRACKABLE_TYPES: set[type[Any]] = {bool, type(None), type(...), type(NotImplemented)} 1a

62engine_logger: Logger = get_logger("engine") 1a

63T = TypeVar("T") 1a

64 

65 

66async def collect_task_run_inputs( 1a

67 expr: Any, max_depth: int = -1 

68) -> set[Union[TaskRunResult, FlowRunResult]]: 

69 """ 

70 This function recurses through an expression to generate a set of any discernible 

71 task run inputs it finds in the data structure. It produces a set of all inputs 

72 found. 

73 

74 Examples: 

75 

76 ```python 

77 task_inputs = { 

78 k: await collect_task_run_inputs(v) for k, v in parameters.items() 

79 } 

80 ``` 

81 """ 

82 # TODO: This function needs to be updated to detect parameters and constants 

83 

84 inputs: set[Union[TaskRunResult, FlowRunResult]] = set() 

85 

86 def add_futures_and_states_to_inputs(obj: Any) -> None: 

87 if isinstance(obj, PrefectFuture): 

88 inputs.add(TaskRunResult(id=obj.task_run_id)) 

89 elif isinstance(obj, State): 

90 if obj.state_details.task_run_id: 

91 inputs.add( 

92 TaskRunResult( 

93 id=obj.state_details.task_run_id, 

94 ) 

95 ) 

96 # Expressions inside quotes should not be traversed 

97 elif isinstance(obj, quote): 

98 raise StopVisiting 

99 else: 

100 res = get_state_for_result(obj) 

101 if res: 

102 state, run_type = res 

103 run_result = state.state_details.to_run_result(run_type) 

104 if run_result: 

105 inputs.add(run_result) 

106 

107 visit_collection( 

108 expr, 

109 visit_fn=add_futures_and_states_to_inputs, 

110 return_data=False, 

111 max_depth=max_depth, 

112 ) 

113 

114 return inputs 

115 

116 

117def collect_task_run_inputs_sync( 1a

118 expr: Any, future_cls: Any = PrefectFuture, max_depth: int = -1 

119) -> set[Union[TaskRunResult, FlowRunResult]]: 

120 """ 

121 This function recurses through an expression to generate a set of any discernible 

122 task run inputs it finds in the data structure. It produces a set of all inputs 

123 found. 

124 

125 Examples: 

126 ```python 

127 task_inputs = { 

128 k: collect_task_run_inputs_sync(v) for k, v in parameters.items() 

129 } 

130 ``` 

131 """ 

132 # TODO: This function needs to be updated to detect parameters and constants 

133 

134 inputs: set[Union[TaskRunResult, FlowRunResult]] = set() 

135 

136 def add_futures_and_states_to_inputs(obj: Any) -> None: 

137 if isinstance(obj, future_cls) and hasattr(obj, "task_run_id"): 

138 inputs.add( 

139 TaskRunResult( 

140 id=obj.task_run_id, 

141 ) 

142 ) 

143 elif isinstance(obj, State): 

144 if obj.state_details.task_run_id: 

145 inputs.add( 

146 TaskRunResult( 

147 id=obj.state_details.task_run_id, 

148 ) 

149 ) 

150 # Expressions inside quotes should not be traversed 

151 elif isinstance(obj, quote): 

152 raise StopVisiting 

153 else: 

154 res = get_state_for_result(obj) 

155 if res: 

156 state, run_type = res 

157 run_result = state.state_details.to_run_result(run_type) 

158 if run_result: 

159 inputs.add(run_result) 

160 

161 visit_collection( 

162 expr, 

163 visit_fn=add_futures_and_states_to_inputs, 

164 return_data=False, 

165 max_depth=max_depth, 

166 ) 

167 

168 return inputs 

169 

170 

171@contextlib.contextmanager 1a

172def capture_sigterm() -> Generator[None, Any, None]: 1a

173 def cancel_flow_run(*args: object) -> NoReturn: 

174 raise TerminationSignal(signal=signal.SIGTERM) 

175 

176 original_term_handler = None 

177 try: 

178 original_term_handler = signal.signal(signal.SIGTERM, cancel_flow_run) 

179 except ValueError: 

180 # Signals only work in the main thread 

181 pass 

182 

183 try: 

184 yield 

185 except TerminationSignal as exc: 

186 # Termination signals are swapped out during a flow run to perform 

187 # a graceful shutdown and raise this exception. This `os.kill` call 

188 # ensures that the previous handler, likely the Python default, 

189 # gets called as well. 

190 if original_term_handler is not None: 

191 signal.signal(exc.signal, original_term_handler) 

192 os.kill(os.getpid(), exc.signal) 

193 

194 raise 

195 

196 finally: 

197 if original_term_handler is not None: 

198 signal.signal(signal.SIGTERM, original_term_handler) 

199 

200 

201async def resolve_inputs( 1a

202 parameters: dict[str, Any], return_data: bool = True, max_depth: int = -1 

203) -> dict[str, Any]: 

204 """ 

205 Resolve any `Quote`, `PrefectFuture`, or `State` types nested in parameters into 

206 data. 

207 

208 Returns: 

209 A copy of the parameters with resolved data 

210 

211 Raises: 

212 UpstreamTaskError: If any of the upstream states are not `COMPLETED` 

213 """ 

214 

215 futures: set[PrefectFuture[Any]] = set() 

216 states: set[State[Any]] = set() 

217 result_by_state: dict[State[Any], Any] = {} 

218 

219 if not parameters: 

220 return {} 

221 

222 def collect_futures_and_states(expr: Any, context: dict[str, Any]) -> Any: 

223 # Expressions inside quotes should not be traversed 

224 if isinstance(context.get("annotation"), quote): 

225 raise StopVisiting() 

226 

227 if isinstance(expr, PrefectFuture): 

228 fut: PrefectFuture[Any] = expr 

229 futures.add(fut) 

230 if isinstance(expr, State): 

231 state: State[Any] = expr 

232 states.add(state) 

233 

234 return cast(Any, expr) 

235 

236 visit_collection( 

237 parameters, 

238 visit_fn=collect_futures_and_states, 

239 return_data=False, 

240 max_depth=max_depth, 

241 context={}, 

242 ) 

243 

244 # Only retrieve the result if requested as it may be expensive 

245 if return_data: 

246 finished_states = [state for state in states if state.is_final()] 

247 

248 state_results = [ 

249 state.aresult(raise_on_failure=False) for state in finished_states 

250 ] 

251 state_results = await asyncio.gather(*state_results) 

252 

253 for state, result in zip(finished_states, state_results): 

254 result_by_state[state] = result 

255 

256 def resolve_input(expr: Any, context: dict[str, Any]) -> Any: 

257 state: Optional[State[Any]] = None 

258 

259 # Expressions inside quotes should not be modified 

260 if isinstance(context.get("annotation"), quote): 

261 raise StopVisiting() 

262 

263 if isinstance(expr, PrefectFuture): 

264 state = expr.state 

265 elif isinstance(expr, State): 

266 state = expr 

267 else: 

268 return expr 

269 

270 # Do not allow uncompleted upstreams except failures when `allow_failure` has 

271 # been used 

272 if not state.is_completed() and not ( 

273 # TODO: Note that the contextual annotation here is only at the current level 

274 # if `allow_failure` is used then another annotation is used, this will 

275 # incorrectly evaluate to false — to resolve this, we must track all 

276 # annotations wrapping the current expression but this is not yet 

277 # implemented. 

278 isinstance(context.get("annotation"), allow_failure) and state.is_failed() 

279 ): 

280 raise UpstreamTaskError( 

281 f"Upstream task run '{state.state_details.task_run_id}' did not reach a" 

282 " 'COMPLETED' state." 

283 ) 

284 

285 return result_by_state.get(state) 

286 

287 resolved_parameters: dict[str, Any] = {} 

288 for parameter, value in parameters.items(): 

289 try: 

290 resolved_parameters[parameter] = visit_collection( 

291 value, 

292 visit_fn=resolve_input, 

293 return_data=return_data, 

294 # we're manually going 1 layer deeper here 

295 max_depth=max_depth - 1, 

296 remove_annotations=True, 

297 context={}, 

298 ) 

299 except UpstreamTaskError: 

300 raise 

301 except Exception as exc: 

302 raise PrefectException( 

303 f"Failed to resolve inputs in parameter {parameter!r}. If your" 

304 " parameter type is not supported, consider using the `quote`" 

305 " annotation to skip resolution of inputs." 

306 ) from exc 

307 

308 return resolved_parameters 

309 

310 

311def _is_result_record(data: Any) -> TypeIs[ResultRecord[Any]]: 1a

312 return isinstance(data, ResultRecord) 

313 

314 

315async def propose_state( 1a

316 client: "PrefectClient", 

317 state: State[Any], 

318 flow_run_id: UUID, 

319 force: bool = False, 

320) -> State[Any]: 

321 """ 

322 Propose a new state for a flow run, invoking Prefect orchestration logic. 

323 

324 If the proposed state is accepted, the provided `state` will be augmented with 

325 details and returned. 

326 

327 If the proposed state is rejected, a new state returned by the Prefect API will be 

328 returned. 

329 

330 If the proposed state results in a WAIT instruction from the Prefect API, the 

331 function will sleep and attempt to propose the state again. 

332 

333 If the proposed state results in an ABORT instruction from the Prefect API, an 

334 error will be raised. 

335 

336 Args: 

337 state: a new state for a flow run 

338 flow_run_id: an optional flow run id, used when proposing flow run states 

339 

340 Returns: 

341 a State model representation of the flow run state 

342 

343 Raises: 

344 prefect.exceptions.Abort: if an ABORT instruction is received from 

345 the Prefect API 

346 """ 

347 

348 if not flow_run_id: 

349 raise ValueError("You must provide a `flow_run_id`") 

350 

351 # Handle sub-flow tracing 

352 if state.is_final(): 

353 result: Any 

354 if _is_result_record(state.data): 

355 result = state.data.result 

356 else: 

357 result = state.data 

358 

359 link_state_to_flow_run_result(state, result) 

360 

361 # Handle repeated WAITs in a loop instead of recursively, to avoid 

362 # reaching max recursion depth in extreme cases. 

363 async def set_state_and_handle_waits( 

364 set_state_func: Callable[[], Awaitable[OrchestrationResult[Any]]], 

365 ) -> OrchestrationResult[Any]: 

366 response = await set_state_func() 

367 while response.status == SetStateStatus.WAIT: 

368 if TYPE_CHECKING: 

369 assert isinstance(response.details, StateWaitDetails) 

370 engine_logger.debug( 

371 f"Received wait instruction for {response.details.delay_seconds}s: " 

372 f"{response.details.reason}" 

373 ) 

374 await anyio.sleep(response.details.delay_seconds) 

375 response = await set_state_func() 

376 return response 

377 

378 set_state = partial(client.set_flow_run_state, flow_run_id, state, force=force) 

379 response = await set_state_and_handle_waits(set_state) 

380 

381 # Parse the response to return the new state 

382 if response.status == SetStateStatus.ACCEPT: 

383 # Update the state with the details if provided 

384 if TYPE_CHECKING: 

385 assert response.state is not None 

386 state.id = response.state.id 

387 state.timestamp = response.state.timestamp 

388 if response.state.state_details: 

389 state.state_details = response.state.state_details 

390 return state 

391 

392 elif response.status == SetStateStatus.ABORT: 

393 if TYPE_CHECKING: 

394 assert isinstance(response.details, StateAbortDetails) 

395 

396 raise prefect.exceptions.Abort(response.details.reason) 

397 

398 elif response.status == SetStateStatus.REJECT: 

399 if TYPE_CHECKING: 

400 assert response.state is not None 

401 assert isinstance(response.details, StateRejectDetails) 

402 

403 if response.state.is_paused(): 

404 raise Pause(response.details.reason, state=response.state) 

405 return response.state 

406 

407 else: 

408 raise ValueError( 

409 f"Received unexpected `SetStateStatus` from server: {response.status!r}" 

410 ) 

411 

412 

413def propose_state_sync( 1a

414 client: "SyncPrefectClient", 

415 state: State[Any], 

416 flow_run_id: UUID, 

417 force: bool = False, 

418) -> State[Any]: 

419 """ 

420 Propose a new state for a flow run, invoking Prefect orchestration logic. 

421 

422 If the proposed state is accepted, the provided `state` will be augmented with 

423 details and returned. 

424 

425 If the proposed state is rejected, a new state returned by the Prefect API will be 

426 returned. 

427 

428 If the proposed state results in a WAIT instruction from the Prefect API, the 

429 function will sleep and attempt to propose the state again. 

430 

431 If the proposed state results in an ABORT instruction from the Prefect API, an 

432 error will be raised. 

433 

434 Args: 

435 state: a new state for the flow run 

436 flow_run_id: an optional flow run id, used when proposing flow run states 

437 

438 Returns: 

439 a State model representation of the flow run state 

440 

441 Raises: 

442 ValueError: if flow_run_id is not provided 

443 prefect.exceptions.Abort: if an ABORT instruction is received from 

444 the Prefect API 

445 """ 

446 

447 # Handle sub-flow tracing 

448 if state.is_final(): 

449 if _is_result_record(state.data): 

450 result = state.data.result 

451 else: 

452 result = state.data 

453 

454 link_state_to_flow_run_result(state, result) 

455 

456 # Handle repeated WAITs in a loop instead of recursively, to avoid 

457 # reaching max recursion depth in extreme cases. 

458 def set_state_and_handle_waits( 

459 set_state_func: Callable[[], OrchestrationResult[Any]], 

460 ) -> OrchestrationResult[Any]: 

461 response = set_state_func() 

462 while response.status == SetStateStatus.WAIT: 

463 if TYPE_CHECKING: 

464 assert isinstance(response.details, StateWaitDetails) 

465 engine_logger.debug( 

466 f"Received wait instruction for {response.details.delay_seconds}s: " 

467 f"{response.details.reason}" 

468 ) 

469 time.sleep(response.details.delay_seconds) 

470 response = set_state_func() 

471 return response 

472 

473 # Attempt to set the state 

474 set_state = partial(client.set_flow_run_state, flow_run_id, state, force=force) 

475 response = set_state_and_handle_waits(set_state) 

476 

477 # Parse the response to return the new state 

478 if response.status == SetStateStatus.ACCEPT: 

479 if TYPE_CHECKING: 

480 assert response.state is not None 

481 # Update the state with the details if provided 

482 state.id = response.state.id 

483 state.timestamp = response.state.timestamp 

484 if response.state.state_details: 

485 state.state_details = response.state.state_details 

486 return state 

487 

488 elif response.status == SetStateStatus.ABORT: 

489 if TYPE_CHECKING: 

490 assert isinstance(response.details, StateAbortDetails) 

491 raise prefect.exceptions.Abort(response.details.reason) 

492 

493 elif response.status == SetStateStatus.REJECT: 

494 if TYPE_CHECKING: 

495 assert response.state is not None 

496 assert isinstance(response.details, StateRejectDetails) 

497 if response.state.is_paused(): 

498 raise Pause(response.details.reason, state=response.state) 

499 return response.state 

500 

501 else: 

502 raise ValueError( 

503 f"Received unexpected `SetStateStatus` from server: {response.status!r}" 

504 ) 

505 

506 

507def get_state_for_result(obj: Any) -> Optional[tuple[State, RunType]]: 1a

508 """ 

509 Get the state related to a result object. 

510 

511 `link_state_to_result` must have been called first. 

512 """ 

513 flow_run_context = FlowRunContext.get() 

514 if flow_run_context: 

515 return flow_run_context.run_results.get(id(obj)) 

516 

517 

518def link_state_to_flow_run_result(state: State, result: Any) -> None: 1a

519 """Creates a link between a state and flow run result""" 

520 link_state_to_result(state, result, RunType.FLOW_RUN) 

521 

522 

523def link_state_to_task_run_result(state: State, result: Any) -> None: 1a

524 """Creates a link between a state and task run result""" 

525 link_state_to_result(state, result, RunType.TASK_RUN) 

526 

527 

528def link_state_to_result(state: State, result: Any, run_type: RunType) -> None: 1a

529 """ 

530 Caches a link between a state and a result and its components using 

531 the `id` of the components to map to the state. The cache is persisted to the 

532 current flow run context since task relationships are limited to within a flow run. 

533 

534 This allows dependency tracking to occur when results are passed around. 

535 Note: Because `id` is used, we cannot cache links between singleton objects. 

536 

537 We only cache the relationship between components 1-layer deep. 

538 Example: 

539 Given the result [1, ["a","b"], ("c",)], the following elements will be 

540 mapped to the state: 

541 - [1, ["a","b"], ("c",)] 

542 - ["a","b"] 

543 - ("c",) 

544 

545 Note: the int `1` will not be mapped to the state because it is a singleton. 

546 

547 Other Notes: 

548 We do not hash the result because: 

549 - If changes are made to the object in the flow between task calls, we can still 

550 track that they are related. 

551 - Hashing can be expensive. 

552 - Not all objects are hashable. 

553 

554 We do not set an attribute, e.g. `__prefect_state__`, on the result because: 

555 

556 - Mutating user's objects is dangerous. 

557 - Unrelated equality comparisons can break unexpectedly. 

558 - The field can be preserved on copy. 

559 - We cannot set this attribute on Python built-ins. 

560 """ 

561 

562 flow_run_context = FlowRunContext.get() 

563 # Drop the data field to avoid holding a strong reference to the result 

564 # Holding large user objects in memory can cause memory bloat 

565 linked_state = state.model_copy(update={"data": None}) 

566 

567 if flow_run_context: 

568 

569 def link_if_trackable(obj: Any) -> None: 

570 """Track connection between a task run result and its associated state if it has a unique ID. 

571 

572 We cannot track booleans, Ellipsis, None, NotImplemented, or the integers from -5 to 256 

573 because they are singletons. 

574 

575 This function will mutate the State if the object is an untrackable type by setting the value 

576 for `State.state_details.untrackable_result` to `True`. 

577 

578 """ 

579 if (type(obj) in UNTRACKABLE_TYPES) or ( 

580 isinstance(obj, int) and (-5 <= obj <= 256) 

581 ): 

582 state.state_details.untrackable_result = True 

583 return 

584 flow_run_context.run_results[id(obj)] = (linked_state, run_type) 

585 

586 visit_collection(expr=result, visit_fn=link_if_trackable, max_depth=1) 

587 

588 

589def should_log_prints(flow_or_task: Union["Flow[..., Any]", "Task[..., Any]"]) -> bool: 1a

590 flow_run_context = FlowRunContext.get() 

591 

592 if flow_or_task.log_prints is None: 

593 if flow_run_context: 

594 return flow_run_context.log_prints 

595 else: 

596 return PREFECT_LOGGING_LOG_PRINTS.value() 

597 

598 return flow_or_task.log_prints 

599 

600 

601async def check_api_reachable(client: "PrefectClient", fail_message: str) -> None: 1a

602 # Do not perform a healthcheck if it exists and is not expired 

603 api_url = str(client.api_url) 

604 if api_url in API_HEALTHCHECKS: 

605 expires = API_HEALTHCHECKS[api_url] 

606 if expires > time.monotonic(): 

607 return 

608 

609 connect_error = await client.api_healthcheck() 

610 if connect_error: 

611 raise RuntimeError( 

612 f"{fail_message}. Failed to reach API at {api_url}." 

613 ) from connect_error 

614 

615 # Create a 10 minute cache for the healthy response 

616 API_HEALTHCHECKS[api_url] = get_deadline(60 * 10) 

617 

618 

619def emit_task_run_state_change_event( 1a

620 task_run: TaskRun, 

621 initial_state: Optional[State[Any]], 

622 validated_state: State[Any], 

623 follows: Optional[Event] = None, 

624) -> Optional[Event]: 

625 state_message_truncation_length = 100_000 

626 

627 if _is_result_record(validated_state.data) and should_persist_result(): 

628 data = validated_state.data.metadata.model_dump(mode="json") 

629 else: 

630 data = None 

631 

632 return emit_event( 

633 id=validated_state.id, 

634 occurred=validated_state.timestamp, 

635 event=f"prefect.task-run.{validated_state.name}", 

636 payload={ 

637 "intended": { 

638 "from": str(initial_state.type.value) if initial_state else None, 

639 "to": str(validated_state.type.value) if validated_state else None, 

640 }, 

641 "initial_state": ( 

642 { 

643 "type": str(initial_state.type.value), 

644 "name": initial_state.name, 

645 "message": truncated_to( 

646 state_message_truncation_length, initial_state.message 

647 ), 

648 "state_details": initial_state.state_details.model_dump( 

649 mode="json", 

650 exclude_none=True, 

651 exclude_unset=True, 

652 exclude={"flow_run_id", "task_run_id"}, 

653 ), 

654 } 

655 if initial_state 

656 else None 

657 ), 

658 "validated_state": { 

659 "type": str(validated_state.type.value), 

660 "name": validated_state.name, 

661 "message": truncated_to( 

662 state_message_truncation_length, validated_state.message 

663 ), 

664 "state_details": validated_state.state_details.model_dump( 

665 mode="json", 

666 exclude_none=True, 

667 exclude_unset=True, 

668 exclude={"flow_run_id", "task_run_id"}, 

669 ), 

670 "data": data, 

671 }, 

672 "task_run": task_run.model_dump( 

673 mode="json", 

674 exclude_none=True, 

675 exclude={ 

676 "id", 

677 "created", 

678 "updated", 

679 "flow_run_id", 

680 "state_id", 

681 "state_type", 

682 "state_name", 

683 "state", 

684 # server materialized fields 

685 "estimated_start_time_delta", 

686 "estimated_run_time", 

687 }, 

688 ), 

689 }, 

690 resource={ 

691 "prefect.resource.id": f"prefect.task-run.{task_run.id}", 

692 "prefect.resource.name": task_run.name, 

693 "prefect.run-count": str(task_run.run_count), 

694 "prefect.state-message": truncated_to( 

695 state_message_truncation_length, validated_state.message 

696 ), 

697 "prefect.state-name": validated_state.name or "", 

698 "prefect.state-timestamp": ( 

699 validated_state.timestamp.isoformat() 

700 if validated_state and validated_state.timestamp 

701 else "" 

702 ), 

703 "prefect.state-type": str(validated_state.type.value), 

704 "prefect.orchestration": "client", 

705 }, 

706 related=[ 

707 { 

708 "prefect.resource.id": f"prefect.tag.{tag}", 

709 "prefect.resource.role": "tag", 

710 } 

711 for tag in sorted(task_run.tags) 

712 ], 

713 follows=follows, 

714 ) 

715 

716 

717def resolve_to_final_result(expr: Any, context: dict[str, Any]) -> Any: 1a

718 """ 

719 Resolve any `PrefectFuture`, or `State` types nested in parameters into 

720 data. Designed to be use with `visit_collection`. 

721 """ 

722 state: Optional[State[Any]] = None 

723 

724 # Expressions inside quotes should not be modified 

725 if isinstance(context.get("annotation"), quote): 

726 raise StopVisiting() 

727 

728 if isinstance(expr, PrefectFuture): 

729 upstream_task_run: Optional[TaskRun] = context.get("current_task_run") 

730 upstream_task: Optional["Task[..., Any]"] = context.get("current_task") 

731 if ( 

732 upstream_task 

733 and upstream_task_run 

734 and expr.task_run_id == upstream_task_run.id 

735 ): 

736 raise ValueError( 

737 f"Discovered a task depending on itself. Raising to avoid a deadlock. Please inspect the inputs and dependencies of {upstream_task.name}." 

738 ) 

739 

740 expr.wait() 

741 state = expr.state 

742 elif isinstance(expr, State): 

743 state = expr 

744 else: 

745 return expr 

746 

747 assert state 

748 

749 # Do not allow uncompleted upstreams except failures when `allow_failure` has 

750 # been used 

751 if not state.is_completed() and not ( 

752 # TODO: Note that the contextual annotation here is only at the current level 

753 # if `allow_failure` is used then another annotation is used, this will 

754 # incorrectly evaluate to false — to resolve this, we must track all 

755 # annotations wrapping the current expression but this is not yet 

756 # implemented. 

757 isinstance(context.get("annotation"), allow_failure) and state.is_failed() 

758 ): 

759 raise UpstreamTaskError( 

760 f"Upstream task run '{state.state_details.task_run_id}' did not reach a" 

761 " 'COMPLETED' state." 

762 ) 

763 

764 result: Any = state.result(raise_on_failure=False, _sync=True) # pyright: ignore[reportCallIssue] _sync messes up type inference and can be removed once async_dispatch is removed 

765 

766 if state.state_details.traceparent: 

767 parameter_context = propagate.extract( 

768 {"traceparent": state.state_details.traceparent} 

769 ) 

770 attributes = {} 

771 

772 # If this future is being used as a parameter (as opposed to just a wait_for), 

773 # add attributes to the span to indicate the parameter name and type 

774 if "parameter_name" in context: 

775 attributes = { 

776 "prefect.input.name": context["parameter_name"], 

777 "prefect.input.type": type(result).__name__, 

778 } 

779 

780 trace.get_current_span().add_link( 

781 context=trace.get_current_span(parameter_context).get_span_context(), 

782 attributes=attributes, 

783 ) 

784 

785 return result 

786 

787 

788def resolve_inputs_sync( 1a

789 parameters: dict[str, Any], return_data: bool = True, max_depth: int = -1 

790) -> dict[str, Any]: 

791 """ 

792 Resolve any `Quote`, `PrefectFuture`, or `State` types nested in parameters into 

793 data. 

794 

795 Returns: 

796 A copy of the parameters with resolved data 

797 

798 Raises: 

799 UpstreamTaskError: If any of the upstream states are not `COMPLETED` 

800 """ 

801 

802 if not parameters: 

803 return {} 

804 

805 resolved_parameters: dict[str, Any] = {} 

806 for parameter, value in parameters.items(): 

807 try: 

808 resolved_parameters[parameter] = visit_collection( 

809 value, 

810 visit_fn=resolve_to_final_result, 

811 return_data=return_data, 

812 max_depth=max_depth, 

813 remove_annotations=True, 

814 context={"parameter_name": parameter}, 

815 ) 

816 except UpstreamTaskError: 

817 raise 

818 except Exception as exc: 

819 raise PrefectException( 

820 f"Failed to resolve inputs in parameter {parameter!r}. If your" 

821 " parameter type is not supported, consider using the `quote`" 

822 " annotation to skip resolution of inputs." 

823 ) from exc 

824 

825 return resolved_parameters