Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/orchestration/core_policy.py: 22%

547 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1""" 

2Orchestration logic that fires on state transitions. 

3 

4`CoreFlowPolicy` and `CoreTaskPolicy` contain all default orchestration rules that 

5Prefect enforces on a state transition. 

6""" 

7 

8from __future__ import annotations 1a

9 

10import datetime 1a

11import logging 1a

12import math 1a

13from typing import Any, Union, cast 1a

14from uuid import UUID, uuid4 1a

15 

16import sqlalchemy as sa 1a

17from packaging.version import Version 1a

18from sqlalchemy import select 1a

19 

20from prefect.logging import get_logger 1a

21from prefect.server import models 1a

22from prefect.server.concurrency.lease_storage import ( 1a

23 ConcurrencyLeaseHolder, 

24 ConcurrencyLimitLeaseMetadata, 

25 get_concurrency_lease_storage, 

26) 

27from prefect.server.database import PrefectDBInterface, orm_models 1a

28from prefect.server.database.dependencies import db_injector, provide_database_interface 1a

29from prefect.server.exceptions import ObjectNotFoundError 1a

30from prefect.server.models import concurrency_limits, concurrency_limits_v2, deployments 1a

31from prefect.server.orchestration.dependencies import ( 1a

32 MIN_CLIENT_VERSION_FOR_CONCURRENCY_LIMIT_LEASING, 

33 WORKER_VERSIONS_THAT_MANAGE_DEPLOYMENT_CONCURRENCY, 

34) 

35from prefect.server.orchestration.policies import ( 1a

36 FlowRunOrchestrationPolicy, 

37 TaskRunOrchestrationPolicy, 

38) 

39from prefect.server.orchestration.rules import ( 1a

40 ALL_ORCHESTRATION_STATES, 

41 TERMINAL_STATES, 

42 BaseOrchestrationRule, 

43 BaseUniversalTransform, 

44 FlowOrchestrationContext, 

45 FlowRunOrchestrationRule, 

46 FlowRunUniversalTransform, 

47 GenericOrchestrationRule, 

48 OrchestrationContext, 

49 TaskRunOrchestrationRule, 

50 TaskRunUniversalTransform, 

51) 

52from prefect.server.schemas import core, filters, states 1a

53from prefect.server.schemas.states import StateType 1a

54from prefect.server.task_queue import TaskQueue 1a

55from prefect.settings import ( 1a

56 get_current_settings, 

57) 

58from prefect.types._datetime import now 1a

59from prefect.utilities.math import clamped_poisson_interval 1a

60 

61from .instrumentation_policies import InstrumentFlowRunStateTransitions 1a

62 

63logger: logging.Logger = get_logger(__name__) 1a

64 

65 

66class CoreFlowPolicy(FlowRunOrchestrationPolicy): 1a

67 """ 

68 Orchestration rules that run against flow-run-state transitions in priority order. 

69 """ 

70 

71 @staticmethod 1a

72 def priority() -> list[ 1a

73 Union[ 

74 type[BaseUniversalTransform[orm_models.FlowRun, core.FlowRunPolicy]], 

75 type[BaseOrchestrationRule[orm_models.FlowRun, core.FlowRunPolicy]], 

76 ] 

77 ]: 

78 return cast( 

79 list[ 

80 Union[ 

81 type[ 

82 BaseUniversalTransform[orm_models.FlowRun, core.FlowRunPolicy] 

83 ], 

84 type[BaseOrchestrationRule[orm_models.FlowRun, core.FlowRunPolicy]], 

85 ] 

86 ], 

87 [ 

88 PreventDuplicateTransitions, 

89 HandleFlowTerminalStateTransitions, 

90 EnforceCancellingToCancelledTransition, 

91 BypassCancellingFlowRunsWithNoInfra, 

92 PreventPendingTransitions, 

93 CopyDeploymentConcurrencyLeaseID, 

94 SecureFlowConcurrencySlots, 

95 RemoveDeploymentConcurrencyLeaseForOldClientVersions, 

96 EnsureOnlyScheduledFlowsMarkedLate, 

97 HandlePausingFlows, 

98 HandleResumingPausedFlows, 

99 CopyScheduledTime, 

100 WaitForScheduledTime, 

101 RetryFailedFlows, 

102 InstrumentFlowRunStateTransitions, 

103 ReleaseFlowConcurrencySlots, 

104 ], 

105 ) 

106 

107 

108class CoreTaskPolicy(TaskRunOrchestrationPolicy): 1a

109 """ 

110 Orchestration rules that run against task-run-state transitions in priority order. 

111 """ 

112 

113 @staticmethod 1a

114 def priority() -> list[ 1a

115 Union[ 

116 type[BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy]], 

117 type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]], 

118 ] 

119 ]: 

120 return cast( 

121 list[ 

122 Union[ 

123 type[ 

124 BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy] 

125 ], 

126 type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]], 

127 ] 

128 ], 

129 [ 

130 CacheRetrieval, 

131 HandleTaskTerminalStateTransitions, 

132 PreventRunningTasksFromStoppedFlows, 

133 SecureTaskConcurrencySlots, # retrieve cached states even if slots are full 

134 CopyScheduledTime, 

135 WaitForScheduledTime, 

136 RetryFailedTasks, 

137 RenameReruns, 

138 UpdateFlowRunTrackerOnTasks, 

139 CacheInsertion, 

140 ReleaseTaskConcurrencySlots, 

141 ], 

142 ) 

143 

144 

145class ClientSideTaskOrchestrationPolicy(TaskRunOrchestrationPolicy): 1a

146 """ 

147 Orchestration rules that run against task-run-state transitions in priority order, 

148 specifically for clients doing client-side orchestration. 

149 """ 

150 

151 @staticmethod 1a

152 def priority() -> list[ 1a

153 Union[ 

154 type[BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy]], 

155 type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]], 

156 ] 

157 ]: 

158 return cast( 

159 list[ 

160 Union[ 

161 type[ 

162 BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy] 

163 ], 

164 type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]], 

165 ] 

166 ], 

167 [ 

168 CacheRetrieval, 

169 HandleTaskTerminalStateTransitions, 

170 PreventRunningTasksFromStoppedFlows, 

171 CopyScheduledTime, 

172 WaitForScheduledTime, 

173 RetryFailedTasks, 

174 RenameReruns, 

175 UpdateFlowRunTrackerOnTasks, 

176 CacheInsertion, 

177 ReleaseTaskConcurrencySlots, 

178 ], 

179 ) 

180 

181 

182class BackgroundTaskPolicy(TaskRunOrchestrationPolicy): 1a

183 """ 

184 Orchestration rules that run against task-run-state transitions in priority order. 

185 """ 

186 

187 @staticmethod 1a

188 def priority() -> list[ 1a

189 type[BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy]] 

190 | type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]] 

191 ]: 

192 return cast( 

193 list[ 

194 Union[ 

195 type[ 

196 BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy] 

197 ], 

198 type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]], 

199 ] 

200 ], 

201 [ 

202 PreventPendingTransitions, 

203 CacheRetrieval, 

204 HandleTaskTerminalStateTransitions, 

205 # SecureTaskConcurrencySlots, # retrieve cached states even if slots are full 

206 CopyScheduledTime, 

207 CopyTaskParametersID, 

208 WaitForScheduledTime, 

209 RetryFailedTasks, 

210 RenameReruns, 

211 UpdateFlowRunTrackerOnTasks, 

212 CacheInsertion, 

213 ReleaseTaskConcurrencySlots, 

214 EnqueueScheduledTasks, 

215 ], 

216 ) 

217 

218 

219class MinimalFlowPolicy(FlowRunOrchestrationPolicy): 1a

220 @staticmethod 1a

221 def priority() -> list[ 1a

222 Union[ 

223 type[BaseUniversalTransform[orm_models.FlowRun, core.FlowRunPolicy]], 

224 type[BaseOrchestrationRule[orm_models.FlowRun, core.FlowRunPolicy]], 

225 ] 

226 ]: 

227 return [ 1bc

228 BypassCancellingFlowRunsWithNoInfra, # cancel scheduled or suspended runs from the UI 

229 InstrumentFlowRunStateTransitions, 

230 ReleaseFlowConcurrencySlots, 

231 ] 

232 

233 

