Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/models/task_runs.py: 69%
148 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
1"""
2Functions for interacting with task run ORM objects.
3Intended for internal use by the Prefect REST API.
4"""
6import contextlib 1d
7from typing import ( 1d
8 TYPE_CHECKING,
9 Any,
10 Dict,
11 Optional,
12 Sequence,
13 Type,
14 TypeVar,
15 Union,
16 cast,
17)
18from uuid import UUID 1d
20import sqlalchemy as sa 1d
21from sqlalchemy import delete, select 1d
22from sqlalchemy.ext.asyncio import AsyncSession 1d
23from sqlalchemy.sql import Select 1d
25import prefect.server.models as models 1d
26import prefect.server.schemas as schemas 1d
27from prefect.logging import get_logger 1d
28from prefect.server.database import PrefectDBInterface, db_injector, orm_models 1d
29from prefect.server.exceptions import ObjectNotFoundError 1d
30from prefect.server.orchestration.core_policy import ( 1d
31 BackgroundTaskPolicy,
32 MinimalTaskPolicy,
33)
34from prefect.server.orchestration.global_policy import GlobalTaskPolicy 1d
35from prefect.server.orchestration.policies import ( 1d
36 TaskRunOrchestrationPolicy,
37)
38from prefect.server.orchestration.rules import TaskOrchestrationContext 1d
39from prefect.server.schemas.responses import OrchestrationResult 1d
40from prefect.types._datetime import now 1d
42if TYPE_CHECKING: 42 ↛ 43line 42 didn't jump to line 43 because the condition on line 42 was never true1d
43 import logging
45T = TypeVar("T", bound=tuple[Any, ...]) 1d
47logger: "logging.Logger" = get_logger("server") 1d
50@db_injector 1d
51async def create_task_run( 1d
52 db: PrefectDBInterface,
53 session: AsyncSession,
54 task_run: schemas.core.TaskRun,
55 orchestration_parameters: Optional[Dict[str, Any]] = None,
56) -> orm_models.TaskRun:
57 """
58 Creates a new task run.
60 If a task run with the same flow_run_id, task_key, and dynamic_key already exists,
61 the existing task run will be returned. If the provided task run has a state
62 attached, it will also be created.
64 Args:
65 session: a database session
66 task_run: a task run model
68 Returns:
69 orm_models.TaskRun: the newly-created or existing task run
70 """
72 right_now = now("UTC") 1bca
73 model: Union[orm_models.TaskRun, None]
75 task_run.labels = await with_system_labels_for_task_run( 1bca
76 session=session, task_run=task_run
77 )
79 # if a dynamic key exists, we need to guard against conflicts
80 if task_run.flow_run_id: 80 ↛ 81line 80 didn't jump to line 81 because the condition on line 80 was never true1bca
81 insert_stmt = (
82 db.queries.insert(db.TaskRun)
83 .values(
84 created=right_now,
85 **task_run.model_dump_for_orm(
86 exclude={"state", "created"}, exclude_unset=True
87 ),
88 )
89 .on_conflict_do_nothing(
90 index_elements=db.orm.task_run_unique_upsert_columns,
91 )
92 )
93 await session.execute(insert_stmt)
95 query = (
96 sa.select(db.TaskRun)
97 .where(
98 sa.and_(
99 db.TaskRun.flow_run_id == task_run.flow_run_id,
100 db.TaskRun.task_key == task_run.task_key,
101 db.TaskRun.dynamic_key == task_run.dynamic_key,
102 )
103 )
104 .limit(1)
105 .execution_options(populate_existing=True)
106 )
107 result = await session.execute(query)
108 model = result.scalar_one()
109 else:
110 # Upsert on (task_key, dynamic_key) application logic.
111 query = ( 1bca
112 sa.select(db.TaskRun)
113 .where(
114 sa.and_(
115 db.TaskRun.flow_run_id.is_(None),
116 db.TaskRun.task_key == task_run.task_key,
117 db.TaskRun.dynamic_key == task_run.dynamic_key,
118 )
119 )
120 .limit(1)
121 .execution_options(populate_existing=True)
122 )
124 result = await session.execute(query) 1bca
125 model = result.scalar()
127 if model is None:
128 model = db.TaskRun(
129 created=right_now,
130 **task_run.model_dump_for_orm(
131 exclude={"state", "created"}, exclude_unset=True
132 ),
133 state=None,
134 )
135 session.add(model)
136 await session.flush() 1bca
138 if model.created == right_now and task_run.state:
139 await models.task_runs.set_task_run_state( 1bca
140 session=session,
141 task_run_id=model.id,
142 state=task_run.state,
143 force=True,
144 orchestration_parameters=orchestration_parameters,
145 )
147 return model
150@db_injector 1d
151async def update_task_run( 1d
152 db: PrefectDBInterface,
153 session: AsyncSession,
154 task_run_id: UUID,
155 task_run: schemas.actions.TaskRunUpdate,
156) -> bool:
157 """
158 Updates a task run.
160 Args:
161 session: a database session
162 task_run_id: the task run id to update
163 task_run: a task run model
165 Returns:
166 bool: whether or not matching rows were found to update
167 """
168 update_stmt = ( 1ba
169 sa.update(db.TaskRun)
170 .where(db.TaskRun.id == task_run_id)
171 # exclude_unset=True allows us to only update values provided by
172 # the user, ignoring any defaults on the model
173 .values(**task_run.model_dump_for_orm(exclude_unset=True))
174 )
175 result = await session.execute(update_stmt) 1ba
176 return result.rowcount > 0
179@db_injector 1d
180async def read_task_run( 1d
181 db: PrefectDBInterface, session: AsyncSession, task_run_id: UUID
182) -> Union[orm_models.TaskRun, None]:
183 """
184 Read a task run by id.
186 Args:
187 session: a database session
188 task_run_id: the task run id
190 Returns:
191 orm_models.TaskRun: the task run
192 """
194 model = await session.get(db.TaskRun, task_run_id) 1bca
195 return model 1bca
198@db_injector 1d
199async def read_task_run_with_flow_run_name( 1d
200 db: PrefectDBInterface, session: AsyncSession, task_run_id: UUID
201) -> Union[orm_models.TaskRun, None]:
202 """
203 Read a task run by id.
205 Args:
206 session: a database session
207 task_run_id: the task run id
209 Returns:
210 orm_models.TaskRun: the task run with the flow run name
211 """
213 result = await session.execute( 1ba
214 select(orm_models.TaskRun, orm_models.FlowRun.name.label("flow_run_name"))
215 .outerjoin(
216 orm_models.FlowRun, orm_models.TaskRun.flow_run_id == orm_models.FlowRun.id
217 )
218 .where(orm_models.TaskRun.id == task_run_id)
219 )
220 row = result.first()
221 if not row:
222 return None
224 task_run = row[0]
225 flow_run_name = row[1]
226 if flow_run_name:
227 setattr(task_run, "flow_run_name", flow_run_name)
228 return task_run
231async def _apply_task_run_filters( 1d
232 db: PrefectDBInterface,
233 query: Select[T],
234 flow_filter: Optional[schemas.filters.FlowFilter] = None,
235 flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None,
236 task_run_filter: Optional[schemas.filters.TaskRunFilter] = None,
237 deployment_filter: Optional[schemas.filters.DeploymentFilter] = None,
238 work_pool_filter: Optional[schemas.filters.WorkPoolFilter] = None,
239 work_queue_filter: Optional[schemas.filters.WorkQueueFilter] = None,
240) -> Select[T]:
241 """
242 Applies filters to a task run query as a combination of EXISTS subqueries.
243 """
245 if task_run_filter: 1bca
246 query = query.where(task_run_filter.as_sql_filter()) 1ba
248 # Return a simplified query in the case that the request is ONLY asking to filter on flow_run_id (and task_run_filter)
249 # In this case there's no need to generate the complex EXISTS subqueries; the generated query here is much more efficient
250 if ( 250 ↛ 259line 250 didn't jump to line 259 because the condition on line 250 was never true
251 flow_run_filter
252 and flow_run_filter.only_filters_on_id()
253 and flow_run_filter.id
254 and flow_run_filter.id.any_
255 and not any(
256 [flow_filter, deployment_filter, work_pool_filter, work_queue_filter]
257 )
258 ):
259 query = query.where(db.TaskRun.flow_run_id.in_(flow_run_filter.id.any_))
261 return query
263 if (
264 flow_filter
265 or flow_run_filter
266 or deployment_filter
267 or work_pool_filter
268 or work_queue_filter
269 ):
270 exists_clause = select(db.FlowRun).where( 1ba
271 db.FlowRun.id == db.TaskRun.flow_run_id
272 )
274 if flow_run_filter: 1ba
275 exists_clause = exists_clause.where(flow_run_filter.as_sql_filter()) 1ba
277 if flow_filter: 1ba
278 exists_clause = exists_clause.join( 1ba
279 db.Flow,
280 db.Flow.id == db.FlowRun.flow_id,
281 ).where(flow_filter.as_sql_filter())
283 if deployment_filter: 1ba
284 exists_clause = exists_clause.join( 1ba
285 db.Deployment,
286 db.Deployment.id == db.FlowRun.deployment_id,
287 ).where(deployment_filter.as_sql_filter())
289 if work_queue_filter: 289 ↛ 290line 289 didn't jump to line 290 because the condition on line 289 was never true1ba
290 exists_clause = exists_clause.join(
291 db.WorkQueue,
292 db.WorkQueue.id == db.FlowRun.work_queue_id,
293 ).where(work_queue_filter.as_sql_filter())
295 if work_pool_filter: 295 ↛ 296line 295 didn't jump to line 296 because the condition on line 295 was never true1ba
296 exists_clause = exists_clause.join(
297 db.WorkPool,
298 sa.and_(
299 db.WorkPool.id == db.WorkQueue.work_pool_id,
300 db.WorkQueue.id == db.FlowRun.work_queue_id,
301 ),
302 ).where(work_pool_filter.as_sql_filter())
304 query = query.where(exists_clause.exists()) 1ba
306 return query 1bca
309@db_injector 1d
310async def read_task_runs( 1d
311 db: PrefectDBInterface,
312 session: AsyncSession,
313 flow_filter: Optional[schemas.filters.FlowFilter] = None,
314 flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None,
315 task_run_filter: Optional[schemas.filters.TaskRunFilter] = None,
316 deployment_filter: Optional[schemas.filters.DeploymentFilter] = None,
317 offset: Optional[int] = None,
318 limit: Optional[int] = None,
319 sort: schemas.sorting.TaskRunSort = schemas.sorting.TaskRunSort.ID_DESC,
320) -> Sequence[orm_models.TaskRun]:
321 """
322 Read task runs.
324 Args:
325 session: a database session
326 flow_filter: only select task runs whose flows match these filters
327 flow_run_filter: only select task runs whose flow runs match these filters
328 task_run_filter: only select task runs that match these filters
329 deployment_filter: only select task runs whose deployments match these filters
330 offset: Query offset
331 limit: Query limit
332 sort: Query sort
334 Returns:
335 List[orm_models.TaskRun]: the task runs
336 """
338 query = select(db.TaskRun).order_by(*sort.as_sql_sort()) 1bca
340 query = await _apply_task_run_filters( 1bca
341 db,
342 query,
343 flow_filter=flow_filter,
344 flow_run_filter=flow_run_filter,
345 task_run_filter=task_run_filter,
346 deployment_filter=deployment_filter,
347 )
349 if offset is not None: 349 ↛ 352line 349 didn't jump to line 352 because the condition on line 349 was always true1bca
350 query = query.offset(offset) 1bca
352 if limit is not None: 352 ↛ 355line 352 didn't jump to line 355 because the condition on line 352 was always true1bca
353 query = query.limit(limit) 1bca
355 logger.debug(f"In read_task_runs, query generated is:\n{query}") 1bca
356 result = await session.execute(query) 1bca
357 return result.scalars().unique().all()
360@db_injector 1d
361async def count_task_runs( 1d
362 db: PrefectDBInterface,
363 session: AsyncSession,
364 flow_filter: Optional[schemas.filters.FlowFilter] = None,
365 flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None,
366 task_run_filter: Optional[schemas.filters.TaskRunFilter] = None,
367 deployment_filter: Optional[schemas.filters.DeploymentFilter] = None,
368) -> int:
369 """
370 Count task runs.
372 Args:
373 session: a database session
374 flow_filter: only count task runs whose flows match these filters
375 flow_run_filter: only count task runs whose flow runs match these filters
376 task_run_filter: only count task runs that match these filters
377 deployment_filter: only count task runs whose deployments match these filters
378 Returns:
379 int: count of task runs
380 """
382 query = select(sa.func.count(None)).select_from(db.TaskRun) 1ba
384 query = await _apply_task_run_filters( 1ba
385 db,
386 query,
387 flow_filter=flow_filter,
388 flow_run_filter=flow_run_filter,
389 task_run_filter=task_run_filter,
390 deployment_filter=deployment_filter,
391 )
393 result = await session.execute(query) 1ba
394 return result.scalar_one()
397@db_injector 1d
398async def count_task_runs_by_state( 1d
399 db: PrefectDBInterface,
400 session: AsyncSession,
401 flow_filter: Optional[schemas.filters.FlowFilter] = None,
402 flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None,
403 task_run_filter: Optional[schemas.filters.TaskRunFilter] = None,
404 deployment_filter: Optional[schemas.filters.DeploymentFilter] = None,
405) -> schemas.states.CountByState:
406 """
407 Count task runs by state.
409 Args:
410 session: a database session
411 flow_filter: only count task runs whose flows match these filters
412 flow_run_filter: only count task runs whose flow runs match these filters
413 task_run_filter: only count task runs that match these filters
414 deployment_filter: only count task runs whose deployments match these filters
415 Returns:
416 schemas.states.CountByState: count of task runs by state
417 """
419 base_query = ( 1ba
420 select(db.TaskRun.state_type, sa.func.count(None).label("count"))
421 .select_from(db.TaskRun)
422 .group_by(db.TaskRun.state_type)
423 )
425 query = await _apply_task_run_filters( 1ba
426 db,
427 base_query,
428 flow_filter=flow_filter,
429 flow_run_filter=flow_run_filter,
430 task_run_filter=task_run_filter,
431 deployment_filter=deployment_filter,
432 )
434 result = await session.execute(query) 1ba
436 counts = schemas.states.CountByState()
438 for row in result:
439 setattr(counts, row.state_type, row.count)
441 return counts
444@db_injector 1d
445async def delete_task_run( 1d
446 db: PrefectDBInterface, session: AsyncSession, task_run_id: UUID
447) -> bool:
448 """
449 Delete a task run by id.
451 Args:
452 session: a database session
453 task_run_id: the task run id to delete
455 Returns:
456 bool: whether or not the task run was deleted
457 """
459 result = await session.execute( 1ba
460 delete(db.TaskRun).where(db.TaskRun.id == task_run_id)
461 )
462 return result.rowcount > 0
465async def set_task_run_state( 1d
466 session: AsyncSession,
467 task_run_id: UUID,
468 state: schemas.states.State,
469 force: bool = False,
470 task_policy: Optional[Type[TaskRunOrchestrationPolicy]] = None,
471 orchestration_parameters: Optional[Dict[str, Any]] = None,
472) -> OrchestrationResult:
473 """
474 Creates a new orchestrated task run state.
476 Setting a new state on a run is the one of the principal actions that is governed by
477 Prefect's orchestration logic. Setting a new run state will not guarantee creation,
478 but instead trigger orchestration rules to govern the proposed `state` input. If
479 the state is considered valid, it will be written to the database. Otherwise, a
480 it's possible a different state, or no state, will be created. A `force` flag is
481 supplied to bypass a subset of orchestration logic.
483 Args:
484 session: a database session
485 task_run_id: the task run id
486 state: a task run state model
487 force: if False, orchestration rules will be applied that may alter or prevent
488 the state transition. If True, orchestration rules are not applied.
490 Returns:
491 OrchestrationResult object
492 """
494 # load the task run
495 run = await models.task_runs.read_task_run(session=session, task_run_id=task_run_id) 1bca
497 if not run: 497 ↛ 498line 497 didn't jump to line 498 because the condition on line 497 was never true1bca
498 raise ObjectNotFoundError(f"Task run with id {task_run_id} not found")
500 initial_state = run.state.as_state() if run.state else None 1bca
501 initial_state_type = initial_state.type if initial_state else None 1bca
502 proposed_state_type = state.type if state else None 1bca
503 intended_transition = (initial_state_type, proposed_state_type) 1bca
505 if state.state_details.deferred: 505 ↛ 506line 505 didn't jump to line 506 because the condition on line 505 was never true1bca
506 task_policy = BackgroundTaskPolicy # CoreTaskPolicy + prevent `Running` -> `Running` transition
507 elif force or task_policy is None: 507 ↛ 510line 507 didn't jump to line 510 because the condition on line 507 was always true1bca
508 task_policy = MinimalTaskPolicy 1bca
510 orchestration_rules = task_policy.compile_transition_rules(*intended_transition) # type: ignore 1bca
511 global_rules = GlobalTaskPolicy.compile_transition_rules(*intended_transition) 1bca
513 context = TaskOrchestrationContext( 1bca
514 session=session,
515 run=run,
516 initial_state=initial_state,
517 proposed_state=state,
518 )
520 if orchestration_parameters is not None: 520 ↛ 524line 520 didn't jump to line 524 because the condition on line 520 was always true1bca
521 context.parameters = orchestration_parameters 1bca
523 # apply orchestration rules and create the new task run state
524 async with contextlib.AsyncExitStack() as stack: 1bca
525 for rule in orchestration_rules: 1bca
526 context = await stack.enter_async_context( 1bca
527 rule(context, *intended_transition)
528 )
530 for rule in global_rules: 1bca
531 context = await stack.enter_async_context( 1bca
532 rule(context, *intended_transition)
533 )
535 await context.validate_proposed_state() 1bca
537 if context.orchestration_error is not None:
538 raise context.orchestration_error
540 result = OrchestrationResult(
541 state=context.validated_state,
542 status=context.response_status,
543 details=context.response_details,
544 )
546 return result
549async def with_system_labels_for_task_run( 1d
550 session: AsyncSession,
551 task_run: schemas.core.TaskRun,
552) -> schemas.core.KeyValueLabels:
553 """Augment user supplied labels with system default labels for a task
554 run."""
556 client_supplied_labels = task_run.labels or {} 1bca
557 default_labels = cast(schemas.core.KeyValueLabels, {}) 1bca
558 parent_labels: schemas.core.KeyValueLabels = {} 1bca
560 if task_run.flow_run_id: 1bca
561 default_labels["prefect.flow-run.id"] = str(task_run.flow_run_id)
562 flow_run = await models.flow_runs.read_flow_run(
563 session=session, flow_run_id=task_run.flow_run_id
564 )
565 parent_labels = flow_run.labels if flow_run and flow_run.labels else {}
567 return parent_labels | default_labels | client_supplied_labels 1bca