Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/models/workers.py: 17%
228 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
1"""
2Functions for interacting with worker ORM objects.
3Intended for internal use by the Prefect REST API.
4"""
6import datetime 1a
7from typing import ( 1a
8 Awaitable,
9 Callable,
10 Dict,
11 List,
12 Optional,
13 Sequence,
14 Union,
15)
16from uuid import UUID 1a
18import sqlalchemy as sa 1a
19from sqlalchemy import delete, select 1a
20from sqlalchemy.ext.asyncio import AsyncSession 1a
22import prefect.server.schemas as schemas 1a
23from prefect._internal.uuid7 import uuid7 1a
24from prefect.server.database import PrefectDBInterface, db_injector, orm_models 1a
25from prefect.server.events.clients import PrefectServerEventsClient 1a
26from prefect.server.exceptions import ObjectNotFoundError 1a
27from prefect.server.models.events import work_pool_status_event 1a
28from prefect.server.schemas.statuses import WorkQueueStatus 1a
29from prefect.server.utilities.database import UUID as PrefectUUID 1a
30from prefect.types._datetime import DateTime, now 1a
32DEFAULT_AGENT_WORK_POOL_NAME = "default-agent-pool" 1a
34# -----------------------------------------------------
35# --
36# --
37# -- Work Pools
38# --
39# --
40# -----------------------------------------------------
43@db_injector 1a
44async def create_work_pool( 1a
45 db: PrefectDBInterface,
46 session: AsyncSession,
47 work_pool: Union[schemas.core.WorkPool, schemas.actions.WorkPoolCreate],
48) -> orm_models.WorkPool:
49 """
50 Creates a work pool.
52 If a WorkPool with the same name exists, an error will be thrown.
54 Args:
55 session (AsyncSession): a database session
56 work_pool (schemas.core.WorkPool): a WorkPool model
58 Returns:
59 orm_models.WorkPool: the newly-created WorkPool
61 """
63 pool = db.WorkPool(**work_pool.model_dump())
65 if pool.type != "prefect-agent":
66 if pool.is_paused:
67 pool.status = schemas.statuses.WorkPoolStatus.PAUSED
68 else:
69 pool.status = schemas.statuses.WorkPoolStatus.NOT_READY
71 session.add(pool)
72 await session.flush()
74 default_queue = await create_work_queue(
75 session=session,
76 work_pool_id=pool.id,
77 work_queue=schemas.actions.WorkQueueCreate(
78 name="default", description="The work pool's default queue."
79 ),
80 )
82 pool.default_queue_id = default_queue.id # type: ignore
83 await session.flush()
85 return pool
88@db_injector 1a
89async def read_work_pool( 1a
90 db: PrefectDBInterface, session: AsyncSession, work_pool_id: UUID
91) -> Optional[orm_models.WorkPool]:
92 """
93 Reads a WorkPool by id.
95 Args:
96 session (AsyncSession): A database session
97 work_pool_id (UUID): a WorkPool id
99 Returns:
100 orm_models.WorkPool: the WorkPool
101 """
102 query = sa.select(db.WorkPool).where(db.WorkPool.id == work_pool_id).limit(1)
103 result = await session.execute(query)
104 return result.scalar()
107@db_injector 1a
108async def read_work_pool_by_name( 1a
109 db: PrefectDBInterface, session: AsyncSession, work_pool_name: str
110) -> Optional[orm_models.WorkPool]:
111 """
112 Reads a WorkPool by name.
114 Args:
115 session (AsyncSession): A database session
116 work_pool_name (str): a WorkPool name
118 Returns:
119 orm_models.WorkPool: the WorkPool
120 """
121 query = sa.select(db.WorkPool).where(db.WorkPool.name == work_pool_name).limit(1)
122 result = await session.execute(query)
123 return result.scalar()
126@db_injector 1a
127async def read_work_pools( 1a
128 db: PrefectDBInterface,
129 session: AsyncSession,
130 work_pool_filter: Optional[schemas.filters.WorkPoolFilter] = None,
131 offset: Optional[int] = None,
132 limit: Optional[int] = None,
133) -> Sequence[orm_models.WorkPool]:
134 """
135 Read worker configs.
137 Args:
138 session: A database session
139 offset: Query offset
140 limit: Query limit
141 Returns:
142 List[orm_models.WorkPool]: worker configs
143 """
145 query = select(db.WorkPool).order_by(db.WorkPool.name)
147 if work_pool_filter is not None:
148 query = query.where(work_pool_filter.as_sql_filter())
149 if offset is not None:
150 query = query.offset(offset)
151 if limit is not None:
152 query = query.limit(limit)
154 result = await session.execute(query)
155 return result.scalars().unique().all()
158@db_injector 1a
159async def count_work_pools( 1a
160 db: PrefectDBInterface,
161 session: AsyncSession,
162 work_pool_filter: Optional[schemas.filters.WorkPoolFilter] = None,
163) -> int:
164 """
165 Read worker configs.
167 Args:
168 session: A database session
169 work_pool_filter: filter criteria to apply to the count
170 Returns:
171 int: the count of work pools matching the criteria
172 """
174 query = select(sa.func.count()).select_from(db.WorkPool)
176 if work_pool_filter is not None:
177 query = query.where(work_pool_filter.as_sql_filter())
179 result = await session.execute(query)
180 return result.scalar_one()
183@db_injector 1a
184async def update_work_pool( 1a
185 db: PrefectDBInterface,
186 session: AsyncSession,
187 work_pool_id: UUID,
188 work_pool: schemas.actions.WorkPoolUpdate,
189 emit_status_change: Optional[
190 Callable[
191 [UUID, DateTime, orm_models.WorkPool, orm_models.WorkPool],
192 Awaitable[None],
193 ]
194 ] = None,
195) -> bool:
196 """
197 Update a WorkPool by id.
199 Args:
200 session (AsyncSession): A database session
201 work_pool_id (UUID): a WorkPool id
202 worker: the work queue data
203 emit_status_change: function to call when work pool
204 status is changed
206 Returns:
207 bool: whether or not the worker was updated
208 """
209 # exclude_unset=True allows us to only update values provided by
210 # the user, ignoring any defaults on the model
211 update_data = work_pool.model_dump_for_orm(exclude_unset=True)
213 current_work_pool = await read_work_pool(session=session, work_pool_id=work_pool_id)
214 if not current_work_pool:
215 raise ObjectNotFoundError
217 # Remove this from the session so we have a copy of the current state before we
218 # update it; this will give us something to compare against when emitting events
219 session.expunge(current_work_pool)
221 if current_work_pool.type != "prefect-agent":
222 if update_data.get("is_paused"):
223 update_data["status"] = schemas.statuses.WorkPoolStatus.PAUSED
225 if update_data.get("is_paused") is False:
226 # If the work pool has any online workers, set the status to READY
227 # Otherwise set it to, NOT_READY
228 workers = await read_workers(
229 session=session,
230 work_pool_id=work_pool_id,
231 worker_filter=schemas.filters.WorkerFilter(
232 status=schemas.filters.WorkerFilterStatus(
233 any_=[schemas.statuses.WorkerStatus.ONLINE]
234 )
235 ),
236 )
237 if len(workers) > 0:
238 update_data["status"] = schemas.statuses.WorkPoolStatus.READY
239 else:
240 update_data["status"] = schemas.statuses.WorkPoolStatus.NOT_READY
242 if "status" in update_data:
243 update_data["last_status_event_id"] = uuid7()
244 update_data["last_transitioned_status_at"] = now("UTC")
246 update_stmt = (
247 sa.update(db.WorkPool)
248 .where(db.WorkPool.id == work_pool_id)
249 .values(**update_data)
250 )
251 result = await session.execute(update_stmt)
253 updated = result.rowcount > 0
254 if updated:
255 wp = await read_work_pool(session=session, work_pool_id=work_pool_id)
257 assert wp is not None
258 assert current_work_pool is not wp
260 if "status" in update_data and emit_status_change:
261 await emit_status_change(
262 event_id=update_data["last_status_event_id"], # type: ignore
263 occurred=update_data["last_transitioned_status_at"],
264 pre_update_work_pool=current_work_pool,
265 work_pool=wp,
266 )
268 return updated
271@db_injector 1a
272async def delete_work_pool( 1a
273 db: PrefectDBInterface, session: AsyncSession, work_pool_id: UUID
274) -> bool:
275 """
276 Delete a WorkPool by id.
278 Args:
279 session (AsyncSession): A database session
280 work_pool_id (UUID): a work pool id
282 Returns:
283 bool: whether or not the WorkPool was deleted
284 """
286 result = await session.execute(
287 delete(db.WorkPool).where(db.WorkPool.id == work_pool_id)
288 )
289 return result.rowcount > 0
292@db_injector 1a
293async def get_scheduled_flow_runs( 1a
294 db: PrefectDBInterface,
295 session: AsyncSession,
296 work_pool_ids: Optional[List[UUID]] = None,
297 work_queue_ids: Optional[List[UUID]] = None,
298 scheduled_before: Optional[datetime.datetime] = None,
299 scheduled_after: Optional[datetime.datetime] = None,
300 limit: Optional[int] = None,
301 respect_queue_priorities: Optional[bool] = None,
302) -> Sequence[schemas.responses.WorkerFlowRunResponse]:
303 """
304 Get runs from queues in a specific work pool.
306 Args:
307 session (AsyncSession): a database session
308 work_pool_ids (List[UUID]): a list of work pool ids
309 work_queue_ids (List[UUID]): a list of work pool queue ids
310 scheduled_before (datetime.datetime): a datetime to filter runs scheduled before
311 scheduled_after (datetime.datetime): a datetime to filter runs scheduled after
312 respect_queue_priorities (bool): whether or not to respect queue priorities
313 limit (int): the maximum number of runs to return
314 db: a database interface
316 Returns:
317 List[WorkerFlowRunResponse]: the runs, as well as related work pool details
319 """
321 if respect_queue_priorities is None:
322 respect_queue_priorities = True
324 return await db.queries.get_scheduled_flow_runs_from_work_pool(
325 session=session,
326 work_pool_ids=work_pool_ids,
327 work_queue_ids=work_queue_ids,
328 scheduled_before=scheduled_before,
329 scheduled_after=scheduled_after,
330 respect_queue_priorities=respect_queue_priorities,
331 limit=limit,
332 )
335# -----------------------------------------------------
336# --
337# --
338# -- Work Pool Queues
339# --
340# --
341# -----------------------------------------------------
344@db_injector 1a
345async def create_work_queue( 1a
346 db: PrefectDBInterface,
347 session: AsyncSession,
348 work_pool_id: UUID,
349 work_queue: schemas.actions.WorkQueueCreate,
350) -> orm_models.WorkQueue:
351 """
352 Creates a work pool queue.
354 Args:
355 session (AsyncSession): a database session
356 work_pool_id (UUID): a work pool id
357 work_queue (schemas.actions.WorkQueueCreate): a WorkQueue action model
359 Returns:
360 orm_models.WorkQueue: the newly-created WorkQueue
362 """
363 data = work_queue.model_dump(exclude={"work_pool_id"})
364 if work_queue.priority is None:
365 # Set the priority to be the first priority value that isn't already taken
366 priorities_query = sa.select(db.WorkQueue.priority).where(
367 db.WorkQueue.work_pool_id == work_pool_id
368 )
369 priorities = (await session.execute(priorities_query)).scalars().all()
371 priority = None
372 for i, p in enumerate(sorted(priorities)):
373 # if a rank was skipped (e.g. the set priority is different than the
374 # enumerated priority) then we can "take" that spot for this work
375 # queue
376 if i + 1 != p:
377 priority = i + 1
378 break
380 # otherwise take the maximum priority plus one
381 if priority is None:
382 priority = max(priorities, default=0) + 1
384 data["priority"] = priority
386 model = db.WorkQueue(**data, work_pool_id=work_pool_id)
388 session.add(model)
389 await session.flush()
390 await session.refresh(model)
392 if work_queue.priority:
393 await bulk_update_work_queue_priorities(
394 session=session,
395 work_pool_id=work_pool_id,
396 new_priorities={model.id: work_queue.priority},
397 )
398 return model
401@db_injector 1a
402async def bulk_update_work_queue_priorities( 1a
403 db: PrefectDBInterface,
404 session: AsyncSession,
405 work_pool_id: UUID,
406 new_priorities: Dict[UUID, int],
407) -> None:
408 """
409 This is a brute force update of all work pool queue priorities for a given work
410 pool.
412 It loads all queues fully into memory, sorts them, and flushes the update to
413 the orm_models. The algorithm ensures that priorities are unique integers > 0, and
414 makes the minimum number of changes required to satisfy the provided
415 `new_priorities`. For example, if no queues currently have the provided
416 `new_priorities`, then they are assigned without affecting other queues. If
417 they are held by other queues, then those queues' priorities are
418 incremented as necessary.
420 Updating queue priorities is not a common operation (happens on the same scale as
421 queue modification, which is significantly less than reading from queues),
422 so while this implementation is slow, it may suffice and make up for that
423 with extreme simplicity.
424 """
426 if len(set(new_priorities.values())) != len(new_priorities):
427 raise ValueError("Duplicate target priorities provided")
429 # get all the work queues, sorted by priority
430 work_queues_query = (
431 sa.select(db.WorkQueue)
432 .where(db.WorkQueue.work_pool_id == work_pool_id)
433 .order_by(db.WorkQueue.priority.asc())
434 )
435 result = await session.execute(work_queues_query)
436 all_work_queues = result.scalars().all()
438 # split the queues into those that need to be updated and those that don't
439 work_queues = [wq for wq in all_work_queues if wq.id not in new_priorities]
440 updated_queues = [wq for wq in all_work_queues if wq.id in new_priorities]
442 # update queue priorities and insert them into the appropriate place in the
443 # full list of queues
444 for queue in sorted(updated_queues, key=lambda wq: new_priorities[wq.id]):
445 queue.priority = new_priorities[queue.id]
446 for i, wq in enumerate(work_queues):
447 if wq.priority >= new_priorities[queue.id]:
448 work_queues.insert(i, queue)
449 break
451 # walk through the queues and update their priorities such that the
452 # priorities are sequential. Do this by tracking that last priority seen and
453 # ensuring that each successive queue's priority is higher than it. This
454 # will maintain queue order and ensure increasing priorities with minimal
455 # changes.
456 last_priority = 0
457 for queue in work_queues:
458 if queue.priority <= last_priority:
459 last_priority += 1
460 queue.priority = last_priority
461 else:
462 last_priority = queue.priority
464 await session.flush()
467@db_injector 1a
468async def read_work_queues( 1a
469 db: PrefectDBInterface,
470 session: AsyncSession,
471 work_pool_id: UUID,
472 work_queue_filter: Optional[schemas.filters.WorkQueueFilter] = None,
473 offset: Optional[int] = None,
474 limit: Optional[int] = None,
475) -> Sequence[orm_models.WorkQueue]:
476 """
477 Read all work pool queues for a work pool. Results are ordered by ascending priority.
479 Args:
480 session (AsyncSession): a database session
481 work_pool_id (UUID): a work pool id
482 work_queue_filter: Filter criteria for work pool queues
483 offset: Query offset
484 limit: Query limit
487 Returns:
488 List[orm_models.WorkQueue]: the WorkQueues
490 """
491 query = (
492 sa.select(db.WorkQueue)
493 .where(db.WorkQueue.work_pool_id == work_pool_id)
494 .order_by(db.WorkQueue.priority.asc())
495 )
497 if work_queue_filter is not None:
498 query = query.where(work_queue_filter.as_sql_filter())
499 if offset is not None:
500 query = query.offset(offset)
501 if limit is not None:
502 query = query.limit(limit)
504 result = await session.execute(query)
505 return result.scalars().unique().all()
508@db_injector 1a
509async def read_work_queue( 1a
510 db: PrefectDBInterface,
511 session: AsyncSession,
512 work_queue_id: Union[UUID, PrefectUUID],
513) -> Optional[orm_models.WorkQueue]:
514 """
515 Read a specific work pool queue.
517 Args:
518 session (AsyncSession): a database session
519 work_queue_id (UUID): a work pool queue id
521 Returns:
522 orm_models.WorkQueue: the WorkQueue
524 """
525 return await session.get(db.WorkQueue, work_queue_id)
528@db_injector 1a
529async def read_work_queue_by_name( 1a
530 db: PrefectDBInterface,
531 session: AsyncSession,
532 work_pool_name: str,
533 work_queue_name: str,
534) -> Optional[orm_models.WorkQueue]:
535 """
536 Reads a WorkQueue by name.
538 Args:
539 session (AsyncSession): A database session
540 work_pool_name (str): a WorkPool name
541 work_queue_name (str): a WorkQueue name
543 Returns:
544 orm_models.WorkQueue: the WorkQueue
545 """
546 query = (
547 sa.select(db.WorkQueue)
548 .join(
549 db.WorkPool,
550 db.WorkPool.id == db.WorkQueue.work_pool_id,
551 )
552 .where(
553 db.WorkPool.name == work_pool_name,
554 db.WorkQueue.name == work_queue_name,
555 )
556 .limit(1)
557 )
558 result = await session.execute(query)
559 return result.scalar()
562@db_injector 1a
563async def update_work_queue( 1a
564 db: PrefectDBInterface,
565 session: AsyncSession,
566 work_queue_id: UUID,
567 work_queue: schemas.actions.WorkQueueUpdate,
568 emit_status_change: Optional[
569 Callable[[orm_models.WorkQueue], Awaitable[None]]
570 ] = None,
571 default_status: WorkQueueStatus = WorkQueueStatus.NOT_READY,
572) -> bool:
573 """
574 Update a work pool queue.
576 Args:
577 session (AsyncSession): a database session
578 work_queue_id (UUID): a work pool queue ID
579 work_queue (schemas.actions.WorkQueueUpdate): a WorkQueue model
580 emit_status_change: function to call when work queue
581 status is changed
583 Returns:
584 bool: whether or not the WorkQueue was updated
586 """
587 from prefect.server.models.work_queues import is_last_polled_recent
589 update_values = work_queue.model_dump_for_orm(exclude_unset=True)
591 if "is_paused" in update_values:
592 if (wq := await session.get(db.WorkQueue, work_queue_id)) is None:
593 return False
595 # Only update the status to paused if it's not already paused. This ensures a work queue that is already
596 # paused will not get a status update if it's paused again
597 if update_values.get("is_paused") and wq.status != WorkQueueStatus.PAUSED:
598 update_values["status"] = WorkQueueStatus.PAUSED
600 # If unpausing, only update status if it's currently paused. This ensures a work queue that is already
601 # unpaused will not get a status update if it's unpaused again
602 if (
603 update_values.get("is_paused") is False
604 and wq.status == WorkQueueStatus.PAUSED
605 ):
606 # Default status if unpaused
607 update_values["status"] = default_status
609 # Determine source of last_polled: update_data or database
610 if "last_polled" in update_values:
611 last_polled = update_values["last_polled"]
612 else:
613 last_polled = wq.last_polled
615 # Check if last polled is recent and set status to READY if so
616 if is_last_polled_recent(last_polled):
617 update_values["status"] = schemas.statuses.WorkQueueStatus.READY
619 update_stmt = (
620 sa.update(db.WorkQueue)
621 .where(db.WorkQueue.id == work_queue_id)
622 .values(update_values)
623 )
624 result = await session.execute(update_stmt)
626 updated = result.rowcount > 0
628 if updated:
629 if "priority" in update_values or "status" in update_values:
630 updated_work_queue = await session.get(db.WorkQueue, work_queue_id)
631 assert updated_work_queue
633 if "priority" in update_values:
634 await bulk_update_work_queue_priorities(
635 session,
636 work_pool_id=updated_work_queue.work_pool_id,
637 new_priorities={work_queue_id: update_values["priority"]},
638 )
640 if "status" in update_values and emit_status_change:
641 await emit_status_change(updated_work_queue)
643 return updated
646@db_injector 1a
647async def delete_work_queue( 1a
648 db: PrefectDBInterface,
649 session: AsyncSession,
650 work_queue_id: UUID,
651) -> bool:
652 """
653 Delete a work pool queue.
655 Args:
656 session (AsyncSession): a database session
657 work_queue_id (UUID): a work pool queue ID
659 Returns:
660 bool: whether or not the WorkQueue was deleted
662 """
663 work_queue = await session.get(db.WorkQueue, work_queue_id)
664 if work_queue is None:
665 return False
667 await session.delete(work_queue)
668 try:
669 await session.flush()
671 # if an error was raised, check if the user tried to delete a default queue
672 except sa.exc.IntegrityError as exc:
673 if "foreign key constraint" in str(exc).lower():
674 raise ValueError("Can't delete a pool's default queue.")
675 raise
677 await bulk_update_work_queue_priorities(
678 session,
679 work_pool_id=work_queue.work_pool_id,
680 new_priorities={},
681 )
682 return True
685# -----------------------------------------------------
686# --
687# --
688# -- Workers
689# --
690# --
691# -----------------------------------------------------
694@db_injector 1a
695async def read_workers( 1a
696 db: PrefectDBInterface,
697 session: AsyncSession,
698 work_pool_id: UUID,
699 worker_filter: Optional[schemas.filters.WorkerFilter] = None,
700 limit: Optional[int] = None,
701 offset: Optional[int] = None,
702) -> Sequence[orm_models.Worker]:
703 query = (
704 sa.select(db.Worker)
705 .where(db.Worker.work_pool_id == work_pool_id)
706 .order_by(db.Worker.last_heartbeat_time.desc())
707 .limit(limit)
708 )
710 if worker_filter:
711 query = query.where(worker_filter.as_sql_filter())
713 if limit is not None:
714 query = query.limit(limit)
716 if offset is not None:
717 query = query.offset(offset)
719 result = await session.execute(query)
720 return result.scalars().all()
723@db_injector 1a
724async def worker_heartbeat( 1a
725 db: PrefectDBInterface,
726 session: AsyncSession,
727 work_pool_id: UUID,
728 worker_name: str,
729 heartbeat_interval_seconds: Optional[int] = None,
730) -> bool:
731 """
732 Record a worker process heartbeat.
734 Args:
735 session (AsyncSession): a database session
736 work_pool_id (UUID): a work pool ID
737 worker_name (str): a worker name
739 Returns:
740 bool: whether or not the worker was updated
742 """
743 right_now = now("UTC")
744 # Values that won't change between heart beats
745 base_values = dict(
746 work_pool_id=work_pool_id,
747 name=worker_name,
748 )
749 # Values that can and will change between heartbeats
750 update_values = dict(
751 last_heartbeat_time=right_now,
752 status=schemas.statuses.WorkerStatus.ONLINE,
753 )
754 if heartbeat_interval_seconds is not None:
755 update_values["heartbeat_interval_seconds"] = heartbeat_interval_seconds
757 insert_stmt = (
758 db.queries.insert(db.Worker)
759 .values(**base_values, **update_values)
760 .on_conflict_do_update(
761 index_elements=[
762 db.Worker.work_pool_id,
763 db.Worker.name,
764 ],
765 set_=update_values,
766 )
767 )
769 result = await session.execute(insert_stmt)
770 return result.rowcount > 0
773@db_injector 1a
774async def delete_worker( 1a
775 db: PrefectDBInterface,
776 session: AsyncSession,
777 work_pool_id: UUID,
778 worker_name: str,
779) -> bool:
780 """
781 Delete a work pool's worker.
783 Args:
784 session (AsyncSession): a database session
785 work_pool_id (UUID): a work pool ID
786 worker_name (str): a worker name
788 Returns:
789 bool: whether or not the Worker was deleted
791 """
792 result = await session.execute(
793 delete(db.Worker).where(
794 db.Worker.work_pool_id == work_pool_id,
795 db.Worker.name == worker_name,
796 )
797 )
799 return result.rowcount > 0
802async def emit_work_pool_status_event( 1a
803 event_id: UUID,
804 occurred: DateTime,
805 pre_update_work_pool: Optional[orm_models.WorkPool],
806 work_pool: orm_models.WorkPool,
807) -> None:
808 if not work_pool.status:
809 return
811 async with PrefectServerEventsClient() as events_client:
812 await events_client.emit(
813 await work_pool_status_event(
814 event_id=event_id,
815 occurred=occurred,
816 pre_update_work_pool=pre_update_work_pool,
817 work_pool=work_pool,
818 )
819 )