Coverage for /usr/local/lib/python3.12/site-packages/prefect/_internal/pydantic/validated_func.py: 38%

156 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 13:38 +0000

1""" 

2Pure Pydantic v2 implementation of function argument validation. 

3 

4This module provides validation of function arguments without calling the function, 

5compatible with Python 3.14+ (no Pydantic v1 dependency). 

6""" 

7 

8from __future__ import annotations 1a

9 

10import inspect 1a

11from typing import Any, Callable, ClassVar, Optional 1a

12 

13from pydantic import ( 1a

14 BaseModel, 

15 ConfigDict, 

16 Field, 

17 ValidationError, 

18 create_model, 

19 field_validator, 

20) 

21 

22# Special field names for validation 

23# These match pydantic.v1.decorator constants for compatibility 

24V_ARGS_NAME = "v__args" 1a

25V_KWARGS_NAME = "v__kwargs" 1a

26V_POSITIONAL_ONLY_NAME = "v__positional_only" 1a

27V_DUPLICATE_KWARGS = "v__duplicate_kwargs" 1a

28 

29 

30class ValidatedFunction: 1a

31 """ 

32 Validates function arguments using Pydantic v2 without calling the function. 

33 

34 This class inspects a function's signature and creates a Pydantic model 

35 that can validate arguments passed to the function, including handling 

36 of *args, **kwargs, positional-only parameters, and duplicate arguments. 

37 

38 Example: 

39 ```python 

40 def greet(name: str, age: int = 0): 

41 return f"Hello {name}, you are {age} years old" 

42 

43 vf = ValidatedFunction(greet) 

44 

45 # Validate arguments 

46 values = vf.validate_call_args(("Alice",), {"age": 30}) 

47 # Returns: {"name": "Alice", "age": 30} 

48 

49 # Invalid arguments will raise ValidationError 

50 vf.validate_call_args(("Bob",), {"age": "not a number"}) 

51 # Raises: ValidationError 

52 ``` 

53 """ 

54 

55 def __init__( 1a

56 self, 

57 function: Callable[..., Any], 

58 config: ConfigDict | None = None, 

59 ): 

60 """ 

61 Initialize the validated function. 

62 

63 Args: 

64 function: The function to validate arguments for 

65 config: Optional Pydantic ConfigDict or dict configuration 

66 

67 Raises: 

68 ValueError: If function parameters conflict with internal field names 

69 """ 

70 self.raw_function = function 1a

71 self.signature = inspect.signature(function) 1a

72 self.arg_mapping: dict[int, str] = {} 1a

73 self.positional_only_args: set[str] = set() 1a

74 self.v_args_name = V_ARGS_NAME 1a

75 self.v_kwargs_name = V_KWARGS_NAME 1a

76 self._needs_rebuild = False 1a

77 

78 # Check for conflicts with internal field names 

79 reserved_names = { 1a

80 V_ARGS_NAME, 

81 V_KWARGS_NAME, 

82 V_POSITIONAL_ONLY_NAME, 

83 V_DUPLICATE_KWARGS, 

84 } 

85 param_names = set(self.signature.parameters.keys()) 1a

86 conflicts = reserved_names & param_names 1a

87 if conflicts: 87 ↛ 88line 87 didn't jump to line 88 because the condition on line 87 was never true1a

88 raise ValueError( 

89 f"Function parameters conflict with internal field names: {conflicts}. " 

90 f"These names are reserved: {reserved_names}" 

91 ) 

92 

93 # Build the validation model 

94 fields, takes_args, takes_kwargs, has_forward_refs = self._build_fields() 1a

95 self._create_model(fields, takes_args, takes_kwargs, config, has_forward_refs) 1a

96 

97 def _build_fields(self) -> tuple[dict[str, Any], bool, bool, bool]: 1a

98 """ 

99 Build field definitions from function signature. 

100 

101 Returns: 

102 Tuple of (fields_dict, takes_args, takes_kwargs, has_forward_refs) 

103 """ 

104 fields: dict[str, Any] = {} 1a

105 takes_args = False 1a

106 takes_kwargs = False 1a

107 has_forward_refs = False 1a

108 position = 0 1a

109 

110 for param_name, param in self.signature.parameters.items(): 1a

111 if param.kind == inspect.Parameter.VAR_POSITIONAL: 111 ↛ 112line 111 didn't jump to line 112 because the condition on line 111 was never true1a

112 takes_args = True 

113 continue 

114 

115 if param.kind == inspect.Parameter.VAR_KEYWORD: 115 ↛ 116line 115 didn't jump to line 116 because the condition on line 115 was never true1a

116 takes_kwargs = True 

117 continue 

118 

119 # Track positional-only parameters 

120 if param.kind == inspect.Parameter.POSITIONAL_ONLY: 120 ↛ 121line 120 didn't jump to line 121 because the condition on line 120 was never true1a

121 self.positional_only_args.add(param_name) 

122 

123 # Map position to parameter name 

124 if param.kind in ( 124 ↛ 132line 124 didn't jump to line 132 because the condition on line 124 was always true1a

125 inspect.Parameter.POSITIONAL_ONLY, 

126 inspect.Parameter.POSITIONAL_OR_KEYWORD, 

127 ): 

