Coverage for /usr/local/lib/python3.12/site-packages/prefect/runner/_observers.py: 21%

91 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1from __future__ import annotations 1a

2 

3import asyncio 1a

4import uuid 1a

5from contextlib import AsyncExitStack 1a

6from typing import Any, Protocol 1a

7 

8from prefect.client.orchestration import PrefectClient, get_client 1a

9from prefect.client.schemas.filters import ( 1a

10 FlowRunFilter, 

11 FlowRunFilterId, 

12 FlowRunFilterState, 

13 FlowRunFilterStateName, 

14 FlowRunFilterStateType, 

15) 

16from prefect.client.schemas.objects import StateType 1a

17from prefect.events.clients import PrefectEventSubscriber, get_events_subscriber 1a

18from prefect.events.filters import EventFilter, EventNameFilter 1a

19from prefect.logging.loggers import get_logger 1a

20from prefect.utilities.services import critical_service_loop 1a

21 

22 

23class OnCancellingCallback(Protocol): 1a

24 def __call__(self, flow_run_id: uuid.UUID) -> None: ... 24 ↛ exitline 24 didn't return from function '__call__' because 1a

25 

26 

27class FlowRunCancellingObserver: 1a

28 def __init__( 1a

29 self, on_cancelling: OnCancellingCallback, polling_interval: float = 10 

30 ): 

31 """ 

32 Observer that cancels flow runs when they are marked as cancelling. 

33 

34 Will use a websocket connection to listen for cancelling flow run events by default with a fallback 

35 to polling when the websocket connection is lost. 

36 

37 Args: 

38 on_cancelling: Callback to call when a flow run is marked as cancelling. 

39 flow_run_ids: List of flow run IDs to poll when websocket connection is lost. 

40 polling_interval: Interval in seconds to poll for cancelling flow runs when websocket connection is lost. 

41 """ 

42 self.logger = get_logger("FlowRunCancellingObserver") 

43 self.on_cancelling = on_cancelling 

44 self.polling_interval = polling_interval 

45 self._in_flight_flow_run_ids: set[uuid.UUID] = set() 

46 self._events_subscriber: PrefectEventSubscriber | None 

47 self._exit_stack = AsyncExitStack() 

48 self._consumer_task: asyncio.Task[None] | None = None 

49 self._polling_task: asyncio.Task[None] | None = None 

50 self._is_shutting_down = False 

51 self._client: PrefectClient | None = None 

52 self._cancelling_flow_run_ids: set[uuid.UUID] = set() 

53 

54 def add_in_flight_flow_run_id(self, flow_run_id: uuid.UUID): 1a

55 self.logger.debug("Adding in-flight flow run ID: %s", flow_run_id) 

56 self._in_flight_flow_run_ids.add(flow_run_id) 

57 

58 def remove_in_flight_flow_run_id(self, flow_run_id: uuid.UUID): 1a

59 self.logger.debug("Removing in-flight flow run ID: %s", flow_run_id) 

60 self._in_flight_flow_run_ids.discard(flow_run_id) 

61 

62 async def _consume_events(self): 1a

63 if self._events_subscriber is None: 

64 raise RuntimeError( 

65 "Events subscriber not initialized. Please use `async with` to initialize the observer." 

66 ) 

67 async for event in self._events_subscriber: 

68 try: 

69 flow_run_id = uuid.UUID( 

70 event.resource["prefect.resource.id"].replace( 

71 "prefect.flow-run.", "" 

72 ) 

73 ) 

74 self.on_cancelling(flow_run_id) 

75 except ValueError: 

76 self.logger.debug( 

77 "Received event with invalid flow run ID: %s", 

78 event.resource["prefect.resource.id"], 

79 ) 

80 

81 def _start_polling_task(self, task: asyncio.Task[None]): 1a

82 if task.cancelled(): 

83 # If the consumer task was cancelled, the observer is shutting down 

84 # and we don't need to start the polling task 

85 return 

86 if exc := task.exception(): 

