Coverage for /usr/local/lib/python3.12/site-packages/prefect/infrastructure/provisioners/ecs.py: 0%

454 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1from __future__ import annotations 

2 

3import base64 

4import contextlib 

5import contextvars 

6import importlib 

7import ipaddress 

8import json 

9from copy import deepcopy 

10from functools import partial 

11from textwrap import dedent 

12from types import ModuleType 

13from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional 

14 

15import anyio 

16import anyio.to_thread 

17from anyio import run_process 

18from rich.console import Console 

19from rich.panel import Panel 

20from rich.progress import Progress, SpinnerColumn, TextColumn 

21from rich.prompt import Confirm 

22from rich.syntax import Syntax 

23 

24from prefect._internal.installation import ainstall_packages 

25from prefect.client.schemas.actions import BlockDocumentCreate 

26from prefect.client.utilities import inject_client 

27from prefect.exceptions import ObjectNotFound 

28from prefect.settings import ( 

29 PREFECT_DEFAULT_DOCKER_BUILD_NAMESPACE, 

30 update_current_profile, 

31) 

32from prefect.utilities.collections import get_from_dict 

33from prefect.utilities.importtools import lazy_import 

34 

35if TYPE_CHECKING: 

36 from prefect.client.orchestration import PrefectClient 

37 

38boto3: ModuleType = lazy_import("boto3") 

39 

40current_console: contextvars.ContextVar[Console] = contextvars.ContextVar( 

41 "console", default=Console() 

42) 

43 

44 

45@contextlib.contextmanager 

46def console_context(value: Console) -> Generator[None, None, None]: 

47 token = current_console.set(value) 

48 try: 

49 yield 

50 finally: 

51 current_console.reset(token) 

52 

53 

54class IamPolicyResource: 

55 """ 

56 Represents an IAM policy resource for managing ECS tasks. 

57 

58 Args: 

59 policy_name: The name of the IAM policy. Defaults to "prefect-ecs-policy". 

60 """ 

61 

62 def __init__( 

63 self, 

64 policy_name: str, 

65 ): 

66 self._iam_client = boto3.client("iam") 

67 self._policy_name = policy_name 

68 

69 self._requires_provisioning = None 

70 

71 async def get_task_count(self) -> int: 

72 """ 

73 Returns the number of tasks that will be executed to provision this resource. 

74 

75 Returns: 

76 int: The number of tasks to be provisioned. 

77 """ 

78 return 1 if await self.requires_provisioning() else 0 

79 

80 def _get_policy_by_name(self, name: str) -> dict[str, Any] | None: 

81 paginator = self._iam_client.get_paginator("list_policies") 

82 page_iterator = paginator.paginate(Scope="Local") 

83 

84 for page in page_iterator: 

85 for policy in page["Policies"]: 

86 if policy["PolicyName"] == name: 

87 return policy 

88 return None 

89 

90 async def requires_provisioning(self) -> bool: 

91 """ 

92 Check if this resource requires provisioning. 

93 

94 Returns: 

95 bool: True if provisioning is required, False otherwise. 

96 """ 

97 if self._requires_provisioning is not None: 

98 return self._requires_provisioning 

99 policy = await anyio.to_thread.run_sync( 

100 partial(self._get_policy_by_name, self._policy_name) 

101 ) 

102 if policy is not None: 

103 self._requires_provisioning = False 

104 return False 

105 

106 self._requires_provisioning = True 

107 return True 

108 

109 async def get_planned_actions(self) -> List[str]: 

110 """ 

111 Returns a description of the planned actions for provisioning this resource. 

112 

113 Returns: 

114 Optional[str]: A description of the planned actions for provisioning the resource, 

115 or None if provisioning is not required. 

116 """ 

117 if await self.requires_provisioning(): 

118 return [ 

119 "Creating and attaching an IAM policy for managing ECS tasks:" 

120 f" [blue]{self._policy_name}[/]" 

121 ] 

122 return [] 

123 

124 async def provision( 

125 self, 

126 policy_document: dict[str, Any], 

127 advance: Callable[[], None], 

128 ) -> str: 

129 """ 

130 Provisions an IAM policy. 

131 

132 Args: 

133 advance: A callback function to indicate progress. 

134 

135 Returns: 

136 str: The ARN (Amazon Resource Name) of the created IAM policy. 

137 """ 

138 if await self.requires_provisioning(): 

139 console = current_console.get() 

140 console.print("Creating IAM policy") 

141 policy = await anyio.to_thread.run_sync( 

142 partial( 

143 self._iam_client.create_policy, 

144 PolicyName=self._policy_name, 

145 PolicyDocument=json.dumps(policy_document), 

146 ) 

147 ) 

148 policy_arn = policy["Policy"]["Arn"] 

149 advance() 

150 return policy_arn 

151 else: 

152 policy = await anyio.to_thread.run_sync( 

153 partial(self._get_policy_by_name, self._policy_name) 

154 ) 

155 # This should never happen, but just in case 

156 assert policy is not None, "Could not find expected policy" 

157 return policy["Arn"] 

158 

159 @property 

160 def next_steps(self) -> list[str]: 

161 return [] 

162 

163 

164class IamUserResource: 

165 """ 

166 Represents an IAM user resource for managing ECS tasks. 

167 

168 Args: 

169 user_name: The desired name of the IAM user. 

170 """ 

171 

172 def __init__(self, user_name: str): 

173 self._iam_client = boto3.client("iam") 

174 self._user_name = user_name 

175 self._requires_provisioning = None 

176 

177 async def get_task_count(self) -> int: 

178 """ 

179 Returns the number of tasks that will be executed to provision this resource. 

180 

181 Returns: 

182 int: The number of tasks to be provisioned. 

183 """ 

184 return 1 if await self.requires_provisioning() else 0 

