Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/api/clients.py: 0%
124 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 13:38 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 13:38 +0000
1from __future__ import annotations
3import base64
4from typing import TYPE_CHECKING, Any, Dict, List, Optional
5from urllib.parse import quote
6from uuid import UUID
8import httpx
9import pydantic
10from httpx import Response
11from starlette import status
12from typing_extensions import Self
14from prefect.client.base import PrefectHttpxAsyncClient
15from prefect.exceptions import ObjectNotFound
16from prefect.logging import get_logger
17from prefect.server.schemas.actions import DeploymentFlowRunCreate, StateCreate
18from prefect.server.schemas.core import WorkPool
19from prefect.server.schemas.filters import VariableFilter, VariableFilterName
20from prefect.server.schemas.responses import DeploymentResponse, OrchestrationResult
21from prefect.settings import get_current_settings
22from prefect.types import StrictVariableValue
24if TYPE_CHECKING:
25 import logging
27logger: "logging.Logger" = get_logger(__name__)
30class BaseClient:
31 _http_client: PrefectHttpxAsyncClient
33 def __init__(self, additional_headers: dict[str, str] | None = None):
34 from prefect.server.api.server import create_app
36 additional_headers = additional_headers or {}
38 # create_app caches application instances, and invoking it with no arguments
39 # will point it to the the currently running server instance
40 api_app = create_app()
42 settings = get_current_settings()
44 # we pull the auth string from _server_ settings because this client is run on the server
45 if auth_string_secret := settings.server.api.auth_string:
46 if auth_string := auth_string_secret.get_secret_value():
47 token = base64.b64encode(auth_string.encode("utf-8")).decode("utf-8")
48 additional_headers.setdefault("Authorization", f"Basic {token}")
50 self._http_client = PrefectHttpxAsyncClient(
51 transport=httpx.ASGITransport(app=api_app, raise_app_exceptions=False),
52 headers={**additional_headers},
53 base_url=f"http://prefect-in-memory{settings.server.api.base_path or '/api'}",
54 enable_csrf_support=settings.server.api.csrf_protection_enabled,
55 raise_on_all_errors=False,
56 )
58 async def __aenter__(self) -> Self:
59 await self._http_client.__aenter__()
60 return self
62 async def __aexit__(self, *args: Any) -> None:
63 await self._http_client.__aexit__(*args)
66class OrchestrationClient(BaseClient):
67 async def read_deployment_raw(self, deployment_id: UUID) -> Response:
68 return await self._http_client.get(f"/deployments/{deployment_id}")
70 async def read_deployment(
71 self, deployment_id: UUID
72 ) -> Optional[DeploymentResponse]:
73 try:
74 response = await self.read_deployment_raw(deployment_id)
75 response.raise_for_status()
76 except httpx.HTTPStatusError as e:
77 if e.response.status_code == status.HTTP_404_NOT_FOUND:
78 return None
79 raise
80 return DeploymentResponse.model_validate(response.json())
82 async def read_flow_raw(self, flow_id: UUID) -> Response:
83 return await self._http_client.get(f"/flows/{flow_id}")
85 async def create_flow_run(
86 self, deployment_id: UUID, flow_run_create: DeploymentFlowRunCreate
87 ) -> Response:
88 return await self._http_client.post(
89 f"/deployments/{deployment_id}/create_flow_run",
90 json=flow_run_create.model_dump(mode="json"),
91 )
93 async def read_flow_run_raw(self, flow_run_id: UUID) -> Response:
94 return await self._http_client.get(f"/flow_runs/{flow_run_id}")
96 async def read_task_run_raw(self, task_run_id: UUID) -> Response:
97 return await self._http_client.get(f"/task_runs/{task_run_id}")
99 async def resume_flow_run(self, flow_run_id: UUID) -> OrchestrationResult:
100 response = await self._http_client.post(
101 f"/flow_runs/{flow_run_id}/resume",
102 )
103 response.raise_for_status()
104 return OrchestrationResult.model_validate(response.json())
106 async def pause_deployment(self, deployment_id: UUID) -> Response:
107 return await self._http_client.post(
108 f"/deployments/{deployment_id}/pause_deployment",
109 )
111 async def resume_deployment(self, deployment_id: UUID) -> Response:
112 return await self._http_client.post(
113 f"/deployments/{deployment_id}/resume_deployment",
114 )
116 async def set_flow_run_state(
117 self, flow_run_id: UUID, state: StateCreate
118 ) -> Response:
119 return await self._http_client.post(
120 f"/flow_runs/{flow_run_id}/set_state",
121 json={
122 "state": state.model_dump(mode="json"),
123 "force": False,
124 },
125 )
127 async def pause_work_pool(self, work_pool_name: str) -> Response:
128 return await self._http_client.patch(
129 f"/work_pools/{quote(work_pool_name)}", json={"is_paused": True}
130 )
132 async def resume_work_pool(self, work_pool_name: str) -> Response:
133 return await self._http_client.patch(
134 f"/work_pools/{quote(work_pool_name)}", json={"is_paused": False}
135 )
137 async def read_work_pool_raw(self, work_pool_id: UUID) -> Response:
138 return await self._http_client.post(
139 "/work_pools/filter",
140 json={"work_pools": {"id": {"any_": [str(work_pool_id)]}}},
141 )
143 async def read_work_pool(self, work_pool_id: UUID) -> Optional[WorkPool]:
144 response = await self.read_work_pool_raw(work_pool_id)
145 response.raise_for_status()
147 pools = pydantic.TypeAdapter(List[WorkPool]).validate_python(response.json())
148 return pools[0] if pools else None
150 async def read_work_queue_raw(self, work_queue_id: UUID) -> Response:
151 return await self._http_client.get(f"/work_queues/{work_queue_id}")
153 async def read_work_queue_status_raw(self, work_queue_id: UUID) -> Response:
154 return await self._http_client.get(f"/work_queues/{work_queue_id}/status")
156 async def pause_work_queue(self, work_queue_id: UUID) -> Response:
157 return await self._http_client.patch(
158 f"/work_queues/{work_queue_id}",
159 json={"is_paused": True},
160 )
162 async def resume_work_queue(self, work_queue_id: UUID) -> Response:
163 return await self._http_client.patch(
164 f"/work_queues/{work_queue_id}",
165 json={"is_paused": False},
166 )
168 async def read_block_document_raw(
169 self,
170 block_document_id: UUID,
171 include_secrets: bool = True,
172 ) -> Response:
173 return await self._http_client.get(
174 f"/block_documents/{block_document_id}",
175 params=dict(include_secrets=include_secrets),
176 )
178 VARIABLE_PAGE_SIZE = 200
179 MAX_VARIABLES_PER_WORKSPACE = 1000
181 async def read_workspace_variables(
182 self, names: Optional[List[str]] = None
183 ) -> Dict[str, StrictVariableValue]:
184 variables: Dict[str, StrictVariableValue] = {}
186 offset = 0
188 filter = VariableFilter()
190 if names is not None and not names:
191 return variables
192 elif names is not None:
193 filter.name = VariableFilterName(any_=list(set(names)))
195 for offset in range(
196 0, self.MAX_VARIABLES_PER_WORKSPACE, self.VARIABLE_PAGE_SIZE
197 ):
198 response = await self._http_client.post(
199 "/variables/filter",
200 json={
201 "variables": filter.model_dump(),
202 "limit": self.VARIABLE_PAGE_SIZE,
203 "offset": offset,
204 },
205 )
206 if response.status_code >= 300:
207 response.raise_for_status()
209 results = response.json()
210 for variable in results:
211 variables[variable["name"]] = variable["value"]
213 if len(results) < self.VARIABLE_PAGE_SIZE:
214 break
216 return variables
218 async def read_concurrency_limit_v2_raw(
219 self, concurrency_limit_id: UUID
220 ) -> Response:
221 return await self._http_client.get(
222 f"/v2/concurrency_limits/{concurrency_limit_id}"
223 )
226class WorkPoolsOrchestrationClient(BaseClient):
227 async def __aenter__(self) -> Self:
228 return self
230 async def read_work_pool(self, work_pool_name: str) -> WorkPool:
231 """
232 Reads information for a given work pool
233 Args:
234 work_pool_name: The name of the work pool to for which to get
235 information.
236 Returns:
237 Information about the requested work pool.
238 """
239 try:
240 response = await self._http_client.get(f"/work_pools/{work_pool_name}")
241 response.raise_for_status()
242 return WorkPool.model_validate(response.json())
243 except httpx.HTTPStatusError as e:
244 if e.response.status_code == status.HTTP_404_NOT_FOUND:
245 raise ObjectNotFound(http_exc=e) from e
246 else:
247 raise