Coverage for polar/worker/__init__.py: 71%
100 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
1import contextlib 1ba
2import functools 1ba
3from collections.abc import Awaitable, Callable 1ba
4from enum import IntEnum 1ba
5from typing import Any, ParamSpec 1ba
7import dramatiq 1ba
8import logfire 1ba
9import redis 1ba
10import structlog 1ba
11from apscheduler.triggers.cron import CronTrigger 1ba
12from dramatiq import actor as _actor 1ba
13from dramatiq import middleware 1ba
14from dramatiq.brokers.redis import RedisBroker 1ba
16from polar.config import settings 1ba
17from polar.logfire import instrument_httpx 1ba
19from ._encoder import JSONEncoder 1ba
20from ._enqueue import JobQueueManager, enqueue_events, enqueue_job 1ba
21from ._health import HealthMiddleware 1ba
22from ._redis import RedisMiddleware 1ba
23from ._sqlalchemy import AsyncSessionMaker, SQLAlchemyMiddleware 1ba
26class MaxRetriesMiddleware(dramatiq.Middleware): 1ba
27 """Middleware to set the max_retries option for a message."""
29 def before_process_message( 1ba
30 self, broker: dramatiq.Broker, message: dramatiq.Message[Any]
31 ) -> None:
32 actor = broker.get_actor(message.actor_name)
33 max_retries = message.options.get(
34 "max_retries", actor.options.get("max_retries", settings.WORKER_MAX_RETRIES)
35 )
36 message.options["max_retries"] = max_retries
39def get_retries() -> int: 1ba
40 message = middleware.CurrentMessage.get_current_message()
41 assert message is not None
42 return message.options.get("retries", 0)
45def can_retry() -> bool: 1ba
46 message = middleware.CurrentMessage.get_current_message()
47 assert message is not None
48 return get_retries() < message.options["max_retries"]
51class SchedulerMiddleware(dramatiq.Middleware): 1ba
52 """Middleware to manage scheduled jobs using APScheduler."""
54 def __init__(self) -> None: 1ba
55 self.cron_triggers: list[tuple[Callable[..., Any], CronTrigger]] = [] 1ba
57 @property 1ba
58 def actor_options(self) -> set[str]: 1ba
59 return {"cron_trigger"} 1ba
61 def after_declare_actor( 1ba
62 self, broker: dramatiq.Broker, actor: dramatiq.Actor[Any, Any]
63 ) -> None:
64 if cron_trigger := actor.options.get("cron_trigger"): 1a
65 self.cron_triggers.append((actor.send, cron_trigger)) 1a
68scheduler_middleware = SchedulerMiddleware() 1ba
71class LogContextMiddleware(dramatiq.Middleware): 1ba
72 """Middleware to manage log context for each message."""
74 def before_process_message( 1ba
75 self, broker: dramatiq.Broker, message: dramatiq.Message[Any]
76 ) -> None:
77 structlog.contextvars.bind_contextvars(
78 actor_name=message.actor_name, message_id=message.message_id
79 )
81 def after_process_message( 1ba
82 self,
83 broker: dramatiq.Broker,
84 message: dramatiq.Message[Any],
85 *,
86 result: Any | None = None,
87 exception: Exception | None = None,
88 ) -> None:
89 structlog.contextvars.unbind_contextvars("actor_name", "message_id")
91 def after_skip_message( 1ba
92 self, broker: dramatiq.Broker, message: dramatiq.Message[Any]
93 ) -> None:
94 return self.after_process_message(broker, message)
97class LogfireMiddleware(dramatiq.Middleware): 1ba
98 """Middleware to manage a Logfire span when handling a message."""
100 def before_worker_boot( 1ba
101 self, broker: dramatiq.Broker, worker: dramatiq.Worker
102 ) -> None:
103 instrument_httpx()
105 def before_process_message( 1ba
106 self, broker: dramatiq.Broker, message: dramatiq.Message[Any]
107 ) -> None:
108 logfire_stack = contextlib.ExitStack()
109 actor_name = message.actor_name
110 if actor_name in settings.LOGFIRE_IGNORED_ACTORS:
111 logfire_span = logfire_stack.enter_context(
112 logfire.suppress_instrumentation()
113 )
114 else:
115 logfire_span = logfire_stack.enter_context(
116 logfire.span("TASK {actor}", actor=actor_name, message=message.asdict())
117 )
118 message.options["logfire_stack"] = logfire_stack
120 def after_process_message( 1ba
121 self,
122 broker: dramatiq.Broker,
123 message: dramatiq.Message[Any],
124 *,
125 result: Any | None = None,
126 exception: Exception | None = None,
127 ) -> None:
128 logfire_stack: contextlib.ExitStack | None = message.options.pop(
129 "logfire_stack", None
130 )
131 if logfire_stack is not None:
132 logfire_stack.close()
134 # THEORY: force flush logfire events after each task to avoid memory bursts
135 logfire.force_flush()
137 def after_skip_message( 1ba
138 self, broker: dramatiq.Broker, message: dramatiq.Message[Any]
139 ) -> None:
140 return self.after_process_message(broker, message)
143broker = RedisBroker( 1ba
144 connection_pool=redis.ConnectionPool.from_url(
145 settings.redis_url,
146 client_name=f"{settings.ENV.value}.worker.dramatiq",
147 ),
148 # Override default middlewares
149 middleware=[
150 m()
151 for m in (
152 middleware.AgeLimit,
153 middleware.TimeLimit,
154 middleware.ShutdownNotifications,
155 )
156 ],
157)
159broker.add_middleware( 1ba
160 middleware.Retries(
161 max_retries=settings.WORKER_MAX_RETRIES,
162 min_backoff=settings.WORKER_MIN_BACKOFF_MILLISECONDS,
163 )
164)
165broker.add_middleware(HealthMiddleware()) 1ba
166broker.add_middleware(middleware.AsyncIO()) 1ba
167broker.add_middleware(middleware.CurrentMessage()) 1ba
168broker.add_middleware(MaxRetriesMiddleware()) 1ba
169broker.add_middleware(SQLAlchemyMiddleware()) 1ba
170broker.add_middleware(RedisMiddleware()) 1ba
171broker.add_middleware(scheduler_middleware) 1ba
172broker.add_middleware(LogfireMiddleware()) 1ba
173broker.add_middleware(LogContextMiddleware()) 1ba
174dramatiq.set_broker(broker) 1ba
175dramatiq.set_encoder(JSONEncoder()) 1ba
178class TaskPriority(IntEnum): 1ba
179 HIGH = 0 1ba
180 MEDIUM = 50 1ba
181 LOW = 100 1ba
184class TaskQueue: 1ba
185 HIGH_PRIORITY = "high_priority" 1ba
186 DEFAULT = "default" 1ba
189P = ParamSpec("P") 1ba
192def actor[**P, R]( 1ba
193 actor_class: Callable[..., dramatiq.Actor[Any, Any]] = dramatiq.Actor,
194 actor_name: str | None = None,
195 queue_name: str | None = None,
196 priority: TaskPriority = TaskPriority.LOW,
197 broker: dramatiq.Broker | None = None,
198 **options: Any,
199) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
200 if queue_name is None: 200 ↛ 207line 200 didn't jump to line 207 because the condition on line 200 was always true1a
201 queue_name = ( 1a
202 TaskQueue.HIGH_PRIORITY
203 if priority == TaskPriority.HIGH
204 else TaskQueue.DEFAULT
205 )
207 def decorator( 1a
208 fn: Callable[P, Awaitable[R]],
209 ) -> Callable[P, Awaitable[R]]:
210 @functools.wraps(fn) 1a
211 async def _wrapped_fn(*args: P.args, **kwargs: P.kwargs) -> R: 1a
212 async with JobQueueManager.open(
213 dramatiq.get_broker(), RedisMiddleware.get()
214 ):
215 return await fn(*args, **kwargs)
217 _actor( 1a
218 _wrapped_fn, # type: ignore
219 actor_class=actor_class,
220 actor_name=actor_name,
221 queue_name=queue_name,
222 priority=priority,
223 broker=broker,
224 **options,
225 )
227 return _wrapped_fn 1a
229 return decorator 1a
232__all__ = [ 1ba
233 "actor",
234 "CronTrigger",
235 "AsyncSessionMaker",
236 "RedisMiddleware",
237 "JobQueueManager",
238 "scheduler_middleware",
239 "enqueue_job",
240 "enqueue_events",
241 "get_retries",
242 "can_retry",
243 "TaskPriority",
244 "TaskQueue",
245]