185 

186 async def requires_provisioning(self) -> bool: 

187 """ 

188 Check if this resource requires provisioning. 

189 

190 Returns: 

191 bool: True if provisioning is required, False otherwise. 

192 """ 

193 if self._requires_provisioning is None: 

194 try: 

195 await anyio.to_thread.run_sync( 

196 partial(self._iam_client.get_user, UserName=self._user_name) 

197 ) 

198 self._requires_provisioning = False 

199 except self._iam_client.exceptions.NoSuchEntityException: 

200 self._requires_provisioning = True 

201 

202 return self._requires_provisioning 

203 

204 async def get_planned_actions(self) -> List[str]: 

205 """ 

206 Returns a description of the planned actions for provisioning this resource. 

207 

208 Returns: 

209 Optional[str]: A description of the planned actions for provisioning the resource, 

210 or None if provisioning is not required. 

211 """ 

212 if await self.requires_provisioning(): 

213 return [ 

214 "Creating an IAM user for managing ECS tasks:" 

215 f" [blue]{self._user_name}[/]" 

216 ] 

217 return [] 

218 

219 async def provision( 

220 self, 

221 advance: Callable[[], None], 

222 ) -> None: 

223 """ 

224 Provisions an IAM user. 

225 

226 Args: 

227 advance: A callback function to indicate progress. 

228 """ 

229 console = current_console.get() 

230 if await self.requires_provisioning(): 

231 console.print("Provisioning IAM user") 

232 await anyio.to_thread.run_sync( 

233 partial(self._iam_client.create_user, UserName=self._user_name) 

234 ) 

235 advance() 

236 

237 @property 

238 def next_steps(self) -> list[str]: 

239 return [] 

240 

241 

242class CredentialsBlockResource: 

243 def __init__(self, user_name: str, block_document_name: str): 

244 self._block_document_name = block_document_name 

245 self._user_name = user_name 

246 self._requires_provisioning = None 

247 

248 async def get_task_count(self) -> int: 

249 """ 

250 Returns the number of tasks that will be executed to provision this resource. 

251 

252 Returns: 

253 int: The number of tasks to be provisioned. 

254 """ 

255 return 2 if await self.requires_provisioning() else 0 

256 

257 @inject_client 

258 async def requires_provisioning( 

259 self, client: Optional["PrefectClient"] = None 

260 ) -> bool: 

261 if self._requires_provisioning is None: 

262 try: 

263 assert client is not None 

264 await client.read_block_document_by_name( 

265 self._block_document_name, "aws-credentials" 

266 ) 

267 self._requires_provisioning = False 

268 except ObjectNotFound: 

269 self._requires_provisioning = True 

270 return self._requires_provisioning 

271 

272 async def get_planned_actions(self) -> List[str]: 

273 """ 

274 Returns a description of the planned actions for provisioning this resource. 

275 

276 Returns: 

277 Optional[str]: A description of the planned actions for provisioning the resource, 

278 or None if provisioning is not required. 

279 """ 

280 if await self.requires_provisioning(): 

281 return ["Storing generated AWS credentials in a block"] 

282 return [] 

283 

284 @inject_client 

285 async def provision( 

286 self, 

287 base_job_template: Dict[str, Any], 

288 advance: Callable[[], None], 

289 client: Optional["PrefectClient"] = None, 

290 ): 

291 """ 

292 Provisions an AWS credentials block. 

293 

294 Will generate new credentials if the block does not already exist. Updates 

295 the `aws_credentials` variable in the job template to reference the block. 

296 

297 Args: 

298 base_job_template: The base job template. 

299 advance: A callback function to indicate progress. 

300 client: A Prefect client to use for interacting with the Prefect API. 

301 """ 

302 assert client is not None, "Client injection failed" 

303 if not await self.requires_provisioning(): 

304 block_doc = await client.read_block_document_by_name( 

305 self._block_document_name, "aws-credentials" 

306 ) 

307 else: 

308 console = current_console.get() 

309 console.print("Generating AWS credentials") 

310 iam_client = boto3.client("iam") 

311 access_key_data = await anyio.to_thread.run_sync( 

312 partial(iam_client.create_access_key, UserName=self._user_name) 

313 ) 

314 access_key = access_key_data["AccessKey"] 

315 advance() 

316 console.print("Creating AWS credentials block") 

317 assert client is not None 

318 

319 try: 

320 credentials_block_type = await client.read_block_type_by_slug( 

321 "aws-credentials" 

322 ) 

323 except ObjectNotFound as exc: 

324 raise RuntimeError( 

325 dedent( 

326 """\ 

327 Unable to find block type "aws-credentials". 

328 To register the `aws-credentials` block type, run: 

329 

330 pip install prefect-aws 

331 prefect blocks register -m prefect_aws 

332 

333 """ 

334 ) 

335 ) from exc 

336 

337 credentials_block_schema = ( 

338 await client.get_most_recent_block_schema_for_block_type( 

339 block_type_id=credentials_block_type.id 

340 ) 

341 ) 

342 assert credentials_block_schema is not None, ( 

343 f"Unable to find schema for block type {credentials_block_type.slug}" 

344 ) 

345 

346 block_doc = await client.create_block_document( 

347 block_document=BlockDocumentCreate( 

348 name=self._block_document_name, 

349 data={ 

350 "aws_access_key_id": access_key["AccessKeyId"], 

351 "aws_secret_access_key": access_key["SecretAccessKey"], 

352 "region_name": boto3.session.Session().region_name, 

353 }, 

354 block_type_id=credentials_block_type.id, 

355 block_schema_id=credentials_block_schema.id, 

356 ) 

357 ) 

358 advance() 

