Coverage for /usr/local/lib/python3.12/site-packages/prefect/cli/transfer/_dag.py: 15%

174 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1""" 

2Execution DAG for managing resource transfer dependencies. 

3 

4This module provides a pure execution engine that: 

5- Stores nodes by UUID for deduplication 

6- Implements Kahn's algorithm for topological sorting 

7- Manages concurrent execution with worker pools 

8- Handles failure propagation (skip descendants) 

9""" 

10 

11from __future__ import annotations 1a

12 

13import asyncio 1a

14import uuid 1a

15from collections import defaultdict, deque 1a

16from dataclasses import dataclass, field 1a

17from enum import Enum 1a

18from typing import Any, Awaitable, Callable, Coroutine, Sequence 1a

19 

20import anyio 1a

21from anyio import create_task_group 1a

22from anyio.abc import TaskGroup 1a

23 

24from prefect.cli.transfer._exceptions import TransferSkipped 1a

25from prefect.cli.transfer._migratable_resources import MigratableProtocol 1a

26from prefect.logging import get_logger 1a

27 

28logger = get_logger(__name__) 1a

29 

30 

31class NodeState(Enum): 1a

32 """State of a node during traversal.""" 

33 

34 PENDING = "pending" 1a

35 READY = "ready" 1a

36 IN_PROGRESS = "in_progress" 1a

37 COMPLETED = "completed" 1a

38 FAILED = "failed" 1a

39 SKIPPED = "skipped" 1a

40 

41 

42@dataclass 1a

43class NodeStatus: 1a

44 """Tracks the status of a node during traversal.""" 

45 

46 node: MigratableProtocol 1a

47 state: NodeState = NodeState.PENDING 1a

48 dependents: set[uuid.UUID] = field(default_factory=set) 1a

49 dependencies: set[uuid.UUID] = field(default_factory=set) 1a

50 error: Exception | None = None 1a

51 

52 

53class TransferDAG: 1a

54 """ 

55 Execution DAG for managing resource transfer dependencies. 

56 

57 Uses Kahn's algorithm for topological sorting and concurrent execution. 

58 See: https://en.wikipedia.org/wiki/Topological_sorting#Kahn%27s_algorithm 

59 

60 The DAG ensures resources are transferred in dependency order while 

61 maximizing parallelism for independent resources. 

62 """ 

63 

64 def __init__(self): 1a

65 self.nodes: dict[uuid.UUID, MigratableProtocol] = {} 

66 self._dependencies: dict[uuid.UUID, set[uuid.UUID]] = defaultdict(set) 

67 self._dependents: dict[uuid.UUID, set[uuid.UUID]] = defaultdict(set) 

68 self._status: dict[uuid.UUID, NodeStatus] = {} 

69 self._lock = asyncio.Lock() 

70 

71 def add_node(self, node: MigratableProtocol) -> uuid.UUID: 1a

72 """ 

73 Add a node to the graph, deduplicating by source ID. 

74 

75 Args: 

76 node: Resource to add to the graph 

77 

78 Returns: 

79 The node's source UUID 

80 """ 

81 if node.source_id not in self.nodes: 

82 self.nodes[node.source_id] = node 

83 self._status[node.source_id] = NodeStatus(node) 

84 return node.source_id 

85 

86 def add_edge(self, dependent_id: uuid.UUID, dependency_id: uuid.UUID) -> None: 1a

87 """ 

88 Add a dependency edge where dependent depends on dependency. 

89 

90 Args: 

91 dependent_id: ID of the resource that has a dependency 

92 dependency_id: ID of the resource being depended upon 

93 """ 

94 if dependency_id in self._dependencies[dependent_id]: 

95 return 

96 

97 self._dependencies[dependent_id].add(dependency_id) 

98 self._dependents[dependency_id].add(dependent_id) 

99 

100 self._status[dependent_id].dependencies.add(dependency_id) 

101 self._status[dependency_id].dependents.add(dependent_id) 

102 

103 async def build_from_roots(self, roots: Sequence[MigratableProtocol]) -> None: 1a

104 """ 

105 Build the graph from root resources by recursively discovering dependencies. 

106 

107 Args: 

108 roots: Collection of root resources to start discovery from 

109 """ 

110 visited: set[uuid.UUID] = set() 

111 

112 async def visit(resource: MigratableProtocol): 

113 if resource.source_id in visited: 

114 return 

115 visited.add(resource.source_id) 

116 

