Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/orchestration/core_policy.py: 20%
547 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 13:38 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 13:38 +0000
1"""
2Orchestration logic that fires on state transitions.
4`CoreFlowPolicy` and `CoreTaskPolicy` contain all default orchestration rules that
5Prefect enforces on a state transition.
6"""
8from __future__ import annotations 1a
10import datetime 1a
11import logging 1a
12import math 1a
13from typing import Any, Union, cast 1a
14from uuid import UUID, uuid4 1a
16import sqlalchemy as sa 1a
17from packaging.version import Version 1a
18from sqlalchemy import select 1a
20from prefect.logging import get_logger 1a
21from prefect.server import models 1a
22from prefect.server.concurrency.lease_storage import ( 1a
23 ConcurrencyLeaseHolder,
24 ConcurrencyLimitLeaseMetadata,
25 get_concurrency_lease_storage,
26)
27from prefect.server.database import PrefectDBInterface, orm_models 1a
28from prefect.server.database.dependencies import db_injector, provide_database_interface 1a
29from prefect.server.exceptions import ObjectNotFoundError 1a
30from prefect.server.models import concurrency_limits, concurrency_limits_v2, deployments 1a
31from prefect.server.orchestration.dependencies import ( 1a
32 MIN_CLIENT_VERSION_FOR_CONCURRENCY_LIMIT_LEASING,
33 WORKER_VERSIONS_THAT_MANAGE_DEPLOYMENT_CONCURRENCY,
34)
35from prefect.server.orchestration.policies import ( 1a
36 FlowRunOrchestrationPolicy,
37 TaskRunOrchestrationPolicy,
38)
39from prefect.server.orchestration.rules import ( 1a
40 ALL_ORCHESTRATION_STATES,
41 TERMINAL_STATES,
42 BaseOrchestrationRule,
43 BaseUniversalTransform,
44 FlowOrchestrationContext,
45 FlowRunOrchestrationRule,
46 FlowRunUniversalTransform,
47 GenericOrchestrationRule,
48 OrchestrationContext,
49 TaskRunOrchestrationRule,
50 TaskRunUniversalTransform,
51)
52from prefect.server.schemas import core, filters, states 1a
53from prefect.server.schemas.states import StateType 1a
54from prefect.server.task_queue import TaskQueue 1a
55from prefect.settings import ( 1a
56 get_current_settings,
57)
58from prefect.types._datetime import now 1a
59from prefect.utilities.math import clamped_poisson_interval 1a
61from .instrumentation_policies import InstrumentFlowRunStateTransitions 1a
63logger: logging.Logger = get_logger(__name__) 1a
66class CoreFlowPolicy(FlowRunOrchestrationPolicy): 1a
67 """
68 Orchestration rules that run against flow-run-state transitions in priority order.
69 """
71 @staticmethod 1a
72 def priority() -> list[ 1a
73 Union[
74 type[BaseUniversalTransform[orm_models.FlowRun, core.FlowRunPolicy]],
75 type[BaseOrchestrationRule[orm_models.FlowRun, core.FlowRunPolicy]],
76 ]
77 ]:
78 return cast(
79 list[
80 Union[
81 type[
82 BaseUniversalTransform[orm_models.FlowRun, core.FlowRunPolicy]
83 ],
84 type[BaseOrchestrationRule[orm_models.FlowRun, core.FlowRunPolicy]],
85 ]
86 ],
87 [
88 PreventDuplicateTransitions,
89 HandleFlowTerminalStateTransitions,
90 EnforceCancellingToCancelledTransition,
91 BypassCancellingFlowRunsWithNoInfra,
92 PreventPendingTransitions,
93 CopyDeploymentConcurrencyLeaseID,
94 SecureFlowConcurrencySlots,
95 RemoveDeploymentConcurrencyLeaseForOldClientVersions,
96 EnsureOnlyScheduledFlowsMarkedLate,
97 HandlePausingFlows,
98 HandleResumingPausedFlows,
99 CopyScheduledTime,
100 WaitForScheduledTime,
101 RetryFailedFlows,
102 InstrumentFlowRunStateTransitions,
103 ReleaseFlowConcurrencySlots,
104 ],
105 )
108class CoreTaskPolicy(TaskRunOrchestrationPolicy): 1a
109 """
110 Orchestration rules that run against task-run-state transitions in priority order.
111 """
113 @staticmethod 1a
114 def priority() -> list[ 1a
115 Union[
116 type[BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy]],
117 type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]],
118 ]
119 ]:
120 return cast(
121 list[
122 Union[
123 type[
124 BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy]
125 ],
126 type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]],
127 ]
128 ],
129 [
130 CacheRetrieval,
131 HandleTaskTerminalStateTransitions,
132 PreventRunningTasksFromStoppedFlows,
133 SecureTaskConcurrencySlots, # retrieve cached states even if slots are full
134 CopyScheduledTime,
135 WaitForScheduledTime,
136 RetryFailedTasks,
137 RenameReruns,
138 UpdateFlowRunTrackerOnTasks,
139 CacheInsertion,
140 ReleaseTaskConcurrencySlots,
141 ],
142 )
145class ClientSideTaskOrchestrationPolicy(TaskRunOrchestrationPolicy): 1a
146 """
147 Orchestration rules that run against task-run-state transitions in priority order,
148 specifically for clients doing client-side orchestration.
149 """
151 @staticmethod 1a
152 def priority() -> list[ 1a
153 Union[
154 type[BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy]],
155 type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]],
156 ]
157 ]:
158 return cast(
159 list[
160 Union[
161 type[
162 BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy]
163 ],
164 type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]],
165 ]
166 ],
167 [
168 CacheRetrieval,
169 HandleTaskTerminalStateTransitions,
170 PreventRunningTasksFromStoppedFlows,
171 CopyScheduledTime,
172 WaitForScheduledTime,
173 RetryFailedTasks,
174 RenameReruns,
175 UpdateFlowRunTrackerOnTasks,
176 CacheInsertion,
177 ReleaseTaskConcurrencySlots,
178 ],
179 )
182class BackgroundTaskPolicy(TaskRunOrchestrationPolicy): 1a
183 """
184 Orchestration rules that run against task-run-state transitions in priority order.
185 """
187 @staticmethod 1a
188 def priority() -> list[ 1a
189 type[BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy]]
190 | type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]]
191 ]:
192 return cast(
193 list[
194 Union[
195 type[
196 BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy]
197 ],
198 type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]],
199 ]
200 ],
201 [
202 PreventPendingTransitions,
203 CacheRetrieval,
204 HandleTaskTerminalStateTransitions,
205 # SecureTaskConcurrencySlots, # retrieve cached states even if slots are full
206 CopyScheduledTime,
207 CopyTaskParametersID,
208 WaitForScheduledTime,
209 RetryFailedTasks,
210 RenameReruns,
211 UpdateFlowRunTrackerOnTasks,
212 CacheInsertion,
213 ReleaseTaskConcurrencySlots,
214 EnqueueScheduledTasks,
215 ],
216 )
219class MinimalFlowPolicy(FlowRunOrchestrationPolicy): 1a
220 @staticmethod 1a
221 def priority() -> list[ 1a
222 Union[
223 type[BaseUniversalTransform[orm_models.FlowRun, core.FlowRunPolicy]],
224 type[BaseOrchestrationRule[orm_models.FlowRun, core.FlowRunPolicy]],
225 ]
226 ]:
227 return [
228 BypassCancellingFlowRunsWithNoInfra, # cancel scheduled or suspended runs from the UI
229 InstrumentFlowRunStateTransitions,
230 ReleaseFlowConcurrencySlots,
231 ]
234class MarkLateRunsPolicy(FlowRunOrchestrationPolicy): 1a
235 @staticmethod 1a
236 def priority() -> list[ 1a
237 Union[
238 type[BaseUniversalTransform[orm_models.FlowRun, core.FlowRunPolicy]],
239 type[BaseOrchestrationRule[orm_models.FlowRun, core.FlowRunPolicy]],
240 ]
241 ]:
242 return [
243 EnsureOnlyScheduledFlowsMarkedLate,
244 InstrumentFlowRunStateTransitions,
245 ]
248class MinimalTaskPolicy(TaskRunOrchestrationPolicy): 1a
249 @staticmethod 1a
250 def priority() -> list[ 1a
251 Union[
252 type[BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy]],
253 type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]],
254 ]
255 ]:
256 return [
257 ReleaseTaskConcurrencySlots, # always release concurrency slots
258 ]
261class SecureTaskConcurrencySlots(TaskRunOrchestrationRule): 1a
262 """
263 Checks relevant concurrency slots are available before entering a Running state.
265 This rule checks if concurrency limits have been set on the tags associated with a
266 TaskRun. If so, a concurrency slot will be secured against each concurrency limit
267 before being allowed to transition into a running state. If a concurrency limit has
268 been reached, the client will be instructed to delay the transition for the duration
269 specified by the "PREFECT_TASK_RUN_TAG_CONCURRENCY_SLOT_WAIT_SECONDS" setting
270 before trying again. If the concurrency limit set on a tag is 0, the transition will
271 be aborted to prevent deadlocks.
272 """
274 FROM_STATES = ALL_ORCHESTRATION_STATES 1a
275 TO_STATES = {StateType.RUNNING} 1a
277 async def before_transition( 1a
278 self,
279 initial_state: states.State[Any] | None,
280 proposed_state: states.State[Any] | None,
281 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy],
282 ) -> None:
283 settings = get_current_settings()
284 self._applied_limits: list[str] = []
285 self._acquired_v2_lease_ids: list[UUID] = []
286 v1_limits = (
287 await concurrency_limits.filter_concurrency_limits_for_orchestration(
288 context.session, tags=context.run.tags
289 )
290 )
291 v2_names = [f"tag:{tag}" for tag in context.run.tags]
292 v2_limits = await concurrency_limits_v2.bulk_read_concurrency_limits(
293 context.session, names=v2_names
294 )
296 # Handle V2 limits first (if they exist)
297 v2_tags: set[str] = set() # Track which tags have V2 limits
298 if v2_limits:
299 lease_storage = get_concurrency_lease_storage()
300 # Track which tags have V2 limits to exclude from V1 processing
301 v2_tags = {
302 limit.name.removeprefix("tag:") for limit in v2_limits if limit.active
303 }
305 # Check for zero limits that would deadlock
306 for limit in v2_limits:
307 if limit.active and limit.limit == 0:
308 # Clean up any already acquired V2 leases
309 for lease_id in self._acquired_v2_lease_ids:
310 try:
311 lease = await lease_storage.read_lease(
312 lease_id=lease_id,
313 )
314 if lease:
315 await concurrency_limits_v2.bulk_decrement_active_slots(
316 session=context.session,
317 concurrency_limit_ids=lease.resource_ids,
318 slots=lease.metadata.slots if lease.metadata else 1,
319 )
320 await lease_storage.revoke_lease(
321 lease_id=lease.id,
322 )
323 except Exception:
324 logger.warning(
325 f"Failed to clean up lease {lease_id} during abort",
326 exc_info=True,
327 )
329 await self.abort_transition(
330 reason=f'The concurrency limit on tag "{limit.name.removeprefix("tag:")}" is 0 and will deadlock if the task tries to run again.',
331 )
333 # Try to acquire V2 slots with lease (exclude zero limits as they're handled above)
334 active_v2_limits = [
335 limit for limit in v2_limits if limit.active and limit.limit > 0
336 ]
337 if active_v2_limits:
338 # Attempt to acquire slots
339 async with provide_database_interface().session_context(
340 begin_transaction=True
341 ) as session:
342 acquired = await concurrency_limits_v2.bulk_increment_active_slots(
343 session=session,
344 concurrency_limit_ids=[limit.id for limit in active_v2_limits],
345 slots=1,
346 )
347 if not acquired:
348 await session.rollback()
349 # Slots not available, delay transition
350 delay_seconds = clamped_poisson_interval(
351 average_interval=settings.server.tasks.tag_concurrency_slot_wait_seconds,
352 )
353 await self.delay_transition(
354 delay_seconds=round(delay_seconds),
355 reason=f"Concurrency limit reached for tags: {', '.join([limit.name.removeprefix('tag:') for limit in active_v2_limits])}",
356 )
357 return
359 # Create lease for acquired slots with minimal metadata first
360 lease = await lease_storage.create_lease(
361 resource_ids=[limit.id for limit in active_v2_limits],
362 ttl=concurrency_limits.V1_LEASE_TTL,
363 metadata=ConcurrencyLimitLeaseMetadata(
364 slots=1,
365 holder=ConcurrencyLeaseHolder(
366 type="task_run",
367 id=context.run.id,
368 ),
369 ),
370 )
372 self._acquired_v2_lease_ids.append(lease.id)
374 remaining_v1_limits = [limit for limit in v1_limits if limit.tag not in v2_tags]
375 if remaining_v1_limits:
376 run_limits = {limit.tag: limit for limit in v1_limits}
377 for tag, cl in run_limits.items():
378 limit = cl.concurrency_limit
379 if limit == 0:
380 # limits of 0 will deadlock, and the transition needs to abort
381 for stale_tag in self._applied_limits:
382 stale_limit = run_limits.get(stale_tag, None)
383 if stale_limit:
384 active_slots: set[str] = set(stale_limit.active_slots)
385 active_slots.discard(str(context.run.id))
386 stale_limit.active_slots = list(active_slots)
388 await self.abort_transition(
389 reason=(
390 f'The concurrency limit on tag "{tag}" is 0 and will deadlock'
391 " if the task tries to run again."
392 ),
393 )
394 elif len(cl.active_slots) >= limit:
395 # if the limit has already been reached, delay the transition
396 for stale_tag in self._applied_limits:
397 stale_limit = run_limits.get(stale_tag, None)
398 if stale_limit:
399 active_slots = set(stale_limit.active_slots)
400 active_slots.discard(str(context.run.id))
401 stale_limit.active_slots = list(active_slots)
403 await self.delay_transition(
404 delay_seconds=int(
405 settings.server.tasks.tag_concurrency_slot_wait_seconds
406 ),
407 # PREFECT_TASK_RUN_TAG_CONCURRENCY_SLOT_WAIT_SECONDS.value(),
408 reason=f"Concurrency limit for the {tag} tag has been reached",
409 )
410 else:
411 # log the TaskRun ID to active_slots
412 self._applied_limits.append(tag)
413 active_slots = set(cl.active_slots)
414 active_slots.add(str(context.run.id))
415 cl.active_slots = list(active_slots)
417 async def cleanup( 1a
418 self,
419 initial_state: states.State[Any] | None,
420 validated_state: states.State[Any] | None,
421 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy],
422 ) -> None:
423 lease_storage = get_concurrency_lease_storage()
424 # Clean up V2 leases
425 for lease_id in self._acquired_v2_lease_ids:
426 try:
427 lease = await lease_storage.read_lease(
428 lease_id=lease_id,
429 )
430 if lease:
431 await concurrency_limits_v2.bulk_decrement_active_slots(
432 session=context.session,
433 concurrency_limit_ids=lease.resource_ids,
434 slots=lease.metadata.slots if lease.metadata else 1,
435 )
436 await lease_storage.revoke_lease(
437 lease_id=lease.id,
438 )
439 else:
440 logger.warning(f"Lease {lease_id} not found during cleanup")
441 except Exception:
442 logger.warning(f"Failed to clean up lease {lease_id}", exc_info=True)
444 for tag in self._applied_limits:
445 cl = await concurrency_limits.read_concurrency_limit_by_tag(
446 context.session, tag
447 )
448 if cl:
449 active_slots = set(cl.active_slots)
450 active_slots.discard(str(context.run.id))
451 cl.active_slots = list(active_slots)
454class ReleaseTaskConcurrencySlots(TaskRunUniversalTransform): 1a
455 """
456 Releases any concurrency slots held by a run upon exiting a Running or
457 Cancelling state.
458 """
460 async def after_transition( 1a
461 self,
462 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy],
463 ) -> None:
464 if self.nullified_transition():
465 return
467 if context.validated_state and context.validated_state.type not in [
468 states.StateType.RUNNING,
469 states.StateType.CANCELLING,
470 ]:
471 v2_names = [f"tag:{tag}" for tag in context.run.tags]
472 v2_limits = await concurrency_limits_v2.bulk_read_concurrency_limits(
473 context.session, names=v2_names
474 )
475 # Release V2 leases for this task run
476 if v2_limits:
477 lease_storage = get_concurrency_lease_storage()
478 lease_ids_to_reconcile: set[UUID] = set()
479 for v2_limit in v2_limits:
480 # Find holders for this limit
481 holders_with_leases: list[
482 tuple[UUID, ConcurrencyLeaseHolder]
483 ] = await lease_storage.list_holders_for_limit(
484 limit_id=v2_limit.id,
485 )
486 # Find leases that belong to this task run
487 for lease_id, holder in holders_with_leases:
488 if holder.id == context.run.id:
489 lease_ids_to_reconcile.add(lease_id)
491 # Reconcile all found leases
492 for lease_id in lease_ids_to_reconcile:
493 try:
494 lease = await lease_storage.read_lease(
495 lease_id=lease_id,
496 )
497 if lease:
498 await concurrency_limits_v2.bulk_decrement_active_slots(
499 session=context.session,
500 concurrency_limit_ids=lease.resource_ids,
501 slots=lease.metadata.slots if lease.metadata else 1,
502 )
503 await lease_storage.revoke_lease(
504 lease_id=lease.id,
505 )
506 else:
507 logger.warning(f"Lease {lease_id} not found during release")
508 except Exception:
509 logger.warning(
510 f"Failed to reconcile lease {lease_id} during release",
511 exc_info=True,
512 )
514 v1_limits = (
515 await concurrency_limits.filter_concurrency_limits_for_orchestration(
516 context.session, tags=context.run.tags
517 )
518 )
519 for cl in v1_limits:
520 active_slots = set(cl.active_slots)
521 active_slots.discard(str(context.run.id))
522 cl.active_slots = list(active_slots)
525class SecureFlowConcurrencySlots(FlowRunOrchestrationRule): 1a
526 """
527 Enforce deployment concurrency limits.
529 This rule enforces concurrency limits on deployments. If a deployment has a concurrency limit,
530 this rule will prevent more than that number of flow runs from being submitted concurrently
531 based on the concurrency limit behavior configured for the deployment.
533 We use the PENDING state as the target transition because this allows workers to secure a slot
534 before provisioning dynamic infrastructure to run a flow. If a slot isn't available, the worker
535 won't provision infrastructure.
537 A lease is created for the concurrency limit. The client will be responsible for maintaining the lease.
538 """
540 FROM_STATES = ALL_ORCHESTRATION_STATES - { 1a
541 states.StateType.PENDING,
542 states.StateType.RUNNING,
543 states.StateType.CANCELLING,
544 }
545 TO_STATES = {states.StateType.PENDING} 1a
547 async def before_transition( 1a
548 self,
549 initial_state: states.State[Any] | None,
550 proposed_state: states.State[Any] | None,
551 context: FlowOrchestrationContext,
552 ) -> None:
553 if (
554 not context.session
555 or not context.run.deployment_id
556 or not proposed_state
557 or context.client_version
558 in WORKER_VERSIONS_THAT_MANAGE_DEPLOYMENT_CONCURRENCY
559 ):
560 return
562 deployment = await deployments.read_deployment(
563 session=context.session,
564 deployment_id=context.run.deployment_id,
565 )
566 if not deployment:
567 await self.abort_transition("Deployment not found.")
568 return
570 if (
571 not deployment.global_concurrency_limit
572 or not deployment.concurrency_limit_id
573 ):
574 return
576 if deployment.global_concurrency_limit.limit == 0:
577 await self.abort_transition(
578 "The deployment concurrency limit is 0. The flow will deadlock if submitted again."
579 )
580 return
582 acquired = await concurrency_limits_v2.bulk_increment_active_slots(
583 session=context.session,
584 concurrency_limit_ids=[deployment.concurrency_limit_id],
585 slots=1,
586 )
587 if acquired:
588 lease_storage = get_concurrency_lease_storage()
589 settings = get_current_settings()
590 lease = await lease_storage.create_lease(
591 resource_ids=[deployment.concurrency_limit_id],
592 metadata=ConcurrencyLimitLeaseMetadata(
593 slots=1,
594 ),
595 ttl=datetime.timedelta(
596 seconds=settings.server.concurrency.initial_deployment_lease_duration
597 ),
598 )
599 proposed_state.state_details.deployment_concurrency_lease_id = lease.id
601 else:
602 concurrency_options = (
603 deployment.concurrency_options
604 or core.ConcurrencyOptions(
605 collision_strategy=core.ConcurrencyLimitStrategy.ENQUEUE
606 )
607 )
609 if (
610 concurrency_options.collision_strategy
611 == core.ConcurrencyLimitStrategy.ENQUEUE
612 ):
613 settings = get_current_settings()
614 await self.reject_transition(
615 state=states.Scheduled(
616 name="AwaitingConcurrencySlot",
617 scheduled_time=now("UTC")
618 + datetime.timedelta(
619 seconds=settings.server.deployments.concurrency_slot_wait_seconds
620 ),
621 ),
622 reason="Deployment concurrency limit reached.",
623 )
624 elif (
625 concurrency_options.collision_strategy
626 == core.ConcurrencyLimitStrategy.CANCEL_NEW
627 ):
628 await self.reject_transition(
629 state=states.Cancelled(
630 message="Deployment concurrency limit reached."
631 ),
632 reason="Deployment concurrency limit reached.",
633 )
635 async def cleanup( # type: ignore 1a
636 self,
637 initial_state: states.State[Any] | None,
638 validated_state: states.State[Any] | None,
639 context: FlowOrchestrationContext,
640 ) -> None:
641 logger = get_logger()
642 if not context.session or not context.run.deployment_id:
643 return
645 try:
646 deployment = await deployments.read_deployment(
647 session=context.session,
648 deployment_id=context.run.deployment_id,
649 )
651 if not deployment or not deployment.concurrency_limit_id:
652 return
654 await concurrency_limits_v2.bulk_decrement_active_slots(
655 session=context.session,
656 concurrency_limit_ids=[deployment.concurrency_limit_id],
657 slots=1,
658 )
659 if (
660 validated_state
661 and validated_state.state_details.deployment_concurrency_lease_id
662 ):
663 lease_storage = get_concurrency_lease_storage()
664 await lease_storage.revoke_lease(
665 lease_id=validated_state.state_details.deployment_concurrency_lease_id,
666 )
667 validated_state.state_details.deployment_concurrency_lease_id = None
669 except Exception as e:
670 logger.error(f"Error releasing concurrency slots on cleanup: {e}")
673class CopyDeploymentConcurrencyLeaseID(FlowRunOrchestrationRule): 1a
674 """
675 Copies the deployment concurrency lease ID to the proposed state.
676 """
678 FROM_STATES = {states.StateType.PENDING} 1a
679 TO_STATES = {states.StateType.RUNNING} 1a
681 async def before_transition( 1a
682 self,
683 initial_state: states.State[Any] | None,
684 proposed_state: states.State[Any] | None,
685 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy],
686 ) -> None:
687 if initial_state is None or proposed_state is None:
688 return
690 if not proposed_state.state_details.deployment_concurrency_lease_id:
691 proposed_state.state_details.deployment_concurrency_lease_id = (
692 initial_state.state_details.deployment_concurrency_lease_id
693 )
696class RemoveDeploymentConcurrencyLeaseForOldClientVersions(FlowRunOrchestrationRule): 1a
697 """
698 Removes a deployment concurrency lease if the client version is less than the minimum version for leasing.
699 """
701 FROM_STATES = {states.StateType.PENDING} 1a
702 TO_STATES = {states.StateType.RUNNING, states.StateType.CANCELLING} 1a
704 async def after_transition( 1a
705 self,
706 initial_state: states.State[Any] | None,
707 validated_state: states.State[Any] | None,
708 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy],
709 ) -> None:
710 if not initial_state or (
711 context.client_version
712 and Version(context.client_version)
713 >= MIN_CLIENT_VERSION_FOR_CONCURRENCY_LIMIT_LEASING
714 ):
715 return
717 if lease_id := initial_state.state_details.deployment_concurrency_lease_id:
718 lease_storage = get_concurrency_lease_storage()
719 await lease_storage.revoke_lease(
720 lease_id=lease_id,
721 )
724class ReleaseFlowConcurrencySlots(FlowRunUniversalTransform): 1a
725 """
726 Releases deployment concurrency slots held by a flow run.
728 This rule releases a concurrency slot for a deployment when a flow run
729 transitions out of the Running or Cancelling state.
730 """
732 async def after_transition( 1a
733 self,
734 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy],
735 ) -> None:
736 if self.nullified_transition():
737 return
739 initial_state_type = (
740 context.initial_state.type if context.initial_state else None
741 )
742 proposed_state_type = (
743 context.proposed_state.type if context.proposed_state else None
744 )
746 # Check if the transition is valid for releasing concurrency slots.
747 # This should happen within `after_transition` because BaseUniversalTransforms
748 # don't know how to "fizzle" themselves if they encounter a transition that
749 # shouldn't apply to them, even if they use FROM_STATES and TO_STATES.
750 if not (
751 initial_state_type
752 in {
753 states.StateType.RUNNING,
754 states.StateType.CANCELLING,
755 states.StateType.PENDING,
756 }
757 and proposed_state_type
758 not in {
759 states.StateType.PENDING,
760 states.StateType.RUNNING,
761 states.StateType.CANCELLING,
762 }
763 ):
764 return
765 if not context.session or not context.run.deployment_id:
766 return
768 lease_storage = get_concurrency_lease_storage()
769 if (
770 context.initial_state
771 and context.initial_state.state_details.deployment_concurrency_lease_id
772 and (
773 lease := await lease_storage.read_lease(
774 lease_id=context.initial_state.state_details.deployment_concurrency_lease_id,
775 )
776 )
777 and lease.metadata
778 ):
779 await concurrency_limits_v2.bulk_decrement_active_slots(
780 session=context.session,
781 concurrency_limit_ids=lease.resource_ids,
782 slots=lease.metadata.slots,
783 )
784 await lease_storage.revoke_lease(
785 lease_id=lease.id,
786 )
787 else:
788 deployment = await deployments.read_deployment(
789 session=context.session,
790 deployment_id=context.run.deployment_id,
791 )
792 if not deployment or not deployment.concurrency_limit_id:
793 return
795 await concurrency_limits_v2.bulk_decrement_active_slots(
796 session=context.session,
797 concurrency_limit_ids=[deployment.concurrency_limit_id],
798 slots=1,
799 )
802class CacheInsertion(TaskRunOrchestrationRule): 1a
803 """
804 Caches completed states with cache keys after they are validated.
805 """
807 FROM_STATES = ALL_ORCHESTRATION_STATES 1a
808 TO_STATES = {StateType.COMPLETED} 1a
810 async def before_transition( 1a
811 self,
812 initial_state: states.State[Any] | None,
813 proposed_state: states.State[Any] | None,
814 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy],
815 ) -> None:
816 if proposed_state is None:
817 return
819 settings = get_current_settings()
820 cache_key = proposed_state.state_details.cache_key
821 if cache_key and len(cache_key) > settings.server.tasks.max_cache_key_length:
822 await self.reject_transition(
823 state=proposed_state,
824 reason=f"Cache key exceeded maximum allowed length of {settings.server.tasks.max_cache_key_length} characters.",
825 )
826 return
828 @db_injector 1a
829 async def after_transition( 1a
830 self,
831 db: PrefectDBInterface,
832 initial_state: states.State[Any] | None,
833 validated_state: states.State[Any] | None,
834 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy],
835 ) -> None:
836 if not validated_state or not context.session:
837 return
839 cache_key = validated_state.state_details.cache_key
840 if cache_key:
841 new_cache_item = db.TaskRunStateCache(
842 cache_key=cache_key,
843 cache_expiration=validated_state.state_details.cache_expiration,
844 task_run_state_id=validated_state.id,
845 )
846 context.session.add(new_cache_item)
849class CacheRetrieval(TaskRunOrchestrationRule): 1a
850 """
851 Rejects running states if a completed state has been cached.
853 This rule rejects transitions into a running state with a cache key if the key
854 has already been associated with a completed state in the cache table. The client
855 will be instructed to transition into the cached completed state instead.
856 """
858 FROM_STATES = ALL_ORCHESTRATION_STATES 1a
859 TO_STATES = {StateType.RUNNING} 1a
861 @db_injector 1a
862 async def before_transition( 1a
863 self,
864 db: PrefectDBInterface,
865 initial_state: states.State[Any] | None,
866 proposed_state: states.State[Any] | None,
867 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy],
868 ) -> None:
869 if not proposed_state:
870 return
872 cache_key = proposed_state.state_details.cache_key
873 if cache_key and not proposed_state.state_details.refresh_cache:
874 # Check for cached states matching the cache key
875 cached_state_id = (
876 select(db.TaskRunStateCache.task_run_state_id)
877 .where(
878 sa.and_(
879 db.TaskRunStateCache.cache_key == cache_key,
880 sa.or_(
881 db.TaskRunStateCache.cache_expiration.is_(None),
882 db.TaskRunStateCache.cache_expiration > now("UTC"),
883 ),
884 ),
885 )
886 .order_by(db.TaskRunStateCache.created.desc())
887 .limit(1)
888 ).scalar_subquery()
889 query = select(db.TaskRunState).where(db.TaskRunState.id == cached_state_id)
890 cached_state = (await context.session.execute(query)).scalar()
891 if cached_state:
892 new_state = cached_state.as_state().fresh_copy()
893 new_state.name = "Cached"
894 await self.reject_transition(
895 state=new_state, reason="Retrieved state from cache"
896 )
899class RetryFailedFlows(FlowRunOrchestrationRule): 1a
900 """
901 Rejects failed states and schedules a retry if the retry limit has not been reached.
903 This rule rejects transitions into a failed state if `retries` has been
904 set and the run count has not reached the specified limit. The client will be
905 instructed to transition into a scheduled state to retry flow execution.
906 """
908 FROM_STATES = {StateType.RUNNING} 1a
909 TO_STATES = {StateType.FAILED} 1a
911 async def before_transition( 1a
912 self,
913 initial_state: states.State[Any] | None,
914 proposed_state: states.State[Any] | None,
915 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy],
916 ) -> None:
917 if initial_state is None or proposed_state is None:
918 return
920 run_settings = context.run_settings
921 run_count = context.run.run_count
923 if run_settings.retries is None or run_count > run_settings.retries:
924 # Clear retry type to allow for future infrastructure level retries (e.g. via the UI)
925 updated_policy = context.run.empirical_policy.model_dump()
926 updated_policy["retry_type"] = None
927 context.run.empirical_policy = core.FlowRunPolicy(**updated_policy)
929 return # Retry count exceeded, allow transition to failed
931 scheduled_start_time = now("UTC") + datetime.timedelta(
932 seconds=run_settings.retry_delay or 0
933 )
935 # support old-style flow run retries for older clients
936 # older flow retries require us to loop over failed tasks to update their state
937 # this is not required after API version 0.8.3
938 api_version = context.parameters.get("api-version", None)
939 if api_version and api_version < Version("0.8.3"):
940 failed_task_runs = await models.task_runs.read_task_runs(
941 context.session,
942 flow_run_filter=filters.FlowRunFilter(
943 id=filters.FlowRunFilterId(any_=[context.run.id])
944 ),
945 task_run_filter=filters.TaskRunFilter(
946 state=filters.TaskRunFilterState(
947 type=filters.TaskRunFilterStateType(any_=[StateType.FAILED])
948 )
949 ),
950 )
951 for run in failed_task_runs:
952 await models.task_runs.set_task_run_state(
953 context.session,
954 run.id,
955 state=states.AwaitingRetry(scheduled_time=scheduled_start_time),
956 force=True,
957 )
958 # Reset the run count so that the task run retries still work correctly
959 run.run_count = 0
961 # Reset pause metadata on retry
962 # Pauses as a concept only exist after API version 0.8.4
963 api_version = context.parameters.get("api-version", None)
964 if api_version is None or api_version >= Version("0.8.4"):
965 updated_policy = context.run.empirical_policy.model_dump()
966 updated_policy["resuming"] = False
967 updated_policy["pause_keys"] = set()
968 updated_policy["retry_type"] = "in_process"
969 context.run.empirical_policy = core.FlowRunPolicy(**updated_policy)
971 # Generate a new state for the flow
972 retry_state = states.AwaitingRetry(
973 scheduled_time=scheduled_start_time,
974 message=proposed_state.message,
975 data=proposed_state.data,
976 )
977 await self.reject_transition(state=retry_state, reason="Retrying")
980class RetryFailedTasks(TaskRunOrchestrationRule): 1a
981 """
982 Rejects failed states and schedules a retry if the retry limit has not been reached.
984 This rule rejects transitions into a failed state if `retries` has been
985 set, the run count has not reached the specified limit, and the client
986 asserts it is a retriable task run. The client will be instructed to
987 transition into a scheduled state to retry task execution.
988 """
990 FROM_STATES = {StateType.RUNNING} 1a
991 TO_STATES = {StateType.FAILED} 1a
993 async def before_transition( 1a
994 self,
995 initial_state: states.State[Any] | None,
996 proposed_state: states.State[Any] | None,
997 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy],
998 ) -> None:
999 if initial_state is None or proposed_state is None:
1000 return
1002 run_settings = context.run_settings
1003 run_count = context.run.run_count
1004 delay = run_settings.retry_delay
1006 if isinstance(delay, list):
1007 base_delay = delay[min(run_count - 1, len(delay) - 1)]
1008 else:
1009 base_delay = delay or 0
1011 # guard against negative relative jitter inputs
1012 if run_settings.retry_jitter_factor:
1013 delay = clamped_poisson_interval(
1014 base_delay, clamping_factor=run_settings.retry_jitter_factor
1015 )
1016 else:
1017 delay = base_delay
1019 # set by user to conditionally retry a task using @task(retry_condition_fn=...)
1020 if getattr(proposed_state.state_details, "retriable", True) is False:
1021 return
1023 if run_settings.retries is not None and run_count <= run_settings.retries:
1024 retry_state = states.AwaitingRetry(
1025 scheduled_time=now("UTC") + datetime.timedelta(seconds=delay),
1026 message=proposed_state.message,
1027 data=proposed_state.data,
1028 )
1029 await self.reject_transition(state=retry_state, reason="Retrying")
1032class EnqueueScheduledTasks(TaskRunOrchestrationRule): 1a
1033 """
1034 Enqueues background task runs when they are scheduled
1035 """
1037 FROM_STATES = ALL_ORCHESTRATION_STATES 1a
1038 TO_STATES = {StateType.SCHEDULED} 1a
1040 async def after_transition( 1a
1041 self,
1042 initial_state: states.State[Any] | None,
1043 validated_state: states.State[Any] | None,
1044 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy],
1045 ) -> None:
1046 if not validated_state:
1047 # Only if the transition was valid
1048 return
1050 if not validated_state.state_details.deferred:
1051 # Only for tasks that are deferred
1052 return
1054 task_run: core.TaskRun = core.TaskRun.model_validate(context.run)
1055 queue: TaskQueue = TaskQueue.for_key(task_run.task_key)
1057 if validated_state.name == "AwaitingRetry":
1058 await queue.retry(task_run)
1059 else:
1060 await queue.enqueue(task_run)
1063class RenameReruns(GenericOrchestrationRule): 1a
1064 """
1065 Name the states if they have run more than once.
1067 In the special case where the initial state is an "AwaitingRetry" scheduled state,
1068 the proposed state will be renamed to "Retrying" instead.
1069 """
1071 FROM_STATES = ALL_ORCHESTRATION_STATES 1a
1072 TO_STATES = {StateType.RUNNING} 1a
1074 async def before_transition( 1a
1075 self,
1076 initial_state: states.State[Any] | None,
1077 proposed_state: states.State[Any] | None,
1078 context: OrchestrationContext[
1079 orm_models.Run, core.TaskRunPolicy | core.FlowRunPolicy
1080 ],
1081 ) -> None:
1082 if initial_state is None or proposed_state is None:
1083 return
1085 run_count = context.run.run_count
1086 if run_count > 0:
1087 if initial_state.name == "AwaitingRetry":
1088 await self.rename_state("Retrying")
1089 else:
1090 await self.rename_state("Rerunning")
1093class CopyScheduledTime( 1a
1094 BaseOrchestrationRule[orm_models.Run, Union[core.TaskRunPolicy, core.FlowRunPolicy]]
1095):
1096 """
1097 Ensures scheduled time is copied from scheduled states to pending states.
1099 If a new scheduled time has been proposed on the pending state, the scheduled time
1100 on the scheduled state will be ignored.
1101 """
1103 FROM_STATES = {StateType.SCHEDULED} 1a
1104 TO_STATES = {StateType.PENDING} 1a
1106 async def before_transition( 1a
1107 self,
1108 initial_state: states.State[Any] | None,
1109 proposed_state: states.State[Any] | None,
1110 context: OrchestrationContext[
1111 orm_models.Run, core.TaskRunPolicy | core.FlowRunPolicy
1112 ],
1113 ) -> None:
1114 if initial_state is None or proposed_state is None:
1115 return
1117 if not proposed_state.state_details.scheduled_time:
1118 proposed_state.state_details.scheduled_time = (
1119 initial_state.state_details.scheduled_time
1120 )
1123class WaitForScheduledTime( 1a
1124 BaseOrchestrationRule[orm_models.Run, Union[core.TaskRunPolicy, core.FlowRunPolicy]]
1125):
1126 """
1127 Prevents transitions to running states from happening too early.
1129 This rule enforces that all scheduled states will only start with the machine clock
1130 used by the Prefect REST API instance. This rule will identify transitions from scheduled
1131 states that are too early and nullify them. Instead, no state will be written to the
1132 database and the client will be sent an instruction to wait for `delay_seconds`
1133 before attempting the transition again.
1134 """
1136 FROM_STATES = {StateType.SCHEDULED, StateType.PENDING} 1a
1137 TO_STATES = {StateType.RUNNING} 1a
1139 async def before_transition( 1a
1140 self,
1141 initial_state: states.State[Any] | None,
1142 proposed_state: states.State[Any] | None,
1143 context: OrchestrationContext[
1144 orm_models.Run, core.TaskRunPolicy | core.FlowRunPolicy
1145 ],
1146 ) -> None:
1147 if initial_state is None or proposed_state is None:
1148 return
1150 scheduled_time = initial_state.state_details.scheduled_time
1151 if not scheduled_time:
1152 return
1154 # At this moment, we round delay to the nearest second as the API schema
1155 # specifies an integer return value.
1156 delay = scheduled_time - now("UTC")
1157 delay_seconds = math.floor(delay.total_seconds())
1158 delay_seconds += round(delay.microseconds / 1e6)
1159 if delay_seconds > 0:
1160 await self.delay_transition(
1161 delay_seconds, reason="Scheduled time is in the future"
1162 )
1165class CopyTaskParametersID(TaskRunOrchestrationRule): 1a
1166 """
1167 Ensures a task's parameters ID is copied from Scheduled to Pending and from
1168 Pending to Running states.
1170 If a parameters ID has been included on the proposed state, the parameters ID
1171 on the initial state will be ignored.
1172 """
1174 FROM_STATES = {StateType.SCHEDULED, StateType.PENDING} 1a
1175 TO_STATES = {StateType.PENDING, StateType.RUNNING} 1a
1177 async def before_transition( 1a
1178 self,
1179 initial_state: states.State[Any] | None,
1180 proposed_state: states.State[Any] | None,
1181 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy],
1182 ) -> None:
1183 if initial_state is None or proposed_state is None:
1184 return
1186 if not proposed_state.state_details.task_parameters_id:
1187 proposed_state.state_details.task_parameters_id = (
1188 initial_state.state_details.task_parameters_id
1189 )
1192class HandlePausingFlows(FlowRunOrchestrationRule): 1a
1193 """
1194 Governs runs attempting to enter a Paused/Suspended state
1195 """
1197 FROM_STATES = ALL_ORCHESTRATION_STATES 1a
1198 TO_STATES = {StateType.PAUSED} 1a
1200 async def before_transition( 1a
1201 self,
1202 initial_state: states.State[Any] | None,
1203 proposed_state: states.State[Any] | None,
1204 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy],
1205 ) -> None:
1206 if proposed_state is None:
1207 return
1209 verb = "suspend" if proposed_state.name == "Suspended" else "pause"
1211 if initial_state is None:
1212 await self.abort_transition(f"Cannot {verb} flows with no state.")
1213 return
1215 if not initial_state.is_running():
1216 await self.reject_transition(
1217 state=None,
1218 reason=f"Cannot {verb} flows that are not currently running.",
1219 )
1220 return
1222 self.key = proposed_state.state_details.pause_key
1223 if self.key is None:
1224 # if no pause key is provided, default to a UUID
1225 self.key = str(uuid4())
1227 pause_keys = context.run.empirical_policy.pause_keys or set()
1228 if self.key in pause_keys:
1229 await self.reject_transition(
1230 state=None, reason=f"This {verb} has already fired."
1231 )
1232 return
1234 if proposed_state.state_details.pause_reschedule:
1235 if context.run.parent_task_run_id:
1236 await self.abort_transition(
1237 reason=f"Cannot {verb} subflows.",
1238 )
1239 return
1241 if context.run.deployment_id is None:
1242 await self.abort_transition(
1243 reason=f"Cannot {verb} flows without a deployment.",
1244 )
1245 return
1247 async def after_transition( 1a
1248 self,
1249 initial_state: states.State[Any] | None,
1250 validated_state: states.State[Any] | None,
1251 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy],
1252 ) -> None:
1253 updated_policy = context.run.empirical_policy.model_dump()
1254 updated_policy["pause_keys"].add(self.key)
1255 context.run.empirical_policy = core.FlowRunPolicy(**updated_policy)
1258class HandleResumingPausedFlows(FlowRunOrchestrationRule): 1a
1259 """
1260 Governs runs attempting to leave a Paused state
1261 """
1263 FROM_STATES = {StateType.PAUSED} 1a
1264 TO_STATES = ALL_ORCHESTRATION_STATES 1a
1266 async def before_transition( 1a
1267 self,
1268 initial_state: states.State[Any] | None,
1269 proposed_state: states.State[Any] | None,
1270 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy],
1271 ) -> None:
1272 if initial_state is None or proposed_state is None:
1273 return
1275 if not (
1276 proposed_state
1277 and (
1278 proposed_state.is_running()
1279 or proposed_state.is_scheduled()
1280 or proposed_state.is_final()
1281 )
1282 ):
1283 await self.reject_transition(
1284 state=None,
1285 reason=(
1286 f"This run cannot transition to the {proposed_state.type} state"
1287 f" from the {initial_state.type} state."
1288 ),
1289 )
1290 return
1292 verb = "suspend" if proposed_state.name == "Suspended" else "pause"
1294 display_state_name = (
1295 proposed_state.name.lower()
1296 if proposed_state.name
1297 else proposed_state.type.value.lower()
1298 )
1300 if initial_state.state_details.pause_reschedule:
1301 if not context.run.deployment_id:
1302 await self.reject_transition(
1303 state=None,
1304 reason=(
1305 f"Cannot reschedule a {display_state_name} flow run"
1306 " without a deployment."
1307 ),
1308 )
1309 return
1310 pause_timeout = initial_state.state_details.pause_timeout
1311 if pause_timeout and pause_timeout < now("UTC"):
1312 pause_timeout_failure = states.Failed(
1313 message=(f"The flow was {display_state_name} and never resumed."),
1314 )
1315 await self.reject_transition(
1316 state=pause_timeout_failure,
1317 reason=f"The flow run {verb} has timed out and can no longer resume.",
1318 )
1319 return
1321 async def after_transition( 1a
1322 self,
1323 initial_state: states.State[Any] | None,
1324 validated_state: states.State[Any] | None,
1325 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy],
1326 ) -> None:
1327 updated_policy = context.run.empirical_policy.model_dump()
1328 updated_policy["resuming"] = True
1329 context.run.empirical_policy = core.FlowRunPolicy(**updated_policy)
1332class UpdateFlowRunTrackerOnTasks(TaskRunOrchestrationRule): 1a
1333 """
1334 Tracks the flow run attempt a task run state is associated with.
1335 """
1337 FROM_STATES = ALL_ORCHESTRATION_STATES 1a
1338 TO_STATES = {StateType.RUNNING} 1a
1340 async def after_transition( 1a
1341 self,
1342 initial_state: states.State[Any] | None,
1343 validated_state: states.State[Any] | None,
1344 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy],
1345 ) -> None:
1346 if context.run.flow_run_id is not None:
1347 self.flow_run: orm_models.FlowRun | None = await context.flow_run()
1348 if self.flow_run:
1349 context.run.flow_run_run_count = self.flow_run.run_count
1350 else:
1351 raise ObjectNotFoundError(
1352 (
1353 "Unable to read flow run associated with task run:"
1354 f" {context.run.id}, this flow run might have been deleted"
1355 ),
1356 )
1359class HandleTaskTerminalStateTransitions(TaskRunOrchestrationRule): 1a
1360 """
1361 We do not allow tasks to leave terminal states if:
1362 - The task is completed and has a persisted result
1363 - The task is going to CANCELLING / PAUSED / CRASHED
1365 We reset the run count when a task leaves a terminal state for a non-terminal state
1366 which resets task run retries; this is particularly relevant for flow run retries.
1367 """
1369 FROM_STATES: set[states.StateType | None] = TERMINAL_STATES # pyright: ignore[reportAssignmentType] technically TERMINAL_STATES doesn't contain None 1a
1370 TO_STATES: set[states.StateType | None] = ALL_ORCHESTRATION_STATES 1a
1372 async def before_transition( 1a
1373 self,
1374 initial_state: states.State[Any] | None,
1375 proposed_state: states.State[Any] | None,
1376 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy],
1377 ) -> None:
1378 if initial_state is None or proposed_state is None:
1379 return
1381 self.original_run_count: int = context.run.run_count
1383 # Do not allow runs to be marked as crashed, paused, or cancelling if already terminal
1384 if proposed_state.type in {
1385 StateType.CANCELLING,
1386 StateType.PAUSED,
1387 StateType.CRASHED,
1388 }:
1389 await self.abort_transition(f"Run is already {initial_state.type.value}.")
1390 return
1392 # Only allow departure from a happily completed state if the result is not persisted
1393 if (
1394 initial_state.is_completed()
1395 and initial_state.data
1396 and initial_state.data.get("type") != "unpersisted"
1397 ):
1398 await self.reject_transition(None, "This run is already completed.")
1399 return
1401 if not proposed_state.is_final():
1402 # Reset run count to reset retries
1403 context.run.run_count = 0
1405 # Change the name of the state to retrying if its a flow run retry
1406 if proposed_state.is_running() and context.run.flow_run_id is not None:
1407 self.flow_run: orm_models.FlowRun | None = await context.flow_run()
1408 if self.flow_run is not None:
1409 flow_retrying = context.run.flow_run_run_count < self.flow_run.run_count
1410 if flow_retrying:
1411 await self.rename_state("Retrying")
1413 async def cleanup( 1a
1414 self,
1415 initial_state: states.State[Any] | None,
1416 validated_state: states.State[Any] | None,
1417 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy],
1418 ) -> None:
1419 # reset run count
1420 context.run.run_count = self.original_run_count
1423class HandleFlowTerminalStateTransitions(FlowRunOrchestrationRule): 1a
1424 """
1425 We do not allow flows to leave terminal states if:
1426 - The flow is completed and has a persisted result
1427 - The flow is going to CANCELLING / PAUSED / CRASHED
1428 - The flow is going to scheduled and has no deployment
1430 We reset the pause metadata when a flow leaves a terminal state for a non-terminal
1431 state. This resets pause behavior during manual flow run retries.
1432 """
1434 FROM_STATES: set[states.StateType | None] = TERMINAL_STATES # pyright: ignore[reportAssignmentType] technically TERMINAL_STATES doesn't contain None 1a
1435 TO_STATES: set[states.StateType | None] = ALL_ORCHESTRATION_STATES 1a
1437 async def before_transition( 1a
1438 self,
1439 initial_state: states.State[Any] | None,
1440 proposed_state: states.State[Any] | None,
1441 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy],
1442 ) -> None:
1443 if initial_state is None or proposed_state is None:
1444 return
1446 self.original_flow_policy: dict[str, Any] = (
1447 context.run.empirical_policy.model_dump()
1448 )
1450 # Do not allow runs to be marked as crashed, paused, or cancelling if already terminal
1451 if proposed_state.type in {
1452 StateType.CANCELLING,
1453 StateType.PAUSED,
1454 StateType.CRASHED,
1455 }:
1456 await self.abort_transition(
1457 f"Run is already in terminal state {initial_state.type.value}."
1458 )
1459 return
1461 # Only allow departure from a happily completed state if the result is not
1462 # persisted and the a rerun is being proposed
1463 if (
1464 initial_state.is_completed()
1465 and not proposed_state.is_final()
1466 and initial_state.data
1467 and initial_state.data.get("type") != "unpersisted"
1468 ):
1469 await self.reject_transition(None, "Run is already COMPLETED.")
1470 return
1472 # Do not allows runs to be rescheduled without a deployment
1473 if proposed_state.is_scheduled() and not context.run.deployment_id:
1474 await self.abort_transition(
1475 "Cannot reschedule a run without an associated deployment."
1476 )
1477 return
1479 if not proposed_state.is_final():
1480 # Reset pause metadata when leaving a terminal state
1481 api_version = context.parameters.get("api-version", None)
1482 if api_version is None or api_version >= Version("0.8.4"):
1483 updated_policy = context.run.empirical_policy.model_dump()
1484 updated_policy["resuming"] = False
1485 updated_policy["pause_keys"] = set()
1486 if proposed_state.is_scheduled():
1487 updated_policy["retry_type"] = "reschedule"
1488 else:
1489 updated_policy["retry_type"] = None
1490 context.run.empirical_policy = core.FlowRunPolicy(**updated_policy)
1492 async def cleanup( 1a
1493 self,
1494 initial_state: states.State[Any] | None,
1495 validated_state: states.State[Any] | None,
1496 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy],
1497 ) -> None:
1498 context.run.empirical_policy = core.FlowRunPolicy(**self.original_flow_policy)
1501class PreventPendingTransitions(GenericOrchestrationRule): 1a
1502 """
1503 Prevents transitions to PENDING.
1505 This rule is only used for flow runs.
1507 This is intended to prevent race conditions during duplicate submissions of runs.
1508 Before a run is submitted to its execution environment, it should be placed in a
1509 PENDING state. If two workers attempt to submit the same run, one of them should
1510 encounter a PENDING -> PENDING transition and abort orchestration of the run.
1512 Similarly, if the execution environment starts quickly the run may be in a RUNNING
1513 state when the second worker attempts the PENDING transition. We deny these state
1514 changes as well to prevent duplicate submission. If a run has transitioned to a
1515 RUNNING state a worker should not attempt to submit it again unless it has moved
1516 into a terminal state.
1518 CANCELLING and CANCELLED runs should not be allowed to transition to PENDING.
1519 For re-runs of deployed runs, they should transition to SCHEDULED first.
1520 For re-runs of ad-hoc runs, they should transition directly to RUNNING.
1521 """
1523 FROM_STATES = { 1a
1524 StateType.PENDING,
1525 StateType.CANCELLING,
1526 StateType.RUNNING,
1527 StateType.CANCELLED,
1528 }
1529 TO_STATES = {StateType.PENDING} 1a
1531 async def before_transition( 1a
1532 self,
1533 initial_state: states.State[Any] | None,
1534 proposed_state: states.State[Any] | None,
1535 context: OrchestrationContext[
1536 orm_models.Run, Union[core.FlowRunPolicy, core.TaskRunPolicy]
1537 ],
1538 ) -> None:
1539 if initial_state is None or proposed_state is None:
1540 return
1542 await self.abort_transition(
1543 reason=(
1544 f"This run is in a {initial_state.type.name} state and cannot"
1545 " transition to a PENDING state."
1546 )
1547 )
1550class EnsureOnlyScheduledFlowsMarkedLate(FlowRunOrchestrationRule): 1a
1551 FROM_STATES = ALL_ORCHESTRATION_STATES 1a
1552 TO_STATES = {StateType.SCHEDULED} 1a
1554 async def before_transition( 1a
1555 self,
1556 initial_state: states.State[Any] | None,
1557 proposed_state: states.State[Any] | None,
1558 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy],
1559 ) -> None:
1560 if initial_state is None or proposed_state is None:
1561 return
1563 marking_flow_late = (
1564 proposed_state.is_scheduled() and proposed_state.name == "Late"
1565 )
1566 if marking_flow_late and not initial_state.is_scheduled():
1567 await self.reject_transition(
1568 state=None, reason="Only scheduled flows can be marked late."
1569 )
1572class PreventRunningTasksFromStoppedFlows(TaskRunOrchestrationRule): 1a
1573 """
1574 Prevents running tasks from stopped flows.
1576 A running state implies execution, but also the converse. This rule ensures that a
1577 flow's tasks cannot be run unless the flow is also running.
1578 """
1580 FROM_STATES = ALL_ORCHESTRATION_STATES 1a
1581 TO_STATES = {StateType.RUNNING} 1a
1583 async def before_transition( 1a
1584 self,
1585 initial_state: states.State[Any] | None,
1586 proposed_state: states.State[Any] | None,
1587 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy],
1588 ) -> None:
1589 flow_run = await context.flow_run()
1590 if flow_run is not None:
1591 if flow_run.state is None:
1592 await self.abort_transition(
1593 reason="The enclosing flow must be running to begin task execution."
1594 )
1595 elif flow_run.state.type == StateType.PAUSED:
1596 # Use the flow run's Paused state details to preserve data like
1597 # timeouts.
1598 paused_state = states.Paused(
1599 name="NotReady",
1600 pause_expiration_time=flow_run.state.state_details.pause_timeout,
1601 reschedule=flow_run.state.state_details.pause_reschedule,
1602 )
1603 await self.reject_transition(
1604 state=paused_state,
1605 reason=(
1606 "The flow is paused, new tasks can execute after resuming flow"
1607 f" run: {flow_run.id}."
1608 ),
1609 )
1610 elif not flow_run.state.type == StateType.RUNNING:
1611 # task runners should abort task run execution
1612 await self.abort_transition(
1613 reason=(
1614 "The enclosing flow must be running to begin task execution."
1615 ),
1616 )
1619class EnforceCancellingToCancelledTransition(TaskRunOrchestrationRule): 1a
1620 """
1621 Rejects transitions from Cancelling to any terminal state except for Cancelled.
1622 """
1624 FROM_STATES = {StateType.CANCELLED, StateType.CANCELLING} 1a
1625 TO_STATES = ALL_ORCHESTRATION_STATES - {StateType.CANCELLED} 1a
1627 async def before_transition( 1a
1628 self,
1629 initial_state: states.State[Any] | None,
1630 proposed_state: states.State[Any] | None,
1631 context: OrchestrationContext[orm_models.TaskRun, core.TaskRunPolicy],
1632 ) -> None:
1633 await self.reject_transition(
1634 state=None,
1635 reason=(
1636 "Cannot transition flows that are cancelling to a state other "
1637 "than Cancelled."
1638 ),
1639 )
1640 return
1643class BypassCancellingFlowRunsWithNoInfra(FlowRunOrchestrationRule): 1a
1644 """Rejects transitions from Scheduled to Cancelling, and instead sets the state to Cancelled,
1645 if the flow run has no associated infrastructure process ID. Also Rejects transitions from
1646 Paused to Cancelling if the Paused state's details indicates the flow run has been suspended,
1647 exiting the flow and tearing down infra.
1649 The `Cancelling` state is used to clean up infrastructure. If there is not infrastructure
1650 to clean up, we can transition directly to `Cancelled`. Runs that are `Resuming` are in a
1651 `Scheduled` state that were previously `Suspended` and do not yet have infrastructure.
1653 Runs that are `AwaitingRetry` are a `Scheduled` state that may have associated infrastructure.
1654 """
1656 FROM_STATES = {StateType.SCHEDULED, StateType.PAUSED} 1a
1657 TO_STATES = {StateType.CANCELLING} 1a
1659 async def before_transition( 1a
1660 self,
1661 initial_state: states.State[Any] | None,
1662 proposed_state: states.State[Any] | None,
1663 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy],
1664 ) -> None:
1665 if initial_state is None or proposed_state is None:
1666 return
1668 if (
1669 initial_state.type == states.StateType.SCHEDULED
1670 and not context.run.infrastructure_pid
1671 or initial_state.name == "Resuming"
1672 ):
1673 await self.reject_transition(
1674 state=states.Cancelled(),
1675 reason="Scheduled flow run has no infrastructure to terminate.",
1676 )
1677 elif (
1678 initial_state.type == states.StateType.PAUSED
1679 and initial_state.state_details.pause_reschedule
1680 ):
1681 await self.reject_transition(
1682 state=states.Cancelled(),
1683 reason="Suspended flow run has no infrastructure to terminate.",
1684 )
1687class PreventDuplicateTransitions(FlowRunOrchestrationRule): 1a
1688 """
1689 Prevent duplicate transitions from being made right after one another.
1691 This rule allows for clients to set an optional transition_id on a state. If the
1692 run's next transition has the same transition_id, the transition will be
1693 rejected and the existing state will be returned.
1695 This allows for clients to make state transition requests without worrying about
1696 the following case:
1697 - A client making a state transition request
1698 - The server accepts transition and commits the transition
1699 - The client is unable to receive the response and retries the request
1700 """
1702 FROM_STATES: set[states.StateType | None] = ALL_ORCHESTRATION_STATES 1a
1703 TO_STATES: set[states.StateType | None] = ALL_ORCHESTRATION_STATES 1a
1705 async def before_transition( 1a
1706 self,
1707 initial_state: states.State[Any] | None,
1708 proposed_state: states.State[Any] | None,
1709 context: OrchestrationContext[orm_models.FlowRun, core.FlowRunPolicy],
1710 ) -> None:
1711 if initial_state is None or proposed_state is None:
1712 return
1714 initial_transition_id = getattr(
1715 initial_state.state_details, "transition_id", None
1716 )
1717 proposed_transition_id = getattr(
1718 proposed_state.state_details, "transition_id", None
1719 )
1720 if (
1721 initial_transition_id is not None
1722 and proposed_transition_id is not None
1723 and initial_transition_id == proposed_transition_id
1724 ):
1725 await self.reject_transition(
1726 # state=None will return the initial (current) state
1727 state=None,
1728 reason="This run has already made this state transition.",
1729 )