359 base_job_template["variables"]["properties"]["aws_credentials"]["default"] = { 

360 "$ref": {"block_document_id": str(block_doc.id)} 

361 } 

362 

363 @property 

364 def next_steps(self) -> list[str]: 

365 return [] 

366 

367 

368class AuthenticationResource: 

369 def __init__( 

370 self, 

371 work_pool_name: str, 

372 user_name: str = "prefect-ecs-user", 

373 policy_name: str = "prefect-ecs-policy", 

374 credentials_block_name: Optional[str] = None, 

375 ): 

376 self._user_name = user_name 

377 self._credentials_block_name = ( 

378 credentials_block_name or f"{work_pool_name}-aws-credentials" 

379 ) 

380 self._policy_name = policy_name 

381 self._policy_document: dict[str, Any] = { 

382 "Version": "2012-10-17", 

383 "Statement": [ 

384 { 

385 "Sid": "PrefectEcsPolicy", 

386 "Effect": "Allow", 

387 "Action": [ 

388 "ec2:AuthorizeSecurityGroupIngress", 

389 "ec2:CreateSecurityGroup", 

390 "ec2:CreateTags", 

391 "ec2:DescribeNetworkInterfaces", 

392 "ec2:DescribeSecurityGroups", 

393 "ec2:DescribeSubnets", 

394 "ec2:DescribeVpcs", 

395 "ecs:CreateCluster", 

396 "ecs:DeregisterTaskDefinition", 

397 "ecs:DescribeClusters", 

398 "ecs:DescribeTaskDefinition", 

399 "ecs:DescribeTasks", 

400 "ecs:ListAccountSettings", 

401 "ecs:ListClusters", 

402 "ecs:ListTaskDefinitions", 

403 "ecs:RegisterTaskDefinition", 

404 "ecs:RunTask", 

405 "ecs:StopTask", 

406 "ecs:TagResource", 

407 "logs:CreateLogStream", 

408 "logs:PutLogEvents", 

409 "logs:DescribeLogGroups", 

410 "logs:GetLogEvents", 

411 ], 

412 "Resource": "*", 

413 } 

414 ], 

415 } 

416 self._iam_user_resource = IamUserResource(user_name=user_name) 

417 self._iam_policy_resource = IamPolicyResource(policy_name=policy_name) 

418 self._credentials_block_resource = CredentialsBlockResource( 

419 user_name=user_name, block_document_name=self._credentials_block_name 

420 ) 

421 self._execution_role_resource = ExecutionRoleResource() 

422 

423 @property 

424 def resources( 

425 self, 

426 ) -> list[ 

427 "ExecutionRoleResource | IamUserResource | IamPolicyResource | CredentialsBlockResource" 

428 ]: 

429 return [ 

430 self._execution_role_resource, 

431 self._iam_user_resource, 

432 self._iam_policy_resource, 

433 self._credentials_block_resource, 

434 ] 

435 

436 async def get_task_count(self) -> int: 

437 """ 

438 Returns the number of tasks that will be executed to provision this resource. 

439 

440 Returns: 

441 int: The number of tasks to be provisioned. 

442 """ 

443 return sum([await resource.get_task_count() for resource in self.resources]) 

444 

445 async def requires_provisioning(self) -> bool: 

446 """ 

447 Check if this resource requires provisioning. 

448 

449 Returns: 

450 bool: True if provisioning is required, False otherwise. 

451 """ 

452 return any( 

453 [await resource.requires_provisioning() for resource in self.resources] 

454 ) 

455 

456 async def get_planned_actions(self) -> List[str]: 

457 """ 

458 Returns a description of the planned actions for provisioning this resource. 

459 

460 Returns: 

461 Optional[str]: A description of the planned actions for provisioning the resource, 

462 or None if provisioning is not required. 

463 """ 

464 return [ 

465 action 

466 for resource in self.resources 

467 for action in await resource.get_planned_actions() 

468 ] 

469 

470 async def provision( 

471 self, 

472 base_job_template: dict[str, Any], 

473 advance: Callable[[], None], 

474 ) -> None: 

475 """ 

476 Provisions the authentication resources. 

477 

478 Args: 

479 base_job_template: The base job template of the work pool to provision 

480 infrastructure for. 

481 advance: A callback function to indicate progress. 

482 """ 

483 # Provision task execution role 

484 role_arn = await self._execution_role_resource.provision( 

485 base_job_template=base_job_template, advance=advance 

486 ) 

487 # Update policy document with the role ARN 

488 self._policy_document["Statement"].append( 

489 { 

490 "Sid": "AllowPassRoleForEcs", 

491 "Effect": "Allow", 

492 "Action": "iam:PassRole", 

493 "Resource": role_arn, 

494 } 

495 ) 

496 # Provision the IAM user 

497 await self._iam_user_resource.provision(advance=advance) 

498 # Provision the IAM policy 

499 policy_arn = await self._iam_policy_resource.provision( 

500 policy_document=self._policy_document, advance=advance 

501 ) 

502 # Attach the policy to the user 

503 if policy_arn: 

504 iam_client = boto3.client("iam") 

505 await anyio.to_thread.run_sync( 

506 partial( 

507 iam_client.attach_user_policy, 

508 UserName=self._user_name, 

509 PolicyArn=policy_arn, 

510 ) 

511 ) 

512 await self._credentials_block_resource.provision( 

513 base_job_template=base_job_template, 

514 advance=advance, 

515 ) 

516 

517 @property 

518 def next_steps(self) -> list[str]: 

519 return [ 

520 next_step 

521 for resource in self.resources 

522 for next_step in resource.next_steps 

523 ] 

524 

525 

526class ClusterResource: 

527 def __init__(self, cluster_name: str = "prefect-ecs-cluster"): 

528 self._ecs_client = boto3.client("ecs") 

529 self._cluster_name = cluster_name 

