Coverage for /usr/local/lib/python3.12/site-packages/prefect/_experimental/bundles/__init__.py: 18%

229 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1from __future__ import annotations 1a

2 

3import ast 1a

4import asyncio 1a

5import base64 1a

6import gzip 1a

7import importlib 1a

8import inspect 1a

9import json 1a

10import logging 1a

11import multiprocessing 1a

12import multiprocessing.context 1a

13import os 1a

14import subprocess 1a

15import sys 1a

16from contextlib import contextmanager 1a

17from pathlib import Path 1a

18import tempfile 1a

19from types import ModuleType 1a

20from typing import Any, TypedDict 1a

21 

22import anyio 1a

23import cloudpickle # pyright: ignore[reportMissingTypeStubs] 1a

24 

25from prefect.client.schemas.objects import FlowRun 1a

26from prefect.context import SettingsContext, get_settings_context, serialize_context 1a

27from prefect.engine import handle_engine_signals 1a

28from prefect.flow_engine import run_flow 1a

29from prefect.flows import Flow 1a

30from prefect.logging import get_logger 1a

31from prefect.settings.context import get_current_settings 1a

32from prefect.settings.models.root import Settings 1a

33from prefect.utilities.slugify import slugify 1a

34 

35from .execute import execute_bundle_from_file 1a

36 

37logger: logging.Logger = get_logger(__name__) 1a

38 

39 

40def _get_uv_path() -> str: 1a

41 try: 

42 import uv 

43 

44 uv_path = uv.find_uv_bin() 

45 except (ImportError, ModuleNotFoundError, FileNotFoundError): 

46 uv_path = "uv" 

47 

48 return uv_path 

49 

50 

51class SerializedBundle(TypedDict): 1a

52 """ 

53 A serialized bundle is a serialized function, context, and flow run that can be 

54 easily transported for later execution. 

55 """ 

56 

57 function: str 1a

58 context: str 1a

59 flow_run: dict[str, Any] 1a

60 dependencies: str 1a

61 

62 

63def _serialize_bundle_object(obj: Any) -> str: 1a

64 """ 

65 Serializes an object to a string. 

66 """ 

67 return base64.b64encode(gzip.compress(cloudpickle.dumps(obj))).decode() # pyright: ignore[reportUnknownMemberType] 

68 

69 

70def _deserialize_bundle_object(serialized_obj: str) -> Any: 1a

71 """ 

72 Deserializes an object from a string. 

73 """ 

74 return cloudpickle.loads(gzip.decompress(base64.b64decode(serialized_obj))) 

75 

76 

77def _is_local_module(module_name: str, module_path: str | None = None) -> bool: 1a

78 """ 

79 Check if a module is a local module (not from standard library or site-packages). 

80 

81 Args: 

82 module_name: The name of the module. 

83 module_path: Optional path to the module file. 

84 

85 Returns: 

86 True if the module is a local module, False otherwise. 

87 """ 

88 # Skip modules that are known to be problematic or not needed 

89 skip_modules = { 

90 "__pycache__", 

91 # Skip test modules 

92 "unittest", 

93 "pytest", 

94 "test_", 

95 "_pytest", 

96 # Skip prefect modules - they'll be available on remote 

97 "prefect", 

98 } 

99 

100 # Check module name prefixes 

101 for skip in skip_modules: 

102 if module_name.startswith(skip): 

103 return False 

104 

105 # Check if it's a built-in module 

106 if module_name in sys.builtin_module_names: 

107 return False 

108 

109 # Check if it's in the standard library (Python 3.10+) 

110 if hasattr(sys, "stdlib_module_names"): 

111 # Check both full module name and base module name 

112 base_module = module_name.split(".")[0] 

113 if ( 

114 module_name in sys.stdlib_module_names 

115 or base_module in sys.stdlib_module_names 

116 ): 

117 return False 

118 

119 # If we have the module path, check if it's in site-packages or dist-packages 

120 if module_path: 

121 path_str = str(module_path) 