234class MarkLateRunsPolicy(FlowRunOrchestrationPolicy): 1a

235 @staticmethod 1a

236 def priority() -> list[ 1a

237 Union[ 

238 type[BaseUniversalTransform[orm_models.FlowRun, core.FlowRunPolicy]], 

239 type[BaseOrchestrationRule[orm_models.FlowRun, core.FlowRunPolicy]], 

240 ] 

241 ]: 

242 return [ 

243 EnsureOnlyScheduledFlowsMarkedLate, 

244 InstrumentFlowRunStateTransitions, 

245 ] 

246 

247 

248class MinimalTaskPolicy(TaskRunOrchestrationPolicy): 1a

249 @staticmethod 1a

250 def priority() -> list[ 1a

251 Union[ 

252 type[BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy]], 

253 type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]], 

254 ] 

255 ]: 

256 return [ 1bdc

257 ReleaseTaskConcurrencySlots, # always release concurrency slots 

258 ] 

259 

260 

261class SecureTaskConcurrencySlots(TaskRunOrchestrationRule): 1a

262 """ 

263 Checks relevant concurrency slots are available before entering a Running state. 

264 

265 This rule checks if concurrency limits have been set on the tags associated with a 

266 TaskRun. If so, a concurrency slot will be secured against each concurrency limit 

267 before being allowed to transition into a running state. If a concurrency limit has 

268 been reached, the client will be instructed to delay the transition for the duration 

269 specified by the "PREFECT_TASK_RUN_TAG_CONCURRENCY_SLOT_WAIT_SECONDS" setting 

270 before trying again. If the concurrency limit set on a tag is 0, the transition will 

271 be aborted to prevent deadlocks. 

272 """ 

273 

274 FROM_STATES = ALL_ORCHESTRATION_STATES 1a

275 TO_STATES = {StateType.RUNNING} 1a

276 

277 async def before_transition( 1a

278 self, 

279 initial_state: states.State[Any] | None, 

280 proposed_state: states.State[Any] | None, 

281 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy], 

282 ) -> None: 

283 settings = get_current_settings() 

284 self._applied_limits: list[str] = [] 

285 self._acquired_v2_lease_ids: list[UUID] = [] 

286 v1_limits = ( 

287 await concurrency_limits.filter_concurrency_limits_for_orchestration( 

288 context.session, tags=context.run.tags 

289 ) 

290 ) 

291 v2_names = [f"tag:{tag}" for tag in context.run.tags] 

292 v2_limits = await concurrency_limits_v2.bulk_read_concurrency_limits( 

293 context.session, names=v2_names 

294 ) 

295 

296 # Handle V2 limits first (if they exist) 

297 v2_tags: set[str] = set() # Track which tags have V2 limits 

298 if v2_limits: 

299 lease_storage = get_concurrency_lease_storage() 

300 # Track which tags have V2 limits to exclude from V1 processing 

301 v2_tags = { 

302 limit.name.removeprefix("tag:") for limit in v2_limits if limit.active 

303 } 

304 

305 # Check for zero limits that would deadlock 

306 for limit in v2_limits: 

307 if limit.active and limit.limit == 0: 

308 # Clean up any already acquired V2 leases 

309 for lease_id in self._acquired_v2_lease_ids: 

310 try: 

311 lease = await lease_storage.read_lease( 

312 lease_id=lease_id, 

313 ) 

314 if lease: 

315 await concurrency_limits_v2.bulk_decrement_active_slots( 

316 session=context.session, 

317 concurrency_limit_ids=lease.resource_ids, 

318 slots=lease.metadata.slots if lease.metadata else 1, 

319 ) 

320 await lease_storage.revoke_lease( 

321 lease_id=lease.id, 

322 ) 

323 except Exception: 

324 logger.warning( 

325 f"Failed to clean up lease {lease_id} during abort", 

326 exc_info=True, 

327 ) 

328 

329 await self.abort_transition( 

330 reason=f'The concurrency limit on tag "{limit.name.removeprefix("tag:")}" is 0 and will deadlock if the task tries to run again.', 

331 ) 

332 

333 # Try to acquire V2 slots with lease (exclude zero limits as they're handled above) 

334 active_v2_limits = [ 

335 limit for limit in v2_limits if limit.active and limit.limit > 0 

336 ] 

337 if active_v2_limits: 

338 # Attempt to acquire slots 

339 async with provide_database_interface().session_context( 

340 begin_transaction=True 

341 ) as session: 

342 acquired = await concurrency_limits_v2.bulk_increment_active_slots( 

343 session=session, 

344 concurrency_limit_ids=[limit.id for limit in active_v2_limits], 

345 slots=1, 

346 ) 

347 if not acquired: 

348 await session.rollback() 

349 # Slots not available, delay transition 

350 delay_seconds = clamped_poisson_interval( 

351 average_interval=settings.server.tasks.tag_concurrency_slot_wait_seconds, 

352 ) 

353 await self.delay_transition( 

354 delay_seconds=round(delay_seconds), 

355 reason=f"Concurrency limit reached for tags: {', '.join([limit.name.removeprefix('tag:') for limit in active_v2_limits])}", 

356 ) 

357 return 

358 

359 # Create lease for acquired slots with minimal metadata first 

360 lease = await lease_storage.create_lease( 

361 resource_ids=[limit.id for limit in active_v2_limits], 

362 ttl=concurrency_limits.V1_LEASE_TTL, 

363 metadata=ConcurrencyLimitLeaseMetadata( 

364 slots=1, 

365 holder=ConcurrencyLeaseHolder( 

366 type="task_run", 

367 id=context.run.id, 

368 ), 

369 ), 

370 ) 

371 

372 self._acquired_v2_lease_ids.append(lease.id) 

373 

374 remaining_v1_limits = [limit for limit in v1_limits if limit.tag not in v2_tags] 

375 if remaining_v1_limits: 

376 run_limits = {limit.tag: limit for limit in v1_limits} 

377 for tag, cl in run_limits.items(): 

378 limit = cl.concurrency_limit 

379 if limit == 0: 

380 # limits of 0 will deadlock, and the transition needs to abort 

381 for stale_tag in self._applied_limits: 

382 stale_limit = run_limits.get(stale_tag, None) 

383 if stale_limit: 

384 active_slots: set[str] = set(stale_limit.active_slots) 

385 active_slots.discard(str(context.run.id)) 

386 stale_limit.active_slots = list(active_slots) 

387 

388 await self.abort_transition( 

389 reason=( 

390 f'The concurrency limit on tag "{tag}" is 0 and will deadlock' 

391 " if the task tries to run again." 

392 ), 

393 ) 

394 elif len(cl.active_slots) >= limit: 

395 # if the limit has already been reached, delay the transition 

396 for stale_tag in self._applied_limits: 

397 stale_limit = run_limits.get(stale_tag, None) 

398 if stale_limit: 

399 active_slots = set(stale_limit.active_slots) 

400 active_slots.discard(str(context.run.id)) 

401 stale_limit.active_slots = list(active_slots) 

402 

403 await self.delay_transition( 

404 delay_seconds=int( 

405 settings.server.tasks.tag_concurrency_slot_wait_seconds 

406 ), 

407 # PREFECT_TASK_RUN_TAG_CONCURRENCY_SLOT_WAIT_SECONDS.value(), 

408 reason=f"Concurrency limit for the {tag} tag has been reached", 

409 ) 

410 else: 

411 # log the TaskRun ID to active_slots 

412 self._applied_limits.append(tag) 

413 active_slots = set(cl.active_slots) 

414 active_slots.add(str(context.run.id)) 

415 cl.active_slots = list(active_slots) 

416 

417 async def cleanup( 1a

418 self, 

419 initial_state: states.State[Any] | None, 

420 validated_state: states.State[Any] | None, 

421 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy], 

422 ) -> None: 

423 lease_storage = get_concurrency_lease_storage() 

424 # Clean up V2 leases 

425 for lease_id in self._acquired_v2_lease_ids: 

426 try: 

427 lease = await lease_storage.read_lease( 

428 lease_id=lease_id, 

429 ) 

430 if lease: 

431 await concurrency_limits_v2.bulk_decrement_active_slots( 

432 session=context.session, 

433 concurrency_limit_ids=lease.resource_ids, 

434 slots=lease.metadata.slots if lease.metadata else 1, 

435 ) 

