Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/models/task_workers.py: 46%
55 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
1import datetime 1a
2import time 1a
3from collections import defaultdict 1a
4from typing import Dict, List, Set 1a
6from pydantic import BaseModel 1a
7from typing_extensions import TypeAlias 1a
9from prefect.types import DateTime 1a
10from prefect.types._datetime import now 1a
12TaskKey: TypeAlias = str 1a
13WorkerId: TypeAlias = str 1a
16class TaskWorkerResponse(BaseModel): 1a
17 identifier: WorkerId 1a
18 task_keys: List[TaskKey] 1a
19 timestamp: DateTime 1a
22class InMemoryTaskWorkerTracker: 1a
23 def __init__(self) -> None: 1a
24 self.workers: dict[WorkerId, Set[TaskKey]] = {} 1a
25 self.task_keys: Dict[TaskKey, Set[WorkerId]] = defaultdict(set) 1a
26 self.worker_timestamps: Dict[WorkerId, float] = {} 1a
28 async def observe_worker( 1a
29 self,
30 task_keys: List[TaskKey],
31 worker_id: WorkerId,
32 ) -> None:
33 self.workers[worker_id] = self.workers.get(worker_id, set()) | set(task_keys)
34 self.worker_timestamps[worker_id] = time.monotonic()
36 for task_key in task_keys:
37 self.task_keys[task_key].add(worker_id)
39 async def forget_worker( 1a
40 self,
41 worker_id: WorkerId,
42 ) -> None:
43 if worker_id in self.workers:
44 task_keys = self.workers.pop(worker_id)
45 for task_key in task_keys:
46 self.task_keys[task_key].discard(worker_id)
47 if not self.task_keys[task_key]:
48 del self.task_keys[task_key]
49 self.worker_timestamps.pop(worker_id, None)
51 async def get_workers_for_task_keys( 1a
52 self,
53 task_keys: List[TaskKey],
54 ) -> List[TaskWorkerResponse]:
55 if not task_keys:
56 return await self.get_all_workers()
57 active_workers = set().union(*(self.task_keys[key] for key in task_keys))
58 return [self._create_worker_response(worker_id) for worker_id in active_workers]
60 async def get_all_workers(self) -> List[TaskWorkerResponse]: 1a
61 return [
62 self._create_worker_response(worker_id)
63 for worker_id in self.worker_timestamps.keys()
64 ]
66 def _create_worker_response(self, worker_id: WorkerId) -> TaskWorkerResponse: 1a
67 timestamp = time.monotonic() - self.worker_timestamps[worker_id]
68 return TaskWorkerResponse(
69 identifier=worker_id,
70 task_keys=list(self.workers.get(worker_id, set())),
71 timestamp=now("UTC") - datetime.timedelta(seconds=timestamp),
72 )
74 def reset(self) -> None: 1a
75 """Testing utility to reset the state of the task worker tracker"""
76 self.workers.clear()
77 self.task_keys.clear()
78 self.worker_timestamps.clear()
81# Global instance of the task worker tracker
82task_worker_tracker: InMemoryTaskWorkerTracker = InMemoryTaskWorkerTracker() 1a
85# Main utilities to be used in the API layer
86async def observe_worker( 1a
87 task_keys: List[TaskKey],
88 worker_id: WorkerId,
89) -> None:
90 await task_worker_tracker.observe_worker(task_keys, worker_id)
93async def forget_worker( 1a
94 worker_id: WorkerId,
95) -> None:
96 await task_worker_tracker.forget_worker(worker_id)
99async def get_workers_for_task_keys( 1a
100 task_keys: List[TaskKey],
101) -> List[TaskWorkerResponse]:
102 return await task_worker_tracker.get_workers_for_task_keys(task_keys)
105async def get_all_workers() -> List[TaskWorkerResponse]: 1a
106 return await task_worker_tracker.get_all_workers()