Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/api/clients.py: 0%

124 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 11:21 +0000

1from __future__ import annotations 

2 

3import base64 

4from typing import TYPE_CHECKING, Any, Dict, List, Optional 

5from urllib.parse import quote 

6from uuid import UUID 

7 

8import httpx 

9import pydantic 

10from httpx import Response 

11from starlette import status 

12from typing_extensions import Self 

13 

14from prefect.client.base import PrefectHttpxAsyncClient 

15from prefect.exceptions import ObjectNotFound 

16from prefect.logging import get_logger 

17from prefect.server.schemas.actions import DeploymentFlowRunCreate, StateCreate 

18from prefect.server.schemas.core import WorkPool 

19from prefect.server.schemas.filters import VariableFilter, VariableFilterName 

20from prefect.server.schemas.responses import DeploymentResponse, OrchestrationResult 

21from prefect.settings import get_current_settings 

22from prefect.types import StrictVariableValue 

23 

24if TYPE_CHECKING: 

25 import logging 

26 

27logger: "logging.Logger" = get_logger(__name__) 

28 

29 

30class BaseClient: 

31 _http_client: PrefectHttpxAsyncClient 

32 

33 def __init__(self, additional_headers: dict[str, str] | None = None): 

34 from prefect.server.api.server import create_app 

35 

36 additional_headers = additional_headers or {} 

37 

38 # create_app caches application instances, and invoking it with no arguments 

39 # will point it to the the currently running server instance 

40 api_app = create_app() 

41 

42 settings = get_current_settings() 

43 

44 # we pull the auth string from _server_ settings because this client is run on the server 

45 if auth_string_secret := settings.server.api.auth_string: 

46 if auth_string := auth_string_secret.get_secret_value(): 

47 token = base64.b64encode(auth_string.encode("utf-8")).decode("utf-8") 

48 additional_headers.setdefault("Authorization", f"Basic {token}") 

49 

50 self._http_client = PrefectHttpxAsyncClient( 

51 transport=httpx.ASGITransport(app=api_app, raise_app_exceptions=False), 

52 headers={**additional_headers}, 

53 base_url=f"http://prefect-in-memory{settings.server.api.base_path or '/api'}", 

54 enable_csrf_support=settings.server.api.csrf_protection_enabled, 

55 raise_on_all_errors=False, 

56 ) 

57 

58 async def __aenter__(self) -> Self: 

59 await self._http_client.__aenter__() 

60 return self 

61 

62 async def __aexit__(self, *args: Any) -> None: 

63 await self._http_client.__aexit__(*args) 

64 

65 

66class OrchestrationClient(BaseClient): 

67 async def read_deployment_raw(self, deployment_id: UUID) -> Response: 

68 return await self._http_client.get(f"/deployments/{deployment_id}") 

69 

70 async def read_deployment( 

71 self, deployment_id: UUID 

72 ) -> Optional[DeploymentResponse]: 

73 try: 

74 response = await self.read_deployment_raw(deployment_id) 

75 response.raise_for_status() 

76 except httpx.HTTPStatusError as e: 

77 if e.response.status_code == status.HTTP_404_NOT_FOUND: 

78 return None 

79 raise 

80 return DeploymentResponse.model_validate(response.json()) 

81 

82 async def read_flow_raw(self, flow_id: UUID) -> Response: 

83 return await self._http_client.get(f"/flows/{flow_id}") 

84 

85 async def create_flow_run( 

86 self, deployment_id: UUID, flow_run_create: DeploymentFlowRunCreate 

87 ) -> Response: 

88 return await self._http_client.post( 

89 f"/deployments/{deployment_id}/create_flow_run", 

90 json=flow_run_create.model_dump(mode="json"), 

91 ) 

92 

93 async def read_flow_run_raw(self, flow_run_id: UUID) -> Response: 

94 return await self._http_client.get(f"/flow_runs/{flow_run_id}") 

95 

96 async def read_task_run_raw(self, task_run_id: UUID) -> Response: 

97 return await self._http_client.get(f"/task_runs/{task_run_id}") 

98 

99 async def resume_flow_run(self, flow_run_id: UUID) -> OrchestrationResult: 

100 response = await self._http_client.post( 

101 f"/flow_runs/{flow_run_id}/resume", 

102 ) 

103 response.raise_for_status() 

104 return OrchestrationResult.model_validate(response.json()) 

105 

106 async def pause_deployment(self, deployment_id: UUID) -> Response: 

107 return await self._http_client.post( 

108 f"/deployments/{deployment_id}/pause_deployment", 

109 ) 

110 

111 async def resume_deployment(self, deployment_id: UUID) -> Response: 

112 return await self._http_client.post( 

113 f"/deployments/{deployment_id}/resume_deployment", 

114 ) 

115 

116 async def set_flow_run_state( 

117 self, flow_run_id: UUID, state: StateCreate 

118 ) -> Response: 

