Coverage for /usr/local/lib/python3.12/site-packages/prefect/utilities/collections.py: 59%

242 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1""" 

2Utilities for extensions of and operations on Python collections. 

3""" 

4 

5import io 1a

6import itertools 1a

7import types 1a

8from collections import OrderedDict 1a

9from collections.abc import ( 1a

10 Callable, 

11 Collection, 

12 Generator, 

13 Hashable, 

14 Iterable, 

15 Iterator, 

16 Sequence, 

17 Set, 

18) 

19from dataclasses import fields, is_dataclass, replace 1a

20from enum import Enum, auto 1a

21from typing import ( 1a

22 TYPE_CHECKING, 

23 Any, 

24 Literal, 

25 Optional, 

26 Union, 

27 cast, 

28 overload, 

29) 

30from unittest.mock import Mock 1a

31 

32import pydantic 1a

33from typing_extensions import TypeAlias, TypeVar 1a

34 

35# Quote moved to `prefect.utilities.annotations` but preserved here for compatibility 

36from prefect.utilities.annotations import BaseAnnotation as BaseAnnotation 1a

37from prefect.utilities.annotations import Quote as Quote 1a

38from prefect.utilities.annotations import quote as quote 1a

39 

40if TYPE_CHECKING: 40 ↛ 41line 40 didn't jump to line 41 because the condition on line 40 was never true1a

41 pass 

42 

43 

44class AutoEnum(str, Enum): 1a

45 """ 

46 An enum class that automatically generates value from variable names. 

47 

48 This guards against common errors where variable names are updated but values are 

49 not. 

50 

51 In addition, because AutoEnums inherit from `str`, they are automatically 

52 JSON-serializable. 

53 

54 See https://docs.python.org/3/library/enum.html#using-automatic-values 

55 

56 Example: 

57 ```python 

58 class MyEnum(AutoEnum): 

59 RED = AutoEnum.auto() # equivalent to RED = 'RED' 

60 BLUE = AutoEnum.auto() # equivalent to BLUE = 'BLUE' 

61 ``` 

62 """ 

63 

64 @staticmethod 1a

65 def _generate_next_value_(name: str, *_: object, **__: object) -> str: 1a

66 return name 1a

67 

68 @staticmethod 1a

69 def auto() -> str: 1a

70 """ 

71 Exposes `enum.auto()` to avoid requiring a second import to use `AutoEnum` 

72 """ 

73 return auto() 1a

74 

75 def __repr__(self) -> str: 1a

76 return f"{type(self).__name__}.{self.value}" 1ab

77 

78 

79KT = TypeVar("KT") 1a

80VT = TypeVar("VT", infer_variance=True) 1a

81VT1 = TypeVar("VT1", infer_variance=True) 1a

82VT2 = TypeVar("VT2", infer_variance=True) 1a

83R = TypeVar("R", infer_variance=True) 1a

84NestedDict: TypeAlias = dict[KT, Union[VT, "NestedDict[KT, VT]"]] 1a

85HashableT = TypeVar("HashableT", bound=Hashable) 1a

86 

87 

88def dict_to_flatdict(dct: NestedDict[KT, VT]) -> dict[tuple[KT, ...], VT]: 1a

89 """Converts a (nested) dictionary to a flattened representation. 

90 

91 Each key of the flat dict will be a CompoundKey tuple containing the "chain of keys" 

92 for the corresponding value. 

93 

94 Args: 

95 dct (dict): The dictionary to flatten 

96 

97 Returns: 

98 A flattened dict of the same type as dct 

99 """ 

100 

101 def flatten( 1acdb

102 dct: NestedDict[KT, VT], _parent: tuple[KT, ...] = () 

103 ) -> Iterator[tuple[tuple[KT, ...], VT]]: 

104 parent = _parent or () 1acdb

105 for k, v in dct.items(): 1acdb

106 k_parent = (*parent, k) 1acdb

107 # if v is a non-empty dict, recurse 

108 if isinstance(v, dict) and v: 1acdb

109 yield from flatten(cast(NestedDict[KT, VT], v), _parent=k_parent) 1acdb

110 else: 

111 yield (k_parent, cast(VT, v)) 1acdb

112 

113 type_ = cast(type[dict[tuple[KT, ...], VT]], type(dct)) 1acdb

