Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/utilities/messaging/memory.py: 65%

205 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1from __future__ import annotations 1a

2 

3import asyncio 1a

4import copy 1a

5import threading 1a

6from collections import defaultdict 1a

7from collections.abc import AsyncGenerator, Iterable, Mapping, MutableMapping 1a

8from contextlib import asynccontextmanager 1a

9from dataclasses import asdict, dataclass 1a

10from datetime import timedelta 1a

11from pathlib import Path 1a

12from types import TracebackType 1a

13from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union 1a

14from uuid import uuid4 1a

15 

16import anyio 1a

17from cachetools import TTLCache 1a

18from exceptiongroup import BaseExceptionGroup # novermin 1a

19from pydantic_core import to_json 1a

20from typing_extensions import Self 1a

21 

22from prefect.logging import get_logger 1a

23from prefect.server.utilities.messaging import Cache as _Cache 1a

24from prefect.server.utilities.messaging import Consumer as _Consumer 1a

25from prefect.server.utilities.messaging import Message, MessageHandler, StopConsumer 1a

26from prefect.server.utilities.messaging import Publisher as _Publisher 1a

27from prefect.settings.context import get_current_settings 1a

28 

29if TYPE_CHECKING: 29 ↛ 30line 29 didn't jump to line 30 because the condition on line 29 was never true1a

30 import logging 

31 

32logger: "logging.Logger" = get_logger(__name__) 1a

33 

34# Simple global counters by topic with thread-safe access 

35_metrics_lock: threading.Lock | None = None 1a

36METRICS: dict[str, dict[str, int]] = defaultdict( 1a

37 lambda: { 

38 "published": 0, 

39 "retried": 0, 

40 "consumed": 0, 

41 } 

42) 

43 

44 

45async def log_metrics_periodically(interval: float = 2.0) -> None: 1a

46 if _metrics_lock is None: 

47 return 

48 while True: 

49 await asyncio.sleep(interval) 

50 with _metrics_lock: 

51 for topic, data in METRICS.items(): 

52 if data["published"] == 0: 

53 continue 

54 depth = data["published"] - data["consumed"] 

55 logger.debug( 

56 "Topic=%r | published=%d consumed=%d retried=%d depth=%d", 

57 topic, 

58 data["published"], 

59 data["consumed"], 

60 data["retried"], 

61 depth, 

62 ) 

63 

64 

65async def update_metric(topic: str, key: str, amount: int = 1) -> None: 1a

66 global _metrics_lock 

67 if _metrics_lock is None: 1cdefgb

68 _metrics_lock = threading.Lock() 1c

69 with _metrics_lock: 1cdefgb

70 METRICS[topic][key] += amount 1cdefgb

71 

72 

73@dataclass 1a

74class MemoryMessage: 1a

75 data: Union[bytes, str] 1a

76 attributes: Mapping[str, Any] 1a

77 retry_count: int = 0 1a

78 

79 

80class Subscription: 1a

81 """ 

82 A subscription to a topic. 

83 

84 Messages are delivered to the subscription's queue and retried up to a 

85 maximum number of times. If a message cannot be delivered after the maximum 

86 number of retries it is moved to the dead letter queue. 

87 

88 The dead letter queue is a directory of JSON files containing the serialized 

89 message. 

90 

91 Messages remain in the dead letter queue until they are removed manually. 

92 

93 Attributes: 

94 topic: The topic that the subscription receives messages from. 

95 max_retries: The maximum number of times a message will be retried for 

96 this subscription. 

97 dead_letter_queue_path: The path to the dead letter queue folder. 

98 """ 

99 

100 def __init__( 1a

101 self, 

102 topic: "Topic", 

103 max_retries: int = 3, 

104 dead_letter_queue_path: Path | str | None = None, 

105 ) -> None: 

106 self.topic = topic 1h

107 self.max_retries = max_retries 1h

108 self.dead_letter_queue_path: Path = ( 1h

109 Path(dead_letter_queue_path) 

110 if dead_letter_queue_path 

111 else get_current_settings().home / "dlq" 

112 ) 

113 self._queue: asyncio.Queue[MemoryMessage] = asyncio.Queue(maxsize=10000) 1h

114 self._retry: asyncio.Queue[MemoryMessage] = asyncio.Queue(maxsize=1000) 1h

115 

116 async def deliver(self, message: MemoryMessage) -> None: 1a

117 """ 

118 Deliver a message to the subscription's queue. 

119 

120 Args: 

121 message: The message to deliver. 

122 """ 

123 try: 1cdefgb

124 self._queue.put_nowait(message) 1cdefgb

125 await update_metric(self.topic.name, "published") 1cdefgb

126 logger.debug( 1cdefgb

127 "Delivered message to topic=%r queue_size=%d retry_queue_size=%d", 

128 self.topic.name, 

129 self._queue.qsize(), 

130 self._retry.qsize(), 

131 ) 

132 except asyncio.QueueFull: 

133 logger.warning( 

134 "Subscription queue is full, dropping message for topic=%r queue_size=%d retry_queue_size=%d", 

135 self.topic.name, 

136 self._queue.qsize(), 

137 self._retry.qsize(), 

138 ) 

