Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/services/task_run_recorder.py: 49%

134 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1from __future__ import annotations 1a

2 

3import asyncio 1a

4from contextlib import AsyncExitStack, asynccontextmanager 1a

5from datetime import timedelta 1a

6from typing import TYPE_CHECKING, AsyncGenerator, NoReturn, Optional 1a

7from uuid import UUID 1a

8 

9from sqlalchemy.ext.asyncio import AsyncSession 1a

10 

11import prefect.types._datetime 1a

12from prefect.logging import get_logger 1a

13from prefect.server.database import ( 1a

14 PrefectDBInterface, 

15 db_injector, 

16 provide_database_interface, 

17) 

18from prefect.server.events.ordering import ( 1a

19 EventArrivedEarly, 

20 get_task_run_recorder_causal_ordering, 

21) 

22from prefect.server.events.schemas.events import ReceivedEvent 1a

23from prefect.server.schemas.core import TaskRun 1a

24from prefect.server.schemas.states import State 1a

25from prefect.server.services.base import RunInEphemeralServers, Service 1a

26from prefect.server.utilities.messaging import ( 1a

27 Consumer, 

28 Message, 

29 MessageHandler, 

30 create_consumer, 

31) 

32from prefect.server.utilities.messaging._consumer_names import ( 1a

33 generate_unique_consumer_name, 

34) 

35from prefect.server.utilities.messaging.memory import log_metrics_periodically 1a

36from prefect.settings.context import get_current_settings 1a

37from prefect.settings.models.server.services import ServicesBaseSetting 1a

38 

39if TYPE_CHECKING: 39 ↛ 40line 39 didn't jump to line 40 because the condition on line 39 was never true1a

40 import logging 

41 

42logger: "logging.Logger" = get_logger(__name__) 1a

43 

44 

45@db_injector 1a

46async def _insert_task_run_state( 1a

47 db: PrefectDBInterface, session: AsyncSession, task_run: TaskRun 

48): 

49 if TYPE_CHECKING: 

50 assert task_run.state is not None 

51 await session.execute( 

52 db.queries.insert(db.TaskRunState) 

53 .values( 

54 created=prefect.types._datetime.now("UTC"), 

55 task_run_id=task_run.id, 

56 **task_run.state.model_dump(), 

57 ) 

58 .on_conflict_do_nothing( 

59 index_elements=[ 

60 "id", 

61 ] 

62 ) 

63 ) 

64 

65 

66def task_run_from_event(event: ReceivedEvent) -> TaskRun: 1a

67 task_run_id = event.resource.prefect_object_id("prefect.task-run") 

68 

69 flow_run_id: Optional[UUID] = None 

70 if flow_run_resource := event.resource_in_role.get("flow-run"): 

71 flow_run_id = flow_run_resource.prefect_object_id("prefect.flow-run") 

72 

73 state: State = State.model_validate( 

74 { 

75 "id": event.id, 

76 "timestamp": event.occurred, 

77 **event.payload["validated_state"], 

78 } 

79 ) 

80 state.state_details.task_run_id = task_run_id 

81 state.state_details.flow_run_id = flow_run_id 

82 

83 return TaskRun.model_validate( 

84 { 

85 "id": task_run_id, 

86 "flow_run_id": flow_run_id, 

87 "state_id": state.id, 

88 "state": state, 

89 **event.payload["task_run"], 

90 } 

91 ) 

92 

93 

94async def record_task_run_event(event: ReceivedEvent, depth: int = 0) -> None: 1a

95 async with AsyncExitStack() as stack: 

96 try: 

97 await stack.enter_async_context( 

98 get_task_run_recorder_causal_ordering().preceding_event_confirmed( 

99 record_task_run_event, event, depth=depth 

100 ) 

101 ) 

102 except EventArrivedEarly: 

103 # We're safe to ACK this message because it has been parked by the 

104 # causal ordering mechanism and will be reprocessed when the preceding 

105 # event arrives. 

106 return 

107 

108 task_run = task_run_from_event(event) 

109 

110 task_run_attributes = task_run.model_dump_for_orm( 

111 exclude={ 

112 "state_id", 

113 "state", 

114 "created", 

115 "estimated_run_time", 

116 "estimated_start_time_delta", 

117 }, 

118 exclude_unset=True, 

119 ) 

120 

121 assert task_run.state 

122 

123 denormalized_state_attributes = { 

124 "state_id": task_run.state.id, 

125 "state_type": task_run.state.type, 

126 "state_name": task_run.state.name, 

127 "state_timestamp": task_run.state.timestamp, 

128 } 

129 

130 db = provide_database_interface() 

131 async with db.session_context() as session: 

132 # Combine all attributes for a single atomic operation 

133 all_attributes = { 

134 **task_run_attributes, 

135 **denormalized_state_attributes, 

136 "created": prefect.types._datetime.now("UTC"), 

137 } 

138 

139 # Single atomic INSERT ... ON CONFLICT DO UPDATE 

140 await session.execute( 

141 db.queries.insert(db.TaskRun) 

142 .values(**all_attributes) 

143 .on_conflict_do_update( 

144 index_elements=["id"], 

145 set_={ 

146 "updated": prefect.types._datetime.now("UTC"), 

147 **task_run_attributes, 

148 **denormalized_state_attributes, 

149 }, 

150 where=db.TaskRun.state_timestamp < task_run.state.timestamp, 

151 ) 

152 ) 

153 

154 # Still need to insert the task_run_state separately 

