Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/models/work_queues.py: 43%

186 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1""" 

2Functions for interacting with work queue ORM objects. 

3Intended for internal use by the Prefect REST API. 

4""" 

5 

6import datetime 1b

7from typing import ( 1b

8 Awaitable, 

9 Callable, 

10 Iterable, 

11 Optional, 

12 Sequence, 

13 Tuple, 

14 Union, 

15 cast, 

16) 

17from uuid import UUID 1b

18 

19import sqlalchemy as sa 1b

20from docket import Depends, Retry 1b

21from pydantic import TypeAdapter 1b

22from sqlalchemy import delete, select 1b

23from sqlalchemy.ext.asyncio import AsyncSession 1b

24 

25import prefect.server.models as models 1b

26import prefect.server.schemas as schemas 1b

27from prefect.server.database import ( 1b

28 PrefectDBInterface, 

29 db_injector, 

30 orm_models, 

31 provide_database_interface, 

32) 

33from prefect.server.events.clients import PrefectServerEventsClient 1b

34from prefect.server.exceptions import ObjectNotFoundError 1b

35from prefect.server.models.events import work_queue_status_event 1b

36from prefect.server.models.workers import ( 1b

37 DEFAULT_AGENT_WORK_POOL_NAME, 

38 bulk_update_work_queue_priorities, 

39) 

40from prefect.server.schemas.states import StateType 1b

41from prefect.server.schemas.statuses import WorkQueueStatus 1b

42from prefect.server.utilities.database import UUID as PrefectUUID 1b

43from prefect.types._datetime import DateTime, now 1b

44 

45WORK_QUEUE_LAST_POLLED_TIMEOUT = datetime.timedelta(seconds=60) 1b

46 

47 

48@db_injector 1b

49async def create_work_queue( 1b

50 db: PrefectDBInterface, 

51 session: AsyncSession, 

52 work_queue: Union[schemas.core.WorkQueue, schemas.actions.WorkQueueCreate], 

53) -> orm_models.WorkQueue: 

54 """ 

55 Inserts a WorkQueue. 

56 

57 If a WorkQueue with the same name exists, an error will be thrown. 

58 

59 Args: 

60 session (AsyncSession): a database session 

61 work_queue (schemas.core.WorkQueue): a WorkQueue model 

62 

63 Returns: 

64 orm_models.WorkQueue: the newly-created or updated WorkQueue 

65 

66 """ 

67 data = work_queue.model_dump() 1a

68 

69 if data.get("work_pool_id") is None: 69 ↛ 96line 69 didn't jump to line 96 because the condition on line 69 was always true1a

70 # If no work pool is provided, get or create the default agent work pool 

71 default_agent_work_pool = await models.workers.read_work_pool_by_name( 1a

72 session=session, work_pool_name=DEFAULT_AGENT_WORK_POOL_NAME 

73 ) 

74 if default_agent_work_pool: 

75 data["work_pool_id"] = default_agent_work_pool.id 

76 else: 

77 default_agent_work_pool = await models.workers.create_work_pool( 1a

78 session=session, 

79 work_pool=schemas.actions.WorkPoolCreate( 

80 name=DEFAULT_AGENT_WORK_POOL_NAME, type="prefect-agent" 

81 ), 

82 ) 

83 if work_queue.name == "default": 

84 # If the desired work queue name is default, it was created when the 

85 # work pool was created. We can just return it. 

86 default_work_queue = await models.workers.read_work_queue( 

87 session=session, 

88 work_queue_id=default_agent_work_pool.default_queue_id, 

89 ) 

90 assert default_work_queue 

91 return default_work_queue 

92 data["work_pool_id"] = default_agent_work_pool.id 

93 

94 # Set the priority to be the max priority + 1 

95 # This will make the new queue the lowest priority 

96 if data["priority"] is None: 

97 # Set the priority to be the first priority value that isn't already taken 

98 priorities_query = sa.select(db.WorkQueue.priority).where( 

99 db.WorkQueue.work_pool_id == data["work_pool_id"] 

100 ) 

101 priorities = (await session.execute(priorities_query)).scalars().all() 1a

