Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/api/task_runs.py: 29%
139 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 13:38 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 13:38 +0000
1"""
2Routes for interacting with task run objects.
3"""
5import asyncio 1a
6import datetime 1a
7from typing import TYPE_CHECKING, Any, Dict, List, Optional 1a
8from uuid import UUID 1a
10from docket import Depends as DocketDepends 1a
11from docket import Retry 1a
12from fastapi import ( 1a
13 Body,
14 Depends,
15 HTTPException,
16 Path,
17 Response,
18 WebSocket,
19)
20from fastapi.responses import ORJSONResponse 1a
21from starlette.websockets import WebSocketDisconnect 1a
23import prefect.server.api.dependencies as dependencies 1a
24import prefect.server.models as models 1a
25import prefect.server.schemas as schemas 1a
26from prefect._internal.compatibility.starlette import status 1a
27from prefect.logging import get_logger 1a
28from prefect.server.api.run_history import run_history 1a
29from prefect.server.database import PrefectDBInterface, provide_database_interface 1a
30from prefect.server.orchestration import dependencies as orchestration_dependencies 1a
31from prefect.server.orchestration.core_policy import CoreTaskPolicy 1a
32from prefect.server.orchestration.policies import TaskRunOrchestrationPolicy 1a
33from prefect.server.schemas.responses import ( 1a
34 OrchestrationResult,
35 TaskRunPaginationResponse,
36)
37from prefect.server.task_queue import MultiQueue, TaskQueue 1a
38from prefect.server.utilities import subscriptions 1a
39from prefect.server.utilities.server import PrefectRouter 1a
40from prefect.types import DateTime 1a
41from prefect.types._datetime import now 1a
43if TYPE_CHECKING: 43 ↛ 44line 43 didn't jump to line 44 because the condition on line 43 was never true1a
44 import logging
46logger: "logging.Logger" = get_logger("server.api") 1a
48router: PrefectRouter = PrefectRouter(prefix="/task_runs", tags=["Task Runs"]) 1a
51@router.post("/") 1a
52async def create_task_run( 1a
53 task_run: schemas.actions.TaskRunCreate,
54 response: Response,
55 db: PrefectDBInterface = Depends(provide_database_interface),
56 orchestration_parameters: Dict[str, Any] = Depends(
57 orchestration_dependencies.provide_task_orchestration_parameters
58 ),
59) -> schemas.core.TaskRun:
60 """
61 Create a task run. If a task run with the same flow_run_id,
62 task_key, and dynamic_key already exists, the existing task
63 run will be returned.
65 If no state is provided, the task run will be created in a PENDING state.
67 For more information, see https://docs.prefect.io/v3/concepts/tasks.
68 """
69 # hydrate the input model into a full task run / state model
70 task_run_dict = task_run.model_dump()
71 if not task_run_dict.get("id"):
72 task_run_dict.pop("id", None)
73 task_run = schemas.core.TaskRun(**task_run_dict)
75 if not task_run.state:
76 task_run.state = schemas.states.Pending()
78 right_now = now("UTC")
80 async with db.session_context(begin_transaction=True) as session:
81 model = await models.task_runs.create_task_run(
82 session=session,
83 task_run=task_run,
84 orchestration_parameters=orchestration_parameters,
85 )
87 if model.created >= right_now:
88 response.status_code = status.HTTP_201_CREATED
90 new_task_run: schemas.core.TaskRun = schemas.core.TaskRun.model_validate(model)
92 return new_task_run
95@router.patch("/{id:uuid}", status_code=status.HTTP_204_NO_CONTENT) 1a
96async def update_task_run( 1a
97 task_run: schemas.actions.TaskRunUpdate,
98 task_run_id: UUID = Path(..., description="The task run id", alias="id"),
99 db: PrefectDBInterface = Depends(provide_database_interface),
100) -> None:
101 """
102 Updates a task run.
103 """
104 async with db.session_context(begin_transaction=True) as session:
105 result = await models.task_runs.update_task_run(
106 session=session, task_run=task_run, task_run_id=task_run_id
107 )
108 if not result:
109 raise HTTPException(status.HTTP_404_NOT_FOUND, detail="Task run not found")
112@router.post("/count") 1a
113async def count_task_runs( 1a
114 db: PrefectDBInterface = Depends(provide_database_interface),
115 flows: schemas.filters.FlowFilter = None,
116 flow_runs: schemas.filters.FlowRunFilter = None,
117 task_runs: schemas.filters.TaskRunFilter = None,
118 deployments: schemas.filters.DeploymentFilter = None,
119) -> int:
120 """
121 Count task runs.
122 """
123 async with db.session_context() as session:
124 return await models.task_runs.count_task_runs(
125 session=session,
126 flow_filter=flows,
127 flow_run_filter=flow_runs,
128 task_run_filter=task_runs,
129 deployment_filter=deployments,
130 )
133@router.post("/history") 1a
134async def task_run_history( 1a
135 history_start: DateTime = Body(..., description="The history's start time."),
136 history_end: DateTime = Body(..., description="The history's end time."),
137 # Workaround for the fact that FastAPI does not let us configure ser_json_timedelta
138 # to represent timedeltas as floats in JSON.
139 history_interval: float = Body(
140 ...,
141 description=(
142 "The size of each history interval, in seconds. Must be at least 1 second."
143 ),
144 json_schema_extra={"format": "time-delta"},
145 alias="history_interval_seconds",
146 ),
147 flows: schemas.filters.FlowFilter = None,
148 flow_runs: schemas.filters.FlowRunFilter = None,
149 task_runs: schemas.filters.TaskRunFilter = None,
150 deployments: schemas.filters.DeploymentFilter = None,
151 db: PrefectDBInterface = Depends(provide_database_interface),
152) -> List[schemas.responses.HistoryResponse]:
153 """
154 Query for task run history data across a given range and interval.
155 """
156 if isinstance(history_interval, float):
157 history_interval = datetime.timedelta(seconds=history_interval)
159 if history_interval < datetime.timedelta(seconds=1):
160 raise HTTPException(
161 status.HTTP_422_UNPROCESSABLE_ENTITY,
162 detail="History interval must not be less than 1 second.",
163 )
165 async with db.session_context() as session:
166 return await run_history(
167 session=session,
168 run_type="task_run",
169 history_start=history_start,
170 history_end=history_end,
171 history_interval=history_interval,
172 flows=flows,
173 flow_runs=flow_runs,
174 task_runs=task_runs,
175 deployments=deployments,
176 )
179@router.get("/{id:uuid}") 1a
180async def read_task_run( 1a
181 task_run_id: UUID = Path(..., description="The task run id", alias="id"),
182 db: PrefectDBInterface = Depends(provide_database_interface),
183) -> schemas.core.TaskRun:
184 """
185 Get a task run by id.
186 """
187 async with db.session_context() as session:
188 task_run = await models.task_runs.read_task_run(
189 session=session, task_run_id=task_run_id
190 )
191 if not task_run:
192 raise HTTPException(status.HTTP_404_NOT_FOUND, detail="Task not found")
193 return task_run
196@router.post("/filter") 1a
197async def read_task_runs( 1a
198 sort: schemas.sorting.TaskRunSort = Body(schemas.sorting.TaskRunSort.ID_DESC),
199 limit: int = dependencies.LimitBody(),
200 offset: int = Body(0, ge=0),
201 flows: Optional[schemas.filters.FlowFilter] = None,
202 flow_runs: Optional[schemas.filters.FlowRunFilter] = None,
203 task_runs: Optional[schemas.filters.TaskRunFilter] = None,
204 deployments: Optional[schemas.filters.DeploymentFilter] = None,
205 db: PrefectDBInterface = Depends(provide_database_interface),
206) -> List[schemas.core.TaskRun]:
207 """
208 Query for task runs.
209 """
210 async with db.session_context() as session:
211 return await models.task_runs.read_task_runs(
212 session=session,
213 flow_filter=flows,
214 flow_run_filter=flow_runs,
215 task_run_filter=task_runs,
216 deployment_filter=deployments,
217 offset=offset,
218 limit=limit,
219 sort=sort,
220 )
223@router.post("/paginate", response_class=ORJSONResponse) 1a
224async def paginate_task_runs( 1a
225 sort: schemas.sorting.TaskRunSort = Body(schemas.sorting.TaskRunSort.ID_DESC),
226 limit: int = dependencies.LimitBody(),
227 page: int = Body(1, ge=1),
228 flows: Optional[schemas.filters.FlowFilter] = None,
229 flow_runs: Optional[schemas.filters.FlowRunFilter] = None,
230 task_runs: Optional[schemas.filters.TaskRunFilter] = None,
231 deployments: Optional[schemas.filters.DeploymentFilter] = None,
232 db: PrefectDBInterface = Depends(provide_database_interface),
233) -> TaskRunPaginationResponse:
234 """
235 Pagination query for task runs.
236 """
237 offset = (page - 1) * limit
239 async with db.session_context() as session:
240 runs = await models.task_runs.read_task_runs(
241 session=session,
242 flow_filter=flows,
243 flow_run_filter=flow_runs,
244 task_run_filter=task_runs,
245 deployment_filter=deployments,
246 offset=offset,
247 limit=limit,
248 sort=sort,
249 )
251 total_count = await models.task_runs.count_task_runs(
252 session=session,
253 flow_filter=flows,
254 flow_run_filter=flow_runs,
255 task_run_filter=task_runs,
256 deployment_filter=deployments,
257 )
259 return TaskRunPaginationResponse.model_validate(
260 dict(
261 results=runs,
262 count=total_count,
263 limit=limit,
264 pages=(total_count + limit - 1) // limit,
265 page=page,
266 )
267 )
270@router.delete("/{id:uuid}", status_code=status.HTTP_204_NO_CONTENT) 1a
271async def delete_task_run( 1a
272 docket: dependencies.Docket,
273 task_run_id: UUID = Path(..., description="The task run id", alias="id"),
274 db: PrefectDBInterface = Depends(provide_database_interface),
275) -> None:
276 """
277 Delete a task run by id.
278 """
279 async with db.session_context(begin_transaction=True) as session:
280 result = await models.task_runs.delete_task_run(
281 session=session, task_run_id=task_run_id
282 )
283 if not result:
284 raise HTTPException(status.HTTP_404_NOT_FOUND, detail="Task not found")
285 await docket.add(delete_task_run_logs)(task_run_id=task_run_id)
288async def delete_task_run_logs( 1a
289 *,
290 db: PrefectDBInterface = DocketDepends(provide_database_interface),
291 task_run_id: UUID,
292 retry: Retry = Retry(attempts=5, delay=datetime.timedelta(seconds=0.5)),
293) -> None:
294 async with db.session_context(begin_transaction=True) as session:
295 await models.logs.delete_logs(
296 session=session,
297 log_filter=schemas.filters.LogFilter(
298 task_run_id=schemas.filters.LogFilterTaskRunId(any_=[task_run_id])
299 ),
300 )
303@router.post("/{id:uuid}/set_state") 1a
304async def set_task_run_state( 1a
305 task_run_id: UUID = Path(..., description="The task run id", alias="id"),
306 state: schemas.actions.StateCreate = Body(..., description="The intended state."),
307 force: bool = Body(
308 False,
309 description=(
310 "If false, orchestration rules will be applied that may alter or prevent"
311 " the state transition. If True, orchestration rules are not applied."
312 ),
313 ),
314 db: PrefectDBInterface = Depends(provide_database_interface),
315 response: Response = None,
316 task_policy: TaskRunOrchestrationPolicy = Depends(
317 orchestration_dependencies.provide_task_policy
318 ),
319 orchestration_parameters: Dict[str, Any] = Depends(
320 orchestration_dependencies.provide_task_orchestration_parameters
321 ),
322) -> OrchestrationResult:
323 """Set a task run state, invoking any orchestration rules."""
325 right_now = now("UTC")
327 # create the state
328 async with db.session_context(
329 begin_transaction=True, with_for_update=True
330 ) as session:
331 orchestration_result = await models.task_runs.set_task_run_state(
332 session=session,
333 task_run_id=task_run_id,
334 state=schemas.states.State.model_validate(
335 state
336 ), # convert to a full State object
337 force=force,
338 task_policy=CoreTaskPolicy,
339 orchestration_parameters=orchestration_parameters,
340 )
342 # set the 201 if a new state was created
343 if orchestration_result.state and orchestration_result.state.timestamp >= right_now:
344 response.status_code = status.HTTP_201_CREATED
345 else:
346 response.status_code = status.HTTP_200_OK
348 return orchestration_result
351@router.websocket("/subscriptions/scheduled") 1a
352async def scheduled_task_subscription(websocket: WebSocket) -> None: 1a
353 websocket = await subscriptions.accept_prefect_socket(websocket)
354 if not websocket:
355 return
357 try:
358 subscription = await websocket.receive_json()
359 except subscriptions.NORMAL_DISCONNECT_EXCEPTIONS:
360 return
362 if subscription.get("type") != "subscribe":
363 return await websocket.close(
364 code=4001, reason="Protocol violation: expected 'subscribe' message"
365 )
367 task_keys = subscription.get("keys", [])
368 if not task_keys:
369 return await websocket.close(
370 code=4001, reason="Protocol violation: expected 'keys' in subscribe message"
371 )
373 if not (client_id := subscription.get("client_id")):
374 return await websocket.close(
375 code=4001,
376 reason="Protocol violation: expected 'client_id' in subscribe message",
377 )
379 subscribed_queue = MultiQueue(task_keys)
381 logger.info(f"Task worker {client_id!r} subscribed to task keys {task_keys!r}")
383 while True:
384 try:
385 # observe here so that all workers with active websockets are tracked
386 await models.task_workers.observe_worker(task_keys, client_id)
387 task_run = await asyncio.wait_for(subscribed_queue.get(), timeout=1)
388 except asyncio.TimeoutError:
389 if not await subscriptions.still_connected(websocket):
390 await models.task_workers.forget_worker(client_id)
391 return
392 continue
394 try:
395 await websocket.send_json(task_run.model_dump(mode="json"))
397 acknowledgement = await websocket.receive_json()
398 ack_type = acknowledgement.get("type")
399 if ack_type != "ack":
400 if ack_type == "quit":
401 return await websocket.close()
403 raise WebSocketDisconnect(
404 code=4001, reason="Protocol violation: expected 'ack' message"
405 )
407 await models.task_workers.observe_worker([task_run.task_key], client_id)
409 except subscriptions.NORMAL_DISCONNECT_EXCEPTIONS:
410 # If sending fails or pong fails, put the task back into the retry queue
411 await asyncio.shield(TaskQueue.for_key(task_run.task_key).retry(task_run))
412 return
413 finally:
414 await models.task_workers.forget_worker(client_id)