Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/utilities/messaging/memory.py: 44%

205 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 11:21 +0000

1from __future__ import annotations 1a

2 

3import asyncio 1a

4import copy 1a

5import threading 1a

6from collections import defaultdict 1a

7from collections.abc import AsyncGenerator, Iterable, Mapping, MutableMapping 1a

8from contextlib import asynccontextmanager 1a

9from dataclasses import asdict, dataclass 1a

10from datetime import timedelta 1a

11from pathlib import Path 1a

12from types import TracebackType 1a

13from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union 1a

14from uuid import uuid4 1a

15 

16import anyio 1a

17from cachetools import TTLCache 1a

18from exceptiongroup import BaseExceptionGroup # novermin 1a

19from pydantic_core import to_json 1a

20from typing_extensions import Self 1a

21 

22from prefect.logging import get_logger 1a

23from prefect.server.utilities.messaging import Cache as _Cache 1a

24from prefect.server.utilities.messaging import Consumer as _Consumer 1a

25from prefect.server.utilities.messaging import Message, MessageHandler, StopConsumer 1a

26from prefect.server.utilities.messaging import Publisher as _Publisher 1a

27from prefect.settings.context import get_current_settings 1a

28 

29if TYPE_CHECKING: 29 ↛ 30line 29 didn't jump to line 30 because the condition on line 29 was never true1a

30 import logging 

31 

32logger: "logging.Logger" = get_logger(__name__) 1a

33 

34# Simple global counters by topic with thread-safe access 

35_metrics_lock: threading.Lock | None = None 1a

36METRICS: dict[str, dict[str, int]] = defaultdict( 1a

37 lambda: { 

38 "published": 0, 

39 "retried": 0, 

40 "consumed": 0, 

41 } 

42) 

43 

44 

45async def log_metrics_periodically(interval: float = 2.0) -> None: 1a

46 if _metrics_lock is None: 

47 return 

48 while True: 

49 await asyncio.sleep(interval) 

50 with _metrics_lock: 

51 for topic, data in METRICS.items(): 

52 if data["published"] == 0: 

53 continue 

54 depth = data["published"] - data["consumed"] 

55 logger.debug( 

56 "Topic=%r | published=%d consumed=%d retried=%d depth=%d", 

57 topic, 

58 data["published"], 

59 data["consumed"], 

60 data["retried"], 

61 depth, 

62 ) 

63 

64 

65async def update_metric(topic: str, key: str, amount: int = 1) -> None: 1a

66 global _metrics_lock 

67 if _metrics_lock is None: 

68 _metrics_lock = threading.Lock() 

69 with _metrics_lock: 

70 METRICS[topic][key] += amount 

71 

72 

73@dataclass 1a

74class MemoryMessage: 1a

75 data: Union[bytes, str] 1a

76 attributes: Mapping[str, Any] 1a

77 retry_count: int = 0 1a

78 

79 

80class Subscription: 1a

81 """ 

82 A subscription to a topic. 

83 

84 Messages are delivered to the subscription's queue and retried up to a 

85 maximum number of times. If a message cannot be delivered after the maximum 

86 number of retries it is moved to the dead letter queue. 

87 

88 The dead letter queue is a directory of JSON files containing the serialized 

89 message. 

90 

91 Messages remain in the dead letter queue until they are removed manually. 

92 

93 Attributes: 

94 topic: The topic that the subscription receives messages from. 

95 max_retries: The maximum number of times a message will be retried for 

96 this subscription. 

97 dead_letter_queue_path: The path to the dead letter queue folder. 

98 """ 

99 

100 def __init__( 1a

101 self, 

102 topic: "Topic", 

103 max_retries: int = 3, 

104 dead_letter_queue_path: Path | str | None = None, 

105 ) -> None: 

106 self.topic = topic 1b

107 self.max_retries = max_retries 1b

108 self.dead_letter_queue_path: Path = ( 1b

109 Path(dead_letter_queue_path) 

110 if dead_letter_queue_path 

111 else get_current_settings().home / "dlq" 

112 ) 

113 self._queue: asyncio.Queue[MemoryMessage] = asyncio.Queue(maxsize=10000) 1b

114 self._retry: asyncio.Queue[MemoryMessage] = asyncio.Queue(maxsize=1000) 1b

115 

116 async def deliver(self, message: MemoryMessage) -> None: 1a

117 """ 

118 Deliver a message to the subscription's queue. 

119 

120 Args: 

121 message: The message to deliver. 

122 """ 

123 try: 

124 self._queue.put_nowait(message) 

125 await update_metric(self.topic.name, "published") 

126 logger.debug( 

127 "Delivered message to topic=%r queue_size=%d retry_queue_size=%d", 

128 self.topic.name, 

129 self._queue.qsize(), 

130 self._retry.qsize(), 

131 ) 

