Coverage for /usr/local/lib/python3.12/site-packages/prefect/client/orchestration/_work_pools/client.py: 13%

167 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 11:21 +0000

1from __future__ import annotations 1a

2 

3import warnings 1a

4from datetime import datetime 1a

5from typing import TYPE_CHECKING, Any 1a

6 

7from httpx import HTTPStatusError 1a

8 

9from prefect.client.base import ServerType 1a

10from prefect.client.orchestration.base import BaseAsyncClient, BaseClient 1a

11 

12if TYPE_CHECKING: 12 ↛ 13line 12 didn't jump to line 13 because the condition on line 12 was never true1a

13 from uuid import UUID 

14 

15 from prefect.client.schemas.actions import ( 

16 WorkPoolCreate, 

17 WorkPoolUpdate, 

18 ) 

19 from prefect.client.schemas.filters import ( 

20 WorkerFilter, 

21 WorkPoolFilter, 

22 ) 

23 from prefect.client.schemas.objects import ( 

24 Worker, 

25 WorkerMetadata, 

26 WorkPool, 

27 ) 

28 from prefect.client.schemas.responses import WorkerFlowRunResponse 

29 

30from prefect.exceptions import ObjectAlreadyExists, ObjectNotFound, ObjectUnsupported 1a

31 

32 

33class WorkPoolClient(BaseClient): 1a

34 def send_worker_heartbeat( 1a

35 self, 

36 work_pool_name: str, 

37 worker_name: str, 

38 heartbeat_interval_seconds: float | None = None, 

39 get_worker_id: bool = False, 

40 worker_metadata: "WorkerMetadata | None" = None, 

41 ) -> "UUID | None": 

42 """ 

43 Sends a worker heartbeat for a given work pool. 

44 

45 Args: 

46 work_pool_name: The name of the work pool to heartbeat against. 

47 worker_name: The name of the worker sending the heartbeat. 

48 return_id: Whether to return the worker ID. Note: will return `None` if the connected server does not support returning worker IDs, even if `return_id` is `True`. 

49 worker_metadata: Metadata about the worker to send to the server. 

50 """ 

51 from uuid import UUID 

52 

53 params: dict[str, Any] = { 

54 "name": worker_name, 

55 "heartbeat_interval_seconds": heartbeat_interval_seconds, 

56 } 

57 if worker_metadata: 

58 params["metadata"] = worker_metadata.model_dump(mode="json") 

59 if get_worker_id: 

60 params["return_id"] = get_worker_id 

61 

62 resp = self.request( 

63 "POST", 

64 "/work_pools/{work_pool_name}/workers/heartbeat", 

65 path_params={"work_pool_name": work_pool_name}, 

66 json=params, 

67 ) 

68 from prefect.settings import get_current_settings 

69 

70 if ( 

71 ( 

72 self.server_type == ServerType.CLOUD 

73 or get_current_settings().testing.test_mode 

74 ) 

75 and get_worker_id 

76 and resp.status_code == 200 

77 ): 

78 return UUID(resp.text) 

79 else: 

80 return None 

81 

82 def read_workers_for_work_pool( 1a

83 self, 

84 work_pool_name: str, 

85 worker_filter: "WorkerFilter | None" = None, 

86 offset: int | None = None, 

87 limit: int | None = None, 

88 ) -> list["Worker"]: 

89 """ 

90 Reads workers for a given work pool. 

91 

92 Args: 

93 work_pool_name: The name of the work pool for which to get 

94 member workers. 

95 worker_filter: Criteria by which to filter workers. 

96 limit: Limit for the worker query. 

97 offset: Limit for the worker query. 

98 """ 

99 from prefect.client.schemas.objects import Worker 

100 

101 response = self.request( 

102 "POST", 

103 "/work_pools/{work_pool_name}/workers/filter", 

104 path_params={"work_pool_name": work_pool_name}, 

105 json={ 

106 "workers": ( 

107 worker_filter.model_dump(mode="json", exclude_unset=True) 

108 if worker_filter 

109 else None 

110 ), 

111 "offset": offset, 

112 "limit": limit, 

113 }, 

114 ) 

