Coverage for /usr/local/lib/python3.12/site-packages/prefect/testing/utilities.py: 0%

118 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1""" 

2Internal utilities for tests. 

3""" 

4 

5from __future__ import annotations 

6 

7import atexit 

8import shutil 

9import warnings 

10from contextlib import ExitStack, contextmanager 

11from pathlib import Path 

12from pprint import pprint 

13from tempfile import mkdtemp 

14from typing import TYPE_CHECKING, Any, Generator 

15 

16import prefect.context 

17import prefect.settings 

18from prefect.blocks.core import Block 

19from prefect.client.orchestration import get_client 

20from prefect.client.schemas import sorting 

21from prefect.client.schemas.filters import FlowFilter, FlowFilterName 

22from prefect.client.utilities import inject_client 

23from prefect.events.worker import EventsWorker 

24from prefect.logging.handlers import APILogWorker 

25from prefect.results import ( 

26 ResultRecord, 

27 ResultRecordMetadata, 

28 ResultStore, 

29 get_default_result_storage, 

30) 

31from prefect.serializers import Serializer 

32from prefect.server.api.server import SubprocessASGIServer 

33from prefect.states import State 

34 

35if TYPE_CHECKING: 

36 from prefect.client.orchestration import PrefectClient 

37 from prefect.client.schemas.objects import FlowRun 

38 from prefect.filesystems import ReadableFileSystem 

39 

40 

41def exceptions_equal(a: Exception, b: Exception) -> bool: 

42 """ 

43 Exceptions cannot be compared by `==`. They can be compared using `is` but this 

44 will fail if the exception is serialized/deserialized so this utility does its 

45 best to assert equality using the type and args used to initialize the exception 

46 """ 

47 if a == b: 

48 return True 

49 return type(a) is type(b) and getattr(a, "args", None) == getattr(b, "args", None) 

50 

51 

52def kubernetes_environments_equal( 

53 actual: list[dict[str, str]], 

54 expected: list[dict[str, str]] | dict[str, str], 

55) -> bool: 

56 # Convert to a required format and sort by name 

57 if isinstance(expected, dict): 

58 expected = [{"name": key, "value": value} for key, value in expected.items()] 

59 

60 expected = list(sorted(expected, key=lambda item: item["name"])) 

61 

62 # Just sort the actual so the format can be tested 

63 if isinstance(actual, dict): 

64 raise TypeError( 

65 "Unexpected type 'dict' for 'actual' kubernetes environment. " 

66 "Expected 'List[dict]'. Did you pass your arguments in the wrong order?" 

67 ) 

68 

69 actual = list(sorted(actual, key=lambda item: item["name"])) 

70 

71 print("---- Actual Kubernetes environment ----") 

72 pprint(actual, width=180) 

73 print() 

74 print("---- Expected Kubernetes environment ----") 

75 pprint(expected, width=180) 

76 print() 

77 

78 for actual_item, expected_item in zip(actual, expected): 

79 if actual_item != expected_item: 

80 print("----- First difference in Kubernetes environments -----") 

81 print(f"Actual: {actual_item}") 

82 print(f"Expected: {expected_item}") 

83 break 

84 

85 return actual == expected 

86 

87 

88@contextmanager 

89def assert_does_not_warn( 

90 ignore_warnings: list[type[Warning]] | None = None, 

91) -> Generator[None, None, None]: 

92 """ 

93 Converts warnings to errors within this context to assert warnings are not raised, 

94 except for those specified in ignore_warnings. 

95 

96 Parameters: 

97 - ignore_warnings: List of warning types to ignore. Example: [DeprecationWarning, UserWarning] 

98 """ 

99 ignore_warnings = ignore_warnings or [] 

100 with warnings.catch_warnings(): 

101 warnings.simplefilter("error") 

102 for warning_type in ignore_warnings: 

103 warnings.filterwarnings("ignore", category=warning_type) 

104 

105 try: 

106 yield 

107 except Warning as warning: 

108 raise AssertionError(f"Warning was raised. {warning!r}") from warning 

109 

110 

111@contextmanager 

112def prefect_test_harness(server_startup_timeout: int | None = 30): 

113 """ 

114 Temporarily run flows against a local SQLite database for testing. 

115 

116 Args: 

117 server_startup_timeout: The maximum time to wait for the server to start. 

118 Defaults to 30 seconds. If set to `None`, the value of 

119 `PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS` will be used. 

120 

121 Examples: 

122 ```python 

123 from prefect import flow 

124 from prefect.testing.utilities import prefect_test_harness 

125 

126 

127 @flow 

128 def my_flow(): 

129 return 'Done!' 

130 

131 with prefect_test_harness(): 

132 assert my_flow() == 'Done!' # run against temporary db 

133 ``` 

134 """ 

135 from prefect.server.database.dependencies import temporary_database_interface 

136 

137 # create temp directory for the testing database 

138 temp_dir = mkdtemp() 

139 

140 def cleanup_temp_dir(temp_dir): 

141 shutil.rmtree(temp_dir) 

142 

143 atexit.register(cleanup_temp_dir, temp_dir) 

144 

145 with ExitStack() as stack: 