122 # Also exclude standard library paths 

123 if ( 

124 "site-packages" in path_str 

125 or "dist-packages" in path_str 

126 or "/lib/python" in path_str 

127 or "/.venv/" in path_str 

128 ): 

129 return False 

130 else: 

131 # Try to import the module to get its path 

132 try: 

133 module = importlib.import_module(module_name) 

134 if hasattr(module, "__file__") and module.__file__: 

135 path_str = str(module.__file__) 

136 if ( 

137 "site-packages" in path_str 

138 or "dist-packages" in path_str 

139 or "/lib/python" in path_str 

140 or "/.venv/" in path_str 

141 ): 

142 return False 

143 except (ImportError, AttributeError): 

144 # If we can't import it, it's probably not a real module 

145 return False 

146 

147 # Only consider it local if it exists and we can verify it 

148 return True 

149 

150 

151def _extract_imports_from_source(source_code: str) -> set[str]: 1a

152 """ 

153 Extract all import statements from Python source code. 

154 

155 Args: 

156 source_code: The Python source code to analyze. 

157 

158 Returns: 

159 A set of imported module names. 

160 """ 

161 imports: set[str] = set() 

162 

163 try: 

164 tree = ast.parse(source_code) 

165 except SyntaxError: 

166 logger.debug("Failed to parse source code for import extraction") 

167 return imports 

168 

169 for node in ast.walk(tree): 

170 if isinstance(node, ast.Import): 

171 for alias in node.names: 

172 imports.add(alias.name) 

173 elif isinstance(node, ast.ImportFrom): 

174 if node.module: 

175 imports.add(node.module) 

176 # Don't add individual imported items as they might be classes/functions 

177 # Only track the module itself 

178 

179 return imports 

180 

181 

182def _discover_local_dependencies( 1a

183 flow: Flow[Any, Any], visited: set[str] | None = None 

184) -> set[str]: 

185 """ 

186 Recursively discover local module dependencies of a flow. 

187 

188 Args: 

189 flow: The flow to analyze. 

190 visited: Set of already visited modules to avoid infinite recursion. 

191 

192 Returns: 

193 A set of local module names that should be serialized by value. 

194 """ 

195 if visited is None: 

196 visited = set() 

197 

198 local_modules: set[str] = set() 

199 

200 # Get the module containing the flow 

201 try: 

202 flow_module = inspect.getmodule(flow.fn) 

203 except (AttributeError, TypeError): 

204 # Flow function doesn't have a module (e.g., defined in REPL) 

205 return local_modules 

206 

207 if not flow_module: 

208 return local_modules 

209 

210 module_name = flow_module.__name__ 

211 

212 # Process the flow's module and all its dependencies recursively 

213 _process_module_dependencies(flow_module, module_name, local_modules, visited) 

214 

215 return local_modules 

216 

217 

218def _process_module_dependencies( 1a

219 module: ModuleType, 

220 module_name: str, 

221 local_modules: set[str], 

222 visited: set[str], 

223) -> None: 

224 """ 

225 Recursively process a module and discover its local dependencies. 

226 

227 Args: 

228 module: The module to process. 

229 module_name: The name of the module. 

230 local_modules: Set to accumulate discovered local modules. 

231 visited: Set of already visited modules to avoid infinite recursion. 

232 """ 

233 # Skip if we've already processed this module 

234 if module_name in visited: 

235 return 

236 visited.add(module_name) 

237 

238 # Check if this is a local module 

239 module_file = getattr(module, "__file__", None) 

240 if not module_file or not _is_local_module(module_name, module_file): 

241 return 

242 

243 local_modules.add(module_name) 

244 

245 # Get the source code of the module 

246 try: 

247 source_code = inspect.getsource(module) 

248 except (OSError, TypeError): 

249 # Can't get source for this module 

250 return 

251 

252 imports = _extract_imports_from_source(source_code) 

253 

254 # Check each import to see if it's local and recursively process it 