115 

116 return Worker.model_validate_list(response.json()) 

117 

118 def read_work_pool(self, work_pool_name: str) -> "WorkPool": 1a

119 """ 

120 Reads information for a given work pool 

121 

122 Args: 

123 work_pool_name: The name of the work pool to for which to get 

124 information. 

125 

126 Returns: 

127 Information about the requested work pool. 

128 """ 

129 from prefect.client.schemas.objects import WorkPool 

130 

131 try: 

132 response = self.request( 

133 "GET", 

134 "/work_pools/{name}", 

135 path_params={"name": work_pool_name}, 

136 ) 

137 return WorkPool.model_validate(response.json()) 

138 except HTTPStatusError as e: 

139 if e.response.status_code == 404: 

140 raise ObjectNotFound(http_exc=e) from e 

141 else: 

142 raise 

143 

144 def read_work_pools( 1a

145 self, 

146 limit: int | None = None, 

147 offset: int = 0, 

148 work_pool_filter: "WorkPoolFilter | None" = None, 

149 ) -> list["WorkPool"]: 

150 """ 

151 Reads work pools. 

152 

153 Args: 

154 limit: Limit for the work pool query. 

155 offset: Offset for the work pool query. 

156 work_pool_filter: Criteria by which to filter work pools. 

157 

158 Returns: 

159 A list of work pools. 

160 """ 

161 from prefect.client.schemas.objects import WorkPool 

162 

163 body: dict[str, Any] = { 

164 "limit": limit, 

165 "offset": offset, 

166 "work_pools": ( 

167 work_pool_filter.model_dump(mode="json") if work_pool_filter else None 

168 ), 

169 } 

170 response = self.request("POST", "/work_pools/filter", json=body) 

171 return WorkPool.model_validate_list(response.json()) 

172 

173 def create_work_pool( 1a

174 self, 

175 work_pool: "WorkPoolCreate", 

176 overwrite: bool = False, 

177 ) -> "WorkPool": 

178 """ 

179 Creates a work pool with the provided configuration. 

180 

181 Args: 

182 work_pool: Desired configuration for the new work pool. 

183 

184 Returns: 

185 Information about the newly created work pool. 

186 """ 

187 from prefect.client.schemas.actions import WorkPoolUpdate 

188 from prefect.client.schemas.objects import WorkPool 

189 

190 try: 

191 response = self.request( 

192 "POST", 

193 "/work_pools/", 

194 json=work_pool.model_dump(mode="json", exclude_unset=True), 

195 ) 

196 response.raise_for_status() 

197 except HTTPStatusError as e: 

198 if e.response.status_code == 403 and "plan does not support" in str(e): 

199 raise ObjectUnsupported(http_exc=e) from e 

200 if e.response.status_code == 409: 

201 if overwrite: 

202 existing_work_pool = self.read_work_pool( 

203 work_pool_name=work_pool.name 

204 ) 

205 if existing_work_pool.type != work_pool.type: 

206 warnings.warn( 

207 "Overwriting work pool type is not supported. Ignoring provided type.", 

208 category=UserWarning, 

209 ) 

210 self.update_work_pool( 

211 work_pool_name=work_pool.name, 

212 work_pool=WorkPoolUpdate.model_validate( 

213 work_pool.model_dump(exclude={"name", "type"}) 

214 ), 

215 ) 

216 response = self.request( 

217 "GET", 

218 "/work_pools/{name}", 

219 path_params={"name": work_pool.name}, 

220 ) 

221 else: 

222 raise ObjectAlreadyExists(http_exc=e) from e 

223 else: 

224 raise 

225 

226 return WorkPool.model_validate(response.json()) 

227 

228 def update_work_pool( 1a

229 self, 

230 work_pool_name: str, 

231 work_pool: "WorkPoolUpdate", 

232 ) -> None: 

233 """ 

234 Updates a work pool. 

235 

236 Args: 

237 work_pool_name: Name of the work pool to update. 

238 work_pool: Fields to update in the work pool. 

239 """ 

