Coverage for /usr/local/lib/python3.12/site-packages/prefect/client/cloud.py: 30%

112 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1from __future__ import annotations 1a

2 

3import re 1a

4from typing import Any, NoReturn, Optional, cast 1a

5 

6import anyio 1a

7import httpx 1a

8import pydantic 1a

9from starlette import status 1a

10from typing_extensions import Self 1a

11 

12import prefect.settings 1a

13from prefect.client.base import PrefectHttpxAsyncClient 1a

14from prefect.client.schemas.objects import ( 1a

15 IPAllowlist, 

16 IPAllowlistMyAccessResponse, 

17 Workspace, 

18) 

19from prefect.exceptions import ObjectNotFound, PrefectException 1a

20from prefect.settings import ( 1a

21 PREFECT_API_KEY, 

22 PREFECT_API_URL, 

23 PREFECT_CLOUD_API_URL, 

24 PREFECT_TESTING_UNIT_TEST_MODE, 

25) 

26 

27PARSE_API_URL_REGEX = re.compile(r"accounts/(.{36})/workspaces/(.{36})") 1a

28 

29# Cache for TypeAdapter instances to avoid repeated instantiation 

30_TYPE_ADAPTER_CACHE: dict[type, pydantic.TypeAdapter[Any]] = {} 1a

31 

32 

33def _get_type_adapter(type_: type) -> pydantic.TypeAdapter[Any]: 1a

34 """Get or create a cached TypeAdapter for the given type.""" 

35 if type_ not in _TYPE_ADAPTER_CACHE: 

36 _TYPE_ADAPTER_CACHE[type_] = pydantic.TypeAdapter(type_) 

37 return _TYPE_ADAPTER_CACHE[type_] 

38 

39 

40def get_cloud_client( 1a

41 host: Optional[str] = None, 

42 api_key: Optional[str] = None, 

43 httpx_settings: Optional[dict[str, Any]] = None, 

44 infer_cloud_url: bool = False, 

45) -> "CloudClient": 

46 """ 

47 Needs a docstring. 

48 """ 

49 if httpx_settings is not None: 

50 httpx_settings = httpx_settings.copy() 

51 

52 if infer_cloud_url is False: 

53 host = host or PREFECT_CLOUD_API_URL.value() 

54 else: 

55 configured_url = prefect.settings.PREFECT_API_URL.value() 

56 host = re.sub(PARSE_API_URL_REGEX, "", configured_url) 

57 

58 if host is None: 

59 raise ValueError("Host was not provided and could not be inferred") 

60 

61 return CloudClient( 

62 host=host, 

63 api_key=api_key or PREFECT_API_KEY.value(), 

64 httpx_settings=httpx_settings, 

65 ) 

66 

67 

68class CloudUnauthorizedError(PrefectException): 1a

69 """ 

70 Raised when the CloudClient receives a 401 or 403 from the Cloud API. 

71 """ 

72 

73 

74class CloudClient: 1a

75 account_id: Optional[str] = None 1a

76 workspace_id: Optional[str] = None 1a

77 

78 def __init__( 1a

79 self, 

80 host: str, 

81 api_key: str, 

82 httpx_settings: Optional[dict[str, Any]] = None, 

83 ) -> None: 

84 httpx_settings = httpx_settings or dict() 

85 httpx_settings.setdefault("headers", dict()) 

86 httpx_settings["headers"].setdefault("Authorization", f"Bearer {api_key}") 

87 

88 httpx_settings.setdefault("base_url", host) 

89 if not PREFECT_TESTING_UNIT_TEST_MODE.value(): 

90 httpx_settings.setdefault("follow_redirects", True) 

91 self._client = PrefectHttpxAsyncClient( 

92 **httpx_settings, enable_csrf_support=False 

93 ) 

94 

95 api_url: str = prefect.settings.PREFECT_API_URL.value() or "" 

96 if match := ( 

97 re.search(PARSE_API_URL_REGEX, host) 

98 or re.search(PARSE_API_URL_REGEX, api_url) 

99 ): 

100 self.account_id, self.workspace_id = match.groups() 

101 

102 @property 1a

103 def account_base_url(self) -> str: 1a

104 if not self.account_id: 

105 raise ValueError("Account ID not set") 

106 

107 return f"accounts/{self.account_id}" 

108 

109 @property 1a

110 def workspace_base_url(self) -> str: 1a

111 if not self.workspace_id: 

112 raise ValueError("Workspace ID not set") 

113 

114 return f"{self.account_base_url}/workspaces/{self.workspace_id}" 

115 

116 async def api_healthcheck(self) -> None: 1a

117 """ 

118 Attempts to connect to the Cloud API and raises the encountered exception if not 

119 successful. 

120 

121 If successful, returns `None`. 

122 """ 

