Coverage for /usr/local/lib/python3.12/site-packages/prefect/cli/transfer/_dag.py: 15%
174 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
1"""
2Execution DAG for managing resource transfer dependencies.
4This module provides a pure execution engine that:
5- Stores nodes by UUID for deduplication
6- Implements Kahn's algorithm for topological sorting
7- Manages concurrent execution with worker pools
8- Handles failure propagation (skip descendants)
9"""
11from __future__ import annotations 1a
13import asyncio 1a
14import uuid 1a
15from collections import defaultdict, deque 1a
16from dataclasses import dataclass, field 1a
17from enum import Enum 1a
18from typing import Any, Awaitable, Callable, Coroutine, Sequence 1a
20import anyio 1a
21from anyio import create_task_group 1a
22from anyio.abc import TaskGroup 1a
24from prefect.cli.transfer._exceptions import TransferSkipped 1a
25from prefect.cli.transfer._migratable_resources import MigratableProtocol 1a
26from prefect.logging import get_logger 1a
28logger = get_logger(__name__) 1a
31class NodeState(Enum): 1a
32 """State of a node during traversal."""
34 PENDING = "pending" 1a
35 READY = "ready" 1a
36 IN_PROGRESS = "in_progress" 1a
37 COMPLETED = "completed" 1a
38 FAILED = "failed" 1a
39 SKIPPED = "skipped" 1a
42@dataclass 1a
43class NodeStatus: 1a
44 """Tracks the status of a node during traversal."""
46 node: MigratableProtocol 1a
47 state: NodeState = NodeState.PENDING 1a
48 dependents: set[uuid.UUID] = field(default_factory=set) 1a
49 dependencies: set[uuid.UUID] = field(default_factory=set) 1a
50 error: Exception | None = None 1a
53class TransferDAG: 1a
54 """
55 Execution DAG for managing resource transfer dependencies.
57 Uses Kahn's algorithm for topological sorting and concurrent execution.
58 See: https://en.wikipedia.org/wiki/Topological_sorting#Kahn%27s_algorithm
60 The DAG ensures resources are transferred in dependency order while
61 maximizing parallelism for independent resources.
62 """
64 def __init__(self): 1a
65 self.nodes: dict[uuid.UUID, MigratableProtocol] = {}
66 self._dependencies: dict[uuid.UUID, set[uuid.UUID]] = defaultdict(set)
67 self._dependents: dict[uuid.UUID, set[uuid.UUID]] = defaultdict(set)
68 self._status: dict[uuid.UUID, NodeStatus] = {}
69 self._lock = asyncio.Lock()
71 def add_node(self, node: MigratableProtocol) -> uuid.UUID: 1a
72 """
73 Add a node to the graph, deduplicating by source ID.
75 Args:
76 node: Resource to add to the graph
78 Returns:
79 The node's source UUID
80 """
81 if node.source_id not in self.nodes:
82 self.nodes[node.source_id] = node
83 self._status[node.source_id] = NodeStatus(node)
84 return node.source_id
86 def add_edge(self, dependent_id: uuid.UUID, dependency_id: uuid.UUID) -> None: 1a
87 """
88 Add a dependency edge where dependent depends on dependency.
90 Args:
91 dependent_id: ID of the resource that has a dependency
92 dependency_id: ID of the resource being depended upon
93 """
94 if dependency_id in self._dependencies[dependent_id]:
95 return
97 self._dependencies[dependent_id].add(dependency_id)
98 self._dependents[dependency_id].add(dependent_id)
100 self._status[dependent_id].dependencies.add(dependency_id)
101 self._status[dependency_id].dependents.add(dependent_id)
103 async def build_from_roots(self, roots: Sequence[MigratableProtocol]) -> None: 1a
104 """
105 Build the graph from root resources by recursively discovering dependencies.
107 Args:
108 roots: Collection of root resources to start discovery from
109 """
110 visited: set[uuid.UUID] = set()
112 async def visit(resource: MigratableProtocol):
113 if resource.source_id in visited:
114 return
115 visited.add(resource.source_id)
117 rid = self.add_node(resource)
119 visit_coroutines: list[Coroutine[Any, Any, None]] = []
120 for dep in await resource.get_dependencies():
121 did = self.add_node(dep)
122 self.add_edge(rid, did)
123 visit_coroutines.append(visit(dep))
124 await asyncio.gather(*visit_coroutines)
126 visit_coroutines = [visit(r) for r in roots]
127 await asyncio.gather(*visit_coroutines)
129 def has_cycles(self) -> bool: 1a
130 """
131 Check if the graph has cycles using three-color DFS.
133 Uses the classic three-color algorithm where:
134 - WHITE (0): Unvisited node
135 - GRAY (1): Currently being explored (in DFS stack)
136 - BLACK (2): Fully explored
138 A cycle exists if we encounter a GRAY node during traversal (back edge).
139 See: https://en.wikipedia.org/wiki/Depth-first_search#Vertex_orderings
141 Returns:
142 True if the graph contains cycles, False otherwise
143 """
144 WHITE, GRAY, BLACK = 0, 1, 2
145 color = {node_id: WHITE for node_id in self.nodes}
147 def visit(node_id: uuid.UUID) -> bool:
148 if color[node_id] == GRAY:
149 return True # Back edge found - cycle detected
150 if color[node_id] == BLACK:
151 return False # Already fully explored
153 color[node_id] = GRAY
154 for dep_id in self._dependencies[node_id]:
155 if visit(dep_id):
156 return True
157 color[node_id] = BLACK
158 return False
160 for node_id in self.nodes:
161 if color[node_id] == WHITE:
162 if visit(node_id):
163 return True
164 return False
166 def get_execution_layers( 1a
167 self, *, _assume_acyclic: bool = False
168 ) -> list[list[MigratableProtocol]]:
169 """
170 Get execution layers using Kahn's algorithm.
172 Each layer contains nodes that can be executed in parallel.
173 Kahn's algorithm repeatedly removes nodes with no dependencies,
174 forming layers of concurrent work.
176 See: https://en.wikipedia.org/wiki/Topological_sorting#Kahn%27s_algorithm
178 Args:
179 _assume_acyclic: Skip cycle check if caller already verified
181 Returns:
182 List of layers, each containing nodes that can run in parallel
184 Raises:
185 ValueError: If the graph contains cycles
186 """
187 if not _assume_acyclic and self.has_cycles():
188 raise ValueError("Cannot sort DAG with cycles")
190 in_degree = {n: len(self._dependencies[n]) for n in self.nodes}
192 layers: list[list[MigratableProtocol]] = []
193 cur = [n for n in self.nodes if in_degree[n] == 0]
195 while cur:
196 layers.append([self.nodes[n] for n in cur])
197 nxt: list[uuid.UUID] = []
198 for n in cur:
199 for d in self._dependents[n]:
200 in_degree[d] -= 1
201 if in_degree[d] == 0:
202 nxt.append(d)
203 cur = nxt
205 return layers
207 async def execute_concurrent( 1a
208 self,
209 process_node: Callable[[MigratableProtocol], Awaitable[Any]],
210 max_workers: int = 10,
211 skip_on_failure: bool = True,
212 ) -> dict[uuid.UUID, Any]:
213 """
214 Execute the DAG concurrently using Kahn's algorithm.
216 Processes nodes in topological order while maximizing parallelism.
217 When a node completes, its dependents are checked to see if they're
218 ready to execute (all dependencies satisfied).
220 Args:
221 process_node: Async function to process each node
222 max_workers: Maximum number of concurrent workers
223 skip_on_failure: Whether to skip descendants when a node fails
225 Returns:
226 Dictionary mapping node IDs to their results (or exceptions)
228 Raises:
229 ValueError: If the graph contains cycles
230 """
231 if self.has_cycles():
232 raise ValueError("Cannot execute DAG with cycles")
234 layers = self.get_execution_layers(_assume_acyclic=True)
235 logger.debug(f"Execution plan has {len(layers)} layers")
236 for i, layer in enumerate(layers):
237 # Count each type in the layer
238 type_counts: dict[str, int] = {}
239 for node in layer:
240 node_type = type(node).__name__
241 type_counts[node_type] = type_counts.get(node_type, 0) + 1
243 type_summary = ", ".join(
244 [f"{count} {type_name}" for type_name, count in type_counts.items()]
245 )
246 logger.debug(f"Layer {i}: ({type_summary})")
248 # Initialize with nodes that have no dependencies
249 ready_queue: list[uuid.UUID] = []
250 for nid in self.nodes:
251 if not self._dependencies[nid]:
252 ready_queue.append(nid)
253 self._status[nid].state = NodeState.READY
255 results: dict[uuid.UUID, Any] = {}
256 limiter = anyio.CapacityLimiter(max_workers)
257 processing: set[uuid.UUID] = set()
259 async def worker(nid: uuid.UUID, tg: TaskGroup):
260 """Process a single node."""
261 node = self.nodes[nid]
263 # Check if node was skipped after being queued
264 if self._status[nid].state != NodeState.READY:
265 logger.debug(f"Node {node} was skipped before execution")
266 return
268 async with limiter:
269 try:
270 self._status[nid].state = NodeState.IN_PROGRESS
271 logger.debug(f"Processing {node}")
273 res = await process_node(node)
274 results[nid] = res
276 self._status[nid].state = NodeState.COMPLETED
277 logger.debug(f"Completed {node}")
279 # Mark dependents as ready if all their dependencies are satisfied
280 async with self._lock:
281 for did in self._status[nid].dependents:
282 dst = self._status[did]
283 if dst.state == NodeState.PENDING:
284 if all(
285 self._status[d].state == NodeState.COMPLETED
286 for d in dst.dependencies
287 ):
288 dst.state = NodeState.READY
289 # Start the newly ready task immediately
290 if did not in processing:
291 processing.add(did)
292 tg.start_soon(worker, did, tg)
294 except TransferSkipped as e:
295 results[nid] = e
296 self._status[nid].state = NodeState.SKIPPED
297 self._status[nid].error = e
298 logger.debug(f"Skipped {node}: {e}")
300 except Exception as e:
301 results[nid] = e
302 self._status[nid].state = NodeState.FAILED
303 self._status[nid].error = e
304 logger.debug(f"Failed to process {node}: {e}")
306 if skip_on_failure:
307 # Skip all descendants of the failed node
308 to_skip = deque([nid])
309 seen_failed: set[uuid.UUID] = set()
311 while to_skip:
312 cur = to_skip.popleft()
313 if cur in seen_failed:
314 continue
315 seen_failed.add(cur)
317 for did in self._status[cur].dependents:
318 st = self._status[did]
319 # Skip nodes that haven't started yet
320 if st.state in {NodeState.PENDING, NodeState.READY}:
321 st.state = NodeState.SKIPPED
322 results[did] = TransferSkipped(
323 "Skipped due to upstream resource failure"
324 )
325 logger.debug(
326 f"Skipped {self.nodes[did]} due to upstream failure"
327 )
328 to_skip.append(did)
329 finally:
330 processing.discard(nid)
332 async with create_task_group() as tg:
333 # Start processing all initially ready nodes
334 for nid in ready_queue:
335 if self._status[nid].state == NodeState.READY:
336 processing.add(nid)
337 tg.start_soon(worker, nid, tg)
339 return results
341 def get_statistics(self) -> dict[str, Any]: 1a
342 """
343 Get statistics about the DAG structure.
345 Returns:
346 Dictionary with node counts, edge counts, and cycle detection
347 """
348 deps = self._dependencies
349 return {
350 "total_nodes": len(self.nodes),
351 "total_edges": sum(len(v) for v in deps.values()),
352 "max_in_degree": max((len(deps[n]) for n in self.nodes), default=0),
353 "max_out_degree": max(
354 (len(self._dependents[n]) for n in self.nodes), default=0
355 ),
356 "has_cycles": self.has_cycles(),
357 }