102 

103 priority = None 

104 for i, p in enumerate(sorted(priorities)): 

105 # if a rank was skipped (e.g. the set priority is different than the 

106 # enumerated priority) then we can "take" that spot for this work 

107 # queue 

108 if i + 1 != p: 

109 priority = i + 1 

110 break 

111 

112 # otherwise take the maximum priority plus one 

113 if priority is None: 

114 priority = max(priorities, default=0) + 1 

115 

116 data["priority"] = priority 

117 

118 model = db.WorkQueue(**data) 

119 

120 session.add(model) 

121 await session.flush() 1a

122 await session.refresh(model) 1a

123 

124 if work_queue.priority: 

125 await bulk_update_work_queue_priorities( 

126 session=session, 

127 work_pool_id=data["work_pool_id"], 

128 new_priorities={model.id: work_queue.priority}, 

129 ) 

130 

131 return model 

132 

133 

134@db_injector 1b

135async def read_work_queue( 1b

136 db: PrefectDBInterface, 

137 session: AsyncSession, 

138 work_queue_id: Union[UUID, PrefectUUID], 

139) -> Optional[orm_models.WorkQueue]: 

140 """ 

141 Reads a WorkQueue by id. 

142 

143 Args: 

144 session (AsyncSession): A database session 

145 work_queue_id (str): a WorkQueue id 

146 

147 Returns: 

148 orm_models.WorkQueue: the WorkQueue 

149 """ 

150 

151 return await session.get(db.WorkQueue, work_queue_id) 1ac

152 

153 

154@db_injector 1b

155async def read_work_queue_by_name( 1b

156 db: PrefectDBInterface, session: AsyncSession, name: str 

157) -> Optional[orm_models.WorkQueue]: 

158 """ 

159 Reads a WorkQueue by id. 

160 

161 Args: 

162 session (AsyncSession): A database session 

163 work_queue_id (str): a WorkQueue id 

164 

165 Returns: 

166 orm_models.WorkQueue: the WorkQueue 

167 """ 

168 default_work_pool = await models.workers.read_work_pool_by_name( 1adc

169 session=session, work_pool_name=DEFAULT_AGENT_WORK_POOL_NAME 

170 ) 

171 # Logic to make sure this functionality doesn't break during migration 

172 if default_work_pool is not None: 

173 query = select(db.WorkQueue).filter_by( 

174 name=name, work_pool_id=default_work_pool.id 

175 ) 

176 else: 

177 query = select(db.WorkQueue).filter_by(name=name) 

178 result = await session.execute(query) 1adc

179 return result.scalar() 

180 

181 

182@db_injector 1b

183async def read_work_queues( 1b

184 db: PrefectDBInterface, 

185 session: AsyncSession, 

186 offset: Optional[int] = None, 

187 limit: Optional[int] = None, 

188 work_queue_filter: Optional[schemas.filters.WorkQueueFilter] = None, 

189) -> Sequence[orm_models.WorkQueue]: 

190 """ 

191 Read WorkQueues. 

192 

193 Args: 

194 session: A database session 

195 offset: Query offset 

196 limit: Query limit 

197 work_queue_filter: only select work queues matching these filters 

198 Returns: 

199 Sequence[orm_models.WorkQueue]: WorkQueues 

200 """ 

201 

202 query = select(db.WorkQueue).order_by(db.WorkQueue.name) 1ac

203 

204 if offset is not None: 204 ↛ 206line 204 didn't jump to line 206 because the condition on line 204 was always true1ac

205 query = query.offset(offset) 1ac

206 if limit is not None: 206 ↛ 208line 206 didn't jump to line 208 because the condition on line 206 was always true1ac

207 query = query.limit(limit) 1ac

208 if work_queue_filter: 1ac

209 query = query.where(work_queue_filter.as_sql_filter()) 1ac

210 

211 result = await session.execute(query) 1ac

212 return result.scalars().unique().all() 

213 

214 

215def is_last_polled_recent(last_polled: Optional[DateTime]) -> bool: 1b

216 if last_polled is None: 

217 return False 

