Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/utilities/subscriptions.py: 23%

50 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 11:21 +0000

1import asyncio 1a

2from asyncio import IncompleteReadError as IOError 1a

3from logging import Logger 1a

4from typing import Optional 1a

5 

6from fastapi import WebSocket 1a

7from starlette.status import WS_1002_PROTOCOL_ERROR, WS_1008_POLICY_VIOLATION 1a

8from starlette.websockets import WebSocketDisconnect 1a

9from websockets.exceptions import ConnectionClosed 1a

10 

11from prefect.logging import get_logger 1a

12from prefect.settings import get_current_settings 1a

13 

14NORMAL_DISCONNECT_EXCEPTIONS = (IOError, ConnectionClosed, WebSocketDisconnect) 1a

15 

16logger: Logger = get_logger("prefect.server.utilities.subscriptions") 1a

17 

18 

19async def accept_prefect_socket(websocket: WebSocket) -> Optional[WebSocket]: 1a

20 subprotocols = websocket.headers.get("Sec-WebSocket-Protocol", "").split(",") 

21 if "prefect" not in subprotocols: 

22 return await websocket.close(WS_1002_PROTOCOL_ERROR) 

23 

24 await websocket.accept(subprotocol="prefect") 

25 

26 try: 

27 # Websocket connections are authenticated via messages. The first 

28 # message is expected to be an auth message, and if any other type of 

29 # message is received then the connection will be closed. 

30 # 

31 # The protocol requires receiving an auth message for compatibility 

32 # with Prefect Cloud, even if server-side auth is not configured. 

33 message = await websocket.receive_json() 

34 

35 auth_setting = ( 

36 auth_setting_secret.get_secret_value() 

37 if (auth_setting_secret := get_current_settings().server.api.auth_string) 

38 else None 

39 ) 

40 logger.debug( 

41 f"PREFECT_SERVER_API_AUTH_STRING setting: {'*' * len(auth_setting) if auth_setting else 'Not set'}" 

42 ) 

43 

44 if message.get("type") != "auth": 

45 logger.warning( 

46 "WebSocket connection closed: Expected 'auth' message first." 

47 ) 

48 return await websocket.close( 

49 WS_1008_POLICY_VIOLATION, reason="Expected 'auth' message" 

50 ) 

51 

52 # Check authentication if PREFECT_SERVER_API_AUTH_STRING is set 

53 if auth_setting: 

54 received_token = message.get("token") 

55 logger.debug( 

56 f"Auth required. Received token: {'*' * len(received_token) if received_token else 'None'}" 

57 ) 

58 if not received_token: 

59 logger.warning( 

60 "WebSocket connection closed: Auth required but no token received." 

61 ) 

62 return await websocket.close( 

63 WS_1008_POLICY_VIOLATION, 

64 reason="Auth required but no token provided", 

65 ) 

66 

67 if received_token != auth_setting: 

68 logger.warning("WebSocket connection closed: Invalid token.") 

69 return await websocket.close( 

70 WS_1008_POLICY_VIOLATION, reason="Invalid token" 

71 ) 

72 logger.debug("WebSocket token authentication successful.") 

73 else: 

74 logger.debug("No server auth string set, skipping token check.") 

75 

76 await websocket.send_json({"type": "auth_success"}) 

77 logger.debug("Sent auth_success to WebSocket.") 

78 return websocket 

79 

80 except NORMAL_DISCONNECT_EXCEPTIONS: 

81 # it's fine if a client disconnects either normally or abnormally 

82 return None 

83 

84 

85async def still_connected(websocket: WebSocket) -> bool: 1a

86 """Checks that a client websocket still seems to be connected during a period where 

87 the server is expected to be sending messages.""" 

88 try: 

89 await asyncio.wait_for(websocket.receive(), timeout=0.1) 

90 return True # this should never happen, but if it does, we're still connected 

91 except asyncio.TimeoutError: 

92 # The fact that we timed out rather than getting another kind of error 

93 # here means we're still connected to our client, so we can continue to send 

94 # events. 

95 return True 

96 except RuntimeError: 

97 # starlette raises this if we test a client that's disconnected 

98 return False 

99 except NORMAL_DISCONNECT_EXCEPTIONS: 

100 return False