Coverage for /usr/local/lib/python3.12/site-packages/prefect/logging/clients.py: 26%

149 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 11:21 +0000

1import asyncio 1a

2from datetime import timedelta 1a

3from types import TracebackType 1a

4from typing import ( 1a

5 TYPE_CHECKING, 

6 Any, 

7 MutableMapping, 

8 Optional, 

9 Tuple, 

10 Type, 

11 cast, 

12) 

13from uuid import UUID 1a

14 

15import orjson 1a

16from cachetools import TTLCache 1a

17from prometheus_client import Counter 1a

18from typing_extensions import Self 1a

19from websockets import Subprotocol 1a

20from websockets.asyncio.client import ClientConnection 1a

21from websockets.exceptions import ( 1a

22 ConnectionClosed, 

23 ConnectionClosedError, 

24 ConnectionClosedOK, 

25) 

26 

27from prefect._internal.websockets import ( 1a

28 create_ssl_context_for_websocket, 

29 websocket_connect, 

30) 

31from prefect.client.schemas.objects import Log 1a

32from prefect.logging import get_logger 1a

33from prefect.settings import ( 1a

34 PREFECT_API_AUTH_STRING, 

35 PREFECT_API_KEY, 

36 PREFECT_API_URL, 

37 PREFECT_CLOUD_API_URL, 

38 PREFECT_SERVER_ALLOW_EPHEMERAL_MODE, 

39) 

40from prefect.types._datetime import now 1a

41 

42if TYPE_CHECKING: 42 ↛ 43line 42 didn't jump to line 43 because the condition on line 42 was never true1a

43 import logging 

44 

45 from prefect.client.schemas.filters import LogFilter 

46 

47logger: "logging.Logger" = get_logger(__name__) 1a

48 

49LOGS_OBSERVED = Counter( 1a

50 "prefect_logs_observed", 

51 "The number of logs observed by Prefect log subscribers", 

52 labelnames=["client"], 

53) 

54LOG_WEBSOCKET_CONNECTIONS = Counter( 1a

55 "prefect_log_websocket_connections", 

56 ( 

57 "The number of times Prefect log clients have connected to a log stream, " 

58 "broken down by direction (in/out) and connection (initial/reconnect)" 

59 ), 

60 labelnames=["client", "direction", "connection"], 

61) 

62 

63SEEN_LOGS_SIZE = 500_000 1a

64SEEN_LOGS_TTL = 120 1a

65 

66 

67def http_to_ws(url: str) -> str: 1a

68 return url.replace("https://", "wss://").replace("http://", "ws://").rstrip("/") 

69 

70 

71def logs_out_socket_from_api_url(url: str) -> str: 1a

72 return http_to_ws(url) + "/logs/out" 

73 

74 

75def _get_api_url_and_key( 1a

76 api_url: Optional[str], api_key: Optional[str] 

77) -> Tuple[str, str]: 

78 api_url = api_url or PREFECT_API_URL.value() 

79 api_key = api_key or PREFECT_API_KEY.value() 

80 

81 if not api_url or not api_key: 

82 raise ValueError( 

83 "api_url and api_key must be provided or set in the Prefect configuration" 

84 ) 

85 

86 return api_url, api_key 

87 

88 

89def get_logs_subscriber( 1a

90 filter: Optional["LogFilter"] = None, 

91 reconnection_attempts: int = 10, 

92) -> "PrefectLogsSubscriber": 

93 """ 

94 Get a logs subscriber based on the current Prefect configuration. 

95 

96 Similar to get_events_subscriber, this automatically detects whether 

97 you're using Prefect Cloud or OSS and returns the appropriate subscriber. 

98 """ 

99 api_url = PREFECT_API_URL.value() 

100 

101 if isinstance(api_url, str) and api_url.startswith(PREFECT_CLOUD_API_URL.value()): 

102 return PrefectCloudLogsSubscriber( 

103 filter=filter, reconnection_attempts=reconnection_attempts 

104 ) 

105 elif api_url: 

106 return PrefectLogsSubscriber( 

107 api_url=api_url, 

108 filter=filter, 

109 reconnection_attempts=reconnection_attempts, 

110 ) 

111 elif PREFECT_SERVER_ALLOW_EPHEMERAL_MODE: 

112 from prefect.server.api.server import SubprocessASGIServer 

113 

114 server = SubprocessASGIServer() 

115 server.start() 