436 await lease_storage.revoke_lease( 

437 lease_id=lease.id, 

438 ) 

439 else: 

440 logger.warning(f"Lease {lease_id} not found during cleanup") 

441 except Exception: 

442 logger.warning(f"Failed to clean up lease {lease_id}", exc_info=True) 

443 

444 for tag in self._applied_limits: 

445 cl = await concurrency_limits.read_concurrency_limit_by_tag( 

446 context.session, tag 

447 ) 

448 if cl: 

449 active_slots = set(cl.active_slots) 

450 active_slots.discard(str(context.run.id)) 

451 cl.active_slots = list(active_slots) 

452 

453 

454class ReleaseTaskConcurrencySlots(TaskRunUniversalTransform): 1a

455 """ 

456 Releases any concurrency slots held by a run upon exiting a Running or 

457 Cancelling state. 

458 """ 

459 

460 async def after_transition( 1a

461 self, 

462 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy], 

463 ) -> None: 

464 if self.nullified_transition(): 464 ↛ 465line 464 didn't jump to line 465 because the condition on line 464 was never true1bdc

465 return 

466 

467 if context.validated_state and context.validated_state.type not in [ 467 ↛ exitline 467 didn't return from function 'after_transition' because the condition on line 467 was always true1bdc

468 states.StateType.RUNNING, 

469 states.StateType.CANCELLING, 

470 ]: 

471 v2_names = [f"tag:{tag}" for tag in context.run.tags] 1bdc

472 v2_limits = await concurrency_limits_v2.bulk_read_concurrency_limits( 1bdc

473 context.session, names=v2_names 

474 ) 

475 # Release V2 leases for this task run 

476 if v2_limits: 

477 lease_storage = get_concurrency_lease_storage() 

478 lease_ids_to_reconcile: set[UUID] = set() 

479 for v2_limit in v2_limits: 

480 # Find holders for this limit 

481 holders_with_leases: list[ 

482 tuple[UUID, ConcurrencyLeaseHolder] 

483 ] = await lease_storage.list_holders_for_limit( 

484 limit_id=v2_limit.id, 

485 ) 

486 # Find leases that belong to this task run 

487 for lease_id, holder in holders_with_leases: 

488 if holder.id == context.run.id: 

489 lease_ids_to_reconcile.add(lease_id) 

490 

491 # Reconcile all found leases 

492 for lease_id in lease_ids_to_reconcile: 

493 try: 

494 lease = await lease_storage.read_lease( 

495 lease_id=lease_id, 

496 ) 

497 if lease: 

498 await concurrency_limits_v2.bulk_decrement_active_slots( 

499 session=context.session, 

500 concurrency_limit_ids=lease.resource_ids, 

501 slots=lease.metadata.slots if lease.metadata else 1, 

502 ) 

503 await lease_storage.revoke_lease( 

504 lease_id=lease.id, 

505 ) 

506 else: 

507 logger.warning(f"Lease {lease_id} not found during release") 

508 except Exception: 

509 logger.warning( 

510 f"Failed to reconcile lease {lease_id} during release", 

511 exc_info=True, 

512 ) 

513 

514 v1_limits = ( 

515 await concurrency_limits.filter_concurrency_limits_for_orchestration( 

516 context.session, tags=context.run.tags 

517 ) 

518 ) 

519 for cl in v1_limits: 

520 active_slots = set(cl.active_slots) 

521 active_slots.discard(str(context.run.id)) 

522 cl.active_slots = list(active_slots) 

523 

524 

525class SecureFlowConcurrencySlots(FlowRunOrchestrationRule): 1a

526 """ 

527 Enforce deployment concurrency limits. 

528 

529 This rule enforces concurrency limits on deployments. If a deployment has a concurrency limit, 

530 this rule will prevent more than that number of flow runs from being submitted concurrently 

531 based on the concurrency limit behavior configured for the deployment. 

532 

533 We use the PENDING state as the target transition because this allows workers to secure a slot 

534 before provisioning dynamic infrastructure to run a flow. If a slot isn't available, the worker 

535 won't provision infrastructure. 

536 

537 A lease is created for the concurrency limit. The client will be responsible for maintaining the lease. 

538 """ 

539 

540 FROM_STATES = ALL_ORCHESTRATION_STATES - { 1a

541 states.StateType.PENDING, 

542 states.StateType.RUNNING, 

543 states.StateType.CANCELLING, 

544 } 

545 TO_STATES = {states.StateType.PENDING} 1a

546 

547 async def before_transition( 1a

548 self, 

549 initial_state: states.State[Any] | None, 

550 proposed_state: states.State[Any] | None, 

551 context: FlowOrchestrationContext, 

552 ) -> None: 

553 if ( 

554 not context.session 

555 or not context.run.deployment_id 

556 or not proposed_state 

557 or context.client_version 

558 in WORKER_VERSIONS_THAT_MANAGE_DEPLOYMENT_CONCURRENCY 

559 ): 

560 return 

561 

562 deployment = await deployments.read_deployment( 

563 session=context.session, 

564 deployment_id=context.run.deployment_id, 

565 ) 

566 if not deployment: 

567 await self.abort_transition("Deployment not found.") 

568 return 

569 

570 if ( 

571 not deployment.global_concurrency_limit 

572 or not deployment.concurrency_limit_id 

573 ): 

574 return 

575 

576 if deployment.global_concurrency_limit.limit == 0: 

577 await self.abort_transition( 

578 "The deployment concurrency limit is 0. The flow will deadlock if submitted again." 

579 ) 

580 return 

581 

582 acquired = await concurrency_limits_v2.bulk_increment_active_slots( 

583 session=context.session, 

584 concurrency_limit_ids=[deployment.concurrency_limit_id], 

585 slots=1, 

586 ) 

587 if acquired: 

588 lease_storage = get_concurrency_lease_storage() 

589 settings = get_current_settings() 

590 lease = await lease_storage.create_lease( 

591 resource_ids=[deployment.concurrency_limit_id], 

592 metadata=ConcurrencyLimitLeaseMetadata( 

593 slots=1, 

594 ), 

595 ttl=datetime.timedelta( 

596 seconds=settings.server.concurrency.initial_deployment_lease_duration 

597 ), 

598 ) 

599 proposed_state.state_details.deployment_concurrency_lease_id = lease.id 

600 

601 else: 

602 concurrency_options = ( 

603 deployment.concurrency_options 

604 or core.ConcurrencyOptions( 

605 collision_strategy=core.ConcurrencyLimitStrategy.ENQUEUE 

606 ) 

607 ) 

608 

609 if ( 

610 concurrency_options.collision_strategy 

611 == core.ConcurrencyLimitStrategy.ENQUEUE 

612 ): 

613 settings = get_current_settings() 

614 await self.reject_transition( 

615 state=states.Scheduled( 

616 name="AwaitingConcurrencySlot", 

617 scheduled_time=now("UTC") 

618 + datetime.timedelta( 

619 seconds=settings.server.deployments.concurrency_slot_wait_seconds 

620 ), 

621 ), 

622 reason="Deployment concurrency limit reached.", 

623 ) 

624 elif ( 

625 concurrency_options.collision_strategy 

626 == core.ConcurrencyLimitStrategy.CANCEL_NEW 

627 ): 

628 await self.reject_transition( 

629 state=states.Cancelled( 

630 message="Deployment concurrency limit reached." 

631 ), 

632 reason="Deployment concurrency limit reached.", 

633 ) 

634 

635 async def cleanup( # type: ignore 1a

636 self, 

637 initial_state: states.State[Any] | None, 

638 validated_state: states.State[Any] | None, 

639 context: FlowOrchestrationContext, 

640 ) -> None: 

641 logger = get_logger() 

642 if not context.session or not context.run.deployment_id: 

643 return 

644 

645 try: 

646 deployment = await deployments.read_deployment( 

647 session=context.session, 

648 deployment_id=context.run.deployment_id, 

649 ) 

650 

651 if not deployment or not deployment.concurrency_limit_id: 

652 return 

653 

654 await concurrency_limits_v2.bulk_decrement_active_slots( 

655 session=context.session, 

656 concurrency_limit_ids=[deployment.concurrency_limit_id], 

657 slots=1, 

658 ) 