530 self._requires_provisioning = None 

531 

532 async def get_task_count(self) -> int: 

533 """ 

534 Returns the number of tasks that will be executed to provision this resource. 

535 

536 Returns: 

537 int: The number of tasks to be provisioned. 

538 """ 

539 return 1 if await self.requires_provisioning() else 0 

540 

541 async def requires_provisioning(self) -> bool: 

542 """ 

543 Check if this resource requires provisioning. 

544 

545 Returns: 

546 bool: True if provisioning is required, False otherwise. 

547 """ 

548 if self._requires_provisioning is None: 

549 response = await anyio.to_thread.run_sync( 

550 partial( 

551 self._ecs_client.describe_clusters, clusters=[self._cluster_name] 

552 ) 

553 ) 

554 if response["clusters"] and response["clusters"][0]["status"] == "ACTIVE": 

555 self._requires_provisioning = False 

556 else: 

557 self._requires_provisioning = True 

558 return self._requires_provisioning 

559 

560 async def get_planned_actions(self) -> List[str]: 

561 """ 

562 Returns a description of the planned actions for provisioning this resource. 

563 

564 Returns: 

565 Optional[str]: A description of the planned actions for provisioning the resource, 

566 or None if provisioning is not required. 

567 """ 

568 if await self.requires_provisioning(): 

569 return [ 

570 "Creating an ECS cluster for running Prefect flows:" 

571 f" [blue]{self._cluster_name}[/]" 

572 ] 

573 return [] 

574 

575 async def provision( 

576 self, 

577 base_job_template: dict[str, Any], 

578 advance: Callable[[], None], 

579 ) -> None: 

580 """ 

581 Provisions an ECS cluster. 

582 

583 Will update the `cluster` variable in the job template to reference the cluster. 

584 

585 Args: 

586 base_job_template: The base job template of the work pool to provision 

587 infrastructure for. 

588 advance: A callback function to indicate progress. 

589 """ 

590 if await self.requires_provisioning(): 

591 console = current_console.get() 

592 console.print("Provisioning ECS cluster") 

593 await anyio.to_thread.run_sync( 

594 partial(self._ecs_client.create_cluster, clusterName=self._cluster_name) 

595 ) 

596 advance() 

597 

598 base_job_template["variables"]["properties"]["cluster"]["default"] = ( 

599 self._cluster_name 

600 ) 

601 

602 @property 

603 def next_steps(self) -> list[str]: 

604 return [] 

605 

606 

607class VpcResource: 

608 def __init__( 

609 self, 

610 vpc_name: str = "prefect-ecs-vpc", 

611 ecs_security_group_name: str = "prefect-ecs-security-group", 

612 ): 

613 self._ec2_client = boto3.client("ec2") 

614 self._ec2_resource = boto3.resource("ec2") 

615 self._vpc_name = vpc_name 

616 self._requires_provisioning = None 

617 self._ecs_security_group_name = ecs_security_group_name 

618 

619 async def get_task_count(self) -> int: 

620 """ 

621 Returns the number of tasks that will be executed to provision this resource. 

622 

623 Returns: 

624 int: The number of tasks to be provisioned. 

625 """ 

626 return 4 if await self.requires_provisioning() else 0 

627 

628 async def _default_vpc_exists(self): 

629 response = await anyio.to_thread.run_sync(self._ec2_client.describe_vpcs) 

630 default_vpc = next( 

631 ( 

632 vpc 

633 for vpc in response["Vpcs"] 

634 if vpc["IsDefault"] and vpc["State"] == "available" 

635 ), 

636 None, 

637 ) 

638 return default_vpc is not None 

639 

640 async def _get_prefect_created_vpc(self): 

641 vpcs = await anyio.to_thread.run_sync( 

642 partial( 

643 self._ec2_resource.vpcs.filter, 

644 Filters=[{"Name": "tag:Name", "Values": [self._vpc_name]}], 

645 ) 

646 ) 

647 return next(iter(vpcs), None) 

648 

649 async def _get_existing_vpc_cidrs(self): 

650 response = await anyio.to_thread.run_sync(self._ec2_client.describe_vpcs) 

651 return [vpc["CidrBlock"] for vpc in response["Vpcs"]] 

652 

653 async def _find_non_overlapping_cidr( 

654 self, default_cidr: str = "172.31.0.0/16" 

655 ) -> str: 

656 """Find a non-overlapping CIDR block""" 

657 response = await anyio.to_thread.run_sync(self._ec2_client.describe_vpcs) 

658 existing_cidrs = [vpc["CidrBlock"] for vpc in response["Vpcs"]] 

659 

660 base_ip = ipaddress.ip_network(default_cidr) 

661 new_cidr = base_ip 

662 while True: 

663 if any( 

664 new_cidr.overlaps(ipaddress.ip_network(cidr)) for cidr in existing_cidrs 

665 ): 

666 # Increase the network address by the size of the network 

667 new_network_address = int(new_cidr.network_address) + 2 ** ( 

668 32 - new_cidr.prefixlen 

669 ) 

670 try: 

671 new_cidr = ipaddress.ip_network( 

672 f"{ipaddress.IPv4Address(new_network_address)}/{new_cidr.prefixlen}" 

673 ) 

674 except ValueError: 

675 raise Exception( 

676 "Unable to find a non-overlapping CIDR block in the default" 

677 " range" 

678 ) 

679 else: 

680 return str(new_cidr) 

681 

682 async def requires_provisioning(self) -> bool: 

683 """ 

684 Check if this resource requires provisioning. 

685 

686 Returns: 

687 bool: True if provisioning is required, False otherwise. 

688 """ 

689 if self._requires_provisioning is not None: 

690 return self._requires_provisioning 

691 

692 if await self._default_vpc_exists(): 

