Coverage for /usr/local/lib/python3.12/site-packages/prefect/infrastructure/provisioners/cloud_run.py: 0%
182 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
1import json
2import shlex
3import subprocess
4import tempfile
5from copy import deepcopy
6from pathlib import Path
7from textwrap import dedent
8from typing import TYPE_CHECKING, Any, Dict, Optional
9from uuid import UUID
11from anyio import run_process
12from rich.console import Console
13from rich.panel import Panel
14from rich.pretty import Pretty
15from rich.progress import Progress, SpinnerColumn, TextColumn
16from rich.prompt import Confirm
17from rich.syntax import Syntax
19from prefect.client.orchestration import ServerType
20from prefect.client.schemas.actions import BlockDocumentCreate
21from prefect.client.utilities import inject_client
22from prefect.exceptions import ObjectAlreadyExists
23from prefect.settings import (
24 PREFECT_DEBUG_MODE,
25 PREFECT_DEFAULT_DOCKER_BUILD_NAMESPACE,
26 update_current_profile,
27)
29if TYPE_CHECKING:
30 from prefect.client.orchestration import PrefectClient
33class CloudRunPushProvisioner:
34 def __init__(self):
35 self._console: Console = Console()
36 self._project = None
37 self._region = None
38 self._service_account_name = "prefect-cloud-run"
39 self._credentials_block_name = None
40 self._image_repository_name = "prefect-images"
42 @property
43 def console(self) -> Console:
44 return self._console
46 @console.setter
47 def console(self, value: Console) -> None:
48 self._console = value
50 async def _run_command(self, command: str, *args: Any, **kwargs: Any) -> str:
51 result = await run_process(shlex.split(command), check=False, *args, **kwargs)
53 if result.returncode != 0:
54 if PREFECT_DEBUG_MODE:
55 self._console.print(
56 "Error running command:",
57 Pretty(
58 {
59 "command": command,
60 "stdout": result.stdout.decode("utf-8"),
61 "stderr": result.stderr.decode("utf-8"),
62 }
63 ),
64 style="red",
65 )
66 raise subprocess.CalledProcessError(
67 result.returncode, command, output=result.stdout, stderr=result.stderr
68 )
70 return result.stdout.decode("utf-8").strip()
72 async def _verify_gcloud_ready(self):
73 try:
74 await self._run_command("gcloud --version")
75 except subprocess.CalledProcessError as e:
76 raise RuntimeError(
77 "gcloud not found. Please install gcloud and ensure it is in your PATH."
78 ) from e
80 accounts = json.loads(await self._run_command("gcloud auth list --format=json"))
81 if not [account for account in accounts if account["status"] == "ACTIVE"]:
82 raise RuntimeError(
83 "No active gcloud account found. Please run `gcloud auth login`."
84 )
86 async def _get_project(self):
87 from prefect.cli._prompts import prompt_select_from_table
89 if self._console.is_interactive:
90 with Progress(
91 SpinnerColumn(),
92 TextColumn("Fetching your GCP projects..."),
93 transient=True,
94 console=self.console,
95 ) as progress:
96 list_projects_task = progress.add_task("list_projects", total=1)
97 projects_raw = await self._run_command(
98 "gcloud projects list --format=json"
99 )
101 progress.update(list_projects_task, completed=1)
102 projects = json.loads(projects_raw)
103 selected_project = prompt_select_from_table(
104 self.console,
105 "Please select which GCP project to use",
106 [
107 {"header": "Name", "key": "name"},
108 {"header": "Project ID", "key": "projectId"},
109 ],
110 projects,
111 )
112 return selected_project["projectId"]
113 else:
114 return await self._run_command("gcloud config get-value project")
116 async def _get_default_region(self):
117 default_region = await self._run_command("gcloud config get-value run/region")
118 return default_region or "us-central1"
120 async def _enable_cloud_run_api(self):
121 try:
122 await self._run_command(
123 f"gcloud services enable run.googleapis.com --project={self._project}"
124 )
126 except subprocess.CalledProcessError as e:
127 raise RuntimeError(
128 "Error enabling Cloud Run API. Please ensure you have the necessary"
129 " permissions."
130 ) from e
132 async def _create_service_account(self):
133 try:
134 await self._run_command(
135 f"gcloud iam service-accounts create {self._service_account_name}"
136 ' --display-name "Prefect Cloud Run Service Account"'
137 f" --project={self._project}"
138 )
139 except subprocess.CalledProcessError as e:
140 if "already exists" not in e.output.decode("utf-8"):
141 return
142 raise RuntimeError(
143 "Error creating service account. Please ensure you have the necessary"
144 " permissions."
145 ) from e
147 async def _create_service_account_key(self):
148 with tempfile.TemporaryDirectory() as tmpdir:
149 try:
150 await self._run_command(
151 "gcloud iam service-accounts keys create"
152 f" {tmpdir}/{self._service_account_name}-key.json"
153 f" --iam-account={self._service_account_name}@{self._project}.iam.gserviceaccount.com"
154 )
155 except subprocess.CalledProcessError as e:
156 raise RuntimeError(
157 "Error creating service account key. Please ensure you have the"
158 " necessary permissions."
159 ) from e
160 key = json.loads(
161 (Path(tmpdir) / f"{self._service_account_name}-key.json").read_text()
162 )
163 return key
165 async def _assign_roles(self):
166 try:
167 await self._run_command(
168 "gcloud projects add-iam-policy-binding"
169 f' {self._project} --member="serviceAccount:{self._service_account_name}@{self._project}.iam.gserviceaccount.com"'
170 ' --role="roles/iam.serviceAccountUser"'
171 )
172 await self._run_command(
173 "gcloud projects add-iam-policy-binding"
174 f' {self._project} --member="serviceAccount:{self._service_account_name}@{self._project}.iam.gserviceaccount.com"'
175 ' --role="roles/run.developer"'
176 )
177 except subprocess.CalledProcessError as e:
178 raise RuntimeError(
179 "Error assigning roles to service account. Please ensure you have the"
180 " necessary permissions."
181 ) from e
183 async def _enable_artifact_registry_api(self):
184 try:
185 await self._run_command(
186 "gcloud services enable artifactregistry.googleapis.com"
187 f" --project={self._project}"
188 )
189 except subprocess.CalledProcessError as e:
190 raise RuntimeError(
191 "Error enabling Artifact Registry API. Please ensure you have the"
192 " necessary permissions."
193 ) from e
195 async def _create_artifact_registry_repository(self, repository_name: str):
196 try:
197 await self._run_command(
198 "gcloud artifacts repositories create"
199 f" {repository_name} --repository-format=docker"
200 f" --location={self._region} --project={self._project}"
201 )
202 except subprocess.CalledProcessError as e:
203 if "already exists" not in e.output.decode("utf-8"):
204 return
205 raise RuntimeError(
206 "Error creating Artifact Registry repository. Please ensure you have"
207 " the necessary permissions."
208 ) from e
210 async def _login_to_artifact_registry(self):
211 try:
212 await self._run_command(
213 f"gcloud auth configure-docker {self._region}-docker.pkg.dev"
214 f" --project={self._project}"
215 )
216 except subprocess.CalledProcessError as e:
217 raise RuntimeError(
218 "Error logging into Artifact Registry. Please ensure you have the"
219 " necessary permissions."
220 ) from e
222 async def _create_gcp_credentials_block(
223 self, block_document_name: str, key: dict, client: "PrefectClient"
224 ) -> UUID:
225 credentials_block_type = await client.read_block_type_by_slug("gcp-credentials")
227 credentials_block_schema = (
228 await client.get_most_recent_block_schema_for_block_type(
229 block_type_id=credentials_block_type.id
230 )
231 )
233 try:
234 block_doc = await client.create_block_document(
235 block_document=BlockDocumentCreate(
236 name=block_document_name,
237 data={"service_account_info": key},
238 block_type_id=credentials_block_type.id,
239 block_schema_id=credentials_block_schema.id,
240 )
241 )
242 return block_doc.id
243 except ObjectAlreadyExists:
244 block_doc = await client.read_block_document_by_name(
245 name=block_document_name,
246 block_type_slug="gcp-credentials",
247 )
248 return block_doc.id
250 async def _create_provision_table(
251 self, work_pool_name: str, client: "PrefectClient"
252 ):
253 return Panel(
254 dedent(
255 f"""\
256 Provisioning infrastructure for your work pool [blue]{work_pool_name}[/] will require:
258 Updates in GCP project [blue]{self._project}[/] in region [blue]{self._region}[/]
260 - Activate the Cloud Run API for your project
261 - Activate the Artifact Registry API for your project
262 - Create an Artifact Registry repository named [blue]{self._image_repository_name}[/]
263 - Create a service account for managing Cloud Run jobs: [blue]{self._service_account_name}[/]
264 - Service account will be granted the following roles:
265 - Service Account User
266 - Cloud Run Developer
267 - Create a key for service account: [blue]{self._service_account_name}[/]
269 Updates in Prefect {"workspace" if client.server_type == ServerType.CLOUD else "server"}
271 - Create GCP credentials block to store the service account key: [blue]{self._credentials_block_name}[/]
272 """
273 ),
274 expand=False,
275 )
277 async def _customize_resource_names(
278 self, work_pool_name: str, client: "PrefectClient"
279 ) -> bool:
280 from prefect.cli._prompts import prompt
282 self._service_account_name = prompt(
283 "Please enter a name for the service account",
284 default=self._service_account_name,
285 )
286 self._credentials_block_name = prompt(
287 "Please enter a name for the GCP credentials block",
288 default=self._credentials_block_name,
289 )
290 self._image_repository_name = prompt(
291 "Please enter a name for the Artifact Registry repository",
292 default=self._image_repository_name,
293 )
294 table = await self._create_provision_table(work_pool_name, client)
295 self._console.print(table)
297 return Confirm.ask(
298 "Proceed with infrastructure provisioning?", console=self._console
299 )
301 @inject_client
302 async def provision(
303 self,
304 work_pool_name: str,
305 base_job_template: dict,
306 client: Optional["PrefectClient"] = None,
307 ) -> Dict[str, Any]:
308 from prefect.cli._prompts import prompt_select_from_table
310 assert client, "Client injection failed"
311 await self._verify_gcloud_ready()
312 self._project = await self._get_project()
313 self._region = await self._get_default_region()
314 self._credentials_block_name = f"{work_pool_name}-push-pool-credentials"
316 table = await self._create_provision_table(work_pool_name, client)
317 self._console.print(table)
318 if self._console.is_interactive:
319 chosen_option = prompt_select_from_table(
320 self._console,
321 "Proceed with infrastructure provisioning with default resource names?",
322 [
323 {"header": "Options:", "key": "option"},
324 ],
325 [
326 {
327 "option": (
328 "Yes, proceed with infrastructure provisioning with default"
329 " resource names"
330 )
331 },
332 {"option": "Customize resource names"},
333 {"option": "Do not proceed with infrastructure provisioning"},
334 ],
335 )
336 if chosen_option["option"] == "Customize resource names":
337 if not await self._customize_resource_names(work_pool_name, client):
338 return base_job_template
340 elif (
341 chosen_option["option"]
342 == "Do not proceed with infrastructure provisioning"
343 ):
344 return base_job_template
345 elif (
346 chosen_option["option"]
347 != "Yes, proceed with infrastructure provisioning with default"
348 " resource names"
349 ):
350 # basically, we should never hit this. i'm concerned that we might change
351 # the options in the future and forget to update this check
352 raise ValueError(f"Invalid option selected: {chosen_option['option']}")
354 with Progress(console=self._console) as progress:
355 task = progress.add_task("Provisioning Infrastructure", total=9)
356 progress.console.print("Activating Cloud Run API")
357 await self._enable_cloud_run_api()
358 progress.advance(task)
360 progress.console.print("Activating Artifact Registry API")
361 await self._enable_artifact_registry_api()
362 progress.advance(task)
364 progress.console.print("Creating Artifact Registry repository")
365 await self._create_artifact_registry_repository(self._image_repository_name)
366 progress.advance(task)
368 progress.console.print("Configuring authentication to Artifact Registry")
369 await self._login_to_artifact_registry()
370 progress.advance(task)
372 progress.console.print("Setting default Docker build namespace")
373 default_docker_build_namespace = f"{self._region}-docker.pkg.dev/{self._project}/{self._image_repository_name}"
374 update_current_profile(
375 {PREFECT_DEFAULT_DOCKER_BUILD_NAMESPACE: default_docker_build_namespace}
376 )
377 progress.advance(task)
379 progress.console.print("Creating service account")
380 await self._create_service_account()
381 progress.advance(task)
383 progress.console.print("Assigning roles to service account")
384 await self._assign_roles()
385 progress.advance(task)
387 progress.console.print("Creating service account key")
388 key = await self._create_service_account_key()
389 progress.advance(task)
391 progress.console.print("Creating GCP credentials block")
392 block_doc_id = await self._create_gcp_credentials_block(
393 self._credentials_block_name, key, client
394 )
395 base_job_template_copy = deepcopy(base_job_template)
396 base_job_template_copy["variables"]["properties"]["credentials"][
397 "default"
398 ] = {"$ref": {"block_document_id": str(block_doc_id)}}
399 progress.advance(task)
401 self._console.print(
402 dedent(
403 f"""\
404 Your default Docker build namespace has been set to [blue]{default_docker_build_namespace!r}[/].
405 Use any image name to build and push to this registry by default:
406 """
407 ),
408 Panel(
409 Syntax(
410 dedent(
411 f"""\
412 from prefect import flow
413 from prefect.docker import DockerImage
416 @flow(log_prints=True)
417 def my_flow(name: str = "world"):
418 print(f"Hello {{name}}! I'm a flow running in Cloud Run!")
421 if __name__ == "__main__":
422 my_flow.deploy(
423 name="my-deployment",
424 work_pool_name="{work_pool_name}",
425 image=DockerImage(
426 name="my-image:latest",
427 platform="linux/amd64",
428 )
429 )"""
430 ),
431 "python",
432 background_color="default",
433 ),
434 title="example_deploy_script.py",
435 expand=False,
436 ),
437 )
439 self._console.print(
440 (
441 f"Infrastructure successfully provisioned for '{work_pool_name}' work"
442 " pool!"
443 ),
444 style="green",
445 )
447 return base_job_template_copy