659 if ( 

660 validated_state 

661 and validated_state.state_details.deployment_concurrency_lease_id 

662 ): 

663 lease_storage = get_concurrency_lease_storage() 

664 await lease_storage.revoke_lease( 

665 lease_id=validated_state.state_details.deployment_concurrency_lease_id, 

666 ) 

667 validated_state.state_details.deployment_concurrency_lease_id = None 

668 

669 except Exception as e: 

670 logger.error(f"Error releasing concurrency slots on cleanup: {e}") 

671 

672 

673class CopyDeploymentConcurrencyLeaseID(FlowRunOrchestrationRule): 1a

674 """ 

675 Copies the deployment concurrency lease ID to the proposed state. 

676 """ 

677 

678 FROM_STATES = {states.StateType.PENDING} 1a

679 TO_STATES = {states.StateType.RUNNING} 1a

680 

681 async def before_transition( 1a

682 self, 

683 initial_state: states.State[Any] | None, 

684 proposed_state: states.State[Any] | None, 

685 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy], 

686 ) -> None: 

687 if initial_state is None or proposed_state is None: 

688 return 

689 

690 if not proposed_state.state_details.deployment_concurrency_lease_id: 

691 proposed_state.state_details.deployment_concurrency_lease_id = ( 

692 initial_state.state_details.deployment_concurrency_lease_id 

693 ) 

694 

695 

696class RemoveDeploymentConcurrencyLeaseForOldClientVersions(FlowRunOrchestrationRule): 1a

697 """ 

698 Removes a deployment concurrency lease if the client version is less than the minimum version for leasing. 

699 """ 

700 

701 FROM_STATES = {states.StateType.PENDING} 1a

702 TO_STATES = {states.StateType.RUNNING, states.StateType.CANCELLING} 1a

703 

704 async def after_transition( 1a

705 self, 

706 initial_state: states.State[Any] | None, 

707 validated_state: states.State[Any] | None, 

708 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy], 

709 ) -> None: 

710 if not initial_state or ( 

711 context.client_version 

712 and Version(context.client_version) 

713 >= MIN_CLIENT_VERSION_FOR_CONCURRENCY_LIMIT_LEASING 

714 ): 

715 return 

716 

717 if lease_id := initial_state.state_details.deployment_concurrency_lease_id: 

718 lease_storage = get_concurrency_lease_storage() 

719 await lease_storage.revoke_lease( 

720 lease_id=lease_id, 

721 ) 

722 

723 

724class ReleaseFlowConcurrencySlots(FlowRunUniversalTransform): 1a

725 """ 

726 Releases deployment concurrency slots held by a flow run. 

727 

728 This rule releases a concurrency slot for a deployment when a flow run 

729 transitions out of the Running or Cancelling state. 

730 """ 

731 

732 async def after_transition( 1a

733 self, 

734 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy], 

735 ) -> None: 

736 if self.nullified_transition(): 736 ↛ 737line 736 didn't jump to line 737 because the condition on line 736 was never true1bc

737 return 

738 

739 initial_state_type = ( 1bc

740 context.initial_state.type if context.initial_state else None 

741 ) 

742 proposed_state_type = ( 1bc

743 context.proposed_state.type if context.proposed_state else None 

744 ) 

745 

746 # Check if the transition is valid for releasing concurrency slots. 

747 # This should happen within `after_transition` because BaseUniversalTransforms 

748 # don't know how to "fizzle" themselves if they encounter a transition that 

749 # shouldn't apply to them, even if they use FROM_STATES and TO_STATES. 

750 if not ( 750 ↛ 765line 750 didn't jump to line 765 because the condition on line 750 was always true

751 initial_state_type 

752 in { 

753 states.StateType.RUNNING, 

754 states.StateType.CANCELLING, 

755 states.StateType.PENDING, 

756 } 

757 and proposed_state_type 

758 not in { 

759 states.StateType.PENDING, 

760 states.StateType.RUNNING, 

761 states.StateType.CANCELLING, 

762 } 

763 ): 

764 return 1bc

765 if not context.session or not context.run.deployment_id: 

766 return 

767 

768 lease_storage = get_concurrency_lease_storage() 

769 if ( 

770 context.initial_state 

771 and context.initial_state.state_details.deployment_concurrency_lease_id 

772 and ( 

773 lease := await lease_storage.read_lease( 

774 lease_id=context.initial_state.state_details.deployment_concurrency_lease_id, 

775 ) 

776 ) 

777 and lease.metadata 

778 ): 

779 await concurrency_limits_v2.bulk_decrement_active_slots( 

780 session=context.session, 

781 concurrency_limit_ids=lease.resource_ids, 

782 slots=lease.metadata.slots, 

783 ) 

784 await lease_storage.revoke_lease( 

785 lease_id=lease.id, 

786 ) 

787 else: 

788 deployment = await deployments.read_deployment( 

789 session=context.session, 

790 deployment_id=context.run.deployment_id, 

791 ) 

792 if not deployment or not deployment.concurrency_limit_id: 

793 return 

794 

795 await concurrency_limits_v2.bulk_decrement_active_slots( 

796 session=context.session, 

797 concurrency_limit_ids=[deployment.concurrency_limit_id], 

798 slots=1, 

799 ) 

800 

801 

802class CacheInsertion(TaskRunOrchestrationRule): 1a

803 """ 

804 Caches completed states with cache keys after they are validated. 

805 """ 

806 

807 FROM_STATES = ALL_ORCHESTRATION_STATES 1a

808 TO_STATES = {StateType.COMPLETED} 1a

809 

810 async def before_transition( 1a

811 self, 

812 initial_state: states.State[Any] | None, 

813 proposed_state: states.State[Any] | None, 

814 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy], 

815 ) -> None: 

816 if proposed_state is None: 

817 return 

818 

819 settings = get_current_settings() 

820 cache_key = proposed_state.state_details.cache_key 

821 if cache_key and len(cache_key) > settings.server.tasks.max_cache_key_length: 

822 await self.reject_transition( 

823 state=proposed_state, 

824 reason=f"Cache key exceeded maximum allowed length of {settings.server.tasks.max_cache_key_length} characters.", 

825 ) 

826 return 

827 

828 @db_injector 1a

829 async def after_transition( 1a

830 self, 

831 db: PrefectDBInterface, 

832 initial_state: states.State[Any] | None, 

833 validated_state: states.State[Any] | None, 

834 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy], 

835 ) -> None: 

836 if not validated_state or not context.session: 

837 return 

838 

839 cache_key = validated_state.state_details.cache_key 

840 if cache_key: 

841 new_cache_item = db.TaskRunStateCache( 

842 cache_key=cache_key, 

843 cache_expiration=validated_state.state_details.cache_expiration, 

844 task_run_state_id=validated_state.id, 

845 ) 

846 context.session.add(new_cache_item) 

847 

848 

849class CacheRetrieval(TaskRunOrchestrationRule): 1a

850 """ 

851 Rejects running states if a completed state has been cached. 

852 

853 This rule rejects transitions into a running state with a cache key if the key 

854 has already been associated with a completed state in the cache table. The client 

855 will be instructed to transition into the cached completed state instead. 

856 """ 

857 

858 FROM_STATES = ALL_ORCHESTRATION_STATES 1a

859 TO_STATES = {StateType.RUNNING} 1a

860 

861 @db_injector 1a

862 async def before_transition( 1a

863 self, 

864 db: PrefectDBInterface, 

865 initial_state: states.State[Any] | None, 

866 proposed_state: states.State[Any] | None, 

867 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy], 

868 ) -> None: 

869 if not proposed_state: 

870 return 

871 

872 cache_key = proposed_state.state_details.cache_key 

873 if cache_key and not proposed_state.state_details.refresh_cache: 

874 # Check for cached states matching the cache key 

875 cached_state_id = ( 

876 select(db.TaskRunStateCache.task_run_state_id) 

877 .where( 

878 sa.and_( 

879 db.TaskRunStateCache.cache_key == cache_key, 

880 sa.or_( 

881 db.TaskRunStateCache.cache_expiration.is_(None), 

882 db.TaskRunStateCache.cache_expiration > now("UTC"), 

883 ), 

884 ), 

885 ) 

886 .order_by(db.TaskRunStateCache.created.desc()) 

887 .limit(1) 

888 ).scalar_subquery() 

