Coverage for /usr/local/lib/python3.12/site-packages/prefect/client/orchestration/_work_pools/client.py: 13%
167 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
1from __future__ import annotations 1a
3import warnings 1a
4from datetime import datetime 1a
5from typing import TYPE_CHECKING, Any 1a
7from httpx import HTTPStatusError 1a
9from prefect.client.base import ServerType 1a
10from prefect.client.orchestration.base import BaseAsyncClient, BaseClient 1a
12if TYPE_CHECKING: 12 ↛ 13line 12 didn't jump to line 13 because the condition on line 12 was never true1a
13 from uuid import UUID
15 from prefect.client.schemas.actions import (
16 WorkPoolCreate,
17 WorkPoolUpdate,
18 )
19 from prefect.client.schemas.filters import (
20 WorkerFilter,
21 WorkPoolFilter,
22 )
23 from prefect.client.schemas.objects import (
24 Worker,
25 WorkerMetadata,
26 WorkPool,
27 )
28 from prefect.client.schemas.responses import WorkerFlowRunResponse
30from prefect.exceptions import ObjectAlreadyExists, ObjectNotFound, ObjectUnsupported 1a
33class WorkPoolClient(BaseClient): 1a
34 def send_worker_heartbeat( 1a
35 self,
36 work_pool_name: str,
37 worker_name: str,
38 heartbeat_interval_seconds: float | None = None,
39 get_worker_id: bool = False,
40 worker_metadata: "WorkerMetadata | None" = None,
41 ) -> "UUID | None":
42 """
43 Sends a worker heartbeat for a given work pool.
45 Args:
46 work_pool_name: The name of the work pool to heartbeat against.
47 worker_name: The name of the worker sending the heartbeat.
48 return_id: Whether to return the worker ID. Note: will return `None` if the connected server does not support returning worker IDs, even if `return_id` is `True`.
49 worker_metadata: Metadata about the worker to send to the server.
50 """
51 from uuid import UUID
53 params: dict[str, Any] = {
54 "name": worker_name,
55 "heartbeat_interval_seconds": heartbeat_interval_seconds,
56 }
57 if worker_metadata:
58 params["metadata"] = worker_metadata.model_dump(mode="json")
59 if get_worker_id:
60 params["return_id"] = get_worker_id
62 resp = self.request(
63 "POST",
64 "/work_pools/{work_pool_name}/workers/heartbeat",
65 path_params={"work_pool_name": work_pool_name},
66 json=params,
67 )
68 from prefect.settings import get_current_settings
70 if (
71 (
72 self.server_type == ServerType.CLOUD
73 or get_current_settings().testing.test_mode
74 )
75 and get_worker_id
76 and resp.status_code == 200
77 ):
78 return UUID(resp.text)
79 else:
80 return None
82 def read_workers_for_work_pool( 1a
83 self,
84 work_pool_name: str,
85 worker_filter: "WorkerFilter | None" = None,
86 offset: int | None = None,
87 limit: int | None = None,
88 ) -> list["Worker"]:
89 """
90 Reads workers for a given work pool.
92 Args:
93 work_pool_name: The name of the work pool for which to get
94 member workers.
95 worker_filter: Criteria by which to filter workers.
96 limit: Limit for the worker query.
97 offset: Limit for the worker query.
98 """
99 from prefect.client.schemas.objects import Worker
101 response = self.request(
102 "POST",
103 "/work_pools/{work_pool_name}/workers/filter",
104 path_params={"work_pool_name": work_pool_name},
105 json={
106 "workers": (
107 worker_filter.model_dump(mode="json", exclude_unset=True)
108 if worker_filter
109 else None
110 ),
111 "offset": offset,
112 "limit": limit,
113 },
114 )
116 return Worker.model_validate_list(response.json())
118 def read_work_pool(self, work_pool_name: str) -> "WorkPool": 1a
119 """
120 Reads information for a given work pool
122 Args:
123 work_pool_name: The name of the work pool to for which to get
124 information.
126 Returns:
127 Information about the requested work pool.
128 """
129 from prefect.client.schemas.objects import WorkPool
131 try:
132 response = self.request(
133 "GET",
134 "/work_pools/{name}",
135 path_params={"name": work_pool_name},
136 )
137 return WorkPool.model_validate(response.json())
138 except HTTPStatusError as e:
139 if e.response.status_code == 404:
140 raise ObjectNotFound(http_exc=e) from e
141 else:
142 raise
144 def read_work_pools( 1a
145 self,
146 limit: int | None = None,
147 offset: int = 0,
148 work_pool_filter: "WorkPoolFilter | None" = None,
149 ) -> list["WorkPool"]:
150 """
151 Reads work pools.
153 Args:
154 limit: Limit for the work pool query.
155 offset: Offset for the work pool query.
156 work_pool_filter: Criteria by which to filter work pools.
158 Returns:
159 A list of work pools.
160 """
161 from prefect.client.schemas.objects import WorkPool
163 body: dict[str, Any] = {
164 "limit": limit,
165 "offset": offset,
166 "work_pools": (
167 work_pool_filter.model_dump(mode="json") if work_pool_filter else None
168 ),
169 }
170 response = self.request("POST", "/work_pools/filter", json=body)
171 return WorkPool.model_validate_list(response.json())
173 def create_work_pool( 1a
174 self,
175 work_pool: "WorkPoolCreate",
176 overwrite: bool = False,
177 ) -> "WorkPool":
178 """
179 Creates a work pool with the provided configuration.
181 Args:
182 work_pool: Desired configuration for the new work pool.
184 Returns:
185 Information about the newly created work pool.
186 """
187 from prefect.client.schemas.actions import WorkPoolUpdate
188 from prefect.client.schemas.objects import WorkPool
190 try:
191 response = self.request(
192 "POST",
193 "/work_pools/",
194 json=work_pool.model_dump(mode="json", exclude_unset=True),
195 )
196 response.raise_for_status()
197 except HTTPStatusError as e:
198 if e.response.status_code == 403 and "plan does not support" in str(e):
199 raise ObjectUnsupported(http_exc=e) from e
200 if e.response.status_code == 409:
201 if overwrite:
202 existing_work_pool = self.read_work_pool(
203 work_pool_name=work_pool.name
204 )
205 if existing_work_pool.type != work_pool.type:
206 warnings.warn(
207 "Overwriting work pool type is not supported. Ignoring provided type.",
208 category=UserWarning,
209 )
210 self.update_work_pool(
211 work_pool_name=work_pool.name,
212 work_pool=WorkPoolUpdate.model_validate(
213 work_pool.model_dump(exclude={"name", "type"})
214 ),
215 )
216 response = self.request(
217 "GET",
218 "/work_pools/{name}",
219 path_params={"name": work_pool.name},
220 )
221 else:
222 raise ObjectAlreadyExists(http_exc=e) from e
223 else:
224 raise
226 return WorkPool.model_validate(response.json())
228 def update_work_pool( 1a
229 self,
230 work_pool_name: str,
231 work_pool: "WorkPoolUpdate",
232 ) -> None:
233 """
234 Updates a work pool.
236 Args:
237 work_pool_name: Name of the work pool to update.
238 work_pool: Fields to update in the work pool.
239 """
240 try:
241 self.request(
242 "PATCH",
243 "/work_pools/{name}",
244 path_params={"name": work_pool_name},
245 json=work_pool.model_dump(mode="json", exclude_unset=True),
246 )
247 except HTTPStatusError as e:
248 if e.response.status_code == 404:
249 raise ObjectNotFound(http_exc=e) from e
250 else:
251 raise
253 def delete_work_pool( 1a
254 self,
255 work_pool_name: str,
256 ) -> None:
257 """
258 Deletes a work pool.
260 Args:
261 work_pool_name: Name of the work pool to delete.
262 """
263 try:
264 self.request(
265 "DELETE",
266 "/work_pools/{name}",
267 path_params={"name": work_pool_name},
268 )
269 except HTTPStatusError as e:
270 if e.response.status_code == 404:
271 raise ObjectNotFound(http_exc=e) from e
272 else:
273 raise
275 def get_scheduled_flow_runs_for_work_pool( 1a
276 self,
277 work_pool_name: str,
278 work_queue_names: list[str] | None = None,
279 scheduled_before: datetime | None = None,
280 ) -> list["WorkerFlowRunResponse"]:
281 """
282 Retrieves scheduled flow runs for the provided set of work pool queues.
284 Args:
285 work_pool_name: The name of the work pool that the work pool
286 queues are associated with.
287 work_queue_names: The names of the work pool queues from which
288 to get scheduled flow runs.
289 scheduled_before: Datetime used to filter returned flow runs. Flow runs
290 scheduled for after the given datetime string will not be returned.
292 Returns:
293 A list of worker flow run responses containing information about the
294 retrieved flow runs.
295 """
296 from prefect.client.schemas.responses import WorkerFlowRunResponse
298 body: dict[str, Any] = {}
299 if work_queue_names is not None:
300 body["work_queue_names"] = list(work_queue_names)
301 if scheduled_before:
302 body["scheduled_before"] = str(scheduled_before)
304 try:
305 response = self.request(
306 "POST",
307 "/work_pools/{name}/get_scheduled_flow_runs",
308 path_params={"name": work_pool_name},
309 json=body,
310 )
311 except HTTPStatusError as e:
312 if e.response.status_code == 404:
313 raise ObjectNotFound(http_exc=e) from e
314 else:
315 raise
317 return WorkerFlowRunResponse.model_validate_list(response.json())
320class WorkPoolAsyncClient(BaseAsyncClient): 1a
321 async def send_worker_heartbeat( 1a
322 self,
323 work_pool_name: str,
324 worker_name: str,
325 heartbeat_interval_seconds: float | None = None,
326 get_worker_id: bool = False,
327 worker_metadata: "WorkerMetadata | None" = None,
328 ) -> "UUID | None":
329 """
330 Sends a worker heartbeat for a given work pool.
332 Args:
333 work_pool_name: The name of the work pool to heartbeat against.
334 worker_name: The name of the worker sending the heartbeat.
335 return_id: Whether to return the worker ID. Note: will return `None` if the connected server does not support returning worker IDs, even if `return_id` is `True`.
336 worker_metadata: Metadata about the worker to send to the server.
337 """
338 from uuid import UUID
340 params: dict[str, Any] = {
341 "name": worker_name,
342 "heartbeat_interval_seconds": heartbeat_interval_seconds,
343 }
344 if worker_metadata:
345 params["metadata"] = worker_metadata.model_dump(mode="json")
346 if get_worker_id:
347 params["return_id"] = get_worker_id
349 resp = await self.request(
350 "POST",
351 "/work_pools/{work_pool_name}/workers/heartbeat",
352 path_params={"work_pool_name": work_pool_name},
353 json=params,
354 )
355 from prefect.settings import get_current_settings
357 if (
358 (
359 self.server_type == ServerType.CLOUD
360 or get_current_settings().testing.test_mode
361 )
362 and get_worker_id
363 and resp.status_code == 200
364 ):
365 return UUID(resp.text)
366 else:
367 return None
369 async def read_workers_for_work_pool( 1a
370 self,
371 work_pool_name: str,
372 worker_filter: "WorkerFilter | None" = None,
373 offset: int | None = None,
374 limit: int | None = None,
375 ) -> list["Worker"]:
376 """
377 Reads workers for a given work pool.
379 Args:
380 work_pool_name: The name of the work pool for which to get
381 member workers.
382 worker_filter: Criteria by which to filter workers.
383 limit: Limit for the worker query.
384 offset: Limit for the worker query.
385 """
386 from prefect.client.schemas.objects import Worker
388 response = await self.request(
389 "POST",
390 "/work_pools/{work_pool_name}/workers/filter",
391 path_params={"work_pool_name": work_pool_name},
392 json={
393 "workers": (
394 worker_filter.model_dump(mode="json", exclude_unset=True)
395 if worker_filter
396 else None
397 ),
398 "offset": offset,
399 "limit": limit,
400 },
401 )
403 return Worker.model_validate_list(response.json())
405 async def read_work_pool(self, work_pool_name: str) -> "WorkPool": 1a
406 """
407 Reads information for a given work pool
409 Args:
410 work_pool_name: The name of the work pool to for which to get
411 information.
413 Returns:
414 Information about the requested work pool.
415 """
416 from prefect.client.schemas.objects import WorkPool
418 try:
419 response = await self.request(
420 "GET",
421 "/work_pools/{name}",
422 path_params={"name": work_pool_name},
423 )
424 return WorkPool.model_validate(response.json())
425 except HTTPStatusError as e:
426 if e.response.status_code == 404:
427 raise ObjectNotFound(http_exc=e) from e
428 else:
429 raise
431 async def read_work_pools( 1a
432 self,
433 limit: int | None = None,
434 offset: int = 0,
435 work_pool_filter: "WorkPoolFilter | None" = None,
436 ) -> list["WorkPool"]:
437 """
438 Reads work pools.
440 Args:
441 limit: Limit for the work pool query.
442 offset: Offset for the work pool query.
443 work_pool_filter: Criteria by which to filter work pools.
445 Returns:
446 A list of work pools.
447 """
448 from prefect.client.schemas.objects import WorkPool
450 body: dict[str, Any] = {
451 "limit": limit,
452 "offset": offset,
453 "work_pools": (
454 work_pool_filter.model_dump(mode="json") if work_pool_filter else None
455 ),
456 }
457 response = await self.request("POST", "/work_pools/filter", json=body)
458 return WorkPool.model_validate_list(response.json())
460 async def create_work_pool( 1a
461 self,
462 work_pool: "WorkPoolCreate",
463 overwrite: bool = False,
464 ) -> "WorkPool":
465 """
466 Creates a work pool with the provided configuration.
468 Args:
469 work_pool: Desired configuration for the new work pool.
471 Returns:
472 Information about the newly created work pool.
473 """
474 from prefect.client.schemas.actions import WorkPoolUpdate
475 from prefect.client.schemas.objects import WorkPool
477 try:
478 response = await self.request(
479 "POST",
480 "/work_pools/",
481 json=work_pool.model_dump(mode="json", exclude_unset=True),
482 )
483 except HTTPStatusError as e:
484 if e.response.status_code == 409:
485 if overwrite:
486 existing_work_pool = await self.read_work_pool(
487 work_pool_name=work_pool.name
488 )
489 if existing_work_pool.type != work_pool.type:
490 warnings.warn(
491 "Overwriting work pool type is not supported. Ignoring provided type.",
492 category=UserWarning,
493 )
494 await self.update_work_pool(
495 work_pool_name=work_pool.name,
496 work_pool=WorkPoolUpdate.model_validate(
497 work_pool.model_dump(exclude={"name", "type"})
498 ),
499 )
500 response = await self.request(
501 "GET",
502 "/work_pools/{name}",
503 path_params={"name": work_pool.name},
504 )
505 else:
506 raise ObjectAlreadyExists(http_exc=e) from e
507 else:
508 raise
510 return WorkPool.model_validate(response.json())
512 async def update_work_pool( 1a
513 self,
514 work_pool_name: str,
515 work_pool: "WorkPoolUpdate",
516 ) -> None:
517 """
518 Updates a work pool.
520 Args:
521 work_pool_name: Name of the work pool to update.
522 work_pool: Fields to update in the work pool.
523 """
524 try:
525 await self.request(
526 "PATCH",
527 "/work_pools/{name}",
528 path_params={"name": work_pool_name},
529 json=work_pool.model_dump(mode="json", exclude_unset=True),
530 )
531 except HTTPStatusError as e:
532 if e.response.status_code == 404:
533 raise ObjectNotFound(http_exc=e) from e
534 else:
535 raise
537 async def delete_work_pool( 1a
538 self,
539 work_pool_name: str,
540 ) -> None:
541 """
542 Deletes a work pool.
544 Args:
545 work_pool_name: Name of the work pool to delete.
546 """
547 try:
548 await self.request(
549 "DELETE",
550 "/work_pools/{name}",
551 path_params={"name": work_pool_name},
552 )
553 except HTTPStatusError as e:
554 if e.response.status_code == 404:
555 raise ObjectNotFound(http_exc=e) from e
556 else:
557 raise
559 async def get_scheduled_flow_runs_for_work_pool( 1a
560 self,
561 work_pool_name: str,
562 work_queue_names: list[str] | None = None,
563 scheduled_before: datetime | None = None,
564 ) -> list["WorkerFlowRunResponse"]:
565 """
566 Retrieves scheduled flow runs for the provided set of work pool queues.
568 Args:
569 work_pool_name: The name of the work pool that the work pool
570 queues are associated with.
571 work_queue_names: The names of the work pool queues from which
572 to get scheduled flow runs.
573 scheduled_before: Datetime used to filter returned flow runs. Flow runs
574 scheduled for after the given datetime string will not be returned.
576 Returns:
577 A list of worker flow run responses containing information about the
578 retrieved flow runs.
579 """
580 from prefect.client.schemas.responses import WorkerFlowRunResponse
582 body: dict[str, Any] = {}
583 if work_queue_names is not None:
584 body["work_queue_names"] = list(work_queue_names)
585 if scheduled_before:
586 body["scheduled_before"] = str(scheduled_before)
588 try:
589 response = await self.request(
590 "POST",
591 "/work_pools/{name}/get_scheduled_flow_runs",
592 path_params={"name": work_pool_name},
593 json=body,
594 )
595 except HTTPStatusError as e:
596 if e.response.status_code == 404:
597 raise ObjectNotFound(http_exc=e) from e
598 else:
599 raise
601 return WorkerFlowRunResponse.model_validate_list(response.json())