240 try: 

241 self.request( 

242 "PATCH", 

243 "/work_pools/{name}", 

244 path_params={"name": work_pool_name}, 

245 json=work_pool.model_dump(mode="json", exclude_unset=True), 

246 ) 

247 except HTTPStatusError as e: 

248 if e.response.status_code == 404: 

249 raise ObjectNotFound(http_exc=e) from e 

250 else: 

251 raise 

252 

253 def delete_work_pool( 1a

254 self, 

255 work_pool_name: str, 

256 ) -> None: 

257 """ 

258 Deletes a work pool. 

259 

260 Args: 

261 work_pool_name: Name of the work pool to delete. 

262 """ 

263 try: 

264 self.request( 

265 "DELETE", 

266 "/work_pools/{name}", 

267 path_params={"name": work_pool_name}, 

268 ) 

269 except HTTPStatusError as e: 

270 if e.response.status_code == 404: 

271 raise ObjectNotFound(http_exc=e) from e 

272 else: 

273 raise 

274 

275 def get_scheduled_flow_runs_for_work_pool( 1a

276 self, 

277 work_pool_name: str, 

278 work_queue_names: list[str] | None = None, 

279 scheduled_before: datetime | None = None, 

280 ) -> list["WorkerFlowRunResponse"]: 

281 """ 

282 Retrieves scheduled flow runs for the provided set of work pool queues. 

283 

284 Args: 

285 work_pool_name: The name of the work pool that the work pool 

286 queues are associated with. 

287 work_queue_names: The names of the work pool queues from which 

288 to get scheduled flow runs. 

289 scheduled_before: Datetime used to filter returned flow runs. Flow runs 

290 scheduled for after the given datetime string will not be returned. 

291 

292 Returns: 

293 A list of worker flow run responses containing information about the 

294 retrieved flow runs. 

295 """ 

296 from prefect.client.schemas.responses import WorkerFlowRunResponse 

297 

298 body: dict[str, Any] = {} 

299 if work_queue_names is not None: 

300 body["work_queue_names"] = list(work_queue_names) 

301 if scheduled_before: 

302 body["scheduled_before"] = str(scheduled_before) 

303 

304 try: 

305 response = self.request( 

306 "POST", 

307 "/work_pools/{name}/get_scheduled_flow_runs", 

308 path_params={"name": work_pool_name}, 

309 json=body, 

310 ) 

311 except HTTPStatusError as e: 

312 if e.response.status_code == 404: 

313 raise ObjectNotFound(http_exc=e) from e 

314 else: 

315 raise 

316 

317 return WorkerFlowRunResponse.model_validate_list(response.json()) 

318 

319 

320class WorkPoolAsyncClient(BaseAsyncClient): 1a

321 async def send_worker_heartbeat( 1a

322 self, 

323 work_pool_name: str, 

324 worker_name: str, 

325 heartbeat_interval_seconds: float | None = None, 

326 get_worker_id: bool = False, 

327 worker_metadata: "WorkerMetadata | None" = None, 

328 ) -> "UUID | None": 

329 """ 

330 Sends a worker heartbeat for a given work pool. 

331 

332 Args: 

333 work_pool_name: The name of the work pool to heartbeat against. 

334 worker_name: The name of the worker sending the heartbeat. 

335 return_id: Whether to return the worker ID. Note: will return `None` if the connected server does not support returning worker IDs, even if `return_id` is `True`. 

336 worker_metadata: Metadata about the worker to send to the server. 

337 """ 

338 from uuid import UUID 

339 

340 params: dict[str, Any] = { 

341 "name": worker_name, 

342 "heartbeat_interval_seconds": heartbeat_interval_seconds, 

343 } 

344 if worker_metadata: 

345 params["metadata"] = worker_metadata.model_dump(mode="json") 

346 if get_worker_id: 

347 params["return_id"] = get_worker_id 

348 

