Coverage for /usr/local/lib/python3.12/site-packages/prefect/cache_policies.py: 42%

159 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 13:38 +0000

1import inspect 1a

2from copy import deepcopy 1a

3from dataclasses import dataclass, field 1a

4from pathlib import Path 1a

5from typing import ( 1a

6 TYPE_CHECKING, 

7 Any, 

8 Callable, 

9 Dict, 

10 Literal, 

11 Optional, 

12 Union, 

13) 

14 

15from typing_extensions import Self 1a

16 

17from prefect.context import TaskRunContext 1a

18from prefect.exceptions import HashError 1a

19from prefect.utilities.hashing import hash_objects 1a

20 

21if TYPE_CHECKING: 21 ↛ 22line 21 didn't jump to line 22 because the condition on line 21 was never true1a

22 from prefect.filesystems import WritableFileSystem 

23 from prefect.locking.protocol import LockManager 

24 from prefect.transactions import IsolationLevel 

25 

26STABLE_TRANSFORMS: dict[type, Callable[[Any], Any]] = {} 1a

27 

28 

29def _register_stable_transforms() -> None: 1a

30 """ 

31 Some inputs do not reliably produce deterministic byte strings when serialized via 

32 `cloudpickle`. This utility registers stabilizing transformations of such types 

33 so that cache keys that utilize them are deterministic across invocations. 

34 """ 

35 try: 1a

36 import pandas as pd # pyright: ignore 1a

37 

38 STABLE_TRANSFORMS[pd.DataFrame] = lambda df: [ # pyright: ignore 

39 df[col] for col in sorted(df.columns) 

40 ] 

41 except (ImportError, ModuleNotFoundError): 1a

42 pass 1a

43 

44 

45@dataclass 1a

46class CachePolicy: 1a

47 """ 

48 Base class for all cache policies. 

49 """ 

50 

51 key_storage: Union["WritableFileSystem", str, Path, None] = None 1a

52 isolation_level: Union[ 1a

53 Literal["READ_COMMITTED", "SERIALIZABLE"], 

54 "IsolationLevel", 

55 None, 

56 ] = None 

57 lock_manager: Optional["LockManager"] = None 1a

58 

59 @classmethod 1a

60 def from_cache_key_fn( 1a

61 cls, cache_key_fn: Callable[["TaskRunContext", Dict[str, Any]], Optional[str]] 

62 ) -> "CacheKeyFnPolicy": 

63 """ 

64 Given a function generates a key policy. 

65 """ 

66 return CacheKeyFnPolicy(cache_key_fn=cache_key_fn) 

67 

68 def configure( 1a

69 self, 

70 key_storage: Union["WritableFileSystem", str, Path, None] = None, 

71 lock_manager: Optional["LockManager"] = None, 

72 isolation_level: Union[ 

73 Literal["READ_COMMITTED", "SERIALIZABLE"], "IsolationLevel", None 

74 ] = None, 

75 ) -> Self: 

76 """ 

77 Configure the cache policy with the given key storage, lock manager, and isolation level. 

78 

79 Args: 

80 key_storage: The storage to use for cache keys. If not provided, 

81 the current key storage will be used. 

82 lock_manager: The lock manager to use for the cache policy. If not provided, 

83 the current lock manager will be used. 

84 isolation_level: The isolation level to use for the cache policy. If not provided, 

85 the current isolation level will be used. 

86 

87 Returns: 

88 A new cache policy with the given key storage, lock manager, and isolation level. 

89 """ 

90 new = deepcopy(self) 

91 if key_storage is not None: 

92 new.key_storage = key_storage 

93 if lock_manager is not None: 

94 new.lock_manager = lock_manager 

95 if isolation_level is not None: 

96 new.isolation_level = isolation_level 

97 return new 

98 

99 def compute_key( 1a

100 self, 

101 task_ctx: TaskRunContext, 

102 inputs: dict[str, Any], 

103 flow_parameters: dict[str, Any], 

104 **kwargs: Any, 

105 ) -> Optional[str]: 

106 raise NotImplementedError 

107 

108 def __sub__(self, other: str) -> "CachePolicy": 1a

109 "No-op for all policies except Inputs and Compound" 

110 

111 # for interface compatibility 

112 if not isinstance(other, str): # type: ignore[reportUnnecessaryIsInstance] 

113 raise TypeError("Can only subtract strings from key policies.") 

114 return self 

115 

116 def __add__(self, other: "CachePolicy") -> "CachePolicy": 1a

117 # adding _None is a no-op 

118 if isinstance(other, _None): 118 ↛ 119line 118 didn't jump to line 119 because the condition on line 118 was never true1a

119 return self 

120 

121 if ( 121 ↛ 126line 121 didn't jump to line 126 because the condition on line 121 was never true

122 other.key_storage is not None 

123 and self.key_storage is not None 

124 and other.key_storage != self.key_storage 

125 ): 

126 raise ValueError( 

127 "Cannot add CachePolicies with different storage locations." 

128 ) 