117 rid = self.add_node(resource) 

118 

119 visit_coroutines: list[Coroutine[Any, Any, None]] = [] 

120 for dep in await resource.get_dependencies(): 

121 did = self.add_node(dep) 

122 self.add_edge(rid, did) 

123 visit_coroutines.append(visit(dep)) 

124 await asyncio.gather(*visit_coroutines) 

125 

126 visit_coroutines = [visit(r) for r in roots] 

127 await asyncio.gather(*visit_coroutines) 

128 

129 def has_cycles(self) -> bool: 1a

130 """ 

131 Check if the graph has cycles using three-color DFS. 

132 

133 Uses the classic three-color algorithm where: 

134 - WHITE (0): Unvisited node 

135 - GRAY (1): Currently being explored (in DFS stack) 

136 - BLACK (2): Fully explored 

137 

138 A cycle exists if we encounter a GRAY node during traversal (back edge). 

139 See: https://en.wikipedia.org/wiki/Depth-first_search#Vertex_orderings 

140 

141 Returns: 

142 True if the graph contains cycles, False otherwise 

143 """ 

144 WHITE, GRAY, BLACK = 0, 1, 2 

145 color = {node_id: WHITE for node_id in self.nodes} 

146 

147 def visit(node_id: uuid.UUID) -> bool: 

148 if color[node_id] == GRAY: 

149 return True # Back edge found - cycle detected 

150 if color[node_id] == BLACK: 

151 return False # Already fully explored 

152 

153 color[node_id] = GRAY 

154 for dep_id in self._dependencies[node_id]: 

155 if visit(dep_id): 

156 return True 

157 color[node_id] = BLACK 

158 return False 

159 

160 for node_id in self.nodes: 

161 if color[node_id] == WHITE: 

162 if visit(node_id): 

163 return True 

164 return False 

165 

166 def get_execution_layers( 1a

167 self, *, _assume_acyclic: bool = False 

168 ) -> list[list[MigratableProtocol]]: 

169 """ 

170 Get execution layers using Kahn's algorithm. 

171 

172 Each layer contains nodes that can be executed in parallel. 

173 Kahn's algorithm repeatedly removes nodes with no dependencies, 

174 forming layers of concurrent work. 

175 

176 See: https://en.wikipedia.org/wiki/Topological_sorting#Kahn%27s_algorithm 

177 

178 Args: 

179 _assume_acyclic: Skip cycle check if caller already verified 

180 

181 Returns: 

182 List of layers, each containing nodes that can run in parallel 

183 

184 Raises: 

185 ValueError: If the graph contains cycles 

186 """ 

187 if not _assume_acyclic and self.has_cycles(): 

188 raise ValueError("Cannot sort DAG with cycles") 

189 

190 in_degree = {n: len(self._dependencies[n]) for n in self.nodes} 

191 

192 layers: list[list[MigratableProtocol]] = [] 

193 cur = [n for n in self.nodes if in_degree[n] == 0] 

194 

195 while cur: 

196 layers.append([self.nodes[n] for n in cur]) 

197 nxt: list[uuid.UUID] = [] 

198 for n in cur: 

199 for d in self._dependents[n]: 

200 in_degree[d] -= 1 

201 if in_degree[d] == 0: 

202 nxt.append(d) 

203 cur = nxt 

204 

205 return layers 

206 

207 async def execute_concurrent( 1a

208 self, 

209 process_node: Callable[[MigratableProtocol], Awaitable[Any]], 

210 max_workers: int = 10, 

211 skip_on_failure: bool = True, 

212 ) -> dict[uuid.UUID, Any]: 

213 """ 

214 Execute the DAG concurrently using Kahn's algorithm. 

215 

216 Processes nodes in topological order while maximizing parallelism. 

217 When a node completes, its dependents are checked to see if they're 

218 ready to execute (all dependencies satisfied). 

219 

220 Args: 

221 process_node: Async function to process each node 

222 max_workers: Maximum number of concurrent workers 

223 skip_on_failure: Whether to skip descendants when a node fails 

224 

225 Returns: 

226 Dictionary mapping node IDs to their results (or exceptions) 

227 

228 Raises: 

229 ValueError: If the graph contains cycles 

230 """ 

231 if self.has_cycles(): 

232 raise ValueError("Cannot execute DAG with cycles") 

233 

234 layers = self.get_execution_layers(_assume_acyclic=True) 

235 logger.debug(f"Execution plan has {len(layers)} layers") 

