Coverage for /usr/local/lib/python3.12/site-packages/prefect/runner/_observers.py: 21%
91 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
1from __future__ import annotations 1a
3import asyncio 1a
4import uuid 1a
5from contextlib import AsyncExitStack 1a
6from typing import Any, Protocol 1a
8from prefect.client.orchestration import PrefectClient, get_client 1a
9from prefect.client.schemas.filters import ( 1a
10 FlowRunFilter,
11 FlowRunFilterId,
12 FlowRunFilterState,
13 FlowRunFilterStateName,
14 FlowRunFilterStateType,
15)
16from prefect.client.schemas.objects import StateType 1a
17from prefect.events.clients import PrefectEventSubscriber, get_events_subscriber 1a
18from prefect.events.filters import EventFilter, EventNameFilter 1a
19from prefect.logging.loggers import get_logger 1a
20from prefect.utilities.services import critical_service_loop 1a
23class OnCancellingCallback(Protocol): 1a
24 def __call__(self, flow_run_id: uuid.UUID) -> None: ... 24 ↛ exitline 24 didn't return from function '__call__' because 1a
27class FlowRunCancellingObserver: 1a
28 def __init__( 1a
29 self, on_cancelling: OnCancellingCallback, polling_interval: float = 10
30 ):
31 """
32 Observer that cancels flow runs when they are marked as cancelling.
34 Will use a websocket connection to listen for cancelling flow run events by default with a fallback
35 to polling when the websocket connection is lost.
37 Args:
38 on_cancelling: Callback to call when a flow run is marked as cancelling.
39 flow_run_ids: List of flow run IDs to poll when websocket connection is lost.
40 polling_interval: Interval in seconds to poll for cancelling flow runs when websocket connection is lost.
41 """
42 self.logger = get_logger("FlowRunCancellingObserver")
43 self.on_cancelling = on_cancelling
44 self.polling_interval = polling_interval
45 self._in_flight_flow_run_ids: set[uuid.UUID] = set()
46 self._events_subscriber: PrefectEventSubscriber | None
47 self._exit_stack = AsyncExitStack()
48 self._consumer_task: asyncio.Task[None] | None = None
49 self._polling_task: asyncio.Task[None] | None = None
50 self._is_shutting_down = False
51 self._client: PrefectClient | None = None
52 self._cancelling_flow_run_ids: set[uuid.UUID] = set()
54 def add_in_flight_flow_run_id(self, flow_run_id: uuid.UUID): 1a
55 self.logger.debug("Adding in-flight flow run ID: %s", flow_run_id)
56 self._in_flight_flow_run_ids.add(flow_run_id)
58 def remove_in_flight_flow_run_id(self, flow_run_id: uuid.UUID): 1a
59 self.logger.debug("Removing in-flight flow run ID: %s", flow_run_id)
60 self._in_flight_flow_run_ids.discard(flow_run_id)
62 async def _consume_events(self): 1a
63 if self._events_subscriber is None:
64 raise RuntimeError(
65 "Events subscriber not initialized. Please use `async with` to initialize the observer."
66 )
67 async for event in self._events_subscriber:
68 try:
69 flow_run_id = uuid.UUID(
70 event.resource["prefect.resource.id"].replace(
71 "prefect.flow-run.", ""
72 )
73 )
74 self.on_cancelling(flow_run_id)
75 except ValueError:
76 self.logger.debug(
77 "Received event with invalid flow run ID: %s",
78 event.resource["prefect.resource.id"],
79 )
81 def _start_polling_task(self, task: asyncio.Task[None]): 1a
82 if task.cancelled():
83 # If the consumer task was cancelled, the observer is shutting down
84 # and we don't need to start the polling task
85 return
86 if exc := task.exception():
87 self.logger.debug(
88 "The FlowRunCancellingObserver websocket failed with an exception. Switching to polling mode.",
89 exc_info=exc,
90 )
91 self._polling_task = asyncio.create_task(
92 critical_service_loop(
93 workload=self._check_for_cancelled_flow_runs,
94 interval=self.polling_interval,
95 jitter_range=0.3,
96 )
97 )
98 self._polling_task.add_done_callback(
99 lambda task: self.logger.error(
100 "Cancellation polling task failed. Execution will continue, but flow run cancellation will fail.",
101 exc_info=task.exception(),
102 )
103 if task.exception()
104 else self.logger.debug("Polling task completed")
105 )
107 async def _check_for_cancelled_flow_runs(self): 1a
108 if self._is_shutting_down:
109 return
110 if self._client is None:
111 raise RuntimeError(
112 "Client not initialized. Please use `async with` to initialize the observer."
113 )
115 self.logger.debug("Checking for cancelled flow runs")
116 named_cancelling_flow_runs = await self._client.read_flow_runs(
117 flow_run_filter=FlowRunFilter(
118 state=FlowRunFilterState(
119 type=FlowRunFilterStateType(any_=[StateType.CANCELLED]),
120 name=FlowRunFilterStateName(any_=["Cancelling"]),
121 ),
122 # Avoid duplicate cancellation calls
123 id=FlowRunFilterId(
124 any_=list(
125 self._in_flight_flow_run_ids - self._cancelling_flow_run_ids
126 )
127 ),
128 ),
129 )
131 typed_cancelling_flow_runs = await self._client.read_flow_runs(
132 flow_run_filter=FlowRunFilter(
133 state=FlowRunFilterState(
134 type=FlowRunFilterStateType(any_=[StateType.CANCELLING]),
135 ),
136 # Avoid duplicate cancellation calls
137 id=FlowRunFilterId(
138 any_=list(
139 self._in_flight_flow_run_ids - self._cancelling_flow_run_ids
140 )
141 ),
142 ),
143 )
145 cancelling_flow_runs = named_cancelling_flow_runs + typed_cancelling_flow_runs
147 if cancelling_flow_runs:
148 self.logger.info(
149 "Found %s flow runs awaiting cancellation.", len(cancelling_flow_runs)
150 )
152 for flow_run in cancelling_flow_runs:
153 self._cancelling_flow_run_ids.add(flow_run.id)
154 self.on_cancelling(flow_run.id)
156 async def __aenter__(self): 1a
157 self._events_subscriber = await self._exit_stack.enter_async_context(
158 get_events_subscriber(
159 filter=EventFilter(
160 event=EventNameFilter(name=["prefect.flow-run.Cancelling"])
161 ),
162 )
163 )
164 self._client = await self._exit_stack.enter_async_context(get_client())
165 self._consumer_task = asyncio.create_task(self._consume_events())
166 self._consumer_task.add_done_callback(self._start_polling_task)
167 return self
169 async def __aexit__(self, *exc_info: Any): 1a
170 self.logger.debug("Shutting down FlowRunCancellingObserver")
171 self._is_shutting_down = True
172 await self._exit_stack.__aexit__(*exc_info)
173 if self._consumer_task is not None:
174 self._consumer_task.cancel()
175 try:
176 await self._consumer_task
177 except asyncio.CancelledError:
178 pass
179 except Exception:
180 self.logger.warning(
181 "Consumer task exited with exception", exc_info=True
182 )
183 pass
185 if self._polling_task is not None:
186 self._polling_task.cancel()
187 try:
188 await self._polling_task
189 except asyncio.CancelledError:
190 pass
191 except Exception:
192 self.logger.warning("Polling task exited with exception", exc_info=True)
193 pass