129 if ( 129 ↛ 134line 129 didn't jump to line 134 because the condition on line 129 was never true

130 other.isolation_level is not None 

131 and self.isolation_level is not None 

132 and other.isolation_level != self.isolation_level 

133 ): 

134 raise ValueError( 

135 "Cannot add CachePolicies with different isolation levels." 

136 ) 

137 if ( 137 ↛ 142line 137 didn't jump to line 142 because the condition on line 137 was never true

138 other.lock_manager is not None 

139 and self.lock_manager is not None 

140 and other.lock_manager != self.lock_manager 

141 ): 

142 raise ValueError( 

143 "Cannot add CachePolicies with different lock implementations." 

144 ) 

145 

146 return CompoundCachePolicy( 1a

147 policies=[self, other], 

148 key_storage=self.key_storage or other.key_storage, 

149 isolation_level=self.isolation_level or other.isolation_level, 

150 lock_manager=self.lock_manager or other.lock_manager, 

151 ) 

152 

153 

154@dataclass 1a

155class CacheKeyFnPolicy(CachePolicy): 1a

156 """ 

157 This policy accepts a custom function with signature f(task_run_context, task_parameters, flow_parameters) -> str 

158 and uses it to compute a task run cache key. 

159 """ 

160 

161 # making it optional for tests 

162 cache_key_fn: Optional[ 1a

163 Callable[["TaskRunContext", dict[str, Any]], Optional[str]] 

164 ] = None 

165 

166 def compute_key( 1a

167 self, 

168 task_ctx: TaskRunContext, 

169 inputs: dict[str, Any], 

170 flow_parameters: dict[str, Any], 

171 **kwargs: Any, 

172 ) -> Optional[str]: 

173 if self.cache_key_fn: 

174 return self.cache_key_fn(task_ctx, inputs) 

175 

176 

177@dataclass 1a

178class CompoundCachePolicy(CachePolicy): 1a

179 """ 

180 This policy is constructed from two or more other cache policies and works by computing the keys 

181 for each policy individually, and then hashing a sorted tuple of all computed keys. 

182 

183 Any keys that return `None` will be ignored. 

184 """ 

185 

186 policies: list[CachePolicy] = field(default_factory=lambda: []) 1a

187 

188 def __post_init__(self) -> None: 1a

189 # flatten any CompoundCachePolicies 

190 self.policies = [ 1a

191 policy 

192 for p in self.policies 

193 for policy in (p.policies if isinstance(p, CompoundCachePolicy) else [p]) 

194 ] 

195 

196 # deduplicate any Inputs policies 

197 inputs_policies = [p for p in self.policies if isinstance(p, Inputs)] 1a

198 self.policies = [p for p in self.policies if not isinstance(p, Inputs)] 1a

199 if inputs_policies: 199 ↛ exitline 199 didn't return from function '__post_init__' because the condition on line 199 was always true1a

200 all_excludes: set[str] = set() 1a

201 for inputs_policy in inputs_policies: 1a

202 all_excludes.update(inputs_policy.exclude) 1a

203 self.policies.append(Inputs(exclude=sorted(all_excludes))) 1a

204 

205 def compute_key( 1a

206 self, 

207 task_ctx: TaskRunContext, 

208 inputs: dict[str, Any], 

209 flow_parameters: dict[str, Any], 

210 **kwargs: Any, 

211 ) -> Optional[str]: 

212 keys: list[str] = [] 

213 for policy in self.policies: 

214 policy_key = policy.compute_key( 

215 task_ctx=task_ctx, 

216 inputs=inputs, 

217 flow_parameters=flow_parameters, 

218 **kwargs, 

219 ) 

220 if policy_key is not None: 

221 keys.append(policy_key) 

222 if not keys: 

223 return None 

224 return hash_objects(*keys, raise_on_failure=True) 

225 

226 def __add__(self, other: "CachePolicy") -> "CachePolicy": 1a

227 # Call the superclass add method to handle validation 

228 super().__add__(other) 1a

229 

230 if isinstance(other, CompoundCachePolicy): 230 ↛ 231line 230 didn't jump to line 231 because the condition on line 230 was never true1a

231 policies = [*self.policies, *other.policies] 

232 else: 

233 policies = [*self.policies, other] 1a

234 

235 return CompoundCachePolicy( 1a

236 policies=policies, 

237 key_storage=self.key_storage or other.key_storage, 

238 isolation_level=self.isolation_level or other.isolation_level, 

239 lock_manager=self.lock_manager or other.lock_manager, 

240 ) 

241 

242 def __sub__(self, other: str) -> "CachePolicy": 1a

243 if not isinstance(other, str): # type: ignore[reportUnnecessaryIsInstance] 

244 raise TypeError("Can only subtract strings from key policies.") 

245 

246 inputs_policies = [p for p in self.policies if isinstance(p, Inputs)] 

247 

248 if inputs_policies: 

249 new = Inputs(exclude=[other]) 

250 return CompoundCachePolicy(policies=[*self.policies, new]) 

251 else: 

252 # no dependency on inputs already 

253 return self 

254 

255 

256@dataclass 1a

