Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/api/task_runs.py: 29%

139 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 13:38 +0000

1""" 

2Routes for interacting with task run objects. 

3""" 

4 

5import asyncio 1a

6import datetime 1a

7from typing import TYPE_CHECKING, Any, Dict, List, Optional 1a

8from uuid import UUID 1a

9 

10from docket import Depends as DocketDepends 1a

11from docket import Retry 1a

12from fastapi import ( 1a

13 Body, 

14 Depends, 

15 HTTPException, 

16 Path, 

17 Response, 

18 WebSocket, 

19) 

20from fastapi.responses import ORJSONResponse 1a

21from starlette.websockets import WebSocketDisconnect 1a

22 

23import prefect.server.api.dependencies as dependencies 1a

24import prefect.server.models as models 1a

25import prefect.server.schemas as schemas 1a

26from prefect._internal.compatibility.starlette import status 1a

27from prefect.logging import get_logger 1a

28from prefect.server.api.run_history import run_history 1a

29from prefect.server.database import PrefectDBInterface, provide_database_interface 1a

30from prefect.server.orchestration import dependencies as orchestration_dependencies 1a

31from prefect.server.orchestration.core_policy import CoreTaskPolicy 1a

32from prefect.server.orchestration.policies import TaskRunOrchestrationPolicy 1a

33from prefect.server.schemas.responses import ( 1a

34 OrchestrationResult, 

35 TaskRunPaginationResponse, 

36) 

37from prefect.server.task_queue import MultiQueue, TaskQueue 1a

38from prefect.server.utilities import subscriptions 1a

39from prefect.server.utilities.server import PrefectRouter 1a

40from prefect.types import DateTime 1a

41from prefect.types._datetime import now 1a

42 

43if TYPE_CHECKING: 43 ↛ 44line 43 didn't jump to line 44 because the condition on line 43 was never true1a

44 import logging 

45 

46logger: "logging.Logger" = get_logger("server.api") 1a

47 

48router: PrefectRouter = PrefectRouter(prefix="/task_runs", tags=["Task Runs"]) 1a

49 

50 

51@router.post("/") 1a

52async def create_task_run( 1a

53 task_run: schemas.actions.TaskRunCreate, 

54 response: Response, 

55 db: PrefectDBInterface = Depends(provide_database_interface), 

56 orchestration_parameters: Dict[str, Any] = Depends( 

57 orchestration_dependencies.provide_task_orchestration_parameters 

58 ), 

59) -> schemas.core.TaskRun: 

60 """ 

61 Create a task run. If a task run with the same flow_run_id, 

62 task_key, and dynamic_key already exists, the existing task 

63 run will be returned. 

64 

65 If no state is provided, the task run will be created in a PENDING state. 

66 

67 For more information, see https://docs.prefect.io/v3/concepts/tasks. 

68 """ 

69 # hydrate the input model into a full task run / state model 

70 task_run_dict = task_run.model_dump() 

71 if not task_run_dict.get("id"): 

72 task_run_dict.pop("id", None) 

73 task_run = schemas.core.TaskRun(**task_run_dict) 

74 

75 if not task_run.state: 

76 task_run.state = schemas.states.Pending() 

77 

78 right_now = now("UTC") 

79 

80 async with db.session_context(begin_transaction=True) as session: 

81 model = await models.task_runs.create_task_run( 

82 session=session, 

83 task_run=task_run, 

84 orchestration_parameters=orchestration_parameters, 

85 ) 

86 

87 if model.created >= right_now: 

88 response.status_code = status.HTTP_201_CREATED 

89 

90 new_task_run: schemas.core.TaskRun = schemas.core.TaskRun.model_validate(model) 

91 

92 return new_task_run 

93 

94 

95@router.patch("/{id:uuid}", status_code=status.HTTP_204_NO_CONTENT) 1a

96async def update_task_run( 1a

97 task_run: schemas.actions.TaskRunUpdate, 

98 task_run_id: UUID = Path(..., description="The task run id", alias="id"), 

99 db: PrefectDBInterface = Depends(provide_database_interface), 

100) -> None: 