693 self._requires_provisioning = False 

694 return False 

695 

696 if await self._get_prefect_created_vpc() is not None: 

697 self._requires_provisioning = False 

698 return False 

699 

700 self._requires_provisioning = True 

701 return True 

702 

703 async def get_planned_actions(self) -> List[str]: 

704 """ 

705 Returns a description of the planned actions for provisioning this resource. 

706 

707 Returns: 

708 Optional[str]: A description of the planned actions for provisioning the resource, 

709 or None if provisioning is not required. 

710 """ 

711 if await self.requires_provisioning(): 

712 new_vpc_cidr = await self._find_non_overlapping_cidr() 

713 return [ 

714 f"Creating a VPC with CIDR [blue]{new_vpc_cidr}[/] for running" 

715 f" ECS tasks: [blue]{self._vpc_name}[/]" 

716 ] 

717 return [] 

718 

719 async def provision( 

720 self, 

721 base_job_template: dict[str, Any], 

722 advance: Callable[[], None], 

723 ) -> None: 

724 """ 

725 Provisions a VPC. 

726 

727 Chooses a CIDR block to avoid conflicting with any existing VPCs. Will update 

728 the `vpc_id` variable in the job template to reference the VPC. 

729 

730 Args: 

731 base_job_template: The base job template of the work pool to provision 

732 infrastructure for. 

733 advance: A callback function to indicate progress. 

734 """ 

735 if await self.requires_provisioning(): 

736 console = current_console.get() 

737 console.print("Provisioning VPC") 

738 new_vpc_cidr = await self._find_non_overlapping_cidr() 

739 vpc = await anyio.to_thread.run_sync( 

740 partial(self._ec2_resource.create_vpc, CidrBlock=new_vpc_cidr) 

741 ) 

742 await anyio.to_thread.run_sync(vpc.wait_until_available) 

743 await anyio.to_thread.run_sync( 

744 partial( 

745 vpc.create_tags, 

746 Resources=[vpc.id], 

747 Tags=[ 

748 { 

749 "Key": "Name", 

750 "Value": self._vpc_name, 

751 }, 

752 ], 

753 ) 

754 ) 

755 advance() 

756 

757 console.print("Creating internet gateway") 

758 internet_gateway = await anyio.to_thread.run_sync( 

759 self._ec2_resource.create_internet_gateway 

760 ) 

761 await anyio.to_thread.run_sync( 

762 partial( 

763 vpc.attach_internet_gateway, InternetGatewayId=internet_gateway.id 

764 ) 

765 ) 

766 advance() 

767 

768 console.print("Setting up subnets") 

769 vpc_network = ipaddress.ip_network(new_vpc_cidr) 

770 subnet_cidrs = list( 

771 vpc_network.subnets(new_prefix=vpc_network.prefixlen + 2) 

772 ) 

773 

774 # Create subnets 

775 azs = ( 

776 await anyio.to_thread.run_sync( 

777 self._ec2_client.describe_availability_zones 

778 ) 

779 )["AvailabilityZones"] 

780 zones = [az["ZoneName"] for az in azs] 

781 subnets: list[Any] = [] 

782 for i, subnet_cidr in enumerate(subnet_cidrs[0:3]): 

783 subnets.append( 

784 await anyio.to_thread.run_sync( 

785 partial( 

786 vpc.create_subnet, 

787 CidrBlock=str(subnet_cidr), 

788 AvailabilityZone=zones[i], 

789 ) 

790 ) 

791 ) 

792 

793 # Create a Route Table for the public subnet and add a route to the Internet Gateway 

794 public_route_table = await anyio.to_thread.run_sync(vpc.create_route_table) 

795 await anyio.to_thread.run_sync( 

796 partial( 

797 public_route_table.create_route, 

798 DestinationCidrBlock="0.0.0.0/0", 

799 GatewayId=internet_gateway.id, 

800 ) 

801 ) 

802 await anyio.to_thread.run_sync( 

803 partial( 

804 public_route_table.associate_with_subnet, SubnetId=subnets[0].id 

805 ) 

806 ) 

807 await anyio.to_thread.run_sync( 

808 partial( 

809 public_route_table.associate_with_subnet, SubnetId=subnets[1].id 

810 ) 

811 ) 

812 await anyio.to_thread.run_sync( 

813 partial( 

814 public_route_table.associate_with_subnet, SubnetId=subnets[2].id 

815 ) 

816 ) 

817 advance() 

818 

819 console.print("Setting up security group") 

820 # Create a security group to block all inbound traffic 

821 await anyio.to_thread.run_sync( 

822 partial( 

823 self._ec2_resource.create_security_group, 

824 GroupName=self._ecs_security_group_name, 

825 Description=( 

826 "Block all inbound traffic and allow all outbound traffic" 

827 ), 

828 VpcId=vpc.id, 

829 ) 

830 ) 

831 advance() 

832 else: 

833 vpc = await self._get_prefect_created_vpc() 

834 

835 if vpc is not None: 

836 base_job_template["variables"]["properties"]["vpc_id"]["default"] = str( 

837 vpc.id 

838 ) 

839 

840 @property 

841 def next_steps(self) -> list[str]: 

842 return [] 

843 

844 

845class ContainerRepositoryResource: 

846 def __init__(self, work_pool_name: str, repository_name: str = "prefect-flows"): 

847 self._ecr_client = boto3.client("ecr") 

848 self._repository_name = repository_name 

849 self._requires_provisioning = None 

850 self._work_pool_name = work_pool_name 

851 self._next_steps: list[str | Panel] = [] 

852 

853 async def get_task_count(self) -> int: 

854 """ 

855 Returns the number of tasks that will be executed to provision this resource. 

856 

857 Returns: 

858 int: The number of tasks to be provisioned. 

859 """ 