87 self.logger.debug( 

88 "The FlowRunCancellingObserver websocket failed with an exception. Switching to polling mode.", 

89 exc_info=exc, 

90 ) 

91 self._polling_task = asyncio.create_task( 

92 critical_service_loop( 

93 workload=self._check_for_cancelled_flow_runs, 

94 interval=self.polling_interval, 

95 jitter_range=0.3, 

96 ) 

97 ) 

98 self._polling_task.add_done_callback( 

99 lambda task: self.logger.error( 

100 "Cancellation polling task failed. Execution will continue, but flow run cancellation will fail.", 

101 exc_info=task.exception(), 

102 ) 

103 if task.exception() 

104 else self.logger.debug("Polling task completed") 

105 ) 

106 

107 async def _check_for_cancelled_flow_runs(self): 1a

108 if self._is_shutting_down: 

109 return 

110 if self._client is None: 

111 raise RuntimeError( 

112 "Client not initialized. Please use `async with` to initialize the observer." 

113 ) 

114 

115 self.logger.debug("Checking for cancelled flow runs") 

116 named_cancelling_flow_runs = await self._client.read_flow_runs( 

117 flow_run_filter=FlowRunFilter( 

118 state=FlowRunFilterState( 

119 type=FlowRunFilterStateType(any_=[StateType.CANCELLED]), 

120 name=FlowRunFilterStateName(any_=["Cancelling"]), 

121 ), 

122 # Avoid duplicate cancellation calls 

123 id=FlowRunFilterId( 

124 any_=list( 

125 self._in_flight_flow_run_ids - self._cancelling_flow_run_ids 

126 ) 

127 ), 

128 ), 

129 ) 

130 

131 typed_cancelling_flow_runs = await self._client.read_flow_runs( 

132 flow_run_filter=FlowRunFilter( 

133 state=FlowRunFilterState( 

134 type=FlowRunFilterStateType(any_=[StateType.CANCELLING]), 

135 ), 

136 # Avoid duplicate cancellation calls 

137 id=FlowRunFilterId( 

138 any_=list( 

139 self._in_flight_flow_run_ids - self._cancelling_flow_run_ids 

140 ) 

141 ), 

142 ), 

143 ) 

144 

145 cancelling_flow_runs = named_cancelling_flow_runs + typed_cancelling_flow_runs 

146 

147 if cancelling_flow_runs: 

148 self.logger.info( 

149 "Found %s flow runs awaiting cancellation.", len(cancelling_flow_runs) 

150 ) 

151 

152 for flow_run in cancelling_flow_runs: 

153 self._cancelling_flow_run_ids.add(flow_run.id) 

154 self.on_cancelling(flow_run.id) 

155 

156 async def __aenter__(self): 1a

157 self._events_subscriber = await self._exit_stack.enter_async_context( 

158 get_events_subscriber( 

159 filter=EventFilter( 

160 event=EventNameFilter(name=["prefect.flow-run.Cancelling"]) 

161 ), 

162 ) 

163 ) 

164 self._client = await self._exit_stack.enter_async_context(get_client()) 

165 self._consumer_task = asyncio.create_task(self._consume_events()) 

166 self._consumer_task.add_done_callback(self._start_polling_task) 

167 return self 

168 

169 async def __aexit__(self, *exc_info: Any): 1a

170 self.logger.debug("Shutting down FlowRunCancellingObserver") 

171 self._is_shutting_down = True 

172 await self._exit_stack.__aexit__(*exc_info) 

173 if self._consumer_task is not None: 

174 self._consumer_task.cancel() 

175 try: 

176 await self._consumer_task 

177 except asyncio.CancelledError: 

178 pass 

179 except Exception: 

180 self.logger.warning( 

181 "Consumer task exited with exception", exc_info=True 

182 ) 

183 pass 

184 

185 if self._polling_task is not None: 

186 self._polling_task.cancel() 

187 try: 

188 await self._polling_task 

189 except asyncio.CancelledError: 

190 pass 

191 except Exception: 

192 self.logger.warning("Polling task exited with exception", exc_info=True) 

193 pass