257class _None(CachePolicy): 1a

258 """ 

259 Policy that always returns `None` for the computed cache key. 

260 This policy prevents persistence and avoids caching entirely. 

261 """ 

262 

263 def compute_key( 1a

264 self, 

265 task_ctx: TaskRunContext, 

266 inputs: dict[str, Any], 

267 flow_parameters: dict[str, Any], 

268 **kwargs: Any, 

269 ) -> Optional[str]: 

270 return None 

271 

272 def __add__(self, other: "CachePolicy") -> "CachePolicy": 1a

273 # adding _None is a no-op 

274 return other 

275 

276 

277@dataclass 1a

278class TaskSource(CachePolicy): 1a

279 """ 

280 Policy for computing a cache key based on the source code of the task. 

281 

282 This policy only considers raw lines of code in the task, and not the source code of nested tasks. 

283 """ 

284 

285 def compute_key( 1a

286 self, 

287 task_ctx: TaskRunContext, 

288 inputs: Optional[dict[str, Any]], 

289 flow_parameters: Optional[dict[str, Any]], 

290 **kwargs: Any, 

291 ) -> Optional[str]: 

292 if not task_ctx: 

293 return None 

294 try: 

295 lines = inspect.getsource(task_ctx.task) 

296 except TypeError: 

297 lines = inspect.getsource(task_ctx.task.fn.__class__) 

298 except OSError as exc: 

299 if "source code" in str(exc): 

300 lines = task_ctx.task.fn.__code__.co_code 

301 else: 

302 raise 

303 return hash_objects(lines, raise_on_failure=True) 

304 

305 

306@dataclass 1a

307class FlowParameters(CachePolicy): 1a

308 """ 

309 Policy that computes the cache key based on a hash of the flow parameters. 

310 """ 

311 

312 def compute_key( 1a

313 self, 

314 task_ctx: TaskRunContext, 

315 inputs: dict[str, Any], 

316 flow_parameters: dict[str, Any], 

317 **kwargs: Any, 

318 ) -> Optional[str]: 

319 if not flow_parameters: 

320 return None 

321 return hash_objects(flow_parameters, raise_on_failure=True) 

322 

323 

324@dataclass 1a

325class RunId(CachePolicy): 1a

326 """ 

327 Returns either the prevailing flow run ID, or if not found, the prevailing task 

328 run ID. 

329 """ 

330 

331 def compute_key( 1a

332 self, 

333 task_ctx: TaskRunContext, 

334 inputs: dict[str, Any], 

335 flow_parameters: dict[str, Any], 

336 **kwargs: Any, 

337 ) -> Optional[str]: 

338 if not task_ctx: 

339 return None 

340 run_id = task_ctx.task_run.flow_run_id 

341 if run_id is None: 

342 run_id = task_ctx.task_run.id 

343 return str(run_id) 

344 

345 

346@dataclass 1a

347class Inputs(CachePolicy): 1a

348 """ 

349 Policy that computes a cache key based on a hash of the runtime inputs provided to the task.. 

350 """ 

351 

352 exclude: list[str] = field(default_factory=lambda: []) 1a

353 

354 def compute_key( 1a

355 self, 

356 task_ctx: TaskRunContext, 

357 inputs: dict[str, Any], 

358 flow_parameters: dict[str, Any], 

359 **kwargs: Any, 

360 ) -> Optional[str]: 

361 hashed_inputs = {} 

362 inputs = inputs or {} 

363 exclude = self.exclude or [] 

364 

365 if not inputs: 

366 return None 

367 

368 for key, val in inputs.items(): 

369 if key not in exclude: 

370 transformer = STABLE_TRANSFORMS.get(type(val)) # type: ignore[reportUnknownMemberType] 

371 hashed_inputs[key] = transformer(val) if transformer else val 

372 

373 try: 

374 return hash_objects(hashed_inputs, raise_on_failure=True) 

375 except HashError as exc: 

376 msg = ( 

377 f"{exc}\n\n" 

378 "This often occurs when task inputs contain objects that cannot be cached " 

379 "like locks, file handles, or other system resources.\n\n" 

380 "To resolve this, you can:\n" 

381 " 1. Exclude these arguments by defining a custom `cache_key_fn`\n" 

382 " 2. Disable caching by passing `cache_policy=NO_CACHE`\n" 

383 ) 

384 raise ValueError(msg) from exc 

385 

386 def __sub__(self, other: str) -> "CachePolicy": 1a

387 if not isinstance(other, str): # type: ignore[reportUnnecessaryIsInstance] 

388 raise TypeError("Can only subtract strings from key policies.") 

389 return Inputs(exclude=self.exclude + [other]) 

390 

391 

392_register_stable_transforms() 1a

393 

394INPUTS = Inputs() 1a

395NONE = _None() 1a

396NO_CACHE = _None() 1a

397TASK_SOURCE = TaskSource() 1a

398FLOW_PARAMETERS = FlowParameters() 1a

399RUN_ID = RunId() 1a

400DEFAULT = INPUTS + TASK_SOURCE + RUN_ID 1a