860 return 3 if await self.requires_provisioning() else 0 

861 

862 async def _get_prefect_created_registry(self): 

863 try: 

864 registries = await anyio.to_thread.run_sync( 

865 partial( 

866 self._ecr_client.describe_repositories, 

867 repositoryNames=[self._repository_name], 

868 ) 

869 ) 

870 return next(iter(registries), None) 

871 except self._ecr_client.exceptions.RepositoryNotFoundException: 

872 return None 

873 

874 async def requires_provisioning(self) -> bool: 

875 """ 

876 Check if this resource requires provisioning. 

877 

878 Returns: 

879 bool: True if provisioning is required, False otherwise. 

880 """ 

881 if self._requires_provisioning is not None: 

882 return self._requires_provisioning 

883 

884 if await self._get_prefect_created_registry() is not None: 

885 self._requires_provisioning = False 

886 return False 

887 

888 self._requires_provisioning = True 

889 return True 

890 

891 async def get_planned_actions(self) -> List[str]: 

892 """ 

893 Returns a description of the planned actions for provisioning this resource. 

894 

895 Returns: 

896 Optional[str]: A description of the planned actions for provisioning the resource, 

897 or None if provisioning is not required. 

898 """ 

899 if await self.requires_provisioning(): 

900 return [ 

901 "Creating an ECR repository for storing Prefect images:" 

902 f" [blue]{self._repository_name}[/]" 

903 ] 

904 return [] 

905 

906 async def provision( 

907 self, 

908 base_job_template: dict[str, Any], 

909 advance: Callable[[], None], 

910 ) -> None: 

911 """ 

912 Provisions an ECR repository. 

913 

914 Args: 

915 base_job_template: The base job template of the work pool to provision 

916 infrastructure for. 

917 advance: A callback function to indicate progress. 

918 """ 

919 if await self.requires_provisioning(): 

920 console = current_console.get() 

921 console.print("Provisioning ECR repository") 

922 response = await anyio.to_thread.run_sync( 

923 partial( 

924 self._ecr_client.create_repository, 

925 repositoryName=self._repository_name, 

926 ) 

927 ) 

928 advance() 

929 console.print("Authenticating with ECR") 

930 auth_token = self._ecr_client.get_authorization_token() 

931 user, passwd = ( 

932 base64.b64decode( 

933 auth_token["authorizationData"][0]["authorizationToken"] 

934 ) 

935 .decode() 

936 .split(":") 

937 ) 

938 proxy_endpoint = auth_token["authorizationData"][0]["proxyEndpoint"] 

939 await run_process(f"docker login -u {user} -p {passwd} {proxy_endpoint}") 

940 advance() 

941 console.print("Setting default Docker build namespace") 

942 namespace = response["repository"]["repositoryUri"].split("/")[0] 

943 update_current_profile({PREFECT_DEFAULT_DOCKER_BUILD_NAMESPACE: namespace}) 

944 self._update_next_steps(namespace) 

945 advance() 

946 

947 def _update_next_steps(self, repository_uri: str): 

948 self._next_steps.extend( 

949 [ 

950 dedent( 

951 f"""\ 

952 

953 Your default Docker build namespace has been set to [blue]{repository_uri!r}[/]. 

954 

955 To build and push a Docker image to your newly created repository, use [blue]{self._repository_name!r}[/] as your image name: 

956 """ 

957 ), 

958 Panel( 

959 Syntax( 

960 dedent( 

961 f"""\ 

962 from prefect import flow 

963 from prefect.docker import DockerImage 

964 

965 

966 @flow(log_prints=True) 

967 def my_flow(name: str = "world"): 

968 print(f"Hello {{name}}! I'm a flow running on ECS!") 

969 

970 

971 if __name__ == "__main__": 

972 my_flow.deploy( 

973 name="my-deployment", 

974 work_pool_name="{self._work_pool_name}", 

975 image=DockerImage( 

976 name="{self._repository_name}:latest", 

977 platform="linux/amd64", 

978 ) 

979 )""" 

980 ), 

981 "python", 

982 background_color="default", 

983 ), 

984 title="example_deploy_script.py", 

985 expand=False, 

986 ), 

987 ] 

988 ) 

989 

990 @property 

991 def next_steps(self) -> list[str | Panel]: 

992 return self._next_steps 

993 

994 

995class ExecutionRoleResource: 

996 def __init__(self, execution_role_name: str = "PrefectEcsTaskExecutionRole"): 

997 self._iam_client = boto3.client("iam") 

998 self._execution_role_name = execution_role_name 

999 self._trust_policy_document = json.dumps( 

1000 { 

1001 "Version": "2012-10-17", 

1002 "Statement": [ 

1003 { 

1004 "Effect": "Allow", 

1005 "Principal": {"Service": "ecs-tasks.amazonaws.com"}, 

1006 "Action": "sts:AssumeRole", 

1007 } 

1008 ], 

1009 } 

1010 ) 

1011 self._requires_provisioning = None 

1012 

1013 async def get_task_count(self) -> int: 

1014 """ 

1015 Returns the number of tasks that will be executed to provision this resource. 

1016 

1017 Returns: 

1018 int: The number of tasks to be provisioned. 

1019 """ 

1020 return 1 if await self.requires_provisioning() else 0 

1021 

1022 async def requires_provisioning(self) -> bool: 

1023 """ 

1024 Check if this resource requires provisioning. 

1025 

1026 Returns: 

1027 bool: True if provisioning is required, False otherwise. 

1028 """ 

1029 if self._requires_provisioning is None: 

1030 try: 

1031 await anyio.to_thread.run_sync( 

1032 partial( 

1033 self._iam_client.get_role, RoleName=self._execution_role_name 

1034 ) 

1035 ) 

1036 self._requires_provisioning = False 

