Coverage for polar/worker/__init__.py: 71%

100 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1import contextlib 1ba

2import functools 1ba

3from collections.abc import Awaitable, Callable 1ba

4from enum import IntEnum 1ba

5from typing import Any, ParamSpec 1ba

6 

7import dramatiq 1ba

8import logfire 1ba

9import redis 1ba

10import structlog 1ba

11from apscheduler.triggers.cron import CronTrigger 1ba

12from dramatiq import actor as _actor 1ba

13from dramatiq import middleware 1ba

14from dramatiq.brokers.redis import RedisBroker 1ba

15 

16from polar.config import settings 1ba

17from polar.logfire import instrument_httpx 1ba

18 

19from ._encoder import JSONEncoder 1ba

20from ._enqueue import JobQueueManager, enqueue_events, enqueue_job 1ba

21from ._health import HealthMiddleware 1ba

22from ._redis import RedisMiddleware 1ba

23from ._sqlalchemy import AsyncSessionMaker, SQLAlchemyMiddleware 1ba

24 

25 

26class MaxRetriesMiddleware(dramatiq.Middleware): 1ba

27 """Middleware to set the max_retries option for a message.""" 

28 

29 def before_process_message( 1ba

30 self, broker: dramatiq.Broker, message: dramatiq.Message[Any] 

31 ) -> None: 

32 actor = broker.get_actor(message.actor_name) 

33 max_retries = message.options.get( 

34 "max_retries", actor.options.get("max_retries", settings.WORKER_MAX_RETRIES) 

35 ) 

36 message.options["max_retries"] = max_retries 

37 

38 

39def get_retries() -> int: 1ba

40 message = middleware.CurrentMessage.get_current_message() 

41 assert message is not None 

42 return message.options.get("retries", 0) 

43 

44 

45def can_retry() -> bool: 1ba

46 message = middleware.CurrentMessage.get_current_message() 

47 assert message is not None 

48 return get_retries() < message.options["max_retries"] 

49 

50 

51class SchedulerMiddleware(dramatiq.Middleware): 1ba

52 """Middleware to manage scheduled jobs using APScheduler.""" 

53 

54 def __init__(self) -> None: 1ba

55 self.cron_triggers: list[tuple[Callable[..., Any], CronTrigger]] = [] 1ba

56 

57 @property 1ba

58 def actor_options(self) -> set[str]: 1ba

59 return {"cron_trigger"} 1ba

60 

61 def after_declare_actor( 1ba

62 self, broker: dramatiq.Broker, actor: dramatiq.Actor[Any, Any] 

63 ) -> None: 

64 if cron_trigger := actor.options.get("cron_trigger"): 1a

65 self.cron_triggers.append((actor.send, cron_trigger)) 1a

66 

67 

68scheduler_middleware = SchedulerMiddleware() 1ba

69 

70 

71class LogContextMiddleware(dramatiq.Middleware): 1ba

72 """Middleware to manage log context for each message.""" 

73 

74 def before_process_message( 1ba

75 self, broker: dramatiq.Broker, message: dramatiq.Message[Any] 

76 ) -> None: 

77 structlog.contextvars.bind_contextvars( 

78 actor_name=message.actor_name, message_id=message.message_id 

79 ) 

80 

81 def after_process_message( 1ba

82 self, 

83 broker: dramatiq.Broker, 

84 message: dramatiq.Message[Any], 

85 *, 

86 result: Any | None = None, 

87 exception: Exception | None = None, 

88 ) -> None: 

89 structlog.contextvars.unbind_contextvars("actor_name", "message_id") 

90 

91 def after_skip_message( 1ba

92 self, broker: dramatiq.Broker, message: dramatiq.Message[Any] 

93 ) -> None: 

94 return self.after_process_message(broker, message) 

95 

96 

97class LogfireMiddleware(dramatiq.Middleware): 1ba

98 """Middleware to manage a Logfire span when handling a message.""" 

99 

100 def before_worker_boot( 1ba

101 self, broker: dramatiq.Broker, worker: dramatiq.Worker 

102 ) -> None: 

103 instrument_httpx() 

104 

105 def before_process_message( 1ba

106 self, broker: dramatiq.Broker, message: dramatiq.Message[Any] 

107 ) -> None: 

108 logfire_stack = contextlib.ExitStack() 

109 actor_name = message.actor_name 

110 if actor_name in settings.LOGFIRE_IGNORED_ACTORS: 

111 logfire_span = logfire_stack.enter_context( 

112 logfire.suppress_instrumentation() 

113 ) 

114 else: 

115 logfire_span = logfire_stack.enter_context( 

116 logfire.span("TASK {actor}", actor=actor_name, message=message.asdict()) 

117 ) 