155 await _insert_task_run_state(session, task_run) 

156 

157 await session.commit() 

158 

159 logger.debug( 

160 "Recorded task run state change", 

161 extra={ 

162 "task_run_id": task_run.id, 

163 "flow_run_id": task_run.flow_run_id, 

164 "event_id": event.id, 

165 "event_follows": event.follows, 

166 "event": event.event, 

167 "occurred": event.occurred, 

168 "current_state_type": task_run.state_type, 

169 "current_state_name": task_run.state_name, 

170 }, 

171 ) 

172 

173 

174async def record_lost_follower_task_run_events() -> None: 1a

175 ordering = get_task_run_recorder_causal_ordering() 1cde

176 events = await ordering.get_lost_followers() 1cde

177 

178 for event in events: 178 ↛ 183line 178 didn't jump to line 183 because the loop on line 178 never started1cde

179 # Temporarily skip events that are older than 24 hours 

180 # this is to avoid processing a large backlog of events 

181 # that are potentially sitting in the waitlist while 

182 # we were not processing lost followers 

183 if event.occurred < prefect.types._datetime.now("UTC") - timedelta(hours=24): 

184 await ordering.forget_follower(event) 

185 continue 

186 

187 await record_task_run_event(event) 

188 

189 

190async def periodically_process_followers(periodic_granularity: timedelta) -> NoReturn: 1a

191 """Periodically process followers that are waiting on a leader event that never arrived""" 

192 

193 logger.info( 

194 "Starting periodically process followers task every %s seconds", 

195 periodic_granularity.total_seconds(), 

196 ) 

197 while True: 1cde

198 try: 1cde

199 await record_lost_follower_task_run_events() 1cde

200 except asyncio.CancelledError: 

201 logger.info("Periodically process followers task cancelled") 

202 return 

203 except Exception: 

204 logger.exception("Error running periodically process followers task") 

205 finally: 

206 await asyncio.sleep(periodic_granularity.total_seconds()) 1cde

207 

208 

209@asynccontextmanager 1a

210async def consumer() -> AsyncGenerator[MessageHandler, None]: 1a

211 record_lost_followers_task = asyncio.create_task( 1b

212 periodically_process_followers(periodic_granularity=timedelta(seconds=5)) 

213 ) 

214 

215 async def message_handler(message: Message): 1b

216 event: ReceivedEvent = ReceivedEvent.model_validate_json(message.data) 1cfdghe

217 

218 if not event.event.startswith("prefect.task-run"): 218 ↛ 221line 218 didn't jump to line 221 because the condition on line 218 was always true1cfdghe

219 return 1cfdghe

220 

221 if not event.resource.get("prefect.orchestration") == "client": 

222 return 

223 

224 logger.debug( 

225 "Received event: %s with id: %s for resource: %s", 

226 event.event, 

227 event.id, 

228 event.resource.get("prefect.resource.id"), 

229 ) 

230 

231 await record_task_run_event(event) 

232 

233 try: 1b

234 yield message_handler 1b

235 finally: 

236 record_lost_followers_task.cancel() 

237 try: 

238 await record_lost_followers_task 

239 except asyncio.CancelledError: 

240 logger.info("Periodically process followers task cancelled successfully") 

241 

242 

243class TaskRunRecorder(RunInEphemeralServers, Service): 1a

244 """Constructs task runs and states from client-emitted events""" 

245 

246 consumer_task: asyncio.Task[None] | None = None 1a

247 metrics_task: asyncio.Task[None] | None = None 1a

248 

249 @classmethod 1a

250 def service_settings(cls) -> ServicesBaseSetting: 1a

251 return get_current_settings().server.services.task_run_recorder 1b

252 

253 def __init__(self): 1a

254 super().__init__() 1b

255 self._started_event: Optional[asyncio.Event] = None 1b

256 

257 @property 1a

258 def started_event(self) -> asyncio.Event: 1a

259 if self._started_event is None: 259 ↛ 261line 259 didn't jump to line 261 because the condition on line 259 was always true1b

260 self._started_event = asyncio.Event() 1b

261 return self._started_event 1b

262 

263 @started_event.setter 1a

264 def started_event(self, value: asyncio.Event) -> None: 1a

265 self._started_event = value 

266 

267 async def start(self) -> NoReturn: 1a

268 assert self.consumer_task is None, "TaskRunRecorder already started" 1b

269 self.consumer: Consumer = create_consumer( 1b

270 "events", 

271 group="task-run-recorder", 

272 name=generate_unique_consumer_name("task-run-recorder"), 

273 ) 

274 

275 async with consumer() as handler: 1b

276 self.consumer_task = asyncio.create_task(self.consumer.run(handler)) 1b

277 self.metrics_task = asyncio.create_task(log_metrics_periodically()) 1b

278 

279 logger.debug("TaskRunRecorder started") 1b

280 self.started_event.set() 1b

281 

282 try: 1b

283 await self.consumer_task 1b

284 except asyncio.CancelledError: 

285 pass 

286 

287 async def stop(self) -> None: 1a

288 assert self.consumer_task is not None, "Logger not started" 

289 self.consumer_task.cancel() 

290 if self.metrics_task: 

291 self.metrics_task.cancel() 

292 try: 

293 await self.consumer_task 

294 if self.metrics_task: 

295 await self.metrics_task 

296 except asyncio.CancelledError: 

297 pass 

298 finally: 

299 self.consumer_task = None 

300 self.metrics_task = None 

301 logger.debug("TaskRunRecorder stopped")