Coverage for /usr/local/lib/python3.12/site-packages/prefect/transactions.py: 18%
361 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 11:21 +0000
1from __future__ import annotations 1a
3import abc 1a
4import copy 1a
5import inspect 1a
6import logging 1a
7import uuid 1a
8from contextlib import asynccontextmanager, contextmanager 1a
9from contextvars import ContextVar, Token 1a
10from functools import partial 1a
11from typing import ( 1a
12 Any,
13 AsyncGenerator,
14 Callable,
15 ClassVar,
16 Generator,
17 NoReturn,
18 Optional,
19 Type,
20 Union,
21)
23import anyio.to_thread 1a
24from pydantic import Field, PrivateAttr 1a
25from typing_extensions import Self 1a
27from prefect.context import ContextModel 1a
28from prefect.exceptions import ( 1a
29 ConfigurationError,
30 MissingContextError,
31 SerializationError,
32)
33from prefect.filesystems import NullFileSystem 1a
34from prefect.logging.loggers import LoggingAdapter, get_logger, get_run_logger 1a
35from prefect.results import ( 1a
36 ResultRecord,
37 ResultStore,
38 get_result_store,
39)
40from prefect.utilities._engine import get_hook_name 1a
41from prefect.utilities.annotations import NotSet 1a
42from prefect.utilities.asyncutils import run_coro_as_sync 1a
43from prefect.utilities.collections import AutoEnum 1a
45logger: logging.Logger = get_logger("transactions") 1a
48class IsolationLevel(AutoEnum): 1a
49 READ_COMMITTED = AutoEnum.auto() 1a
50 SERIALIZABLE = AutoEnum.auto() 1a
53class CommitMode(AutoEnum): 1a
54 EAGER = AutoEnum.auto() 1a
55 LAZY = AutoEnum.auto() 1a
56 OFF = AutoEnum.auto() 1a
59class TransactionState(AutoEnum): 1a
60 PENDING = AutoEnum.auto() 1a
61 ACTIVE = AutoEnum.auto() 1a
62 STAGED = AutoEnum.auto() 1a
63 COMMITTED = AutoEnum.auto() 1a
64 ROLLED_BACK = AutoEnum.auto() 1a
67class BaseTransaction(ContextModel, abc.ABC): 1a
68 """
69 A base model for transaction state.
70 """
72 store: Optional[ResultStore] = None 1a
73 key: Optional[str] = None 1a
74 children: list[Self] = Field(default_factory=list) 1a
75 commit_mode: Optional[CommitMode] = None 1a
76 isolation_level: Optional[IsolationLevel] = IsolationLevel.READ_COMMITTED 1a
77 state: TransactionState = TransactionState.PENDING 1a
78 on_commit_hooks: list[Callable[[Self], None]] = Field(default_factory=list) 1a
79 on_rollback_hooks: list[Callable[[Self], None]] = Field(default_factory=list) 1a
80 overwrite: bool = False 1a
81 logger: Union[logging.Logger, LoggingAdapter] = Field( 1a
82 default_factory=partial(get_logger, "transactions")
83 )
84 write_on_commit: bool = True 1a
85 _stored_values: dict[str, Any] = PrivateAttr(default_factory=dict) 1a
86 _staged_value: ResultRecord[Any] | Any = None 1a
87 _holder: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4())) 1a
88 __var__: ClassVar[ContextVar[Self]] = ContextVar("transaction") 1a
90 def set(self, name: str, value: Any) -> None: 1a
91 """
92 Set a stored value in the transaction.
94 Args:
95 name: The name of the value to set
96 value: The value to set
98 Examples:
99 Set a value for use later in the transaction:
100 ```python
101 with transaction() as txn:
102 txn.set("key", "value")
103 ...
104 assert txn.get("key") == "value"
105 ```
106 """
107 self._stored_values[name] = value
109 def get(self, name: str, default: Any = NotSet) -> Any: 1a
110 """
111 Get a stored value from the transaction.
113 Child transactions will return values from their parents unless a value with
114 the same name is set in the child transaction.
116 Direct changes to returned values will not update the stored value. To update the
117 stored value, use the `set` method.
119 Args:
120 name: The name of the value to get
121 default: The default value to return if the value is not found
123 Returns:
124 The value from the transaction
126 Examples:
127 Get a value from the transaction:
128 ```python
129 with transaction() as txn:
130 txn.set("key", "value")
131 ...
132 assert txn.get("key") == "value"
133 ```
135 Get a value from a parent transaction:
136 ```python
137 with transaction() as parent:
138 parent.set("key", "parent_value")
139 with transaction() as child:
140 assert child.get("key") == "parent_value"
141 ```
143 Update a stored value:
144 ```python
145 with transaction() as txn:
146 txn.set("key", [1, 2, 3])
147 value = txn.get("key")
148 value.append(4)
149 # Stored value is not updated until `.set` is called
150 assert value == [1, 2, 3, 4]
151 assert txn.get("key") == [1, 2, 3]
153 txn.set("key", value)
154 assert txn.get("key") == [1, 2, 3, 4]
155 ```
156 """
157 # deepcopy to prevent mutation of stored values
158 value = copy.deepcopy(self._stored_values.get(name, NotSet))
159 if value is NotSet:
160 # if there's a parent transaction, get the value from the parent
161 parent = self.get_parent()
162 if parent is not None:
163 value = parent.get(name, default)
164 # if there's no parent transaction, use the default
165 elif default is not NotSet:
166 value = default
167 else:
168 raise ValueError(f"Could not retrieve value for unknown key: {name}")
169 return value
171 def is_committed(self) -> bool: 1a
172 return self.state == TransactionState.COMMITTED
174 def is_rolled_back(self) -> bool: 1a
175 return self.state == TransactionState.ROLLED_BACK
177 def is_staged(self) -> bool: 1a
178 return self.state == TransactionState.STAGED
180 def is_pending(self) -> bool: 1a
181 return self.state == TransactionState.PENDING
183 def is_active(self) -> bool: 1a
184 return self.state == TransactionState.ACTIVE
186 def prepare_transaction(self) -> None: 1a
187 """Helper method to prepare transaction state and validate configuration."""
188 if self._token is not None:
189 raise RuntimeError(
190 "Context already entered. Context enter calls cannot be nested."
191 )
192 parent = get_transaction()
193 # set default commit behavior; either inherit from parent or set a default of eager
194 if self.commit_mode is None:
195 self.commit_mode = parent.commit_mode if parent else CommitMode.LAZY
196 # set default isolation level; either inherit from parent or set a default of read committed
197 if self.isolation_level is None:
198 self.isolation_level = (
199 parent.isolation_level if parent else IsolationLevel.READ_COMMITTED
200 )
202 assert self.isolation_level is not None, "Isolation level was not set correctly"
203 if (
204 self.store
205 and self.key
206 and not self.store.supports_isolation_level(self.isolation_level)
207 ):
208 raise ConfigurationError(
209 f"Isolation level {self.isolation_level.name} is not supported by provided "
210 "configuration. Please ensure you've provided a lock file directory or lock "
211 "manager when using the SERIALIZABLE isolation level."
212 )
214 # this needs to go before begin, which could set the state to committed
215 self.state = TransactionState.ACTIVE
217 def add_child(self, transaction: Self) -> None: 1a
218 self.children.append(transaction)
220 def get_parent(self) -> Self | None: 1a
221 parent = None
222 if self._token:
223 prev_var = self._token.old_value
224 if prev_var != Token.MISSING:
225 parent = prev_var
226 else:
227 # `_token` has been reset so we need to get the active transaction from the context var
228 parent = self.get_active()
229 return parent
231 def stage( 1a
232 self,
233 value: Any,
234 on_rollback_hooks: Optional[list[Callable[..., Any]]] = None,
235 on_commit_hooks: Optional[list[Callable[..., Any]]] = None,
236 ) -> None:
237 """
238 Stage a value to be committed later.
239 """
240 on_commit_hooks = on_commit_hooks or []
241 on_rollback_hooks = on_rollback_hooks or []
243 if self.state != TransactionState.COMMITTED:
244 self._staged_value = value
245 self.on_rollback_hooks += on_rollback_hooks
246 self.on_commit_hooks += on_commit_hooks
247 self.state = TransactionState.STAGED
249 @classmethod 1a
250 def get_active(cls: Type[Self]) -> Optional[Self]: 1a
251 return cls.__var__.get(None)
253 def __eq__(self, other: Any) -> bool: 1a
254 if not isinstance(other, BaseTransaction):
255 return False
256 return dict(self) == dict(other)
259class Transaction(BaseTransaction): 1a
260 """
261 A model representing the state of a transaction.
262 """
264 def __enter__(self) -> Self: 1a
265 self.prepare_transaction()
266 self.begin()
267 self._token = self.__var__.set(self)
268 return self
270 def __exit__(self, *exc_info: Any) -> None: 1a
271 exc_type, exc_val, _ = exc_info
272 if not self._token:
273 raise RuntimeError(
274 "Asymmetric use of context. Context exit called without an enter."
275 )
276 if exc_type:
277 self.rollback()
278 self.reset()
279 raise exc_val
281 if self.commit_mode == CommitMode.EAGER:
282 self.commit()
284 # if parent, let them take responsibility
285 if self.get_parent():
286 self.reset()
287 return
289 if self.commit_mode == CommitMode.OFF:
290 # if no one took responsibility to commit, rolling back
291 # note that rollback returns if already committed
292 self.rollback()
293 elif self.commit_mode == CommitMode.LAZY:
294 # no one left to take responsibility for committing
295 self.commit()
297 self.reset()
299 def begin(self) -> None: 1a
300 if (
301 self.store
302 and self.key
303 and self.isolation_level == IsolationLevel.SERIALIZABLE
304 ):
305 self.logger.debug(f"Acquiring lock for transaction {self.key!r}")
306 self.store.acquire_lock(self.key, holder=self._holder)
307 if (
308 not self.overwrite
309 and self.store
310 and self.key
311 and self.store.exists(key=self.key)
312 ):
313 self.state = TransactionState.COMMITTED
315 def read(self) -> ResultRecord[Any] | None: 1a
316 if self.store and self.key:
317 return self.store.read(key=self.key, holder=self._holder)
318 return None
320 def reset(self) -> None: 1a
321 parent = self.get_parent()
323 if parent:
324 # parent takes responsibility
325 parent.add_child(self)
327 if self._token:
328 self.__var__.reset(self._token)
329 self._token = None
331 # do this below reset so that get_transaction() returns the relevant txn
332 if parent and self.state == TransactionState.ROLLED_BACK:
333 parent.rollback()
335 def commit(self) -> bool: 1a
336 if self.state in [TransactionState.ROLLED_BACK, TransactionState.COMMITTED]:
337 if (
338 self.store
339 and self.key
340 and self.isolation_level == IsolationLevel.SERIALIZABLE
341 ):
342 self.logger.debug(f"Releasing lock for transaction {self.key!r}")
343 self.store.release_lock(self.key, holder=self._holder)
345 return False
347 try:
348 for child in self.children:
349 if inspect.iscoroutinefunction(child.commit):
350 run_coro_as_sync(child.commit())
351 else:
352 child.commit()
354 for hook in self.on_commit_hooks:
355 self.run_hook(hook, "commit")
357 if self.store and self.key and self.write_on_commit:
358 if isinstance(self._staged_value, ResultRecord):
359 self.store.persist_result_record(
360 result_record=self._staged_value, holder=self._holder
361 )
362 else:
363 self.store.write(
364 key=self.key, obj=self._staged_value, holder=self._holder
365 )
367 self.state = TransactionState.COMMITTED
368 if (
369 self.store
370 and self.key
371 and self.isolation_level == IsolationLevel.SERIALIZABLE
372 ):
373 self.logger.debug(f"Releasing lock for transaction {self.key!r}")
374 self.store.release_lock(self.key, holder=self._holder)
375 return True
376 except SerializationError as exc:
377 if self.logger:
378 self.logger.warning(
379 f"Encountered an error while serializing result for transaction {self.key!r}: {exc}"
380 " Code execution will continue, but the transaction will not be committed.",
381 )
382 self.rollback()
383 return False
384 except Exception:
385 if self.logger:
386 self.logger.exception(
387 f"An error was encountered while committing transaction {self.key!r}",
388 exc_info=True,
389 )
390 self.rollback()
391 return False
393 def run_hook(self, hook: Callable[..., Any], hook_type: str) -> None: 1a
394 hook_name = get_hook_name(hook)
395 # Undocumented way to disable logging for a hook. Subject to change.
396 should_log = getattr(hook, "log_on_run", True)
398 if should_log:
399 self.logger.info(f"Running {hook_type} hook {hook_name!r}")
401 try:
402 if inspect.iscoroutinefunction(hook):
403 run_coro_as_sync(hook(self))
404 else:
405 hook(self)
406 except Exception as exc:
407 if should_log:
408 self.logger.error(
409 f"An error was encountered while running {hook_type} hook {hook_name!r}",
410 )
411 raise exc
412 else:
413 if should_log:
414 self.logger.info(
415 f"{hook_type.capitalize()} hook {hook_name!r} finished running successfully"
416 )
418 def rollback(self) -> bool: 1a
419 if self.state in [TransactionState.ROLLED_BACK, TransactionState.COMMITTED]:
420 return False
422 try:
423 for hook in reversed(self.on_rollback_hooks):
424 self.run_hook(hook, "rollback")
426 self.state: TransactionState = TransactionState.ROLLED_BACK
428 for child in reversed(self.children):
429 if inspect.iscoroutinefunction(child.rollback):
430 run_coro_as_sync(child.rollback())
431 else:
432 child.rollback()
434 return True
435 except Exception:
436 if self.logger:
437 self.logger.exception(
438 f"An error was encountered while rolling back transaction {self.key!r}",
439 exc_info=True,
440 )
441 return False
442 finally:
443 if (
444 self.store
445 and self.key
446 and self.isolation_level == IsolationLevel.SERIALIZABLE
447 ):
448 self.logger.debug(f"Releasing lock for transaction {self.key!r}")
449 self.store.release_lock(self.key, holder=self._holder)
452class AsyncTransaction(BaseTransaction): 1a
453 """
454 A model representing the state of an asynchronous transaction.
455 """
457 async def begin(self) -> None: 1a
458 if (
459 self.store
460 and self.key
461 and self.isolation_level == IsolationLevel.SERIALIZABLE
462 ):
463 self.logger.debug(f"Acquiring lock for transaction {self.key!r}")
464 await self.store.aacquire_lock(self.key, holder=self._holder)
465 if (
466 not self.overwrite
467 and self.store
468 and self.key
469 and await self.store.aexists(key=self.key)
470 ):
471 self.state = TransactionState.COMMITTED
473 async def read(self) -> ResultRecord[Any] | None: 1a
474 if self.store and self.key:
475 return await self.store.aread(key=self.key, holder=self._holder)
476 return None
478 async def reset(self) -> None: 1a
479 parent = self.get_parent()
481 if parent:
482 # parent takes responsibility
483 parent.add_child(self)
485 if self._token:
486 self.__var__.reset(self._token)
487 self._token = None
489 # do this below reset so that get_transaction() returns the relevant txn
490 if parent and self.state == TransactionState.ROLLED_BACK:
491 maybe_coro = parent.rollback()
492 if inspect.isawaitable(maybe_coro):
493 await maybe_coro
495 async def commit(self) -> bool: 1a
496 if self.state in [TransactionState.ROLLED_BACK, TransactionState.COMMITTED]:
497 if (
498 self.store
499 and self.key
500 and self.isolation_level == IsolationLevel.SERIALIZABLE
501 ):
502 self.logger.debug(f"Releasing lock for transaction {self.key!r}")
503 self.store.release_lock(self.key, holder=self._holder)
505 return False
507 try:
508 for child in self.children:
509 if isinstance(child, AsyncTransaction):
510 await child.commit()
511 else:
512 child.commit()
514 for hook in self.on_commit_hooks:
515 await self.run_hook(hook, "commit")
517 if self.store and self.key and self.write_on_commit:
518 if isinstance(self._staged_value, ResultRecord):
519 await self.store.apersist_result_record(
520 result_record=self._staged_value, holder=self._holder
521 )
522 else:
523 await self.store.awrite(
524 key=self.key, obj=self._staged_value, holder=self._holder
525 )
527 self.state = TransactionState.COMMITTED
528 if (
529 self.store
530 and self.key
531 and self.isolation_level == IsolationLevel.SERIALIZABLE
532 ):
533 self.logger.debug(f"Releasing lock for transaction {self.key!r}")
534 self.store.release_lock(self.key, holder=self._holder)
535 return True
536 except SerializationError as exc:
537 if self.logger:
538 self.logger.warning(
539 f"Encountered an error while serializing result for transaction {self.key!r}: {exc}"
540 " Code execution will continue, but the transaction will not be committed.",
541 )
542 await self.rollback()
543 return False
544 except Exception:
545 if self.logger:
546 self.logger.exception(
547 f"An error was encountered while committing transaction {self.key!r}",
548 exc_info=True,
549 )
550 await self.rollback()
551 return False
553 async def run_hook(self, hook: Callable[..., Any], hook_type: str) -> None: 1a
554 hook_name = get_hook_name(hook)
555 # Undocumented way to disable logging for a hook. Subject to change.
556 should_log = getattr(hook, "log_on_run", True)
558 if should_log:
559 self.logger.info(f"Running {hook_type} hook {hook_name!r}")
561 try:
562 if inspect.iscoroutinefunction(hook):
563 await hook(self)
564 else:
565 await anyio.to_thread.run_sync(hook, self)
566 except Exception as exc:
567 if should_log:
568 self.logger.error(
569 f"An error was encountered while running {hook_type} hook {hook_name!r}",
570 )
571 raise exc
572 else:
573 if should_log:
574 self.logger.info(
575 f"{hook_type.capitalize()} hook {hook_name!r} finished running successfully"
576 )
578 async def rollback(self) -> bool: 1a
579 if self.state in [TransactionState.ROLLED_BACK, TransactionState.COMMITTED]:
580 return False
582 try:
583 for hook in reversed(self.on_rollback_hooks):
584 await self.run_hook(hook, "rollback")
586 self.state: TransactionState = TransactionState.ROLLED_BACK
588 for child in reversed(self.children):
589 if isinstance(child, AsyncTransaction):
590 await child.rollback()
591 else:
592 child.rollback()
594 return True
595 except Exception:
596 if self.logger:
597 self.logger.exception(
598 f"An error was encountered while rolling back transaction {self.key!r}",
599 exc_info=True,
600 )
601 return False
602 finally:
603 if (
604 self.store
605 and self.key
606 and self.isolation_level == IsolationLevel.SERIALIZABLE
607 ):
608 self.logger.debug(f"Releasing lock for transaction {self.key!r}")
609 self.store.release_lock(self.key, holder=self._holder)
611 async def __aenter__(self) -> Self: 1a
612 self.prepare_transaction()
613 await self.begin()
614 self._token = self.__var__.set(self)
615 return self
617 async def __aexit__(self, *exc_info: Any) -> None: 1a
618 exc_type, exc_val, _ = exc_info
619 if not self._token:
620 raise RuntimeError(
621 "Asymmetric use of context. Context exit called without an enter."
622 )
623 if exc_type:
624 await self.rollback()
625 await self.reset()
626 raise exc_val
628 if self.commit_mode == CommitMode.EAGER:
629 await self.commit()
631 # if parent, let them take responsibility
632 if self.get_parent():
633 await self.reset()
634 return
636 if self.commit_mode == CommitMode.OFF:
637 # if no one took responsibility to commit, rolling back
638 # note that rollback returns if already committed
639 await self.rollback()
640 elif self.commit_mode == CommitMode.LAZY:
641 # no one left to take responsibility for committing
642 await self.commit()
644 await self.reset()
646 def __enter__(self) -> NoReturn: 1a
647 raise NotImplementedError(
648 "AsyncTransaction does not support the `with` statement. Use the `async with` statement instead."
649 )
651 def __exit__(self, *exc_info: Any) -> NoReturn: 1a
652 raise NotImplementedError(
653 "AsyncTransaction does not support the `with` statement. Use the `async with` statement instead."
654 )
657def get_transaction() -> BaseTransaction | None: 1a
658 return BaseTransaction.get_active()
661@contextmanager 1a
662def transaction( 1a
663 key: str | None = None,
664 store: ResultStore | None = None,
665 commit_mode: CommitMode | None = None,
666 isolation_level: IsolationLevel | None = None,
667 overwrite: bool = False,
668 write_on_commit: bool = True,
669 logger: logging.Logger | LoggingAdapter | None = None,
670) -> Generator[Transaction, None, None]:
671 """
672 A context manager for opening and managing a transaction.
674 Args:
675 - key: An identifier to use for the transaction
676 - store: The store to use for persisting the transaction result. If not provided,
677 a default store will be used based on the current run context.
678 - commit_mode: The commit mode controlling when the transaction and
679 child transactions are committed
680 - overwrite: Whether to overwrite an existing transaction record in the store
681 - write_on_commit: Whether to write the result to the store on commit. If not provided,
682 will default will be determined by the current run context. If no run context is
683 available, the value of `PREFECT_RESULTS_PERSIST_BY_DEFAULT` will be used.
685 Yields:
686 - Transaction: An object representing the transaction state
687 """
688 # if there is no key, we won't persist a record
689 if key and not store:
690 store = get_result_store()
692 # Avoid inheriting a NullFileSystem for metadata_storage from a flow's result store
693 if store and isinstance(store.metadata_storage, NullFileSystem):
694 store = store.model_copy(update={"metadata_storage": None})
696 try:
697 _logger: Union[logging.Logger, LoggingAdapter] = logger or get_run_logger()
698 except MissingContextError:
699 _logger = get_logger("transactions")
701 with Transaction(
702 key=key,
703 store=store,
704 commit_mode=commit_mode,
705 isolation_level=isolation_level,
706 overwrite=overwrite,
707 write_on_commit=write_on_commit,
708 logger=_logger,
709 ) as txn:
710 yield txn
713@asynccontextmanager 1a
714async def atransaction( 1a
715 key: str | None = None,
716 store: ResultStore | None = None,
717 commit_mode: CommitMode | None = None,
718 isolation_level: IsolationLevel | None = None,
719 overwrite: bool = False,
720 write_on_commit: bool = True,
721 logger: logging.Logger | LoggingAdapter | None = None,
722) -> AsyncGenerator[AsyncTransaction, None]:
723 """
724 An asynchronous context manager for opening and managing an asynchronous transaction.
726 Args:
727 - key: An identifier to use for the transaction
728 - store: The store to use for persisting the transaction result. If not provided,
729 a default store will be used based on the current run context.
730 - commit_mode: The commit mode controlling when the transaction and
731 child transactions are committed
732 - overwrite: Whether to overwrite an existing transaction record in the store
733 - write_on_commit: Whether to write the result to the store on commit. If not provided,
734 the default will be determined by the current run context. If no run context is
735 available, the value of `PREFECT_RESULTS_PERSIST_BY_DEFAULT` will be used.
737 Yields:
738 - AsyncTransaction: An object representing the transaction state
739 """
741 # if there is no key, we won't persist a record
742 if key and not store:
743 store = get_result_store()
745 # Avoid inheriting a NullFileSystem for metadata_storage from a flow's result store
746 if store and isinstance(store.metadata_storage, NullFileSystem):
747 store = store.model_copy(update={"metadata_storage": None})
749 try:
750 _logger: Union[logging.Logger, LoggingAdapter] = logger or get_run_logger()
751 except MissingContextError:
752 _logger = get_logger("transactions")
754 async with AsyncTransaction(
755 key=key,
756 store=store,
757 commit_mode=commit_mode,
758 isolation_level=isolation_level,
759 overwrite=overwrite,
760 write_on_commit=write_on_commit,
761 logger=_logger,
762 ) as txn:
763 yield txn