349 resp = await self.request( 

350 "POST", 

351 "/work_pools/{work_pool_name}/workers/heartbeat", 

352 path_params={"work_pool_name": work_pool_name}, 

353 json=params, 

354 ) 

355 from prefect.settings import get_current_settings 

356 

357 if ( 

358 ( 

359 self.server_type == ServerType.CLOUD 

360 or get_current_settings().testing.test_mode 

361 ) 

362 and get_worker_id 

363 and resp.status_code == 200 

364 ): 

365 return UUID(resp.text) 

366 else: 

367 return None 

368 

369 async def read_workers_for_work_pool( 1a

370 self, 

371 work_pool_name: str, 

372 worker_filter: "WorkerFilter | None" = None, 

373 offset: int | None = None, 

374 limit: int | None = None, 

375 ) -> list["Worker"]: 

376 """ 

377 Reads workers for a given work pool. 

378 

379 Args: 

380 work_pool_name: The name of the work pool for which to get 

381 member workers. 

382 worker_filter: Criteria by which to filter workers. 

383 limit: Limit for the worker query. 

384 offset: Limit for the worker query. 

385 """ 

386 from prefect.client.schemas.objects import Worker 

387 

388 response = await self.request( 

389 "POST", 

390 "/work_pools/{work_pool_name}/workers/filter", 

391 path_params={"work_pool_name": work_pool_name}, 

392 json={ 

393 "workers": ( 

394 worker_filter.model_dump(mode="json", exclude_unset=True) 

395 if worker_filter 

396 else None 

397 ), 

398 "offset": offset, 

399 "limit": limit, 

400 }, 

401 ) 

402 

403 return Worker.model_validate_list(response.json()) 

404 

405 async def read_work_pool(self, work_pool_name: str) -> "WorkPool": 1a

406 """ 

407 Reads information for a given work pool 

408 

409 Args: 

410 work_pool_name: The name of the work pool to for which to get 

411 information. 

412 

413 Returns: 

414 Information about the requested work pool. 

415 """ 

416 from prefect.client.schemas.objects import WorkPool 

417 

418 try: 

419 response = await self.request( 

420 "GET", 

421 "/work_pools/{name}", 

422 path_params={"name": work_pool_name}, 

423 ) 

424 return WorkPool.model_validate(response.json()) 

425 except HTTPStatusError as e: 

426 if e.response.status_code == 404: 

427 raise ObjectNotFound(http_exc=e) from e 

428 else: 

429 raise 

430 

431 async def read_work_pools( 1a

432 self, 

433 limit: int | None = None, 

434 offset: int = 0, 

435 work_pool_filter: "WorkPoolFilter | None" = None, 

436 ) -> list["WorkPool"]: 

437 """ 

438 Reads work pools. 

439 

440 Args: 

441 limit: Limit for the work pool query. 

442 offset: Offset for the work pool query. 

443 work_pool_filter: Criteria by which to filter work pools. 

444 

445 Returns: 

446 A list of work pools. 

447 """ 

448 from prefect.client.schemas.objects import WorkPool 

449 

450 body: dict[str, Any] = { 

451 "limit": limit, 

452 "offset": offset, 

453 "work_pools": ( 

454 work_pool_filter.model_dump(mode="json") if work_pool_filter else None 

455 ), 

456 } 

457 response = await self.request("POST", "/work_pools/filter", json=body) 

458 return WorkPool.model_validate_list(response.json()) 

459 

460 async def create_work_pool( 1a

461 self, 

462 work_pool: "WorkPoolCreate", 

463 overwrite: bool = False, 

464 ) -> "WorkPool": 

465 """ 

466 Creates a work pool with the provided configuration. 

467 

468 Args: 

469 work_pool: Desired configuration for the new work pool. 

470 

471 Returns: 

472 Information about the newly created work pool. 

473 """ 

474 from prefect.client.schemas.actions import WorkPoolUpdate 

475 from prefect.client.schemas.objects import WorkPool 

476 

477 try: 

478 response = await self.request( 

479 "POST", 

480 "/work_pools/", 

481 json=work_pool.model_dump(mode="json", exclude_unset=True), 

482 ) 