255 for import_name in imports: 

256 # Skip if already visited 

257 if import_name in visited: 

258 continue 

259 

260 # Try to resolve the import 

261 imported_module = None 

262 try: 

263 # Handle relative imports by resolving them 

264 if module_name and "." in module_name: 

265 package = ".".join(module_name.split(".")[:-1]) 

266 try: 

267 imported_module = importlib.import_module(import_name, package) 

268 except ImportError: 

269 imported_module = importlib.import_module(import_name) 

270 else: 

271 imported_module = importlib.import_module(import_name) 

272 except (ImportError, AttributeError): 

273 # Can't import, skip it 

274 continue 

275 

276 # Recursively process this imported module 

277 _process_module_dependencies( 

278 imported_module, import_name, local_modules, visited 

279 ) 

280 

281 

282@contextmanager 1a

283def _pickle_local_modules_by_value(flow: Flow[Any, Any]): 1a

284 """ 

285 Context manager that registers local modules for pickle-by-value serialization. 

286 

287 Args: 

288 flow: The flow whose dependencies should be registered. 

289 """ 

290 registered_modules: list[ModuleType] = [] 

291 

292 try: 

293 # Discover local dependencies 

294 local_modules = _discover_local_dependencies(flow) 

295 logger.debug("Local modules: %s", local_modules) 

296 

297 if local_modules: 

298 logger.debug( 

299 "Registering local modules for pickle-by-value serialization: %s", 

300 ", ".join(local_modules), 

301 ) 

302 

303 # Register each local module for pickle-by-value 

304 for module_name in local_modules: 

305 try: 

306 module = importlib.import_module(module_name) 

307 cloudpickle.register_pickle_by_value(module) # pyright: ignore[reportUnknownMemberType] Missing stubs 

308 registered_modules.append(module) 

309 except (ImportError, AttributeError) as e: 

310 logger.debug( 

311 "Failed to register module %s for pickle-by-value: %s", 

312 module_name, 

313 e, 

314 ) 

315 

316 yield 

317 

318 finally: 

319 # Unregister all modules we registered 

320 for module in registered_modules: 

321 try: 

322 cloudpickle.unregister_pickle_by_value(module) # pyright: ignore[reportUnknownMemberType] Missing stubs 

323 except Exception as e: 

324 logger.debug( 

325 "Failed to unregister module %s from pickle-by-value: %s", 

326 getattr(module, "__name__", module), 

327 e, 

328 ) 

329 

330 

331def create_bundle_for_flow_run( 1a

332 flow: Flow[Any, Any], 

333 flow_run: FlowRun, 

334 context: dict[str, Any] | None = None, 

335) -> SerializedBundle: 

336 """ 

337 Creates a bundle for a flow run. 

338 

339 Args: 

340 flow: The flow to bundle. 

341 flow_run: The flow run to bundle. 

342 context: The context to use when running the flow. 

343 

344 Returns: 

345 A serialized bundle. 

346 """ 

347 context = context or serialize_context() 

348 

349 dependencies = ( 

350 subprocess.check_output( 

351 [ 

352 _get_uv_path(), 

353 "pip", 

354 "freeze", 

355 # Exclude editable installs because we won't be able to install them in the execution environment 

356 "--exclude-editable", 

357 ] 

358 ) 

359 .decode() 

360 .strip() 

361 ) 

362 

363 # Remove dependencies installed from a local file path because we won't be able 

364 # to install them in the execution environment. The user will be responsible for 

365 # making sure they are available in the execution environment 

366 filtered_dependencies: list[str] = [] 

367 file_dependencies: list[str] = [] 

368 for line in dependencies.split("\n"): 

369 if "file://" in line: 

370 file_dependencies.append(line) 

371 else: 

372 filtered_dependencies.append(line) 

373 dependencies = "\n".join(filtered_dependencies) 

374 if file_dependencies: 