123 with anyio.fail_after(10): 

124 await self.read_workspaces() 

125 

126 async def read_workspaces(self) -> list[Workspace]: 1a

127 workspaces = _get_type_adapter(list[Workspace]).validate_python( 

128 await self.get("/me/workspaces") 

129 ) 

130 return workspaces 

131 

132 async def read_current_workspace(self) -> Workspace: 1a

133 workspaces = await self.read_workspaces() 

134 current_api_url = PREFECT_API_URL.value() 

135 for workspace in workspaces: 

136 if workspace.api_url() == current_api_url.rstrip("/"): 

137 return workspace 

138 raise ValueError("Current workspace not found") 

139 

140 async def read_worker_metadata(self) -> dict[str, Any]: 1a

141 response = await self.get( 

142 f"{self.workspace_base_url}/collections/work_pool_types" 

143 ) 

144 return cast(dict[str, Any], response) 

145 

146 async def read_account_settings(self) -> dict[str, Any]: 1a

147 response = await self.get(f"{self.account_base_url}/settings") 

148 return cast(dict[str, Any], response) 

149 

150 async def update_account_settings(self, settings: dict[str, Any]) -> None: 1a

151 await self.request( 

152 "PATCH", 

153 f"{self.account_base_url}/settings", 

154 json=settings, 

155 ) 

156 

157 async def read_account_ip_allowlist(self) -> IPAllowlist: 1a

158 response = await self.get(f"{self.account_base_url}/ip_allowlist") 

159 return IPAllowlist.model_validate(response) 

160 

161 async def update_account_ip_allowlist(self, updated_allowlist: IPAllowlist) -> None: 1a

162 await self.request( 

163 "PUT", 

164 f"{self.account_base_url}/ip_allowlist", 

165 json=updated_allowlist.model_dump(mode="json"), 

166 ) 

167 

168 async def check_ip_allowlist_access(self) -> IPAllowlistMyAccessResponse: 1a

169 response = await self.get(f"{self.account_base_url}/ip_allowlist/my_access") 

170 return IPAllowlistMyAccessResponse.model_validate(response) 

171 

172 async def __aenter__(self) -> Self: 1a

173 await self._client.__aenter__() 

174 return self 

175 

176 async def __aexit__(self, *exc_info: Any) -> None: 1a

177 return await self._client.__aexit__(*exc_info) 

178 

179 def __enter__(self) -> NoReturn: 1a

180 raise RuntimeError( 

181 "The `CloudClient` must be entered with an async context. Use 'async " 

182 "with CloudClient(...)' not 'with CloudClient(...)'" 

183 ) 

184 

185 def __exit__(self, *_: object) -> NoReturn: 1a

186 assert False, "This should never be called but must be defined for __enter__" 

187 

188 async def get(self, route: str, **kwargs: Any) -> Any: 1a

189 return await self.request("GET", route, **kwargs) 

190 

191 async def raw_request( 1a

192 self, 

193 method: str, 

194 path: str, 

195 params: dict[str, Any] | None = None, 

196 path_params: dict[str, Any] | None = None, 

197 **kwargs: Any, 

198 ) -> httpx.Response: 

199 """ 

200 Make a raw HTTP request and return the Response object. 

201 

202 Unlike request(), this does not parse JSON or raise special exceptions, 

203 returning the raw httpx.Response for direct access to headers, status, etc. 

204 

205 Args: 

206 method: HTTP method (GET, POST, etc.) 

207 path: API path/route 

208 params: Query parameters 

209 path_params: Path parameters for formatting 

210 **kwargs: Additional arguments passed to httpx (json, headers, etc.) 

211 

212 Returns: 

213 Raw httpx.Response object 

214 """ 

215 if path_params: 

216 path = path.format(**path_params) 

217 request = self._client.build_request(method, path, params=params, **kwargs) 

218 return await self._client.send(request) 

219 

220 async def request(self, method: str, route: str, **kwargs: Any) -> Any: 1a

221 try: 

222 res = await self._client.request(method, route, **kwargs) 

223 res.raise_for_status() 

224 except httpx.HTTPStatusError as exc: 

225 if exc.response.status_code in ( 

226 status.HTTP_401_UNAUTHORIZED, 

227 status.HTTP_403_FORBIDDEN, 

228 ): 

229 raise CloudUnauthorizedError(str(exc)) from exc 

230 elif exc.response.status_code == status.HTTP_404_NOT_FOUND: 

231 raise ObjectNotFound(http_exc=exc) from exc 

232 else: 

233 raise 

234 

235 if res.status_code == status.HTTP_204_NO_CONTENT: 

236 return 

237 

238 return res.json()