236 for i, layer in enumerate(layers): 

237 # Count each type in the layer 

238 type_counts: dict[str, int] = {} 

239 for node in layer: 

240 node_type = type(node).__name__ 

241 type_counts[node_type] = type_counts.get(node_type, 0) + 1 

242 

243 type_summary = ", ".join( 

244 [f"{count} {type_name}" for type_name, count in type_counts.items()] 

245 ) 

246 logger.debug(f"Layer {i}: ({type_summary})") 

247 

248 # Initialize with nodes that have no dependencies 

249 ready_queue: list[uuid.UUID] = [] 

250 for nid in self.nodes: 

251 if not self._dependencies[nid]: 

252 ready_queue.append(nid) 

253 self._status[nid].state = NodeState.READY 

254 

255 results: dict[uuid.UUID, Any] = {} 

256 limiter = anyio.CapacityLimiter(max_workers) 

257 processing: set[uuid.UUID] = set() 

258 

259 async def worker(nid: uuid.UUID, tg: TaskGroup): 

260 """Process a single node.""" 

261 node = self.nodes[nid] 

262 

263 # Check if node was skipped after being queued 

264 if self._status[nid].state != NodeState.READY: 

265 logger.debug(f"Node {node} was skipped before execution") 

266 return 

267 

268 async with limiter: 

269 try: 

270 self._status[nid].state = NodeState.IN_PROGRESS 

271 logger.debug(f"Processing {node}") 

272 

273 res = await process_node(node) 

274 results[nid] = res 

275 

276 self._status[nid].state = NodeState.COMPLETED 

277 logger.debug(f"Completed {node}") 

278 

279 # Mark dependents as ready if all their dependencies are satisfied 

280 async with self._lock: 

281 for did in self._status[nid].dependents: 

282 dst = self._status[did] 

283 if dst.state == NodeState.PENDING: 

284 if all( 

285 self._status[d].state == NodeState.COMPLETED 

286 for d in dst.dependencies 

287 ): 

288 dst.state = NodeState.READY 

289 # Start the newly ready task immediately 

290 if did not in processing: 

291 processing.add(did) 

292 tg.start_soon(worker, did, tg) 

293 

294 except TransferSkipped as e: 

295 results[nid] = e 

296 self._status[nid].state = NodeState.SKIPPED 

297 self._status[nid].error = e 

298 logger.debug(f"Skipped {node}: {e}") 

299 

300 except Exception as e: 

301 results[nid] = e 

302 self._status[nid].state = NodeState.FAILED 

303 self._status[nid].error = e 

304 logger.debug(f"Failed to process {node}: {e}") 

305 

306 if skip_on_failure: 

307 # Skip all descendants of the failed node 

308 to_skip = deque([nid]) 

309 seen_failed: set[uuid.UUID] = set() 

310 

311 while to_skip: 

312 cur = to_skip.popleft() 

313 if cur in seen_failed: 

314 continue 

315 seen_failed.add(cur) 

316 

317 for did in self._status[cur].dependents: 

318 st = self._status[did] 

319 # Skip nodes that haven't started yet 

320 if st.state in {NodeState.PENDING, NodeState.READY}: 

321 st.state = NodeState.SKIPPED 

322 results[did] = TransferSkipped( 

323 "Skipped due to upstream resource failure" 

324 ) 

325 logger.debug( 

326 f"Skipped {self.nodes[did]} due to upstream failure" 

327 ) 

328 to_skip.append(did) 

329 finally: 

330 processing.discard(nid) 

331 

332 async with create_task_group() as tg: 

333 # Start processing all initially ready nodes 

334 for nid in ready_queue: 

335 if self._status[nid].state == NodeState.READY: 

336 processing.add(nid) 

337 tg.start_soon(worker, nid, tg) 

338 

339 return results 

340 

341 def get_statistics(self) -> dict[str, Any]: 1a

342 """ 

343 Get statistics about the DAG structure. 

344 

345 Returns: 

346 Dictionary with node counts, edge counts, and cycle detection 

347 """ 

348 deps = self._dependencies 

349 return { 

350 "total_nodes": len(self.nodes), 

351 "total_edges": sum(len(v) for v in deps.values()), 

352 "max_in_degree": max((len(deps[n]) for n in self.nodes), default=0), 

353 "max_out_degree": max( 

354 (len(self._dependents[n]) for n in self.nodes), default=0 

355 ), 

356 "has_cycles": self.has_cycles(), 

357 }