Coverage for /usr/local/lib/python3.12/site-packages/prefect/infrastructure/provisioners/ecs.py: 0%
454 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
1from __future__ import annotations
3import base64
4import contextlib
5import contextvars
6import importlib
7import ipaddress
8import json
9from copy import deepcopy
10from functools import partial
11from textwrap import dedent
12from types import ModuleType
13from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional
15import anyio
16import anyio.to_thread
17from anyio import run_process
18from rich.console import Console
19from rich.panel import Panel
20from rich.progress import Progress, SpinnerColumn, TextColumn
21from rich.prompt import Confirm
22from rich.syntax import Syntax
24from prefect._internal.installation import ainstall_packages
25from prefect.client.schemas.actions import BlockDocumentCreate
26from prefect.client.utilities import inject_client
27from prefect.exceptions import ObjectNotFound
28from prefect.settings import (
29 PREFECT_DEFAULT_DOCKER_BUILD_NAMESPACE,
30 update_current_profile,
31)
32from prefect.utilities.collections import get_from_dict
33from prefect.utilities.importtools import lazy_import
35if TYPE_CHECKING:
36 from prefect.client.orchestration import PrefectClient
38boto3: ModuleType = lazy_import("boto3")
40current_console: contextvars.ContextVar[Console] = contextvars.ContextVar(
41 "console", default=Console()
42)
45@contextlib.contextmanager
46def console_context(value: Console) -> Generator[None, None, None]:
47 token = current_console.set(value)
48 try:
49 yield
50 finally:
51 current_console.reset(token)
54class IamPolicyResource:
55 """
56 Represents an IAM policy resource for managing ECS tasks.
58 Args:
59 policy_name: The name of the IAM policy. Defaults to "prefect-ecs-policy".
60 """
62 def __init__(
63 self,
64 policy_name: str,
65 ):
66 self._iam_client = boto3.client("iam")
67 self._policy_name = policy_name
69 self._requires_provisioning = None
71 async def get_task_count(self) -> int:
72 """
73 Returns the number of tasks that will be executed to provision this resource.
75 Returns:
76 int: The number of tasks to be provisioned.
77 """
78 return 1 if await self.requires_provisioning() else 0
80 def _get_policy_by_name(self, name: str) -> dict[str, Any] | None:
81 paginator = self._iam_client.get_paginator("list_policies")
82 page_iterator = paginator.paginate(Scope="Local")
84 for page in page_iterator:
85 for policy in page["Policies"]:
86 if policy["PolicyName"] == name:
87 return policy
88 return None
90 async def requires_provisioning(self) -> bool:
91 """
92 Check if this resource requires provisioning.
94 Returns:
95 bool: True if provisioning is required, False otherwise.
96 """
97 if self._requires_provisioning is not None:
98 return self._requires_provisioning
99 policy = await anyio.to_thread.run_sync(
100 partial(self._get_policy_by_name, self._policy_name)
101 )
102 if policy is not None:
103 self._requires_provisioning = False
104 return False
106 self._requires_provisioning = True
107 return True
109 async def get_planned_actions(self) -> List[str]:
110 """
111 Returns a description of the planned actions for provisioning this resource.
113 Returns:
114 Optional[str]: A description of the planned actions for provisioning the resource,
115 or None if provisioning is not required.
116 """
117 if await self.requires_provisioning():
118 return [
119 "Creating and attaching an IAM policy for managing ECS tasks:"
120 f" [blue]{self._policy_name}[/]"
121 ]
122 return []
124 async def provision(
125 self,
126 policy_document: dict[str, Any],
127 advance: Callable[[], None],
128 ) -> str:
129 """
130 Provisions an IAM policy.
132 Args:
133 advance: A callback function to indicate progress.
135 Returns:
136 str: The ARN (Amazon Resource Name) of the created IAM policy.
137 """
138 if await self.requires_provisioning():
139 console = current_console.get()
140 console.print("Creating IAM policy")
141 policy = await anyio.to_thread.run_sync(
142 partial(
143 self._iam_client.create_policy,
144 PolicyName=self._policy_name,
145 PolicyDocument=json.dumps(policy_document),
146 )
147 )
148 policy_arn = policy["Policy"]["Arn"]
149 advance()
150 return policy_arn
151 else:
152 policy = await anyio.to_thread.run_sync(
153 partial(self._get_policy_by_name, self._policy_name)
154 )
155 # This should never happen, but just in case
156 assert policy is not None, "Could not find expected policy"
157 return policy["Arn"]
159 @property
160 def next_steps(self) -> list[str]:
161 return []
164class IamUserResource:
165 """
166 Represents an IAM user resource for managing ECS tasks.
168 Args:
169 user_name: The desired name of the IAM user.
170 """
172 def __init__(self, user_name: str):
173 self._iam_client = boto3.client("iam")
174 self._user_name = user_name
175 self._requires_provisioning = None
177 async def get_task_count(self) -> int:
178 """
179 Returns the number of tasks that will be executed to provision this resource.
181 Returns:
182 int: The number of tasks to be provisioned.
183 """
184 return 1 if await self.requires_provisioning() else 0
186 async def requires_provisioning(self) -> bool:
187 """
188 Check if this resource requires provisioning.
190 Returns:
191 bool: True if provisioning is required, False otherwise.
192 """
193 if self._requires_provisioning is None:
194 try:
195 await anyio.to_thread.run_sync(
196 partial(self._iam_client.get_user, UserName=self._user_name)
197 )
198 self._requires_provisioning = False
199 except self._iam_client.exceptions.NoSuchEntityException:
200 self._requires_provisioning = True
202 return self._requires_provisioning
204 async def get_planned_actions(self) -> List[str]:
205 """
206 Returns a description of the planned actions for provisioning this resource.
208 Returns:
209 Optional[str]: A description of the planned actions for provisioning the resource,
210 or None if provisioning is not required.
211 """
212 if await self.requires_provisioning():
213 return [
214 "Creating an IAM user for managing ECS tasks:"
215 f" [blue]{self._user_name}[/]"
216 ]
217 return []
219 async def provision(
220 self,
221 advance: Callable[[], None],
222 ) -> None:
223 """
224 Provisions an IAM user.
226 Args:
227 advance: A callback function to indicate progress.
228 """
229 console = current_console.get()
230 if await self.requires_provisioning():
231 console.print("Provisioning IAM user")
232 await anyio.to_thread.run_sync(
233 partial(self._iam_client.create_user, UserName=self._user_name)
234 )
235 advance()
237 @property
238 def next_steps(self) -> list[str]:
239 return []
242class CredentialsBlockResource:
243 def __init__(self, user_name: str, block_document_name: str):
244 self._block_document_name = block_document_name
245 self._user_name = user_name
246 self._requires_provisioning = None
248 async def get_task_count(self) -> int:
249 """
250 Returns the number of tasks that will be executed to provision this resource.
252 Returns:
253 int: The number of tasks to be provisioned.
254 """
255 return 2 if await self.requires_provisioning() else 0
257 @inject_client
258 async def requires_provisioning(
259 self, client: Optional["PrefectClient"] = None
260 ) -> bool:
261 if self._requires_provisioning is None:
262 try:
263 assert client is not None
264 await client.read_block_document_by_name(
265 self._block_document_name, "aws-credentials"
266 )
267 self._requires_provisioning = False
268 except ObjectNotFound:
269 self._requires_provisioning = True
270 return self._requires_provisioning
272 async def get_planned_actions(self) -> List[str]:
273 """
274 Returns a description of the planned actions for provisioning this resource.
276 Returns:
277 Optional[str]: A description of the planned actions for provisioning the resource,
278 or None if provisioning is not required.
279 """
280 if await self.requires_provisioning():
281 return ["Storing generated AWS credentials in a block"]
282 return []
284 @inject_client
285 async def provision(
286 self,
287 base_job_template: Dict[str, Any],
288 advance: Callable[[], None],
289 client: Optional["PrefectClient"] = None,
290 ):
291 """
292 Provisions an AWS credentials block.
294 Will generate new credentials if the block does not already exist. Updates
295 the `aws_credentials` variable in the job template to reference the block.
297 Args:
298 base_job_template: The base job template.
299 advance: A callback function to indicate progress.
300 client: A Prefect client to use for interacting with the Prefect API.
301 """
302 assert client is not None, "Client injection failed"
303 if not await self.requires_provisioning():
304 block_doc = await client.read_block_document_by_name(
305 self._block_document_name, "aws-credentials"
306 )
307 else:
308 console = current_console.get()
309 console.print("Generating AWS credentials")
310 iam_client = boto3.client("iam")
311 access_key_data = await anyio.to_thread.run_sync(
312 partial(iam_client.create_access_key, UserName=self._user_name)
313 )
314 access_key = access_key_data["AccessKey"]
315 advance()
316 console.print("Creating AWS credentials block")
317 assert client is not None
319 try:
320 credentials_block_type = await client.read_block_type_by_slug(
321 "aws-credentials"
322 )
323 except ObjectNotFound as exc:
324 raise RuntimeError(
325 dedent(
326 """\
327 Unable to find block type "aws-credentials".
328 To register the `aws-credentials` block type, run:
330 pip install prefect-aws
331 prefect blocks register -m prefect_aws
333 """
334 )
335 ) from exc
337 credentials_block_schema = (
338 await client.get_most_recent_block_schema_for_block_type(
339 block_type_id=credentials_block_type.id
340 )
341 )
342 assert credentials_block_schema is not None, (
343 f"Unable to find schema for block type {credentials_block_type.slug}"
344 )
346 block_doc = await client.create_block_document(
347 block_document=BlockDocumentCreate(
348 name=self._block_document_name,
349 data={
350 "aws_access_key_id": access_key["AccessKeyId"],
351 "aws_secret_access_key": access_key["SecretAccessKey"],
352 "region_name": boto3.session.Session().region_name,
353 },
354 block_type_id=credentials_block_type.id,
355 block_schema_id=credentials_block_schema.id,
356 )
357 )
358 advance()
359 base_job_template["variables"]["properties"]["aws_credentials"]["default"] = {
360 "$ref": {"block_document_id": str(block_doc.id)}
361 }
363 @property
364 def next_steps(self) -> list[str]:
365 return []
368class AuthenticationResource:
369 def __init__(
370 self,
371 work_pool_name: str,
372 user_name: str = "prefect-ecs-user",
373 policy_name: str = "prefect-ecs-policy",
374 credentials_block_name: Optional[str] = None,
375 ):
376 self._user_name = user_name
377 self._credentials_block_name = (
378 credentials_block_name or f"{work_pool_name}-aws-credentials"
379 )
380 self._policy_name = policy_name
381 self._policy_document: dict[str, Any] = {
382 "Version": "2012-10-17",
383 "Statement": [
384 {
385 "Sid": "PrefectEcsPolicy",
386 "Effect": "Allow",
387 "Action": [
388 "ec2:AuthorizeSecurityGroupIngress",
389 "ec2:CreateSecurityGroup",
390 "ec2:CreateTags",
391 "ec2:DescribeNetworkInterfaces",
392 "ec2:DescribeSecurityGroups",
393 "ec2:DescribeSubnets",
394 "ec2:DescribeVpcs",
395 "ecs:CreateCluster",
396 "ecs:DeregisterTaskDefinition",
397 "ecs:DescribeClusters",
398 "ecs:DescribeTaskDefinition",
399 "ecs:DescribeTasks",
400 "ecs:ListAccountSettings",
401 "ecs:ListClusters",
402 "ecs:ListTaskDefinitions",
403 "ecs:RegisterTaskDefinition",
404 "ecs:RunTask",
405 "ecs:StopTask",
406 "ecs:TagResource",
407 "logs:CreateLogStream",
408 "logs:PutLogEvents",
409 "logs:DescribeLogGroups",
410 "logs:GetLogEvents",
411 ],
412 "Resource": "*",
413 }
414 ],
415 }
416 self._iam_user_resource = IamUserResource(user_name=user_name)
417 self._iam_policy_resource = IamPolicyResource(policy_name=policy_name)
418 self._credentials_block_resource = CredentialsBlockResource(
419 user_name=user_name, block_document_name=self._credentials_block_name
420 )
421 self._execution_role_resource = ExecutionRoleResource()
423 @property
424 def resources(
425 self,
426 ) -> list[
427 "ExecutionRoleResource | IamUserResource | IamPolicyResource | CredentialsBlockResource"
428 ]:
429 return [
430 self._execution_role_resource,
431 self._iam_user_resource,
432 self._iam_policy_resource,
433 self._credentials_block_resource,
434 ]
436 async def get_task_count(self) -> int:
437 """
438 Returns the number of tasks that will be executed to provision this resource.
440 Returns:
441 int: The number of tasks to be provisioned.
442 """
443 return sum([await resource.get_task_count() for resource in self.resources])
445 async def requires_provisioning(self) -> bool:
446 """
447 Check if this resource requires provisioning.
449 Returns:
450 bool: True if provisioning is required, False otherwise.
451 """
452 return any(
453 [await resource.requires_provisioning() for resource in self.resources]
454 )
456 async def get_planned_actions(self) -> List[str]:
457 """
458 Returns a description of the planned actions for provisioning this resource.
460 Returns:
461 Optional[str]: A description of the planned actions for provisioning the resource,
462 or None if provisioning is not required.
463 """
464 return [
465 action
466 for resource in self.resources
467 for action in await resource.get_planned_actions()
468 ]
470 async def provision(
471 self,
472 base_job_template: dict[str, Any],
473 advance: Callable[[], None],
474 ) -> None:
475 """
476 Provisions the authentication resources.
478 Args:
479 base_job_template: The base job template of the work pool to provision
480 infrastructure for.
481 advance: A callback function to indicate progress.
482 """
483 # Provision task execution role
484 role_arn = await self._execution_role_resource.provision(
485 base_job_template=base_job_template, advance=advance
486 )
487 # Update policy document with the role ARN
488 self._policy_document["Statement"].append(
489 {
490 "Sid": "AllowPassRoleForEcs",
491 "Effect": "Allow",
492 "Action": "iam:PassRole",
493 "Resource": role_arn,
494 }
495 )
496 # Provision the IAM user
497 await self._iam_user_resource.provision(advance=advance)
498 # Provision the IAM policy
499 policy_arn = await self._iam_policy_resource.provision(
500 policy_document=self._policy_document, advance=advance
501 )
502 # Attach the policy to the user
503 if policy_arn:
504 iam_client = boto3.client("iam")
505 await anyio.to_thread.run_sync(
506 partial(
507 iam_client.attach_user_policy,
508 UserName=self._user_name,
509 PolicyArn=policy_arn,
510 )
511 )
512 await self._credentials_block_resource.provision(
513 base_job_template=base_job_template,
514 advance=advance,
515 )
517 @property
518 def next_steps(self) -> list[str]:
519 return [
520 next_step
521 for resource in self.resources
522 for next_step in resource.next_steps
523 ]
526class ClusterResource:
527 def __init__(self, cluster_name: str = "prefect-ecs-cluster"):
528 self._ecs_client = boto3.client("ecs")
529 self._cluster_name = cluster_name
530 self._requires_provisioning = None
532 async def get_task_count(self) -> int:
533 """
534 Returns the number of tasks that will be executed to provision this resource.
536 Returns:
537 int: The number of tasks to be provisioned.
538 """
539 return 1 if await self.requires_provisioning() else 0
541 async def requires_provisioning(self) -> bool:
542 """
543 Check if this resource requires provisioning.
545 Returns:
546 bool: True if provisioning is required, False otherwise.
547 """
548 if self._requires_provisioning is None:
549 response = await anyio.to_thread.run_sync(
550 partial(
551 self._ecs_client.describe_clusters, clusters=[self._cluster_name]
552 )
553 )
554 if response["clusters"] and response["clusters"][0]["status"] == "ACTIVE":
555 self._requires_provisioning = False
556 else:
557 self._requires_provisioning = True
558 return self._requires_provisioning
560 async def get_planned_actions(self) -> List[str]:
561 """
562 Returns a description of the planned actions for provisioning this resource.
564 Returns:
565 Optional[str]: A description of the planned actions for provisioning the resource,
566 or None if provisioning is not required.
567 """
568 if await self.requires_provisioning():
569 return [
570 "Creating an ECS cluster for running Prefect flows:"
571 f" [blue]{self._cluster_name}[/]"
572 ]
573 return []
575 async def provision(
576 self,
577 base_job_template: dict[str, Any],
578 advance: Callable[[], None],
579 ) -> None:
580 """
581 Provisions an ECS cluster.
583 Will update the `cluster` variable in the job template to reference the cluster.
585 Args:
586 base_job_template: The base job template of the work pool to provision
587 infrastructure for.
588 advance: A callback function to indicate progress.
589 """
590 if await self.requires_provisioning():
591 console = current_console.get()
592 console.print("Provisioning ECS cluster")
593 await anyio.to_thread.run_sync(
594 partial(self._ecs_client.create_cluster, clusterName=self._cluster_name)
595 )
596 advance()
598 base_job_template["variables"]["properties"]["cluster"]["default"] = (
599 self._cluster_name
600 )
602 @property
603 def next_steps(self) -> list[str]:
604 return []
607class VpcResource:
608 def __init__(
609 self,
610 vpc_name: str = "prefect-ecs-vpc",
611 ecs_security_group_name: str = "prefect-ecs-security-group",
612 ):
613 self._ec2_client = boto3.client("ec2")
614 self._ec2_resource = boto3.resource("ec2")
615 self._vpc_name = vpc_name
616 self._requires_provisioning = None
617 self._ecs_security_group_name = ecs_security_group_name
619 async def get_task_count(self) -> int:
620 """
621 Returns the number of tasks that will be executed to provision this resource.
623 Returns:
624 int: The number of tasks to be provisioned.
625 """
626 return 4 if await self.requires_provisioning() else 0
628 async def _default_vpc_exists(self):
629 response = await anyio.to_thread.run_sync(self._ec2_client.describe_vpcs)
630 default_vpc = next(
631 (
632 vpc
633 for vpc in response["Vpcs"]
634 if vpc["IsDefault"] and vpc["State"] == "available"
635 ),
636 None,
637 )
638 return default_vpc is not None
640 async def _get_prefect_created_vpc(self):
641 vpcs = await anyio.to_thread.run_sync(
642 partial(
643 self._ec2_resource.vpcs.filter,
644 Filters=[{"Name": "tag:Name", "Values": [self._vpc_name]}],
645 )
646 )
647 return next(iter(vpcs), None)
649 async def _get_existing_vpc_cidrs(self):
650 response = await anyio.to_thread.run_sync(self._ec2_client.describe_vpcs)
651 return [vpc["CidrBlock"] for vpc in response["Vpcs"]]
653 async def _find_non_overlapping_cidr(
654 self, default_cidr: str = "172.31.0.0/16"
655 ) -> str:
656 """Find a non-overlapping CIDR block"""
657 response = await anyio.to_thread.run_sync(self._ec2_client.describe_vpcs)
658 existing_cidrs = [vpc["CidrBlock"] for vpc in response["Vpcs"]]
660 base_ip = ipaddress.ip_network(default_cidr)
661 new_cidr = base_ip
662 while True:
663 if any(
664 new_cidr.overlaps(ipaddress.ip_network(cidr)) for cidr in existing_cidrs
665 ):
666 # Increase the network address by the size of the network
667 new_network_address = int(new_cidr.network_address) + 2 ** (
668 32 - new_cidr.prefixlen
669 )
670 try:
671 new_cidr = ipaddress.ip_network(
672 f"{ipaddress.IPv4Address(new_network_address)}/{new_cidr.prefixlen}"
673 )
674 except ValueError:
675 raise Exception(
676 "Unable to find a non-overlapping CIDR block in the default"
677 " range"
678 )
679 else:
680 return str(new_cidr)
682 async def requires_provisioning(self) -> bool:
683 """
684 Check if this resource requires provisioning.
686 Returns:
687 bool: True if provisioning is required, False otherwise.
688 """
689 if self._requires_provisioning is not None:
690 return self._requires_provisioning
692 if await self._default_vpc_exists():
693 self._requires_provisioning = False
694 return False
696 if await self._get_prefect_created_vpc() is not None:
697 self._requires_provisioning = False
698 return False
700 self._requires_provisioning = True
701 return True
703 async def get_planned_actions(self) -> List[str]:
704 """
705 Returns a description of the planned actions for provisioning this resource.
707 Returns:
708 Optional[str]: A description of the planned actions for provisioning the resource,
709 or None if provisioning is not required.
710 """
711 if await self.requires_provisioning():
712 new_vpc_cidr = await self._find_non_overlapping_cidr()
713 return [
714 f"Creating a VPC with CIDR [blue]{new_vpc_cidr}[/] for running"
715 f" ECS tasks: [blue]{self._vpc_name}[/]"
716 ]
717 return []
719 async def provision(
720 self,
721 base_job_template: dict[str, Any],
722 advance: Callable[[], None],
723 ) -> None:
724 """
725 Provisions a VPC.
727 Chooses a CIDR block to avoid conflicting with any existing VPCs. Will update
728 the `vpc_id` variable in the job template to reference the VPC.
730 Args:
731 base_job_template: The base job template of the work pool to provision
732 infrastructure for.
733 advance: A callback function to indicate progress.
734 """
735 if await self.requires_provisioning():
736 console = current_console.get()
737 console.print("Provisioning VPC")
738 new_vpc_cidr = await self._find_non_overlapping_cidr()
739 vpc = await anyio.to_thread.run_sync(
740 partial(self._ec2_resource.create_vpc, CidrBlock=new_vpc_cidr)
741 )
742 await anyio.to_thread.run_sync(vpc.wait_until_available)
743 await anyio.to_thread.run_sync(
744 partial(
745 vpc.create_tags,
746 Resources=[vpc.id],
747 Tags=[
748 {
749 "Key": "Name",
750 "Value": self._vpc_name,
751 },
752 ],
753 )
754 )
755 advance()
757 console.print("Creating internet gateway")
758 internet_gateway = await anyio.to_thread.run_sync(
759 self._ec2_resource.create_internet_gateway
760 )
761 await anyio.to_thread.run_sync(
762 partial(
763 vpc.attach_internet_gateway, InternetGatewayId=internet_gateway.id
764 )
765 )
766 advance()
768 console.print("Setting up subnets")
769 vpc_network = ipaddress.ip_network(new_vpc_cidr)
770 subnet_cidrs = list(
771 vpc_network.subnets(new_prefix=vpc_network.prefixlen + 2)
772 )
774 # Create subnets
775 azs = (
776 await anyio.to_thread.run_sync(
777 self._ec2_client.describe_availability_zones
778 )
779 )["AvailabilityZones"]
780 zones = [az["ZoneName"] for az in azs]
781 subnets: list[Any] = []
782 for i, subnet_cidr in enumerate(subnet_cidrs[0:3]):
783 subnets.append(
784 await anyio.to_thread.run_sync(
785 partial(
786 vpc.create_subnet,
787 CidrBlock=str(subnet_cidr),
788 AvailabilityZone=zones[i],
789 )
790 )
791 )
793 # Create a Route Table for the public subnet and add a route to the Internet Gateway
794 public_route_table = await anyio.to_thread.run_sync(vpc.create_route_table)
795 await anyio.to_thread.run_sync(
796 partial(
797 public_route_table.create_route,
798 DestinationCidrBlock="0.0.0.0/0",
799 GatewayId=internet_gateway.id,
800 )
801 )
802 await anyio.to_thread.run_sync(
803 partial(
804 public_route_table.associate_with_subnet, SubnetId=subnets[0].id
805 )
806 )
807 await anyio.to_thread.run_sync(
808 partial(
809 public_route_table.associate_with_subnet, SubnetId=subnets[1].id
810 )
811 )
812 await anyio.to_thread.run_sync(
813 partial(
814 public_route_table.associate_with_subnet, SubnetId=subnets[2].id
815 )
816 )
817 advance()
819 console.print("Setting up security group")
820 # Create a security group to block all inbound traffic
821 await anyio.to_thread.run_sync(
822 partial(
823 self._ec2_resource.create_security_group,
824 GroupName=self._ecs_security_group_name,
825 Description=(
826 "Block all inbound traffic and allow all outbound traffic"
827 ),
828 VpcId=vpc.id,
829 )
830 )
831 advance()
832 else:
833 vpc = await self._get_prefect_created_vpc()
835 if vpc is not None:
836 base_job_template["variables"]["properties"]["vpc_id"]["default"] = str(
837 vpc.id
838 )
840 @property
841 def next_steps(self) -> list[str]:
842 return []
845class ContainerRepositoryResource:
846 def __init__(self, work_pool_name: str, repository_name: str = "prefect-flows"):
847 self._ecr_client = boto3.client("ecr")
848 self._repository_name = repository_name
849 self._requires_provisioning = None
850 self._work_pool_name = work_pool_name
851 self._next_steps: list[str | Panel] = []
853 async def get_task_count(self) -> int:
854 """
855 Returns the number of tasks that will be executed to provision this resource.
857 Returns:
858 int: The number of tasks to be provisioned.
859 """
860 return 3 if await self.requires_provisioning() else 0
862 async def _get_prefect_created_registry(self):
863 try:
864 registries = await anyio.to_thread.run_sync(
865 partial(
866 self._ecr_client.describe_repositories,
867 repositoryNames=[self._repository_name],
868 )
869 )
870 return next(iter(registries), None)
871 except self._ecr_client.exceptions.RepositoryNotFoundException:
872 return None
874 async def requires_provisioning(self) -> bool:
875 """
876 Check if this resource requires provisioning.
878 Returns:
879 bool: True if provisioning is required, False otherwise.
880 """
881 if self._requires_provisioning is not None:
882 return self._requires_provisioning
884 if await self._get_prefect_created_registry() is not None:
885 self._requires_provisioning = False
886 return False
888 self._requires_provisioning = True
889 return True
891 async def get_planned_actions(self) -> List[str]:
892 """
893 Returns a description of the planned actions for provisioning this resource.
895 Returns:
896 Optional[str]: A description of the planned actions for provisioning the resource,
897 or None if provisioning is not required.
898 """
899 if await self.requires_provisioning():
900 return [
901 "Creating an ECR repository for storing Prefect images:"
902 f" [blue]{self._repository_name}[/]"
903 ]
904 return []
906 async def provision(
907 self,
908 base_job_template: dict[str, Any],
909 advance: Callable[[], None],
910 ) -> None:
911 """
912 Provisions an ECR repository.
914 Args:
915 base_job_template: The base job template of the work pool to provision
916 infrastructure for.
917 advance: A callback function to indicate progress.
918 """
919 if await self.requires_provisioning():
920 console = current_console.get()
921 console.print("Provisioning ECR repository")
922 response = await anyio.to_thread.run_sync(
923 partial(
924 self._ecr_client.create_repository,
925 repositoryName=self._repository_name,
926 )
927 )
928 advance()
929 console.print("Authenticating with ECR")
930 auth_token = self._ecr_client.get_authorization_token()
931 user, passwd = (
932 base64.b64decode(
933 auth_token["authorizationData"][0]["authorizationToken"]
934 )
935 .decode()
936 .split(":")
937 )
938 proxy_endpoint = auth_token["authorizationData"][0]["proxyEndpoint"]
939 await run_process(f"docker login -u {user} -p {passwd} {proxy_endpoint}")
940 advance()
941 console.print("Setting default Docker build namespace")
942 namespace = response["repository"]["repositoryUri"].split("/")[0]
943 update_current_profile({PREFECT_DEFAULT_DOCKER_BUILD_NAMESPACE: namespace})
944 self._update_next_steps(namespace)
945 advance()
947 def _update_next_steps(self, repository_uri: str):
948 self._next_steps.extend(
949 [
950 dedent(
951 f"""\
953 Your default Docker build namespace has been set to [blue]{repository_uri!r}[/].
955 To build and push a Docker image to your newly created repository, use [blue]{self._repository_name!r}[/] as your image name:
956 """
957 ),
958 Panel(
959 Syntax(
960 dedent(
961 f"""\
962 from prefect import flow
963 from prefect.docker import DockerImage
966 @flow(log_prints=True)
967 def my_flow(name: str = "world"):
968 print(f"Hello {{name}}! I'm a flow running on ECS!")
971 if __name__ == "__main__":
972 my_flow.deploy(
973 name="my-deployment",
974 work_pool_name="{self._work_pool_name}",
975 image=DockerImage(
976 name="{self._repository_name}:latest",
977 platform="linux/amd64",
978 )
979 )"""
980 ),
981 "python",
982 background_color="default",
983 ),
984 title="example_deploy_script.py",
985 expand=False,
986 ),
987 ]
988 )
990 @property
991 def next_steps(self) -> list[str | Panel]:
992 return self._next_steps
995class ExecutionRoleResource:
996 def __init__(self, execution_role_name: str = "PrefectEcsTaskExecutionRole"):
997 self._iam_client = boto3.client("iam")
998 self._execution_role_name = execution_role_name
999 self._trust_policy_document = json.dumps(
1000 {
1001 "Version": "2012-10-17",
1002 "Statement": [
1003 {
1004 "Effect": "Allow",
1005 "Principal": {"Service": "ecs-tasks.amazonaws.com"},
1006 "Action": "sts:AssumeRole",
1007 }
1008 ],
1009 }
1010 )
1011 self._requires_provisioning = None
1013 async def get_task_count(self) -> int:
1014 """
1015 Returns the number of tasks that will be executed to provision this resource.
1017 Returns:
1018 int: The number of tasks to be provisioned.
1019 """
1020 return 1 if await self.requires_provisioning() else 0
1022 async def requires_provisioning(self) -> bool:
1023 """
1024 Check if this resource requires provisioning.
1026 Returns:
1027 bool: True if provisioning is required, False otherwise.
1028 """
1029 if self._requires_provisioning is None:
1030 try:
1031 await anyio.to_thread.run_sync(
1032 partial(
1033 self._iam_client.get_role, RoleName=self._execution_role_name
1034 )
1035 )
1036 self._requires_provisioning = False
1037 except self._iam_client.exceptions.NoSuchEntityException:
1038 self._requires_provisioning = True
1040 return self._requires_provisioning
1042 async def get_planned_actions(self) -> List[str]:
1043 """
1044 Returns a description of the planned actions for provisioning this resource.
1046 Returns:
1047 Optional[str]: A description of the planned actions for provisioning the resource,
1048 or None if provisioning is not required.
1049 """
1050 if await self.requires_provisioning():
1051 return [
1052 "Creating an IAM role assigned to ECS tasks:"
1053 f" [blue]{self._execution_role_name}[/]"
1054 ]
1055 return []
1057 async def provision(
1058 self,
1059 base_job_template: dict[str, Any],
1060 advance: Callable[[], None],
1061 ) -> str:
1062 """
1063 Provisions an IAM role.
1065 Args:
1066 base_job_template: The base job template of the work pool to provision
1067 infrastructure for.
1068 advance: A callback function to indicate progress.
1069 """
1070 if await self.requires_provisioning():
1071 console = current_console.get()
1072 console.print("Provisioning execution role")
1073 response = await anyio.to_thread.run_sync(
1074 partial(
1075 self._iam_client.create_role,
1076 RoleName=self._execution_role_name,
1077 Description="Role for ECS tasks to access ECR and other resources.",
1078 AssumeRolePolicyDocument=self._trust_policy_document,
1079 )
1080 )
1081 await anyio.to_thread.run_sync(
1082 partial(
1083 self._iam_client.attach_role_policy,
1084 RoleName=self._execution_role_name,
1085 PolicyArn="arn:aws:iam::aws:policy/service-role/AmazonECSTaskExecutionRolePolicy",
1086 )
1087 )
1088 advance()
1089 else:
1090 response = await anyio.to_thread.run_sync(
1091 partial(self._iam_client.get_role, RoleName=self._execution_role_name)
1092 )
1094 base_job_template["variables"]["properties"]["execution_role_arn"][
1095 "default"
1096 ] = response["Role"]["Arn"]
1097 return response["Role"]["Arn"]
1099 @property
1100 def next_steps(self) -> list[str]:
1101 return []
1104class ElasticContainerServicePushProvisioner:
1105 """
1106 An infrastructure provisioner for ECS push work pools.
1107 """
1109 def __init__(self):
1110 self._console = Console()
1112 @property
1113 def console(self) -> Console:
1114 return self._console
1116 @console.setter
1117 def console(self, value: Console) -> None:
1118 self._console = value
1120 async def _prompt_boto3_installation(self):
1121 global boto3
1122 await ainstall_packages(["boto3"])
1123 boto3 = importlib.import_module("boto3")
1125 @staticmethod
1126 def is_boto3_installed() -> bool:
1127 """
1128 Check if boto3 is installed.
1129 """
1130 try:
1131 importlib.import_module("boto3")
1132 return True
1133 except ModuleNotFoundError:
1134 return False
1136 def _generate_resources(
1137 self,
1138 work_pool_name: str,
1139 user_name: str = "prefect-ecs-user",
1140 policy_name: str = "prefect-ecs-policy",
1141 credentials_block_name: Optional[str] = None,
1142 cluster_name: str = "prefect-ecs-cluster",
1143 vpc_name: str = "prefect-ecs-vpc",
1144 ecs_security_group_name: str = "prefect-ecs-security-group",
1145 repository_name: str = "prefect-flows",
1146 ):
1147 return [
1148 AuthenticationResource(
1149 work_pool_name=work_pool_name,
1150 user_name=user_name,
1151 policy_name=policy_name,
1152 credentials_block_name=credentials_block_name,
1153 ),
1154 ClusterResource(cluster_name=cluster_name),
1155 VpcResource(
1156 vpc_name=vpc_name,
1157 ecs_security_group_name=ecs_security_group_name,
1158 ),
1159 ContainerRepositoryResource(
1160 work_pool_name=work_pool_name,
1161 repository_name=repository_name,
1162 ),
1163 ]
1165 async def provision(
1166 self,
1167 work_pool_name: str,
1168 base_job_template: dict[str, Any],
1169 ) -> dict[str, Any]:
1170 """
1171 Provisions the infrastructure for an ECS push work pool.
1173 Args:
1174 work_pool_name: The name of the work pool to provision infrastructure for.
1175 base_job_template: The base job template of the work pool to provision
1176 infrastructure for.
1178 Returns:
1179 dict: An updated copy base job template.
1180 """
1181 from prefect.cli._prompts import prompt
1183 if not self.is_boto3_installed():
1184 if self.console.is_interactive and Confirm.ask(
1185 "boto3 is required to configure your AWS account. Would you like to"
1186 " install it?"
1187 ):
1188 await self._prompt_boto3_installation()
1189 else:
1190 raise RuntimeError(
1191 "boto3 is required to configure your AWS account. Please install it"
1192 " and try again."
1193 )
1195 try:
1196 if self.console.is_interactive and Confirm.ask(
1197 "Would you like to customize the resource names for your"
1198 " infrastructure? This includes an IAM user, IAM policy, ECS cluster,"
1199 " VPC, ECS security group, and ECR repository."
1200 ):
1201 user_name = prompt(
1202 "Enter a name for the IAM user (manages ECS tasks)",
1203 default="prefect-ecs-user",
1204 )
1205 policy_name = prompt(
1206 (
1207 "Enter a name for the IAM policy (defines ECS task execution"
1208 " and image management permissions)"
1209 ),
1210 default="prefect-ecs-policy",
1211 )
1212 cluster_name = prompt(
1213 "Enter a name for the ECS cluster (hosts ECS tasks)",
1214 default="prefect-ecs-cluster",
1215 )
1216 credentials_name = prompt(
1217 (
1218 "Enter a name for the AWS credentials block (stores AWS"
1219 " credentials for managing ECS tasks)"
1220 ),
1221 default=f"{work_pool_name}-aws-credentials",
1222 )
1223 vpc_name = prompt(
1224 (
1225 "Enter a name for the VPC (provides network isolation for ECS"
1226 " tasks)"
1227 ),
1228 default="prefect-ecs-vpc",
1229 )
1230 ecs_security_group_name = prompt(
1231 (
1232 "Enter a name for the ECS security group (controls task network"
1233 " traffic)"
1234 ),
1235 default="prefect-ecs-security-group",
1236 )
1237 repository_name = prompt(
1238 (
1239 "Enter a name for the ECR repository (stores Docker images for"
1240 " ECS tasks)"
1241 ),
1242 default="prefect-flows",
1243 )
1245 provision_preview = Panel(
1246 dedent(
1247 f"""\
1248 Custom names for infrastructure resources for
1249 [blue]{work_pool_name}[/]:
1251 - IAM user: [blue]{user_name}[/]
1252 - IAM policy: [blue]{policy_name}[/]
1253 - ECS cluster: [blue]{cluster_name}[/]
1254 - AWS credentials block: [blue]{credentials_name}[/]
1255 - VPC: [blue]{vpc_name}[/]
1256 - ECS security group: [blue]{ecs_security_group_name}[/]
1257 - ECR repository: [blue]{repository_name}[/]
1258 """
1259 ),
1260 expand=False,
1261 )
1263 self.console.print(provision_preview)
1265 resources = self._generate_resources(
1266 work_pool_name=work_pool_name,
1267 user_name=user_name,
1268 policy_name=policy_name,
1269 credentials_block_name=credentials_name,
1270 cluster_name=cluster_name,
1271 vpc_name=vpc_name,
1272 ecs_security_group_name=ecs_security_group_name,
1273 repository_name=repository_name,
1274 )
1276 else:
1277 resources = self._generate_resources(work_pool_name=work_pool_name)
1279 with Progress(
1280 SpinnerColumn(),
1281 TextColumn(
1282 "Checking your AWS account for infrastructure that needs to be"
1283 " provisioned..."
1284 ),
1285 transient=True,
1286 console=self.console,
1287 ) as progress:
1288 inspect_aws_account_task = progress.add_task(
1289 "inspect_aws_account", total=1
1290 )
1291 num_tasks = sum(
1292 [await resource.get_task_count() for resource in resources]
1293 )
1294 progress.update(inspect_aws_account_task, completed=1)
1296 if num_tasks > 0:
1297 message = (
1298 "Provisioning infrastructure for your work pool"
1299 f" [blue]{work_pool_name}[/] will require: \n"
1300 )
1301 for resource in resources:
1302 planned_actions = await resource.get_planned_actions()
1303 for action in planned_actions:
1304 message += f"\n\t - {action}"
1306 self.console.print(Panel(message))
1308 if self._console.is_interactive:
1309 if not Confirm.ask(
1310 "Proceed with infrastructure provisioning?",
1311 console=self._console,
1312 ):
1313 return base_job_template
1314 else:
1315 self.console.print(
1316 "No additional infrastructure required for work pool"
1317 f" [blue]{work_pool_name}[/]"
1318 )
1319 # don't return early, we still need to update the base job template
1320 # provision calls will be no-ops, but update the base job template
1322 base_job_template_copy = deepcopy(base_job_template)
1323 next_steps: list[str | Panel] = []
1324 with Progress(console=self._console, disable=num_tasks == 0) as progress:
1325 task = progress.add_task(
1326 "Provisioning Infrastructure",
1327 total=num_tasks,
1328 )
1329 for resource in resources:
1330 with console_context(progress.console):
1331 await resource.provision(
1332 advance=partial(progress.advance, task),
1333 base_job_template=base_job_template_copy,
1334 )
1335 next_steps.append(resource.next_steps)
1337 if next_steps:
1338 for step in next_steps:
1339 for item in step:
1340 self._console.print(item)
1342 if num_tasks > 0:
1343 self._console.print(
1344 "Infrastructure successfully provisioned!", style="green"
1345 )
1347 return base_job_template_copy
1348 except Exception as exc:
1349 if hasattr(exc, "response"):
1350 # Catching boto3 ClientError
1351 response = getattr(exc, "response", {})
1352 error_message = get_from_dict(response, "Error.Message") or str(exc)
1353 raise RuntimeError(error_message) from exc
1354 # Catching any other exception
1355 raise RuntimeError(str(exc)) from exc