Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/services/task_run_recorder.py: 49%
134 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
1from __future__ import annotations 1a
3import asyncio 1a
4from contextlib import AsyncExitStack, asynccontextmanager 1a
5from datetime import timedelta 1a
6from typing import TYPE_CHECKING, AsyncGenerator, NoReturn, Optional 1a
7from uuid import UUID 1a
9from sqlalchemy.ext.asyncio import AsyncSession 1a
11import prefect.types._datetime 1a
12from prefect.logging import get_logger 1a
13from prefect.server.database import ( 1a
14 PrefectDBInterface,
15 db_injector,
16 provide_database_interface,
17)
18from prefect.server.events.ordering import ( 1a
19 EventArrivedEarly,
20 get_task_run_recorder_causal_ordering,
21)
22from prefect.server.events.schemas.events import ReceivedEvent 1a
23from prefect.server.schemas.core import TaskRun 1a
24from prefect.server.schemas.states import State 1a
25from prefect.server.services.base import RunInEphemeralServers, Service 1a
26from prefect.server.utilities.messaging import ( 1a
27 Consumer,
28 Message,
29 MessageHandler,
30 create_consumer,
31)
32from prefect.server.utilities.messaging._consumer_names import ( 1a
33 generate_unique_consumer_name,
34)
35from prefect.server.utilities.messaging.memory import log_metrics_periodically 1a
36from prefect.settings.context import get_current_settings 1a
37from prefect.settings.models.server.services import ServicesBaseSetting 1a
39if TYPE_CHECKING: 39 ↛ 40line 39 didn't jump to line 40 because the condition on line 39 was never true1a
40 import logging
42logger: "logging.Logger" = get_logger(__name__) 1a
45@db_injector 1a
46async def _insert_task_run_state( 1a
47 db: PrefectDBInterface, session: AsyncSession, task_run: TaskRun
48):
49 if TYPE_CHECKING:
50 assert task_run.state is not None
51 await session.execute(
52 db.queries.insert(db.TaskRunState)
53 .values(
54 created=prefect.types._datetime.now("UTC"),
55 task_run_id=task_run.id,
56 **task_run.state.model_dump(),
57 )
58 .on_conflict_do_nothing(
59 index_elements=[
60 "id",
61 ]
62 )
63 )
66def task_run_from_event(event: ReceivedEvent) -> TaskRun: 1a
67 task_run_id = event.resource.prefect_object_id("prefect.task-run")
69 flow_run_id: Optional[UUID] = None
70 if flow_run_resource := event.resource_in_role.get("flow-run"):
71 flow_run_id = flow_run_resource.prefect_object_id("prefect.flow-run")
73 state: State = State.model_validate(
74 {
75 "id": event.id,
76 "timestamp": event.occurred,
77 **event.payload["validated_state"],
78 }
79 )
80 state.state_details.task_run_id = task_run_id
81 state.state_details.flow_run_id = flow_run_id
83 return TaskRun.model_validate(
84 {
85 "id": task_run_id,
86 "flow_run_id": flow_run_id,
87 "state_id": state.id,
88 "state": state,
89 **event.payload["task_run"],
90 }
91 )
94async def record_task_run_event(event: ReceivedEvent, depth: int = 0) -> None: 1a
95 async with AsyncExitStack() as stack:
96 try:
97 await stack.enter_async_context(
98 get_task_run_recorder_causal_ordering().preceding_event_confirmed(
99 record_task_run_event, event, depth=depth
100 )
101 )
102 except EventArrivedEarly:
103 # We're safe to ACK this message because it has been parked by the
104 # causal ordering mechanism and will be reprocessed when the preceding
105 # event arrives.
106 return
108 task_run = task_run_from_event(event)
110 task_run_attributes = task_run.model_dump_for_orm(
111 exclude={
112 "state_id",
113 "state",
114 "created",
115 "estimated_run_time",
116 "estimated_start_time_delta",
117 },
118 exclude_unset=True,
119 )
121 assert task_run.state
123 denormalized_state_attributes = {
124 "state_id": task_run.state.id,
125 "state_type": task_run.state.type,
126 "state_name": task_run.state.name,
127 "state_timestamp": task_run.state.timestamp,
128 }
130 db = provide_database_interface()
131 async with db.session_context() as session:
132 # Combine all attributes for a single atomic operation
133 all_attributes = {
134 **task_run_attributes,
135 **denormalized_state_attributes,
136 "created": prefect.types._datetime.now("UTC"),
137 }
139 # Single atomic INSERT ... ON CONFLICT DO UPDATE
140 await session.execute(
141 db.queries.insert(db.TaskRun)
142 .values(**all_attributes)
143 .on_conflict_do_update(
144 index_elements=["id"],
145 set_={
146 "updated": prefect.types._datetime.now("UTC"),
147 **task_run_attributes,
148 **denormalized_state_attributes,
149 },
150 where=db.TaskRun.state_timestamp < task_run.state.timestamp,
151 )
152 )
154 # Still need to insert the task_run_state separately
155 await _insert_task_run_state(session, task_run)
157 await session.commit()
159 logger.debug(
160 "Recorded task run state change",
161 extra={
162 "task_run_id": task_run.id,
163 "flow_run_id": task_run.flow_run_id,
164 "event_id": event.id,
165 "event_follows": event.follows,
166 "event": event.event,
167 "occurred": event.occurred,
168 "current_state_type": task_run.state_type,
169 "current_state_name": task_run.state_name,
170 },
171 )
174async def record_lost_follower_task_run_events() -> None: 1a
175 ordering = get_task_run_recorder_causal_ordering() 1cde
176 events = await ordering.get_lost_followers() 1cde
178 for event in events: 178 ↛ 183line 178 didn't jump to line 183 because the loop on line 178 never started1cde
179 # Temporarily skip events that are older than 24 hours
180 # this is to avoid processing a large backlog of events
181 # that are potentially sitting in the waitlist while
182 # we were not processing lost followers
183 if event.occurred < prefect.types._datetime.now("UTC") - timedelta(hours=24):
184 await ordering.forget_follower(event)
185 continue
187 await record_task_run_event(event)
190async def periodically_process_followers(periodic_granularity: timedelta) -> NoReturn: 1a
191 """Periodically process followers that are waiting on a leader event that never arrived"""
193 logger.info(
194 "Starting periodically process followers task every %s seconds",
195 periodic_granularity.total_seconds(),
196 )
197 while True: 1cde
198 try: 1cde
199 await record_lost_follower_task_run_events() 1cde
200 except asyncio.CancelledError:
201 logger.info("Periodically process followers task cancelled")
202 return
203 except Exception:
204 logger.exception("Error running periodically process followers task")
205 finally:
206 await asyncio.sleep(periodic_granularity.total_seconds()) 1cde
209@asynccontextmanager 1a
210async def consumer() -> AsyncGenerator[MessageHandler, None]: 1a
211 record_lost_followers_task = asyncio.create_task( 1b
212 periodically_process_followers(periodic_granularity=timedelta(seconds=5))
213 )
215 async def message_handler(message: Message): 1b
216 event: ReceivedEvent = ReceivedEvent.model_validate_json(message.data) 1cfdghe
218 if not event.event.startswith("prefect.task-run"): 218 ↛ 221line 218 didn't jump to line 221 because the condition on line 218 was always true1cfdghe
219 return 1cfdghe
221 if not event.resource.get("prefect.orchestration") == "client":
222 return
224 logger.debug(
225 "Received event: %s with id: %s for resource: %s",
226 event.event,
227 event.id,
228 event.resource.get("prefect.resource.id"),
229 )
231 await record_task_run_event(event)
233 try: 1b
234 yield message_handler 1b
235 finally:
236 record_lost_followers_task.cancel()
237 try:
238 await record_lost_followers_task
239 except asyncio.CancelledError:
240 logger.info("Periodically process followers task cancelled successfully")
243class TaskRunRecorder(RunInEphemeralServers, Service): 1a
244 """Constructs task runs and states from client-emitted events"""
246 consumer_task: asyncio.Task[None] | None = None 1a
247 metrics_task: asyncio.Task[None] | None = None 1a
249 @classmethod 1a
250 def service_settings(cls) -> ServicesBaseSetting: 1a
251 return get_current_settings().server.services.task_run_recorder 1b
253 def __init__(self): 1a
254 super().__init__() 1b
255 self._started_event: Optional[asyncio.Event] = None 1b
257 @property 1a
258 def started_event(self) -> asyncio.Event: 1a
259 if self._started_event is None: 259 ↛ 261line 259 didn't jump to line 261 because the condition on line 259 was always true1b
260 self._started_event = asyncio.Event() 1b
261 return self._started_event 1b
263 @started_event.setter 1a
264 def started_event(self, value: asyncio.Event) -> None: 1a
265 self._started_event = value
267 async def start(self) -> NoReturn: 1a
268 assert self.consumer_task is None, "TaskRunRecorder already started" 1b
269 self.consumer: Consumer = create_consumer( 1b
270 "events",
271 group="task-run-recorder",
272 name=generate_unique_consumer_name("task-run-recorder"),
273 )
275 async with consumer() as handler: 1b
276 self.consumer_task = asyncio.create_task(self.consumer.run(handler)) 1b
277 self.metrics_task = asyncio.create_task(log_metrics_periodically()) 1b
279 logger.debug("TaskRunRecorder started") 1b
280 self.started_event.set() 1b
282 try: 1b
283 await self.consumer_task 1b
284 except asyncio.CancelledError:
285 pass
287 async def stop(self) -> None: 1a
288 assert self.consumer_task is not None, "Logger not started"
289 self.consumer_task.cancel()
290 if self.metrics_task:
291 self.metrics_task.cancel()
292 try:
293 await self.consumer_task
294 if self.metrics_task:
295 await self.metrics_task
296 except asyncio.CancelledError:
297 pass
298 finally:
299 self.consumer_task = None
300 self.metrics_task = None
301 logger.debug("TaskRunRecorder stopped")