Coverage for /usr/local/lib/python3.12/site-packages/prefect/cache_policies.py: 42%
159 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
1import inspect 1a
2from copy import deepcopy 1a
3from dataclasses import dataclass, field 1a
4from pathlib import Path 1a
5from typing import ( 1a
6 TYPE_CHECKING,
7 Any,
8 Callable,
9 Dict,
10 Literal,
11 Optional,
12 Union,
13)
15from typing_extensions import Self 1a
17from prefect.context import TaskRunContext 1a
18from prefect.exceptions import HashError 1a
19from prefect.utilities.hashing import hash_objects 1a
21if TYPE_CHECKING: 21 ↛ 22line 21 didn't jump to line 22 because the condition on line 21 was never true1a
22 from prefect.filesystems import WritableFileSystem
23 from prefect.locking.protocol import LockManager
24 from prefect.transactions import IsolationLevel
26STABLE_TRANSFORMS: dict[type, Callable[[Any], Any]] = {} 1a
29def _register_stable_transforms() -> None: 1a
30 """
31 Some inputs do not reliably produce deterministic byte strings when serialized via
32 `cloudpickle`. This utility registers stabilizing transformations of such types
33 so that cache keys that utilize them are deterministic across invocations.
34 """
35 try: 1a
36 import pandas as pd # pyright: ignore 1a
38 STABLE_TRANSFORMS[pd.DataFrame] = lambda df: [ # pyright: ignore
39 df[col] for col in sorted(df.columns)
40 ]
41 except (ImportError, ModuleNotFoundError): 1a
42 pass 1a
45@dataclass 1a
46class CachePolicy: 1a
47 """
48 Base class for all cache policies.
49 """
51 key_storage: Union["WritableFileSystem", str, Path, None] = None 1a
52 isolation_level: Union[ 1a
53 Literal["READ_COMMITTED", "SERIALIZABLE"],
54 "IsolationLevel",
55 None,
56 ] = None
57 lock_manager: Optional["LockManager"] = None 1a
59 @classmethod 1a
60 def from_cache_key_fn( 1a
61 cls, cache_key_fn: Callable[["TaskRunContext", Dict[str, Any]], Optional[str]]
62 ) -> "CacheKeyFnPolicy":
63 """
64 Given a function generates a key policy.
65 """
66 return CacheKeyFnPolicy(cache_key_fn=cache_key_fn)
68 def configure( 1a
69 self,
70 key_storage: Union["WritableFileSystem", str, Path, None] = None,
71 lock_manager: Optional["LockManager"] = None,
72 isolation_level: Union[
73 Literal["READ_COMMITTED", "SERIALIZABLE"], "IsolationLevel", None
74 ] = None,
75 ) -> Self:
76 """
77 Configure the cache policy with the given key storage, lock manager, and isolation level.
79 Args:
80 key_storage: The storage to use for cache keys. If not provided,
81 the current key storage will be used.
82 lock_manager: The lock manager to use for the cache policy. If not provided,
83 the current lock manager will be used.
84 isolation_level: The isolation level to use for the cache policy. If not provided,
85 the current isolation level will be used.
87 Returns:
88 A new cache policy with the given key storage, lock manager, and isolation level.
89 """
90 new = deepcopy(self)
91 if key_storage is not None:
92 new.key_storage = key_storage
93 if lock_manager is not None:
94 new.lock_manager = lock_manager
95 if isolation_level is not None:
96 new.isolation_level = isolation_level
97 return new
99 def compute_key( 1a
100 self,
101 task_ctx: TaskRunContext,
102 inputs: dict[str, Any],
103 flow_parameters: dict[str, Any],
104 **kwargs: Any,
105 ) -> Optional[str]:
106 raise NotImplementedError
108 def __sub__(self, other: str) -> "CachePolicy": 1a
109 "No-op for all policies except Inputs and Compound"
111 # for interface compatibility
112 if not isinstance(other, str): # type: ignore[reportUnnecessaryIsInstance]
113 raise TypeError("Can only subtract strings from key policies.")
114 return self
116 def __add__(self, other: "CachePolicy") -> "CachePolicy": 1a
117 # adding _None is a no-op
118 if isinstance(other, _None): 118 ↛ 119line 118 didn't jump to line 119 because the condition on line 118 was never true1a
119 return self
121 if ( 121 ↛ 126line 121 didn't jump to line 126 because the condition on line 121 was never true
122 other.key_storage is not None
123 and self.key_storage is not None
124 and other.key_storage != self.key_storage
125 ):
126 raise ValueError(
127 "Cannot add CachePolicies with different storage locations."
128 )
129 if ( 129 ↛ 134line 129 didn't jump to line 134 because the condition on line 129 was never true
130 other.isolation_level is not None
131 and self.isolation_level is not None
132 and other.isolation_level != self.isolation_level
133 ):
134 raise ValueError(
135 "Cannot add CachePolicies with different isolation levels."
136 )
137 if ( 137 ↛ 142line 137 didn't jump to line 142 because the condition on line 137 was never true
138 other.lock_manager is not None
139 and self.lock_manager is not None
140 and other.lock_manager != self.lock_manager
141 ):
142 raise ValueError(
143 "Cannot add CachePolicies with different lock implementations."
144 )
146 return CompoundCachePolicy( 1a
147 policies=[self, other],
148 key_storage=self.key_storage or other.key_storage,
149 isolation_level=self.isolation_level or other.isolation_level,
150 lock_manager=self.lock_manager or other.lock_manager,
151 )
154@dataclass 1a
155class CacheKeyFnPolicy(CachePolicy): 1a
156 """
157 This policy accepts a custom function with signature f(task_run_context, task_parameters, flow_parameters) -> str
158 and uses it to compute a task run cache key.
159 """
161 # making it optional for tests
162 cache_key_fn: Optional[ 1a
163 Callable[["TaskRunContext", dict[str, Any]], Optional[str]]
164 ] = None
166 def compute_key( 1a
167 self,
168 task_ctx: TaskRunContext,
169 inputs: dict[str, Any],
170 flow_parameters: dict[str, Any],
171 **kwargs: Any,
172 ) -> Optional[str]:
173 if self.cache_key_fn:
174 return self.cache_key_fn(task_ctx, inputs)
177@dataclass 1a
178class CompoundCachePolicy(CachePolicy): 1a
179 """
180 This policy is constructed from two or more other cache policies and works by computing the keys
181 for each policy individually, and then hashing a sorted tuple of all computed keys.
183 Any keys that return `None` will be ignored.
184 """
186 policies: list[CachePolicy] = field(default_factory=lambda: []) 1a
188 def __post_init__(self) -> None: 1a
189 # flatten any CompoundCachePolicies
190 self.policies = [ 1a
191 policy
192 for p in self.policies
193 for policy in (p.policies if isinstance(p, CompoundCachePolicy) else [p])
194 ]
196 # deduplicate any Inputs policies
197 inputs_policies = [p for p in self.policies if isinstance(p, Inputs)] 1a
198 self.policies = [p for p in self.policies if not isinstance(p, Inputs)] 1a
199 if inputs_policies: 199 ↛ exitline 199 didn't return from function '__post_init__' because the condition on line 199 was always true1a
200 all_excludes: set[str] = set() 1a
201 for inputs_policy in inputs_policies: 1a
202 all_excludes.update(inputs_policy.exclude) 1a
203 self.policies.append(Inputs(exclude=sorted(all_excludes))) 1a
205 def compute_key( 1a
206 self,
207 task_ctx: TaskRunContext,
208 inputs: dict[str, Any],
209 flow_parameters: dict[str, Any],
210 **kwargs: Any,
211 ) -> Optional[str]:
212 keys: list[str] = []
213 for policy in self.policies:
214 policy_key = policy.compute_key(
215 task_ctx=task_ctx,
216 inputs=inputs,
217 flow_parameters=flow_parameters,
218 **kwargs,
219 )
220 if policy_key is not None:
221 keys.append(policy_key)
222 if not keys:
223 return None
224 return hash_objects(*keys, raise_on_failure=True)
226 def __add__(self, other: "CachePolicy") -> "CachePolicy": 1a
227 # Call the superclass add method to handle validation
228 super().__add__(other) 1a
230 if isinstance(other, CompoundCachePolicy): 230 ↛ 231line 230 didn't jump to line 231 because the condition on line 230 was never true1a
231 policies = [*self.policies, *other.policies]
232 else:
233 policies = [*self.policies, other] 1a
235 return CompoundCachePolicy( 1a
236 policies=policies,
237 key_storage=self.key_storage or other.key_storage,
238 isolation_level=self.isolation_level or other.isolation_level,
239 lock_manager=self.lock_manager or other.lock_manager,
240 )
242 def __sub__(self, other: str) -> "CachePolicy": 1a
243 if not isinstance(other, str): # type: ignore[reportUnnecessaryIsInstance]
244 raise TypeError("Can only subtract strings from key policies.")
246 inputs_policies = [p for p in self.policies if isinstance(p, Inputs)]
248 if inputs_policies:
249 new = Inputs(exclude=[other])
250 return CompoundCachePolicy(policies=[*self.policies, new])
251 else:
252 # no dependency on inputs already
253 return self
256@dataclass 1a
257class _None(CachePolicy): 1a
258 """
259 Policy that always returns `None` for the computed cache key.
260 This policy prevents persistence and avoids caching entirely.
261 """
263 def compute_key( 1a
264 self,
265 task_ctx: TaskRunContext,
266 inputs: dict[str, Any],
267 flow_parameters: dict[str, Any],
268 **kwargs: Any,
269 ) -> Optional[str]:
270 return None
272 def __add__(self, other: "CachePolicy") -> "CachePolicy": 1a
273 # adding _None is a no-op
274 return other
277@dataclass 1a
278class TaskSource(CachePolicy): 1a
279 """
280 Policy for computing a cache key based on the source code of the task.
282 This policy only considers raw lines of code in the task, and not the source code of nested tasks.
283 """
285 def compute_key( 1a
286 self,
287 task_ctx: TaskRunContext,
288 inputs: Optional[dict[str, Any]],
289 flow_parameters: Optional[dict[str, Any]],
290 **kwargs: Any,
291 ) -> Optional[str]:
292 if not task_ctx:
293 return None
294 try:
295 lines = inspect.getsource(task_ctx.task)
296 except TypeError:
297 lines = inspect.getsource(task_ctx.task.fn.__class__)
298 except OSError as exc:
299 if "source code" in str(exc):
300 lines = task_ctx.task.fn.__code__.co_code
301 else:
302 raise
303 return hash_objects(lines, raise_on_failure=True)
306@dataclass 1a
307class FlowParameters(CachePolicy): 1a
308 """
309 Policy that computes the cache key based on a hash of the flow parameters.
310 """
312 def compute_key( 1a
313 self,
314 task_ctx: TaskRunContext,
315 inputs: dict[str, Any],
316 flow_parameters: dict[str, Any],
317 **kwargs: Any,
318 ) -> Optional[str]:
319 if not flow_parameters:
320 return None
321 return hash_objects(flow_parameters, raise_on_failure=True)
324@dataclass 1a
325class RunId(CachePolicy): 1a
326 """
327 Returns either the prevailing flow run ID, or if not found, the prevailing task
328 run ID.
329 """
331 def compute_key( 1a
332 self,
333 task_ctx: TaskRunContext,
334 inputs: dict[str, Any],
335 flow_parameters: dict[str, Any],
336 **kwargs: Any,
337 ) -> Optional[str]:
338 if not task_ctx:
339 return None
340 run_id = task_ctx.task_run.flow_run_id
341 if run_id is None:
342 run_id = task_ctx.task_run.id
343 return str(run_id)
346@dataclass 1a
347class Inputs(CachePolicy): 1a
348 """
349 Policy that computes a cache key based on a hash of the runtime inputs provided to the task..
350 """
352 exclude: list[str] = field(default_factory=lambda: []) 1a
354 def compute_key( 1a
355 self,
356 task_ctx: TaskRunContext,
357 inputs: dict[str, Any],
358 flow_parameters: dict[str, Any],
359 **kwargs: Any,
360 ) -> Optional[str]:
361 hashed_inputs = {}
362 inputs = inputs or {}
363 exclude = self.exclude or []
365 if not inputs:
366 return None
368 for key, val in inputs.items():
369 if key not in exclude:
370 transformer = STABLE_TRANSFORMS.get(type(val)) # type: ignore[reportUnknownMemberType]
371 hashed_inputs[key] = transformer(val) if transformer else val
373 try:
374 return hash_objects(hashed_inputs, raise_on_failure=True)
375 except HashError as exc:
376 msg = (
377 f"{exc}\n\n"
378 "This often occurs when task inputs contain objects that cannot be cached "
379 "like locks, file handles, or other system resources.\n\n"
380 "To resolve this, you can:\n"
381 " 1. Exclude these arguments by defining a custom `cache_key_fn`\n"
382 " 2. Disable caching by passing `cache_policy=NO_CACHE`\n"
383 )
384 raise ValueError(msg) from exc
386 def __sub__(self, other: str) -> "CachePolicy": 1a
387 if not isinstance(other, str): # type: ignore[reportUnnecessaryIsInstance]
388 raise TypeError("Can only subtract strings from key policies.")
389 return Inputs(exclude=self.exclude + [other])
392_register_stable_transforms() 1a
394INPUTS = Inputs() 1a
395NONE = _None() 1a
396NO_CACHE = _None() 1a
397TASK_SOURCE = TaskSource() 1a
398FLOW_PARAMETERS = FlowParameters() 1a
399RUN_ID = RunId() 1a
400DEFAULT = INPUTS + TASK_SOURCE + RUN_ID 1a