218 return (now("UTC") - last_polled) <= WORK_QUEUE_LAST_POLLED_TIMEOUT 

219 

220 

221@db_injector 1b

222async def update_work_queue( 1b

223 db: PrefectDBInterface, 

224 session: AsyncSession, 

225 work_queue_id: UUID, 

226 work_queue: schemas.actions.WorkQueueUpdate, 

227 emit_status_change: Optional[ 

228 Callable[[orm_models.WorkQueue], Awaitable[None]] 

229 ] = None, 

230) -> bool: 

231 """ 

232 Update a WorkQueue by id. 

233 

234 Args: 

235 session (AsyncSession): A database session 

236 work_queue: the work queue data 

237 work_queue_id (str): a WorkQueue id 

238 

239 Returns: 

240 bool: whether or not the WorkQueue was updated 

241 """ 

242 # exclude_unset=True allows us to only update values provided by 

243 # the user, ignoring any defaults on the model 

244 update_data = work_queue.model_dump_for_orm(exclude_unset=True) 1a

245 

246 if "is_paused" in update_data: 246 ↛ 247line 246 didn't jump to line 247 because the condition on line 246 was never true1a

247 wq = await read_work_queue(session=session, work_queue_id=work_queue_id) 

248 if wq is None: 

249 return False 

250 

251 # Only update the status to paused if it's not already paused. This ensures a work queue that is already 

252 # paused will not get a status update if it's paused again 

253 if update_data.get("is_paused") and wq.status != WorkQueueStatus.PAUSED: 

254 update_data["status"] = WorkQueueStatus.PAUSED 

255 

256 # If unpausing, only update status if it's currently paused. This ensures a work queue that is already 

257 # unpaused will not get a status update if it's unpaused again 

258 if ( 

259 update_data.get("is_paused") is False 

260 and wq.status == WorkQueueStatus.PAUSED 

261 ): 

262 # Default status if unpaused 

263 update_data["status"] = WorkQueueStatus.NOT_READY 

264 

265 # Determine source of last_polled: update_data or database 

266 last_polled: Optional[DateTime] 

267 if "last_polled" in update_data: 

268 last_polled = cast(DateTime, update_data["last_polled"]) 

269 else: 

270 last_polled = wq.last_polled 

271 

272 # Check if last polled is recent and set status to READY if so 

273 if is_last_polled_recent(last_polled): 

274 update_data["status"] = schemas.statuses.WorkQueueStatus.READY 

275 

276 update_stmt = ( 1a

277 sa.update(db.WorkQueue) 

278 .where(db.WorkQueue.id == work_queue_id) 

279 .values(**update_data) 

280 ) 

281 result = await session.execute(update_stmt) 1a

282 updated = result.rowcount > 0 

283 

284 if updated: 

285 if "status" in update_data and emit_status_change: 

286 wq = await read_work_queue(session=session, work_queue_id=work_queue_id) 

287 assert wq 

288 await emit_status_change(wq) 

289 

290 return updated 

291 

292 

293@db_injector 1b

294async def delete_work_queue( 1b

295 db: PrefectDBInterface, session: AsyncSession, work_queue_id: UUID 

296) -> bool: 

297 """ 

298 Delete a WorkQueue by id. 

299 

300 Args: 

301 session (AsyncSession): A database session 

302 work_queue_id (str): a WorkQueue id 

303 

304 Returns: 

305 bool: whether or not the WorkQueue was deleted 

306 """ 

307 result = await session.execute( 1a

308 delete(db.WorkQueue).where(db.WorkQueue.id == work_queue_id) 

309 ) 

310 

311 return result.rowcount > 0 

312 

313 

314@db_injector 1b

315async def get_runs_in_work_queue( 1b

316 db: PrefectDBInterface, 

317 session: AsyncSession, 

318 work_queue_id: UUID, 

319 limit: Optional[int] = None, 

320 scheduled_before: Optional[datetime.datetime] = None, 

321) -> Tuple[orm_models.WorkQueue, Sequence[orm_models.FlowRun]]: 

