Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/models/work_queues.py: 43%
186 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
1"""
2Functions for interacting with work queue ORM objects.
3Intended for internal use by the Prefect REST API.
4"""
6import datetime 1b
7from typing import ( 1b
8 Awaitable,
9 Callable,
10 Iterable,
11 Optional,
12 Sequence,
13 Tuple,
14 Union,
15 cast,
16)
17from uuid import UUID 1b
19import sqlalchemy as sa 1b
20from docket import Depends, Retry 1b
21from pydantic import TypeAdapter 1b
22from sqlalchemy import delete, select 1b
23from sqlalchemy.ext.asyncio import AsyncSession 1b
25import prefect.server.models as models 1b
26import prefect.server.schemas as schemas 1b
27from prefect.server.database import ( 1b
28 PrefectDBInterface,
29 db_injector,
30 orm_models,
31 provide_database_interface,
32)
33from prefect.server.events.clients import PrefectServerEventsClient 1b
34from prefect.server.exceptions import ObjectNotFoundError 1b
35from prefect.server.models.events import work_queue_status_event 1b
36from prefect.server.models.workers import ( 1b
37 DEFAULT_AGENT_WORK_POOL_NAME,
38 bulk_update_work_queue_priorities,
39)
40from prefect.server.schemas.states import StateType 1b
41from prefect.server.schemas.statuses import WorkQueueStatus 1b
42from prefect.server.utilities.database import UUID as PrefectUUID 1b
43from prefect.types._datetime import DateTime, now 1b
45WORK_QUEUE_LAST_POLLED_TIMEOUT = datetime.timedelta(seconds=60) 1b
48@db_injector 1b
49async def create_work_queue( 1b
50 db: PrefectDBInterface,
51 session: AsyncSession,
52 work_queue: Union[schemas.core.WorkQueue, schemas.actions.WorkQueueCreate],
53) -> orm_models.WorkQueue:
54 """
55 Inserts a WorkQueue.
57 If a WorkQueue with the same name exists, an error will be thrown.
59 Args:
60 session (AsyncSession): a database session
61 work_queue (schemas.core.WorkQueue): a WorkQueue model
63 Returns:
64 orm_models.WorkQueue: the newly-created or updated WorkQueue
66 """
67 data = work_queue.model_dump() 1a
69 if data.get("work_pool_id") is None: 69 ↛ 96line 69 didn't jump to line 96 because the condition on line 69 was always true1a
70 # If no work pool is provided, get or create the default agent work pool
71 default_agent_work_pool = await models.workers.read_work_pool_by_name( 1a
72 session=session, work_pool_name=DEFAULT_AGENT_WORK_POOL_NAME
73 )
74 if default_agent_work_pool:
75 data["work_pool_id"] = default_agent_work_pool.id
76 else:
77 default_agent_work_pool = await models.workers.create_work_pool( 1a
78 session=session,
79 work_pool=schemas.actions.WorkPoolCreate(
80 name=DEFAULT_AGENT_WORK_POOL_NAME, type="prefect-agent"
81 ),
82 )
83 if work_queue.name == "default":
84 # If the desired work queue name is default, it was created when the
85 # work pool was created. We can just return it.
86 default_work_queue = await models.workers.read_work_queue(
87 session=session,
88 work_queue_id=default_agent_work_pool.default_queue_id,
89 )
90 assert default_work_queue
91 return default_work_queue
92 data["work_pool_id"] = default_agent_work_pool.id
94 # Set the priority to be the max priority + 1
95 # This will make the new queue the lowest priority
96 if data["priority"] is None:
97 # Set the priority to be the first priority value that isn't already taken
98 priorities_query = sa.select(db.WorkQueue.priority).where(
99 db.WorkQueue.work_pool_id == data["work_pool_id"]
100 )
101 priorities = (await session.execute(priorities_query)).scalars().all() 1a
103 priority = None
104 for i, p in enumerate(sorted(priorities)):
105 # if a rank was skipped (e.g. the set priority is different than the
106 # enumerated priority) then we can "take" that spot for this work
107 # queue
108 if i + 1 != p:
109 priority = i + 1
110 break
112 # otherwise take the maximum priority plus one
113 if priority is None:
114 priority = max(priorities, default=0) + 1
116 data["priority"] = priority
118 model = db.WorkQueue(**data)
120 session.add(model)
121 await session.flush() 1a
122 await session.refresh(model) 1a
124 if work_queue.priority:
125 await bulk_update_work_queue_priorities(
126 session=session,
127 work_pool_id=data["work_pool_id"],
128 new_priorities={model.id: work_queue.priority},
129 )
131 return model
134@db_injector 1b
135async def read_work_queue( 1b
136 db: PrefectDBInterface,
137 session: AsyncSession,
138 work_queue_id: Union[UUID, PrefectUUID],
139) -> Optional[orm_models.WorkQueue]:
140 """
141 Reads a WorkQueue by id.
143 Args:
144 session (AsyncSession): A database session
145 work_queue_id (str): a WorkQueue id
147 Returns:
148 orm_models.WorkQueue: the WorkQueue
149 """
151 return await session.get(db.WorkQueue, work_queue_id) 1ac
154@db_injector 1b
155async def read_work_queue_by_name( 1b
156 db: PrefectDBInterface, session: AsyncSession, name: str
157) -> Optional[orm_models.WorkQueue]:
158 """
159 Reads a WorkQueue by id.
161 Args:
162 session (AsyncSession): A database session
163 work_queue_id (str): a WorkQueue id
165 Returns:
166 orm_models.WorkQueue: the WorkQueue
167 """
168 default_work_pool = await models.workers.read_work_pool_by_name( 1adc
169 session=session, work_pool_name=DEFAULT_AGENT_WORK_POOL_NAME
170 )
171 # Logic to make sure this functionality doesn't break during migration
172 if default_work_pool is not None:
173 query = select(db.WorkQueue).filter_by(
174 name=name, work_pool_id=default_work_pool.id
175 )
176 else:
177 query = select(db.WorkQueue).filter_by(name=name)
178 result = await session.execute(query) 1adc
179 return result.scalar()
182@db_injector 1b
183async def read_work_queues( 1b
184 db: PrefectDBInterface,
185 session: AsyncSession,
186 offset: Optional[int] = None,
187 limit: Optional[int] = None,
188 work_queue_filter: Optional[schemas.filters.WorkQueueFilter] = None,
189) -> Sequence[orm_models.WorkQueue]:
190 """
191 Read WorkQueues.
193 Args:
194 session: A database session
195 offset: Query offset
196 limit: Query limit
197 work_queue_filter: only select work queues matching these filters
198 Returns:
199 Sequence[orm_models.WorkQueue]: WorkQueues
200 """
202 query = select(db.WorkQueue).order_by(db.WorkQueue.name) 1ac
204 if offset is not None: 204 ↛ 206line 204 didn't jump to line 206 because the condition on line 204 was always true1ac
205 query = query.offset(offset) 1ac
206 if limit is not None: 206 ↛ 208line 206 didn't jump to line 208 because the condition on line 206 was always true1ac
207 query = query.limit(limit) 1ac
208 if work_queue_filter: 1ac
209 query = query.where(work_queue_filter.as_sql_filter()) 1ac
211 result = await session.execute(query) 1ac
212 return result.scalars().unique().all()
215def is_last_polled_recent(last_polled: Optional[DateTime]) -> bool: 1b
216 if last_polled is None:
217 return False
218 return (now("UTC") - last_polled) <= WORK_QUEUE_LAST_POLLED_TIMEOUT
221@db_injector 1b
222async def update_work_queue( 1b
223 db: PrefectDBInterface,
224 session: AsyncSession,
225 work_queue_id: UUID,
226 work_queue: schemas.actions.WorkQueueUpdate,
227 emit_status_change: Optional[
228 Callable[[orm_models.WorkQueue], Awaitable[None]]
229 ] = None,
230) -> bool:
231 """
232 Update a WorkQueue by id.
234 Args:
235 session (AsyncSession): A database session
236 work_queue: the work queue data
237 work_queue_id (str): a WorkQueue id
239 Returns:
240 bool: whether or not the WorkQueue was updated
241 """
242 # exclude_unset=True allows us to only update values provided by
243 # the user, ignoring any defaults on the model
244 update_data = work_queue.model_dump_for_orm(exclude_unset=True) 1a
246 if "is_paused" in update_data: 246 ↛ 247line 246 didn't jump to line 247 because the condition on line 246 was never true1a
247 wq = await read_work_queue(session=session, work_queue_id=work_queue_id)
248 if wq is None:
249 return False
251 # Only update the status to paused if it's not already paused. This ensures a work queue that is already
252 # paused will not get a status update if it's paused again
253 if update_data.get("is_paused") and wq.status != WorkQueueStatus.PAUSED:
254 update_data["status"] = WorkQueueStatus.PAUSED
256 # If unpausing, only update status if it's currently paused. This ensures a work queue that is already
257 # unpaused will not get a status update if it's unpaused again
258 if (
259 update_data.get("is_paused") is False
260 and wq.status == WorkQueueStatus.PAUSED
261 ):
262 # Default status if unpaused
263 update_data["status"] = WorkQueueStatus.NOT_READY
265 # Determine source of last_polled: update_data or database
266 last_polled: Optional[DateTime]
267 if "last_polled" in update_data:
268 last_polled = cast(DateTime, update_data["last_polled"])
269 else:
270 last_polled = wq.last_polled
272 # Check if last polled is recent and set status to READY if so
273 if is_last_polled_recent(last_polled):
274 update_data["status"] = schemas.statuses.WorkQueueStatus.READY
276 update_stmt = ( 1a
277 sa.update(db.WorkQueue)
278 .where(db.WorkQueue.id == work_queue_id)
279 .values(**update_data)
280 )
281 result = await session.execute(update_stmt) 1a
282 updated = result.rowcount > 0
284 if updated:
285 if "status" in update_data and emit_status_change:
286 wq = await read_work_queue(session=session, work_queue_id=work_queue_id)
287 assert wq
288 await emit_status_change(wq)
290 return updated
293@db_injector 1b
294async def delete_work_queue( 1b
295 db: PrefectDBInterface, session: AsyncSession, work_queue_id: UUID
296) -> bool:
297 """
298 Delete a WorkQueue by id.
300 Args:
301 session (AsyncSession): A database session
302 work_queue_id (str): a WorkQueue id
304 Returns:
305 bool: whether or not the WorkQueue was deleted
306 """
307 result = await session.execute( 1a
308 delete(db.WorkQueue).where(db.WorkQueue.id == work_queue_id)
309 )
311 return result.rowcount > 0
314@db_injector 1b
315async def get_runs_in_work_queue( 1b
316 db: PrefectDBInterface,
317 session: AsyncSession,
318 work_queue_id: UUID,
319 limit: Optional[int] = None,
320 scheduled_before: Optional[datetime.datetime] = None,
321) -> Tuple[orm_models.WorkQueue, Sequence[orm_models.FlowRun]]:
322 """
323 Get runs from a work queue.
325 Args:
326 session: A database session. work_queue_id: The work queue id.
327 scheduled_before: Only return runs scheduled to start before this time.
328 limit: An optional limit for the number of runs to return from the
329 queue. This limit applies to the request only. It does not affect
330 the work queue's concurrency limit. If `limit` exceeds the work
331 queue's concurrency limit, it will be ignored.
333 """
334 work_queue = await read_work_queue(session=session, work_queue_id=work_queue_id) 1ac
335 if not work_queue:
336 raise ObjectNotFoundError(f"Work queue with id {work_queue_id} not found.")
338 if work_queue.filter is None:
339 query = db.queries.get_scheduled_flow_runs_from_work_queues(
340 limit_per_queue=limit,
341 work_queue_ids=[work_queue_id],
342 scheduled_before=scheduled_before,
343 )
344 result = await session.execute(query)
345 return work_queue, result.scalars().unique().all()
347 # if the work queue has a filter, it's a deprecated tag-based work queue
348 # and uses an old approach
349 else:
350 return work_queue, await _legacy_get_runs_in_work_queue(
351 session=session,
352 work_queue_id=work_queue_id,
353 scheduled_before=scheduled_before,
354 limit=limit,
355 )
358async def _legacy_get_runs_in_work_queue( 1b
359 session: AsyncSession,
360 work_queue_id: UUID,
361 scheduled_before: Optional[datetime.datetime] = None,
362 limit: Optional[int] = None,
363) -> Sequence[orm_models.FlowRun]:
364 """
365 DEPRECATED method for getting runs from a tag-based work queue
367 Args:
368 session: A database session.
369 work_queue_id: The work queue id.
370 scheduled_before: Only return runs scheduled to start before this time.
371 limit: An optional limit for the number of runs to return from the queue.
372 This limit applies to the request only. It does not affect the
373 work queue's concurrency limit. If `limit` exceeds the work queue's
374 concurrency limit, it will be ignored.
376 """
378 work_queue = await read_work_queue(session=session, work_queue_id=work_queue_id)
379 if not work_queue:
380 raise ObjectNotFoundError(f"Work queue with id {work_queue_id} not found.")
382 if work_queue.is_paused:
383 return []
385 # ensure the filter object is fully hydrated
386 # SQLAlchemy caching logic can result in a dict type instead
387 # of the full pydantic model
388 work_queue_filter = TypeAdapter(schemas.core.QueueFilter).validate_python(
389 work_queue.filter
390 )
391 flow_run_filter = dict(
392 tags=dict(all_=work_queue_filter.tags),
393 deployment_id=dict(any_=work_queue_filter.deployment_ids, is_null_=False),
394 )
396 # if the work queue has a concurrency limit, check how many runs are currently
397 # executing and compare that count to the concurrency limit
398 if work_queue.concurrency_limit is not None:
399 # Note this does not guarantee race conditions won't be hit
400 running_frs = await models.flow_runs.count_flow_runs(
401 session=session,
402 flow_run_filter=schemas.filters.FlowRunFilter(
403 **flow_run_filter,
404 state=dict(type=dict(any_=[StateType.PENDING, StateType.RUNNING])),
405 ),
406 )
408 # compute the available concurrency slots
409 open_concurrency_slots = max(0, work_queue.concurrency_limit - running_frs)
411 # if a limit override was given, ensure we return no more
412 # than that limit
413 if limit is not None:
414 limit = min(open_concurrency_slots, limit)
415 else:
416 limit = open_concurrency_slots
418 return await models.flow_runs.read_flow_runs(
419 session=session,
420 flow_run_filter=schemas.filters.FlowRunFilter(
421 **flow_run_filter,
422 state=dict(type=dict(any_=[StateType.SCHEDULED])),
423 next_scheduled_start_time=dict(before_=scheduled_before),
424 ),
425 limit=limit,
426 sort=schemas.sorting.FlowRunSort.NEXT_SCHEDULED_START_TIME_ASC,
427 )
430async def ensure_work_queue_exists( 1b
431 session: AsyncSession, name: str
432) -> orm_models.WorkQueue:
433 """
434 Checks if a work queue exists and creates it if it does not.
436 Useful when working with deployments, agents, and flow runs that automatically create work queues.
438 Will also create a work pool queue in the default agent pool to facilitate migration to work pools.
439 """
440 # read work queue
441 work_queue = await models.work_queues.read_work_queue_by_name( 1adc
442 session=session, name=name
443 )
444 if not work_queue:
445 default_pool = await models.workers.read_work_pool_by_name( 1adc
446 session=session, work_pool_name=DEFAULT_AGENT_WORK_POOL_NAME
447 )
449 if default_pool is None:
450 work_queue = await models.work_queues.create_work_queue(
451 session=session,
452 work_queue=schemas.actions.WorkQueueCreate(name=name, priority=1),
453 )
454 else:
455 if name != "default":
456 work_queue = await models.workers.create_work_queue( 1adc
457 session=session,
458 work_pool_id=default_pool.id,
459 work_queue=schemas.actions.WorkQueueCreate(name=name, priority=1),
460 )
461 else:
462 work_queue = await models.work_queues.read_work_queue(
463 session=session, work_queue_id=default_pool.default_queue_id
464 )
465 assert work_queue, "Default work queue not found"
467 return work_queue
470async def read_work_queue_status( 1b
471 session: AsyncSession, work_queue_id: UUID
472) -> schemas.core.WorkQueueStatusDetail:
473 """
474 Get work queue status by id.
476 Args:
477 session (AsyncSession): A database session
478 work_queue_id (str): a WorkQueue id
480 Returns:
481 Information about the status of the work queue.
482 """
484 work_queue = await read_work_queue(session=session, work_queue_id=work_queue_id) 1a
485 if not work_queue:
486 raise ObjectNotFoundError(f"Work queue with id {work_queue_id} not found")
488 work_queue_late_runs_count = await models.flow_runs.count_flow_runs(
489 session=session,
490 flow_run_filter=schemas.filters.FlowRunFilter(
491 state=schemas.filters.FlowRunFilterState(name={"any_": ["Late"]}),
492 ),
493 work_queue_filter=schemas.filters.WorkQueueFilter(
494 id=schemas.filters.WorkQueueFilterId(any_=[work_queue_id])
495 ),
496 )
498 # All work queues use the default policy for now
499 health_check_policy = schemas.core.WorkQueueHealthPolicy(
500 maximum_late_runs=0, maximum_seconds_since_last_polled=60
501 )
503 healthy = health_check_policy.evaluate_health_status(
504 late_runs_count=work_queue_late_runs_count,
505 last_polled=work_queue.last_polled, # type: ignore
506 )
508 return schemas.core.WorkQueueStatusDetail(
509 healthy=healthy,
510 late_runs_count=work_queue_late_runs_count,
511 last_polled=work_queue.last_polled,
512 health_check_policy=health_check_policy,
513 )
516@db_injector 1b
517async def record_work_queue_polls( 1b
518 db: PrefectDBInterface,
519 session: AsyncSession,
520 polled_work_queue_ids: Sequence[UUID],
521 ready_work_queue_ids: Sequence[UUID],
522) -> None:
523 """Record that the given work queues were polled, and also update the given
524 ready_work_queue_ids to READY."""
525 polled = now("UTC") 1a
527 if polled_work_queue_ids: 527 ↛ 528line 527 didn't jump to line 528 because the condition on line 527 was never true1a
528 await session.execute(
529 sa.update(db.WorkQueue)
530 .where(db.WorkQueue.id.in_(polled_work_queue_ids))
531 .values(last_polled=polled)
532 )
534 if ready_work_queue_ids: 534 ↛ exitline 534 didn't return from function 'record_work_queue_polls' because the condition on line 534 was always true1a
535 await session.execute( 1a
536 sa.update(db.WorkQueue)
537 .where(db.WorkQueue.id.in_(ready_work_queue_ids))
538 .values(last_polled=polled, status=WorkQueueStatus.READY)
539 )
542async def mark_work_queues_ready( 1b
543 *,
544 db: PrefectDBInterface = Depends(provide_database_interface),
545 polled_work_queue_ids: Sequence[UUID],
546 ready_work_queue_ids: Sequence[UUID],
547 retry: Retry = Retry(attempts=5, delay=datetime.timedelta(seconds=0.5)),
548) -> None:
549 async with db.session_context(begin_transaction=True) as session: 1a
550 await record_work_queue_polls( 1a
551 session=session,
552 polled_work_queue_ids=polled_work_queue_ids,
553 ready_work_queue_ids=ready_work_queue_ids,
554 )
556 # Emit events for any work queues that have transitioned to ready during this poll
557 # Uses a separate transaction to avoid keeping locks open longer from the updates
558 # in the previous transaction
559 if not ready_work_queue_ids: 559 ↛ 560line 559 didn't jump to line 560 because the condition on line 559 was never true1a
560 return
562 async with db.session_context(begin_transaction=True) as session: 1a
563 newly_ready_work_queues = await session.execute( 1a
564 sa.select(db.WorkQueue).where(db.WorkQueue.id.in_(ready_work_queue_ids))
565 )
567 events = [ 1a
568 await work_queue_status_event(
569 session=session,
570 work_queue=work_queue,
571 occurred=now("UTC"),
572 )
573 for work_queue in newly_ready_work_queues.scalars().all()
574 ]
576 async with PrefectServerEventsClient() as events_client: 1a
577 for event in events: 1a
578 await events_client.emit(event) 1a
581@db_injector 1b
582async def mark_work_queues_not_ready( 1b
583 db: PrefectDBInterface,
584 work_queue_ids: Iterable[UUID],
585) -> None:
586 if not work_queue_ids: 1ac
587 return 1ac
589 async with db.session_context(begin_transaction=True) as session: 1a
590 await session.execute( 1a
591 sa.update(db.WorkQueue)
592 .where(db.WorkQueue.id.in_(work_queue_ids))
593 .values(status=WorkQueueStatus.NOT_READY)
594 )
596 # Emit events for any work queues that have transitioned to ready during this poll
597 # Uses a separate transaction to avoid keeping locks open longer from the updates
598 # in the previous transaction
600 async with db.session_context(begin_transaction=True) as session: 1a
601 newly_unready_work_queues = await session.execute( 1a
602 sa.select(db.WorkQueue).where(db.WorkQueue.id.in_(work_queue_ids))
603 )
605 events = [
606 await work_queue_status_event(
607 session=session,
608 work_queue=work_queue,
609 occurred=now("UTC"),
610 )
611 for work_queue in newly_unready_work_queues.scalars().all()
612 ]
614 async with PrefectServerEventsClient() as events_client: 1a
615 for event in events: 1a
616 await events_client.emit(event) 1a
619@db_injector 1b
620async def emit_work_queue_status_event( 1b
621 db: PrefectDBInterface,
622 work_queue: orm_models.WorkQueue,
623) -> None:
624 async with db.session_context() as session:
625 event = await work_queue_status_event(
626 session=session,
627 work_queue=work_queue,
628 occurred=now("UTC"),
629 )
631 async with PrefectServerEventsClient() as events_client:
632 await events_client.emit(event)