128 self.arg_mapping[position] = param_name 1a

129 position += 1 1a

130 

131 # Determine type and default 

132 annotation = ( 1a

133 param.annotation if param.annotation != inspect.Parameter.empty else Any 

134 ) 

135 

136 # Check if annotation is a string (forward reference) 

137 if isinstance(annotation, str): 137 ↛ 138line 137 didn't jump to line 138 because the condition on line 137 was never true1a

138 has_forward_refs = True 

139 

140 if param.default == inspect.Parameter.empty: 1a

141 # Required parameter 

142 fields[param_name] = (annotation, Field(...)) 1a

143 else: 

144 # Optional parameter with default 

145 fields[param_name] = (annotation, Field(default=param.default)) 1a

146 

147 # Always add args/kwargs fields for validation, even if function doesn't accept them 

148 fields[self.v_args_name] = (Optional[list[Any]], Field(default=None)) 1a

149 fields[self.v_kwargs_name] = (Optional[dict[str, Any]], Field(default=None)) 1a

150 

151 # Add special validation fields 

152 fields[V_POSITIONAL_ONLY_NAME] = (Optional[list[str]], Field(default=None)) 1a

153 fields[V_DUPLICATE_KWARGS] = (Optional[list[str]], Field(default=None)) 1a

154 

155 return fields, takes_args, takes_kwargs, has_forward_refs 1a

156 

157 def _create_model( 1a

158 self, 

159 fields: dict[str, Any], 

160 takes_args: bool, 

161 takes_kwargs: bool, 

162 config: ConfigDict | None, 

163 has_forward_refs: bool, 

164 ) -> None: 

165 """Create the Pydantic validation model.""" 

166 pos_args = len(self.arg_mapping) 1a

167 

168 # Process config 

169 # Note: ConfigDict is a TypedDict, so we can't use isinstance() in Python 3.14 

170 # Instead, check if it's a dict-like object and merge with defaults 

171 if config is None: 171 ↛ 172line 171 didn't jump to line 172 because the condition on line 171 was never true1a

172 config_dict = ConfigDict(extra="forbid") 

173 else: 

174 config_dict = config.copy() 1a

175 if "extra" not in config_dict: 175 ↛ 179line 175 didn't jump to line 179 because the condition on line 175 was always true1a

176 config_dict["extra"] = "forbid" 1a

177 

178 # Create base model with validators 

179 class DecoratorBaseModel(BaseModel): 1a

180 model_config: ClassVar[ConfigDict] = config_dict 1a

181 

182 @field_validator(V_ARGS_NAME, check_fields=False) 1a

183 @classmethod 1a

184 def check_args(cls, v: Optional[list[Any]]) -> Optional[list[Any]]: 1a

185 if takes_args or v is None: 

186 return v 

187 

188 raise TypeError( 

189 f"{pos_args} positional argument{'s' if pos_args != 1 else ''} " 

190 f"expected but {pos_args + len(v)} given" 

191 ) 

192 

193 @field_validator(V_KWARGS_NAME, check_fields=False) 1a

194 @classmethod 1a

195 def check_kwargs( 1a

196 cls, v: Optional[dict[str, Any]] 

197 ) -> Optional[dict[str, Any]]: 

198 if takes_kwargs or v is None: 

199 return v 

200 

201 plural = "" if len(v) == 1 else "s" 

202 keys = ", ".join(map(repr, v.keys())) 

203 raise TypeError(f"unexpected keyword argument{plural}: {keys}") 

204 

205 @field_validator(V_POSITIONAL_ONLY_NAME, check_fields=False) 1a

206 @classmethod 1a

207 def check_positional_only(cls, v: Optional[list[str]]) -> None: 1a

208 if v is None: 

209 return 

210 

211 plural = "" if len(v) == 1 else "s" 

212 keys = ", ".join(map(repr, v)) 

213 raise TypeError( 

214 f"positional-only argument{plural} passed as keyword " 

215 f"argument{plural}: {keys}" 

216 ) 

217 

218 @field_validator(V_DUPLICATE_KWARGS, check_fields=False) 1a

219 @classmethod 1a

220 def check_duplicate_kwargs(cls, v: Optional[list[str]]) -> None: 1a

221 if v is None: 

222 return 

223 

224 plural = "" if len(v) == 1 else "s" 

225 keys = ", ".join(map(repr, v)) 

226 raise TypeError(f"multiple values for argument{plural}: {keys}") 

227 

228 # Create the model dynamically 

229 model_name = f"{self.raw_function.__name__.title()}Model" 1a

230 self.model = create_model( 1a

231 model_name, 

232 __base__=DecoratorBaseModel, 

233 **fields, 

234 ) 

235 

236 # Rebuild the model with the original function's namespace to resolve forward references 

237 # This is necessary when using `from __future__ import annotations` or when 

238 # parameters reference types not in the current scope 

239 # Only rebuild if we detected forward references to avoid performance overhead 

240 # If rebuild fails (e.g., forward-referenced types not yet defined), defer to validation time 