322 """ 

323 Get runs from a work queue. 

324 

325 Args: 

326 session: A database session. work_queue_id: The work queue id. 

327 scheduled_before: Only return runs scheduled to start before this time. 

328 limit: An optional limit for the number of runs to return from the 

329 queue. This limit applies to the request only. It does not affect 

330 the work queue's concurrency limit. If `limit` exceeds the work 

331 queue's concurrency limit, it will be ignored. 

332 

333 """ 

334 work_queue = await read_work_queue(session=session, work_queue_id=work_queue_id) 1ac

335 if not work_queue: 

336 raise ObjectNotFoundError(f"Work queue with id {work_queue_id} not found.") 

337 

338 if work_queue.filter is None: 

339 query = db.queries.get_scheduled_flow_runs_from_work_queues( 

340 limit_per_queue=limit, 

341 work_queue_ids=[work_queue_id], 

342 scheduled_before=scheduled_before, 

343 ) 

344 result = await session.execute(query) 

345 return work_queue, result.scalars().unique().all() 

346 

347 # if the work queue has a filter, it's a deprecated tag-based work queue 

348 # and uses an old approach 

349 else: 

350 return work_queue, await _legacy_get_runs_in_work_queue( 

351 session=session, 

352 work_queue_id=work_queue_id, 

353 scheduled_before=scheduled_before, 

354 limit=limit, 

355 ) 

356 

357 

358async def _legacy_get_runs_in_work_queue( 1b

359 session: AsyncSession, 

360 work_queue_id: UUID, 

361 scheduled_before: Optional[datetime.datetime] = None, 

362 limit: Optional[int] = None, 

363) -> Sequence[orm_models.FlowRun]: 

364 """ 

365 DEPRECATED method for getting runs from a tag-based work queue 

366 

367 Args: 

368 session: A database session. 

369 work_queue_id: The work queue id. 

370 scheduled_before: Only return runs scheduled to start before this time. 

371 limit: An optional limit for the number of runs to return from the queue. 

372 This limit applies to the request only. It does not affect the 

373 work queue's concurrency limit. If `limit` exceeds the work queue's 

374 concurrency limit, it will be ignored. 

375 

376 """ 

377 

378 work_queue = await read_work_queue(session=session, work_queue_id=work_queue_id) 

379 if not work_queue: 

380 raise ObjectNotFoundError(f"Work queue with id {work_queue_id} not found.") 

381 

382 if work_queue.is_paused: 

383 return [] 

384 

385 # ensure the filter object is fully hydrated 

386 # SQLAlchemy caching logic can result in a dict type instead 

387 # of the full pydantic model 

388 work_queue_filter = TypeAdapter(schemas.core.QueueFilter).validate_python( 

389 work_queue.filter 

390 ) 

391 flow_run_filter = dict( 

392 tags=dict(all_=work_queue_filter.tags), 

393 deployment_id=dict(any_=work_queue_filter.deployment_ids, is_null_=False), 

394 ) 

395 

396 # if the work queue has a concurrency limit, check how many runs are currently 

397 # executing and compare that count to the concurrency limit 

398 if work_queue.concurrency_limit is not None: 

399 # Note this does not guarantee race conditions won't be hit 

400 running_frs = await models.flow_runs.count_flow_runs( 

401 session=session, 

402 flow_run_filter=schemas.filters.FlowRunFilter( 

403 **flow_run_filter, 

404 state=dict(type=dict(any_=[StateType.PENDING, StateType.RUNNING])), 

405 ), 

406 ) 

407 

408 # compute the available concurrency slots 

409 open_concurrency_slots = max(0, work_queue.concurrency_limit - running_frs) 

410 

411 # if a limit override was given, ensure we return no more 

412 # than that limit 

413 if limit is not None: 

414 limit = min(open_concurrency_slots, limit) 

415 else: 

416 limit = open_concurrency_slots 

417 

418 return await models.flow_runs.read_flow_runs( 

419 session=session, 

420 flow_run_filter=schemas.filters.FlowRunFilter( 

421 **flow_run_filter, 

422 state=dict(type=dict(any_=[StateType.SCHEDULED])), 

423 next_scheduled_start_time=dict(before_=scheduled_before), 

424 ), 

425 limit=limit, 

426 sort=schemas.sorting.FlowRunSort.NEXT_SCHEDULED_START_TIME_ASC, 

427 ) 