375 logger.warning( 

376 "The following dependencies were installed from a local file path and will not be " 

377 "automatically installed in the execution environment: %s. If these dependencies " 

378 "are not available in the execution environment, your flow run may fail.", 

379 "\n".join(file_dependencies), 

380 ) 

381 

382 # Automatically register local modules for pickle-by-value serialization 

383 with _pickle_local_modules_by_value(flow): 

384 return { 

385 "function": _serialize_bundle_object(flow), 

386 "context": _serialize_bundle_object(context), 

387 "flow_run": flow_run.model_dump(mode="json"), 

388 "dependencies": dependencies, 

389 } 

390 

391 

392def extract_flow_from_bundle(bundle: SerializedBundle) -> Flow[Any, Any]: 1a

393 """ 

394 Extracts a flow from a bundle. 

395 """ 

396 return _deserialize_bundle_object(bundle["function"]) 

397 

398 

399def _extract_and_run_flow( 1a

400 bundle: SerializedBundle, 

401 cwd: Path | str | None = None, 

402 env: dict[str, Any] | None = None, 

403) -> None: 

404 """ 

405 Extracts a flow from a bundle and runs it. 

406 

407 Designed to be run in a subprocess. 

408 

409 Args: 

410 bundle: The bundle to extract and run. 

411 cwd: The working directory to use when running the flow. 

412 env: The environment to use when running the flow. 

413 """ 

414 

415 os.environ.update(env or {}) 

416 # TODO: make this a thing we can pass directly to the engine 

417 os.environ["PREFECT__ENABLE_CANCELLATION_AND_CRASHED_HOOKS"] = "false" 

418 settings_context = get_settings_context() 

419 

420 flow = _deserialize_bundle_object(bundle["function"]) 

421 context = _deserialize_bundle_object(bundle["context"]) 

422 flow_run = FlowRun.model_validate(bundle["flow_run"]) 

423 

424 if cwd: 

425 os.chdir(cwd) 

426 

427 with SettingsContext( 

428 profile=settings_context.profile, 

429 settings=Settings(), 

430 ): 

431 with handle_engine_signals(flow_run.id): 

432 maybe_coro = run_flow( 

433 flow=flow, 

434 flow_run=flow_run, 

435 context=context, 

436 ) 

437 if asyncio.iscoroutine(maybe_coro): 

438 # This is running in a brand new process, so there won't be an existing 

439 # event loop. 

440 asyncio.run(maybe_coro) 

441 

442 

443def execute_bundle_in_subprocess( 1a

444 bundle: SerializedBundle, 

445 env: dict[str, Any] | None = None, 

446 cwd: Path | str | None = None, 

447) -> multiprocessing.context.SpawnProcess: 

448 """ 

449 Executes a bundle in a subprocess. 

450 

451 Args: 

452 bundle: The bundle to execute. 

453 

454 Returns: 

455 A multiprocessing.context.SpawnProcess. 

456 """ 

457 

458 ctx = multiprocessing.get_context("spawn") 

459 env = env or {} 

460 

461 # Install dependencies if necessary 

462 if dependencies := bundle.get("dependencies"): 

463 subprocess.check_call( 

464 [_get_uv_path(), "pip", "install", *dependencies.split("\n")], 

465 # Copy the current environment to ensure we install into the correct venv 

466 env=os.environ, 

467 ) 

468 

469 process = ctx.Process( 

470 target=_extract_and_run_flow, 

471 kwargs={ 

472 "bundle": bundle, 

473 "env": get_current_settings().to_environment_variables(exclude_unset=True) 

474 | os.environ 

475 | env, 

476 "cwd": cwd, 

477 }, 

478 ) 

479 

480 process.start() 

481 

482 return process 

483 

484 

485def convert_step_to_command( 1a

486 step: dict[str, Any], key: str, quiet: bool = False 

487) -> list[str]: 

488 """ 

489 Converts a bundle upload or execution step to a command. 

490 

491 Args: 

492 step: The step to convert. 

493 key: The key to use for the remote file when downloading or uploading. 

494 quiet: Whether to suppress `uv` output from the command. 

495 

496 Returns: 

497 A list of strings representing the command to run the step. 

498 """ 