114 return type_(flatten(dct)) 1acdb

115 

116 

117def flatdict_to_dict(dct: dict[tuple[KT, ...], VT]) -> NestedDict[KT, VT]: 1a

118 """Converts a flattened dictionary back to a nested dictionary. 

119 

120 Args: 

121 dct (dict): The dictionary to be nested. Each key should be a tuple of keys 

122 as generated by `dict_to_flatdict` 

123 

124 Returns 

125 A nested dict of the same type as dct 

126 """ 

127 

128 type_ = cast(type[NestedDict[KT, VT]], type(dct)) 1acdb

129 

130 def new(type_: type[NestedDict[KT, VT]] = type_) -> NestedDict[KT, VT]: 1acdb

131 return type_() 1acdb

132 

133 result = new() 1acdb

134 for key_tuple, value in dct.items(): 1acdb

135 current = result 1acdb

136 *prefix_keys, last_key = key_tuple 1acdb

137 for prefix_key in prefix_keys: 1acdb

138 # Build nested dictionaries up for the current key tuple 

139 try: 1acdb

140 current = cast(NestedDict[KT, VT], current[prefix_key]) 1acdb

141 except KeyError: 1acdb

142 new_dict = current[prefix_key] = new() 1acdb

143 current = new_dict 1acdb

144 

145 # Set the value 

146 current[last_key] = value 1acdb

147 

148 return result 1acdb

149 

150 

151T = TypeVar("T") 1a

152 

153 

154def isiterable(obj: Any) -> bool: 1a

155 """ 

156 Return a boolean indicating if an object is iterable. 

157 

158 Excludes types that are iterable but typically used as singletons: 

159 - str 

160 - bytes 

161 - IO objects 

162 """ 

163 try: 

164 iter(obj) 

165 except TypeError: 

166 return False 

167 else: 

168 return not isinstance(obj, (str, bytes, io.IOBase)) 

169 

170 

171def ensure_iterable(obj: Union[T, Iterable[T]]) -> Collection[T]: 1a

172 if isinstance(obj, Sequence) or isinstance(obj, Set): 

173 return cast(Collection[T], obj) 

174 obj = cast(T, obj) # No longer in the iterable case 

175 return [obj] 

176 

177 

178def listrepr(objs: Iterable[Any], sep: str = " ") -> str: 1a

179 return sep.join(repr(obj) for obj in objs) 

180 

181 

182def extract_instances( 1a

183 objects: Iterable[Any], 

184 types: Union[type[T], tuple[type[T], ...]] = object, 

185) -> Union[list[T], dict[type[T], list[T]]]: 

186 """ 

187 Extract objects from a file and returns a dict of type -> instances 

188 

189 Args: 

190 objects: An iterable of objects 

191 types: A type or tuple of types to extract, defaults to all objects 

192 

193 Returns: 

194 If a single type is given: a list of instances of that type 

195 If a tuple of types is given: a mapping of type to a list of instances 

196 """ 

197 types_collection = ensure_iterable(types) 

198 

199 # Create a mapping of type -> instance from the exec values 

200 ret: dict[type[T], list[Any]] = {} 

201 

202 for o in objects: 

203 # We iterate here so that the key is the passed type rather than type(o) 

204 for type_ in types_collection: 

205 if isinstance(o, type_): 

206 ret.setdefault(type_, []).append(o) 

207 

208 if len(types_collection) == 1: 

209 [type_] = types_collection 

210 return ret[type_] 

211 

212 return ret 

213 

214 

215def batched_iterable( 1a

216 iterable: Iterable[T], size: int 

217) -> Generator[tuple[T, ...], None, None]: 

218 """ 

219 Yield batches of a certain size from an iterable 

220 

221 Args: 

222 iterable (Iterable): An iterable 

223 size (int): The batch size to return 

224 

225 Yields: 

226 tuple: A batch of the iterable 

227 """ 

228 it = iter(iterable) 1fecb

229 while True: 1fecb

230 batch = tuple(itertools.islice(it, size)) 1fecb

231 if not batch: 1fecb

232 break 1fecb

233 yield batch 1e

234 

235 

236class StopVisiting(BaseException): 1a