118 message.options["logfire_stack"] = logfire_stack 

119 

120 def after_process_message( 1ba

121 self, 

122 broker: dramatiq.Broker, 

123 message: dramatiq.Message[Any], 

124 *, 

125 result: Any | None = None, 

126 exception: Exception | None = None, 

127 ) -> None: 

128 logfire_stack: contextlib.ExitStack | None = message.options.pop( 

129 "logfire_stack", None 

130 ) 

131 if logfire_stack is not None: 

132 logfire_stack.close() 

133 

134 # THEORY: force flush logfire events after each task to avoid memory bursts 

135 logfire.force_flush() 

136 

137 def after_skip_message( 1ba

138 self, broker: dramatiq.Broker, message: dramatiq.Message[Any] 

139 ) -> None: 

140 return self.after_process_message(broker, message) 

141 

142 

143broker = RedisBroker( 1ba

144 connection_pool=redis.ConnectionPool.from_url( 

145 settings.redis_url, 

146 client_name=f"{settings.ENV.value}.worker.dramatiq", 

147 ), 

148 # Override default middlewares 

149 middleware=[ 

150 m() 

151 for m in ( 

152 middleware.AgeLimit, 

153 middleware.TimeLimit, 

154 middleware.ShutdownNotifications, 

155 ) 

156 ], 

157) 

158 

159broker.add_middleware( 1ba

160 middleware.Retries( 

161 max_retries=settings.WORKER_MAX_RETRIES, 

162 min_backoff=settings.WORKER_MIN_BACKOFF_MILLISECONDS, 

163 ) 

164) 

165broker.add_middleware(HealthMiddleware()) 1ba

166broker.add_middleware(middleware.AsyncIO()) 1ba

167broker.add_middleware(middleware.CurrentMessage()) 1ba

168broker.add_middleware(MaxRetriesMiddleware()) 1ba

169broker.add_middleware(SQLAlchemyMiddleware()) 1ba

170broker.add_middleware(RedisMiddleware()) 1ba

171broker.add_middleware(scheduler_middleware) 1ba

172broker.add_middleware(LogfireMiddleware()) 1ba

173broker.add_middleware(LogContextMiddleware()) 1ba

174dramatiq.set_broker(broker) 1ba

175dramatiq.set_encoder(JSONEncoder()) 1ba

176 

177 

178class TaskPriority(IntEnum): 1ba

179 HIGH = 0 1ba

180 MEDIUM = 50 1ba

181 LOW = 100 1ba

182 

183 

184class TaskQueue: 1ba

185 HIGH_PRIORITY = "high_priority" 1ba

186 DEFAULT = "default" 1ba

187 

188 

189P = ParamSpec("P") 1ba

190 

191 

192def actor[**P, R]( 1ba

193 actor_class: Callable[..., dramatiq.Actor[Any, Any]] = dramatiq.Actor, 

194 actor_name: str | None = None, 

195 queue_name: str | None = None, 

196 priority: TaskPriority = TaskPriority.LOW, 

197 broker: dramatiq.Broker | None = None, 

198 **options: Any, 

199) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: 

200 if queue_name is None: 200 ↛ 207line 200 didn't jump to line 207 because the condition on line 200 was always true1a

201 queue_name = ( 1a

202 TaskQueue.HIGH_PRIORITY 

203 if priority == TaskPriority.HIGH 

204 else TaskQueue.DEFAULT 

205 ) 

206 

207 def decorator( 1a

208 fn: Callable[P, Awaitable[R]], 

209 ) -> Callable[P, Awaitable[R]]: 

210 @functools.wraps(fn) 1a

211 async def _wrapped_fn(*args: P.args, **kwargs: P.kwargs) -> R: 1a

212 async with JobQueueManager.open( 

213 dramatiq.get_broker(), RedisMiddleware.get() 

214 ): 

215 return await fn(*args, **kwargs) 

216 

217 _actor( 1a

218 _wrapped_fn, # type: ignore 

219 actor_class=actor_class, 

220 actor_name=actor_name, 

221 queue_name=queue_name, 

222 priority=priority, 

223 broker=broker, 

224 **options, 

225 ) 

226 

227 return _wrapped_fn 1a

228 

229 return decorator 1a

230 

231 

232__all__ = [ 1ba

233 "actor", 

234 "CronTrigger", 

235 "AsyncSessionMaker", 

236 "RedisMiddleware", 

237 "JobQueueManager", 

238 "scheduler_middleware", 

239 "enqueue_job", 

240 "enqueue_events", 

241 "get_retries", 

242 "can_retry", 

243 "TaskPriority", 

244 "TaskQueue", 

245]