119 return await self._http_client.post( 

120 f"/flow_runs/{flow_run_id}/set_state", 

121 json={ 

122 "state": state.model_dump(mode="json"), 

123 "force": False, 

124 }, 

125 ) 

126 

127 async def pause_work_pool(self, work_pool_name: str) -> Response: 

128 return await self._http_client.patch( 

129 f"/work_pools/{quote(work_pool_name)}", json={"is_paused": True} 

130 ) 

131 

132 async def resume_work_pool(self, work_pool_name: str) -> Response: 

133 return await self._http_client.patch( 

134 f"/work_pools/{quote(work_pool_name)}", json={"is_paused": False} 

135 ) 

136 

137 async def read_work_pool_raw(self, work_pool_id: UUID) -> Response: 

138 return await self._http_client.post( 

139 "/work_pools/filter", 

140 json={"work_pools": {"id": {"any_": [str(work_pool_id)]}}}, 

141 ) 

142 

143 async def read_work_pool(self, work_pool_id: UUID) -> Optional[WorkPool]: 

144 response = await self.read_work_pool_raw(work_pool_id) 

145 response.raise_for_status() 

146 

147 pools = pydantic.TypeAdapter(List[WorkPool]).validate_python(response.json()) 

148 return pools[0] if pools else None 

149 

150 async def read_work_queue_raw(self, work_queue_id: UUID) -> Response: 

151 return await self._http_client.get(f"/work_queues/{work_queue_id}") 

152 

153 async def read_work_queue_status_raw(self, work_queue_id: UUID) -> Response: 

154 return await self._http_client.get(f"/work_queues/{work_queue_id}/status") 

155 

156 async def pause_work_queue(self, work_queue_id: UUID) -> Response: 

157 return await self._http_client.patch( 

158 f"/work_queues/{work_queue_id}", 

159 json={"is_paused": True}, 

160 ) 

161 

162 async def resume_work_queue(self, work_queue_id: UUID) -> Response: 

163 return await self._http_client.patch( 

164 f"/work_queues/{work_queue_id}", 

165 json={"is_paused": False}, 

166 ) 

167 

168 async def read_block_document_raw( 

169 self, 

170 block_document_id: UUID, 

171 include_secrets: bool = True, 

172 ) -> Response: 

173 return await self._http_client.get( 

174 f"/block_documents/{block_document_id}", 

175 params=dict(include_secrets=include_secrets), 

176 ) 

177 

178 VARIABLE_PAGE_SIZE = 200 

179 MAX_VARIABLES_PER_WORKSPACE = 1000 

180 

181 async def read_workspace_variables( 

182 self, names: Optional[List[str]] = None 

183 ) -> Dict[str, StrictVariableValue]: 

184 variables: Dict[str, StrictVariableValue] = {} 

185 

186 offset = 0 

187 

188 filter = VariableFilter() 

189 

190 if names is not None and not names: 

191 return variables 

192 elif names is not None: 

193 filter.name = VariableFilterName(any_=list(set(names))) 

194 

195 for offset in range( 

196 0, self.MAX_VARIABLES_PER_WORKSPACE, self.VARIABLE_PAGE_SIZE 

197 ): 

198 response = await self._http_client.post( 

199 "/variables/filter", 

200 json={ 

201 "variables": filter.model_dump(), 

202 "limit": self.VARIABLE_PAGE_SIZE, 

203 "offset": offset, 

204 }, 

205 ) 

206 if response.status_code >= 300: 

207 response.raise_for_status() 

208 

209 results = response.json() 

210 for variable in results: 

211 variables[variable["name"]] = variable["value"] 

212 

213 if len(results) < self.VARIABLE_PAGE_SIZE: 

214 break 

215 

216 return variables 

217 

218 async def read_concurrency_limit_v2_raw( 

219 self, concurrency_limit_id: UUID 

220 ) -> Response: 

221 return await self._http_client.get( 

222 f"/v2/concurrency_limits/{concurrency_limit_id}" 

223 ) 

224 

225 

226class WorkPoolsOrchestrationClient(BaseClient): 

227 async def __aenter__(self) -> Self: 

228 return self 

229 

230 async def read_work_pool(self, work_pool_name: str) -> WorkPool: 

231 """ 

232 Reads information for a given work pool 

233 Args: 

234 work_pool_name: The name of the work pool to for which to get 

235 information. 

236 Returns: 

237 Information about the requested work pool. 

238 """ 

239 try: 

240 response = await self._http_client.get(f"/work_pools/{work_pool_name}") 

241 response.raise_for_status() 

242 return WorkPool.model_validate(response.json()) 

243 except httpx.HTTPStatusError as e: 

244 if e.response.status_code == status.HTTP_404_NOT_FOUND: 

245 raise ObjectNotFound(http_exc=e) from e 

246 else: 

247 raise