237 """ 

238 A special exception used to stop recursive visits in `visit_collection`. 

239 

240 When raised, the expression is returned without modification and recursive visits 

241 in that path will end. 

242 """ 

243 

244 

245@overload 1a

246def visit_collection( 246 ↛ exitline 246 didn't return from function 'visit_collection' because 1a

247 expr: Any, 

248 visit_fn: Callable[[Any, dict[str, VT]], Any], 

249 *, 

250 return_data: Literal[True] = ..., 

251 max_depth: int = ..., 

252 context: dict[str, VT] = ..., 

253 remove_annotations: bool = ..., 

254 _seen: Optional[dict[int, Any]] = ..., 

255) -> Any: ... 

256 

257 

258@overload 1a

259def visit_collection( 259 ↛ exitline 259 didn't return from function 'visit_collection' because 1a

260 expr: Any, 

261 visit_fn: Callable[[Any], Any], 

262 *, 

263 return_data: Literal[True] = ..., 

264 max_depth: int = ..., 

265 context: None = None, 

266 remove_annotations: bool = ..., 

267 _seen: Optional[dict[int, Any]] = ..., 

268) -> Any: ... 

269 

270 

271@overload 1a

272def visit_collection( 272 ↛ exitline 272 didn't return from function 'visit_collection' because 1a

273 expr: Any, 

274 visit_fn: Callable[[Any, dict[str, VT]], Any], 

275 *, 

276 return_data: bool = ..., 

277 max_depth: int = ..., 

278 context: dict[str, VT] = ..., 

279 remove_annotations: bool = ..., 

280 _seen: Optional[dict[int, Any]] = ..., 

281) -> Optional[Any]: ... 

282 

283 

284@overload 1a

285def visit_collection( 285 ↛ exitline 285 didn't return from function 'visit_collection' because 1a

286 expr: Any, 

287 visit_fn: Callable[[Any], Any], 

288 *, 

289 return_data: bool = ..., 

290 max_depth: int = ..., 

291 context: None = None, 

292 remove_annotations: bool = ..., 

293 _seen: Optional[dict[int, Any]] = ..., 

294) -> Optional[Any]: ... 

295 

296 

297@overload 1a

298def visit_collection( 298 ↛ exitline 298 didn't return from function 'visit_collection' because 1a

299 expr: Any, 

300 visit_fn: Callable[[Any, dict[str, VT]], Any], 

301 *, 

302 return_data: Literal[False] = False, 

303 max_depth: int = ..., 

304 context: dict[str, VT] = ..., 

305 remove_annotations: bool = ..., 

306 _seen: Optional[dict[int, Any]] = ..., 

307) -> None: ... 

308 

309 

310def visit_collection( 1a

311 expr: Any, 

312 visit_fn: Union[Callable[[Any, dict[str, VT]], Any], Callable[[Any], Any]], 

313 *, 

314 return_data: bool = False, 

315 max_depth: int = -1, 

316 context: Optional[dict[str, VT]] = None, 

317 remove_annotations: bool = False, 

318 _seen: Optional[dict[int, Any]] = None, 

319) -> Optional[Any]: 