116 return PrefectLogsSubscriber( 

117 api_url=server.api_url, 

118 filter=filter, 

119 reconnection_attempts=reconnection_attempts, 

120 ) 

121 else: 

122 raise ValueError( 

123 "No Prefect API URL provided. Please set PREFECT_API_URL to the address of a running Prefect server." 

124 ) 

125 

126 

127class PrefectLogsSubscriber: 1a

128 """ 

129 Subscribes to a Prefect logs stream, yielding logs as they occur. 

130 

131 Example: 

132 

133 from prefect.logging.clients import PrefectLogsSubscriber 

134 from prefect.client.schemas.filters import LogFilter, LogFilterLevel 

135 import logging 

136 

137 filter = LogFilter(level=LogFilterLevel(ge_=logging.INFO)) 

138 

139 async with PrefectLogsSubscriber(filter=filter) as subscriber: 

140 async for log in subscriber: 

141 print(log.timestamp, log.level, log.message) 

142 

143 """ 

144 

145 _websocket: Optional[ClientConnection] 1a

146 _filter: "LogFilter" 1a

147 _seen_logs: MutableMapping[UUID, bool] 1a

148 

149 _api_key: Optional[str] 1a

150 _auth_token: Optional[str] 1a

151 

152 def __init__( 1a

153 self, 

154 api_url: Optional[str] = None, 

155 filter: Optional["LogFilter"] = None, 

156 reconnection_attempts: int = 10, 

157 ): 

158 """ 

159 Args: 

160 api_url: The base URL for a Prefect workspace 

161 filter: Log filter to apply 

162 reconnection_attempts: When the client is disconnected, how many times 

163 the client should attempt to reconnect 

164 """ 

165 self._api_key = None 

166 self._auth_token = PREFECT_API_AUTH_STRING.value() 

167 

168 if not api_url: 

169 api_url = cast(str, PREFECT_API_URL.value()) 

170 

171 from prefect.client.schemas.filters import LogFilter 

172 

173 self._filter = filter or LogFilter() # type: ignore[call-arg] 

174 self._seen_logs = TTLCache(maxsize=SEEN_LOGS_SIZE, ttl=SEEN_LOGS_TTL) 

175 

176 socket_url = logs_out_socket_from_api_url(api_url) 

177 

178 logger.debug("Connecting to %s", socket_url) 

179 

180 # Configure SSL context for the connection 

181 ssl_context = create_ssl_context_for_websocket(socket_url) 

182 connect_kwargs: dict[str, Any] = {"subprotocols": [Subprotocol("prefect")]} 

183 if ssl_context: 

184 connect_kwargs["ssl"] = ssl_context 

185 

186 self._connect = websocket_connect(socket_url, **connect_kwargs) 

187 self._websocket = None 

188 self._reconnection_attempts = reconnection_attempts 

189 if self._reconnection_attempts < 0: 

190 raise ValueError("reconnection_attempts must be a non-negative integer") 

191 

192 @property 1a

193 def client_name(self) -> str: 1a

194 return self.__class__.__name__ 

195 

196 async def __aenter__(self) -> Self: 1a

197 # Don't handle any errors in the initial connection, because these are most 

198 # likely a permission or configuration issue that should propagate 

199 try: 

200 await self._reconnect() 

201 finally: 

202 LOG_WEBSOCKET_CONNECTIONS.labels(self.client_name, "out", "initial").inc() 

203 return self 

204 

205 async def _reconnect(self) -> None: 1a

206 logger.debug("Reconnecting...") 

207 if self._websocket: 

208 self._websocket = None 

209 await self._connect.__aexit__(None, None, None) 

210 

211 self._websocket = await self._connect.__aenter__() 

212 

213 # make sure we have actually connected 

214 logger.debug(" pinging...") 

215 pong = await self._websocket.ping() 

216 await pong 

217 

218 # Send authentication message - logs WebSocket requires auth handshake 

219 auth_token = self._api_key or self._auth_token 

220 auth_message = {"type": "auth", "token": auth_token} 

221 logger.debug(" authenticating...") 

222 await self._websocket.send(orjson.dumps(auth_message).decode()) 

223 

224 # Wait for auth response 

225 try: 

226 message = orjson.loads(await self._websocket.recv()) 

227 logger.debug(" auth result %s", message) 

228 assert message["type"] == "auth_success", message.get("reason", "") 

229 except AssertionError as e: 

