Coverage for /usr/local/lib/python3.12/site-packages/prefect/concurrency/_asyncio.py: 28%
86 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
1from __future__ import annotations 1a
3import asyncio 1a
4import logging 1a
5from contextlib import asynccontextmanager 1a
6from typing import TYPE_CHECKING, AsyncGenerator, Literal, Optional 1a
7from uuid import UUID 1a
9import anyio 1a
10import httpx 1a
12from prefect.client.orchestration import get_client 1a
13from prefect.client.schemas.responses import ( 1a
14 ConcurrencyLimitWithLeaseResponse,
15 MinimalConcurrencyLimitResponse,
16)
17from prefect.concurrency._events import ( 1a
18 emit_concurrency_acquisition_events,
19 emit_concurrency_release_events,
20)
21from prefect.concurrency._leases import amaintain_concurrency_lease 1a
22from prefect.concurrency.context import ConcurrencyContext 1a
23from prefect.logging import get_logger 1a
24from prefect.logging.loggers import get_run_logger 1a
26from .services import ( 1a
27 ConcurrencySlotAcquisitionService,
28 ConcurrencySlotAcquisitionWithLeaseService,
29)
31if TYPE_CHECKING: 31 ↛ 32line 31 didn't jump to line 32 because the condition on line 31 was never true1a
32 from prefect.client.schemas.objects import ConcurrencyLeaseHolder
35class ConcurrencySlotAcquisitionError(Exception): 1a
36 """Raised when an unhandlable occurs while acquiring concurrency slots."""
39class AcquireConcurrencySlotTimeoutError(TimeoutError): 1a
40 """Raised when acquiring a concurrency slot times out."""
43logger: logging.Logger = get_logger("concurrency") 1a
46async def aacquire_concurrency_slots( 1a
47 names: list[str],
48 slots: int,
49 mode: Literal["concurrency", "rate_limit"] = "concurrency",
50 timeout_seconds: Optional[float] = None,
51 max_retries: Optional[int] = None,
52 strict: bool = False,
53) -> list[MinimalConcurrencyLimitResponse]:
54 service = ConcurrencySlotAcquisitionService.instance(frozenset(names))
55 future = service.send((slots, mode, timeout_seconds, max_retries))
56 try:
57 response = await asyncio.wrap_future(future)
58 except TimeoutError as timeout:
59 raise AcquireConcurrencySlotTimeoutError(
60 f"Attempt to acquire concurrency slots timed out after {timeout_seconds} second(s)"
61 ) from timeout
62 except Exception as exc:
63 raise ConcurrencySlotAcquisitionError(
64 f"Unable to acquire concurrency slots on {names!r}"
65 ) from exc
67 retval = _response_to_minimal_concurrency_limit_response(response)
69 if not retval:
70 if strict:
71 raise ConcurrencySlotAcquisitionError(
72 f"Concurrency limits {names!r} must be created before acquiring slots"
73 )
74 try:
75 logger = get_run_logger()
76 except Exception:
77 pass
78 else:
79 logger.warning(
80 f"Concurrency limits {names!r} do not exist - skipping acquisition."
81 )
83 return retval
86async def aacquire_concurrency_slots_with_lease( 1a
87 names: list[str],
88 slots: int,
89 mode: Literal["concurrency", "rate_limit"] = "concurrency",
90 timeout_seconds: Optional[float] = None,
91 max_retries: Optional[int] = None,
92 lease_duration: float = 300,
93 strict: bool = False,
94 holder: "Optional[ConcurrencyLeaseHolder]" = None,
95 suppress_warnings: bool = False,
96) -> ConcurrencyLimitWithLeaseResponse:
97 service = ConcurrencySlotAcquisitionWithLeaseService.instance(frozenset(names))
98 future = service.send(
99 (slots, mode, timeout_seconds, max_retries, lease_duration, strict, holder)
100 )
101 try:
102 response = await asyncio.wrap_future(future)
103 except TimeoutError as timeout:
104 raise AcquireConcurrencySlotTimeoutError(
105 f"Attempt to acquire concurrency slots timed out after {timeout_seconds} second(s)"
106 ) from timeout
107 except Exception as exc:
108 raise ConcurrencySlotAcquisitionError(
109 f"Unable to acquire concurrency slots on {names!r}"
110 ) from exc
112 retval = ConcurrencyLimitWithLeaseResponse.model_validate(response.json())
114 if not retval.limits:
115 if strict:
116 raise ConcurrencySlotAcquisitionError(
117 f"Concurrency limits {names!r} must be created before acquiring slots"
118 )
119 else:
120 try:
121 # Use a run logger if available
122 task_logger = get_run_logger()
123 except Exception:
124 task_logger = get_logger("concurrency")
126 log_level = logging.DEBUG if suppress_warnings else logging.WARNING
127 task_logger.log(
128 log_level,
129 f"Concurrency limits {names!r} do not exist - skipping acquisition.",
130 )
132 return retval
135async def arelease_concurrency_slots( 1a
136 names: list[str], slots: int, occupancy_seconds: float
137) -> list[MinimalConcurrencyLimitResponse]:
138 async with get_client() as client:
139 response = await client.release_concurrency_slots(
140 names=names, slots=slots, occupancy_seconds=occupancy_seconds
141 )
142 return _response_to_minimal_concurrency_limit_response(response)
145async def arelease_concurrency_slots_with_lease( 1a
146 lease_id: UUID,
147) -> None:
148 async with get_client() as client:
149 await client.release_concurrency_slots_with_lease(lease_id=lease_id)
152def _response_to_minimal_concurrency_limit_response( 1a
153 response: httpx.Response,
154) -> list[MinimalConcurrencyLimitResponse]:
155 return [
156 MinimalConcurrencyLimitResponse.model_validate(obj_) for obj_ in response.json()
157 ]
160@asynccontextmanager 1a
161async def concurrency( 1a
162 names: str | list[str],
163 occupy: int = 1,
164 timeout_seconds: Optional[float] = None,
165 max_retries: Optional[int] = None,
166 lease_duration: float = 300,
167 strict: bool = False,
168 holder: "Optional[ConcurrencyLeaseHolder]" = None,
169 suppress_warnings: bool = False,
170) -> AsyncGenerator[None, None]:
171 """
172 Internal version of the `concurrency` context manager. The public version is located in `prefect.concurrency.asyncio`.
174 Args:
175 names: The names of the concurrency limits to acquire slots from.
176 occupy: The number of slots to acquire and hold from each limit.
177 timeout_seconds: The number of seconds to wait for the slots to be acquired before
178 raising a `TimeoutError`. A timeout of `None` will wait indefinitely.
179 max_retries: The maximum number of retries to acquire the concurrency slots.
180 lease_duration: The duration of the lease for the acquired slots in seconds.
181 strict: A boolean specifying whether to raise an error if the concurrency limit does not exist.
182 Defaults to `False`.
183 holder: A dictionary containing information about the holder of the concurrency slots.
184 Typically includes 'type' and 'id' keys.
186 Raises:
187 TimeoutError: If the slots are not acquired within the given timeout.
188 ConcurrencySlotAcquisitionError: If the concurrency limit does not exist and `strict` is `True`.
190 Example:
191 A simple example of using the async `concurrency` context manager:
192 ```python
193 from prefect.concurrency.asyncio import concurrency
195 async def resource_heavy():
196 async with concurrency("test", occupy=1):
197 print("Resource heavy task")
199 async def main():
200 await resource_heavy()
201 ```
202 """
203 if not names:
204 yield
205 return
207 names = names if isinstance(names, list) else [names]
209 response = await aacquire_concurrency_slots_with_lease(
210 names=names,
211 slots=occupy,
212 timeout_seconds=timeout_seconds,
213 max_retries=max_retries,
214 lease_duration=lease_duration,
215 strict=strict,
216 holder=holder,
217 suppress_warnings=suppress_warnings,
218 )
219 emitted_events = emit_concurrency_acquisition_events(response.limits, occupy)
221 try:
222 async with amaintain_concurrency_lease(
223 response.lease_id,
224 lease_duration,
225 raise_on_lease_renewal_failure=strict,
226 suppress_warnings=suppress_warnings,
227 ):
228 yield
229 finally:
230 try:
231 await arelease_concurrency_slots_with_lease(
232 lease_id=response.lease_id,
233 )
234 except anyio.get_cancelled_exc_class():
235 # The task was cancelled before it could release the lease. Add the
236 # lease ID to the cleanup list so it can be released when the
237 # concurrency context is exited.
238 if ctx := ConcurrencyContext.get():
239 ctx.cleanup_lease_ids.append(response.lease_id)
241 emit_concurrency_release_events(response.limits, occupy, emitted_events)