320 """ 

321 Visits and potentially transforms every element of an arbitrary Python collection. 

322 

323 If an element is a Python collection, it will be visited recursively. If an element 

324 is not a collection, `visit_fn` will be called with the element. The return value of 

325 `visit_fn` can be used to alter the element if `return_data` is set to `True`. 

326 

327 Note: 

328 - When `return_data` is `True`, a copy of each collection is created only if 

329 `visit_fn` modifies an element within that collection. This approach minimizes 

330 performance penalties by avoiding unnecessary copying. 

331 - When `return_data` is `False`, no copies are created, and only side effects from 

332 `visit_fn` are applied. This mode is faster and should be used when no transformation 

333 of the collection is required, because it never has to copy any data. 

334 

335 Supported types: 

336 - List (including iterators) 

337 - Tuple 

338 - Set 

339 - Dict (note: keys are also visited recursively) 

340 - Dataclass 

341 - Pydantic model 

342 - Prefect annotations 

343 

344 Note that visit_collection will not consume generators or async generators, as it would prevent 

345 the caller from iterating over them. 

346 

347 Args: 

348 expr (Any): A Python object or expression. 

349 visit_fn (Callable[[Any, Optional[dict]], Any] or Callable[[Any], Any]): A function 

350 that will be applied to every non-collection element of `expr`. The function can 

351 accept one or two arguments. If two arguments are accepted, the second argument 

352 will be the context dictionary. 

353 return_data (bool): If `True`, a copy of `expr` containing data modified by `visit_fn` 

354 will be returned. This is slower than `return_data=False` (the default). 

355 max_depth (int): Controls the depth of recursive visitation. If set to zero, no 

356 recursion will occur. If set to a positive integer `N`, visitation will only 

357 descend to `N` layers deep. If set to any negative integer, no limit will be 

358 enforced and recursion will continue until terminal items are reached. By 

359 default, recursion is unlimited. 

360 context (Optional[dict]): An optional dictionary. If passed, the context will be sent 

361 to each call to the `visit_fn`. The context can be mutated by each visitor and 

362 will be available for later visits to expressions at the given depth. Values 

363 will not be available "up" a level from a given expression. 

364 The context will be automatically populated with an 'annotation' key when 

365 visiting collections within a `BaseAnnotation` type. This requires the caller to 

366 pass `context={}` and will not be activated by default. 

367 remove_annotations (bool): If set, annotations will be replaced by their contents. By 

368 default, annotations are preserved but their contents are visited. 

369 _seen (Optional[Set[int]]): A set of object ids that have already been visited. This 

370 prevents infinite recursion when visiting recursive data structures. 

371 

372 Returns: 

373 Any: The modified collection if `return_data` is `True`, otherwise `None`. 

374 """ 

375 

376 if _seen is None: 1acdb

377 _seen = {} 1acdb

378 

379 if context is not None: 379 ↛ 380line 379 didn't jump to line 380 because the condition on line 379 was never true1acdb

380 _callback = cast(Callable[[Any, dict[str, VT]], Any], visit_fn) 

381 

382 def visit_nested(expr: Any) -> Optional[Any]: 

383 return visit_collection( 

384 expr, 

385 _callback, 

386 return_data=return_data, 

387 remove_annotations=remove_annotations, 

388 max_depth=max_depth - 1, 

389 # Copy the context on nested calls so it does not "propagate up" 

390 context=context.copy(), 

391 _seen=_seen, 

392 ) 

393 

394 def visit_expression(expr: Any) -> Any: 

395 return _callback(expr, context) 

396 else: 

397 _callback = cast(Callable[[Any], Any], visit_fn) 1acdb

398 

399 def visit_nested(expr: Any) -> Optional[Any]: 1acdb

400 # Utility for a recursive call, preserving options and updating the depth. 

401 return visit_collection( 1acdb

402 expr, 

403 _callback, 

404 return_data=return_data, 

405 remove_annotations=remove_annotations, 

406 max_depth=max_depth - 1, 

407 _seen=_seen, 

408 ) 

409 

410 def visit_expression(expr: Any) -> Any: 1acdb

411 return _callback(expr) 1acdb

412 

413 # --- 1. Visit every expression 

414 try: 1acdb

415 result = visit_expression(expr) 1acdb

416 except StopVisiting: 

417 max_depth = 0 

418 result = expr 

419 

420 if return_data: 420 ↛ 429line 420 didn't jump to line 429 because the condition on line 420 was always true1acdb

421 # Only mutate the root expression if the user indicated we're returning data, 

422 # otherwise the function could return null and we have no collection to check 

423 expr = result 1acdb

424 

425 # --- 2. Visit every child of the expression recursively 

426 

427 # If we have reached the maximum depth or we have already visited this object, 

428 # return the result if we are returning data, otherwise return None 

429 obj_id = id(expr) 1acdb

430 if max_depth == 0: 430 ↛ 431line 430 didn't jump to line 431 because the condition on line 430 was never true1acdb

431 return result if return_data else None 

432 elif obj_id in _seen: 1acdb

433 # Return the cached transformed result 

434 return _seen[obj_id] if return_data else None 1acdb

435 

436 # Mark this object as being processed to handle circular references 

437 # We'll update with the actual result later 

438 _seen[obj_id] = expr 1acdb

439 

440 # Then visit every item in the expression if it is a collection 

441 

442 # presume that the result is the original expression. 

443 # in each of the following cases, we will update the result if we need to. 

