Coverage for /usr/local/lib/python3.12/site-packages/prefect/testing/fixtures.py: 0%

241 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1import asyncio 

2import json 

3import os 

4import socket 

5import sys 

6from contextlib import contextmanager 

7from typing import Any, AsyncGenerator, Callable, Generator, List, Optional, Union 

8from unittest import mock 

9from unittest.mock import AsyncMock 

10from uuid import UUID 

11 

12import anyio 

13import httpx 

14import pytest 

15from starlette.status import WS_1008_POLICY_VIOLATION 

16from websockets.asyncio.server import ( 

17 Server, 

18 ServerConnection, 

19 serve, 

20) 

21from websockets.exceptions import ConnectionClosed 

22 

23from prefect.events import Event 

24from prefect.events.clients import ( 

25 AssertingEventsClient, 

26 AssertingPassthroughEventsClient, 

27) 

28from prefect.events.filters import EventFilter 

29from prefect.events.worker import EventsWorker 

30from prefect.server.api.server import SubprocessASGIServer 

31from prefect.server.events.pipeline import EventsPipeline 

32from prefect.settings import ( 

33 PREFECT_API_URL, 

34 PREFECT_SERVER_ALLOW_EPHEMERAL_MODE, 

35 PREFECT_SERVER_CSRF_PROTECTION_ENABLED, 

36 get_current_settings, 

37 temporary_settings, 

38) 

39from prefect.types._datetime import DateTime, now 

40from prefect.utilities.asyncutils import sync_compatible 

41from prefect.utilities.processutils import open_process 

42 

43 

44@pytest.fixture(autouse=True) 

45def add_prefect_loggers_to_caplog( 

46 caplog: pytest.LogCaptureFixture, 

47) -> Generator[None, None, None]: 

48 import logging 

49 

50 logger = logging.getLogger("prefect") 

51 logger.propagate = True 

52 

53 try: 

54 yield 

55 finally: 

56 logger.propagate = False 

57 

58 

59def is_port_in_use(port: int) -> bool: 

60 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 

61 return s.connect_ex(("localhost", port)) == 0 

62 

63 

64@pytest.fixture(scope="session") 

65async def hosted_api_server( 

66 unused_tcp_port_factory: Callable[[], int], 

67) -> AsyncGenerator[str, None]: 

68 """ 

69 Runs an instance of the Prefect API server in a subprocess instead of the using the 

70 ephemeral application. 

71 

72 Uses the same database as the rest of the tests. 

73 

74 Yields: 

75 The API URL 

76 """ 

77 port = unused_tcp_port_factory() 

78 print(f"Running hosted API server on port {port}") 

79 

80 # Will connect to the same database as normal test clients 

81 settings = get_current_settings().to_environment_variables(exclude_unset=True) 

82 async with open_process( 

83 command=[ 

84 "uvicorn", 

85 "--factory", 

86 "prefect.server.api.server:create_app", 

87 "--host", 

88 "127.0.0.1", 

89 "--port", 

90 str(port), 

91 "--log-level", 

92 "info", 

93 ], 

94 stdout=sys.stdout, 

95 stderr=sys.stderr, 

96 env={ 

97 **os.environ, 

98 **settings, 

99 }, 

100 ) as process: 

101 api_url = f"http://localhost:{port}/api" 

102 

103 # Wait for the server to be ready 

104 async with httpx.AsyncClient() as client: 

105 response = None 

106 with anyio.move_on_after(20): 

107 while True: 

108 try: 

109 response = await client.get(api_url + "/health") 

110 except httpx.ConnectError: 

111 pass 

112 else: 

113 if response.status_code == 200: 

114 break 

115 await anyio.sleep(0.1) 

116 if response: 

117 response.raise_for_status() 

118 if not response: 

119 raise RuntimeError( 

120 "Timed out while attempting to connect to hosted test Prefect API." 

121 ) 

122 

123 # Yield to the consuming tests 

124 yield api_url 

125 

126 # Then shutdown the process 

127 try: 

128 process.terminate() 

129 