428 

429 

430async def ensure_work_queue_exists( 1b

431 session: AsyncSession, name: str 

432) -> orm_models.WorkQueue: 

433 """ 

434 Checks if a work queue exists and creates it if it does not. 

435 

436 Useful when working with deployments, agents, and flow runs that automatically create work queues. 

437 

438 Will also create a work pool queue in the default agent pool to facilitate migration to work pools. 

439 """ 

440 # read work queue 

441 work_queue = await models.work_queues.read_work_queue_by_name( 1adc

442 session=session, name=name 

443 ) 

444 if not work_queue: 

445 default_pool = await models.workers.read_work_pool_by_name( 1adc

446 session=session, work_pool_name=DEFAULT_AGENT_WORK_POOL_NAME 

447 ) 

448 

449 if default_pool is None: 

450 work_queue = await models.work_queues.create_work_queue( 

451 session=session, 

452 work_queue=schemas.actions.WorkQueueCreate(name=name, priority=1), 

453 ) 

454 else: 

455 if name != "default": 

456 work_queue = await models.workers.create_work_queue( 1adc

457 session=session, 

458 work_pool_id=default_pool.id, 

459 work_queue=schemas.actions.WorkQueueCreate(name=name, priority=1), 

460 ) 

461 else: 

462 work_queue = await models.work_queues.read_work_queue( 

463 session=session, work_queue_id=default_pool.default_queue_id 

464 ) 

465 assert work_queue, "Default work queue not found" 

466 

467 return work_queue 

468 

469 

470async def read_work_queue_status( 1b

471 session: AsyncSession, work_queue_id: UUID 

472) -> schemas.core.WorkQueueStatusDetail: 

473 """ 

474 Get work queue status by id. 

475 

476 Args: 

477 session (AsyncSession): A database session 

478 work_queue_id (str): a WorkQueue id 

479 

480 Returns: 

481 Information about the status of the work queue. 

482 """ 

483 

484 work_queue = await read_work_queue(session=session, work_queue_id=work_queue_id) 1a

485 if not work_queue: 

486 raise ObjectNotFoundError(f"Work queue with id {work_queue_id} not found") 

487 

488 work_queue_late_runs_count = await models.flow_runs.count_flow_runs( 

489 session=session, 

490 flow_run_filter=schemas.filters.FlowRunFilter( 

491 state=schemas.filters.FlowRunFilterState(name={"any_": ["Late"]}), 

492 ), 

493 work_queue_filter=schemas.filters.WorkQueueFilter( 

494 id=schemas.filters.WorkQueueFilterId(any_=[work_queue_id]) 

495 ), 

496 ) 

497 

498 # All work queues use the default policy for now 

499 health_check_policy = schemas.core.WorkQueueHealthPolicy( 

500 maximum_late_runs=0, maximum_seconds_since_last_polled=60 

501 ) 

502 

503 healthy = health_check_policy.evaluate_health_status( 

504 late_runs_count=work_queue_late_runs_count, 

505 last_polled=work_queue.last_polled, # type: ignore 

506 ) 

507 

508 return schemas.core.WorkQueueStatusDetail( 

509 healthy=healthy, 

510 late_runs_count=work_queue_late_runs_count, 

511 last_polled=work_queue.last_polled, 

512 health_check_policy=health_check_policy, 

513 ) 

514 

515 

516@db_injector 1b

517async def record_work_queue_polls( 1b

518 db: PrefectDBInterface, 

519 session: AsyncSession, 

520 polled_work_queue_ids: Sequence[UUID], 

521 ready_work_queue_ids: Sequence[UUID], 

522) -> None: 

523 """Record that the given work queues were polled, and also update the given 

524 ready_work_queue_ids to READY.""" 

525 polled = now("UTC") 1a

526 

527 if polled_work_queue_ids: 527 ↛ 528line 527 didn't jump to line 528 because the condition on line 527 was never true1a

