Coverage for /usr/local/lib/python3.12/site-packages/prefect/concurrency/_asyncio.py: 28%

86 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1from __future__ import annotations 1a

2 

3import asyncio 1a

4import logging 1a

5from contextlib import asynccontextmanager 1a

6from typing import TYPE_CHECKING, AsyncGenerator, Literal, Optional 1a

7from uuid import UUID 1a

8 

9import anyio 1a

10import httpx 1a

11 

12from prefect.client.orchestration import get_client 1a

13from prefect.client.schemas.responses import ( 1a

14 ConcurrencyLimitWithLeaseResponse, 

15 MinimalConcurrencyLimitResponse, 

16) 

17from prefect.concurrency._events import ( 1a

18 emit_concurrency_acquisition_events, 

19 emit_concurrency_release_events, 

20) 

21from prefect.concurrency._leases import amaintain_concurrency_lease 1a

22from prefect.concurrency.context import ConcurrencyContext 1a

23from prefect.logging import get_logger 1a

24from prefect.logging.loggers import get_run_logger 1a

25 

26from .services import ( 1a

27 ConcurrencySlotAcquisitionService, 

28 ConcurrencySlotAcquisitionWithLeaseService, 

29) 

30 

31if TYPE_CHECKING: 31 ↛ 32line 31 didn't jump to line 32 because the condition on line 31 was never true1a

32 from prefect.client.schemas.objects import ConcurrencyLeaseHolder 

33 

34 

35class ConcurrencySlotAcquisitionError(Exception): 1a

36 """Raised when an unhandlable occurs while acquiring concurrency slots.""" 

37 

38 

39class AcquireConcurrencySlotTimeoutError(TimeoutError): 1a

40 """Raised when acquiring a concurrency slot times out.""" 

41 

42 

43logger: logging.Logger = get_logger("concurrency") 1a

44 

45 

46async def aacquire_concurrency_slots( 1a

47 names: list[str], 

48 slots: int, 

49 mode: Literal["concurrency", "rate_limit"] = "concurrency", 

50 timeout_seconds: Optional[float] = None, 

51 max_retries: Optional[int] = None, 

52 strict: bool = False, 

53) -> list[MinimalConcurrencyLimitResponse]: 

54 service = ConcurrencySlotAcquisitionService.instance(frozenset(names)) 

55 future = service.send((slots, mode, timeout_seconds, max_retries)) 

56 try: 

57 response = await asyncio.wrap_future(future) 

58 except TimeoutError as timeout: 

59 raise AcquireConcurrencySlotTimeoutError( 

60 f"Attempt to acquire concurrency slots timed out after {timeout_seconds} second(s)" 

61 ) from timeout 

62 except Exception as exc: 

63 raise ConcurrencySlotAcquisitionError( 

64 f"Unable to acquire concurrency slots on {names!r}" 

65 ) from exc 

66 

67 retval = _response_to_minimal_concurrency_limit_response(response) 

68 

69 if not retval: 

70 if strict: 

71 raise ConcurrencySlotAcquisitionError( 

72 f"Concurrency limits {names!r} must be created before acquiring slots" 

73 ) 

74 try: 

75 logger = get_run_logger() 

76 except Exception: 

77 pass 

78 else: 

79 logger.warning( 

80 f"Concurrency limits {names!r} do not exist - skipping acquisition." 

81 ) 

82 

83 return retval 

84 

85 

86async def aacquire_concurrency_slots_with_lease( 1a

87 names: list[str], 

88 slots: int, 

89 mode: Literal["concurrency", "rate_limit"] = "concurrency", 

90 timeout_seconds: Optional[float] = None, 

91 max_retries: Optional[int] = None, 

92 lease_duration: float = 300, 

93 strict: bool = False, 

94 holder: "Optional[ConcurrencyLeaseHolder]" = None, 

95 suppress_warnings: bool = False, 

96) -> ConcurrencyLimitWithLeaseResponse: 

97 service = ConcurrencySlotAcquisitionWithLeaseService.instance(frozenset(names)) 

98 future = service.send( 

99 (slots, mode, timeout_seconds, max_retries, lease_duration, strict, holder) 

100 ) 

101 try: 

102 response = await asyncio.wrap_future(future) 

103 except TimeoutError as timeout: 

104 raise AcquireConcurrencySlotTimeoutError( 

105 f"Attempt to acquire concurrency slots timed out after {timeout_seconds} second(s)" 

106 ) from timeout 

107 except Exception as exc: 

108 raise ConcurrencySlotAcquisitionError( 

109 f"Unable to acquire concurrency slots on {names!r}" 

110 ) from exc 

111 

112 retval = ConcurrencyLimitWithLeaseResponse.model_validate(response.json()) 

113 

114 if not retval.limits: 

115 if strict: 

