Coverage for /usr/local/lib/python3.12/site-packages/prefect/testing/fixtures.py: 0%
241 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 13:38 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 13:38 +0000
1import asyncio
2import json
3import os
4import socket
5import sys
6from contextlib import contextmanager
7from typing import Any, AsyncGenerator, Callable, Generator, List, Optional, Union
8from unittest import mock
9from unittest.mock import AsyncMock
10from uuid import UUID
12import anyio
13import httpx
14import pytest
15from starlette.status import WS_1008_POLICY_VIOLATION
16from websockets.asyncio.server import (
17 Server,
18 ServerConnection,
19 serve,
20)
21from websockets.exceptions import ConnectionClosed
23from prefect.events import Event
24from prefect.events.clients import (
25 AssertingEventsClient,
26 AssertingPassthroughEventsClient,
27)
28from prefect.events.filters import EventFilter
29from prefect.events.worker import EventsWorker
30from prefect.server.api.server import SubprocessASGIServer
31from prefect.server.events.pipeline import EventsPipeline
32from prefect.settings import (
33 PREFECT_API_URL,
34 PREFECT_SERVER_ALLOW_EPHEMERAL_MODE,
35 PREFECT_SERVER_CSRF_PROTECTION_ENABLED,
36 get_current_settings,
37 temporary_settings,
38)
39from prefect.types._datetime import DateTime, now
40from prefect.utilities.asyncutils import sync_compatible
41from prefect.utilities.processutils import open_process
44@pytest.fixture(autouse=True)
45def add_prefect_loggers_to_caplog(
46 caplog: pytest.LogCaptureFixture,
47) -> Generator[None, None, None]:
48 import logging
50 logger = logging.getLogger("prefect")
51 logger.propagate = True
53 try:
54 yield
55 finally:
56 logger.propagate = False
59def is_port_in_use(port: int) -> bool:
60 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
61 return s.connect_ex(("localhost", port)) == 0
64@pytest.fixture(scope="session")
65async def hosted_api_server(
66 unused_tcp_port_factory: Callable[[], int],
67) -> AsyncGenerator[str, None]:
68 """
69 Runs an instance of the Prefect API server in a subprocess instead of the using the
70 ephemeral application.
72 Uses the same database as the rest of the tests.
74 Yields:
75 The API URL
76 """
77 port = unused_tcp_port_factory()
78 print(f"Running hosted API server on port {port}")
80 # Will connect to the same database as normal test clients
81 settings = get_current_settings().to_environment_variables(exclude_unset=True)
82 async with open_process(
83 command=[
84 "uvicorn",
85 "--factory",
86 "prefect.server.api.server:create_app",
87 "--host",
88 "127.0.0.1",
89 "--port",
90 str(port),
91 "--log-level",
92 "info",
93 ],
94 stdout=sys.stdout,
95 stderr=sys.stderr,
96 env={
97 **os.environ,
98 **settings,
99 },
100 ) as process:
101 api_url = f"http://localhost:{port}/api"
103 # Wait for the server to be ready
104 async with httpx.AsyncClient() as client:
105 response = None
106 with anyio.move_on_after(20):
107 while True:
108 try:
109 response = await client.get(api_url + "/health")
110 except httpx.ConnectError:
111 pass
112 else:
113 if response.status_code == 200:
114 break
115 await anyio.sleep(0.1)
116 if response:
117 response.raise_for_status()
118 if not response:
119 raise RuntimeError(
120 "Timed out while attempting to connect to hosted test Prefect API."
121 )
123 # Yield to the consuming tests
124 yield api_url
126 # Then shutdown the process
127 try:
128 process.terminate()
130 # Give the process a 10 second grace period to shutdown
131 for _ in range(10):
132 if process.returncode is not None:
133 break
134 await anyio.sleep(1)
135 else:
136 # Kill the process if it is not shutdown in time
137 process.kill()
139 except ProcessLookupError:
140 pass
143@pytest.fixture(autouse=True)
144def use_hosted_api_server(hosted_api_server: str) -> Generator[str, None, None]:
145 """
146 Sets `PREFECT_API_URL` to the test session's hosted API endpoint.
147 """
148 with temporary_settings(
149 {
150 PREFECT_API_URL: hosted_api_server,
151 PREFECT_SERVER_CSRF_PROTECTION_ENABLED: False,
152 }
153 ):
154 yield hosted_api_server
157@pytest.fixture
158def disable_hosted_api_server() -> Generator[None, None, None]:
159 """
160 Disables the hosted API server by setting `PREFECT_API_URL` to `None`.
161 """
162 with temporary_settings(
163 {
164 PREFECT_API_URL: None,
165 }
166 ):
167 yield
170@pytest.fixture
171def enable_ephemeral_server(
172 disable_hosted_api_server: None,
173) -> Generator[None, None, None]:
174 """
175 Enables the ephemeral server by setting `PREFECT_SERVER_ALLOW_EPHEMERAL_MODE` to `True`.
176 """
177 with temporary_settings(
178 {
179 PREFECT_SERVER_ALLOW_EPHEMERAL_MODE: True,
180 }
181 ):
182 yield
184 SubprocessASGIServer().stop()
187@pytest.fixture
188def mock_anyio_sleep(
189 monkeypatch: pytest.MonkeyPatch,
190) -> Generator[Callable[[float], None], None, None]:
191 """
192 Mock sleep used to not actually sleep but to set the current time to now + sleep
193 delay seconds while still yielding to other tasks in the event loop.
195 Provides "assert_sleeps_for" context manager which asserts a sleep time occurred
196 within the context while using the actual runtime of the context as a tolerance.
197 """
198 original_now = now
199 original_sleep = anyio.sleep
200 time_shift = 0.0
202 async def callback(delay_in_seconds: float) -> None:
203 nonlocal time_shift
204 time_shift += float(delay_in_seconds)
205 # Preserve yield effects of sleep
206 await original_sleep(0)
208 def latest_now(*args: Any) -> DateTime:
209 # Fast-forwards the time by the total sleep time
210 return original_now(*args).add(
211 # Ensure we retain float precision
212 seconds=int(time_shift),
213 microseconds=int((time_shift - int(time_shift)) * 1000000),
214 )
216 monkeypatch.setattr("prefect.types._datetime.now", latest_now)
218 sleep = AsyncMock(side_effect=callback)
219 monkeypatch.setattr("anyio.sleep", sleep)
221 @contextmanager
222 def assert_sleeps_for(
223 seconds: Union[int, float], extra_tolerance: Union[int, float] = 0
224 ):
225 """
226 Assert that sleep was called for N seconds during the duration of the context.
227 The runtime of the code during the context of the duration is used as an
228 upper tolerance to account for sleeps that start based on a time. This is less
229 brittle than attempting to freeze the current time.
231 If an integer is provided, the upper tolerance will be rounded up to the nearest
232 integer. If a float is provided, the upper tolerance will be a float.
234 An optional extra tolerance may be provided to account for any other issues.
235 This will be applied symmetrically.
236 """
237 run_t0 = original_now().timestamp()
238 sleep_t0 = time_shift
239 yield
240 run_t1 = original_now().timestamp()
241 sleep_t1 = time_shift
242 runtime = run_t1 - run_t0
243 if isinstance(seconds, int):
244 # Round tolerance up to the nearest integer if input is an int
245 runtime = int(runtime) + 1
246 sleeptime = sleep_t1 - sleep_t0
247 assert (
248 sleeptime - float(extra_tolerance)
249 <= seconds
250 <= sleeptime + runtime + extra_tolerance
251 ), (
252 f"Sleep was called for {sleeptime}; expected {seconds} with tolerance of"
253 f" +{runtime + extra_tolerance}, -{extra_tolerance}"
254 )
256 sleep.assert_sleeps_for = assert_sleeps_for
258 return sleep
261class Recorder:
262 connections: int
263 path: Optional[str]
264 events: List[Event]
265 token: Optional[str]
266 filter: Optional[EventFilter]
268 def __init__(self):
269 self.connections = 0
270 self.path = None
271 self.events = []
274class Puppeteer:
275 token: Optional[str]
277 hard_auth_failure: bool
278 refuse_any_further_connections: bool
279 hard_disconnect_after: Optional[UUID]
281 outgoing_events: List[Event]
283 def __init__(self):
284 self.hard_auth_failure = False
285 self.refuse_any_further_connections = False
286 self.hard_disconnect_after = None
287 self.outgoing_events = []
290@pytest.fixture
291def recorder() -> Recorder:
292 return Recorder()
295@pytest.fixture
296def puppeteer() -> Puppeteer:
297 return Puppeteer()
300@pytest.fixture
301async def events_server(
302 unused_tcp_port: int, recorder: Recorder, puppeteer: Puppeteer
303) -> AsyncGenerator[Server, None]:
304 server: Server
306 async def handler(socket: ServerConnection) -> None:
307 assert socket.request
308 path = socket.request.path
309 recorder.connections += 1
310 if puppeteer.refuse_any_further_connections:
311 raise ValueError("nope")
313 recorder.path = path
315 if path.endswith("/events/in"):
316 await incoming_events(socket)
317 elif path.endswith("/events/out"):
318 await outgoing_events(socket)
320 async def incoming_events(socket: ServerConnection):
321 while True:
322 try:
323 message = await socket.recv()
324 except ConnectionClosed:
325 return
327 event = Event.model_validate_json(message)
328 recorder.events.append(event)
330 if puppeteer.hard_disconnect_after == event.id:
331 puppeteer.hard_disconnect_after = None
332 raise ValueError("Disconnect after incoming event")
334 async def outgoing_events(socket: ServerConnection):
335 # 1. authentication
336 auth_message = json.loads(await socket.recv())
338 assert auth_message["type"] == "auth"
339 recorder.token = auth_message["token"]
340 if puppeteer.token != recorder.token:
341 if not puppeteer.hard_auth_failure:
342 await socket.send(
343 json.dumps({"type": "auth_failure", "reason": "nope"})
344 )
345 await socket.close(WS_1008_POLICY_VIOLATION)
346 return
348 await socket.send(json.dumps({"type": "auth_success"}))
350 # 2. filter
351 filter_message = json.loads(await socket.recv())
352 assert filter_message["type"] == "filter"
353 recorder.filter = EventFilter.model_validate(filter_message["filter"])
355 # 3. send events
356 for event in puppeteer.outgoing_events:
357 await socket.send(
358 json.dumps(
359 {
360 "type": "event",
361 "event": event.model_dump(mode="json"),
362 }
363 )
364 )
365 if puppeteer.hard_disconnect_after == event.id:
366 puppeteer.hard_disconnect_after = None
367 raise ValueError("zonk")
369 async with serve(handler, host="localhost", port=unused_tcp_port) as server:
370 yield server
373@pytest.fixture
374def events_api_url(events_server: Server, unused_tcp_port: int) -> str:
375 return f"http://localhost:{unused_tcp_port}"
378@pytest.fixture
379def events_cloud_api_url(events_server: Server, unused_tcp_port: int) -> str:
380 return f"http://localhost:{unused_tcp_port}/accounts/A/workspaces/W"
383@pytest.fixture
384def mock_should_emit_events(monkeypatch: pytest.MonkeyPatch) -> mock.Mock:
385 m = mock.Mock()
386 m.return_value = True
387 monkeypatch.setattr("prefect.events.utilities.should_emit_events", m)
388 return m
391@pytest.fixture
392def asserting_events_worker(
393 monkeypatch: pytest.MonkeyPatch,
394) -> Generator[EventsWorker, None, None]:
395 worker = EventsWorker.instance(AssertingEventsClient)
396 # Always yield the asserting worker when new instances are retrieved
397 monkeypatch.setattr(EventsWorker, "instance", lambda *_: worker)
398 try:
399 yield worker
400 finally:
401 worker.drain()
404@pytest.fixture
405def asserting_and_emitting_events_worker(
406 monkeypatch: pytest.MonkeyPatch,
407) -> Generator[EventsWorker, None, None]:
408 worker = EventsWorker.instance(AssertingPassthroughEventsClient)
409 # Always yield the asserting worker when new instances are retrieved
410 monkeypatch.setattr(EventsWorker, "instance", lambda *_: worker)
411 try:
412 yield worker
413 finally:
414 worker.drain()
417@pytest.fixture
418async def events_pipeline(
419 asserting_events_worker: EventsWorker,
420) -> AsyncGenerator[EventsPipeline, None]:
421 class AssertingEventsPipeline(EventsPipeline):
422 @sync_compatible
423 async def process_events(
424 self,
425 dequeue_events: bool = True,
426 min_events: int = 0,
427 timeout: int = 10,
428 ):
429 async def wait_for_min_events():
430 while len(asserting_events_worker._client.events) < min_events:
431 await asyncio.sleep(0.1)
433 if min_events:
434 try:
435 await asyncio.wait_for(wait_for_min_events(), timeout=timeout)
436 except TimeoutError:
437 raise TimeoutError(
438 f"Timed out waiting for {min_events} events after {timeout} seconds. Only observed {len(asserting_events_worker._client.events)} events."
439 )
440 else:
441 asserting_events_worker.wait_until_empty()
443 if dequeue_events:
444 events = asserting_events_worker._client.pop_events()
445 else:
446 events = asserting_events_worker._client.events
448 messages = self.events_to_messages(events)
449 await self.process_messages(messages)
451 yield AssertingEventsPipeline()
454@pytest.fixture
455async def emitting_events_pipeline(
456 asserting_and_emitting_events_worker: EventsWorker,
457) -> AsyncGenerator[EventsPipeline, None]:
458 class AssertingAndEmittingEventsPipeline(EventsPipeline):
459 @sync_compatible
460 async def process_events(self):
461 asserting_and_emitting_events_worker.wait_until_empty()
462 events = asserting_and_emitting_events_worker._client.pop_events()
464 messages = self.events_to_messages(events)
465 await self.process_messages(messages)
467 yield AssertingAndEmittingEventsPipeline()
470@pytest.fixture
471def reset_worker_events(
472 asserting_events_worker: EventsWorker,
473) -> Generator[None, None, None]:
474 yield
475 assert isinstance(asserting_events_worker._client, AssertingEventsClient)
476 asserting_events_worker._client.events = []