444 result = expr 1acdb

445 

446 # --- Generators 

447 

448 if isinstance(expr, (types.GeneratorType, types.AsyncGeneratorType)): 448 ↛ 450line 448 didn't jump to line 450 because the condition on line 448 was never true1acdb

449 # Do not attempt to iterate over generators, as it will exhaust them 

450 pass 

451 

452 # --- Mocks 

453 

454 elif isinstance(expr, Mock): 454 ↛ 456line 454 didn't jump to line 456 because the condition on line 454 was never true1acdb

455 # Do not attempt to recurse into mock objects 

456 pass 

457 

458 # --- Annotations (unmapped, quote, etc.) 

459 

460 elif isinstance(expr, BaseAnnotation): 460 ↛ 461line 460 didn't jump to line 461 because the condition on line 460 was never true1acdb

461 annotated = cast(BaseAnnotation[Any], expr) 

462 if context is not None: 

463 context["annotation"] = cast(VT, annotated) 

464 unwrapped = annotated.unwrap() 

465 value = visit_nested(unwrapped) 

466 

467 if return_data: 

468 # if we are removing annotations, return the value 

469 if remove_annotations: 

470 result = value 

471 # if the value was modified, rewrap it 

472 elif value is not unwrapped: 

473 result = annotated.rewrap(value) 

474 # otherwise return the expr 

475 

476 # --- Sequences 

477 

478 elif isinstance(expr, (list, tuple, set)): 1acdb

479 seq = cast(Union[list[Any], tuple[Any], set[Any]], expr) 1acdb

480 items = [visit_nested(o) for o in seq] 1acdb

481 if return_data: 481 ↛ 537line 481 didn't jump to line 537 because the condition on line 481 was always true1acdb

482 modified = any(item is not orig for item, orig in zip(items, seq)) 1acdb

483 if modified: 483 ↛ 484line 483 didn't jump to line 484 because the condition on line 483 was never true1acdb

484 result = type(seq)(items) 

485 

486 # --- Dictionaries 

487 

488 elif isinstance(expr, (dict, OrderedDict)): 1acdb

489 mapping = cast(dict[Any, Any], expr) 1acdb

490 items = [(visit_nested(k), visit_nested(v)) for k, v in mapping.items()] 1acdb

491 if return_data: 491 ↛ 537line 491 didn't jump to line 537 because the condition on line 491 was always true1acdb

492 modified = any( 1acdb

493 k1 is not k2 or v1 is not v2 

494 for (k1, v1), (k2, v2) in zip(items, mapping.items()) 

495 ) 

496 if modified: 496 ↛ 497line 496 didn't jump to line 497 because the condition on line 496 was never true1acdb

497 result = type(mapping)(items) 

498 

499 # --- Dataclasses 

500 

501 elif is_dataclass(expr) and not isinstance(expr, type): 501 ↛ 502line 501 didn't jump to line 502 because the condition on line 501 was never true1acdb

502 expr_fields = fields(expr) 

503 values = [visit_nested(getattr(expr, f.name)) for f in expr_fields] 

504 if return_data: 

505 modified = any( 

506 getattr(expr, f.name) is not v for f, v in zip(expr_fields, values) 

507 ) 

508 if modified: 

509 result = replace( 

510 expr, **{f.name: v for f, v in zip(expr_fields, values)} 

511 ) 

512 

513 # --- Pydantic models 

514 

515 elif isinstance(expr, pydantic.BaseModel): 515 ↛ 517line 515 didn't jump to line 517 because the condition on line 515 was never true1acdb

516 # when extra=allow, fields not in model_fields may be in model_fields_set 

517 original_data = dict(expr) 

518 updated_data = { 

519 field: visit_nested(value) for field, value in original_data.items() 

520 } 

521 

522 if return_data: 

523 modified = any( 

524 original_data[field] is not updated_data[field] 

525 for field in updated_data 

526 ) 

527 if modified: 

528 # Use construct to avoid validation and handle immutability 

529 model_instance = expr.model_construct( 

530 _fields_set=expr.model_fields_set, **updated_data 

531 ) 

532 for private_attr in expr.__private_attributes__: 

533 setattr(model_instance, private_attr, getattr(expr, private_attr)) 

534 result = model_instance 

535 

536 # Update the cache with the final transformed result 