139 

140 async def retry(self, message: MemoryMessage) -> None: 1a

141 """ 

142 Place a message back on the retry queue. 

143 

144 If the message has retried more than the maximum number of times it is 

145 moved to the dead letter queue. 

146 

147 Args: 

148 message: The message to retry. 

149 """ 

150 message.retry_count += 1 

151 if message.retry_count > self.max_retries: 

152 logger.warning( 

153 "Message failed after %d retries and will be moved to the dead letter queue", 

154 message.retry_count, 

155 extra={"event_message": message}, 

156 ) 

157 await self.send_to_dead_letter_queue(message) 

158 else: 

159 await self._retry.put(message) 

160 await update_metric(self.topic.name, "retried") 

161 logger.debug( 

162 "Retried message on topic=%r retry_count=%d queue_size=%d retry_queue_size=%d", 

163 self.topic.name, 

164 message.retry_count, 

165 self._queue.qsize(), 

166 self._retry.qsize(), 

167 ) 

168 

169 async def get(self) -> MemoryMessage: 1ab

170 """ 

171 Get a message from the subscription's queue. 

172 """ 

173 if not self._retry.empty(): 173 ↛ 174line 173 didn't jump to line 174 because the condition on line 173 was never true1icdefgb

174 return await self._retry.get() 

175 return await self._queue.get() 1icdefgb

176 

177 async def send_to_dead_letter_queue(self, message: MemoryMessage) -> None: 1a

178 """ 

179 Send a message to the dead letter queue. 

180 

181 The dead letter queue is a directory of JSON files containing the 

182 serialized messages. 

183 

184 Args: 

185 message: The message to send to the dead letter queue. 

186 """ 

187 self.dead_letter_queue_path.mkdir(parents=True, exist_ok=True) 

188 try: 

189 await anyio.Path(self.dead_letter_queue_path / uuid4().hex).write_bytes( 

190 to_json(asdict(message)) 

191 ) 

192 except Exception as e: 

193 logger.warning("Failed to write message to dead letter queue", exc_info=e) 

194 

195 

196class Topic: 1a

197 _topics: dict[str, "Topic"] = {} 1a

198 

199 name: str 1a

200 _subscriptions: list[Subscription] 1a

201 

202 def __init__(self, name: str) -> None: 1a

203 self.name = name 1h

204 self._subscriptions = [] 1hb

205 

206 @classmethod 1a

207 def by_name(cls, name: str) -> "Topic": 1a

208 try: 1hcdefgb

209 return cls._topics[name] 1hcdefgb

210 except KeyError: 1hb

211 topic = cls(name) 1hb

212 cls._topics[name] = topic 1hb

213 return topic 1hb

214 

215 @classmethod 1ab

216 def clear_all(cls) -> None: 1ab

217 for topic in cls._topics.values(): 217 ↛ anywhereline 217 didn't jump anywhere: it always raised an exception.

218 topic.clear() 

219 cls._topics = {} 

220 

221 def subscribe(self, **subscription_kwargs: Any) -> Subscription: 1a

222 subscription = Subscription(self, **subscription_kwargs) 1h

223 self._subscriptions.append(subscription) 1h

224 return subscription 1h

225 

226 def unsubscribe(self, subscription: Subscription) -> None: 1a

227 self._subscriptions.remove(subscription) 

228 

229 def clear(self) -> None: 1ab

230 for subscription in self._subscriptions: 230 ↛ 232line 230 didn't jump to line 232 because the loop on line 230 didn't complete

231 self.unsubscribe(subscription) 

232 self._subscriptions = [] 

233 

234 async def publish(self, message: MemoryMessage) -> None: 1a

235 for subscription in self._subscriptions: 1cdefgb

236 # Ensure that each subscription gets its own copy of the message 

237 await subscription.deliver(copy.deepcopy(message)) 1cdefgb

238 

239 

240@asynccontextmanager 1a

241async def break_topic(): 1a

242 from unittest import mock 

243 

244 publishing_mock = mock.AsyncMock(side_effect=ValueError("oops")) 

245 

246 with mock.patch( 

247 "prefect.server.utilities.messaging.memory.Topic.publish", 

248 publishing_mock, 

249 ): 

250 yield 

251 

252 

253M = TypeVar("M", bound=Message) 1a

254 

255 

256class Cache(_Cache): 1a

257 _recently_seen_messages: MutableMapping[str, bool] = TTLCache( 1a

258 maxsize=1000, 

259 ttl=timedelta(minutes=5).total_seconds(), 

260 ) 

261 

262 async def clear_recently_seen_messages(self) -> None: 1a

263 self._recently_seen_messages.clear() 

264 

265 async def without_duplicates( 1a

266 self, attribute: str, messages: Iterable[M] 

267 ) -> list[M]: 

268 messages_with_attribute: list[M] = [] 1cdefgb

269 messages_without_attribute: list[M] = [] 1cdefgb

270 

271 for m in messages: 1cdefgb