116 raise ConcurrencySlotAcquisitionError( 

117 f"Concurrency limits {names!r} must be created before acquiring slots" 

118 ) 

119 else: 

120 try: 

121 # Use a run logger if available 

122 task_logger = get_run_logger() 

123 except Exception: 

124 task_logger = get_logger("concurrency") 

125 

126 log_level = logging.DEBUG if suppress_warnings else logging.WARNING 

127 task_logger.log( 

128 log_level, 

129 f"Concurrency limits {names!r} do not exist - skipping acquisition.", 

130 ) 

131 

132 return retval 

133 

134 

135async def arelease_concurrency_slots( 1a

136 names: list[str], slots: int, occupancy_seconds: float 

137) -> list[MinimalConcurrencyLimitResponse]: 

138 async with get_client() as client: 

139 response = await client.release_concurrency_slots( 

140 names=names, slots=slots, occupancy_seconds=occupancy_seconds 

141 ) 

142 return _response_to_minimal_concurrency_limit_response(response) 

143 

144 

145async def arelease_concurrency_slots_with_lease( 1a

146 lease_id: UUID, 

147) -> None: 

148 async with get_client() as client: 

149 await client.release_concurrency_slots_with_lease(lease_id=lease_id) 

150 

151 

152def _response_to_minimal_concurrency_limit_response( 1a

153 response: httpx.Response, 

154) -> list[MinimalConcurrencyLimitResponse]: 

155 return [ 

156 MinimalConcurrencyLimitResponse.model_validate(obj_) for obj_ in response.json() 

157 ] 

158 

159 

160@asynccontextmanager 1a

161async def concurrency( 1a

162 names: str | list[str], 

163 occupy: int = 1, 

164 timeout_seconds: Optional[float] = None, 

165 max_retries: Optional[int] = None, 

166 lease_duration: float = 300, 

167 strict: bool = False, 

168 holder: "Optional[ConcurrencyLeaseHolder]" = None, 

169 suppress_warnings: bool = False, 

170) -> AsyncGenerator[None, None]: 

171 """ 

172 Internal version of the `concurrency` context manager. The public version is located in `prefect.concurrency.asyncio`. 

173 

174 Args: 

175 names: The names of the concurrency limits to acquire slots from. 

176 occupy: The number of slots to acquire and hold from each limit. 

177 timeout_seconds: The number of seconds to wait for the slots to be acquired before 

178 raising a `TimeoutError`. A timeout of `None` will wait indefinitely. 

179 max_retries: The maximum number of retries to acquire the concurrency slots. 

180 lease_duration: The duration of the lease for the acquired slots in seconds. 

181 strict: A boolean specifying whether to raise an error if the concurrency limit does not exist. 

182 Defaults to `False`. 

183 holder: A dictionary containing information about the holder of the concurrency slots. 

184 Typically includes 'type' and 'id' keys. 

185 

186 Raises: 

187 TimeoutError: If the slots are not acquired within the given timeout. 

188 ConcurrencySlotAcquisitionError: If the concurrency limit does not exist and `strict` is `True`. 

189 

190 Example: 

191 A simple example of using the async `concurrency` context manager: 

192 ```python 

193 from prefect.concurrency.asyncio import concurrency 

194 

195 async def resource_heavy(): 

196 async with concurrency("test", occupy=1): 

197 print("Resource heavy task") 

198 

199 async def main(): 

200 await resource_heavy() 

201 ``` 

202 """ 

203 if not names: 

204 yield 

205 return 

206 

207 names = names if isinstance(names, list) else [names] 

208 

209 response = await aacquire_concurrency_slots_with_lease( 

210 names=names, 

211 slots=occupy, 

212 timeout_seconds=timeout_seconds, 

213 max_retries=max_retries, 

214 lease_duration=lease_duration, 

215 strict=strict, 

216 holder=holder, 

217 suppress_warnings=suppress_warnings, 

218 ) 

219 emitted_events = emit_concurrency_acquisition_events(response.limits, occupy) 

220 

221 try: 

222 async with amaintain_concurrency_lease( 

223 response.lease_id, 

224 lease_duration, 

225 raise_on_lease_renewal_failure=strict, 

226 suppress_warnings=suppress_warnings, 

227 ): 

228 yield 

229 finally: 

230 try: 

231 await arelease_concurrency_slots_with_lease( 

232 lease_id=response.lease_id, 

233 ) 

234 except anyio.get_cancelled_exc_class(): 

235 # The task was cancelled before it could release the lease. Add the 

236 # lease ID to the cleanup list so it can be released when the 

237 # concurrency context is exited. 

238 if ctx := ConcurrencyContext.get(): 

239 ctx.cleanup_lease_ids.append(response.lease_id) 

240 

241 emit_concurrency_release_events(response.limits, occupy, emitted_events)