Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/models/task_workers.py: 46%

55 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 13:38 +0000

1import datetime 1a

2import time 1a

3from collections import defaultdict 1a

4from typing import Dict, List, Set 1a

5 

6from pydantic import BaseModel 1a

7from typing_extensions import TypeAlias 1a

8 

9from prefect.types import DateTime 1a

10from prefect.types._datetime import now 1a

11 

12TaskKey: TypeAlias = str 1a

13WorkerId: TypeAlias = str 1a

14 

15 

16class TaskWorkerResponse(BaseModel): 1a

17 identifier: WorkerId 1a

18 task_keys: List[TaskKey] 1a

19 timestamp: DateTime 1a

20 

21 

22class InMemoryTaskWorkerTracker: 1a

23 def __init__(self) -> None: 1a

24 self.workers: dict[WorkerId, Set[TaskKey]] = {} 1a

25 self.task_keys: Dict[TaskKey, Set[WorkerId]] = defaultdict(set) 1a

26 self.worker_timestamps: Dict[WorkerId, float] = {} 1a

27 

28 async def observe_worker( 1a

29 self, 

30 task_keys: List[TaskKey], 

31 worker_id: WorkerId, 

32 ) -> None: 

33 self.workers[worker_id] = self.workers.get(worker_id, set()) | set(task_keys) 

34 self.worker_timestamps[worker_id] = time.monotonic() 

35 

36 for task_key in task_keys: 

37 self.task_keys[task_key].add(worker_id) 

38 

39 async def forget_worker( 1a

40 self, 

41 worker_id: WorkerId, 

42 ) -> None: 

43 if worker_id in self.workers: 

44 task_keys = self.workers.pop(worker_id) 

45 for task_key in task_keys: 

46 self.task_keys[task_key].discard(worker_id) 

47 if not self.task_keys[task_key]: 

48 del self.task_keys[task_key] 

49 self.worker_timestamps.pop(worker_id, None) 

50 

51 async def get_workers_for_task_keys( 1a

52 self, 

53 task_keys: List[TaskKey], 

54 ) -> List[TaskWorkerResponse]: 

55 if not task_keys: 

56 return await self.get_all_workers() 

57 active_workers = set().union(*(self.task_keys[key] for key in task_keys)) 

58 return [self._create_worker_response(worker_id) for worker_id in active_workers] 

59 

60 async def get_all_workers(self) -> List[TaskWorkerResponse]: 1a

61 return [ 

62 self._create_worker_response(worker_id) 

63 for worker_id in self.worker_timestamps.keys() 

64 ] 

65 

66 def _create_worker_response(self, worker_id: WorkerId) -> TaskWorkerResponse: 1a

67 timestamp = time.monotonic() - self.worker_timestamps[worker_id] 

68 return TaskWorkerResponse( 

69 identifier=worker_id, 

70 task_keys=list(self.workers.get(worker_id, set())), 

71 timestamp=now("UTC") - datetime.timedelta(seconds=timestamp), 

72 ) 

73 

74 def reset(self) -> None: 1a

75 """Testing utility to reset the state of the task worker tracker""" 

76 self.workers.clear() 

77 self.task_keys.clear() 

78 self.worker_timestamps.clear() 

79 

80 

81# Global instance of the task worker tracker 

82task_worker_tracker: InMemoryTaskWorkerTracker = InMemoryTaskWorkerTracker() 1a

83 

84 

85# Main utilities to be used in the API layer 

86async def observe_worker( 1a

87 task_keys: List[TaskKey], 

88 worker_id: WorkerId, 

89) -> None: 

90 await task_worker_tracker.observe_worker(task_keys, worker_id) 

91 

92 

93async def forget_worker( 1a

94 worker_id: WorkerId, 

95) -> None: 

96 await task_worker_tracker.forget_worker(worker_id) 

97 

98 

99async def get_workers_for_task_keys( 1a

100 task_keys: List[TaskKey], 

101) -> List[TaskWorkerResponse]: 

102 return await task_worker_tracker.get_workers_for_task_keys(task_keys) 

103 

104 

105async def get_all_workers() -> List[TaskWorkerResponse]: 1a

106 return await task_worker_tracker.get_all_workers()