889 query = select(db.TaskRunState).where(db.TaskRunState.id == cached_state_id) 

890 cached_state = (await context.session.execute(query)).scalar() 

891 if cached_state: 

892 new_state = cached_state.as_state().fresh_copy() 

893 new_state.name = "Cached" 

894 await self.reject_transition( 

895 state=new_state, reason="Retrieved state from cache" 

896 ) 

897 

898 

899class RetryFailedFlows(FlowRunOrchestrationRule): 1a

900 """ 

901 Rejects failed states and schedules a retry if the retry limit has not been reached. 

902 

903 This rule rejects transitions into a failed state if `retries` has been 

904 set and the run count has not reached the specified limit. The client will be 

905 instructed to transition into a scheduled state to retry flow execution. 

906 """ 

907 

908 FROM_STATES = {StateType.RUNNING} 1a

909 TO_STATES = {StateType.FAILED} 1a

910 

911 async def before_transition( 1a

912 self, 

913 initial_state: states.State[Any] | None, 

914 proposed_state: states.State[Any] | None, 

915 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy], 

916 ) -> None: 

917 if initial_state is None or proposed_state is None: 

918 return 

919 

920 run_settings = context.run_settings 

921 run_count = context.run.run_count 

922 

923 if run_settings.retries is None or run_count > run_settings.retries: 

924 # Clear retry type to allow for future infrastructure level retries (e.g. via the UI) 

925 updated_policy = context.run.empirical_policy.model_dump() 

926 updated_policy["retry_type"] = None 

927 context.run.empirical_policy = core.FlowRunPolicy(**updated_policy) 

928 

929 return # Retry count exceeded, allow transition to failed 

930 

931 scheduled_start_time = now("UTC") + datetime.timedelta( 

932 seconds=run_settings.retry_delay or 0 

933 ) 

934 

935 # support old-style flow run retries for older clients 

936 # older flow retries require us to loop over failed tasks to update their state 

937 # this is not required after API version 0.8.3 

938 api_version = context.parameters.get("api-version", None) 

939 if api_version and api_version < Version("0.8.3"): 

940 failed_task_runs = await models.task_runs.read_task_runs( 

941 context.session, 

942 flow_run_filter=filters.FlowRunFilter( 

943 id=filters.FlowRunFilterId(any_=[context.run.id]) 

944 ), 

945 task_run_filter=filters.TaskRunFilter( 

946 state=filters.TaskRunFilterState( 

947 type=filters.TaskRunFilterStateType(any_=[StateType.FAILED]) 

948 ) 

949 ), 

950 ) 

951 for run in failed_task_runs: 

952 await models.task_runs.set_task_run_state( 

953 context.session, 

954 run.id, 

955 state=states.AwaitingRetry(scheduled_time=scheduled_start_time), 

956 force=True, 

957 ) 

958 # Reset the run count so that the task run retries still work correctly 

959 run.run_count = 0 

960 

961 # Reset pause metadata on retry 

962 # Pauses as a concept only exist after API version 0.8.4 

963 api_version = context.parameters.get("api-version", None) 

964 if api_version is None or api_version >= Version("0.8.4"): 

965 updated_policy = context.run.empirical_policy.model_dump() 

966 updated_policy["resuming"] = False 

967 updated_policy["pause_keys"] = set() 

968 updated_policy["retry_type"] = "in_process" 

969 context.run.empirical_policy = core.FlowRunPolicy(**updated_policy) 

970 

971 # Generate a new state for the flow 

972 retry_state = states.AwaitingRetry( 

973 scheduled_time=scheduled_start_time, 

974 message=proposed_state.message, 

975 data=proposed_state.data, 

976 ) 

977 await self.reject_transition(state=retry_state, reason="Retrying") 

978 

979 

980class RetryFailedTasks(TaskRunOrchestrationRule): 1a

981 """ 

982 Rejects failed states and schedules a retry if the retry limit has not been reached. 

983 

984 This rule rejects transitions into a failed state if `retries` has been 

985 set, the run count has not reached the specified limit, and the client 

986 asserts it is a retriable task run. The client will be instructed to 

987 transition into a scheduled state to retry task execution. 

988 """ 

989 

990 FROM_STATES = {StateType.RUNNING} 1a

991 TO_STATES = {StateType.FAILED} 1a

992 

993 async def before_transition( 1a

994 self, 

995 initial_state: states.State[Any] | None, 

996 proposed_state: states.State[Any] | None, 

997 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy], 

998 ) -> None: 

999 if initial_state is None or proposed_state is None: 

1000 return 

1001 

1002 run_settings = context.run_settings 

1003 run_count = context.run.run_count 

1004 delay = run_settings.retry_delay 

1005 

1006 if isinstance(delay, list): 

1007 base_delay = delay[min(run_count - 1, len(delay) - 1)] 

1008 else: 

1009 base_delay = delay or 0 

1010 

1011 # guard against negative relative jitter inputs 

1012 if run_settings.retry_jitter_factor: 

1013 delay = clamped_poisson_interval( 

1014 base_delay, clamping_factor=run_settings.retry_jitter_factor 

1015 ) 

1016 else: 

1017 delay = base_delay 

1018 

1019 # set by user to conditionally retry a task using @task(retry_condition_fn=...) 

1020 if getattr(proposed_state.state_details, "retriable", True) is False: 

1021 return 

1022 

1023 if run_settings.retries is not None and run_count <= run_settings.retries: 

1024 retry_state = states.AwaitingRetry( 

1025 scheduled_time=now("UTC") + datetime.timedelta(seconds=delay), 

1026 message=proposed_state.message, 

1027 data=proposed_state.data, 

1028 ) 

1029 await self.reject_transition(state=retry_state, reason="Retrying") 

1030 

1031 

1032class EnqueueScheduledTasks(TaskRunOrchestrationRule): 1a

1033 """ 

1034 Enqueues background task runs when they are scheduled 

1035 """ 

1036 

1037 FROM_STATES = ALL_ORCHESTRATION_STATES 1a

1038 TO_STATES = {StateType.SCHEDULED} 1a

1039 

1040 async def after_transition( 1a

1041 self, 

1042 initial_state: states.State[Any] | None, 

1043 validated_state: states.State[Any] | None, 

1044 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy], 

1045 ) -> None: 

1046 if not validated_state: 

1047 # Only if the transition was valid 

1048 return 

1049 

1050 if not validated_state.state_details.deferred: 

1051 # Only for tasks that are deferred 

1052 return 

1053 

1054 task_run: core.TaskRun = core.TaskRun.model_validate(context.run) 

1055 queue: TaskQueue = TaskQueue.for_key(task_run.task_key) 

1056 

1057 if validated_state.name == "AwaitingRetry": 

1058 await queue.retry(task_run) 

1059 else: 

1060 await queue.enqueue(task_run) 

1061 

1062 

1063class RenameReruns(GenericOrchestrationRule): 1a

1064 """ 

1065 Name the states if they have run more than once. 

1066 

1067 In the special case where the initial state is an "AwaitingRetry" scheduled state, 

1068 the proposed state will be renamed to "Retrying" instead. 

1069 """ 

1070 

1071 FROM_STATES = ALL_ORCHESTRATION_STATES 1a

1072 TO_STATES = {StateType.RUNNING} 1a

1073 

1074 async def before_transition( 1a

1075 self, 

1076 initial_state: states.State[Any] | None, 

1077 proposed_state: states.State[Any] | None, 

1078 context: OrchestrationContext[ 

1079 orm_models.Run, core.TaskRunPolicy | core.FlowRunPolicy 

1080 ], 

1081 ) -> None: 

1082 if initial_state is None or proposed_state is None: 

1083 return 

1084 

1085 run_count = context.run.run_count 

1086 if run_count > 0: 

1087 if initial_state.name == "AwaitingRetry": 

1088 await self.rename_state("Retrying") 

1089 else: 

1090 await self.rename_state("Rerunning") 

1091 

1092 

1093class CopyScheduledTime( 1a

1094 BaseOrchestrationRule[orm_models.Run, Union[core.TaskRunPolicy, core.FlowRunPolicy]] 

1095): 

1096 """ 

1097 Ensures scheduled time is copied from scheduled states to pending states. 

1098 

1099 If a new scheduled time has been proposed on the pending state, the scheduled time 

1100 on the scheduled state will be ignored. 

1101 """ 