132 except asyncio.QueueFull: 

133 logger.warning( 

134 "Subscription queue is full, dropping message for topic=%r queue_size=%d retry_queue_size=%d", 

135 self.topic.name, 

136 self._queue.qsize(), 

137 self._retry.qsize(), 

138 ) 

139 

140 async def retry(self, message: MemoryMessage) -> None: 1a

141 """ 

142 Place a message back on the retry queue. 

143 

144 If the message has retried more than the maximum number of times it is 

145 moved to the dead letter queue. 

146 

147 Args: 

148 message: The message to retry. 

149 """ 

150 message.retry_count += 1 

151 if message.retry_count > self.max_retries: 

152 logger.warning( 

153 "Message failed after %d retries and will be moved to the dead letter queue", 

154 message.retry_count, 

155 extra={"event_message": message}, 

156 ) 

157 await self.send_to_dead_letter_queue(message) 

158 else: 

159 await self._retry.put(message) 

160 await update_metric(self.topic.name, "retried") 

161 logger.debug( 

162 "Retried message on topic=%r retry_count=%d queue_size=%d retry_queue_size=%d", 

163 self.topic.name, 

164 message.retry_count, 

165 self._queue.qsize(), 

166 self._retry.qsize(), 

167 ) 

168 

169 async def get(self) -> MemoryMessage: 1a

170 """ 

171 Get a message from the subscription's queue. 

172 """ 

173 if not self._retry.empty(): 173 ↛ 174line 173 didn't jump to line 174 because the condition on line 173 was never true1c

174 return await self._retry.get() 

175 return await self._queue.get() 1c

176 

177 async def send_to_dead_letter_queue(self, message: MemoryMessage) -> None: 1a

178 """ 

179 Send a message to the dead letter queue. 

180 

181 The dead letter queue is a directory of JSON files containing the 

182 serialized messages. 

183 

184 Args: 

185 message: The message to send to the dead letter queue. 

186 """ 

187 self.dead_letter_queue_path.mkdir(parents=True, exist_ok=True) 

188 try: 

189 await anyio.Path(self.dead_letter_queue_path / uuid4().hex).write_bytes( 

190 to_json(asdict(message)) 

191 ) 

192 except Exception as e: 

193 logger.warning("Failed to write message to dead letter queue", exc_info=e) 

194 

195 

196class Topic: 1a

197 _topics: dict[str, "Topic"] = {} 1a

198 

199 name: str 1a

200 _subscriptions: list[Subscription] 1a

201 

202 def __init__(self, name: str) -> None: 1a

203 self.name = name 1b

204 self._subscriptions = [] 1b

205 

206 @classmethod 1a

207 def by_name(cls, name: str) -> "Topic": 1a

208 try: 1b

209 return cls._topics[name] 1b

210 except KeyError: 1b

211 topic = cls(name) 1b

212 cls._topics[name] = topic 1b

213 return topic 1b

214 

215 @classmethod 1a

216 def clear_all(cls) -> None: 1a

217 for topic in cls._topics.values(): 

218 topic.clear() 

219 cls._topics = {} 

220 

221 def subscribe(self, **subscription_kwargs: Any) -> Subscription: 1a

222 subscription = Subscription(self, **subscription_kwargs) 1b

223 self._subscriptions.append(subscription) 1b

224 return subscription 1b

225 

226 def unsubscribe(self, subscription: Subscription) -> None: 1a

227 self._subscriptions.remove(subscription) 

228 

229 def clear(self) -> None: 1a

230 for subscription in self._subscriptions: 

231 self.unsubscribe(subscription) 

232 self._subscriptions = [] 

233 

234 async def publish(self, message: MemoryMessage) -> None: 1a

235 for subscription in self._subscriptions: 

236 # Ensure that each subscription gets its own copy of the message 

237 await subscription.deliver(copy.deepcopy(message)) 

238 

239 

240@asynccontextmanager 1a

241async def break_topic(): 1a

242 from unittest import mock 

243 

244 publishing_mock = mock.AsyncMock(side_effect=ValueError("oops")) 

245 

246 with mock.patch( 

247 "prefect.server.utilities.messaging.memory.Topic.publish", 

248 publishing_mock, 

249 ): 

250 yield 

251 

252 

253M = TypeVar("M", bound=Message) 1a

254 

255 

256class Cache(_Cache): 1a

257 _recently_seen_messages: MutableMapping[str, bool] = TTLCache( 1a

258 maxsize=1000, 

259 ttl=timedelta(minutes=5).total_seconds(), 

260 ) 

261 

262 async def clear_recently_seen_messages(self) -> None: 1a

263 self._recently_seen_messages.clear() 

264 

265 async def without_duplicates( 1a

266 self, attribute: str, messages: Iterable[M] 

267 ) -> list[M]: 