146 # temporarily override any database interface components 

147 stack.enter_context(temporary_database_interface()) 

148 

149 DB_PATH = "sqlite+aiosqlite:///" + str(Path(temp_dir) / "prefect-test.db") 

150 stack.enter_context( 

151 prefect.settings.temporary_settings( 

152 # Use a temporary directory for the database 

153 updates={ 

154 prefect.settings.PREFECT_API_DATABASE_CONNECTION_URL: DB_PATH, 

155 }, 

156 ) 

157 ) 

158 # start a subprocess server to test against 

159 test_server = SubprocessASGIServer() 

160 test_server.start( 

161 timeout=server_startup_timeout 

162 if server_startup_timeout is not None 

163 else prefect.settings.PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS.value() 

164 ) 

165 stack.enter_context( 

166 prefect.settings.temporary_settings( 

167 # Use a temporary directory for the database 

168 updates={ 

169 prefect.settings.PREFECT_API_URL: test_server.api_url, 

170 }, 

171 ) 

172 ) 

173 yield 

174 # drain the logs before stopping the server to avoid connection errors on shutdown 

175 APILogWorker.instance().drain() 

176 # drain events to prevent stale events from leaking into subsequent test harnesses 

177 EventsWorker.drain_all() 

178 test_server.stop() 

179 

180 

181async def get_most_recent_flow_run( 

182 client: "PrefectClient | None" = None, flow_name: str | None = None 

183) -> "FlowRun": 

184 if client is None: 

185 client = get_client() 

186 

187 flow_runs = await client.read_flow_runs( 

188 sort=sorting.FlowRunSort.EXPECTED_START_TIME_ASC, 

189 limit=1, 

190 flow_filter=FlowFilter(name=FlowFilterName(any_=[flow_name])) 

191 if flow_name 

192 else None, 

193 ) 

194 

195 return flow_runs[0] 

196 

197 

198def assert_blocks_equal( 

199 found: Block, expected: Block, exclude_private: bool = True, **kwargs: Any 

200) -> None: 

201 assert isinstance(found, type(expected)), ( 

202 f"Unexpected type {type(found).__name__}, expected {type(expected).__name__}" 

203 ) 

204 

205 if exclude_private: 

206 exclude = set(kwargs.pop("exclude", set())) 

207 for field_name in found.__private_attributes__: 

208 exclude.add(field_name) 

209 

210 assert found.model_dump(exclude=exclude, **kwargs) == expected.model_dump( 

211 exclude=exclude, **kwargs 

212 ) 

213 

214 

215async def assert_uses_result_serializer( 

216 state: State, serializer: str | Serializer, client: "PrefectClient" 

217) -> None: 

218 assert isinstance(state.data, (ResultRecord, ResultRecordMetadata)) 

219 if isinstance(state.data, ResultRecord): 

220 result_serializer = state.data.metadata.serializer 

221 storage_block_id = state.data.metadata.storage_block_id 

222 storage_key = state.data.metadata.storage_key 

223 else: 

224 result_serializer = state.data.serializer 

225 storage_block_id = state.data.storage_block_id 

226 storage_key = state.data.storage_key 

227 

228 assert ( 

229 result_serializer.type == serializer 

230 if isinstance(serializer, str) 

231 else serializer.type 

232 ) 

233 if storage_block_id is not None: 

234 block = Block._from_block_document( 

235 await client.read_block_document(storage_block_id) 

236 ) 

237 else: 

238 block = await get_default_result_storage() 

239 

240 blob = await ResultStore(result_storage=block, serializer=result_serializer).aread( 

241 storage_key 

242 ) 

243 assert ( 

244 blob.metadata.serializer == serializer 

245 if isinstance(serializer, Serializer) 

246 else Serializer(type=serializer) 

247 ) 

248 

249 

250@inject_client 

251async def assert_uses_result_storage( 

252 state: State, storage: "str | ReadableFileSystem", client: "PrefectClient" 

253) -> None: 

254 assert isinstance(state.data, (ResultRecord, ResultRecordMetadata)) 

255 if isinstance(state.data, ResultRecord): 

256 assert_blocks_equal( 

257 Block._from_block_document( 

258 await client.read_block_document(state.data.metadata.storage_block_id) 

259 ), 

260 ( 

261 storage 

262 if isinstance(storage, Block) 

263 else await Block.aload(storage, client=client) 

264 ), 

265 ) 

266 else: 

267 assert_blocks_equal( 

268 Block._from_block_document( 

269 await client.read_block_document(state.data.storage_block_id) 

270 ), 

271 ( 

272 storage 

273 if isinstance(storage, Block) 

274 else await Block.aload(storage, client=client) 

275 ), 

276 ) 

277 

278 

279def a_test_step(**kwargs: Any) -> dict[str, Any]: 

280 kwargs.update( 

281 { 

282 "output1": 1, 

283 "output2": ["b", 2, 3], 

284 "output3": "This one is actually a string", 

285 } 

286 ) 

287 return kwargs 

288 

289 

290def b_test_step(**kwargs: Any) -> dict[str, Any]: 

291 kwargs.update({"output1": 1, "output2": ["b", 2, 3]}) 

292 return kwargs