Coverage for /usr/local/lib/python3.12/site-packages/prefect/events/clients.py: 27%
326 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
1import abc 1a
2import asyncio 1a
3from datetime import timedelta 1a
4from types import TracebackType 1a
5from typing import ( 1a
6 TYPE_CHECKING,
7 Any,
8 ClassVar,
9 Dict,
10 List,
11 MutableMapping,
12 Optional,
13 Tuple,
14 Type,
15 cast,
16)
17from uuid import UUID 1a
19import orjson 1a
20from cachetools import TTLCache 1a
21from prometheus_client import Counter 1a
22from typing_extensions import Self 1a
23from websockets import Subprotocol 1a
24from websockets.asyncio.client import ClientConnection 1a
25from websockets.exceptions import ( 1a
26 ConnectionClosed,
27 ConnectionClosedError,
28 ConnectionClosedOK,
29)
31import prefect.types._datetime 1a
32from prefect._internal.websockets import websocket_connect 1a
33from prefect.events import Event 1a
34from prefect.logging import get_logger 1a
35from prefect.settings import ( 1a
36 PREFECT_API_AUTH_STRING,
37 PREFECT_API_KEY,
38 PREFECT_API_URL,
39 PREFECT_CLOUD_API_URL,
40 PREFECT_DEBUG_MODE,
41 PREFECT_SERVER_ALLOW_EPHEMERAL_MODE,
42)
44if TYPE_CHECKING: 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true1a
45 from prefect.events.filters import EventFilter
47EVENTS_EMITTED = Counter( 1a
48 "prefect_events_emitted",
49 "The number of events emitted by Prefect event clients",
50 labelnames=["client"],
51)
52EVENTS_OBSERVED = Counter( 1a
53 "prefect_events_observed",
54 "The number of events observed by Prefect event subscribers",
55 labelnames=["client"],
56)
57EVENT_WEBSOCKET_CONNECTIONS = Counter( 1a
58 "prefect_event_websocket_connections",
59 (
60 "The number of times Prefect event clients have connected to an event stream, "
61 "broken down by direction (in/out) and connection (initial/reconnect)"
62 ),
63 labelnames=["client", "direction", "connection"],
64)
65EVENT_WEBSOCKET_CHECKPOINTS = Counter( 1a
66 "prefect_event_websocket_checkpoints",
67 "The number of checkpoints performed by Prefect event clients",
68 labelnames=["client"],
69)
71if TYPE_CHECKING: 71 ↛ 72line 71 didn't jump to line 72 because the condition on line 71 was never true1a
72 import logging
74logger: "logging.Logger" = get_logger(__name__) 1a
77def http_to_ws(url: str) -> str: 1a
78 return url.replace("https://", "wss://").replace("http://", "ws://").rstrip("/")
81def events_in_socket_from_api_url(url: str) -> str: 1a
82 return http_to_ws(url) + "/events/in"
85def events_out_socket_from_api_url(url: str) -> str: 1a
86 return http_to_ws(url) + "/events/out"
89def get_events_client( 1a
90 reconnection_attempts: int = 10,
91 checkpoint_every: int = 700,
92) -> "EventsClient":
93 api_url = PREFECT_API_URL.value()
94 if isinstance(api_url, str) and api_url.startswith(PREFECT_CLOUD_API_URL.value()):
95 return PrefectCloudEventsClient(
96 reconnection_attempts=reconnection_attempts,
97 checkpoint_every=checkpoint_every,
98 )
99 elif api_url:
100 return PrefectEventsClient(
101 reconnection_attempts=reconnection_attempts,
102 checkpoint_every=checkpoint_every,
103 )
104 elif PREFECT_SERVER_ALLOW_EPHEMERAL_MODE:
105 from prefect.server.api.server import SubprocessASGIServer
107 server = SubprocessASGIServer()
108 server.start()
109 return PrefectEventsClient(
110 api_url=server.api_url,
111 reconnection_attempts=reconnection_attempts,
112 checkpoint_every=checkpoint_every,
113 )
114 else:
115 raise ValueError(
116 "No Prefect API URL provided. Please set PREFECT_API_URL to the address of a running Prefect server."
117 )
120def get_events_subscriber( 1a
121 filter: Optional["EventFilter"] = None,
122 reconnection_attempts: int = 10,
123) -> "PrefectEventSubscriber":
124 api_url = PREFECT_API_URL.value()
126 if isinstance(api_url, str) and api_url.startswith(PREFECT_CLOUD_API_URL.value()):
127 return PrefectCloudEventSubscriber(
128 filter=filter, reconnection_attempts=reconnection_attempts
129 )
130 elif api_url:
131 return PrefectEventSubscriber(
132 filter=filter, reconnection_attempts=reconnection_attempts
133 )
134 elif PREFECT_SERVER_ALLOW_EPHEMERAL_MODE:
135 from prefect.server.api.server import SubprocessASGIServer
137 server = SubprocessASGIServer()
138 server.start()
139 return PrefectEventSubscriber(
140 api_url=server.api_url,
141 filter=filter,
142 reconnection_attempts=reconnection_attempts,
143 )
144 else:
145 raise ValueError(
146 "No Prefect API URL provided. Please set PREFECT_API_URL to the address of a running Prefect server."
147 )
150class EventsClient(abc.ABC): 1a
151 """The abstract interface for all Prefect Events clients"""
153 @property 1a
154 def client_name(self) -> str: 1a
155 return self.__class__.__name__
157 async def emit(self, event: Event) -> None: 1a
158 """Emit a single event"""
159 if not hasattr(self, "_in_context"):
160 raise TypeError(
161 "Events may only be emitted while this client is being used as a "
162 "context manager"
163 )
165 try:
166 return await self._emit(event)
167 finally:
168 EVENTS_EMITTED.labels(self.client_name).inc()
170 @abc.abstractmethod 1a
171 async def _emit(self, event: Event) -> None: # pragma: no cover 1a
172 ...
174 async def __aenter__(self) -> Self: 1a
175 self._in_context = True
176 return self
178 async def __aexit__( 1a
179 self,
180 exc_type: Optional[Type[BaseException]],
181 exc_val: Optional[BaseException],
182 exc_tb: Optional[TracebackType],
183 ) -> None:
184 del self._in_context
185 return None
188class NullEventsClient(EventsClient): 1a
189 """A Prefect Events client implementation that does nothing"""
191 async def _emit(self, event: Event) -> None: 1a
192 pass
195class AssertingEventsClient(EventsClient): 1a
196 """A Prefect Events client that records all events sent to it for inspection during
197 tests."""
199 last: ClassVar["Optional[AssertingEventsClient]"] = None 1a
200 all: ClassVar[List["AssertingEventsClient"]] = [] 1a
202 args: tuple[Any, ...] 1a
203 kwargs: dict[str, Any] 1a
204 events: list[Event] 1a
206 def __init__(self, *args: Any, **kwargs: Any): 1a
207 AssertingEventsClient.last = self
208 AssertingEventsClient.all.append(self)
209 self.args = args
210 self.kwargs = kwargs
212 @classmethod 1a
213 def reset(cls) -> None: 1a
214 """Reset all captured instances and their events. For use between
215 tests"""
216 cls.last = None
217 cls.all = []
219 def pop_events(self) -> List[Event]: 1a
220 events = self.events
221 self.events = []
222 return events
224 async def _emit(self, event: Event) -> None: 1a
225 self.events.append(event)
227 async def __aenter__(self) -> Self: 1a
228 await super().__aenter__()
229 self.events = []
230 return self
233def _get_api_url_and_key( 1a
234 api_url: Optional[str], api_key: Optional[str]
235) -> Tuple[str, str]:
236 api_url = api_url or PREFECT_API_URL.value()
237 api_key = api_key or PREFECT_API_KEY.value()
239 if not api_url or not api_key:
240 raise ValueError(
241 "api_url and api_key must be provided or set in the Prefect configuration"
242 )
244 return api_url, api_key
247class PrefectEventsClient(EventsClient): 1a
248 """A Prefect Events client that streams events to a Prefect server"""
250 _websocket: Optional[ClientConnection] 1a
251 _unconfirmed_events: List[Event] 1a
253 def __init__( 1a
254 self,
255 api_url: Optional[str] = None,
256 reconnection_attempts: int = 10,
257 checkpoint_every: int = 700,
258 ):
259 """
260 Args:
261 api_url: The base URL for a Prefect server
262 reconnection_attempts: When the client is disconnected, how many times
263 the client should attempt to reconnect
264 checkpoint_every: How often the client should sync with the server to
265 confirm receipt of all previously sent events
266 """
267 api_url = api_url or PREFECT_API_URL.value()
268 if not api_url:
269 raise ValueError(
270 "api_url must be provided or set in the Prefect configuration"
271 )
273 self._events_socket_url = events_in_socket_from_api_url(api_url)
274 self._connect = websocket_connect(self._events_socket_url)
275 self._websocket = None
276 self._reconnection_attempts = reconnection_attempts
277 self._unconfirmed_events = []
278 self._checkpoint_every = checkpoint_every
280 async def __aenter__(self) -> Self: 1a
281 # Don't handle any errors in the initial connection, because these are most
282 # likely a permission or configuration issue that should propagate
283 await super().__aenter__()
284 await self._reconnect()
285 return self
287 async def __aexit__( 1a
288 self,
289 exc_type: Optional[Type[BaseException]],
290 exc_val: Optional[BaseException],
291 exc_tb: Optional[TracebackType],
292 ) -> None:
293 self._websocket = None
294 await self._connect.__aexit__(exc_type, exc_val, exc_tb)
295 return await super().__aexit__(exc_type, exc_val, exc_tb)
297 def _log_debug(self, message: str, *args: Any, **kwargs: Any) -> None: 1a
298 message = f"EventsClient(id={id(self)}): " + message
299 logger.debug(message, *args, **kwargs)
301 async def _reconnect(self) -> None: 1a
302 logger.debug("Reconnecting websocket connection.")
304 if self._websocket:
305 self._websocket = None
306 await self._connect.__aexit__(None, None, None)
307 logger.debug("Cleared existing websocket connection.")
309 try:
310 logger.debug("Opening websocket connection.")
311 self._websocket = await self._connect.__aenter__()
312 # make sure we have actually connected
313 logger.debug("Pinging to ensure websocket connected.")
314 pong = await self._websocket.ping()
315 await pong
316 logger.debug("Pong received. Websocket connected.")
317 except Exception as e:
318 # The client is frequently run in a background thread
319 # so we log an additional warning to ensure
320 # surfacing the error to the user.
321 logger.warning(
322 "Unable to connect to %r. "
323 "Please check your network settings to ensure websocket connections "
324 "to the API are allowed. Otherwise event data (including task run data) may be lost. "
325 "Reason: %s. "
326 "Set PREFECT_DEBUG_MODE=1 to see the full error.",
327 self._events_socket_url,
328 str(e),
329 exc_info=PREFECT_DEBUG_MODE.value(),
330 )
331 raise
333 events_to_resend = self._unconfirmed_events
334 logger.debug("Resending %s unconfirmed events.", len(events_to_resend))
335 # Clear the unconfirmed events here, because they are going back through emit
336 # and will be added again through the normal checkpointing process
337 self._unconfirmed_events = []
338 for event in events_to_resend:
339 await self.emit(event)
340 logger.debug("Finished resending unconfirmed events.")
342 async def _checkpoint(self) -> None: 1a
343 assert self._websocket
345 unconfirmed_count = len(self._unconfirmed_events)
347 if unconfirmed_count < self._checkpoint_every:
348 return
350 logger.debug("Pinging to checkpoint unconfirmed events.")
351 pong = await self._websocket.ping()
352 await pong
353 self._log_debug("Pong received. Events checkpointed.")
355 # once the pong returns, we know for sure that we've sent all the messages
356 # we had enqueued prior to that. There could be more that came in after, so
357 # don't clear the list, just the ones that we are sure of.
358 self._unconfirmed_events = self._unconfirmed_events[unconfirmed_count:]
360 EVENT_WEBSOCKET_CHECKPOINTS.labels(self.client_name).inc()
362 async def _emit(self, event: Event) -> None: 1a
363 self._log_debug("Emitting event id=%s.", event.id)
365 self._unconfirmed_events.append(event)
367 logger.debug(
368 "Added event id=%s to unconfirmed events list. "
369 "There are now %s unconfirmed events.",
370 event.id,
371 len(self._unconfirmed_events),
372 )
374 for i in range(self._reconnection_attempts + 1):
375 self._log_debug("Emit reconnection attempt %s.", i)
376 try:
377 # If we're here and the websocket is None, then we've had a failure in a
378 # previous reconnection attempt.
379 #
380 # Otherwise, after the first time through this loop, we're recovering
381 # from a ConnectionClosed, so reconnect now, resending any unconfirmed
382 # events before we send this one.
383 if not self._websocket or i > 0:
384 self._log_debug("Attempting websocket reconnection.")
385 await self._reconnect()
386 assert self._websocket
388 self._log_debug("Sending event id=%s.", event.id)
389 await self._websocket.send(event.model_dump_json())
390 self._log_debug("Checkpointing event id=%s.", event.id)
391 await self._checkpoint()
393 return
394 except ConnectionClosed:
395 self._log_debug("Got ConnectionClosed error.")
396 if i == self._reconnection_attempts:
397 # this was our final chance, raise the most recent error
398 raise
400 if i > 2:
401 # let the first two attempts happen quickly in case this is just
402 # a standard load balancer timeout, but after that, just take a
403 # beat to let things come back around.
404 logger.debug(
405 "Sleeping for 1 second before next reconnection attempt."
406 )
407 await asyncio.sleep(1)
410class AssertingPassthroughEventsClient(PrefectEventsClient): 1a
411 """A Prefect Events client that BOTH records all events sent to it for inspection
412 during tests AND sends them to a Prefect server."""
414 last: ClassVar["Optional[AssertingPassthroughEventsClient]"] = None 1a
415 all: ClassVar[list["AssertingPassthroughEventsClient"]] = [] 1a
417 args: tuple[Any, ...] 1a
418 kwargs: dict[str, Any] 1a
419 events: list[Event] 1a
421 def __init__(self, *args: Any, **kwargs: Any): 1a
422 super().__init__(*args, **kwargs)
423 AssertingPassthroughEventsClient.last = self
424 AssertingPassthroughEventsClient.all.append(self)
425 self.args = args
426 self.kwargs = kwargs
428 @classmethod 1a
429 def reset(cls) -> None: 1a
430 cls.last = None
431 cls.all = []
433 def pop_events(self) -> list[Event]: 1a
434 events = self.events
435 self.events = []
436 return events
438 async def _emit(self, event: Event) -> None: 1a
439 # actually send the event to the server
440 await super()._emit(event)
442 # record the event for inspection
443 self.events.append(event)
445 async def __aenter__(self) -> Self: 1a
446 await super().__aenter__()
447 self.events = []
448 return self
451class PrefectCloudEventsClient(PrefectEventsClient): 1a
452 """A Prefect Events client that streams events to a Prefect Cloud Workspace"""
454 def __init__( 1a
455 self,
456 api_url: Optional[str] = None,
457 api_key: Optional[str] = None,
458 reconnection_attempts: int = 10,
459 checkpoint_every: int = 700,
460 ):
461 """
462 Args:
463 api_url: The base URL for a Prefect Cloud workspace
464 api_key: The API of an actor with the manage_events scope
465 reconnection_attempts: When the client is disconnected, how many times
466 the client should attempt to reconnect
467 checkpoint_every: How often the client should sync with the server to
468 confirm receipt of all previously sent events
469 """
470 api_url, api_key = _get_api_url_and_key(api_url, api_key)
471 super().__init__(
472 api_url=api_url,
473 reconnection_attempts=reconnection_attempts,
474 checkpoint_every=checkpoint_every,
475 )
476 self._connect = websocket_connect(
477 self._events_socket_url,
478 additional_headers={"Authorization": f"bearer {api_key}"},
479 )
482SEEN_EVENTS_SIZE = 500_000 1a
483SEEN_EVENTS_TTL = 120 1a
486class PrefectEventSubscriber: 1a
487 """
488 Subscribes to a Prefect event stream, yielding events as they occur.
490 Example:
492 from prefect.events.clients import PrefectEventSubscriber
493 from prefect.events.filters import EventFilter, EventNameFilter
495 filter = EventFilter(event=EventNameFilter(prefix=["prefect.flow-run."]))
497 async with PrefectEventSubscriber(filter=filter) as subscriber:
498 async for event in subscriber:
499 print(event.occurred, event.resource.id, event.event)
501 """
503 _websocket: Optional[ClientConnection] 1a
504 _filter: "EventFilter" 1a
505 _seen_events: MutableMapping[UUID, bool] 1a
507 _api_key: Optional[str] 1a
508 _auth_token: Optional[str] 1a
510 def __init__( 1a
511 self,
512 api_url: Optional[str] = None,
513 filter: Optional["EventFilter"] = None,
514 reconnection_attempts: int = 10,
515 ):
516 """
517 Args:
518 api_url: The base URL for a Prefect Cloud workspace
519 api_key: The API of an actor with the manage_events scope
520 reconnection_attempts: When the client is disconnected, how many times
521 the client should attempt to reconnect
522 """
523 self._api_key = None
524 self._auth_token = PREFECT_API_AUTH_STRING.value()
526 if not api_url:
527 api_url = cast(str, PREFECT_API_URL.value())
529 from prefect.events.filters import EventFilter
531 self._filter = filter or EventFilter() # type: ignore[call-arg]
532 self._seen_events = TTLCache(maxsize=SEEN_EVENTS_SIZE, ttl=SEEN_EVENTS_TTL)
534 socket_url = events_out_socket_from_api_url(api_url)
536 logger.debug("Connecting to %s", socket_url)
538 self._connect = websocket_connect(
539 socket_url,
540 subprotocols=[Subprotocol("prefect")],
541 )
542 self._websocket = None
543 self._reconnection_attempts = reconnection_attempts
544 if self._reconnection_attempts < 0:
545 raise ValueError("reconnection_attempts must be a non-negative integer")
547 @property 1a
548 def client_name(self) -> str: 1a
549 return self.__class__.__name__
551 async def __aenter__(self) -> Self: 1a
552 # Don't handle any errors in the initial connection, because these are most
553 # likely a permission or configuration issue that should propagate
554 try:
555 await self._reconnect()
556 finally:
557 EVENT_WEBSOCKET_CONNECTIONS.labels(self.client_name, "out", "initial").inc()
558 return self
560 async def _reconnect(self) -> None: 1a
561 logger.debug("Reconnecting...")
562 if self._websocket:
563 self._websocket = None
564 await self._connect.__aexit__(None, None, None)
566 self._websocket = await self._connect.__aenter__()
568 # make sure we have actually connected
569 logger.debug(" pinging...")
570 pong = await self._websocket.ping()
571 await pong
573 logger.debug(" authenticating...")
574 # Use the API key (for Cloud) OR the auth token (for self-hosted with auth string)
575 token = self._api_key or self._auth_token
576 await self._websocket.send(
577 orjson.dumps({"type": "auth", "token": token}).decode()
578 )
580 try:
581 message: Dict[str, Any] = orjson.loads(await self._websocket.recv())
582 logger.debug(" auth result %s", message)
583 assert message["type"] == "auth_success", message.get("reason", "")
584 except AssertionError as e:
585 raise Exception(
586 "Unable to authenticate to the event stream. Please ensure the "
587 "provided api_key or auth_token you are using is valid for this environment. "
588 f"Reason: {e.args[0]}"
589 )
590 except ConnectionClosedError as e:
591 reason = getattr(e.rcvd, "reason", None)
592 msg = "Unable to authenticate to the event stream. Please ensure the "
593 msg += "provided api_key or auth_token you are using is valid for this environment. "
594 msg += f"Reason: {reason}" if reason else ""
595 raise Exception(msg) from e
597 from prefect.events.filters import EventOccurredFilter
599 self._filter.occurred = EventOccurredFilter(
600 since=prefect.types._datetime.now("UTC") - timedelta(minutes=1),
601 until=prefect.types._datetime.now("UTC") + timedelta(days=365),
602 )
604 logger.debug(" filtering events since %s...", self._filter.occurred.since)
605 filter_message = {
606 "type": "filter",
607 "filter": self._filter.model_dump(mode="json"),
608 }
609 await self._websocket.send(orjson.dumps(filter_message).decode())
611 async def __aexit__( 1a
612 self,
613 exc_type: Optional[Type[BaseException]],
614 exc_val: Optional[BaseException],
615 exc_tb: Optional[TracebackType],
616 ) -> None:
617 self._websocket = None
618 await self._connect.__aexit__(exc_type, exc_val, exc_tb)
620 def __aiter__(self) -> Self: 1a
621 return self
623 async def __anext__(self) -> Event: 1a
624 assert self._reconnection_attempts >= 0
625 for i in range(self._reconnection_attempts + 1): # pragma: no branch
626 try:
627 # If we're here and the websocket is None, then we've had a failure in a
628 # previous reconnection attempt.
629 #
630 # Otherwise, after the first time through this loop, we're recovering
631 # from a ConnectionClosed, so reconnect now.
632 if not self._websocket or i > 0:
633 try:
634 await self._reconnect()
635 finally:
636 EVENT_WEBSOCKET_CONNECTIONS.labels(
637 self.client_name, "out", "reconnect"
638 ).inc()
639 assert self._websocket
641 while True:
642 message = orjson.loads(await self._websocket.recv())
643 event: Event = Event.model_validate(message["event"])
645 if event.id in self._seen_events:
646 continue
647 self._seen_events[event.id] = True
649 try:
650 return event
651 finally:
652 EVENTS_OBSERVED.labels(self.client_name).inc()
653 except ConnectionClosedOK:
654 logger.debug('Connection closed with "OK" status')
655 raise StopAsyncIteration
656 except ConnectionClosed:
657 logger.debug(
658 "Connection closed with %s/%s attempts",
659 i + 1,
660 self._reconnection_attempts,
661 )
662 if i == self._reconnection_attempts:
663 # this was our final chance, raise the most recent error
664 raise
666 if i > 2:
667 # let the first two attempts happen quickly in case this is just
668 # a standard load balancer timeout, but after that, just take a
669 # beat to let things come back around.
670 await asyncio.sleep(1)
671 raise StopAsyncIteration
674class PrefectCloudEventSubscriber(PrefectEventSubscriber): 1a
675 def __init__( 1a
676 self,
677 api_url: Optional[str] = None,
678 api_key: Optional[str] = None,
679 filter: Optional["EventFilter"] = None,
680 reconnection_attempts: int = 10,
681 ):
682 """
683 Args:
684 api_url: The base URL for a Prefect Cloud workspace
685 api_key: The API of an actor with the manage_events scope
686 reconnection_attempts: When the client is disconnected, how many times
687 the client should attempt to reconnect
688 """
689 api_url, api_key = _get_api_url_and_key(api_url, api_key)
691 super().__init__(
692 api_url=api_url,
693 filter=filter,
694 reconnection_attempts=reconnection_attempts,
695 )
697 self._api_key = api_key
700class PrefectCloudAccountEventSubscriber(PrefectCloudEventSubscriber): 1a
701 def __init__( 1a
702 self,
703 api_url: Optional[str] = None,
704 api_key: Optional[str] = None,
705 filter: Optional["EventFilter"] = None,
706 reconnection_attempts: int = 10,
707 ):
708 """
709 Args:
710 api_url: The base URL for a Prefect Cloud workspace
711 api_key: The API of an actor with the manage_events scope
712 reconnection_attempts: When the client is disconnected, how many times
713 the client should attempt to reconnect
714 """
715 api_url, api_key = _get_api_url_and_key(api_url, api_key)
717 account_api_url, _, _ = api_url.partition("/workspaces/")
719 super().__init__(
720 api_url=account_api_url,
721 filter=filter,
722 reconnection_attempts=reconnection_attempts,
723 )
725 self._api_key = api_key