101 """ 

102 Updates a task run. 

103 """ 

104 async with db.session_context(begin_transaction=True) as session: 

105 result = await models.task_runs.update_task_run( 

106 session=session, task_run=task_run, task_run_id=task_run_id 

107 ) 

108 if not result: 

109 raise HTTPException(status.HTTP_404_NOT_FOUND, detail="Task run not found") 

110 

111 

112@router.post("/count") 1a

113async def count_task_runs( 1a

114 db: PrefectDBInterface = Depends(provide_database_interface), 

115 flows: schemas.filters.FlowFilter = None, 

116 flow_runs: schemas.filters.FlowRunFilter = None, 

117 task_runs: schemas.filters.TaskRunFilter = None, 

118 deployments: schemas.filters.DeploymentFilter = None, 

119) -> int: 

120 """ 

121 Count task runs. 

122 """ 

123 async with db.session_context() as session: 

124 return await models.task_runs.count_task_runs( 

125 session=session, 

126 flow_filter=flows, 

127 flow_run_filter=flow_runs, 

128 task_run_filter=task_runs, 

129 deployment_filter=deployments, 

130 ) 

131 

132 

133@router.post("/history") 1a

134async def task_run_history( 1a

135 history_start: DateTime = Body(..., description="The history's start time."), 

136 history_end: DateTime = Body(..., description="The history's end time."), 

137 # Workaround for the fact that FastAPI does not let us configure ser_json_timedelta 

138 # to represent timedeltas as floats in JSON. 

139 history_interval: float = Body( 

140 ..., 

141 description=( 

142 "The size of each history interval, in seconds. Must be at least 1 second." 

143 ), 

144 json_schema_extra={"format": "time-delta"}, 

145 alias="history_interval_seconds", 

146 ), 

147 flows: schemas.filters.FlowFilter = None, 

148 flow_runs: schemas.filters.FlowRunFilter = None, 

149 task_runs: schemas.filters.TaskRunFilter = None, 

150 deployments: schemas.filters.DeploymentFilter = None, 

151 db: PrefectDBInterface = Depends(provide_database_interface), 

152) -> List[schemas.responses.HistoryResponse]: 

153 """ 

154 Query for task run history data across a given range and interval. 

155 """ 

156 if isinstance(history_interval, float): 

157 history_interval = datetime.timedelta(seconds=history_interval) 

158 

159 if history_interval < datetime.timedelta(seconds=1): 

160 raise HTTPException( 

161 status.HTTP_422_UNPROCESSABLE_ENTITY, 

162 detail="History interval must not be less than 1 second.", 

163 ) 

164 

165 async with db.session_context() as session: 

166 return await run_history( 

167 session=session, 

168 run_type="task_run", 

169 history_start=history_start, 

170 history_end=history_end, 

171 history_interval=history_interval, 

172 flows=flows, 

173 flow_runs=flow_runs, 

174 task_runs=task_runs, 

175 deployments=deployments, 

176 ) 

177 

178 

179@router.get("/{id:uuid}") 1a

180async def read_task_run( 1a

181 task_run_id: UUID = Path(..., description="The task run id", alias="id"), 

182 db: PrefectDBInterface = Depends(provide_database_interface), 

183) -> schemas.core.TaskRun: 

184 """ 

185 Get a task run by id. 

186 """ 

187 async with db.session_context() as session: 

188 task_run = await models.task_runs.read_task_run( 

189 session=session, task_run_id=task_run_id 

190 ) 

191 if not task_run: 

192 raise HTTPException(status.HTTP_404_NOT_FOUND, detail="Task not found") 

193 return task_run 

194 

195 

196@router.post("/filter") 1a

197async def read_task_runs( 1a

198 sort: schemas.sorting.TaskRunSort = Body(schemas.sorting.TaskRunSort.ID_DESC), 

199 limit: int = dependencies.LimitBody(), 

200 offset: int = Body(0, ge=0), 

201 flows: Optional[schemas.filters.FlowFilter] = None, 

202 flow_runs: Optional[schemas.filters.FlowRunFilter] = None, 

203 task_runs: Optional[schemas.filters.TaskRunFilter] = None, 

204 deployments: Optional[schemas.filters.DeploymentFilter] = None, 

205 db: PrefectDBInterface = Depends(provide_database_interface), 

206) -> List[schemas.core.TaskRun]: 