1102 

1103 FROM_STATES = {StateType.SCHEDULED} 1a

1104 TO_STATES = {StateType.PENDING} 1a

1105 

1106 async def before_transition( 1a

1107 self, 

1108 initial_state: states.State[Any] | None, 

1109 proposed_state: states.State[Any] | None, 

1110 context: OrchestrationContext[ 

1111 orm_models.Run, core.TaskRunPolicy | core.FlowRunPolicy 

1112 ], 

1113 ) -> None: 

1114 if initial_state is None or proposed_state is None: 

1115 return 

1116 

1117 if not proposed_state.state_details.scheduled_time: 

1118 proposed_state.state_details.scheduled_time = ( 

1119 initial_state.state_details.scheduled_time 

1120 ) 

1121 

1122 

1123class WaitForScheduledTime( 1a

1124 BaseOrchestrationRule[orm_models.Run, Union[core.TaskRunPolicy, core.FlowRunPolicy]] 

1125): 

1126 """ 

1127 Prevents transitions to running states from happening too early. 

1128 

1129 This rule enforces that all scheduled states will only start with the machine clock 

1130 used by the Prefect REST API instance. This rule will identify transitions from scheduled 

1131 states that are too early and nullify them. Instead, no state will be written to the 

1132 database and the client will be sent an instruction to wait for `delay_seconds` 

1133 before attempting the transition again. 

1134 """ 

1135 

1136 FROM_STATES = {StateType.SCHEDULED, StateType.PENDING} 1a

1137 TO_STATES = {StateType.RUNNING} 1a

1138 

1139 async def before_transition( 1a

1140 self, 

1141 initial_state: states.State[Any] | None, 

1142 proposed_state: states.State[Any] | None, 

1143 context: OrchestrationContext[ 

1144 orm_models.Run, core.TaskRunPolicy | core.FlowRunPolicy 

1145 ], 

1146 ) -> None: 

1147 if initial_state is None or proposed_state is None: 

1148 return 

1149 

1150 scheduled_time = initial_state.state_details.scheduled_time 

1151 if not scheduled_time: 

1152 return 

1153 

1154 # At this moment, we round delay to the nearest second as the API schema 

1155 # specifies an integer return value. 

1156 delay = scheduled_time - now("UTC") 

1157 delay_seconds = math.floor(delay.total_seconds()) 

1158 delay_seconds += round(delay.microseconds / 1e6) 

1159 if delay_seconds > 0: 

1160 await self.delay_transition( 

1161 delay_seconds, reason="Scheduled time is in the future" 

1162 ) 

1163 

1164 

1165class CopyTaskParametersID(TaskRunOrchestrationRule): 1a

1166 """ 

1167 Ensures a task's parameters ID is copied from Scheduled to Pending and from 

1168 Pending to Running states. 

1169 

1170 If a parameters ID has been included on the proposed state, the parameters ID 

1171 on the initial state will be ignored. 

1172 """ 

1173 

1174 FROM_STATES = {StateType.SCHEDULED, StateType.PENDING} 1a

1175 TO_STATES = {StateType.PENDING, StateType.RUNNING} 1a

1176 

1177 async def before_transition( 1a

1178 self, 

1179 initial_state: states.State[Any] | None, 

1180 proposed_state: states.State[Any] | None, 

1181 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy], 

1182 ) -> None: 

1183 if initial_state is None or proposed_state is None: 

1184 return 

1185 

1186 if not proposed_state.state_details.task_parameters_id: 

1187 proposed_state.state_details.task_parameters_id = ( 

1188 initial_state.state_details.task_parameters_id 

1189 ) 

1190 

1191 

1192class HandlePausingFlows(FlowRunOrchestrationRule): 1a

1193 """ 

1194 Governs runs attempting to enter a Paused/Suspended state 

1195 """ 

1196 

1197 FROM_STATES = ALL_ORCHESTRATION_STATES 1a

1198 TO_STATES = {StateType.PAUSED} 1a

1199 

1200 async def before_transition( 1a

1201 self, 

1202 initial_state: states.State[Any] | None, 

1203 proposed_state: states.State[Any] | None, 

1204 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy], 

1205 ) -> None: 

1206 if proposed_state is None: 

1207 return 

1208 

1209 verb = "suspend" if proposed_state.name == "Suspended" else "pause" 

1210 

1211 if initial_state is None: 

1212 await self.abort_transition(f"Cannot {verb} flows with no state.") 

1213 return 

1214 

1215 if not initial_state.is_running(): 

1216 await self.reject_transition( 

1217 state=None, 

1218 reason=f"Cannot {verb} flows that are not currently running.", 

1219 ) 

1220 return 

1221 

1222 self.key = proposed_state.state_details.pause_key 

1223 if self.key is None: 

1224 # if no pause key is provided, default to a UUID 

1225 self.key = str(uuid4()) 

1226 

1227 pause_keys = context.run.empirical_policy.pause_keys or set() 

1228 if self.key in pause_keys: 

1229 await self.reject_transition( 

1230 state=None, reason=f"This {verb} has already fired." 

1231 ) 

1232 return 

1233 

1234 if proposed_state.state_details.pause_reschedule: 

1235 if context.run.parent_task_run_id: 

1236 await self.abort_transition( 

1237 reason=f"Cannot {verb} subflows.", 

1238 ) 

1239 return 

1240 

1241 if context.run.deployment_id is None: 

1242 await self.abort_transition( 

1243 reason=f"Cannot {verb} flows without a deployment.", 

1244 ) 

1245 return 

1246 

1247 async def after_transition( 1a

1248 self, 

1249 initial_state: states.State[Any] | None, 

1250 validated_state: states.State[Any] | None, 

1251 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy], 

1252 ) -> None: 

1253 updated_policy = context.run.empirical_policy.model_dump() 

1254 updated_policy["pause_keys"].add(self.key) 

1255 context.run.empirical_policy = core.FlowRunPolicy(**updated_policy) 

1256 

1257 

1258class HandleResumingPausedFlows(FlowRunOrchestrationRule): 1a

1259 """ 

1260 Governs runs attempting to leave a Paused state 

1261 """ 

1262 

1263 FROM_STATES = {StateType.PAUSED} 1a

1264 TO_STATES = ALL_ORCHESTRATION_STATES 1a

1265 

1266 async def before_transition( 1a

1267 self, 

1268 initial_state: states.State[Any] | None, 

1269 proposed_state: states.State[Any] | None, 

1270 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy], 

1271 ) -> None: 

1272 if initial_state is None or proposed_state is None: 

1273 return 

1274 

1275 if not ( 

1276 proposed_state 

1277 and ( 

1278 proposed_state.is_running() 

1279 or proposed_state.is_scheduled() 

1280 or proposed_state.is_final() 

1281 ) 

1282 ): 

1283 await self.reject_transition( 

1284 state=None, 

1285 reason=( 

1286 f"This run cannot transition to the {proposed_state.type} state" 

1287 f" from the {initial_state.type} state." 

1288 ), 

1289 ) 

1290 return 

1291 

1292 verb = "suspend" if proposed_state.name == "Suspended" else "pause" 

1293 

1294 display_state_name = ( 

1295 proposed_state.name.lower() 

1296 if proposed_state.name 

1297 else proposed_state.type.value.lower() 

1298 ) 

1299 

1300 if initial_state.state_details.pause_reschedule: 

1301 if not context.run.deployment_id: 

1302 await self.reject_transition( 

1303 state=None, 

1304 reason=( 

1305 f"Cannot reschedule a {display_state_name} flow run" 

1306 " without a deployment." 

1307 ), 

1308 ) 

1309 return 

1310 pause_timeout = initial_state.state_details.pause_timeout 

1311 if pause_timeout and pause_timeout < now("UTC"): 

1312 pause_timeout_failure = states.Failed( 

1313 message=(f"The flow was {display_state_name} and never resumed."), 

1314 ) 

1315 await self.reject_transition( 

1316 state=pause_timeout_failure, 

1317 reason=f"The flow run {verb} has timed out and can no longer resume.", 

1318 ) 

1319 return 

1320 

1321 async def after_transition( 1a

1322 self, 

1323 initial_state: states.State[Any] | None, 

1324 validated_state: states.State[Any] | None, 

1325 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy], 