241 if has_forward_refs: 241 ↛ 242line 241 didn't jump to line 242 because the condition on line 241 was never true1a

242 try: 

243 self.model.model_rebuild(_types_namespace=self.raw_function.__globals__) 

244 except (NameError, AttributeError): 

245 # Forward references can't be resolved yet (e.g., types defined after decorator) 

246 # Model will be rebuilt during validate_call_args when types are available 

247 self._needs_rebuild = True 

248 

249 def validate_call_args( 1a

250 self, args: tuple[Any, ...], kwargs: dict[str, Any] 

251 ) -> dict[str, Any]: 

252 """ 

253 Validate function arguments and return normalized parameters. 

254 

255 Args: 

256 args: Positional arguments 

257 kwargs: Keyword arguments 

258 

259 Returns: 

260 Dictionary mapping parameter names to values 

261 

262 Raises: 

263 ValidationError: If arguments don't match the function signature 

264 """ 

265 # Build the values dict for validation 

266 values: dict[str, Any] = {} 

267 var_args: list[Any] = [] 

268 var_kwargs: dict[str, Any] = {} 

269 positional_only_passed_as_kw: list[str] = [] 

270 duplicate_kwargs: list[str] = [] 

271 

272 # Process positional arguments 

273 for i, arg in enumerate(args): 

274 if i in self.arg_mapping: 

275 param_name = self.arg_mapping[i] 

276 if param_name in kwargs: 

277 # Duplicate: both positional and keyword 

278 duplicate_kwargs.append(param_name) 

279 values[param_name] = arg 

280 else: 

281 # Extra positional args go into *args 

282 var_args.append(arg) 

283 

284 # Process keyword arguments 

285 for key, value in kwargs.items(): 

286 if key in values: 

287 # Already set by positional arg 

288 continue 

289 

290 # Check if this is a positional-only param passed as keyword 

291 if key in self.positional_only_args: 

292 positional_only_passed_as_kw.append(key) 

293 continue 

294 

295 # Check if this is a known parameter 

296 if key in self.signature.parameters: 

297 values[key] = value 

298 else: 

299 # Unknown parameter goes into **kwargs 

300 var_kwargs[key] = value 

301 

302 # Add special fields 

303 if var_args: 

304 values[self.v_args_name] = var_args 

305 if var_kwargs: 

306 values[self.v_kwargs_name] = var_kwargs 

307 if positional_only_passed_as_kw: 

308 values[V_POSITIONAL_ONLY_NAME] = positional_only_passed_as_kw 

309 if duplicate_kwargs: 

310 values[V_DUPLICATE_KWARGS] = duplicate_kwargs 

311 

312 # Rebuild model if needed to resolve any forward references that weren't available 

313 # at initialization time (e.g., when using `from __future__ import annotations`) 

314 # Only rebuild if we previously failed to resolve forward refs at init time 

315 if self._needs_rebuild: 

316 # Try rebuilding with raise_errors=False to handle any remaining issues gracefully 

317 self.model.model_rebuild( 

318 _types_namespace=self.raw_function.__globals__, raise_errors=False 

319 ) 

320 # Clear the flag - we only need to rebuild once 

321 self._needs_rebuild = False 

322 

323 # Validate using the model 

324 try: 

325 validated = self.model.model_validate(values) 

326 except ValidationError as e: 

327 # Convert ValidationError to TypeError for certain cases to match Python's behavior 

328 # Check if the error is about extra kwargs 

329 for error in e.errors(): 

330 if error.get("type") == "extra_forbidden" and error.get("loc") == ( 

331 "kwargs", 

332 ): 

333 # This is an extra keyword argument error 

334 extra_keys = error.get("input", {}) 

335 if isinstance(extra_keys, dict): 

336 plural = "" if len(extra_keys) == 1 else "s" 

337 keys = ", ".join(map(repr, extra_keys.keys())) 

338 raise TypeError(f"unexpected keyword argument{plural}: {keys}") 

339 # For other validation errors, re-raise as-is 

340 raise 

341 

342 # Extract only the actual function parameters 

343 result: dict[str, Any] = {} 

344 for param_name in self.signature.parameters.keys(): 

345 param = self.signature.parameters[param_name] 

346 

347 if param.kind == inspect.Parameter.VAR_POSITIONAL: 

348 result[param_name] = getattr(validated, self.v_args_name) or [] 

349 elif param.kind == inspect.Parameter.VAR_KEYWORD: 

350 result[param_name] = getattr(validated, self.v_kwargs_name) or {} 

351 else: 

352 # Regular parameter 

353 value = getattr(validated, param_name) 

354 result[param_name] = value 

355 

356 return result 

357 

358 def __call__(self, *args: Any, **kwargs: Any) -> Any: 1a

359 """ 

360 Validate arguments and call the function. 

361 

362 Args: 

363 *args: Positional arguments 

364 **kwargs: Keyword arguments 

365 

366 Returns: 

367 The result of calling the function with validated arguments 

368 """ 

369 validated_params = self.validate_call_args(args, kwargs) 

370 return self.raw_function(**validated_params)