207 """ 

208 Query for task runs. 

209 """ 

210 async with db.session_context() as session: 

211 return await models.task_runs.read_task_runs( 

212 session=session, 

213 flow_filter=flows, 

214 flow_run_filter=flow_runs, 

215 task_run_filter=task_runs, 

216 deployment_filter=deployments, 

217 offset=offset, 

218 limit=limit, 

219 sort=sort, 

220 ) 

221 

222 

223@router.post("/paginate", response_class=ORJSONResponse) 1a

224async def paginate_task_runs( 1a

225 sort: schemas.sorting.TaskRunSort = Body(schemas.sorting.TaskRunSort.ID_DESC), 

226 limit: int = dependencies.LimitBody(), 

227 page: int = Body(1, ge=1), 

228 flows: Optional[schemas.filters.FlowFilter] = None, 

229 flow_runs: Optional[schemas.filters.FlowRunFilter] = None, 

230 task_runs: Optional[schemas.filters.TaskRunFilter] = None, 

231 deployments: Optional[schemas.filters.DeploymentFilter] = None, 

232 db: PrefectDBInterface = Depends(provide_database_interface), 

233) -> TaskRunPaginationResponse: 

234 """ 

235 Pagination query for task runs. 

236 """ 

237 offset = (page - 1) * limit 

238 

239 async with db.session_context() as session: 

240 runs = await models.task_runs.read_task_runs( 

241 session=session, 

242 flow_filter=flows, 

243 flow_run_filter=flow_runs, 

244 task_run_filter=task_runs, 

245 deployment_filter=deployments, 

246 offset=offset, 

247 limit=limit, 

248 sort=sort, 

249 ) 

250 

251 total_count = await models.task_runs.count_task_runs( 

252 session=session, 

253 flow_filter=flows, 

254 flow_run_filter=flow_runs, 

255 task_run_filter=task_runs, 

256 deployment_filter=deployments, 

257 ) 

258 

259 return TaskRunPaginationResponse.model_validate( 

260 dict( 

261 results=runs, 

262 count=total_count, 

263 limit=limit, 

264 pages=(total_count + limit - 1) // limit, 

265 page=page, 

266 ) 

267 ) 

268 

269 

270@router.delete("/{id:uuid}", status_code=status.HTTP_204_NO_CONTENT) 1a

271async def delete_task_run( 1a

272 docket: dependencies.Docket, 

273 task_run_id: UUID = Path(..., description="The task run id", alias="id"), 

274 db: PrefectDBInterface = Depends(provide_database_interface), 

275) -> None: 

276 """ 

277 Delete a task run by id. 

278 """ 

279 async with db.session_context(begin_transaction=True) as session: 

280 result = await models.task_runs.delete_task_run( 

281 session=session, task_run_id=task_run_id 

282 ) 

283 if not result: 

284 raise HTTPException(status.HTTP_404_NOT_FOUND, detail="Task not found") 

285 await docket.add(delete_task_run_logs)(task_run_id=task_run_id) 

286 

287 

288async def delete_task_run_logs( 1a

289 *, 

290 db: PrefectDBInterface = DocketDepends(provide_database_interface), 

291 task_run_id: UUID, 

292 retry: Retry = Retry(attempts=5, delay=datetime.timedelta(seconds=0.5)), 

293) -> None: 

294 async with db.session_context(begin_transaction=True) as session: 

295 await models.logs.delete_logs( 

296 session=session, 

297 log_filter=schemas.filters.LogFilter( 

298 task_run_id=schemas.filters.LogFilterTaskRunId(any_=[task_run_id]) 

299 ), 

300 ) 

301 

302 

303@router.post("/{id:uuid}/set_state") 1a