1037 except self._iam_client.exceptions.NoSuchEntityException: 

1038 self._requires_provisioning = True 

1039 

1040 return self._requires_provisioning 

1041 

1042 async def get_planned_actions(self) -> List[str]: 

1043 """ 

1044 Returns a description of the planned actions for provisioning this resource. 

1045 

1046 Returns: 

1047 Optional[str]: A description of the planned actions for provisioning the resource, 

1048 or None if provisioning is not required. 

1049 """ 

1050 if await self.requires_provisioning(): 

1051 return [ 

1052 "Creating an IAM role assigned to ECS tasks:" 

1053 f" [blue]{self._execution_role_name}[/]" 

1054 ] 

1055 return [] 

1056 

1057 async def provision( 

1058 self, 

1059 base_job_template: dict[str, Any], 

1060 advance: Callable[[], None], 

1061 ) -> str: 

1062 """ 

1063 Provisions an IAM role. 

1064 

1065 Args: 

1066 base_job_template: The base job template of the work pool to provision 

1067 infrastructure for. 

1068 advance: A callback function to indicate progress. 

1069 """ 

1070 if await self.requires_provisioning(): 

1071 console = current_console.get() 

1072 console.print("Provisioning execution role") 

1073 response = await anyio.to_thread.run_sync( 

1074 partial( 

1075 self._iam_client.create_role, 

1076 RoleName=self._execution_role_name, 

1077 Description="Role for ECS tasks to access ECR and other resources.", 

1078 AssumeRolePolicyDocument=self._trust_policy_document, 

1079 ) 

1080 ) 

1081 await anyio.to_thread.run_sync( 

1082 partial( 

1083 self._iam_client.attach_role_policy, 

1084 RoleName=self._execution_role_name, 

1085 PolicyArn="arn:aws:iam::aws:policy/service-role/AmazonECSTaskExecutionRolePolicy", 

1086 ) 

1087 ) 

1088 advance() 

1089 else: 

1090 response = await anyio.to_thread.run_sync( 

1091 partial(self._iam_client.get_role, RoleName=self._execution_role_name) 

1092 ) 

1093 

1094 base_job_template["variables"]["properties"]["execution_role_arn"][ 

1095 "default" 

1096 ] = response["Role"]["Arn"] 

1097 return response["Role"]["Arn"] 

1098 

1099 @property 

1100 def next_steps(self) -> list[str]: 

1101 return [] 

1102 

1103 

1104class ElasticContainerServicePushProvisioner: 

1105 """ 

1106 An infrastructure provisioner for ECS push work pools. 

1107 """ 

1108 

1109 def __init__(self): 

1110 self._console = Console() 

1111 

1112 @property 

1113 def console(self) -> Console: 

1114 return self._console 

1115 

1116 @console.setter 

1117 def console(self, value: Console) -> None: 

1118 self._console = value 

1119 

1120 async def _prompt_boto3_installation(self): 

1121 global boto3 

1122 await ainstall_packages(["boto3"]) 

1123 boto3 = importlib.import_module("boto3") 

1124 

1125 @staticmethod 

1126 def is_boto3_installed() -> bool: 

1127 """ 

1128 Check if boto3 is installed. 

1129 """ 

1130 try: 

1131 importlib.import_module("boto3") 

1132 return True 

1133 except ModuleNotFoundError: 

1134 return False 

1135 

1136 def _generate_resources( 

1137 self, 

1138 work_pool_name: str, 

1139 user_name: str = "prefect-ecs-user", 

1140 policy_name: str = "prefect-ecs-policy", 

1141 credentials_block_name: Optional[str] = None, 

1142 cluster_name: str = "prefect-ecs-cluster", 

1143 vpc_name: str = "prefect-ecs-vpc", 

1144 ecs_security_group_name: str = "prefect-ecs-security-group", 

1145 repository_name: str = "prefect-flows", 

1146 ): 

1147 return [ 

1148 AuthenticationResource( 

1149 work_pool_name=work_pool_name, 

1150 user_name=user_name, 

1151 policy_name=policy_name, 

1152 credentials_block_name=credentials_block_name, 

1153 ), 

1154 ClusterResource(cluster_name=cluster_name), 

1155 VpcResource( 

1156 vpc_name=vpc_name, 

1157 ecs_security_group_name=ecs_security_group_name, 

1158 ), 

1159 ContainerRepositoryResource( 

1160 work_pool_name=work_pool_name, 

1161 repository_name=repository_name, 

1162 ), 

1163 ] 

1164 

1165 async def provision( 

1166 self, 

1167 work_pool_name: str, 

1168 base_job_template: dict[str, Any], 

1169 ) -> dict[str, Any]: 

1170 """ 

1171 Provisions the infrastructure for an ECS push work pool. 

1172 

1173 Args: 

1174 work_pool_name: The name of the work pool to provision infrastructure for. 

1175 base_job_template: The base job template of the work pool to provision 

1176 infrastructure for. 

1177 

1178 Returns: 

1179 dict: An updated copy base job template. 

1180 """ 

1181 from prefect.cli._prompts import prompt 

1182 

1183 if not self.is_boto3_installed(): 

1184 if self.console.is_interactive and Confirm.ask( 

1185 "boto3 is required to configure your AWS account. Would you like to" 

1186 " install it?" 

1187 ): 

1188 await self._prompt_boto3_installation() 

1189 else: 

1190 raise RuntimeError( 

1191 "boto3 is required to configure your AWS account. Please install it" 

1192 " and try again." 

1193 ) 

1194 

1195 try: 

1196 if self.console.is_interactive and Confirm.ask( 

1197 "Would you like to customize the resource names for your" 

1198 " infrastructure? This includes an IAM user, IAM policy, ECS cluster," 

1199 " VPC, ECS security group, and ECR repository." 

1200 ): 

