Coverage for /usr/local/lib/python3.12/site-packages/prefect/client/cloud.py: 30%
112 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
1from __future__ import annotations 1a
3import re 1a
4from typing import Any, NoReturn, Optional, cast 1a
6import anyio 1a
7import httpx 1a
8import pydantic 1a
9from starlette import status 1a
10from typing_extensions import Self 1a
12import prefect.settings 1a
13from prefect.client.base import PrefectHttpxAsyncClient 1a
14from prefect.client.schemas.objects import ( 1a
15 IPAllowlist,
16 IPAllowlistMyAccessResponse,
17 Workspace,
18)
19from prefect.exceptions import ObjectNotFound, PrefectException 1a
20from prefect.settings import ( 1a
21 PREFECT_API_KEY,
22 PREFECT_API_URL,
23 PREFECT_CLOUD_API_URL,
24 PREFECT_TESTING_UNIT_TEST_MODE,
25)
27PARSE_API_URL_REGEX = re.compile(r"accounts/(.{36})/workspaces/(.{36})") 1a
29# Cache for TypeAdapter instances to avoid repeated instantiation
30_TYPE_ADAPTER_CACHE: dict[type, pydantic.TypeAdapter[Any]] = {} 1a
33def _get_type_adapter(type_: type) -> pydantic.TypeAdapter[Any]: 1a
34 """Get or create a cached TypeAdapter for the given type."""
35 if type_ not in _TYPE_ADAPTER_CACHE:
36 _TYPE_ADAPTER_CACHE[type_] = pydantic.TypeAdapter(type_)
37 return _TYPE_ADAPTER_CACHE[type_]
40def get_cloud_client( 1a
41 host: Optional[str] = None,
42 api_key: Optional[str] = None,
43 httpx_settings: Optional[dict[str, Any]] = None,
44 infer_cloud_url: bool = False,
45) -> "CloudClient":
46 """
47 Needs a docstring.
48 """
49 if httpx_settings is not None:
50 httpx_settings = httpx_settings.copy()
52 if infer_cloud_url is False:
53 host = host or PREFECT_CLOUD_API_URL.value()
54 else:
55 configured_url = prefect.settings.PREFECT_API_URL.value()
56 host = re.sub(PARSE_API_URL_REGEX, "", configured_url)
58 if host is None:
59 raise ValueError("Host was not provided and could not be inferred")
61 return CloudClient(
62 host=host,
63 api_key=api_key or PREFECT_API_KEY.value(),
64 httpx_settings=httpx_settings,
65 )
68class CloudUnauthorizedError(PrefectException): 1a
69 """
70 Raised when the CloudClient receives a 401 or 403 from the Cloud API.
71 """
74class CloudClient: 1a
75 account_id: Optional[str] = None 1a
76 workspace_id: Optional[str] = None 1a
78 def __init__( 1a
79 self,
80 host: str,
81 api_key: str,
82 httpx_settings: Optional[dict[str, Any]] = None,
83 ) -> None:
84 httpx_settings = httpx_settings or dict()
85 httpx_settings.setdefault("headers", dict())
86 httpx_settings["headers"].setdefault("Authorization", f"Bearer {api_key}")
88 httpx_settings.setdefault("base_url", host)
89 if not PREFECT_TESTING_UNIT_TEST_MODE.value():
90 httpx_settings.setdefault("follow_redirects", True)
91 self._client = PrefectHttpxAsyncClient(
92 **httpx_settings, enable_csrf_support=False
93 )
95 api_url: str = prefect.settings.PREFECT_API_URL.value() or ""
96 if match := (
97 re.search(PARSE_API_URL_REGEX, host)
98 or re.search(PARSE_API_URL_REGEX, api_url)
99 ):
100 self.account_id, self.workspace_id = match.groups()
102 @property 1a
103 def account_base_url(self) -> str: 1a
104 if not self.account_id:
105 raise ValueError("Account ID not set")
107 return f"accounts/{self.account_id}"
109 @property 1a
110 def workspace_base_url(self) -> str: 1a
111 if not self.workspace_id:
112 raise ValueError("Workspace ID not set")
114 return f"{self.account_base_url}/workspaces/{self.workspace_id}"
116 async def api_healthcheck(self) -> None: 1a
117 """
118 Attempts to connect to the Cloud API and raises the encountered exception if not
119 successful.
121 If successful, returns `None`.
122 """
123 with anyio.fail_after(10):
124 await self.read_workspaces()
126 async def read_workspaces(self) -> list[Workspace]: 1a
127 workspaces = _get_type_adapter(list[Workspace]).validate_python(
128 await self.get("/me/workspaces")
129 )
130 return workspaces
132 async def read_current_workspace(self) -> Workspace: 1a
133 workspaces = await self.read_workspaces()
134 current_api_url = PREFECT_API_URL.value()
135 for workspace in workspaces:
136 if workspace.api_url() == current_api_url.rstrip("/"):
137 return workspace
138 raise ValueError("Current workspace not found")
140 async def read_worker_metadata(self) -> dict[str, Any]: 1a
141 response = await self.get(
142 f"{self.workspace_base_url}/collections/work_pool_types"
143 )
144 return cast(dict[str, Any], response)
146 async def read_account_settings(self) -> dict[str, Any]: 1a
147 response = await self.get(f"{self.account_base_url}/settings")
148 return cast(dict[str, Any], response)
150 async def update_account_settings(self, settings: dict[str, Any]) -> None: 1a
151 await self.request(
152 "PATCH",
153 f"{self.account_base_url}/settings",
154 json=settings,
155 )
157 async def read_account_ip_allowlist(self) -> IPAllowlist: 1a
158 response = await self.get(f"{self.account_base_url}/ip_allowlist")
159 return IPAllowlist.model_validate(response)
161 async def update_account_ip_allowlist(self, updated_allowlist: IPAllowlist) -> None: 1a
162 await self.request(
163 "PUT",
164 f"{self.account_base_url}/ip_allowlist",
165 json=updated_allowlist.model_dump(mode="json"),
166 )
168 async def check_ip_allowlist_access(self) -> IPAllowlistMyAccessResponse: 1a
169 response = await self.get(f"{self.account_base_url}/ip_allowlist/my_access")
170 return IPAllowlistMyAccessResponse.model_validate(response)
172 async def __aenter__(self) -> Self: 1a
173 await self._client.__aenter__()
174 return self
176 async def __aexit__(self, *exc_info: Any) -> None: 1a
177 return await self._client.__aexit__(*exc_info)
179 def __enter__(self) -> NoReturn: 1a
180 raise RuntimeError(
181 "The `CloudClient` must be entered with an async context. Use 'async "
182 "with CloudClient(...)' not 'with CloudClient(...)'"
183 )
185 def __exit__(self, *_: object) -> NoReturn: 1a
186 assert False, "This should never be called but must be defined for __enter__"
188 async def get(self, route: str, **kwargs: Any) -> Any: 1a
189 return await self.request("GET", route, **kwargs)
191 async def raw_request( 1a
192 self,
193 method: str,
194 path: str,
195 params: dict[str, Any] | None = None,
196 path_params: dict[str, Any] | None = None,
197 **kwargs: Any,
198 ) -> httpx.Response:
199 """
200 Make a raw HTTP request and return the Response object.
202 Unlike request(), this does not parse JSON or raise special exceptions,
203 returning the raw httpx.Response for direct access to headers, status, etc.
205 Args:
206 method: HTTP method (GET, POST, etc.)
207 path: API path/route
208 params: Query parameters
209 path_params: Path parameters for formatting
210 **kwargs: Additional arguments passed to httpx (json, headers, etc.)
212 Returns:
213 Raw httpx.Response object
214 """
215 if path_params:
216 path = path.format(**path_params)
217 request = self._client.build_request(method, path, params=params, **kwargs)
218 return await self._client.send(request)
220 async def request(self, method: str, route: str, **kwargs: Any) -> Any: 1a
221 try:
222 res = await self._client.request(method, route, **kwargs)
223 res.raise_for_status()
224 except httpx.HTTPStatusError as exc:
225 if exc.response.status_code in (
226 status.HTTP_401_UNAUTHORIZED,
227 status.HTTP_403_FORBIDDEN,
228 ):
229 raise CloudUnauthorizedError(str(exc)) from exc
230 elif exc.response.status_code == status.HTTP_404_NOT_FOUND:
231 raise ObjectNotFound(http_exc=exc) from exc
232 else:
233 raise
235 if res.status_code == status.HTTP_204_NO_CONTENT:
236 return
238 return res.json()