304async def set_task_run_state( 1a

305 task_run_id: UUID = Path(..., description="The task run id", alias="id"), 

306 state: schemas.actions.StateCreate = Body(..., description="The intended state."), 

307 force: bool = Body( 

308 False, 

309 description=( 

310 "If false, orchestration rules will be applied that may alter or prevent" 

311 " the state transition. If True, orchestration rules are not applied." 

312 ), 

313 ), 

314 db: PrefectDBInterface = Depends(provide_database_interface), 

315 response: Response = None, 

316 task_policy: TaskRunOrchestrationPolicy = Depends( 

317 orchestration_dependencies.provide_task_policy 

318 ), 

319 orchestration_parameters: Dict[str, Any] = Depends( 

320 orchestration_dependencies.provide_task_orchestration_parameters 

321 ), 

322) -> OrchestrationResult: 

323 """Set a task run state, invoking any orchestration rules.""" 

324 

325 right_now = now("UTC") 

326 

327 # create the state 

328 async with db.session_context( 

329 begin_transaction=True, with_for_update=True 

330 ) as session: 

331 orchestration_result = await models.task_runs.set_task_run_state( 

332 session=session, 

333 task_run_id=task_run_id, 

334 state=schemas.states.State.model_validate( 

335 state 

336 ), # convert to a full State object 

337 force=force, 

338 task_policy=CoreTaskPolicy, 

339 orchestration_parameters=orchestration_parameters, 

340 ) 

341 

342 # set the 201 if a new state was created 

343 if orchestration_result.state and orchestration_result.state.timestamp >= right_now: 

344 response.status_code = status.HTTP_201_CREATED 

345 else: 

346 response.status_code = status.HTTP_200_OK 

347 

348 return orchestration_result 

349 

350 

351@router.websocket("/subscriptions/scheduled") 1a

352async def scheduled_task_subscription(websocket: WebSocket) -> None: 1a

353 websocket = await subscriptions.accept_prefect_socket(websocket) 

354 if not websocket: 

355 return 

356 

357 try: 

358 subscription = await websocket.receive_json() 

359 except subscriptions.NORMAL_DISCONNECT_EXCEPTIONS: 

360 return 

361 

362 if subscription.get("type") != "subscribe": 

363 return await websocket.close( 

364 code=4001, reason="Protocol violation: expected 'subscribe' message" 

365 ) 

366 

367 task_keys = subscription.get("keys", []) 

368 if not task_keys: 

369 return await websocket.close( 

370 code=4001, reason="Protocol violation: expected 'keys' in subscribe message" 

371 ) 

372 

373 if not (client_id := subscription.get("client_id")): 

374 return await websocket.close( 

375 code=4001, 

376 reason="Protocol violation: expected 'client_id' in subscribe message", 

377 ) 

378 

379 subscribed_queue = MultiQueue(task_keys) 

380 

381 logger.info(f"Task worker {client_id!r} subscribed to task keys {task_keys!r}") 

382 

383 while True: 

384 try: 

385 # observe here so that all workers with active websockets are tracked 

386 await models.task_workers.observe_worker(task_keys, client_id) 

387 task_run = await asyncio.wait_for(subscribed_queue.get(), timeout=1) 

388 except asyncio.TimeoutError: 

389 if not await subscriptions.still_connected(websocket): 

390 await models.task_workers.forget_worker(client_id) 

391 return 

392 continue 

393 

394 try: 

395 await websocket.send_json(task_run.model_dump(mode="json")) 

396 

397 acknowledgement = await websocket.receive_json() 

398 ack_type = acknowledgement.get("type") 

399 if ack_type != "ack": 

400 if ack_type == "quit": 

401 return await websocket.close() 

402 

403 raise WebSocketDisconnect( 

404 code=4001, reason="Protocol violation: expected 'ack' message" 

405 ) 

406 

407 await models.task_workers.observe_worker([task_run.task_key], client_id) 

408 

409 except subscriptions.NORMAL_DISCONNECT_EXCEPTIONS: 

410 # If sending fails or pong fails, put the task back into the retry queue 

411 await asyncio.shield(TaskQueue.for_key(task_run.task_key).retry(task_run)) 

412 return 

413 finally: 

414 await models.task_workers.forget_worker(client_id)