272 if not m.attributes or attribute not in m.attributes: 272 ↛ 273line 272 didn't jump to line 273 because the condition on line 272 was never true1cdefgb

273 logger.warning( 

274 "Message is missing deduplication attribute %r", 

275 attribute, 

276 extra={"event_message": m}, 

277 ) 

278 messages_without_attribute.append(m) 

279 continue 

280 

281 if self._recently_seen_messages.get(m.attributes[attribute]): 281 ↛ 282line 281 didn't jump to line 282 because the condition on line 281 was never true1cdefgb

282 continue 

283 

284 self._recently_seen_messages[m.attributes[attribute]] = True 1cdefgb

285 messages_with_attribute.append(m) 1cdefgb

286 

287 return messages_with_attribute + messages_without_attribute 1cdefgb

288 

289 async def forget_duplicates(self, attribute: str, messages: Iterable[M]) -> None: 1a

290 for m in messages: 

291 if not m.attributes or attribute not in m.attributes: 

292 logger.warning( 

293 "Message is missing deduplication attribute %r", 

294 attribute, 

295 extra={"event_message": m}, 

296 ) 

297 continue 

298 self._recently_seen_messages.pop(m.attributes[attribute], None) 

299 

300 

301class Publisher(_Publisher): 1a

302 def __init__(self, topic: str, cache: Cache, deduplicate_by: Optional[str] = None): 1a

303 self.topic: Topic = Topic.by_name(topic) 1cdefgb

304 self.deduplicate_by = deduplicate_by 1cdefgb

305 self._cache = cache 1cdefgb

306 

307 async def __aenter__(self) -> Self: 1a

308 return self 1cdefgb

309 

310 async def __aexit__( 1a

311 self, 

312 exc_type: type[BaseException] | None, 

313 exc_val: BaseException | None, 

314 exc_tb: TracebackType | None, 

315 ) -> None: 

316 return None 1cdefgb

317 

318 async def publish_data(self, data: bytes, attributes: Mapping[str, str]) -> None: 1a

319 to_publish = [MemoryMessage(data, attributes)] 1cdefgb

320 if self.deduplicate_by: 1cdefgb

321 to_publish = await self._cache.without_duplicates( 1cdefgb

322 self.deduplicate_by, to_publish 

323 ) 

324 

325 try: 1cdefgb

326 for message in to_publish: 1cdefgb

327 await self.topic.publish(message) 1cdefgb

328 except Exception: 

329 if self.deduplicate_by: 

330 await self._cache.forget_duplicates(self.deduplicate_by, to_publish) 

331 raise 

332 

333 

334class Consumer(_Consumer): 1a

335 def __init__( 1a

336 self, 

337 topic: str, 

338 subscription: Optional[Subscription] = None, 

339 concurrency: int = 2, 

340 **kwargs: Any, 

341 ): 

342 self.topic: Topic = Topic.by_name(topic) 1h

343 if not subscription: 343 ↛ 345line 343 didn't jump to line 345 because the condition on line 343 was always true1h

344 subscription = self.topic.subscribe() 1h

345 assert subscription.topic is self.topic 1h

346 self.subscription = subscription 1h

347 self.concurrency = concurrency 1h

348 

349 async def run(self, handler: MessageHandler) -> None: 1a

350 try: 1i

351 async with anyio.create_task_group() as tg: 1i

352 for _ in range(self.concurrency): 1i

353 tg.start_soon(self._consume_loop, handler) 1i

354 except BaseExceptionGroup as group: # novermin 

355 if all(isinstance(exc, StopConsumer) for exc in group.exceptions): 

356 logger.debug("StopConsumer received") 

357 return # Exit cleanly when all tasks stop 

358 # Re-raise if any non-StopConsumer exceptions 

359 raise group 

360 

361 async def cleanup(self) -> None: 1a

362 """ 

363 Cleanup resources by unsubscribing from the topic. 

364 

365 This should be called when the consumer is no longer needed to prevent 

366 memory leaks from orphaned subscriptions. 

367 """ 

368 self.topic.unsubscribe(self.subscription) 

369 logger.debug("Unsubscribed from topic=%r", self.topic.name) 

370 

371 async def _consume_loop(self, handler: MessageHandler) -> None: 1a

372 while True: 1icdefgb

373 message = await self.subscription.get() 1icdefgb

374 try: 1cdefgb

375 await handler(message) 1cdefgb

376 await update_metric(self.topic.name, "consumed") 1cdefgb

377 except StopConsumer as e: 

378 if not e.ack: 

379 await self.subscription.retry(message) 

380 raise # Propagate to task group 

381 except Exception: 

382 logger.exception("Failed in consume_loop") 

383 await self.subscription.retry(message) 

384 

385 

386@asynccontextmanager 1a

387async def ephemeral_subscription(topic: str) -> AsyncGenerator[Mapping[str, Any], None]: 1a

388 subscription = Topic.by_name(topic).subscribe() 

389 try: 

390 yield {"topic": topic, "subscription": subscription} 

391 finally: 

392 Topic.by_name(topic).unsubscribe(subscription)