Coverage for /usr/local/lib/python3.12/site-packages/prefect/client/subscriptions.py: 26%

75 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 11:21 +0000

1import asyncio 1a

2from collections.abc import Iterable 1a

3from logging import Logger 1a

4from typing import Any, Generic, Optional, TypeVar 1a

5 

6import orjson 1a

7import websockets 1a

8import websockets.asyncio.client 1a

9import websockets.exceptions 1a

10from starlette.status import WS_1008_POLICY_VIOLATION 1a

11from typing_extensions import Self 1a

12 

13from prefect._internal.schemas.bases import IDBaseModel 1a

14from prefect._internal.websockets import websocket_connect 1a

15from prefect.logging import get_logger 1a

16from prefect.settings import get_current_settings 1a

17 

18logger: Logger = get_logger(__name__) 1a

19 

20S = TypeVar("S", bound=IDBaseModel) 1a

21 

22 

23class Subscription(Generic[S]): 1a

24 def __init__( 1a

25 self, 

26 model: type[S], 

27 path: str, 

28 keys: Iterable[str], 

29 client_id: Optional[str] = None, 

30 base_url: Optional[str] = None, 

31 ): 

32 self.model = model 

33 self.client_id = client_id 

34 base_url = base_url.replace("http", "ws", 1) if base_url else None 

35 self.subscription_url: str = f"{base_url}{path}" 

36 

37 self.keys: list[str] = list(keys) 

38 

39 self._connect = websocket_connect( 

40 self.subscription_url, 

41 subprotocols=[websockets.Subprotocol("prefect")], 

42 ) 

43 self._websocket = None 

44 

45 def __aiter__(self) -> Self: 1a

46 return self 

47 

48 @property 1a

49 def websocket(self) -> websockets.asyncio.client.ClientConnection: 1a

50 if not self._websocket: 

51 raise RuntimeError("Subscription is not connected") 

52 return self._websocket 

53 

54 async def __anext__(self) -> S: 1a

55 while True: 

56 try: 

57 await self._ensure_connected() 

58 message = await self.websocket.recv() 

59 

60 await self.websocket.send(orjson.dumps({"type": "ack"}).decode()) 

61 

62 return self.model.model_validate_json(message) 

63 except ( 

64 ConnectionRefusedError, 

65 websockets.exceptions.ConnectionClosedError, 

66 ): 

67 self._websocket = None 

68 if hasattr(self._connect, "protocol"): 

69 await self._connect.__aexit__(None, None, None) 

70 await asyncio.sleep(0.5) 

71 

72 async def _ensure_connected(self): 1a

73 if self._websocket: 

74 return 

75 

76 websocket = await self._connect.__aenter__() 

77 

78 try: 

79 settings = get_current_settings() 

80 auth_token = ( 

81 settings.api.auth_string.get_secret_value() 

82 if settings.api.auth_string 

83 else None 

84 ) 

85 api_key = settings.api.key.get_secret_value() if settings.api.key else None 

86 token = auth_token or api_key # Prioritize auth_token 

87 

88 await websocket.send( 

89 orjson.dumps({"type": "auth", "token": token}).decode() 

90 ) 

91 

92 auth: dict[str, Any] = orjson.loads(await websocket.recv()) 

93 assert auth["type"] == "auth_success", auth.get("message") 

94 

95 message: dict[str, Any] = {"type": "subscribe", "keys": self.keys} 

96 if self.client_id: 

97 message.update({"client_id": self.client_id}) 

98 

99 await websocket.send(orjson.dumps(message).decode()) 

100 except ( 

101 AssertionError, 

102 websockets.exceptions.ConnectionClosedError, 

103 ) as e: 

104 if isinstance(e, AssertionError) or ( 

105 e.rcvd and e.rcvd.code == WS_1008_POLICY_VIOLATION 

106 ): 

107 if isinstance(e, AssertionError): 

108 reason = e.args[0] 

109 elif e.rcvd and e.rcvd.reason: 

110 reason = e.rcvd.reason 

111 else: 

112 reason = "unknown" 

113 else: 

114 reason = None 

115 

116 if reason: 

117 error_message = ( 

118 "Unable to authenticate to the subscription. Please ensure the provided " 

119 "`PREFECT_API_AUTH_STRING` (for self-hosted with auth string) or " 

120 "`PREFECT_API_KEY` (for Cloud or self-hosted with API key) " 

121 f"you are using is valid for this environment. Reason: {reason}" 

122 ) 

123 raise Exception(error_message) from e 

124 raise 

125 else: 

126 self._websocket = websocket 

127 

128 def __repr__(self) -> str: 1a

129 return f"{type(self).__name__}[{self.model.__name__}]"