Coverage for /usr/local/lib/python3.12/site-packages/prefect/testing/utilities.py: 0%
118 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 13:38 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 13:38 +0000
1"""
2Internal utilities for tests.
3"""
5from __future__ import annotations
7import atexit
8import shutil
9import warnings
10from contextlib import ExitStack, contextmanager
11from pathlib import Path
12from pprint import pprint
13from tempfile import mkdtemp
14from typing import TYPE_CHECKING, Any, Generator
16import prefect.context
17import prefect.settings
18from prefect.blocks.core import Block
19from prefect.client.orchestration import get_client
20from prefect.client.schemas import sorting
21from prefect.client.schemas.filters import FlowFilter, FlowFilterName
22from prefect.client.utilities import inject_client
23from prefect.events.worker import EventsWorker
24from prefect.logging.handlers import APILogWorker
25from prefect.results import (
26 ResultRecord,
27 ResultRecordMetadata,
28 ResultStore,
29 get_default_result_storage,
30)
31from prefect.serializers import Serializer
32from prefect.server.api.server import SubprocessASGIServer
33from prefect.states import State
35if TYPE_CHECKING:
36 from prefect.client.orchestration import PrefectClient
37 from prefect.client.schemas.objects import FlowRun
38 from prefect.filesystems import ReadableFileSystem
41def exceptions_equal(a: Exception, b: Exception) -> bool:
42 """
43 Exceptions cannot be compared by `==`. They can be compared using `is` but this
44 will fail if the exception is serialized/deserialized so this utility does its
45 best to assert equality using the type and args used to initialize the exception
46 """
47 if a == b:
48 return True
49 return type(a) is type(b) and getattr(a, "args", None) == getattr(b, "args", None)
52def kubernetes_environments_equal(
53 actual: list[dict[str, str]],
54 expected: list[dict[str, str]] | dict[str, str],
55) -> bool:
56 # Convert to a required format and sort by name
57 if isinstance(expected, dict):
58 expected = [{"name": key, "value": value} for key, value in expected.items()]
60 expected = list(sorted(expected, key=lambda item: item["name"]))
62 # Just sort the actual so the format can be tested
63 if isinstance(actual, dict):
64 raise TypeError(
65 "Unexpected type 'dict' for 'actual' kubernetes environment. "
66 "Expected 'List[dict]'. Did you pass your arguments in the wrong order?"
67 )
69 actual = list(sorted(actual, key=lambda item: item["name"]))
71 print("---- Actual Kubernetes environment ----")
72 pprint(actual, width=180)
73 print()
74 print("---- Expected Kubernetes environment ----")
75 pprint(expected, width=180)
76 print()
78 for actual_item, expected_item in zip(actual, expected):
79 if actual_item != expected_item:
80 print("----- First difference in Kubernetes environments -----")
81 print(f"Actual: {actual_item}")
82 print(f"Expected: {expected_item}")
83 break
85 return actual == expected
88@contextmanager
89def assert_does_not_warn(
90 ignore_warnings: list[type[Warning]] | None = None,
91) -> Generator[None, None, None]:
92 """
93 Converts warnings to errors within this context to assert warnings are not raised,
94 except for those specified in ignore_warnings.
96 Parameters:
97 - ignore_warnings: List of warning types to ignore. Example: [DeprecationWarning, UserWarning]
98 """
99 ignore_warnings = ignore_warnings or []
100 with warnings.catch_warnings():
101 warnings.simplefilter("error")
102 for warning_type in ignore_warnings:
103 warnings.filterwarnings("ignore", category=warning_type)
105 try:
106 yield
107 except Warning as warning:
108 raise AssertionError(f"Warning was raised. {warning!r}") from warning
111@contextmanager
112def prefect_test_harness(server_startup_timeout: int | None = 30):
113 """
114 Temporarily run flows against a local SQLite database for testing.
116 Args:
117 server_startup_timeout: The maximum time to wait for the server to start.
118 Defaults to 30 seconds. If set to `None`, the value of
119 `PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS` will be used.
121 Examples:
122 ```python
123 from prefect import flow
124 from prefect.testing.utilities import prefect_test_harness
127 @flow
128 def my_flow():
129 return 'Done!'
131 with prefect_test_harness():
132 assert my_flow() == 'Done!' # run against temporary db
133 ```
134 """
135 from prefect.server.database.dependencies import temporary_database_interface
137 # create temp directory for the testing database
138 temp_dir = mkdtemp()
140 def cleanup_temp_dir(temp_dir):
141 shutil.rmtree(temp_dir)
143 atexit.register(cleanup_temp_dir, temp_dir)
145 with ExitStack() as stack:
146 # temporarily override any database interface components
147 stack.enter_context(temporary_database_interface())
149 DB_PATH = "sqlite+aiosqlite:///" + str(Path(temp_dir) / "prefect-test.db")
150 stack.enter_context(
151 prefect.settings.temporary_settings(
152 # Use a temporary directory for the database
153 updates={
154 prefect.settings.PREFECT_API_DATABASE_CONNECTION_URL: DB_PATH,
155 },
156 )
157 )
158 # start a subprocess server to test against
159 test_server = SubprocessASGIServer()
160 test_server.start(
161 timeout=server_startup_timeout
162 if server_startup_timeout is not None
163 else prefect.settings.PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS.value()
164 )
165 stack.enter_context(
166 prefect.settings.temporary_settings(
167 # Use a temporary directory for the database
168 updates={
169 prefect.settings.PREFECT_API_URL: test_server.api_url,
170 },
171 )
172 )
173 yield
174 # drain the logs before stopping the server to avoid connection errors on shutdown
175 APILogWorker.instance().drain()
176 # drain events to prevent stale events from leaking into subsequent test harnesses
177 EventsWorker.drain_all()
178 test_server.stop()
181async def get_most_recent_flow_run(
182 client: "PrefectClient | None" = None, flow_name: str | None = None
183) -> "FlowRun":
184 if client is None:
185 client = get_client()
187 flow_runs = await client.read_flow_runs(
188 sort=sorting.FlowRunSort.EXPECTED_START_TIME_ASC,
189 limit=1,
190 flow_filter=FlowFilter(name=FlowFilterName(any_=[flow_name]))
191 if flow_name
192 else None,
193 )
195 return flow_runs[0]
198def assert_blocks_equal(
199 found: Block, expected: Block, exclude_private: bool = True, **kwargs: Any
200) -> None:
201 assert isinstance(found, type(expected)), (
202 f"Unexpected type {type(found).__name__}, expected {type(expected).__name__}"
203 )
205 if exclude_private:
206 exclude = set(kwargs.pop("exclude", set()))
207 for field_name in found.__private_attributes__:
208 exclude.add(field_name)
210 assert found.model_dump(exclude=exclude, **kwargs) == expected.model_dump(
211 exclude=exclude, **kwargs
212 )
215async def assert_uses_result_serializer(
216 state: State, serializer: str | Serializer, client: "PrefectClient"
217) -> None:
218 assert isinstance(state.data, (ResultRecord, ResultRecordMetadata))
219 if isinstance(state.data, ResultRecord):
220 result_serializer = state.data.metadata.serializer
221 storage_block_id = state.data.metadata.storage_block_id
222 storage_key = state.data.metadata.storage_key
223 else:
224 result_serializer = state.data.serializer
225 storage_block_id = state.data.storage_block_id
226 storage_key = state.data.storage_key
228 assert (
229 result_serializer.type == serializer
230 if isinstance(serializer, str)
231 else serializer.type
232 )
233 if storage_block_id is not None:
234 block = Block._from_block_document(
235 await client.read_block_document(storage_block_id)
236 )
237 else:
238 block = await get_default_result_storage()
240 blob = await ResultStore(result_storage=block, serializer=result_serializer).aread(
241 storage_key
242 )
243 assert (
244 blob.metadata.serializer == serializer
245 if isinstance(serializer, Serializer)
246 else Serializer(type=serializer)
247 )
250@inject_client
251async def assert_uses_result_storage(
252 state: State, storage: "str | ReadableFileSystem", client: "PrefectClient"
253) -> None:
254 assert isinstance(state.data, (ResultRecord, ResultRecordMetadata))
255 if isinstance(state.data, ResultRecord):
256 assert_blocks_equal(
257 Block._from_block_document(
258 await client.read_block_document(state.data.metadata.storage_block_id)
259 ),
260 (
261 storage
262 if isinstance(storage, Block)
263 else await Block.aload(storage, client=client)
264 ),
265 )
266 else:
267 assert_blocks_equal(
268 Block._from_block_document(
269 await client.read_block_document(state.data.storage_block_id)
270 ),
271 (
272 storage
273 if isinstance(storage, Block)
274 else await Block.aload(storage, client=client)
275 ),
276 )
279def a_test_step(**kwargs: Any) -> dict[str, Any]:
280 kwargs.update(
281 {
282 "output1": 1,
283 "output2": ["b", 2, 3],
284 "output3": "This one is actually a string",
285 }
286 )
287 return kwargs
290def b_test_step(**kwargs: Any) -> dict[str, Any]:
291 kwargs.update({"output1": 1, "output2": ["b", 2, 3]})
292 return kwargs