130 # Give the process a 10 second grace period to shutdown 

131 for _ in range(10): 

132 if process.returncode is not None: 

133 break 

134 await anyio.sleep(1) 

135 else: 

136 # Kill the process if it is not shutdown in time 

137 process.kill() 

138 

139 except ProcessLookupError: 

140 pass 

141 

142 

143@pytest.fixture(autouse=True) 

144def use_hosted_api_server(hosted_api_server: str) -> Generator[str, None, None]: 

145 """ 

146 Sets `PREFECT_API_URL` to the test session's hosted API endpoint. 

147 """ 

148 with temporary_settings( 

149 { 

150 PREFECT_API_URL: hosted_api_server, 

151 PREFECT_SERVER_CSRF_PROTECTION_ENABLED: False, 

152 } 

153 ): 

154 yield hosted_api_server 

155 

156 

157@pytest.fixture 

158def disable_hosted_api_server() -> Generator[None, None, None]: 

159 """ 

160 Disables the hosted API server by setting `PREFECT_API_URL` to `None`. 

161 """ 

162 with temporary_settings( 

163 { 

164 PREFECT_API_URL: None, 

165 } 

166 ): 

167 yield 

168 

169 

170@pytest.fixture 

171def enable_ephemeral_server( 

172 disable_hosted_api_server: None, 

173) -> Generator[None, None, None]: 

174 """ 

175 Enables the ephemeral server by setting `PREFECT_SERVER_ALLOW_EPHEMERAL_MODE` to `True`. 

176 """ 

177 with temporary_settings( 

178 { 

179 PREFECT_SERVER_ALLOW_EPHEMERAL_MODE: True, 

180 } 

181 ): 

182 yield 

183 

184 SubprocessASGIServer().stop() 

185 

186 

187@pytest.fixture 

188def mock_anyio_sleep( 

189 monkeypatch: pytest.MonkeyPatch, 

190) -> Generator[Callable[[float], None], None, None]: 

191 """ 

192 Mock sleep used to not actually sleep but to set the current time to now + sleep 

193 delay seconds while still yielding to other tasks in the event loop. 

194 

195 Provides "assert_sleeps_for" context manager which asserts a sleep time occurred 

196 within the context while using the actual runtime of the context as a tolerance. 

197 """ 

198 original_now = now 

199 original_sleep = anyio.sleep 

200 time_shift = 0.0 

201 

202 async def callback(delay_in_seconds: float) -> None: 

203 nonlocal time_shift 

204 time_shift += float(delay_in_seconds) 

205 # Preserve yield effects of sleep 

206 await original_sleep(0) 

207 

208 def latest_now(*args: Any) -> DateTime: 

209 # Fast-forwards the time by the total sleep time 

210 return original_now(*args).add( 

211 # Ensure we retain float precision 

212 seconds=int(time_shift), 

213 microseconds=int((time_shift - int(time_shift)) * 1000000), 

214 ) 

215 

216 monkeypatch.setattr("prefect.types._datetime.now", latest_now) 

217 

218 sleep = AsyncMock(side_effect=callback) 

219 monkeypatch.setattr("anyio.sleep", sleep) 

220 

221 @contextmanager 

222 def assert_sleeps_for( 

223 seconds: Union[int, float], extra_tolerance: Union[int, float] = 0 

224 ): 

225 """ 

226 Assert that sleep was called for N seconds during the duration of the context. 

227 The runtime of the code during the context of the duration is used as an 

228 upper tolerance to account for sleeps that start based on a time. This is less 

229 brittle than attempting to freeze the current time. 

230 

231 If an integer is provided, the upper tolerance will be rounded up to the nearest 

232 integer. If a float is provided, the upper tolerance will be a float. 

233 

234 An optional extra tolerance may be provided to account for any other issues. 

235 This will be applied symmetrically. 

236 """ 

237 run_t0 = original_now().timestamp() 

238 sleep_t0 = time_shift 

239 yield 

240 run_t1 = original_now().timestamp() 

241 sleep_t1 = time_shift 

242 runtime = run_t1 - run_t0 