1201 user_name = prompt( 

1202 "Enter a name for the IAM user (manages ECS tasks)", 

1203 default="prefect-ecs-user", 

1204 ) 

1205 policy_name = prompt( 

1206 ( 

1207 "Enter a name for the IAM policy (defines ECS task execution" 

1208 " and image management permissions)" 

1209 ), 

1210 default="prefect-ecs-policy", 

1211 ) 

1212 cluster_name = prompt( 

1213 "Enter a name for the ECS cluster (hosts ECS tasks)", 

1214 default="prefect-ecs-cluster", 

1215 ) 

1216 credentials_name = prompt( 

1217 ( 

1218 "Enter a name for the AWS credentials block (stores AWS" 

1219 " credentials for managing ECS tasks)" 

1220 ), 

1221 default=f"{work_pool_name}-aws-credentials", 

1222 ) 

1223 vpc_name = prompt( 

1224 ( 

1225 "Enter a name for the VPC (provides network isolation for ECS" 

1226 " tasks)" 

1227 ), 

1228 default="prefect-ecs-vpc", 

1229 ) 

1230 ecs_security_group_name = prompt( 

1231 ( 

1232 "Enter a name for the ECS security group (controls task network" 

1233 " traffic)" 

1234 ), 

1235 default="prefect-ecs-security-group", 

1236 ) 

1237 repository_name = prompt( 

1238 ( 

1239 "Enter a name for the ECR repository (stores Docker images for" 

1240 " ECS tasks)" 

1241 ), 

1242 default="prefect-flows", 

1243 ) 

1244 

1245 provision_preview = Panel( 

1246 dedent( 

1247 f"""\ 

1248 Custom names for infrastructure resources for 

1249 [blue]{work_pool_name}[/]: 

1250 

1251 - IAM user: [blue]{user_name}[/] 

1252 - IAM policy: [blue]{policy_name}[/] 

1253 - ECS cluster: [blue]{cluster_name}[/] 

1254 - AWS credentials block: [blue]{credentials_name}[/] 

1255 - VPC: [blue]{vpc_name}[/] 

1256 - ECS security group: [blue]{ecs_security_group_name}[/] 

1257 - ECR repository: [blue]{repository_name}[/] 

1258 """ 

1259 ), 

1260 expand=False, 

1261 ) 

1262 

1263 self.console.print(provision_preview) 

1264 

1265 resources = self._generate_resources( 

1266 work_pool_name=work_pool_name, 

1267 user_name=user_name, 

1268 policy_name=policy_name, 

1269 credentials_block_name=credentials_name, 

1270 cluster_name=cluster_name, 

1271 vpc_name=vpc_name, 

1272 ecs_security_group_name=ecs_security_group_name, 

1273 repository_name=repository_name, 

1274 ) 

1275 

1276 else: 

1277 resources = self._generate_resources(work_pool_name=work_pool_name) 

1278 

1279 with Progress( 

1280 SpinnerColumn(), 

1281 TextColumn( 

1282 "Checking your AWS account for infrastructure that needs to be" 

1283 " provisioned..." 

1284 ), 

1285 transient=True, 

1286 console=self.console, 

1287 ) as progress: 

1288 inspect_aws_account_task = progress.add_task( 

1289 "inspect_aws_account", total=1 

1290 ) 

1291 num_tasks = sum( 

1292 [await resource.get_task_count() for resource in resources] 

1293 ) 

1294 progress.update(inspect_aws_account_task, completed=1) 

1295 

1296 if num_tasks > 0: 

1297 message = ( 

1298 "Provisioning infrastructure for your work pool" 

1299 f" [blue]{work_pool_name}[/] will require: \n" 

1300 ) 

1301 for resource in resources: 

1302 planned_actions = await resource.get_planned_actions() 

1303 for action in planned_actions: 

1304 message += f"\n\t - {action}" 

1305 

1306 self.console.print(Panel(message)) 

1307 

1308 if self._console.is_interactive: 

1309 if not Confirm.ask( 

1310 "Proceed with infrastructure provisioning?", 

1311 console=self._console, 

1312 ): 

1313 return base_job_template 

1314 else: 

1315 self.console.print( 

1316 "No additional infrastructure required for work pool" 

1317 f" [blue]{work_pool_name}[/]" 

1318 ) 

1319 # don't return early, we still need to update the base job template 

1320 # provision calls will be no-ops, but update the base job template 

1321 

1322 base_job_template_copy = deepcopy(base_job_template) 

1323 next_steps: list[str | Panel] = [] 

1324 with Progress(console=self._console, disable=num_tasks == 0) as progress: 

1325 task = progress.add_task( 

1326 "Provisioning Infrastructure", 

1327 total=num_tasks, 

1328 ) 

1329 for resource in resources: 

1330 with console_context(progress.console): 

1331 await resource.provision( 

1332 advance=partial(progress.advance, task), 

1333 base_job_template=base_job_template_copy, 

1334 ) 

1335 next_steps.append(resource.next_steps) 

1336 

1337 if next_steps: 

1338 for step in next_steps: 

1339 for item in step: 

1340 self._console.print(item) 

1341 

1342 if num_tasks > 0: 

1343 self._console.print( 

1344 "Infrastructure successfully provisioned!", style="green" 

1345 ) 

1346 

1347 return base_job_template_copy 

1348 except Exception as exc: 

1349 if hasattr(exc, "response"): 

1350 # Catching boto3 ClientError 

1351 response = getattr(exc, "response", {}) 

1352 error_message = get_from_dict(response, "Error.Message") or str(exc) 

1353 raise RuntimeError(error_message) from exc 

1354 # Catching any other exception 

1355 raise RuntimeError(str(exc)) from exc