1326 ) -> None: 

1327 updated_policy = context.run.empirical_policy.model_dump() 

1328 updated_policy["resuming"] = True 

1329 context.run.empirical_policy = core.FlowRunPolicy(**updated_policy) 

1330 

1331 

1332class UpdateFlowRunTrackerOnTasks(TaskRunOrchestrationRule): 1a

1333 """ 

1334 Tracks the flow run attempt a task run state is associated with. 

1335 """ 

1336 

1337 FROM_STATES = ALL_ORCHESTRATION_STATES 1a

1338 TO_STATES = {StateType.RUNNING} 1a

1339 

1340 async def after_transition( 1a

1341 self, 

1342 initial_state: states.State[Any] | None, 

1343 validated_state: states.State[Any] | None, 

1344 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy], 

1345 ) -> None: 

1346 if context.run.flow_run_id is not None: 

1347 self.flow_run: orm_models.FlowRun | None = await context.flow_run() 

1348 if self.flow_run: 

1349 context.run.flow_run_run_count = self.flow_run.run_count 

1350 else: 

1351 raise ObjectNotFoundError( 

1352 ( 

1353 "Unable to read flow run associated with task run:" 

1354 f" {context.run.id}, this flow run might have been deleted" 

1355 ), 

1356 ) 

1357 

1358 

1359class HandleTaskTerminalStateTransitions(TaskRunOrchestrationRule): 1a

1360 """ 

1361 We do not allow tasks to leave terminal states if: 

1362 - The task is completed and has a persisted result 

1363 - The task is going to CANCELLING / PAUSED / CRASHED 

1364 

1365 We reset the run count when a task leaves a terminal state for a non-terminal state 

1366 which resets task run retries; this is particularly relevant for flow run retries. 

1367 """ 

1368 

1369 FROM_STATES: set[states.StateType | None] = TERMINAL_STATES # pyright: ignore[reportAssignmentType] technically TERMINAL_STATES doesn't contain None 1a

1370 TO_STATES: set[states.StateType | None] = ALL_ORCHESTRATION_STATES 1a

1371 

1372 async def before_transition( 1a

1373 self, 

1374 initial_state: states.State[Any] | None, 

1375 proposed_state: states.State[Any] | None, 

1376 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy], 

1377 ) -> None: 

1378 if initial_state is None or proposed_state is None: 

1379 return 

1380 

1381 self.original_run_count: int = context.run.run_count 

1382 

1383 # Do not allow runs to be marked as crashed, paused, or cancelling if already terminal 

1384 if proposed_state.type in { 

1385 StateType.CANCELLING, 

1386 StateType.PAUSED, 

1387 StateType.CRASHED, 

1388 }: 

1389 await self.abort_transition(f"Run is already {initial_state.type.value}.") 

1390 return 

1391 

1392 # Only allow departure from a happily completed state if the result is not persisted 

1393 if ( 

1394 initial_state.is_completed() 

1395 and initial_state.data 

1396 and initial_state.data.get("type") != "unpersisted" 

1397 ): 

1398 await self.reject_transition(None, "This run is already completed.") 

1399 return 

1400 

1401 if not proposed_state.is_final(): 

1402 # Reset run count to reset retries 

1403 context.run.run_count = 0 

1404 

1405 # Change the name of the state to retrying if its a flow run retry 

1406 if proposed_state.is_running() and context.run.flow_run_id is not None: 

1407 self.flow_run: orm_models.FlowRun | None = await context.flow_run() 

1408 if self.flow_run is not None: 

1409 flow_retrying = context.run.flow_run_run_count < self.flow_run.run_count 

1410 if flow_retrying: 

1411 await self.rename_state("Retrying") 

1412 

1413 async def cleanup( 1a

1414 self, 

1415 initial_state: states.State[Any] | None, 

1416 validated_state: states.State[Any] | None, 

1417 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy], 

1418 ) -> None: 

1419 # reset run count 

1420 context.run.run_count = self.original_run_count 

1421 

1422 

1423class HandleFlowTerminalStateTransitions(FlowRunOrchestrationRule): 1a

1424 """ 

1425 We do not allow flows to leave terminal states if: 

1426 - The flow is completed and has a persisted result 

1427 - The flow is going to CANCELLING / PAUSED / CRASHED 

1428 - The flow is going to scheduled and has no deployment 

1429 

1430 We reset the pause metadata when a flow leaves a terminal state for a non-terminal 

1431 state. This resets pause behavior during manual flow run retries. 

1432 """ 

1433 

1434 FROM_STATES: set[states.StateType | None] = TERMINAL_STATES # pyright: ignore[reportAssignmentType] technically TERMINAL_STATES doesn't contain None 1a

1435 TO_STATES: set[states.StateType | None] = ALL_ORCHESTRATION_STATES 1a

1436 

1437 async def before_transition( 1a

1438 self, 

1439 initial_state: states.State[Any] | None, 

1440 proposed_state: states.State[Any] | None, 

1441 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy], 

1442 ) -> None: 

1443 if initial_state is None or proposed_state is None: 

1444 return 

1445 

1446 self.original_flow_policy: dict[str, Any] = ( 

1447 context.run.empirical_policy.model_dump() 

1448 ) 

1449 

1450 # Do not allow runs to be marked as crashed, paused, or cancelling if already terminal 

1451 if proposed_state.type in { 

1452 StateType.CANCELLING, 

1453 StateType.PAUSED, 

1454 StateType.CRASHED, 

1455 }: 

1456 await self.abort_transition( 

1457 f"Run is already in terminal state {initial_state.type.value}." 

1458 ) 

1459 return 

1460 

1461 # Only allow departure from a happily completed state if the result is not 

1462 # persisted and the a rerun is being proposed 

1463 if ( 

1464 initial_state.is_completed() 

1465 and not proposed_state.is_final() 

1466 and initial_state.data 

1467 and initial_state.data.get("type") != "unpersisted" 

1468 ): 

1469 await self.reject_transition(None, "Run is already COMPLETED.") 

1470 return 

1471 

1472 # Do not allows runs to be rescheduled without a deployment 

1473 if proposed_state.is_scheduled() and not context.run.deployment_id: 

1474 await self.abort_transition( 

1475 "Cannot reschedule a run without an associated deployment." 

1476 ) 

1477 return 

1478 

1479 if not proposed_state.is_final(): 

1480 # Reset pause metadata when leaving a terminal state 

1481 api_version = context.parameters.get("api-version", None) 

1482 if api_version is None or api_version >= Version("0.8.4"): 

1483 updated_policy = context.run.empirical_policy.model_dump() 

1484 updated_policy["resuming"] = False 

1485 updated_policy["pause_keys"] = set() 

1486 if proposed_state.is_scheduled(): 

1487 updated_policy["retry_type"] = "reschedule" 

1488 else: 

1489 updated_policy["retry_type"] = None 

1490 context.run.empirical_policy = core.FlowRunPolicy(**updated_policy) 

1491 

1492 async def cleanup( 1a

1493 self, 

1494 initial_state: states.State[Any] | None, 

1495 validated_state: states.State[Any] | None, 

1496 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy], 

1497 ) -> None: 

1498 context.run.empirical_policy = core.FlowRunPolicy(**self.original_flow_policy) 

1499 

1500 

1501class PreventPendingTransitions(GenericOrchestrationRule): 1a

1502 """ 

1503 Prevents transitions to PENDING. 

1504 

1505 This rule is only used for flow runs. 

1506 

1507 This is intended to prevent race conditions during duplicate submissions of runs. 

1508 Before a run is submitted to its execution environment, it should be placed in a 

1509 PENDING state. If two workers attempt to submit the same run, one of them should 

1510 encounter a PENDING -> PENDING transition and abort orchestration of the run. 

1511 

1512 Similarly, if the execution environment starts quickly the run may be in a RUNNING 

1513 state when the second worker attempts the PENDING transition. We deny these state 

1514 changes as well to prevent duplicate submission. If a run has transitioned to a 

1515 RUNNING state a worker should not attempt to submit it again unless it has moved 

1516 into a terminal state. 

1517 

1518 CANCELLING and CANCELLED runs should not be allowed to transition to PENDING. 

1519 For re-runs of deployed runs, they should transition to SCHEDULED first. 

1520 For re-runs of ad-hoc runs, they should transition directly to RUNNING. 

1521 """ 