243 if isinstance(seconds, int): 

244 # Round tolerance up to the nearest integer if input is an int 

245 runtime = int(runtime) + 1 

246 sleeptime = sleep_t1 - sleep_t0 

247 assert ( 

248 sleeptime - float(extra_tolerance) 

249 <= seconds 

250 <= sleeptime + runtime + extra_tolerance 

251 ), ( 

252 f"Sleep was called for {sleeptime}; expected {seconds} with tolerance of" 

253 f" +{runtime + extra_tolerance}, -{extra_tolerance}" 

254 ) 

255 

256 sleep.assert_sleeps_for = assert_sleeps_for 

257 

258 return sleep 

259 

260 

261class Recorder: 

262 connections: int 

263 path: Optional[str] 

264 events: List[Event] 

265 token: Optional[str] 

266 filter: Optional[EventFilter] 

267 

268 def __init__(self): 

269 self.connections = 0 

270 self.path = None 

271 self.events = [] 

272 

273 

274class Puppeteer: 

275 token: Optional[str] 

276 

277 hard_auth_failure: bool 

278 refuse_any_further_connections: bool 

279 hard_disconnect_after: Optional[UUID] 

280 

281 outgoing_events: List[Event] 

282 

283 def __init__(self): 

284 self.hard_auth_failure = False 

285 self.refuse_any_further_connections = False 

286 self.hard_disconnect_after = None 

287 self.outgoing_events = [] 

288 

289 

290@pytest.fixture 

291def recorder() -> Recorder: 

292 return Recorder() 

293 

294 

295@pytest.fixture 

296def puppeteer() -> Puppeteer: 

297 return Puppeteer() 

298 

299 

300@pytest.fixture 

301async def events_server( 

302 unused_tcp_port: int, recorder: Recorder, puppeteer: Puppeteer 

303) -> AsyncGenerator[Server, None]: 

304 server: Server 

305 

306 async def handler(socket: ServerConnection) -> None: 

307 assert socket.request 

308 path = socket.request.path 

309 recorder.connections += 1 

310 if puppeteer.refuse_any_further_connections: 

311 raise ValueError("nope") 

312 

313 recorder.path = path 

314 

315 if path.endswith("/events/in"): 

316 await incoming_events(socket) 

317 elif path.endswith("/events/out"): 

318 await outgoing_events(socket) 

319 

320 async def incoming_events(socket: ServerConnection): 

321 while True: 

322 try: 

323 message = await socket.recv() 

324 except ConnectionClosed: 

325 return 

326 

327 event = Event.model_validate_json(message) 

328 recorder.events.append(event) 

329 

330 if puppeteer.hard_disconnect_after == event.id: 

331 puppeteer.hard_disconnect_after = None 

332 raise ValueError("Disconnect after incoming event") 

333 

334 async def outgoing_events(socket: ServerConnection): 

335 # 1. authentication 

336 auth_message = json.loads(await socket.recv()) 

337 

338 assert auth_message["type"] == "auth" 

339 recorder.token = auth_message["token"] 

340 if puppeteer.token != recorder.token: 

341 if not puppeteer.hard_auth_failure: 

342 await socket.send( 

343 json.dumps({"type": "auth_failure", "reason": "nope"}) 

344 ) 

345 await socket.close(WS_1008_POLICY_VIOLATION) 

346 return 

347 

348 await socket.send(json.dumps({"type": "auth_success"})) 

349 

350 # 2. filter 

351 filter_message = json.loads(await socket.recv()) 

352 assert filter_message["type"] == "filter" 

353 recorder.filter = EventFilter.model_validate(filter_message["filter"]) 

354 

355 # 3. send events 

356 for event in puppeteer.outgoing_events: 

357 await socket.send( 

358 json.dumps( 

359 { 

360 "type": "event", 

361 "event": event.model_dump(mode="json"), 

362 } 

363 ) 

364 ) 

365 if puppeteer.hard_disconnect_after == event.id: 

366 puppeteer.hard_disconnect_after = None 

367 raise ValueError("zonk") 

368 

