Coverage for polar/worker/_enqueue.py: 93%

86 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1import contextlib 1bc

2import contextvars 1bc

3import itertools 1bc

4import uuid 1bc

5from collections import defaultdict 1bc

6from collections.abc import AsyncIterator, Iterable, Mapping 1bc

7from typing import Any, Self, TypeAlias 1bc

8 

9import dramatiq 1bc

10import structlog 1bc

11 

12from polar.logging import Logger 1bc

13from polar.redis import Redis 1bc

14 

15log: Logger = structlog.get_logger() 1bc

16 

17 

18JSONSerializable: TypeAlias = ( 1bc

19 Mapping[str, "JSONSerializable"] 

20 | Iterable["JSONSerializable"] 

21 | str 

22 | int 

23 | float 

24 | bool 

25 | uuid.UUID 

26 | None 

27) 

28 

29 

30_job_queue_manager: contextvars.ContextVar["JobQueueManager | None"] = ( 1bc

31 contextvars.ContextVar("polar.job_queue_manager") 

32) 

33 

34FLUSH_BATCH_SIZE = 50 1bc

35 

36 

37class JobQueueManager: 1bc

38 __slots__ = ("_enqueued_jobs", "_ingested_events") 1bc

39 

40 def __init__(self) -> None: 1bc

41 self._enqueued_jobs: list[ 1ad

42 tuple[str, tuple[JSONSerializable, ...], dict[str, JSONSerializable]] 

43 ] = [] 

44 self._ingested_events: list[uuid.UUID] = [] 1ad

45 

46 def enqueue_job( 1bc

47 self, actor: str, *args: JSONSerializable, **kwargs: JSONSerializable 

48 ) -> None: 

49 self._enqueued_jobs.append((actor, args, kwargs)) 1a

50 log.debug("polar.worker.job_enqueued", actor=actor) 1a

51 

52 def enqueue_events(self, *event_ids: uuid.UUID) -> None: 1bc

53 self._ingested_events.extend(event_ids) 

54 

55 async def flush(self, broker: dramatiq.Broker, redis: Redis) -> None: 1bc

56 if len(self._ingested_events) > 0: 56 ↛ 57line 56 didn't jump to line 57 because the condition on line 56 was never true1ad

57 self.enqueue_job("event.ingested", self._ingested_events) 

58 

59 if not self._enqueued_jobs: 1ad

60 self.reset() 1ad

61 return 1ad

62 

63 queue_messages = defaultdict[str, list[tuple[str, Any]]](list) 1a

64 all_messages: list[tuple[str, Any]] = [] 1a

65 

66 for actor_name, args, kwargs in self._enqueued_jobs: 1a

67 fn: dramatiq.Actor[Any, Any] = broker.get_actor(actor_name) 1a

68 redis_message_id = str(uuid.uuid4()) 1a

69 message = fn.message_with_options( 1a

70 args=args, kwargs=kwargs, redis_message_id=redis_message_id 

71 ) 

72 encoded_message = message.encode() 1a

73 queue_messages[message.queue_name].append( 1a

74 (redis_message_id, encoded_message) 

75 ) 

76 all_messages.append((fn.actor_name, message.encode())) 1a

77 

78 for queue_name, messages in queue_messages.items(): 1a

79 for batch in itertools.batched(messages, FLUSH_BATCH_SIZE): 1a

80 await self._batch_hset_messages(redis, queue_name, batch) 1a

81 await self._batch_rpush_queue( 1a

82 redis, queue_name, (message_id for message_id, _ in batch) 

83 ) 

84 

85 for actor_name, encoded_message in all_messages: 1a

86 log.debug( 1a

87 "polar.worker.job_flushed", actor=actor_name, message=encoded_message 

88 ) 

89 

90 self.reset() 1a

91 

92 async def _batch_hset_messages( 1bc

93 self, 

94 redis: Redis, 

95 queue_name: str, 

96 message_batch: Iterable[tuple[str, Any]], 

97 ) -> None: 

98 """Batch hset operations for message storage.""" 

99 hash_key = f"dramatiq:{queue_name}.msgs" 1a

100 await redis.hset( 1a

101 hash_key, 

102 mapping={ 

103 message_id: encoded_message 

104 for message_id, encoded_message in message_batch 

105 }, 

106 ) 

107 

108 async def _batch_rpush_queue( 1bc

109 self, redis: Redis, queue_name: str, message_ids: Iterable[str] 

110 ) -> None: 

111 """Batch rpush operations for queue entries.""" 

112 queue_key = f"dramatiq:{queue_name}" 1a

113 await redis.rpush(queue_key, *message_ids) 1a

114 

115 def reset(self) -> None: 1bc

116 self._enqueued_jobs = [] 1ad

117 self._ingested_events = [] 1ad

118 

119 @classmethod 1bc

120 def set(cls) -> "Self": 1bc

121 job_queue_manager = cls() 1ad

122 _job_queue_manager.set(job_queue_manager) 1ad

123 return job_queue_manager 1ad

124 

125 @classmethod 1bc

126 def close(cls) -> None: 1bc

127 job_queue_manager = cls.get() 1ad

128 job_queue_manager.reset() 1ad

129 _job_queue_manager.set(None) 1ad

130 

131 @classmethod 1bc

132 @contextlib.asynccontextmanager 1bc

133 async def open(cls, broker: dramatiq.Broker, redis: Redis) -> AsyncIterator["Self"]: 1bc

134 job_queue_manager = cls.set() 1ad

135 try: 1ad

136 yield job_queue_manager 1ad

137 await job_queue_manager.flush(broker, redis) 1ad

138 finally: 

139 cls.close() 1ad

140 

141 @classmethod 1bc

142 def get(cls) -> "JobQueueManager": 1bc

143 job_queue_manager = _job_queue_manager.get() 1ad

144 if job_queue_manager is None: 144 ↛ 145line 144 didn't jump to line 145 because the condition on line 144 was never true1ad

145 raise RuntimeError("JobQueueManager not initialized") 

146 return job_queue_manager 1ad

147 

148 

149def enqueue_job( 1bc

150 actor: str, *args: JSONSerializable, **kwargs: JSONSerializable 

151) -> None: 

152 """Enqueue a job by actor name.""" 

153 job_queue_manager = JobQueueManager.get() 1a

154 job_queue_manager.enqueue_job(actor, *args, **kwargs) 1a

155 

156 

157def enqueue_events(*event_ids: uuid.UUID) -> None: 1bc

158 """Enqueue events to be ingested.""" 

159 job_queue_manager = JobQueueManager.get() 

160 job_queue_manager.enqueue_events(*event_ids)