230 raise Exception( 

231 "Unable to authenticate to the log stream. Please ensure the " 

232 "provided api_key or auth_token you are using is valid for this environment. " 

233 f"Reason: {e.args[0]}" 

234 ) 

235 except ConnectionClosedError as e: 

236 reason = getattr(e.rcvd, "reason", None) 

237 msg = "Unable to authenticate to the log stream. Please ensure the " 

238 msg += "provided api_key or auth_token you are using is valid for this environment. " 

239 msg += f"Reason: {reason}" if reason else "" 

240 raise Exception(msg) from e 

241 

242 from prefect.client.schemas.filters import LogFilterTimestamp 

243 

244 current_time = now("UTC") 

245 self._filter.timestamp = LogFilterTimestamp( 

246 after_=current_time - timedelta(minutes=1), # type: ignore[arg-type] 

247 before_=current_time + timedelta(days=365), # type: ignore[arg-type] 

248 ) 

249 

250 logger.debug(" filtering logs since %s...", self._filter.timestamp.after_) 

251 filter_message = { 

252 "type": "filter", 

253 "filter": self._filter.model_dump(mode="json"), 

254 } 

255 await self._websocket.send(orjson.dumps(filter_message).decode()) 

256 

257 async def __aexit__( 1a

258 self, 

259 exc_type: Optional[Type[BaseException]], 

260 exc_val: Optional[BaseException], 

261 exc_tb: Optional[TracebackType], 

262 ) -> None: 

263 self._websocket = None 

264 await self._connect.__aexit__(exc_type, exc_val, exc_tb) 

265 

266 def __aiter__(self) -> Self: 1a

267 return self 

268 

269 async def __anext__(self) -> Log: 1a

270 assert self._reconnection_attempts >= 0 

271 for i in range(self._reconnection_attempts + 1): # pragma: no branch 

272 try: 

273 # If we're here and the websocket is None, then we've had a failure in a 

274 # previous reconnection attempt. 

275 # 

276 # Otherwise, after the first time through this loop, we're recovering 

277 # from a ConnectionClosed, so reconnect now. 

278 if not self._websocket or i > 0: 

279 try: 

280 await self._reconnect() 

281 finally: 

282 LOG_WEBSOCKET_CONNECTIONS.labels( 

283 self.client_name, "out", "reconnect" 

284 ).inc() 

285 assert self._websocket 

286 

287 while True: 

288 message = orjson.loads(await self._websocket.recv()) 

289 log: Log = Log.model_validate(message["log"]) 

290 

291 if log.id in self._seen_logs: 

292 continue 

293 self._seen_logs[log.id] = True 

294 

295 try: 

296 return log 

297 finally: 

298 LOGS_OBSERVED.labels(self.client_name).inc() 

299 

300 except ConnectionClosedOK: 

301 logger.debug('Connection closed with "OK" status') 

302 raise StopAsyncIteration 

303 except ConnectionClosed: 

304 logger.debug( 

305 "Connection closed with %s/%s attempts", 

306 i + 1, 

307 self._reconnection_attempts, 

308 ) 

309 if i == self._reconnection_attempts: 

310 # this was our final chance, raise the most recent error 

311 raise 

312 

313 if i > 2: 

314 # let the first two attempts happen quickly in case this is just 

315 # a standard load balancer timeout, but after that, just take a 

316 # beat to let things come back around. 

317 await asyncio.sleep(1) 

318 raise StopAsyncIteration 

319 

320 

321class PrefectCloudLogsSubscriber(PrefectLogsSubscriber): 1a

322 """Logs subscriber for Prefect Cloud""" 

323 

324 def __init__( 1a

325 self, 

326 api_url: Optional[str] = None, 

327 api_key: Optional[str] = None, 

328 filter: Optional["LogFilter"] = None, 

329 reconnection_attempts: int = 10, 

330 ): 

331 """ 

332 Args: 

333 api_url: The base URL for a Prefect Cloud workspace 

334 api_key: The API key of an actor with the see_flows scope 

335 filter: Log filter to apply 

336 reconnection_attempts: When the client is disconnected, how many times 

337 the client should attempt to reconnect 

338 """ 

339 api_url, api_key = _get_api_url_and_key(api_url, api_key) 

340 

341 super().__init__( 

342 api_url=api_url, 

343 filter=filter, 

344 reconnection_attempts=reconnection_attempts, 

345 ) 

346 

347 self._api_key = api_key