Coverage for /usr/local/lib/python3.12/site-packages/prefect/client/subscriptions.py: 26%
75 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
1import asyncio 1a
2from collections.abc import Iterable 1a
3from logging import Logger 1a
4from typing import Any, Generic, Optional, TypeVar 1a
6import orjson 1a
7import websockets 1a
8import websockets.asyncio.client 1a
9import websockets.exceptions 1a
10from starlette.status import WS_1008_POLICY_VIOLATION 1a
11from typing_extensions import Self 1a
13from prefect._internal.schemas.bases import IDBaseModel 1a
14from prefect._internal.websockets import websocket_connect 1a
15from prefect.logging import get_logger 1a
16from prefect.settings import get_current_settings 1a
18logger: Logger = get_logger(__name__) 1a
20S = TypeVar("S", bound=IDBaseModel) 1a
23class Subscription(Generic[S]): 1a
24 def __init__( 1a
25 self,
26 model: type[S],
27 path: str,
28 keys: Iterable[str],
29 client_id: Optional[str] = None,
30 base_url: Optional[str] = None,
31 ):
32 self.model = model
33 self.client_id = client_id
34 base_url = base_url.replace("http", "ws", 1) if base_url else None
35 self.subscription_url: str = f"{base_url}{path}"
37 self.keys: list[str] = list(keys)
39 self._connect = websocket_connect(
40 self.subscription_url,
41 subprotocols=[websockets.Subprotocol("prefect")],
42 )
43 self._websocket = None
45 def __aiter__(self) -> Self: 1a
46 return self
48 @property 1a
49 def websocket(self) -> websockets.asyncio.client.ClientConnection: 1a
50 if not self._websocket:
51 raise RuntimeError("Subscription is not connected")
52 return self._websocket
54 async def __anext__(self) -> S: 1a
55 while True:
56 try:
57 await self._ensure_connected()
58 message = await self.websocket.recv()
60 await self.websocket.send(orjson.dumps({"type": "ack"}).decode())
62 return self.model.model_validate_json(message)
63 except (
64 ConnectionRefusedError,
65 websockets.exceptions.ConnectionClosedError,
66 ):
67 self._websocket = None
68 if hasattr(self._connect, "protocol"):
69 await self._connect.__aexit__(None, None, None)
70 await asyncio.sleep(0.5)
72 async def _ensure_connected(self): 1a
73 if self._websocket:
74 return
76 websocket = await self._connect.__aenter__()
78 try:
79 settings = get_current_settings()
80 auth_token = (
81 settings.api.auth_string.get_secret_value()
82 if settings.api.auth_string
83 else None
84 )
85 api_key = settings.api.key.get_secret_value() if settings.api.key else None
86 token = auth_token or api_key # Prioritize auth_token
88 await websocket.send(
89 orjson.dumps({"type": "auth", "token": token}).decode()
90 )
92 auth: dict[str, Any] = orjson.loads(await websocket.recv())
93 assert auth["type"] == "auth_success", auth.get("message")
95 message: dict[str, Any] = {"type": "subscribe", "keys": self.keys}
96 if self.client_id:
97 message.update({"client_id": self.client_id})
99 await websocket.send(orjson.dumps(message).decode())
100 except (
101 AssertionError,
102 websockets.exceptions.ConnectionClosedError,
103 ) as e:
104 if isinstance(e, AssertionError) or (
105 e.rcvd and e.rcvd.code == WS_1008_POLICY_VIOLATION
106 ):
107 if isinstance(e, AssertionError):
108 reason = e.args[0]
109 elif e.rcvd and e.rcvd.reason:
110 reason = e.rcvd.reason
111 else:
112 reason = "unknown"
113 else:
114 reason = None
116 if reason:
117 error_message = (
118 "Unable to authenticate to the subscription. Please ensure the provided "
119 "`PREFECT_API_AUTH_STRING` (for self-hosted with auth string) or "
120 "`PREFECT_API_KEY` (for Cloud or self-hosted with API key) "
121 f"you are using is valid for this environment. Reason: {reason}"
122 )
123 raise Exception(error_message) from e
124 raise
125 else:
126 self._websocket = websocket
128 def __repr__(self) -> str: 1a
129 return f"{type(self).__name__}[{self.model.__name__}]"