Coverage for /usr/local/lib/python3.12/site-packages/prefect/_experimental/bundles/__init__.py: 18%
229 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 13:38 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 13:38 +0000
1from __future__ import annotations 1a
3import ast 1a
4import asyncio 1a
5import base64 1a
6import gzip 1a
7import importlib 1a
8import inspect 1a
9import json 1a
10import logging 1a
11import multiprocessing 1a
12import multiprocessing.context 1a
13import os 1a
14import subprocess 1a
15import sys 1a
16from contextlib import contextmanager 1a
17from pathlib import Path 1a
18import tempfile 1a
19from types import ModuleType 1a
20from typing import Any, TypedDict 1a
22import anyio 1a
23import cloudpickle # pyright: ignore[reportMissingTypeStubs] 1a
25from prefect.client.schemas.objects import FlowRun 1a
26from prefect.context import SettingsContext, get_settings_context, serialize_context 1a
27from prefect.engine import handle_engine_signals 1a
28from prefect.flow_engine import run_flow 1a
29from prefect.flows import Flow 1a
30from prefect.logging import get_logger 1a
31from prefect.settings.context import get_current_settings 1a
32from prefect.settings.models.root import Settings 1a
33from prefect.utilities.slugify import slugify 1a
35from .execute import execute_bundle_from_file 1a
37logger: logging.Logger = get_logger(__name__) 1a
40def _get_uv_path() -> str: 1a
41 try:
42 import uv
44 uv_path = uv.find_uv_bin()
45 except (ImportError, ModuleNotFoundError, FileNotFoundError):
46 uv_path = "uv"
48 return uv_path
51class SerializedBundle(TypedDict): 1a
52 """
53 A serialized bundle is a serialized function, context, and flow run that can be
54 easily transported for later execution.
55 """
57 function: str 1a
58 context: str 1a
59 flow_run: dict[str, Any] 1a
60 dependencies: str 1a
63def _serialize_bundle_object(obj: Any) -> str: 1a
64 """
65 Serializes an object to a string.
66 """
67 return base64.b64encode(gzip.compress(cloudpickle.dumps(obj))).decode() # pyright: ignore[reportUnknownMemberType]
70def _deserialize_bundle_object(serialized_obj: str) -> Any: 1a
71 """
72 Deserializes an object from a string.
73 """
74 return cloudpickle.loads(gzip.decompress(base64.b64decode(serialized_obj)))
77def _is_local_module(module_name: str, module_path: str | None = None) -> bool: 1a
78 """
79 Check if a module is a local module (not from standard library or site-packages).
81 Args:
82 module_name: The name of the module.
83 module_path: Optional path to the module file.
85 Returns:
86 True if the module is a local module, False otherwise.
87 """
88 # Skip modules that are known to be problematic or not needed
89 skip_modules = {
90 "__pycache__",
91 # Skip test modules
92 "unittest",
93 "pytest",
94 "test_",
95 "_pytest",
96 # Skip prefect modules - they'll be available on remote
97 "prefect",
98 }
100 # Check module name prefixes
101 for skip in skip_modules:
102 if module_name.startswith(skip):
103 return False
105 # Check if it's a built-in module
106 if module_name in sys.builtin_module_names:
107 return False
109 # Check if it's in the standard library (Python 3.10+)
110 if hasattr(sys, "stdlib_module_names"):
111 # Check both full module name and base module name
112 base_module = module_name.split(".")[0]
113 if (
114 module_name in sys.stdlib_module_names
115 or base_module in sys.stdlib_module_names
116 ):
117 return False
119 # If we have the module path, check if it's in site-packages or dist-packages
120 if module_path:
121 path_str = str(module_path)
122 # Also exclude standard library paths
123 if (
124 "site-packages" in path_str
125 or "dist-packages" in path_str
126 or "/lib/python" in path_str
127 or "/.venv/" in path_str
128 ):
129 return False
130 else:
131 # Try to import the module to get its path
132 try:
133 module = importlib.import_module(module_name)
134 if hasattr(module, "__file__") and module.__file__:
135 path_str = str(module.__file__)
136 if (
137 "site-packages" in path_str
138 or "dist-packages" in path_str
139 or "/lib/python" in path_str
140 or "/.venv/" in path_str
141 ):
142 return False
143 except (ImportError, AttributeError):
144 # If we can't import it, it's probably not a real module
145 return False
147 # Only consider it local if it exists and we can verify it
148 return True
151def _extract_imports_from_source(source_code: str) -> set[str]: 1a
152 """
153 Extract all import statements from Python source code.
155 Args:
156 source_code: The Python source code to analyze.
158 Returns:
159 A set of imported module names.
160 """
161 imports: set[str] = set()
163 try:
164 tree = ast.parse(source_code)
165 except SyntaxError:
166 logger.debug("Failed to parse source code for import extraction")
167 return imports
169 for node in ast.walk(tree):
170 if isinstance(node, ast.Import):
171 for alias in node.names:
172 imports.add(alias.name)
173 elif isinstance(node, ast.ImportFrom):
174 if node.module:
175 imports.add(node.module)
176 # Don't add individual imported items as they might be classes/functions
177 # Only track the module itself
179 return imports
182def _discover_local_dependencies( 1a
183 flow: Flow[Any, Any], visited: set[str] | None = None
184) -> set[str]:
185 """
186 Recursively discover local module dependencies of a flow.
188 Args:
189 flow: The flow to analyze.
190 visited: Set of already visited modules to avoid infinite recursion.
192 Returns:
193 A set of local module names that should be serialized by value.
194 """
195 if visited is None:
196 visited = set()
198 local_modules: set[str] = set()
200 # Get the module containing the flow
201 try:
202 flow_module = inspect.getmodule(flow.fn)
203 except (AttributeError, TypeError):
204 # Flow function doesn't have a module (e.g., defined in REPL)
205 return local_modules
207 if not flow_module:
208 return local_modules
210 module_name = flow_module.__name__
212 # Process the flow's module and all its dependencies recursively
213 _process_module_dependencies(flow_module, module_name, local_modules, visited)
215 return local_modules
218def _process_module_dependencies( 1a
219 module: ModuleType,
220 module_name: str,
221 local_modules: set[str],
222 visited: set[str],
223) -> None:
224 """
225 Recursively process a module and discover its local dependencies.
227 Args:
228 module: The module to process.
229 module_name: The name of the module.
230 local_modules: Set to accumulate discovered local modules.
231 visited: Set of already visited modules to avoid infinite recursion.
232 """
233 # Skip if we've already processed this module
234 if module_name in visited:
235 return
236 visited.add(module_name)
238 # Check if this is a local module
239 module_file = getattr(module, "__file__", None)
240 if not module_file or not _is_local_module(module_name, module_file):
241 return
243 local_modules.add(module_name)
245 # Get the source code of the module
246 try:
247 source_code = inspect.getsource(module)
248 except (OSError, TypeError):
249 # Can't get source for this module
250 return
252 imports = _extract_imports_from_source(source_code)
254 # Check each import to see if it's local and recursively process it
255 for import_name in imports:
256 # Skip if already visited
257 if import_name in visited:
258 continue
260 # Try to resolve the import
261 imported_module = None
262 try:
263 # Handle relative imports by resolving them
264 if module_name and "." in module_name:
265 package = ".".join(module_name.split(".")[:-1])
266 try:
267 imported_module = importlib.import_module(import_name, package)
268 except ImportError:
269 imported_module = importlib.import_module(import_name)
270 else:
271 imported_module = importlib.import_module(import_name)
272 except (ImportError, AttributeError):
273 # Can't import, skip it
274 continue
276 # Recursively process this imported module
277 _process_module_dependencies(
278 imported_module, import_name, local_modules, visited
279 )
282@contextmanager 1a
283def _pickle_local_modules_by_value(flow: Flow[Any, Any]): 1a
284 """
285 Context manager that registers local modules for pickle-by-value serialization.
287 Args:
288 flow: The flow whose dependencies should be registered.
289 """
290 registered_modules: list[ModuleType] = []
292 try:
293 # Discover local dependencies
294 local_modules = _discover_local_dependencies(flow)
295 logger.debug("Local modules: %s", local_modules)
297 if local_modules:
298 logger.debug(
299 "Registering local modules for pickle-by-value serialization: %s",
300 ", ".join(local_modules),
301 )
303 # Register each local module for pickle-by-value
304 for module_name in local_modules:
305 try:
306 module = importlib.import_module(module_name)
307 cloudpickle.register_pickle_by_value(module) # pyright: ignore[reportUnknownMemberType] Missing stubs
308 registered_modules.append(module)
309 except (ImportError, AttributeError) as e:
310 logger.debug(
311 "Failed to register module %s for pickle-by-value: %s",
312 module_name,
313 e,
314 )
316 yield
318 finally:
319 # Unregister all modules we registered
320 for module in registered_modules:
321 try:
322 cloudpickle.unregister_pickle_by_value(module) # pyright: ignore[reportUnknownMemberType] Missing stubs
323 except Exception as e:
324 logger.debug(
325 "Failed to unregister module %s from pickle-by-value: %s",
326 getattr(module, "__name__", module),
327 e,
328 )
331def create_bundle_for_flow_run( 1a
332 flow: Flow[Any, Any],
333 flow_run: FlowRun,
334 context: dict[str, Any] | None = None,
335) -> SerializedBundle:
336 """
337 Creates a bundle for a flow run.
339 Args:
340 flow: The flow to bundle.
341 flow_run: The flow run to bundle.
342 context: The context to use when running the flow.
344 Returns:
345 A serialized bundle.
346 """
347 context = context or serialize_context()
349 dependencies = (
350 subprocess.check_output(
351 [
352 _get_uv_path(),
353 "pip",
354 "freeze",
355 # Exclude editable installs because we won't be able to install them in the execution environment
356 "--exclude-editable",
357 ]
358 )
359 .decode()
360 .strip()
361 )
363 # Remove dependencies installed from a local file path because we won't be able
364 # to install them in the execution environment. The user will be responsible for
365 # making sure they are available in the execution environment
366 filtered_dependencies: list[str] = []
367 file_dependencies: list[str] = []
368 for line in dependencies.split("\n"):
369 if "file://" in line:
370 file_dependencies.append(line)
371 else:
372 filtered_dependencies.append(line)
373 dependencies = "\n".join(filtered_dependencies)
374 if file_dependencies:
375 logger.warning(
376 "The following dependencies were installed from a local file path and will not be "
377 "automatically installed in the execution environment: %s. If these dependencies "
378 "are not available in the execution environment, your flow run may fail.",
379 "\n".join(file_dependencies),
380 )
382 # Automatically register local modules for pickle-by-value serialization
383 with _pickle_local_modules_by_value(flow):
384 return {
385 "function": _serialize_bundle_object(flow),
386 "context": _serialize_bundle_object(context),
387 "flow_run": flow_run.model_dump(mode="json"),
388 "dependencies": dependencies,
389 }
392def extract_flow_from_bundle(bundle: SerializedBundle) -> Flow[Any, Any]: 1a
393 """
394 Extracts a flow from a bundle.
395 """
396 return _deserialize_bundle_object(bundle["function"])
399def _extract_and_run_flow( 1a
400 bundle: SerializedBundle,
401 cwd: Path | str | None = None,
402 env: dict[str, Any] | None = None,
403) -> None:
404 """
405 Extracts a flow from a bundle and runs it.
407 Designed to be run in a subprocess.
409 Args:
410 bundle: The bundle to extract and run.
411 cwd: The working directory to use when running the flow.
412 env: The environment to use when running the flow.
413 """
415 os.environ.update(env or {})
416 # TODO: make this a thing we can pass directly to the engine
417 os.environ["PREFECT__ENABLE_CANCELLATION_AND_CRASHED_HOOKS"] = "false"
418 settings_context = get_settings_context()
420 flow = _deserialize_bundle_object(bundle["function"])
421 context = _deserialize_bundle_object(bundle["context"])
422 flow_run = FlowRun.model_validate(bundle["flow_run"])
424 if cwd:
425 os.chdir(cwd)
427 with SettingsContext(
428 profile=settings_context.profile,
429 settings=Settings(),
430 ):
431 with handle_engine_signals(flow_run.id):
432 maybe_coro = run_flow(
433 flow=flow,
434 flow_run=flow_run,
435 context=context,
436 )
437 if asyncio.iscoroutine(maybe_coro):
438 # This is running in a brand new process, so there won't be an existing
439 # event loop.
440 asyncio.run(maybe_coro)
443def execute_bundle_in_subprocess( 1a
444 bundle: SerializedBundle,
445 env: dict[str, Any] | None = None,
446 cwd: Path | str | None = None,
447) -> multiprocessing.context.SpawnProcess:
448 """
449 Executes a bundle in a subprocess.
451 Args:
452 bundle: The bundle to execute.
454 Returns:
455 A multiprocessing.context.SpawnProcess.
456 """
458 ctx = multiprocessing.get_context("spawn")
459 env = env or {}
461 # Install dependencies if necessary
462 if dependencies := bundle.get("dependencies"):
463 subprocess.check_call(
464 [_get_uv_path(), "pip", "install", *dependencies.split("\n")],
465 # Copy the current environment to ensure we install into the correct venv
466 env=os.environ,
467 )
469 process = ctx.Process(
470 target=_extract_and_run_flow,
471 kwargs={
472 "bundle": bundle,
473 "env": get_current_settings().to_environment_variables(exclude_unset=True)
474 | os.environ
475 | env,
476 "cwd": cwd,
477 },
478 )
480 process.start()
482 return process
485def convert_step_to_command( 1a
486 step: dict[str, Any], key: str, quiet: bool = False
487) -> list[str]:
488 """
489 Converts a bundle upload or execution step to a command.
491 Args:
492 step: The step to convert.
493 key: The key to use for the remote file when downloading or uploading.
494 quiet: Whether to suppress `uv` output from the command.
496 Returns:
497 A list of strings representing the command to run the step.
498 """
499 # Start with uv run
500 command = ["uv", "run"]
502 if quiet:
503 command.append("--quiet")
505 step_keys = list(step.keys())
507 if len(step_keys) != 1:
508 raise ValueError("Expected exactly one function in step")
510 function_fqn = step_keys[0]
511 function_args = step[function_fqn]
513 # Add the `--with` argument to handle dependencies for running the step
514 requires: list[str] | str = function_args.get("requires", [])
515 if isinstance(requires, str):
516 requires = [requires]
517 if requires:
518 command.extend(["--with", ",".join(requires)])
520 # Add the `--python` argument to handle the Python version for running the step
521 python_version = sys.version_info
522 command.extend(["--python", f"{python_version.major}.{python_version.minor}"])
524 # Add the `-m` argument to defined the function to run
525 command.extend(["-m", function_fqn])
527 # Add any arguments with values defined in the step
528 for arg_name, arg_value in function_args.items():
529 if arg_name == "requires":
530 continue
532 command.extend([f"--{slugify(arg_name)}", arg_value])
534 # Add the `--key` argument to specify the remote file name
535 command.extend(["--key", key])
537 return command
540def upload_bundle_to_storage( 1a
541 bundle: SerializedBundle, key: str, upload_command: list[str]
542) -> None:
543 """
544 Uploads a bundle to storage.
546 Args:
547 bundle: The serialized bundle to upload.
548 key: The key to use for the remote file when uploading.
549 upload_command: The command to use to upload the bundle as a list of strings.
550 """
551 # Write the bundle to a temporary directory so it can be uploaded to the bundle storage
552 # via the upload command
553 with tempfile.TemporaryDirectory() as temp_dir:
554 Path(temp_dir).joinpath(key).write_bytes(json.dumps(bundle).encode("utf-8"))
556 try:
557 full_command = upload_command + [key]
558 logger.debug("Uploading execution bundle with command: %s", full_command)
559 subprocess.check_call(
560 full_command,
561 cwd=temp_dir,
562 )
563 except subprocess.CalledProcessError as e:
564 raise RuntimeError(e.stderr.decode("utf-8")) from e
567async def aupload_bundle_to_storage( 1a
568 bundle: SerializedBundle, key: str, upload_command: list[str]
569) -> None:
570 """
571 Asynchronously uploads a bundle to storage.
573 Args:
574 bundle: The serialized bundle to upload.
575 key: The key to use for the remote file when uploading.
576 upload_command: The command to use to upload the bundle as a list of strings.
577 """
578 # Write the bundle to a temporary directory so it can be uploaded to the bundle storage
579 # via the upload command
580 with tempfile.TemporaryDirectory() as temp_dir:
581 await (
582 anyio.Path(temp_dir)
583 .joinpath(key)
584 .write_bytes(json.dumps(bundle).encode("utf-8"))
585 )
587 try:
588 full_command = upload_command + [key]
589 logger.debug("Uploading execution bundle with command: %s", full_command)
590 await anyio.run_process(
591 full_command,
592 cwd=temp_dir,
593 )
594 except subprocess.CalledProcessError as e:
595 raise RuntimeError(e.stderr.decode("utf-8")) from e
598__all__ = [ 1a
599 "execute_bundle_from_file",
600 "convert_step_to_command",
601 "create_bundle_for_flow_run",
602 "extract_flow_from_bundle",
603 "execute_bundle_in_subprocess",
604 "SerializedBundle",
605]