1522 

1523 FROM_STATES = { 1a

1524 StateType.PENDING, 

1525 StateType.CANCELLING, 

1526 StateType.RUNNING, 

1527 StateType.CANCELLED, 

1528 } 

1529 TO_STATES = {StateType.PENDING} 1a

1530 

1531 async def before_transition( 1a

1532 self, 

1533 initial_state: states.State[Any] | None, 

1534 proposed_state: states.State[Any] | None, 

1535 context: OrchestrationContext[ 

1536 orm_models.Run, Union[core.FlowRunPolicy, core.TaskRunPolicy] 

1537 ], 

1538 ) -> None: 

1539 if initial_state is None or proposed_state is None: 

1540 return 

1541 

1542 await self.abort_transition( 

1543 reason=( 

1544 f"This run is in a {initial_state.type.name} state and cannot" 

1545 " transition to a PENDING state." 

1546 ) 

1547 ) 

1548 

1549 

1550class EnsureOnlyScheduledFlowsMarkedLate(FlowRunOrchestrationRule): 1a

1551 FROM_STATES = ALL_ORCHESTRATION_STATES 1a

1552 TO_STATES = {StateType.SCHEDULED} 1a

1553 

1554 async def before_transition( 1a

1555 self, 

1556 initial_state: states.State[Any] | None, 

1557 proposed_state: states.State[Any] | None, 

1558 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy], 

1559 ) -> None: 

1560 if initial_state is None or proposed_state is None: 

1561 return 

1562 

1563 marking_flow_late = ( 

1564 proposed_state.is_scheduled() and proposed_state.name == "Late" 

1565 ) 

1566 if marking_flow_late and not initial_state.is_scheduled(): 

1567 await self.reject_transition( 

1568 state=None, reason="Only scheduled flows can be marked late." 

1569 ) 

1570 

1571 

1572class PreventRunningTasksFromStoppedFlows(TaskRunOrchestrationRule): 1a

1573 """ 

1574 Prevents running tasks from stopped flows. 

1575 

1576 A running state implies execution, but also the converse. This rule ensures that a 

1577 flow's tasks cannot be run unless the flow is also running. 

1578 """ 

1579 

1580 FROM_STATES = ALL_ORCHESTRATION_STATES 1a

1581 TO_STATES = {StateType.RUNNING} 1a

1582 

1583 async def before_transition( 1a

1584 self, 

1585 initial_state: states.State[Any] | None, 

1586 proposed_state: states.State[Any] | None, 

1587 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy], 

1588 ) -> None: 

1589 flow_run = await context.flow_run() 

1590 if flow_run is not None: 

1591 if flow_run.state is None: 

1592 await self.abort_transition( 

1593 reason="The enclosing flow must be running to begin task execution." 

1594 ) 

1595 elif flow_run.state.type == StateType.PAUSED: 

1596 # Use the flow run's Paused state details to preserve data like 

1597 # timeouts. 

1598 paused_state = states.Paused( 

1599 name="NotReady", 

1600 pause_expiration_time=flow_run.state.state_details.pause_timeout, 

1601 reschedule=flow_run.state.state_details.pause_reschedule, 

1602 ) 

1603 await self.reject_transition( 

1604 state=paused_state, 

1605 reason=( 

1606 "The flow is paused, new tasks can execute after resuming flow" 

1607 f" run: {flow_run.id}." 

1608 ), 

1609 ) 

1610 elif not flow_run.state.type == StateType.RUNNING: 

1611 # task runners should abort task run execution 

1612 await self.abort_transition( 

1613 reason=( 

1614 "The enclosing flow must be running to begin task execution." 

1615 ), 

1616 ) 

1617 

1618 

1619class EnforceCancellingToCancelledTransition(TaskRunOrchestrationRule): 1a

1620 """ 

1621 Rejects transitions from Cancelling to any terminal state except for Cancelled. 

1622 """ 

1623 

1624 FROM_STATES = {StateType.CANCELLED, StateType.CANCELLING} 1a

1625 TO_STATES = ALL_ORCHESTRATION_STATES - {StateType.CANCELLED} 1a

1626 

1627 async def before_transition( 1a

1628 self, 

1629 initial_state: states.State[Any] | None, 

1630 proposed_state: states.State[Any] | None, 

1631 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy], 

1632 ) -> None: 

1633 await self.reject_transition( 

1634 state=None, 

1635 reason=( 

1636 "Cannot transition flows that are cancelling to a state other " 

1637 "than Cancelled." 

1638 ), 

1639 ) 

1640 return 

1641 

1642 

1643class BypassCancellingFlowRunsWithNoInfra(FlowRunOrchestrationRule): 1a

1644 """Rejects transitions from Scheduled to Cancelling, and instead sets the state to Cancelled, 

1645 if the flow run has no associated infrastructure process ID. Also Rejects transitions from 

1646 Paused to Cancelling if the Paused state's details indicates the flow run has been suspended, 

1647 exiting the flow and tearing down infra. 

1648 

1649 The `Cancelling` state is used to clean up infrastructure. If there is not infrastructure 

1650 to clean up, we can transition directly to `Cancelled`. Runs that are `Resuming` are in a 

1651 `Scheduled` state that were previously `Suspended` and do not yet have infrastructure. 

1652 

1653 Runs that are `AwaitingRetry` are a `Scheduled` state that may have associated infrastructure. 

1654 """ 

1655 

1656 FROM_STATES = {StateType.SCHEDULED, StateType.PAUSED} 1a

1657 TO_STATES = {StateType.CANCELLING} 1a

1658 

1659 async def before_transition( 1a

1660 self, 

1661 initial_state: states.State[Any] | None, 

1662 proposed_state: states.State[Any] | None, 

1663 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy], 

1664 ) -> None: 

1665 if initial_state is None or proposed_state is None: 

1666 return 

1667 

1668 if ( 

1669 initial_state.type == states.StateType.SCHEDULED 

1670 and not context.run.infrastructure_pid 

1671 or initial_state.name == "Resuming" 

1672 ): 

1673 await self.reject_transition( 

1674 state=states.Cancelled(), 

1675 reason="Scheduled flow run has no infrastructure to terminate.", 

1676 ) 

1677 elif ( 

1678 initial_state.type == states.StateType.PAUSED 

1679 and initial_state.state_details.pause_reschedule 

1680 ): 

1681 await self.reject_transition( 

1682 state=states.Cancelled(), 

1683 reason="Suspended flow run has no infrastructure to terminate.", 

1684 ) 

1685 

1686 

1687class PreventDuplicateTransitions(FlowRunOrchestrationRule): 1a

1688 """ 

1689 Prevent duplicate transitions from being made right after one another. 

1690 

1691 This rule allows for clients to set an optional transition_id on a state. If the 

1692 run's next transition has the same transition_id, the transition will be 

1693 rejected and the existing state will be returned. 

1694 

1695 This allows for clients to make state transition requests without worrying about 

1696 the following case: 

1697 - A client making a state transition request 

1698 - The server accepts transition and commits the transition 

1699 - The client is unable to receive the response and retries the request 

1700 """ 

1701 

1702 FROM_STATES: set[states.StateType | None] = ALL_ORCHESTRATION_STATES 1a

1703 TO_STATES: set[states.StateType | None] = ALL_ORCHESTRATION_STATES 1a

1704 

1705 async def before_transition( 1a

1706 self, 

1707 initial_state: states.State[Any] | None, 

1708 proposed_state: states.State[Any] | None, 

1709 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy], 

1710 ) -> None: 

1711 if initial_state is None or proposed_state is None: 

1712 return 

1713 

1714 initial_transition_id = getattr( 

1715 initial_state.state_details, "transition_id", None 

1716 ) 

1717 proposed_transition_id = getattr( 

1718 proposed_state.state_details, "transition_id", None 

1719 ) 

1720 if ( 

1721 initial_transition_id is not None 

1722 and proposed_transition_id is not None 

1723 and initial_transition_id == proposed_transition_id 

1724 ): 

1725 await self.reject_transition( 

1726 # state=None will return the initial (current) state 

1727 state=None, 

1728 reason="This run has already made this state transition.", 

1729 )