268 messages_with_attribute: list[M] = [] 

269 messages_without_attribute: list[M] = [] 

270 

271 for m in messages: 

272 if not m.attributes or attribute not in m.attributes: 

273 logger.warning( 

274 "Message is missing deduplication attribute %r", 

275 attribute, 

276 extra={"event_message": m}, 

277 ) 

278 messages_without_attribute.append(m) 

279 continue 

280 

281 if self._recently_seen_messages.get(m.attributes[attribute]): 

282 continue 

283 

284 self._recently_seen_messages[m.attributes[attribute]] = True 

285 messages_with_attribute.append(m) 

286 

287 return messages_with_attribute + messages_without_attribute 

288 

289 async def forget_duplicates(self, attribute: str, messages: Iterable[M]) -> None: 1a

290 for m in messages: 

291 if not m.attributes or attribute not in m.attributes: 

292 logger.warning( 

293 "Message is missing deduplication attribute %r", 

294 attribute, 

295 extra={"event_message": m}, 

296 ) 

297 continue 

298 self._recently_seen_messages.pop(m.attributes[attribute], None) 

299 

300 

301class Publisher(_Publisher): 1a

302 def __init__(self, topic: str, cache: Cache, deduplicate_by: Optional[str] = None): 1a

303 self.topic: Topic = Topic.by_name(topic) 

304 self.deduplicate_by = deduplicate_by 

305 self._cache = cache 

306 

307 async def __aenter__(self) -> Self: 1a

308 return self 

309 

310 async def __aexit__( 1a

311 self, 

312 exc_type: type[BaseException] | None, 

313 exc_val: BaseException | None, 

314 exc_tb: TracebackType | None, 

315 ) -> None: 

316 return None 

317 

318 async def publish_data(self, data: bytes, attributes: Mapping[str, str]) -> None: 1a

319 to_publish = [MemoryMessage(data, attributes)] 

320 if self.deduplicate_by: 

321 to_publish = await self._cache.without_duplicates( 

322 self.deduplicate_by, to_publish 

323 ) 

324 

325 try: 

326 for message in to_publish: 

327 await self.topic.publish(message) 

328 except Exception: 

329 if self.deduplicate_by: 

330 await self._cache.forget_duplicates(self.deduplicate_by, to_publish) 

331 raise 

332 

333 

334class Consumer(_Consumer): 1a

335 def __init__( 1a

336 self, 

337 topic: str, 

338 subscription: Optional[Subscription] = None, 

339 concurrency: int = 2, 

340 **kwargs: Any, 

341 ): 

342 self.topic: Topic = Topic.by_name(topic) 1b

343 if not subscription: 343 ↛ 345line 343 didn't jump to line 345 because the condition on line 343 was always true1b

344 subscription = self.topic.subscribe() 1b

345 assert subscription.topic is self.topic 1b

346 self.subscription = subscription 1b

347 self.concurrency = concurrency 1b

348 

349 async def run(self, handler: MessageHandler) -> None: 1a

350 try: 1c

351 async with anyio.create_task_group() as tg: 1c

352 for _ in range(self.concurrency): 1c

353 tg.start_soon(self._consume_loop, handler) 1c

354 except BaseExceptionGroup as group: # novermin 

355 if all(isinstance(exc, StopConsumer) for exc in group.exceptions): 

356 logger.debug("StopConsumer received") 

357 return # Exit cleanly when all tasks stop 

358 # Re-raise if any non-StopConsumer exceptions 

359 raise group 

360 

361 async def cleanup(self) -> None: 1a

362 """ 

363 Cleanup resources by unsubscribing from the topic. 

364 

365 This should be called when the consumer is no longer needed to prevent 

366 memory leaks from orphaned subscriptions. 

367 """ 

368 self.topic.unsubscribe(self.subscription) 

369 logger.debug("Unsubscribed from topic=%r", self.topic.name) 

370 

371 async def _consume_loop(self, handler: MessageHandler) -> None: 1a

372 while True: 1c

373 message = await self.subscription.get() 1c

374 try: 

375 await handler(message) 

376 await update_metric(self.topic.name, "consumed") 

377 except StopConsumer as e: 

378 if not e.ack: 

379 await self.subscription.retry(message) 

380 raise # Propagate to task group 

381 except Exception: 

382 logger.exception("Failed in consume_loop") 

383 await self.subscription.retry(message) 

384 

385 

386@asynccontextmanager 1a

387async def ephemeral_subscription(topic: str) -> AsyncGenerator[Mapping[str, Any], None]: 1a

388 subscription = Topic.by_name(topic).subscribe() 

389 try: 

390 yield {"topic": topic, "subscription": subscription} 

391 finally: 

392 Topic.by_name(topic).unsubscribe(subscription)