528 await session.execute( 

529 sa.update(db.WorkQueue) 

530 .where(db.WorkQueue.id.in_(polled_work_queue_ids)) 

531 .values(last_polled=polled) 

532 ) 

533 

534 if ready_work_queue_ids: 534 ↛ exitline 534 didn't return from function 'record_work_queue_polls' because the condition on line 534 was always true1a

535 await session.execute( 1a

536 sa.update(db.WorkQueue) 

537 .where(db.WorkQueue.id.in_(ready_work_queue_ids)) 

538 .values(last_polled=polled, status=WorkQueueStatus.READY) 

539 ) 

540 

541 

542async def mark_work_queues_ready( 1b

543 *, 

544 db: PrefectDBInterface = Depends(provide_database_interface), 

545 polled_work_queue_ids: Sequence[UUID], 

546 ready_work_queue_ids: Sequence[UUID], 

547 retry: Retry = Retry(attempts=5, delay=datetime.timedelta(seconds=0.5)), 

548) -> None: 

549 async with db.session_context(begin_transaction=True) as session: 1a

550 await record_work_queue_polls( 1a

551 session=session, 

552 polled_work_queue_ids=polled_work_queue_ids, 

553 ready_work_queue_ids=ready_work_queue_ids, 

554 ) 

555 

556 # Emit events for any work queues that have transitioned to ready during this poll 

557 # Uses a separate transaction to avoid keeping locks open longer from the updates 

558 # in the previous transaction 

559 if not ready_work_queue_ids: 559 ↛ 560line 559 didn't jump to line 560 because the condition on line 559 was never true1a

560 return 

561 

562 async with db.session_context(begin_transaction=True) as session: 1a

563 newly_ready_work_queues = await session.execute( 1a

564 sa.select(db.WorkQueue).where(db.WorkQueue.id.in_(ready_work_queue_ids)) 

565 ) 

566 

567 events = [ 1a

568 await work_queue_status_event( 

569 session=session, 

570 work_queue=work_queue, 

571 occurred=now("UTC"), 

572 ) 

573 for work_queue in newly_ready_work_queues.scalars().all() 

574 ] 

575 

576 async with PrefectServerEventsClient() as events_client: 1a

577 for event in events: 1a

578 await events_client.emit(event) 1a

579 

580 

581@db_injector 1b

582async def mark_work_queues_not_ready( 1b

583 db: PrefectDBInterface, 

584 work_queue_ids: Iterable[UUID], 

585) -> None: 

586 if not work_queue_ids: 1ac

587 return 1ac

588 

589 async with db.session_context(begin_transaction=True) as session: 1a

590 await session.execute( 1a

591 sa.update(db.WorkQueue) 

592 .where(db.WorkQueue.id.in_(work_queue_ids)) 

593 .values(status=WorkQueueStatus.NOT_READY) 

594 ) 

595 

596 # Emit events for any work queues that have transitioned to ready during this poll 

597 # Uses a separate transaction to avoid keeping locks open longer from the updates 

598 # in the previous transaction 

599 

600 async with db.session_context(begin_transaction=True) as session: 1a

601 newly_unready_work_queues = await session.execute( 1a

602 sa.select(db.WorkQueue).where(db.WorkQueue.id.in_(work_queue_ids)) 

603 ) 

604 

605 events = [ 

606 await work_queue_status_event( 

607 session=session, 

608 work_queue=work_queue, 

609 occurred=now("UTC"), 

610 ) 

611 for work_queue in newly_unready_work_queues.scalars().all() 

612 ] 

613 

614 async with PrefectServerEventsClient() as events_client: 1a

615 for event in events: 1a

616 await events_client.emit(event) 1a

617 

618 

619@db_injector 1b

620async def emit_work_queue_status_event( 1b

621 db: PrefectDBInterface, 

622 work_queue: orm_models.WorkQueue, 

623) -> None: 

624 async with db.session_context() as session: 

625 event = await work_queue_status_event( 

626 session=session, 

627 work_queue=work_queue, 

628 occurred=now("UTC"), 

629 ) 

630 

631 async with PrefectServerEventsClient() as events_client: 

632 await events_client.emit(event)