Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/utilities/messaging/memory.py: 65%
205 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
1from __future__ import annotations 1a
3import asyncio 1a
4import copy 1a
5import threading 1a
6from collections import defaultdict 1a
7from collections.abc import AsyncGenerator, Iterable, Mapping, MutableMapping 1a
8from contextlib import asynccontextmanager 1a
9from dataclasses import asdict, dataclass 1a
10from datetime import timedelta 1a
11from pathlib import Path 1a
12from types import TracebackType 1a
13from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union 1a
14from uuid import uuid4 1a
16import anyio 1a
17from cachetools import TTLCache 1a
18from exceptiongroup import BaseExceptionGroup # novermin 1a
19from pydantic_core import to_json 1a
20from typing_extensions import Self 1a
22from prefect.logging import get_logger 1a
23from prefect.server.utilities.messaging import Cache as _Cache 1a
24from prefect.server.utilities.messaging import Consumer as _Consumer 1a
25from prefect.server.utilities.messaging import Message, MessageHandler, StopConsumer 1a
26from prefect.server.utilities.messaging import Publisher as _Publisher 1a
27from prefect.settings.context import get_current_settings 1a
29if TYPE_CHECKING: 29 ↛ 30line 29 didn't jump to line 30 because the condition on line 29 was never true1a
30 import logging
32logger: "logging.Logger" = get_logger(__name__) 1a
34# Simple global counters by topic with thread-safe access
35_metrics_lock: threading.Lock | None = None 1a
36METRICS: dict[str, dict[str, int]] = defaultdict( 1a
37 lambda: {
38 "published": 0,
39 "retried": 0,
40 "consumed": 0,
41 }
42)
45async def log_metrics_periodically(interval: float = 2.0) -> None: 1a
46 if _metrics_lock is None:
47 return
48 while True:
49 await asyncio.sleep(interval)
50 with _metrics_lock:
51 for topic, data in METRICS.items():
52 if data["published"] == 0:
53 continue
54 depth = data["published"] - data["consumed"]
55 logger.debug(
56 "Topic=%r | published=%d consumed=%d retried=%d depth=%d",
57 topic,
58 data["published"],
59 data["consumed"],
60 data["retried"],
61 depth,
62 )
65async def update_metric(topic: str, key: str, amount: int = 1) -> None: 1a
66 global _metrics_lock
67 if _metrics_lock is None: 1cdefgb
68 _metrics_lock = threading.Lock() 1c
69 with _metrics_lock: 1cdefgb
70 METRICS[topic][key] += amount 1cdefgb
73@dataclass 1a
74class MemoryMessage: 1a
75 data: Union[bytes, str] 1a
76 attributes: Mapping[str, Any] 1a
77 retry_count: int = 0 1a
80class Subscription: 1a
81 """
82 A subscription to a topic.
84 Messages are delivered to the subscription's queue and retried up to a
85 maximum number of times. If a message cannot be delivered after the maximum
86 number of retries it is moved to the dead letter queue.
88 The dead letter queue is a directory of JSON files containing the serialized
89 message.
91 Messages remain in the dead letter queue until they are removed manually.
93 Attributes:
94 topic: The topic that the subscription receives messages from.
95 max_retries: The maximum number of times a message will be retried for
96 this subscription.
97 dead_letter_queue_path: The path to the dead letter queue folder.
98 """
100 def __init__( 1a
101 self,
102 topic: "Topic",
103 max_retries: int = 3,
104 dead_letter_queue_path: Path | str | None = None,
105 ) -> None:
106 self.topic = topic 1h
107 self.max_retries = max_retries 1h
108 self.dead_letter_queue_path: Path = ( 1h
109 Path(dead_letter_queue_path)
110 if dead_letter_queue_path
111 else get_current_settings().home / "dlq"
112 )
113 self._queue: asyncio.Queue[MemoryMessage] = asyncio.Queue(maxsize=10000) 1h
114 self._retry: asyncio.Queue[MemoryMessage] = asyncio.Queue(maxsize=1000) 1h
116 async def deliver(self, message: MemoryMessage) -> None: 1a
117 """
118 Deliver a message to the subscription's queue.
120 Args:
121 message: The message to deliver.
122 """
123 try: 1cdefgb
124 self._queue.put_nowait(message) 1cdefgb
125 await update_metric(self.topic.name, "published") 1cdefgb
126 logger.debug( 1cdefgb
127 "Delivered message to topic=%r queue_size=%d retry_queue_size=%d",
128 self.topic.name,
129 self._queue.qsize(),
130 self._retry.qsize(),
131 )
132 except asyncio.QueueFull:
133 logger.warning(
134 "Subscription queue is full, dropping message for topic=%r queue_size=%d retry_queue_size=%d",
135 self.topic.name,
136 self._queue.qsize(),
137 self._retry.qsize(),
138 )
140 async def retry(self, message: MemoryMessage) -> None: 1a
141 """
142 Place a message back on the retry queue.
144 If the message has retried more than the maximum number of times it is
145 moved to the dead letter queue.
147 Args:
148 message: The message to retry.
149 """
150 message.retry_count += 1
151 if message.retry_count > self.max_retries:
152 logger.warning(
153 "Message failed after %d retries and will be moved to the dead letter queue",
154 message.retry_count,
155 extra={"event_message": message},
156 )
157 await self.send_to_dead_letter_queue(message)
158 else:
159 await self._retry.put(message)
160 await update_metric(self.topic.name, "retried")
161 logger.debug(
162 "Retried message on topic=%r retry_count=%d queue_size=%d retry_queue_size=%d",
163 self.topic.name,
164 message.retry_count,
165 self._queue.qsize(),
166 self._retry.qsize(),
167 )
169 async def get(self) -> MemoryMessage: 1ab
170 """
171 Get a message from the subscription's queue.
172 """
173 if not self._retry.empty(): 173 ↛ 174line 173 didn't jump to line 174 because the condition on line 173 was never true1icdefgb
174 return await self._retry.get()
175 return await self._queue.get() 1icdefgb
177 async def send_to_dead_letter_queue(self, message: MemoryMessage) -> None: 1a
178 """
179 Send a message to the dead letter queue.
181 The dead letter queue is a directory of JSON files containing the
182 serialized messages.
184 Args:
185 message: The message to send to the dead letter queue.
186 """
187 self.dead_letter_queue_path.mkdir(parents=True, exist_ok=True)
188 try:
189 await anyio.Path(self.dead_letter_queue_path / uuid4().hex).write_bytes(
190 to_json(asdict(message))
191 )
192 except Exception as e:
193 logger.warning("Failed to write message to dead letter queue", exc_info=e)
196class Topic: 1a
197 _topics: dict[str, "Topic"] = {} 1a
199 name: str 1a
200 _subscriptions: list[Subscription] 1a
202 def __init__(self, name: str) -> None: 1a
203 self.name = name 1h
204 self._subscriptions = [] 1hb
206 @classmethod 1a
207 def by_name(cls, name: str) -> "Topic": 1a
208 try: 1hcdefgb
209 return cls._topics[name] 1hcdefgb
210 except KeyError: 1hb
211 topic = cls(name) 1hb
212 cls._topics[name] = topic 1hb
213 return topic 1hb
215 @classmethod 1ab
216 def clear_all(cls) -> None: 1ab
217 for topic in cls._topics.values(): 217 ↛ anywhereline 217 didn't jump anywhere: it always raised an exception.
218 topic.clear()
219 cls._topics = {}
221 def subscribe(self, **subscription_kwargs: Any) -> Subscription: 1a
222 subscription = Subscription(self, **subscription_kwargs) 1h
223 self._subscriptions.append(subscription) 1h
224 return subscription 1h
226 def unsubscribe(self, subscription: Subscription) -> None: 1a
227 self._subscriptions.remove(subscription)
229 def clear(self) -> None: 1ab
230 for subscription in self._subscriptions: 230 ↛ 232line 230 didn't jump to line 232 because the loop on line 230 didn't complete
231 self.unsubscribe(subscription)
232 self._subscriptions = []
234 async def publish(self, message: MemoryMessage) -> None: 1a
235 for subscription in self._subscriptions: 1cdefgb
236 # Ensure that each subscription gets its own copy of the message
237 await subscription.deliver(copy.deepcopy(message)) 1cdefgb
240@asynccontextmanager 1a
241async def break_topic(): 1a
242 from unittest import mock
244 publishing_mock = mock.AsyncMock(side_effect=ValueError("oops"))
246 with mock.patch(
247 "prefect.server.utilities.messaging.memory.Topic.publish",
248 publishing_mock,
249 ):
250 yield
253M = TypeVar("M", bound=Message) 1a
256class Cache(_Cache): 1a
257 _recently_seen_messages: MutableMapping[str, bool] = TTLCache( 1a
258 maxsize=1000,
259 ttl=timedelta(minutes=5).total_seconds(),
260 )
262 async def clear_recently_seen_messages(self) -> None: 1a
263 self._recently_seen_messages.clear()
265 async def without_duplicates( 1a
266 self, attribute: str, messages: Iterable[M]
267 ) -> list[M]:
268 messages_with_attribute: list[M] = [] 1cdefgb
269 messages_without_attribute: list[M] = [] 1cdefgb
271 for m in messages: 1cdefgb
272 if not m.attributes or attribute not in m.attributes: 272 ↛ 273line 272 didn't jump to line 273 because the condition on line 272 was never true1cdefgb
273 logger.warning(
274 "Message is missing deduplication attribute %r",
275 attribute,
276 extra={"event_message": m},
277 )
278 messages_without_attribute.append(m)
279 continue
281 if self._recently_seen_messages.get(m.attributes[attribute]): 281 ↛ 282line 281 didn't jump to line 282 because the condition on line 281 was never true1cdefgb
282 continue
284 self._recently_seen_messages[m.attributes[attribute]] = True 1cdefgb
285 messages_with_attribute.append(m) 1cdefgb
287 return messages_with_attribute + messages_without_attribute 1cdefgb
289 async def forget_duplicates(self, attribute: str, messages: Iterable[M]) -> None: 1a
290 for m in messages:
291 if not m.attributes or attribute not in m.attributes:
292 logger.warning(
293 "Message is missing deduplication attribute %r",
294 attribute,
295 extra={"event_message": m},
296 )
297 continue
298 self._recently_seen_messages.pop(m.attributes[attribute], None)
301class Publisher(_Publisher): 1a
302 def __init__(self, topic: str, cache: Cache, deduplicate_by: Optional[str] = None): 1a
303 self.topic: Topic = Topic.by_name(topic) 1cdefgb
304 self.deduplicate_by = deduplicate_by 1cdefgb
305 self._cache = cache 1cdefgb
307 async def __aenter__(self) -> Self: 1a
308 return self 1cdefgb
310 async def __aexit__( 1a
311 self,
312 exc_type: type[BaseException] | None,
313 exc_val: BaseException | None,
314 exc_tb: TracebackType | None,
315 ) -> None:
316 return None 1cdefgb
318 async def publish_data(self, data: bytes, attributes: Mapping[str, str]) -> None: 1a
319 to_publish = [MemoryMessage(data, attributes)] 1cdefgb
320 if self.deduplicate_by: 1cdefgb
321 to_publish = await self._cache.without_duplicates( 1cdefgb
322 self.deduplicate_by, to_publish
323 )
325 try: 1cdefgb
326 for message in to_publish: 1cdefgb
327 await self.topic.publish(message) 1cdefgb
328 except Exception:
329 if self.deduplicate_by:
330 await self._cache.forget_duplicates(self.deduplicate_by, to_publish)
331 raise
334class Consumer(_Consumer): 1a
335 def __init__( 1a
336 self,
337 topic: str,
338 subscription: Optional[Subscription] = None,
339 concurrency: int = 2,
340 **kwargs: Any,
341 ):
342 self.topic: Topic = Topic.by_name(topic) 1h
343 if not subscription: 343 ↛ 345line 343 didn't jump to line 345 because the condition on line 343 was always true1h
344 subscription = self.topic.subscribe() 1h
345 assert subscription.topic is self.topic 1h
346 self.subscription = subscription 1h
347 self.concurrency = concurrency 1h
349 async def run(self, handler: MessageHandler) -> None: 1a
350 try: 1i
351 async with anyio.create_task_group() as tg: 1i
352 for _ in range(self.concurrency): 1i
353 tg.start_soon(self._consume_loop, handler) 1i
354 except BaseExceptionGroup as group: # novermin
355 if all(isinstance(exc, StopConsumer) for exc in group.exceptions):
356 logger.debug("StopConsumer received")
357 return # Exit cleanly when all tasks stop
358 # Re-raise if any non-StopConsumer exceptions
359 raise group
361 async def cleanup(self) -> None: 1a
362 """
363 Cleanup resources by unsubscribing from the topic.
365 This should be called when the consumer is no longer needed to prevent
366 memory leaks from orphaned subscriptions.
367 """
368 self.topic.unsubscribe(self.subscription)
369 logger.debug("Unsubscribed from topic=%r", self.topic.name)
371 async def _consume_loop(self, handler: MessageHandler) -> None: 1a
372 while True: 1icdefgb
373 message = await self.subscription.get() 1icdefgb
374 try: 1cdefgb
375 await handler(message) 1cdefgb
376 await update_metric(self.topic.name, "consumed") 1cdefgb
377 except StopConsumer as e:
378 if not e.ack:
379 await self.subscription.retry(message)
380 raise # Propagate to task group
381 except Exception:
382 logger.exception("Failed in consume_loop")
383 await self.subscription.retry(message)
386@asynccontextmanager 1a
387async def ephemeral_subscription(topic: str) -> AsyncGenerator[Mapping[str, Any], None]: 1a
388 subscription = Topic.by_name(topic).subscribe()
389 try:
390 yield {"topic": topic, "subscription": subscription}
391 finally:
392 Topic.by_name(topic).unsubscribe(subscription)