537 if return_data: 537 ↛ 540line 537 didn't jump to line 540 because the condition on line 537 was always true1acdb

538 _seen[obj_id] = result 1acdb

539 

540 if return_data: 540 ↛ exitline 540 didn't return from function 'visit_collection' because the condition on line 540 was always true1acdb

541 return result 1acdb

542 

543 

544@overload 1a

545def remove_nested_keys( 545 ↛ exitline 545 didn't return from function 'remove_nested_keys' because 1a

546 keys_to_remove: list[HashableT], obj: NestedDict[HashableT, VT] 

547) -> NestedDict[HashableT, VT]: ... 

548 

549 

550@overload 1a

551def remove_nested_keys(keys_to_remove: list[HashableT], obj: Any) -> Any: ... 551 ↛ exitline 551 didn't return from function 'remove_nested_keys' because 1a

552 

553 

554def remove_nested_keys( 1a

555 keys_to_remove: list[HashableT], obj: Union[NestedDict[HashableT, VT], Any] 

556) -> Union[NestedDict[HashableT, VT], Any]: 

557 """ 

558 Recurses a dictionary returns a copy without all keys that match an entry in 

559 `key_to_remove`. Return `obj` unchanged if not a dictionary. 

560 

561 Args: 

562 keys_to_remove: A list of keys to remove from obj obj: The object to remove keys 

563 from. 

564 

565 Returns: 

566 `obj` without keys matching an entry in `keys_to_remove` if `obj` is a 

567 dictionary. `obj` if `obj` is not a dictionary. 

568 """ 

569 if not isinstance(obj, dict): 1agecb

570 return obj 1agecb

571 return { 1agecb

572 key: remove_nested_keys(keys_to_remove, value) 

573 for key, value in cast(NestedDict[HashableT, VT], obj).items() 

574 if key not in keys_to_remove 

575 } 

576 

577 

578@overload 1a

579def distinct( 579 ↛ exitline 579 didn't return from function 'distinct' because 1a

580 iterable: Iterable[HashableT], key: None = None 

581) -> Iterator[HashableT]: ... 

582 

583 

584@overload 1a

585def distinct(iterable: Iterable[T], key: Callable[[T], Hashable]) -> Iterator[T]: ... 585 ↛ exitline 585 didn't return from function 'distinct' because 1a

586 

587 

588def distinct( 1a

589 iterable: Iterable[Union[T, HashableT]], 

590 key: Optional[Callable[[T], Hashable]] = None, 

591) -> Iterator[Union[T, HashableT]]: 

592 def _key(__i: Any) -> Hashable: 

593 return __i 

594 

595 if key is not None: 

596 _key = cast(Callable[[Any], Hashable], key) 

597 

598 seen: set[Hashable] = set() 

599 for item in iterable: 

600 if _key(item) in seen: 

601 continue 

602 seen.add(_key(item)) 

603 yield item 

604 

605 

606@overload 1a

607def get_from_dict( 607 ↛ exitline 607 didn't return from function 'get_from_dict' because 1a

608 dct: NestedDict[str, VT], keys: Union[str, list[str]], default: None = None 

609) -> Optional[VT]: ... 

610 

611 

612@overload 1a

613def get_from_dict( 613 ↛ exitline 613 didn't return from function 'get_from_dict' because 1a

614 dct: NestedDict[str, VT], keys: Union[str, list[str]], default: R 

615) -> Union[VT, R]: ... 

616 

617 

618def get_from_dict( 1a

619 dct: NestedDict[str, VT], keys: Union[str, list[str]], default: Optional[R] = None 

620) -> Union[VT, R, None]: 

