Coverage for /usr/local/lib/python3.12/site-packages/prefect/workers/base.py: 21%
674 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 13:38 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 13:38 +0000
1from __future__ import annotations 1a
3import abc 1a
4import asyncio 1a
5import datetime 1a
6import threading 1a
7import uuid 1a
8import warnings 1a
9from contextlib import AsyncExitStack 1a
10from functools import partial 1a
11from typing import ( 1a
12 TYPE_CHECKING,
13 Any,
14 Callable,
15 Generic,
16 Optional,
17 Type,
18)
19from uuid import UUID, uuid4 1a
20from zoneinfo import ZoneInfo 1a
22import anyio 1a
23import anyio.abc 1a
24import httpx 1a
25from exceptiongroup import BaseExceptionGroup, ExceptionGroup 1a
26from importlib_metadata import ( 1a
27 distributions, # type: ignore[reportUnknownVariableType] incomplete typing
28)
29from pydantic import BaseModel, Field, PrivateAttr, field_validator 1a
30from pydantic.json_schema import GenerateJsonSchema 1a
31from typing_extensions import Literal, Self, TypeVar 1a
33import prefect 1a
34import prefect.types._datetime 1a
35from prefect._internal.compatibility.deprecated import PrefectDeprecationWarning 1a
36from prefect._internal.schemas.validators import return_v_or_none 1a
37from prefect.client.base import ServerType 1a
38from prefect.client.orchestration import PrefectClient, get_client 1a
39from prefect.client.schemas.actions import WorkPoolCreate, WorkPoolUpdate 1a
40from prefect.client.schemas.objects import Flow as APIFlow 1a
41from prefect.client.schemas.objects import ( 1a
42 Integration,
43 StateType,
44 WorkerMetadata,
45 WorkPool,
46)
47from prefect.client.utilities import inject_client 1a
48from prefect.context import FlowRunContext, TagsContext 1a
49from prefect.events import Event, RelatedResource, emit_event 1a
50from prefect.events.related import object_as_related_resource, tags_as_related_resources 1a
51from prefect.exceptions import ( 1a
52 Abort,
53 ObjectNotFound,
54)
55from prefect.filesystems import LocalFileSystem 1a
56from prefect.futures import PrefectFlowRunFuture 1a
57from prefect.logging.loggers import ( 1a
58 PrefectLogAdapter,
59 flow_run_logger,
60 get_worker_logger,
61)
62from prefect.plugins import load_prefect_collections 1a
63from prefect.settings import ( 1a
64 PREFECT_API_URL,
65 PREFECT_TEST_MODE,
66 PREFECT_WORKER_HEARTBEAT_SECONDS,
67 PREFECT_WORKER_PREFETCH_SECONDS,
68 PREFECT_WORKER_QUERY_SECONDS,
69 get_current_settings,
70)
71from prefect.states import ( 1a
72 Cancelled,
73 Crashed,
74 Pending,
75 exception_to_failed_state,
76)
77from prefect.tasks import Task 1a
78from prefect.types import KeyValueLabels 1a
79from prefect.utilities.dispatch import get_registry_for_type, register_base_type 1a
80from prefect.utilities.engine import propose_state 1a
81from prefect.utilities.services import ( 1a
82 critical_service_loop,
83 start_client_metrics_server,
84 stop_client_metrics_server,
85)
86from prefect.utilities.slugify import slugify 1a
87from prefect.utilities.templating import ( 1a
88 apply_values,
89 resolve_block_document_references,
90 resolve_variables,
91)
92from prefect.utilities.urls import url_for 1a
94if TYPE_CHECKING: 94 ↛ 95line 94 didn't jump to line 95 because the condition on line 94 was never true1a
95 from prefect.client.schemas.objects import FlowRun
96 from prefect.client.schemas.responses import (
97 DeploymentResponse,
98 WorkerFlowRunResponse,
99 )
100 from prefect.flows import Flow
103class BaseJobConfiguration(BaseModel): 1a
104 command: Optional[str] = Field( 1a
105 default=None,
106 description=(
107 "The command to use when starting a flow run. "
108 "In most cases, this should be left blank and the command "
109 "will be automatically generated by the worker."
110 ),
111 )
112 env: dict[str, Optional[str]] = Field( 1a
113 default_factory=dict,
114 title="Environment Variables",
115 description="Environment variables to set when starting a flow run.",
116 )
117 labels: dict[str, str] = Field( 1a
118 default_factory=dict,
119 description=(
120 "Labels applied to infrastructure created by the worker using "
121 "this job configuration."
122 ),
123 )
124 name: Optional[str] = Field( 1a
125 default=None,
126 description=(
127 "Name given to infrastructure created by the worker using this "
128 "job configuration."
129 ),
130 )
132 _related_objects: dict[str, Any] = PrivateAttr(default_factory=dict) 1a
134 @property 1a
135 def is_using_a_runner(self) -> bool: 1a
136 return self.command is not None and "prefect flow-run execute" in self.command
138 @field_validator("command") 1a
139 @classmethod 1a
140 def _coerce_command(cls, v: str | None) -> str | None: 1a
141 return return_v_or_none(v)
143 @field_validator("env", mode="before") 1a
144 @classmethod 1a
145 def _coerce_env(cls, v: dict[str, Any]) -> dict[str, str | None]: 1a
146 return {k: str(v) if v is not None else None for k, v in v.items()}
148 @staticmethod 1a
149 def _get_base_config_defaults(variables: dict[str, Any]) -> dict[str, Any]: 1a
150 """Get default values from base config for all variables that have them."""
151 defaults: dict[str, Any] = {}
152 for variable_name, attrs in variables.items():
153 # We remote `None` values because we don't want to use them in templating.
154 # The currently logic depends on keys not existing to populate the correct value
155 # in some cases.
156 # Pydantic will provide default values if the keys are missing when creating
157 # a configuration class.
158 if "default" in attrs and attrs.get("default") is not None:
159 defaults[variable_name] = attrs["default"]
161 return defaults
163 @classmethod 1a
164 @inject_client 1a
165 async def from_template_and_values( 1a
166 cls,
167 base_job_template: dict[str, Any],
168 values: dict[str, Any],
169 client: "PrefectClient | None" = None,
170 ):
171 """Creates a valid worker configuration object from the provided base
172 configuration and overrides.
174 Important: this method expects that the base_job_template was already
175 validated server-side.
176 """
177 base_config: dict[str, Any] = base_job_template["job_configuration"]
178 variables_schema = base_job_template["variables"]
179 variables = cls._get_base_config_defaults(
180 variables_schema.get("properties", {})
181 )
183 # merge variable defaults for `env` into base config before they're replaced by
184 # deployment overrides
185 if variables.get("env"):
186 if isinstance(base_config.get("env"), dict):
187 # Merge: preserve env vars from job_configuration
188 base_config["env"] = base_config["env"] | variables.get("env")
189 else:
190 # Replace template with defaults
191 base_config["env"] = variables.get("env")
193 variables.update(values)
195 # deep merge `env`
196 if isinstance(base_config.get("env"), dict) and (
197 deployment_env := variables.get("env")
198 ):
199 base_config["env"] = base_config.get("env") | deployment_env
201 populated_configuration = apply_values(template=base_config, values=variables)
202 populated_configuration = await resolve_block_document_references(
203 template=populated_configuration, client=client
204 )
205 populated_configuration = await resolve_variables(
206 template=populated_configuration, client=client
207 )
208 return cls(**populated_configuration)
210 @classmethod 1a
211 def json_template(cls) -> dict[str, Any]: 1a
212 """Returns a dict with job configuration as keys and the corresponding templates as values
214 Defaults to using the job configuration parameter name as the template variable name.
216 e.g.
217 ```python
218 {
219 key1: '{{ key1 }}', # default variable template
220 key2: '{{ template2 }}', # `template2` specifically provide as template
221 }
222 ```
223 """
224 configuration: dict[str, Any] = {}
225 properties = cls.model_json_schema()["properties"]
226 for k, v in properties.items():
227 if v.get("template"):
228 template = v["template"]
229 else:
230 template = "{{ " + k + " }}"
231 configuration[k] = template
233 return configuration
235 def prepare_for_flow_run( 1a
236 self,
237 flow_run: "FlowRun",
238 deployment: "DeploymentResponse | None" = None,
239 flow: "APIFlow | None" = None,
240 work_pool: "WorkPool | None" = None,
241 worker_name: str | None = None,
242 ) -> None:
243 """
244 Prepare the job configuration for a flow run.
246 This method is called by the worker before starting a flow run. It
247 should be used to set any configuration values that are dependent on
248 the flow run.
250 Args:
251 flow_run: The flow run to be executed.
252 deployment: The deployment that the flow run is associated with.
253 flow: The flow that the flow run is associated with.
254 work_pool: The work pool that the flow run is running in.
255 worker_name: The name of the worker that is submitting the flow run.
256 """
258 self._related_objects = {
259 "deployment": deployment,
260 "flow": flow,
261 "flow-run": flow_run,
262 }
264 env = {
265 **self._base_environment(),
266 **self._base_flow_run_environment(flow_run),
267 **(self.env if isinstance(self.env, dict) else {}), # pyright: ignore[reportUnnecessaryIsInstance]
268 }
269 self.env = {key: value for key, value in env.items() if value is not None}
270 self.labels = {
271 **self._base_flow_run_labels(flow_run),
272 **self._base_work_pool_labels(work_pool),
273 **self._base_worker_name_label(worker_name),
274 **self._base_flow_labels(flow),
275 **self._base_deployment_labels(deployment),
276 **self.labels,
277 }
278 self.name = self.name or flow_run.name
279 self.command = self.command or self._base_flow_run_command()
281 @staticmethod 1a
282 def _base_flow_run_command() -> str: 1a
283 """
284 Generate a command for a flow run job.
285 """
286 return "prefect flow-run execute"
288 @staticmethod 1a
289 def _base_flow_run_labels(flow_run: "FlowRun") -> dict[str, str]: 1a
290 """
291 Generate a dictionary of labels for a flow run job.
292 """
293 return {
294 "prefect.io/flow-run-id": str(flow_run.id),
295 "prefect.io/flow-run-name": flow_run.name,
296 "prefect.io/version": prefect.__version__,
297 }
299 @classmethod 1a
300 def _base_environment(cls) -> dict[str, str]: 1a
301 """
302 Environment variables that should be passed to all created infrastructure.
304 These values should be overridable with the `env` field.
305 """
306 return get_current_settings().to_environment_variables(exclude_unset=True)
308 @staticmethod 1a
309 def _base_flow_run_environment(flow_run: "FlowRun | None") -> dict[str, str]: 1a
310 """
311 Generate a dictionary of environment variables for a flow run job.
312 """
313 if flow_run is None:
314 return {}
316 return {"PREFECT__FLOW_RUN_ID": str(flow_run.id)}
318 @staticmethod 1a
319 def _base_deployment_labels( 1a
320 deployment: "DeploymentResponse | None",
321 ) -> dict[str, str]:
322 if deployment is None:
323 return {}
325 labels = {
326 "prefect.io/deployment-id": str(deployment.id),
327 "prefect.io/deployment-name": deployment.name,
328 }
329 if deployment.updated is not None:
330 labels["prefect.io/deployment-updated"] = deployment.updated.astimezone(
331 ZoneInfo("UTC")
332 ).isoformat()
333 return labels
335 @staticmethod 1a
336 def _base_flow_labels(flow: "APIFlow | None") -> dict[str, str]: 1a
337 if flow is None:
338 return {}
340 return {
341 "prefect.io/flow-id": str(flow.id),
342 "prefect.io/flow-name": flow.name,
343 }
345 @staticmethod 1a
346 def _base_work_pool_labels(work_pool: "WorkPool | None") -> dict[str, str]: 1a
347 """Adds the work pool labels to the job manifest."""
348 if work_pool is None:
349 return {}
351 return {
352 "prefect.io/work-pool-name": work_pool.name,
353 "prefect.io/work-pool-id": str(work_pool.id),
354 }
356 @staticmethod 1a
357 def _base_worker_name_label(worker_name: str | None) -> dict[str, str]: 1a
358 """Adds the worker name label to the job manifest."""
359 if worker_name is None:
360 return {}
362 return {"prefect.io/worker-name": worker_name}
364 def _related_resources(self) -> list[RelatedResource]: 1a
365 tags: set[str] = set()
366 related: list[RelatedResource] = []
368 for kind, obj in self._related_objects.items():
369 if obj is None:
370 continue
371 if hasattr(obj, "tags"):
372 tags.update(obj.tags)
373 related.append(object_as_related_resource(kind=kind, role=kind, object=obj))
375 return related + tags_as_related_resources(tags)
378class BaseVariables(BaseModel): 1a
379 name: Optional[str] = Field( 1a
380 default=None,
381 description="Name given to infrastructure created by a worker.",
382 )
383 env: dict[str, Optional[str]] = Field( 1a
384 default_factory=dict,
385 title="Environment Variables",
386 description="Environment variables to set when starting a flow run.",
387 )
388 labels: dict[str, str] = Field( 1a
389 default_factory=dict,
390 description="Labels applied to infrastructure created by a worker.",
391 )
392 command: Optional[str] = Field( 1a
393 default=None,
394 description=(
395 "The command to use when starting a flow run. "
396 "In most cases, this should be left blank and the command "
397 "will be automatically generated by the worker."
398 ),
399 )
401 @classmethod 1a
402 def model_json_schema( 1a
403 cls,
404 by_alias: bool = True,
405 ref_template: str = "#/definitions/{model}",
406 schema_generator: Type[GenerateJsonSchema] = GenerateJsonSchema,
407 mode: Literal["validation", "serialization"] = "validation",
408 *,
409 union_format: Literal["any_of", "primitive_type_array"] = "any_of",
410 ) -> dict[str, Any]:
411 """TODO: stop overriding this method - use GenerateSchema in ConfigDict instead?"""
412 schema = super().model_json_schema(
413 by_alias, ref_template, schema_generator, mode
414 )
416 # ensure backwards compatibility by copying $defs into definitions
417 if "$defs" in schema:
418 schema["definitions"] = schema.pop("$defs")
420 # we aren't expecting these additional fields in the schema
421 if "additionalProperties" in schema:
422 schema.pop("additionalProperties")
424 for _, definition in schema.get("definitions", {}).items():
425 if "additionalProperties" in definition:
426 definition.pop("additionalProperties")
428 return schema
431class BaseWorkerResult(BaseModel, abc.ABC): 1a
432 identifier: str 1a
433 status_code: int 1a
435 def __bool__(self) -> bool: 1a
436 return self.status_code == 0
439C = TypeVar("C", bound=BaseJobConfiguration) 1a
440V = TypeVar("V", bound=BaseVariables) 1a
441R = TypeVar("R", bound=BaseWorkerResult) 1a
442FR = TypeVar("FR") # used to capture the return type of a flow 1a
445@register_base_type 1a
446class BaseWorker(abc.ABC, Generic[C, V, R]): 1a
447 type: str 1a
448 job_configuration: Type[C] = BaseJobConfiguration # type: ignore 1a
449 job_configuration_variables: Optional[Type[V]] = None 1a
451 _documentation_url = "" 1a
452 _logo_url = "" 1a
453 _description = "" 1a
455 def __init__( 1a
456 self,
457 work_pool_name: str,
458 work_queues: list[str] | None = None,
459 name: str | None = None,
460 prefetch_seconds: float | None = None,
461 create_pool_if_not_found: bool = True,
462 limit: int | None = None,
463 heartbeat_interval_seconds: int | None = None,
464 *,
465 base_job_template: dict[str, Any] | None = None,
466 ):
467 """
468 Base class for all Prefect workers.
470 Args:
471 name: The name of the worker. If not provided, a random one
472 will be generated. If provided, it cannot contain '/' or '%'.
473 The name is used to identify the worker in the UI; if two
474 processes have the same name, they will be treated as the same
475 worker.
476 work_pool_name: The name of the work pool to poll.
477 work_queues: A list of work queues to poll. If not provided, all
478 work queue in the work pool will be polled.
479 prefetch_seconds: The number of seconds to prefetch flow runs for.
480 create_pool_if_not_found: Whether to create the work pool
481 if it is not found. Defaults to `True`, but can be set to `False` to
482 ensure that work pools are not created accidentally.
483 limit: The maximum number of flow runs this worker should be running at
484 a given time.
485 heartbeat_interval_seconds: The number of seconds between worker heartbeats.
486 base_job_template: If creating the work pool, provide the base job
487 template to use. Logs a warning if the pool already exists.
488 """
489 if name and ("/" in name or "%" in name):
490 raise ValueError("Worker name cannot contain '/' or '%'")
491 self.name: str = name or f"{self.__class__.__name__} {uuid4()}"
492 self._started_event: Optional[Event] = None
493 self.backend_id: Optional[UUID] = None
494 self._logger = get_worker_logger(self)
496 self.is_setup = False
497 self._create_pool_if_not_found = create_pool_if_not_found
498 self._base_job_template = base_job_template
499 self._work_pool_name = work_pool_name
500 self._work_queues: set[str] = set(work_queues) if work_queues else set()
502 self._prefetch_seconds: float = (
503 prefetch_seconds or PREFECT_WORKER_PREFETCH_SECONDS.value()
504 )
505 self.heartbeat_interval_seconds: int = (
506 heartbeat_interval_seconds or PREFECT_WORKER_HEARTBEAT_SECONDS.value()
507 )
509 self._work_pool: Optional[WorkPool] = None
510 self._exit_stack: AsyncExitStack = AsyncExitStack()
511 self._runs_task_group: Optional[anyio.abc.TaskGroup] = None
512 self._client: Optional[PrefectClient] = None
513 self._last_polled_time: datetime.datetime = prefect.types._datetime.now("UTC")
514 self._limit = limit
515 self._limiter: Optional[anyio.CapacityLimiter] = None
516 self._submitting_flow_run_ids: set[UUID] = set()
517 self._cancelling_flow_run_ids: set[UUID] = set()
518 self._scheduled_task_scopes: set[anyio.CancelScope] = set()
519 self._worker_metadata_sent = False
521 @property 1a
522 def client(self) -> PrefectClient: 1a
523 if self._client is None:
524 raise RuntimeError(
525 "Worker has not been correctly initialized. Please use the worker class as an async context manager."
526 )
527 return self._client
529 @property 1a
530 def work_pool(self) -> WorkPool: 1a
531 if self._work_pool is None:
532 raise RuntimeError(
533 "Worker has not been correctly initialized. Please use the worker class as an async context manager."
534 )
535 return self._work_pool
537 @property 1a
538 def limiter(self) -> anyio.CapacityLimiter: 1a
539 if self._limiter is None:
540 raise RuntimeError(
541 "Worker has not been correctly initialized. Please use the worker class as an async context manager."
542 )
543 return self._limiter
545 @classmethod 1a
546 def get_documentation_url(cls) -> str: 1a
547 return cls._documentation_url
549 @classmethod 1a
550 def get_logo_url(cls) -> str: 1a
551 return cls._logo_url
553 @classmethod 1a
554 def get_description(cls) -> str: 1a
555 return cls._description
557 @classmethod 1a
558 def get_default_base_job_template(cls) -> dict[str, Any]: 1a
559 if cls.job_configuration_variables is None:
560 schema = cls.job_configuration.model_json_schema()
561 # remove "template" key from all dicts in schema['properties'] because it is not a
562 # relevant field
563 for key, value in schema["properties"].items():
564 if isinstance(value, dict):
565 schema["properties"][key].pop("template", None)
566 variables_schema = schema
567 else:
568 variables_schema = cls.job_configuration_variables.model_json_schema()
569 variables_schema.pop("title", None)
570 return {
571 "job_configuration": cls.job_configuration.json_template(),
572 "variables": variables_schema,
573 }
575 @staticmethod 1a
576 def get_worker_class_from_type( 1a
577 type: str,
578 ) -> Optional[Type["BaseWorker[Any, Any, Any]"]]:
579 """
580 Returns the worker class for a given worker type. If the worker type
581 is not recognized, returns None.
582 """
583 load_prefect_collections()
584 worker_registry = get_registry_for_type(BaseWorker)
585 if worker_registry is not None:
586 return worker_registry.get(type)
588 @staticmethod 1a
589 def get_all_available_worker_types() -> list[str]: 1a
590 """
591 Returns all worker types available in the local registry.
592 """
593 load_prefect_collections()
594 worker_registry = get_registry_for_type(BaseWorker)
595 if worker_registry is not None:
596 return list(worker_registry.keys())
597 return []
599 def get_name_slug(self) -> str: 1a
600 return slugify(self.name)
602 def get_flow_run_logger(self, flow_run: "FlowRun") -> PrefectLogAdapter: 1a
603 extra = {
604 "worker_name": self.name,
605 "work_pool_name": (
606 self._work_pool_name if self._work_pool else "<unknown>"
607 ),
608 "work_pool_id": str(getattr(self._work_pool, "id", "unknown")),
609 }
610 if self.backend_id:
611 extra["worker_id"] = str(self.backend_id)
613 return flow_run_logger(flow_run=flow_run).getChild(
614 "worker",
615 extra=extra,
616 )
618 async def start( 1a
619 self,
620 run_once: bool = False,
621 with_healthcheck: bool = False,
622 printer: Callable[..., None] = print,
623 ) -> None:
624 """
625 Starts the worker and runs the main worker loops.
627 By default, the worker will run loops to poll for scheduled/cancelled flow
628 runs and sync with the Prefect API server.
630 If `run_once` is set, the worker will only run each loop once and then return.
632 If `with_healthcheck` is set, the worker will start a healthcheck server which
633 can be used to determine if the worker is still polling for flow runs and restart
634 the worker if necessary.
636 Args:
637 run_once: If set, the worker will only run each loop once then return.
638 with_healthcheck: If set, the worker will start a healthcheck server.
639 printer: A `print`-like function where logs will be reported.
640 """
641 healthcheck_server = None
642 healthcheck_thread = None
643 try:
644 async with self as worker:
645 # schedule the scheduled flow run polling loop
646 async with anyio.create_task_group() as loops_task_group:
647 loops_task_group.start_soon(
648 partial(
649 critical_service_loop,
650 workload=self.get_and_submit_flow_runs,
651 interval=PREFECT_WORKER_QUERY_SECONDS.value(),
652 run_once=run_once,
653 jitter_range=0.3,
654 backoff=4, # Up to ~1 minute interval during backoff
655 )
656 )
657 # schedule the sync loop
658 loops_task_group.start_soon(
659 partial(
660 critical_service_loop,
661 workload=self.sync_with_backend,
662 interval=self.heartbeat_interval_seconds,
663 run_once=run_once,
664 jitter_range=0.3,
665 backoff=4,
666 )
667 )
669 self._started_event = await self._emit_worker_started_event()
671 start_client_metrics_server()
673 if with_healthcheck:
674 from prefect.workers.server import build_healthcheck_server
676 # we'll start the ASGI server in a separate thread so that
677 # uvicorn does not block the main thread
678 healthcheck_server = build_healthcheck_server(
679 worker=worker,
680 query_interval_seconds=PREFECT_WORKER_QUERY_SECONDS.value(),
681 )
682 healthcheck_thread = threading.Thread(
683 name="healthcheck-server-thread",
684 target=healthcheck_server.run,
685 daemon=True,
686 )
687 healthcheck_thread.start()
688 printer(f"Worker {worker.name!r} started!")
690 # If running once, wait for active runs to finish before teardown
691 if run_once and self._limiter:
692 # Use the limiter's borrowed token count as the source of truth
693 while self.limiter.borrowed_tokens > 0:
694 self._logger.debug(
695 "Waiting for %s active run(s) to finish before shutdown...",
696 self.limiter.borrowed_tokens,
697 )
698 await anyio.sleep(0.1)
699 finally:
700 stop_client_metrics_server()
702 if healthcheck_server and healthcheck_thread:
703 self._logger.debug("Stopping healthcheck server...")
704 healthcheck_server.should_exit = True
705 healthcheck_thread.join()
706 self._logger.debug("Healthcheck server stopped.")
708 printer(f"Worker {worker.name!r} stopped!")
710 @abc.abstractmethod 1a
711 async def run( 1a
712 self,
713 flow_run: "FlowRun",
714 configuration: C,
715 task_status: Optional[anyio.abc.TaskStatus[int]] = None,
716 ) -> R:
717 """
718 Runs a given flow run on the current worker.
719 """
720 raise NotImplementedError(
721 "Workers must implement a method for running submitted flow runs"
722 )
724 async def _initiate_run( 1a
725 self,
726 flow_run: "FlowRun",
727 configuration: C,
728 ) -> None:
729 """
730 This method is called by the worker to initiate a flow run and should return as
731 soon as possible.
733 This method is used in `.submit` to allow non-blocking submission of flows. For
734 workers that wait for completion in their `run` method, this method should be
735 implemented to return immediately.
737 If this method is not implemented, `.submit` will fall back to the `.run` method.
738 """
739 raise NotImplementedError(
740 "This worker has not implemented `_initiate_run`. Please use `run` instead."
741 )
743 async def submit( 1a
744 self,
745 flow: "Flow[..., FR]",
746 parameters: dict[str, Any] | None = None,
747 job_variables: dict[str, Any] | None = None,
748 ) -> "PrefectFlowRunFuture[FR]":
749 """
750 EXPERIMENTAL: The interface for this method is subject to change.
752 Submits a flow to run via the worker.
754 Args:
755 flow: The flow to submit
756 parameters: The parameters to pass to the flow
758 Returns:
759 A flow run object
760 """
761 warnings.warn(
762 "Ad-hoc flow submission via workers is experimental. The interface "
763 "and behavior of this feature are subject to change.",
764 category=FutureWarning,
765 )
766 if self._runs_task_group is None:
767 raise RuntimeError("Worker not properly initialized")
769 flow_run = await self._runs_task_group.start(
770 partial(
771 self._submit_adhoc_run,
772 flow=flow,
773 parameters=parameters,
774 job_variables=job_variables,
775 ),
776 )
777 return PrefectFlowRunFuture(flow_run_id=flow_run.id)
779 async def _submit_adhoc_run( 1a
780 self,
781 flow: "Flow[..., FR]",
782 parameters: dict[str, Any] | None = None,
783 job_variables: dict[str, Any] | None = None,
784 task_status: anyio.abc.TaskStatus["FlowRun"] | None = None,
785 ):
786 """
787 Submits a flow for the worker to kick off execution for.
788 """
789 from prefect._experimental.bundles import (
790 aupload_bundle_to_storage,
791 convert_step_to_command,
792 create_bundle_for_flow_run,
793 )
795 if (
796 self.work_pool.storage_configuration.bundle_upload_step is None
797 or self.work_pool.storage_configuration.bundle_execution_step is None
798 ):
799 raise RuntimeError(
800 f"Storage is not configured for work pool {self.work_pool.name!r}. "
801 "Please configure storage for the work pool by running `prefect "
802 "work-pool storage configure`."
803 )
805 from prefect.results import aresolve_result_storage, get_result_store
807 current_result_store = get_result_store()
808 # Check result storage and use the work pool default if needed
809 if (
810 current_result_store.result_storage is None
811 or isinstance(current_result_store.result_storage, LocalFileSystem)
812 and flow.result_storage is None
813 ):
814 if (
815 self.work_pool.storage_configuration.default_result_storage_block_id
816 is None
817 ):
818 self._logger.warning(
819 f"Flow {flow.name!r} has no result storage configured. Please configure "
820 "result storage for the flow if you want to retrieve the result for the flow run."
821 )
822 else:
823 # Use the work pool's default result storage block for the flow run to ensure the caller can retrieve the result
824 flow = flow.with_options(
825 result_storage=await aresolve_result_storage(
826 self.work_pool.storage_configuration.default_result_storage_block_id
827 ),
828 persist_result=True,
829 )
831 bundle_key = str(uuid.uuid4())
832 upload_command = convert_step_to_command(
833 self.work_pool.storage_configuration.bundle_upload_step,
834 bundle_key,
835 quiet=True,
836 )
837 execute_command = convert_step_to_command(
838 self.work_pool.storage_configuration.bundle_execution_step, bundle_key
839 )
841 job_variables = (job_variables or {}) | {"command": " ".join(execute_command)}
842 parameters = parameters or {}
844 # Create a parent task run if this is a child flow run to ensure it shows up as a child flow in the UI
845 parent_task_run = None
846 if flow_run_ctx := FlowRunContext.get():
847 parent_task = Task[Any, Any](
848 name=flow.name,
849 fn=flow.fn,
850 version=flow.version,
851 )
852 parent_task_run = await parent_task.create_run(
853 flow_run_context=flow_run_ctx,
854 parameters=parameters,
855 )
857 flow_run = await self.client.create_flow_run(
858 flow,
859 parameters=flow.serialize_parameters(parameters),
860 state=Pending(),
861 job_variables=job_variables,
862 work_pool_name=self.work_pool.name,
863 tags=TagsContext.get().current_tags,
864 parent_task_run_id=getattr(parent_task_run, "id", None),
865 )
866 if task_status is not None:
867 # Emit the flow run object to .submit to allow it to return a future as soon as possible
868 task_status.started(flow_run)
869 # Avoid an API call to get the flow
870 api_flow = APIFlow(id=flow_run.flow_id, name=flow.name, labels={})
871 logger = self.get_flow_run_logger(flow_run)
873 configuration = await self.job_configuration.from_template_and_values(
874 base_job_template=self.work_pool.base_job_template,
875 values=job_variables,
876 client=self._client,
877 )
878 configuration.prepare_for_flow_run(
879 flow_run=flow_run,
880 flow=api_flow,
881 work_pool=self.work_pool,
882 worker_name=self.name,
883 )
885 bundle = create_bundle_for_flow_run(flow=flow, flow_run=flow_run)
886 await aupload_bundle_to_storage(bundle, bundle_key, upload_command)
888 logger.debug("Successfully uploaded execution bundle")
890 try:
891 # Call the implementation-specific run method with the constructed configuration. This is where the
892 # rubber meets the road.
893 try:
894 await self._initiate_run(flow_run, configuration)
895 except NotImplementedError:
896 result = await self.run(flow_run, configuration)
898 if result.status_code != 0:
899 await self._propose_crashed_state(
900 flow_run,
901 (
902 "Flow run infrastructure exited with non-zero status code"
903 f" {result.status_code}."
904 ),
905 )
906 except Exception as exc:
907 # This flow run was being submitted and did not start successfully
908 logger.exception(
909 f"Failed to submit flow run '{flow_run.id}' to infrastructure."
910 )
911 message = f"Flow run could not be submitted to infrastructure:\n{exc!r}"
912 await self._propose_crashed_state(flow_run, message, client=self.client)
914 @classmethod 1a
915 def __dispatch_key__(cls) -> str | None: 1a
916 if cls.__name__ == "BaseWorker": 1a
917 return None # The base class is abstract 1a
918 return cls.type 1a
920 async def setup(self) -> None: 1a
921 """Prepares the worker to run."""
922 self._logger.debug("Setting up worker...")
923 self._runs_task_group = anyio.create_task_group()
924 self._limiter = (
925 anyio.CapacityLimiter(self._limit) if self._limit is not None else None
926 )
928 if not PREFECT_TEST_MODE and not PREFECT_API_URL.value():
929 raise ValueError("`PREFECT_API_URL` must be set to start a Worker.")
931 self._client = get_client()
933 await self._exit_stack.enter_async_context(self._client)
934 await self._exit_stack.enter_async_context(self._runs_task_group)
936 await self.sync_with_backend()
938 self.is_setup = True
940 async def teardown(self, *exc_info: Any) -> None: 1a
941 """Cleans up resources after the worker is stopped."""
942 self._logger.debug("Tearing down worker...")
943 self.is_setup: bool = False
944 for scope in self._scheduled_task_scopes:
945 scope.cancel()
947 # Emit stopped event before closing client
948 if self._started_event:
949 try:
950 await self._emit_worker_stopped_event(self._started_event)
951 except Exception:
952 self._logger.exception("Failed to emit worker stopped event")
954 await self._exit_stack.__aexit__(*exc_info)
955 self._runs_task_group = None
956 self._client = None
958 def is_worker_still_polling(self, query_interval_seconds: float) -> bool: 1a
959 """
960 This method is invoked by a webserver healthcheck handler
961 and returns a boolean indicating if the worker has recorded a
962 scheduled flow run poll within a variable amount of time.
964 The `query_interval_seconds` is the same value that is used by
965 the loop services - we will evaluate if the _last_polled_time
966 was within that interval x 30 (so 10s -> 5m)
968 The instance property `self._last_polled_time`
969 is currently set/updated in `get_and_submit_flow_runs()`
970 """
971 threshold_seconds = query_interval_seconds * 30
973 seconds_since_last_poll = (
974 prefect.types._datetime.now("UTC") - self._last_polled_time
975 ).seconds
977 is_still_polling = seconds_since_last_poll <= threshold_seconds
979 if not is_still_polling:
980 self._logger.error(
981 f"Worker has not polled in the last {seconds_since_last_poll} seconds "
982 "and should be restarted"
983 )
985 return is_still_polling
987 async def get_and_submit_flow_runs(self) -> list["FlowRun"]: 1a
988 runs_response = await self._get_scheduled_flow_runs()
990 self._last_polled_time = prefect.types._datetime.now("UTC")
992 return await self._submit_scheduled_flow_runs(flow_run_response=runs_response)
994 async def _update_local_work_pool_info(self) -> None: 1a
995 if TYPE_CHECKING:
996 assert self._client is not None
997 try:
998 work_pool = await self._client.read_work_pool(
999 work_pool_name=self._work_pool_name
1000 )
1002 except ObjectNotFound:
1003 if self._create_pool_if_not_found:
1004 wp = WorkPoolCreate(
1005 name=self._work_pool_name,
1006 type=self.type,
1007 )
1008 if self._base_job_template is not None:
1009 wp.base_job_template = self._base_job_template
1011 work_pool = await self._client.create_work_pool(work_pool=wp)
1012 self._logger.info(f"Work pool {self._work_pool_name!r} created.")
1013 else:
1014 self._logger.warning(f"Work pool {self._work_pool_name!r} not found!")
1015 if self._base_job_template is not None:
1016 self._logger.warning(
1017 "Ignoring supplied base job template because the work pool"
1018 " already exists"
1019 )
1020 return
1022 # if the remote config type changes (or if it's being loaded for the
1023 # first time), check if it matches the local type and warn if not
1024 if getattr(self._work_pool, "type", 0) != work_pool.type:
1025 if work_pool.type != self.__class__.type:
1026 self._logger.warning(
1027 "Worker type mismatch! This worker process expects type "
1028 f"{self.type!r} but received {work_pool.type!r}"
1029 " from the server. Unexpected behavior may occur."
1030 )
1032 # once the work pool is loaded, verify that it has a `base_job_template` and
1033 # set it if not
1034 if not work_pool.base_job_template:
1035 job_template = self.__class__.get_default_base_job_template()
1036 await self._set_work_pool_template(work_pool, job_template)
1037 work_pool.base_job_template = job_template
1039 self._work_pool = work_pool
1041 async def _worker_metadata(self) -> Optional[WorkerMetadata]: 1a
1042 """
1043 Returns metadata about installed Prefect collections for the worker.
1044 """
1045 installed_integrations = load_prefect_collections().keys()
1047 integration_versions = [
1048 Integration(name=dist.metadata["Name"], version=dist.version) # pyright: ignore[reportOptionalSubscript]
1049 for dist in distributions()
1050 # PyPI packages often use dashes, but Python package names use underscores
1051 # because they must be valid identifiers.
1052 if dist.metadata # pyright: ignore[reportOptionalMemberAccess]
1053 and (name := dist.metadata.get("Name"))
1054 and (name.replace("-", "_") in installed_integrations)
1055 ]
1057 if integration_versions:
1058 return WorkerMetadata(integrations=integration_versions)
1059 return None
1061 async def _send_worker_heartbeat(self) -> Optional[UUID]: 1a
1062 """
1063 Sends a heartbeat to the API.
1064 """
1065 if not self._client:
1066 self._logger.warning("Client has not been initialized; skipping heartbeat.")
1067 return None
1068 if not self._work_pool:
1069 self._logger.debug("Worker has no work pool; skipping heartbeat.")
1070 return None
1072 should_get_worker_id = self._should_get_worker_id()
1074 params: dict[str, Any] = {
1075 "work_pool_name": self._work_pool_name,
1076 "worker_name": self.name,
1077 "heartbeat_interval_seconds": self.heartbeat_interval_seconds,
1078 "get_worker_id": should_get_worker_id,
1079 }
1080 if (
1081 self._client.server_type == ServerType.CLOUD
1082 and not self._worker_metadata_sent
1083 ):
1084 worker_metadata = await self._worker_metadata()
1085 if worker_metadata:
1086 params["worker_metadata"] = worker_metadata
1087 self._worker_metadata_sent = True
1089 worker_id = None
1090 try:
1091 worker_id = await self._client.send_worker_heartbeat(**params)
1092 except httpx.HTTPStatusError as e:
1093 if e.response.status_code == 422 and should_get_worker_id:
1094 self._logger.warning(
1095 "Failed to retrieve worker ID from the Prefect API server."
1096 )
1097 params["get_worker_id"] = False
1098 worker_id = await self._client.send_worker_heartbeat(**params)
1099 else:
1100 raise e
1102 if should_get_worker_id and worker_id is None:
1103 self._logger.warning(
1104 "Failed to retrieve worker ID from the Prefect API server."
1105 )
1107 return worker_id
1109 async def sync_with_backend(self) -> None: 1a
1110 """
1111 Updates the worker's local information about it's current work pool and
1112 queues. Sends a worker heartbeat to the API.
1113 """
1114 await self._update_local_work_pool_info()
1116 remote_id = await self._send_worker_heartbeat()
1117 if remote_id:
1118 self.backend_id = remote_id
1119 self._logger = get_worker_logger(self)
1121 self._logger.debug(
1122 "Worker synchronized with the Prefect API server. "
1123 + (f"Remote ID: {self.backend_id}" if self.backend_id else "")
1124 )
1126 def _should_get_worker_id(self): 1a
1127 """Determines if the worker should request an ID from the API server."""
1128 return (
1129 self._client
1130 and self._client.server_type == ServerType.CLOUD
1131 and self.backend_id is None
1132 )
1134 async def _get_scheduled_flow_runs( 1a
1135 self,
1136 ) -> list["WorkerFlowRunResponse"]:
1137 """
1138 Retrieve scheduled flow runs from the work pool's queues.
1139 """
1140 scheduled_before = prefect.types._datetime.now("UTC") + datetime.timedelta(
1141 seconds=int(self._prefetch_seconds)
1142 )
1143 self._logger.debug(
1144 f"Querying for flow runs scheduled before {scheduled_before}"
1145 )
1146 try:
1147 scheduled_flow_runs = (
1148 await self.client.get_scheduled_flow_runs_for_work_pool(
1149 work_pool_name=self._work_pool_name,
1150 scheduled_before=scheduled_before,
1151 work_queue_names=list(self._work_queues),
1152 )
1153 )
1154 self._logger.debug(
1155 f"Discovered {len(scheduled_flow_runs)} scheduled_flow_runs"
1156 )
1157 return scheduled_flow_runs
1158 except ObjectNotFound:
1159 # the pool doesn't exist; it will be created on the next
1160 # heartbeat (or an appropriate warning will be logged)
1161 return []
1163 async def _submit_scheduled_flow_runs( 1a
1164 self, flow_run_response: list["WorkerFlowRunResponse"]
1165 ) -> list["FlowRun"]:
1166 """
1167 Takes a list of WorkerFlowRunResponses and submits the referenced flow runs
1168 for execution by the worker.
1169 """
1170 submittable_flow_runs = [entry.flow_run for entry in flow_run_response]
1172 for flow_run in submittable_flow_runs:
1173 if flow_run.id in self._submitting_flow_run_ids:
1174 self._logger.debug(
1175 f"Skipping {flow_run.id} because it's already being submitted"
1176 )
1177 continue
1178 try:
1179 if self._limiter:
1180 self._limiter.acquire_on_behalf_of_nowait(flow_run.id)
1181 except anyio.WouldBlock:
1182 self._logger.debug(
1183 f"Flow run limit reached; {self.limiter.borrowed_tokens} flow runs"
1184 " in progress."
1185 )
1186 break
1187 else:
1188 run_logger = self.get_flow_run_logger(flow_run)
1189 run_logger.info(
1190 f"Worker '{self.name}' submitting flow run '{flow_run.id}'"
1191 )
1192 if self.backend_id:
1193 try:
1194 worker_url = url_for(
1195 "worker",
1196 obj_id=self.backend_id,
1197 work_pool_name=self._work_pool_name,
1198 )
1200 run_logger.info(
1201 f"Running on worker id: {self.backend_id}. See worker logs here: {worker_url}"
1202 )
1203 except ValueError as ve:
1204 run_logger.warning(f"Failed to generate worker URL: {ve}")
1206 self._submitting_flow_run_ids.add(flow_run.id)
1207 if TYPE_CHECKING:
1208 assert self._runs_task_group is not None
1209 self._runs_task_group.start_soon(
1210 self._submit_run,
1211 flow_run,
1212 )
1214 return list(
1215 filter(
1216 lambda run: run.id in self._submitting_flow_run_ids,
1217 submittable_flow_runs,
1218 )
1219 )
1221 async def _submit_run(self, flow_run: "FlowRun") -> None: 1a
1222 """
1223 Submits a given flow run for execution by the worker.
1224 """
1225 run_logger = self.get_flow_run_logger(flow_run)
1227 if flow_run.deployment_id:
1228 try:
1229 await self.client.read_deployment(flow_run.deployment_id)
1230 except ObjectNotFound:
1231 self._logger.exception(
1232 f"Deployment {flow_run.deployment_id} no longer exists. "
1233 f"Flow run {flow_run.id} will not be submitted for"
1234 " execution"
1235 )
1236 self._submitting_flow_run_ids.remove(flow_run.id)
1237 await self._mark_flow_run_as_cancelled(
1238 flow_run,
1239 state_updates=dict(
1240 message=f"Deployment {flow_run.deployment_id} no longer exists, cancelled run."
1241 ),
1242 )
1243 return
1245 ready_to_submit = await self._propose_pending_state(flow_run)
1246 self._logger.debug(f"Ready to submit {flow_run.id}: {ready_to_submit}")
1247 if ready_to_submit:
1248 if TYPE_CHECKING:
1249 assert self._runs_task_group is not None
1250 readiness_result = await self._runs_task_group.start(
1251 self._submit_run_and_capture_errors, flow_run
1252 )
1254 if readiness_result and not isinstance(readiness_result, Exception):
1255 try:
1256 await self.client.update_flow_run(
1257 flow_run_id=flow_run.id,
1258 infrastructure_pid=str(readiness_result),
1259 )
1260 except Exception:
1261 run_logger.exception(
1262 "An error occurred while setting the `infrastructure_pid` on "
1263 f"flow run {flow_run.id!r}. The flow run will "
1264 "not be cancellable."
1265 )
1267 run_logger.info(f"Completed submission of flow run '{flow_run.id}'")
1269 else:
1270 # If the run is not ready to submit, release the concurrency slot
1271 self._release_limit_slot(flow_run.id)
1272 else:
1273 self._release_limit_slot(flow_run.id)
1274 self._submitting_flow_run_ids.remove(flow_run.id)
1276 async def _submit_run_and_capture_errors( 1a
1277 self,
1278 flow_run: "FlowRun",
1279 task_status: anyio.abc.TaskStatus[int | Exception] | None = None,
1280 ) -> BaseWorkerResult | Exception:
1281 run_logger = self.get_flow_run_logger(flow_run)
1283 try:
1284 configuration = await self._get_configuration(flow_run)
1285 submitted_event = self._emit_flow_run_submitted_event(configuration)
1286 await self._give_worker_labels_to_flow_run(flow_run.id)
1288 result = await self.run(
1289 flow_run=flow_run,
1290 task_status=task_status,
1291 configuration=configuration,
1292 )
1293 except Exception as exc:
1294 if task_status and not getattr(task_status, "_future").done():
1295 # This flow run was being submitted and did not start successfully
1296 run_logger.exception(
1297 f"Failed to submit flow run '{flow_run.id}' to infrastructure."
1298 )
1299 # Mark the task as started to prevent agent crash
1300 task_status.started(exc)
1301 message = f"Flow run could not be submitted to infrastructure:\n{exc!r}"
1302 await self._propose_crashed_state(flow_run, message)
1303 else:
1304 run_logger.exception(
1305 f"An error occurred while monitoring flow run '{flow_run.id}'. "
1306 "The flow run will not be marked as failed, but an issue may have "
1307 "occurred."
1308 )
1309 return exc
1310 finally:
1311 self._release_limit_slot(flow_run.id)
1313 if task_status and not getattr(task_status, "_future").done():
1314 run_logger.error(
1315 f"Infrastructure returned without reporting flow run '{flow_run.id}' "
1316 "as started or raising an error. This behavior is not expected and "
1317 "generally indicates improper implementation of infrastructure. The "
1318 "flow run will not be marked as failed, but an issue may have occurred."
1319 )
1320 # Mark the task as started to prevent agent crash
1321 task_status.started(
1322 RuntimeError(
1323 "Infrastructure returned without reporting flow run as started or raising an error."
1324 )
1325 )
1327 if result.status_code != 0:
1328 await self._propose_crashed_state(
1329 flow_run,
1330 (
1331 "Flow run infrastructure exited with non-zero status code"
1332 f" {result.status_code}."
1333 ),
1334 )
1336 if submitted_event:
1337 self._emit_flow_run_executed_event(result, configuration, submitted_event)
1339 return result
1341 def _release_limit_slot(self, flow_run_id: UUID) -> None: 1a
1342 """
1343 Frees up a slot taken by the given flow run id.
1345 This method gracefully handles cases where the slot has already been released
1346 to prevent worker crashes from double-release scenarios.
1347 """
1348 if self._limiter:
1349 try:
1350 self._limiter.release_on_behalf_of(flow_run_id)
1351 self._logger.debug("Limit slot released for flow run '%s'", flow_run_id)
1352 except RuntimeError:
1353 # Slot was already released - this can happen in certain error paths
1354 # where multiple cleanup attempts occur. Log it but don't crash.
1355 self._logger.debug(
1356 "Limit slot for flow run '%s' was already released", flow_run_id
1357 )
1359 def get_status(self) -> dict[str, Any]: 1a
1360 """
1361 Retrieves the status of the current worker including its name, current worker
1362 pool, the work pool queues it is polling, and its local settings.
1363 """
1364 return {
1365 "name": self.name,
1366 "work_pool": (
1367 self._work_pool.model_dump(mode="json")
1368 if self._work_pool is not None
1369 else None
1370 ),
1371 "settings": {
1372 "prefetch_seconds": self._prefetch_seconds,
1373 },
1374 }
1376 async def _get_configuration( 1a
1377 self,
1378 flow_run: "FlowRun",
1379 deployment: Optional["DeploymentResponse"] = None,
1380 ) -> C:
1381 if not deployment and flow_run.deployment_id:
1382 deployment = await self.client.read_deployment(flow_run.deployment_id)
1384 flow = await self.client.read_flow(flow_run.flow_id)
1386 deployment_vars = getattr(deployment, "job_variables", {}) or {}
1387 flow_run_vars = flow_run.job_variables or {}
1388 job_variables = {**deployment_vars}
1390 # merge environment variables carefully, otherwise full override
1391 if isinstance(job_variables.get("env"), dict):
1392 job_variables["env"].update(flow_run_vars.pop("env", {}))
1393 job_variables.update(flow_run_vars)
1395 configuration = await self.job_configuration.from_template_and_values(
1396 base_job_template=self.work_pool.base_job_template,
1397 values=job_variables,
1398 client=self.client,
1399 )
1400 try:
1401 configuration.prepare_for_flow_run(
1402 flow_run=flow_run,
1403 deployment=deployment,
1404 flow=flow,
1405 work_pool=self.work_pool,
1406 worker_name=self.name,
1407 )
1408 except TypeError:
1409 warnings.warn(
1410 "This worker is missing the `work_pool` and `worker_name` arguments "
1411 "in its JobConfiguration.prepare_for_flow_run method. Please update "
1412 "the worker's JobConfiguration class to accept these arguments to "
1413 "avoid this warning.",
1414 category=PrefectDeprecationWarning,
1415 )
1416 # Handle older subclasses that don't accept work_pool and worker_name
1417 configuration.prepare_for_flow_run(
1418 flow_run=flow_run, deployment=deployment, flow=flow
1419 )
1420 return configuration
1422 async def _propose_pending_state(self, flow_run: "FlowRun") -> bool: 1a
1423 run_logger = self.get_flow_run_logger(flow_run)
1424 state = flow_run.state
1425 try:
1426 state = await propose_state(self.client, Pending(), flow_run_id=flow_run.id)
1427 except Abort as exc:
1428 run_logger.info(
1429 (
1430 f"Aborted submission of flow run '{flow_run.id}'. "
1431 f"Server sent an abort signal: {exc}"
1432 ),
1433 )
1435 return False
1436 except Exception:
1437 run_logger.exception(
1438 f"Failed to update state of flow run '{flow_run.id}'",
1439 )
1440 return False
1442 if not state.is_pending():
1443 run_logger.info(
1444 (
1445 f"Aborted submission of flow run '{flow_run.id}': "
1446 f"Server returned a non-pending state {state.type.value!r}"
1447 ),
1448 )
1449 return False
1451 return True
1453 async def _propose_failed_state(self, flow_run: "FlowRun", exc: Exception) -> None: 1a
1454 run_logger = self.get_flow_run_logger(flow_run)
1455 try:
1456 await propose_state(
1457 self.client,
1458 await exception_to_failed_state(message="Submission failed.", exc=exc),
1459 flow_run_id=flow_run.id,
1460 )
1461 except Abort:
1462 # We've already failed, no need to note the abort but we don't want it to
1463 # raise in the agent process
1464 pass
1465 except Exception:
1466 run_logger.error(
1467 f"Failed to update state of flow run '{flow_run.id}'",
1468 exc_info=True,
1469 )
1471 async def _propose_crashed_state( 1a
1472 self, flow_run: "FlowRun", message: str, client: PrefectClient | None = None
1473 ) -> None:
1474 run_logger = self.get_flow_run_logger(flow_run)
1475 try:
1476 state = await propose_state(
1477 client or self.client,
1478 Crashed(message=message),
1479 flow_run_id=flow_run.id,
1480 )
1481 except Abort:
1482 # Flow run already marked as failed
1483 pass
1484 except ObjectNotFound:
1485 # Flow run was deleted - log it but don't crash the worker
1486 run_logger.debug(
1487 f"Flow run '{flow_run.id}' was deleted before state could be updated"
1488 )
1489 except Exception:
1490 run_logger.exception(f"Failed to update state of flow run '{flow_run.id}'")
1491 else:
1492 if state.is_crashed():
1493 run_logger.info(
1494 f"Reported flow run '{flow_run.id}' as crashed: {message}"
1495 )
1497 async def _mark_flow_run_as_cancelled( 1a
1498 self, flow_run: "FlowRun", state_updates: dict[str, Any] | None = None
1499 ) -> None:
1500 state_updates = state_updates or {}
1501 state_updates.setdefault("name", "Cancelled")
1503 if flow_run.state:
1504 state_updates.setdefault("type", StateType.CANCELLED)
1505 state = flow_run.state.model_copy(update=state_updates)
1506 else:
1507 # Unexpectedly when flow run does not have a state, create a new one
1508 # does not need to explicitly set the type
1509 state = Cancelled(**state_updates)
1511 try:
1512 await self.client.set_flow_run_state(flow_run.id, state, force=True)
1513 except ObjectNotFound:
1514 # Flow run was deleted - log it but don't crash the worker
1515 run_logger = self.get_flow_run_logger(flow_run)
1516 run_logger.debug(
1517 f"Flow run '{flow_run.id}' was deleted before it could be marked as cancelled"
1518 )
1520 # Do not remove the flow run from the cancelling set immediately because
1521 # the API caches responses for the `read_flow_runs` and we do not want to
1522 # duplicate cancellations.
1523 await self._schedule_task(
1524 60 * 10, self._cancelling_flow_run_ids.remove, flow_run.id
1525 )
1527 async def _set_work_pool_template( 1a
1528 self, work_pool: "WorkPool", job_template: dict[str, Any]
1529 ):
1530 """Updates the `base_job_template` for the worker's work pool server side."""
1532 await self.client.update_work_pool(
1533 work_pool_name=work_pool.name,
1534 work_pool=WorkPoolUpdate(
1535 base_job_template=job_template,
1536 ),
1537 )
1539 async def _schedule_task( 1a
1540 self, __in_seconds: int, fn: Callable[..., Any], *args: Any, **kwargs: Any
1541 ):
1542 """
1543 Schedule a background task to start after some time.
1545 These tasks will be run immediately when the worker exits instead of waiting.
1547 The function may be async or sync. Async functions will be awaited.
1548 """
1549 if not self._runs_task_group:
1550 raise RuntimeError(
1551 "Worker has not been correctly initialized. Please use the worker class as an async context manager."
1552 )
1554 async def wrapper(task_status: anyio.abc.TaskStatus[Any]):
1555 # If we are shutting down, do not sleep; otherwise sleep until the scheduled
1556 # time or shutdown
1557 if self.is_setup:
1558 with anyio.CancelScope() as scope:
1559 self._scheduled_task_scopes.add(scope)
1560 task_status.started()
1561 await anyio.sleep(__in_seconds)
1563 self._scheduled_task_scopes.remove(scope)
1564 else:
1565 task_status.started()
1567 result = fn(*args, **kwargs)
1568 if asyncio.iscoroutine(result):
1569 await result
1571 await self._runs_task_group.start(wrapper)
1573 async def _give_worker_labels_to_flow_run(self, flow_run_id: UUID): 1a
1574 """
1575 Give this worker's identifying labels to the specified flow run.
1576 """
1577 if self._client:
1578 labels: KeyValueLabels = {
1579 "prefect.worker.name": self.name,
1580 "prefect.worker.type": self.type,
1581 }
1583 if self._work_pool:
1584 labels.update(
1585 {
1586 "prefect.work-pool.name": self._work_pool.name,
1587 "prefect.work-pool.id": str(self._work_pool.id),
1588 }
1589 )
1591 await self._client.update_flow_run_labels(flow_run_id, labels)
1593 async def __aenter__(self) -> Self: 1a
1594 self._logger.debug("Entering worker context...")
1595 await self.setup()
1597 return self
1599 async def __aexit__(self, *exc_info: Any) -> None: 1a
1600 try:
1601 self._logger.debug("Exiting worker context...")
1602 await self.teardown(*exc_info)
1603 except (ExceptionGroup, BaseExceptionGroup) as exc:
1604 # For less verbose tracebacks
1605 exceptions = exc.exceptions
1606 if len(exceptions) == 1:
1607 raise exceptions[0] from None
1608 else:
1609 raise
1611 def __repr__(self) -> str: 1a
1612 return f"Worker(pool={self._work_pool_name!r}, name={self.name!r})"
1614 def _event_resource(self): 1a
1615 return {
1616 "prefect.resource.id": f"prefect.worker.{self.type}.{self.get_name_slug()}",
1617 "prefect.resource.name": self.name,
1618 "prefect.version": prefect.__version__,
1619 "prefect.worker-type": self.type,
1620 }
1622 def _event_related_resources( 1a
1623 self,
1624 configuration: BaseJobConfiguration | None = None,
1625 include_self: bool = False,
1626 ) -> list[RelatedResource]:
1627 related: list[RelatedResource] = []
1628 if configuration:
1629 related += getattr(configuration, "_related_resources")()
1631 if self._work_pool:
1632 related.append(
1633 object_as_related_resource(
1634 kind="work-pool", role="work-pool", object=self._work_pool
1635 )
1636 )
1638 if include_self:
1639 worker_resource = self._event_resource()
1640 worker_resource["prefect.resource.role"] = "worker"
1641 related.append(RelatedResource.model_validate(worker_resource))
1643 return related
1645 def _emit_flow_run_submitted_event( 1a
1646 self, configuration: BaseJobConfiguration
1647 ) -> Event | None:
1648 return emit_event(
1649 event="prefect.worker.submitted-flow-run",
1650 resource=self._event_resource(),
1651 related=self._event_related_resources(configuration=configuration),
1652 )
1654 def _emit_flow_run_executed_event( 1a
1655 self,
1656 result: BaseWorkerResult,
1657 configuration: BaseJobConfiguration,
1658 submitted_event: Event | None = None,
1659 ):
1660 related = self._event_related_resources(configuration=configuration)
1662 for resource in related:
1663 if resource.role == "flow-run":
1664 resource["prefect.infrastructure.identifier"] = str(result.identifier)
1665 resource["prefect.infrastructure.status-code"] = str(result.status_code)
1667 emit_event(
1668 event="prefect.worker.executed-flow-run",
1669 resource=self._event_resource(),
1670 related=related,
1671 follows=submitted_event,
1672 )
1674 async def _emit_worker_started_event(self) -> Event | None: 1a
1675 return emit_event(
1676 "prefect.worker.started",
1677 resource=self._event_resource(),
1678 related=self._event_related_resources(),
1679 )
1681 async def _emit_worker_stopped_event(self, started_event: Event): 1a
1682 emit_event(
1683 "prefect.worker.stopped",
1684 resource=self._event_resource(),
1685 related=self._event_related_resources(),
1686 follows=started_event,
1687 )