Coverage for /usr/local/lib/python3.12/site-packages/prefect/_internal/pydantic/validated_func.py: 38%
156 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
1"""
2Pure Pydantic v2 implementation of function argument validation.
4This module provides validation of function arguments without calling the function,
5compatible with Python 3.14+ (no Pydantic v1 dependency).
6"""
8from __future__ import annotations 1a
10import inspect 1a
11from typing import Any, Callable, ClassVar, Optional 1a
13from pydantic import ( 1a
14 BaseModel,
15 ConfigDict,
16 Field,
17 ValidationError,
18 create_model,
19 field_validator,
20)
22# Special field names for validation
23# These match pydantic.v1.decorator constants for compatibility
24V_ARGS_NAME = "v__args" 1a
25V_KWARGS_NAME = "v__kwargs" 1a
26V_POSITIONAL_ONLY_NAME = "v__positional_only" 1a
27V_DUPLICATE_KWARGS = "v__duplicate_kwargs" 1a
30class ValidatedFunction: 1a
31 """
32 Validates function arguments using Pydantic v2 without calling the function.
34 This class inspects a function's signature and creates a Pydantic model
35 that can validate arguments passed to the function, including handling
36 of *args, **kwargs, positional-only parameters, and duplicate arguments.
38 Example:
39 ```python
40 def greet(name: str, age: int = 0):
41 return f"Hello {name}, you are {age} years old"
43 vf = ValidatedFunction(greet)
45 # Validate arguments
46 values = vf.validate_call_args(("Alice",), {"age": 30})
47 # Returns: {"name": "Alice", "age": 30}
49 # Invalid arguments will raise ValidationError
50 vf.validate_call_args(("Bob",), {"age": "not a number"})
51 # Raises: ValidationError
52 ```
53 """
55 def __init__( 1a
56 self,
57 function: Callable[..., Any],
58 config: ConfigDict | None = None,
59 ):
60 """
61 Initialize the validated function.
63 Args:
64 function: The function to validate arguments for
65 config: Optional Pydantic ConfigDict or dict configuration
67 Raises:
68 ValueError: If function parameters conflict with internal field names
69 """
70 self.raw_function = function 1a
71 self.signature = inspect.signature(function) 1a
72 self.arg_mapping: dict[int, str] = {} 1a
73 self.positional_only_args: set[str] = set() 1a
74 self.v_args_name = V_ARGS_NAME 1a
75 self.v_kwargs_name = V_KWARGS_NAME 1a
76 self._needs_rebuild = False 1a
78 # Check for conflicts with internal field names
79 reserved_names = { 1a
80 V_ARGS_NAME,
81 V_KWARGS_NAME,
82 V_POSITIONAL_ONLY_NAME,
83 V_DUPLICATE_KWARGS,
84 }
85 param_names = set(self.signature.parameters.keys()) 1a
86 conflicts = reserved_names & param_names 1a
87 if conflicts: 87 ↛ 88line 87 didn't jump to line 88 because the condition on line 87 was never true1a
88 raise ValueError(
89 f"Function parameters conflict with internal field names: {conflicts}. "
90 f"These names are reserved: {reserved_names}"
91 )
93 # Build the validation model
94 fields, takes_args, takes_kwargs, has_forward_refs = self._build_fields() 1a
95 self._create_model(fields, takes_args, takes_kwargs, config, has_forward_refs) 1a
97 def _build_fields(self) -> tuple[dict[str, Any], bool, bool, bool]: 1a
98 """
99 Build field definitions from function signature.
101 Returns:
102 Tuple of (fields_dict, takes_args, takes_kwargs, has_forward_refs)
103 """
104 fields: dict[str, Any] = {} 1a
105 takes_args = False 1a
106 takes_kwargs = False 1a
107 has_forward_refs = False 1a
108 position = 0 1a
110 for param_name, param in self.signature.parameters.items(): 1a
111 if param.kind == inspect.Parameter.VAR_POSITIONAL: 111 ↛ 112line 111 didn't jump to line 112 because the condition on line 111 was never true1a
112 takes_args = True
113 continue
115 if param.kind == inspect.Parameter.VAR_KEYWORD: 115 ↛ 116line 115 didn't jump to line 116 because the condition on line 115 was never true1a
116 takes_kwargs = True
117 continue
119 # Track positional-only parameters
120 if param.kind == inspect.Parameter.POSITIONAL_ONLY: 120 ↛ 121line 120 didn't jump to line 121 because the condition on line 120 was never true1a
121 self.positional_only_args.add(param_name)
123 # Map position to parameter name
124 if param.kind in ( 124 ↛ 132line 124 didn't jump to line 132 because the condition on line 124 was always true1a
125 inspect.Parameter.POSITIONAL_ONLY,
126 inspect.Parameter.POSITIONAL_OR_KEYWORD,
127 ):
128 self.arg_mapping[position] = param_name 1a
129 position += 1 1a
131 # Determine type and default
132 annotation = ( 1a
133 param.annotation if param.annotation != inspect.Parameter.empty else Any
134 )
136 # Check if annotation is a string (forward reference)
137 if isinstance(annotation, str): 137 ↛ 138line 137 didn't jump to line 138 because the condition on line 137 was never true1a
138 has_forward_refs = True
140 if param.default == inspect.Parameter.empty: 1a
141 # Required parameter
142 fields[param_name] = (annotation, Field(...)) 1a
143 else:
144 # Optional parameter with default
145 fields[param_name] = (annotation, Field(default=param.default)) 1a
147 # Always add args/kwargs fields for validation, even if function doesn't accept them
148 fields[self.v_args_name] = (Optional[list[Any]], Field(default=None)) 1a
149 fields[self.v_kwargs_name] = (Optional[dict[str, Any]], Field(default=None)) 1a
151 # Add special validation fields
152 fields[V_POSITIONAL_ONLY_NAME] = (Optional[list[str]], Field(default=None)) 1a
153 fields[V_DUPLICATE_KWARGS] = (Optional[list[str]], Field(default=None)) 1a
155 return fields, takes_args, takes_kwargs, has_forward_refs 1a
157 def _create_model( 1a
158 self,
159 fields: dict[str, Any],
160 takes_args: bool,
161 takes_kwargs: bool,
162 config: ConfigDict | None,
163 has_forward_refs: bool,
164 ) -> None:
165 """Create the Pydantic validation model."""
166 pos_args = len(self.arg_mapping) 1a
168 # Process config
169 # Note: ConfigDict is a TypedDict, so we can't use isinstance() in Python 3.14
170 # Instead, check if it's a dict-like object and merge with defaults
171 if config is None: 171 ↛ 172line 171 didn't jump to line 172 because the condition on line 171 was never true1a
172 config_dict = ConfigDict(extra="forbid")
173 else:
174 config_dict = config.copy() 1a
175 if "extra" not in config_dict: 175 ↛ 179line 175 didn't jump to line 179 because the condition on line 175 was always true1a
176 config_dict["extra"] = "forbid" 1a
178 # Create base model with validators
179 class DecoratorBaseModel(BaseModel): 1a
180 model_config: ClassVar[ConfigDict] = config_dict 1a
182 @field_validator(V_ARGS_NAME, check_fields=False) 1a
183 @classmethod 1a
184 def check_args(cls, v: Optional[list[Any]]) -> Optional[list[Any]]: 1a
185 if takes_args or v is None:
186 return v
188 raise TypeError(
189 f"{pos_args} positional argument{'s' if pos_args != 1 else ''} "
190 f"expected but {pos_args + len(v)} given"
191 )
193 @field_validator(V_KWARGS_NAME, check_fields=False) 1a
194 @classmethod 1a
195 def check_kwargs( 1a
196 cls, v: Optional[dict[str, Any]]
197 ) -> Optional[dict[str, Any]]:
198 if takes_kwargs or v is None:
199 return v
201 plural = "" if len(v) == 1 else "s"
202 keys = ", ".join(map(repr, v.keys()))
203 raise TypeError(f"unexpected keyword argument{plural}: {keys}")
205 @field_validator(V_POSITIONAL_ONLY_NAME, check_fields=False) 1a
206 @classmethod 1a
207 def check_positional_only(cls, v: Optional[list[str]]) -> None: 1a
208 if v is None:
209 return
211 plural = "" if len(v) == 1 else "s"
212 keys = ", ".join(map(repr, v))
213 raise TypeError(
214 f"positional-only argument{plural} passed as keyword "
215 f"argument{plural}: {keys}"
216 )
218 @field_validator(V_DUPLICATE_KWARGS, check_fields=False) 1a
219 @classmethod 1a
220 def check_duplicate_kwargs(cls, v: Optional[list[str]]) -> None: 1a
221 if v is None:
222 return
224 plural = "" if len(v) == 1 else "s"
225 keys = ", ".join(map(repr, v))
226 raise TypeError(f"multiple values for argument{plural}: {keys}")
228 # Create the model dynamically
229 model_name = f"{self.raw_function.__name__.title()}Model" 1a
230 self.model = create_model( 1a
231 model_name,
232 __base__=DecoratorBaseModel,
233 **fields,
234 )
236 # Rebuild the model with the original function's namespace to resolve forward references
237 # This is necessary when using `from __future__ import annotations` or when
238 # parameters reference types not in the current scope
239 # Only rebuild if we detected forward references to avoid performance overhead
240 # If rebuild fails (e.g., forward-referenced types not yet defined), defer to validation time
241 if has_forward_refs: 241 ↛ 242line 241 didn't jump to line 242 because the condition on line 241 was never true1a
242 try:
243 self.model.model_rebuild(_types_namespace=self.raw_function.__globals__)
244 except (NameError, AttributeError):
245 # Forward references can't be resolved yet (e.g., types defined after decorator)
246 # Model will be rebuilt during validate_call_args when types are available
247 self._needs_rebuild = True
249 def validate_call_args( 1a
250 self, args: tuple[Any, ...], kwargs: dict[str, Any]
251 ) -> dict[str, Any]:
252 """
253 Validate function arguments and return normalized parameters.
255 Args:
256 args: Positional arguments
257 kwargs: Keyword arguments
259 Returns:
260 Dictionary mapping parameter names to values
262 Raises:
263 ValidationError: If arguments don't match the function signature
264 """
265 # Build the values dict for validation
266 values: dict[str, Any] = {}
267 var_args: list[Any] = []
268 var_kwargs: dict[str, Any] = {}
269 positional_only_passed_as_kw: list[str] = []
270 duplicate_kwargs: list[str] = []
272 # Process positional arguments
273 for i, arg in enumerate(args):
274 if i in self.arg_mapping:
275 param_name = self.arg_mapping[i]
276 if param_name in kwargs:
277 # Duplicate: both positional and keyword
278 duplicate_kwargs.append(param_name)
279 values[param_name] = arg
280 else:
281 # Extra positional args go into *args
282 var_args.append(arg)
284 # Process keyword arguments
285 for key, value in kwargs.items():
286 if key in values:
287 # Already set by positional arg
288 continue
290 # Check if this is a positional-only param passed as keyword
291 if key in self.positional_only_args:
292 positional_only_passed_as_kw.append(key)
293 continue
295 # Check if this is a known parameter
296 if key in self.signature.parameters:
297 values[key] = value
298 else:
299 # Unknown parameter goes into **kwargs
300 var_kwargs[key] = value
302 # Add special fields
303 if var_args:
304 values[self.v_args_name] = var_args
305 if var_kwargs:
306 values[self.v_kwargs_name] = var_kwargs
307 if positional_only_passed_as_kw:
308 values[V_POSITIONAL_ONLY_NAME] = positional_only_passed_as_kw
309 if duplicate_kwargs:
310 values[V_DUPLICATE_KWARGS] = duplicate_kwargs
312 # Rebuild model if needed to resolve any forward references that weren't available
313 # at initialization time (e.g., when using `from __future__ import annotations`)
314 # Only rebuild if we previously failed to resolve forward refs at init time
315 if self._needs_rebuild:
316 # Try rebuilding with raise_errors=False to handle any remaining issues gracefully
317 self.model.model_rebuild(
318 _types_namespace=self.raw_function.__globals__, raise_errors=False
319 )
320 # Clear the flag - we only need to rebuild once
321 self._needs_rebuild = False
323 # Validate using the model
324 try:
325 validated = self.model.model_validate(values)
326 except ValidationError as e:
327 # Convert ValidationError to TypeError for certain cases to match Python's behavior
328 # Check if the error is about extra kwargs
329 for error in e.errors():
330 if error.get("type") == "extra_forbidden" and error.get("loc") == (
331 "kwargs",
332 ):
333 # This is an extra keyword argument error
334 extra_keys = error.get("input", {})
335 if isinstance(extra_keys, dict):
336 plural = "" if len(extra_keys) == 1 else "s"
337 keys = ", ".join(map(repr, extra_keys.keys()))
338 raise TypeError(f"unexpected keyword argument{plural}: {keys}")
339 # For other validation errors, re-raise as-is
340 raise
342 # Extract only the actual function parameters
343 result: dict[str, Any] = {}
344 for param_name in self.signature.parameters.keys():
345 param = self.signature.parameters[param_name]
347 if param.kind == inspect.Parameter.VAR_POSITIONAL:
348 result[param_name] = getattr(validated, self.v_args_name) or []
349 elif param.kind == inspect.Parameter.VAR_KEYWORD:
350 result[param_name] = getattr(validated, self.v_kwargs_name) or {}
351 else:
352 # Regular parameter
353 value = getattr(validated, param_name)
354 result[param_name] = value
356 return result
358 def __call__(self, *args: Any, **kwargs: Any) -> Any: 1a
359 """
360 Validate arguments and call the function.
362 Args:
363 *args: Positional arguments
364 **kwargs: Keyword arguments
366 Returns:
367 The result of calling the function with validated arguments
368 """
369 validated_params = self.validate_call_args(args, kwargs)
370 return self.raw_function(**validated_params)