499 # Start with uv run 

500 command = ["uv", "run"] 

501 

502 if quiet: 

503 command.append("--quiet") 

504 

505 step_keys = list(step.keys()) 

506 

507 if len(step_keys) != 1: 

508 raise ValueError("Expected exactly one function in step") 

509 

510 function_fqn = step_keys[0] 

511 function_args = step[function_fqn] 

512 

513 # Add the `--with` argument to handle dependencies for running the step 

514 requires: list[str] | str = function_args.get("requires", []) 

515 if isinstance(requires, str): 

516 requires = [requires] 

517 if requires: 

518 command.extend(["--with", ",".join(requires)]) 

519 

520 # Add the `--python` argument to handle the Python version for running the step 

521 python_version = sys.version_info 

522 command.extend(["--python", f"{python_version.major}.{python_version.minor}"]) 

523 

524 # Add the `-m` argument to defined the function to run 

525 command.extend(["-m", function_fqn]) 

526 

527 # Add any arguments with values defined in the step 

528 for arg_name, arg_value in function_args.items(): 

529 if arg_name == "requires": 

530 continue 

531 

532 command.extend([f"--{slugify(arg_name)}", arg_value]) 

533 

534 # Add the `--key` argument to specify the remote file name 

535 command.extend(["--key", key]) 

536 

537 return command 

538 

539 

540def upload_bundle_to_storage( 1a

541 bundle: SerializedBundle, key: str, upload_command: list[str] 

542) -> None: 

543 """ 

544 Uploads a bundle to storage. 

545 

546 Args: 

547 bundle: The serialized bundle to upload. 

548 key: The key to use for the remote file when uploading. 

549 upload_command: The command to use to upload the bundle as a list of strings. 

550 """ 

551 # Write the bundle to a temporary directory so it can be uploaded to the bundle storage 

552 # via the upload command 

553 with tempfile.TemporaryDirectory() as temp_dir: 

554 Path(temp_dir).joinpath(key).write_bytes(json.dumps(bundle).encode("utf-8")) 

555 

556 try: 

557 full_command = upload_command + [key] 

558 logger.debug("Uploading execution bundle with command: %s", full_command) 

559 subprocess.check_call( 

560 full_command, 

561 cwd=temp_dir, 

562 ) 

563 except subprocess.CalledProcessError as e: 

564 raise RuntimeError(e.stderr.decode("utf-8")) from e 

565 

566 

567async def aupload_bundle_to_storage( 1a

568 bundle: SerializedBundle, key: str, upload_command: list[str] 

569) -> None: 

570 """ 

571 Asynchronously uploads a bundle to storage. 

572 

573 Args: 

574 bundle: The serialized bundle to upload. 

575 key: The key to use for the remote file when uploading. 

576 upload_command: The command to use to upload the bundle as a list of strings. 

577 """ 

578 # Write the bundle to a temporary directory so it can be uploaded to the bundle storage 

579 # via the upload command 

580 with tempfile.TemporaryDirectory() as temp_dir: 

581 await ( 

582 anyio.Path(temp_dir) 

583 .joinpath(key) 

584 .write_bytes(json.dumps(bundle).encode("utf-8")) 

585 ) 

586 

587 try: 

588 full_command = upload_command + [key] 

589 logger.debug("Uploading execution bundle with command: %s", full_command) 

590 await anyio.run_process( 

591 full_command, 

592 cwd=temp_dir, 

593 ) 

594 except subprocess.CalledProcessError as e: 

595 raise RuntimeError(e.stderr.decode("utf-8")) from e 

596 

597 

598__all__ = [ 1a

599 "execute_bundle_from_file", 

600 "convert_step_to_command", 

601 "create_bundle_for_flow_run", 

602 "extract_flow_from_bundle", 

603 "execute_bundle_in_subprocess", 

604 "SerializedBundle", 

605]