621 """ 

622 Fetch a value from a nested dictionary or list using a sequence of keys. 

623 

624 This function allows to fetch a value from a deeply nested structure 

625 of dictionaries and lists using either a dot-separated string or a list 

626 of keys. If a requested key does not exist, the function returns the 

627 provided default value. 

628 

629 Args: 

630 dct: The nested dictionary or list from which to fetch the value. 

631 keys: The sequence of keys to use for access. Can be a 

632 dot-separated string or a list of keys. List indices can be included 

633 in the sequence as either integer keys or as string indices in square 

634 brackets. 

635 default: The default value to return if the requested key path does not 

636 exist. Defaults to None. 

637 

638 Returns: 

639 The fetched value if the key exists, or the default value if it does not. 

640 

641 Examples: 

642 

643 ```python 

644 get_from_dict({'a': {'b': {'c': [1, 2, 3, 4]}}}, 'a.b.c[1]') # 2 

645 get_from_dict({'a': {'b': [0, {'c': [1, 2]}]}}, ['a', 'b', 1, 'c', 1]) # 2 

646 get_from_dict({'a': {'b': [0, {'c': [1, 2]}]}}, 'a.b.1.c.2', 'default') # 'default' 

647 ``` 

648 """ 

649 if isinstance(keys, str): 

650 keys = keys.replace("[", ".").replace("]", "").split(".") 

651 value = dct 

652 try: 

653 for key in keys: 

654 try: 

655 # Try to cast to int to handle list indices 

656 key = int(key) 

657 except ValueError: 

658 # If it's not an int, use the key as-is 

659 # for dict lookup 

660 pass 

661 value = value[key] # type: ignore 

662 return cast(VT, value) 

663 except (TypeError, KeyError, IndexError): 

664 return default 

665 

666 

667def set_in_dict( 1a

668 dct: NestedDict[str, VT], keys: Union[str, list[str]], value: VT 

669) -> None: 

670 """ 

671 Sets a value in a nested dictionary using a sequence of keys. 

672 

673 This function allows to set a value in a deeply nested structure 

674 of dictionaries and lists using either a dot-separated string or a list 

675 of keys. If a requested key does not exist, the function will create it as 

676 a new dictionary. 

677 

678 Args: 

679 dct: The dictionary to set the value in. 

680 keys: The sequence of keys to use for access. Can be a 

681 dot-separated string or a list of keys. 

682 value: The value to set in the dictionary. 

683 

684 Returns: 

685 The modified dictionary with the value set at the specified key path. 

686 

687 Raises: 

688 KeyError: If the key path exists and is not a dictionary. 

689 """ 

690 if isinstance(keys, str): 690 ↛ 692line 690 didn't jump to line 692 because the condition on line 690 was always true1a

691 keys = keys.replace("[", ".").replace("]", "").split(".") 1a

692 for k in keys[:-1]: 1a

693 if not isinstance(dct.get(k, {}), dict): 693 ↛ 694line 693 didn't jump to line 694 because the condition on line 693 was never true1a

694 raise TypeError(f"Key path exists and contains a non-dict value: {keys}") 

695 if k not in dct: 1a

696 dct[k] = {} 1a

697 dct = cast(NestedDict[str, VT], dct[k]) 1a

698 dct[keys[-1]] = value 1a

699 

700 

701def deep_merge( 1a

702 dct: NestedDict[str, VT1], merge: NestedDict[str, VT2] 

703) -> NestedDict[str, Union[VT1, VT2]]: 

704 """ 

705 Recursively merges `merge` into `dct`. 

706 

707 Args: 

708 dct: The dictionary to merge into. 

709 merge: The dictionary to merge from. 

710 

711 Returns: 

712 A new dictionary with the merged contents. 

713 """ 

714 result: dict[str, Any] = dct.copy() # Start with keys and values from `dct` 1a

715 for key, value in merge.items(): 1a

716 if key in result and isinstance(result[key], dict) and isinstance(value, dict): 1a

717 # If both values are dictionaries, merge them recursively 

718 result[key] = deep_merge( 1a

719 cast(NestedDict[str, VT1], result[key]), 

720 cast(NestedDict[str, VT2], value), 

721 ) 

722 else: 

723 # Otherwise, overwrite with the new value 

724 result[key] = cast(Union[VT2, NestedDict[str, VT2]], value) 1a

725 return result 1a

726 

727 

728def deep_merge_dicts(*dicts: NestedDict[str, Any]) -> NestedDict[str, Any]: 1a

729 """ 

730 Recursively merges multiple dictionaries. 

731 

732 Args: 

733 dicts: The dictionaries to merge. 

734 

735 Returns: 

736 A new dictionary with the merged contents. 

737 """ 

738 result: NestedDict[str, Any] = {} 1a

739 for dictionary in dicts: 1a

740 result = deep_merge(result, dictionary) 1a

741 return result 1a