369 async with serve(handler, host="localhost", port=unused_tcp_port) as server: 

370 yield server 

371 

372 

373@pytest.fixture 

374def events_api_url(events_server: Server, unused_tcp_port: int) -> str: 

375 return f"http://localhost:{unused_tcp_port}" 

376 

377 

378@pytest.fixture 

379def events_cloud_api_url(events_server: Server, unused_tcp_port: int) -> str: 

380 return f"http://localhost:{unused_tcp_port}/accounts/A/workspaces/W" 

381 

382 

383@pytest.fixture 

384def mock_should_emit_events(monkeypatch: pytest.MonkeyPatch) -> mock.Mock: 

385 m = mock.Mock() 

386 m.return_value = True 

387 monkeypatch.setattr("prefect.events.utilities.should_emit_events", m) 

388 return m 

389 

390 

391@pytest.fixture 

392def asserting_events_worker( 

393 monkeypatch: pytest.MonkeyPatch, 

394) -> Generator[EventsWorker, None, None]: 

395 worker = EventsWorker.instance(AssertingEventsClient) 

396 # Always yield the asserting worker when new instances are retrieved 

397 monkeypatch.setattr(EventsWorker, "instance", lambda *_: worker) 

398 try: 

399 yield worker 

400 finally: 

401 worker.drain() 

402 

403 

404@pytest.fixture 

405def asserting_and_emitting_events_worker( 

406 monkeypatch: pytest.MonkeyPatch, 

407) -> Generator[EventsWorker, None, None]: 

408 worker = EventsWorker.instance(AssertingPassthroughEventsClient) 

409 # Always yield the asserting worker when new instances are retrieved 

410 monkeypatch.setattr(EventsWorker, "instance", lambda *_: worker) 

411 try: 

412 yield worker 

413 finally: 

414 worker.drain() 

415 

416 

417@pytest.fixture 

418async def events_pipeline( 

419 asserting_events_worker: EventsWorker, 

420) -> AsyncGenerator[EventsPipeline, None]: 

421 class AssertingEventsPipeline(EventsPipeline): 

422 @sync_compatible 

423 async def process_events( 

424 self, 

425 dequeue_events: bool = True, 

426 min_events: int = 0, 

427 timeout: int = 10, 

428 ): 

429 async def wait_for_min_events(): 

430 while len(asserting_events_worker._client.events) < min_events: 

431 await asyncio.sleep(0.1) 

432 

433 if min_events: 

434 try: 

435 await asyncio.wait_for(wait_for_min_events(), timeout=timeout) 

436 except TimeoutError: 

437 raise TimeoutError( 

438 f"Timed out waiting for {min_events} events after {timeout} seconds. Only observed {len(asserting_events_worker._client.events)} events." 

439 ) 

440 else: 

441 asserting_events_worker.wait_until_empty() 

442 

443 if dequeue_events: 

444 events = asserting_events_worker._client.pop_events() 

445 else: 

446 events = asserting_events_worker._client.events 

447 

448 messages = self.events_to_messages(events) 

449 await self.process_messages(messages) 

450 

451 yield AssertingEventsPipeline() 

452 

453 

454@pytest.fixture 

455async def emitting_events_pipeline( 

456 asserting_and_emitting_events_worker: EventsWorker, 

457) -> AsyncGenerator[EventsPipeline, None]: 

458 class AssertingAndEmittingEventsPipeline(EventsPipeline): 

459 @sync_compatible 

460 async def process_events(self): 

461 asserting_and_emitting_events_worker.wait_until_empty() 

462 events = asserting_and_emitting_events_worker._client.pop_events() 

463 

464 messages = self.events_to_messages(events) 

465 await self.process_messages(messages) 

466 

467 yield AssertingAndEmittingEventsPipeline() 

468 

469 

470@pytest.fixture 

471def reset_worker_events( 

472 asserting_events_worker: EventsWorker, 

473) -> Generator[None, None, None]: 

474 yield 

475 assert isinstance(asserting_events_worker._client, AssertingEventsClient) 

476 asserting_events_worker._client.events = []