483 except HTTPStatusError as e: 

484 if e.response.status_code == 409: 

485 if overwrite: 

486 existing_work_pool = await self.read_work_pool( 

487 work_pool_name=work_pool.name 

488 ) 

489 if existing_work_pool.type != work_pool.type: 

490 warnings.warn( 

491 "Overwriting work pool type is not supported. Ignoring provided type.", 

492 category=UserWarning, 

493 ) 

494 await self.update_work_pool( 

495 work_pool_name=work_pool.name, 

496 work_pool=WorkPoolUpdate.model_validate( 

497 work_pool.model_dump(exclude={"name", "type"}) 

498 ), 

499 ) 

500 response = await self.request( 

501 "GET", 

502 "/work_pools/{name}", 

503 path_params={"name": work_pool.name}, 

504 ) 

505 else: 

506 raise ObjectAlreadyExists(http_exc=e) from e 

507 else: 

508 raise 

509 

510 return WorkPool.model_validate(response.json()) 

511 

512 async def update_work_pool( 1a

513 self, 

514 work_pool_name: str, 

515 work_pool: "WorkPoolUpdate", 

516 ) -> None: 

517 """ 

518 Updates a work pool. 

519 

520 Args: 

521 work_pool_name: Name of the work pool to update. 

522 work_pool: Fields to update in the work pool. 

523 """ 

524 try: 

525 await self.request( 

526 "PATCH", 

527 "/work_pools/{name}", 

528 path_params={"name": work_pool_name}, 

529 json=work_pool.model_dump(mode="json", exclude_unset=True), 

530 ) 

531 except HTTPStatusError as e: 

532 if e.response.status_code == 404: 

533 raise ObjectNotFound(http_exc=e) from e 

534 else: 

535 raise 

536 

537 async def delete_work_pool( 1a

538 self, 

539 work_pool_name: str, 

540 ) -> None: 

541 """ 

542 Deletes a work pool. 

543 

544 Args: 

545 work_pool_name: Name of the work pool to delete. 

546 """ 

547 try: 

548 await self.request( 

549 "DELETE", 

550 "/work_pools/{name}", 

551 path_params={"name": work_pool_name}, 

552 ) 

553 except HTTPStatusError as e: 

554 if e.response.status_code == 404: 

555 raise ObjectNotFound(http_exc=e) from e 

556 else: 

557 raise 

558 

559 async def get_scheduled_flow_runs_for_work_pool( 1a

560 self, 

561 work_pool_name: str, 

562 work_queue_names: list[str] | None = None, 

563 scheduled_before: datetime | None = None, 

564 ) -> list["WorkerFlowRunResponse"]: 

565 """ 

566 Retrieves scheduled flow runs for the provided set of work pool queues. 

567 

568 Args: 

569 work_pool_name: The name of the work pool that the work pool 

570 queues are associated with. 

571 work_queue_names: The names of the work pool queues from which 

572 to get scheduled flow runs. 

573 scheduled_before: Datetime used to filter returned flow runs. Flow runs 

574 scheduled for after the given datetime string will not be returned. 

575 

576 Returns: 

577 A list of worker flow run responses containing information about the 

578 retrieved flow runs. 

579 """ 

580 from prefect.client.schemas.responses import WorkerFlowRunResponse 

581 

582 body: dict[str, Any] = {} 

583 if work_queue_names is not None: 

584 body["work_queue_names"] = list(work_queue_names) 

585 if scheduled_before: 

586 body["scheduled_before"] = str(scheduled_before) 

587 

588 try: 

589 response = await self.request( 

590 "POST", 

591 "/work_pools/{name}/get_scheduled_flow_runs", 

592 path_params={"name": work_pool_name}, 

593 json=body, 

594 ) 

595 except HTTPStatusError as e: 

596 if e.response.status_code == 404: 

597 raise ObjectNotFound(http_exc=e) from e 

598 else: 

599 raise 

600 

601 return WorkerFlowRunResponse.model_validate_list(response.json())