Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/models/concurrency_limits_v2.py: 74%

85 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1from typing import List, Optional, Sequence, Union 1b

2from uuid import UUID 1b

3 

4import sqlalchemy as sa 1b

5from sqlalchemy.ext.asyncio import AsyncSession 1b

6from sqlalchemy.sql.elements import ColumnElement 1b

7 

8import prefect.server.schemas as schemas 1b

9from prefect.server.database import PrefectDBInterface, db_injector, orm_models 1b

10from prefect.server.utilities.database import greatest, least 1b

11from prefect.settings import get_current_settings 1b

12 

13 

14def active_slots_after_decay(db: PrefectDBInterface) -> ColumnElement[float]: 1b

15 # Active slots will decay at a rate of `slot_decay_per_second` per second. 

16 return greatest( 1ac

17 0, 

18 db.ConcurrencyLimitV2.active_slots 

19 - sa.func.floor( 

20 db.ConcurrencyLimitV2.slot_decay_per_second 

21 * sa.func.date_diff_seconds(db.ConcurrencyLimitV2.updated) 

22 ), 

23 ) 

24 

25 

26def denied_slots_after_decay(db: PrefectDBInterface) -> ColumnElement[float]: 1b

27 """ 

28 Calculate denied_slots after applying decay. 

29 

30 Denied slots decay at a rate of `slot_decay_per_second` per second if it's 

31 greater than 0 (rate limits), otherwise for concurrency limits it decays at 

32 a rate based on clamped `avg_slot_occupancy_seconds`. 

33 

34 The clamping matches the retry-after calculation to prevent denied_slots from 

35 accumulating when clients retry faster than the unclamped decay rate. 

36 """ 

37 settings = get_current_settings() 1ac

38 

39 # Determine max_wait based on limit name prefix 

40 max_wait_for_limit = sa.case( 1ac

41 ( 

42 db.ConcurrencyLimitV2.name.like("tag:%"), 

43 sa.literal(settings.server.tasks.tag_concurrency_slot_wait_seconds), 

44 ), 

45 else_=sa.literal( 

46 settings.server.concurrency.maximum_concurrency_slot_wait_seconds 

47 ), 

48 ) 

49 

50 # Clamp avg_slot_occupancy_seconds with minimum bound to prevent division by zero 

51 clamped_occupancy = greatest( 1ac

52 sa.literal(MINIMUM_OCCUPANCY_SECONDS_PER_SLOT), 

53 least( 

54 sa.cast(db.ConcurrencyLimitV2.avg_slot_occupancy_seconds, sa.Float), 

55 max_wait_for_limit, 

56 ), 

57 ) 

58 

59 # Calculate decay rate: use slot_decay_per_second for rate limits, 

60 # use 1/clamped_occupancy for concurrency limits 

61 decay_rate_per_second = sa.case( 1ac

62 ( 

63 db.ConcurrencyLimitV2.slot_decay_per_second > 0.0, 

64 db.ConcurrencyLimitV2.slot_decay_per_second, # Rate limits - no clamping 

65 ), 

66 else_=(1.0 / clamped_occupancy), # Concurrency limits - use clamped value 

67 ) 

68 

69 return greatest( 1ac

70 0, 

71 db.ConcurrencyLimitV2.denied_slots 

72 - sa.func.floor( 

73 decay_rate_per_second 

74 * sa.func.date_diff_seconds(db.ConcurrencyLimitV2.updated) 

75 ), 

76 ) 

77 

78 

79# OCCUPANCY_SAMPLES_MULTIPLIER is used to determine how many samples to use when 

80# calculating the average occupancy seconds per slot. 

81OCCUPANCY_SAMPLES_MULTIPLIER = 2 1b

82 

83# MINIMUM_OCCUPANCY_SECONDS_PER_SLOT is used to prevent the average occupancy 

84# from dropping too low and causing divide by zero errors. 

85MINIMUM_OCCUPANCY_SECONDS_PER_SLOT = 0.1 1b

86 

87 

88@db_injector 1b

89async def create_concurrency_limit( 1b

90 db: PrefectDBInterface, 

91 session: AsyncSession, 

92 concurrency_limit: Union[ 

93 schemas.actions.ConcurrencyLimitV2Create, schemas.core.ConcurrencyLimitV2 

94 ], 

95) -> orm_models.ConcurrencyLimitV2: 

96 model = db.ConcurrencyLimitV2(**concurrency_limit.model_dump()) 1ac

97 

98 session.add(model) 1ac

99 await session.flush() 1ac

100 

101 return model 

102 

103 

104@db_injector 1b

105async def read_concurrency_limit( 1b

106 db: PrefectDBInterface, 

107 session: AsyncSession, 

108 concurrency_limit_id: Optional[UUID] = None, 

109 name: Optional[str] = None, 

110) -> Union[orm_models.ConcurrencyLimitV2, None]: 

111 if not concurrency_limit_id and not name: 111 ↛ 112line 111 didn't jump to line 112 because the condition on line 111 was never true1adec

112 raise ValueError("Must provide either concurrency_limit_id or name") 

113 

114 where = ( 1adec

115 db.ConcurrencyLimitV2.id == concurrency_limit_id 

116 if concurrency_limit_id 

117 else db.ConcurrencyLimitV2.name == name 

118 ) 

119 query = sa.select(db.ConcurrencyLimitV2).where(where) 1adec

120 result = await session.execute(query) 1adec

121 return result.scalar() 

122 

123 

124@db_injector 1b

125async def read_all_concurrency_limits( 1b

126 db: PrefectDBInterface, 

127 session: AsyncSession, 

128 limit: int, 

129 offset: int, 

130) -> Sequence[orm_models.ConcurrencyLimitV2]: 

131 query = sa.select(db.ConcurrencyLimitV2).order_by(db.ConcurrencyLimitV2.name) 1adc

132 

133 if offset is not None: 133 ↛ 135line 133 didn't jump to line 135 because the condition on line 133 was always true1adc

134 query = query.offset(offset) 1adc

135 if limit is not None: 135 ↛ 138line 135 didn't jump to line 138 because the condition on line 135 was always true1adc

136 query = query.limit(limit) 1adc

137 

138 result = await session.execute(query) 1adc

139 return result.scalars().unique().all() 

140 

141 

142@db_injector 1b

143async def update_concurrency_limit( 1b

144 db: PrefectDBInterface, 

145 session: AsyncSession, 

146 concurrency_limit: schemas.actions.ConcurrencyLimitV2Update, 

147 concurrency_limit_id: Optional[UUID] = None, 

148 name: Optional[str] = None, 

149) -> bool: 

150 current_concurrency_limit = await read_concurrency_limit( 1aec

151 session, concurrency_limit_id=concurrency_limit_id, name=name 

152 ) 

153 if not current_concurrency_limit: 

154 return False 

155 

156 if not concurrency_limit_id and not name: 

157 raise ValueError("Must provide either concurrency_limit_id or name") 

158 

159 where = ( 

160 db.ConcurrencyLimitV2.id == concurrency_limit_id 

161 if concurrency_limit_id 

162 else db.ConcurrencyLimitV2.name == name 

163 ) 

164 

165 result = await session.execute( 1aec

166 sa.update(db.ConcurrencyLimitV2) 

167 .where(where) 

168 .values(**concurrency_limit.model_dump(exclude_unset=True)) 

169 ) 

170 

171 return result.rowcount > 0 

172 

173 

174@db_injector 1b

175async def delete_concurrency_limit( 1b

176 db: PrefectDBInterface, 

177 session: AsyncSession, 

178 concurrency_limit_id: Optional[UUID] = None, 

179 name: Optional[str] = None, 

180) -> bool: 

181 if not concurrency_limit_id and not name: 181 ↛ 182line 181 didn't jump to line 182 because the condition on line 181 was never true1ac

182 raise ValueError("Must provide either concurrency_limit_id or name") 

183 

184 where = ( 1ac

185 db.ConcurrencyLimitV2.id == concurrency_limit_id 

186 if concurrency_limit_id 

187 else db.ConcurrencyLimitV2.name == name 

188 ) 

189 query = sa.delete(db.ConcurrencyLimitV2).where(where) 1ac

190 

191 result = await session.execute(query) 1ac

192 return result.rowcount > 0 

193 

194 

195@db_injector 1b

196async def bulk_read_concurrency_limits( 1b

197 db: PrefectDBInterface, 

198 session: AsyncSession, 

199 names: List[str], 

200) -> List[orm_models.ConcurrencyLimitV2]: 

201 # Get all existing concurrency limits in `names`. 

202 existing_query = sa.select(db.ConcurrencyLimitV2).where( 1adc

203 db.ConcurrencyLimitV2.name.in_(names) 

204 ) 

205 existing_limits = list((await session.execute(existing_query)).scalars().all()) 1adc

206 

207 return existing_limits 

208 

209 

210@db_injector 1b

211async def bulk_increment_active_slots( 1b

212 db: PrefectDBInterface, 

213 session: AsyncSession, 

214 concurrency_limit_ids: List[UUID], 

215 slots: int, 

216) -> bool: 

217 active_slots = active_slots_after_decay(db) 1ac

218 denied_slots = denied_slots_after_decay(db) 1ac

219 

220 query = ( 1ac

221 sa.update(db.ConcurrencyLimitV2) 

222 .where( 

223 sa.and_( 

224 db.ConcurrencyLimitV2.id.in_(concurrency_limit_ids), 

225 db.ConcurrencyLimitV2.active == True, # noqa 

226 active_slots + slots <= db.ConcurrencyLimitV2.limit, 

227 ) 

228 ) 

229 .values( 

230 active_slots=active_slots + slots, 

231 denied_slots=denied_slots, 

232 ) 

233 ).execution_options(synchronize_session=False) 

234 

235 result = await session.execute(query) 1ac

236 return result.rowcount == len(concurrency_limit_ids) 

237 

238 

239@db_injector 1b

240async def bulk_decrement_active_slots( 1b

241 db: PrefectDBInterface, 

242 session: AsyncSession, 

243 concurrency_limit_ids: List[UUID], 

244 slots: int, 

245 occupancy_seconds: Optional[float] = None, 

246) -> bool: 

247 query = ( 1a

248 sa.update(db.ConcurrencyLimitV2) 

249 .where( 

250 sa.and_( 

251 db.ConcurrencyLimitV2.id.in_(concurrency_limit_ids), 

252 db.ConcurrencyLimitV2.active == True, # noqa 

253 ) 

254 ) 

255 .values( 

256 active_slots=sa.case( 

257 (active_slots_after_decay(db) - slots < 0, 0), 

258 else_=active_slots_after_decay(db) - slots, 

259 ), 

260 denied_slots=denied_slots_after_decay(db), 

261 ) 

262 ) 

263 

264 if occupancy_seconds: 264 ↛ 283line 264 didn't jump to line 283 because the condition on line 264 was always true1a

265 occupancy_seconds_per_slot = max( 1a

266 occupancy_seconds / slots, MINIMUM_OCCUPANCY_SECONDS_PER_SLOT 

267 ) 

268 

269 query = query.values( 1a

270 # Update the average occupancy seconds per slot as a weighted 

271 # average over the last `limit * OCCUPANCY_SAMPLE_MULTIPLIER` samples. 

272 avg_slot_occupancy_seconds=db.ConcurrencyLimitV2.avg_slot_occupancy_seconds 

273 + ( 

274 occupancy_seconds_per_slot 

275 / (db.ConcurrencyLimitV2.limit * OCCUPANCY_SAMPLES_MULTIPLIER) 

276 ) 

277 - ( 

278 db.ConcurrencyLimitV2.avg_slot_occupancy_seconds 

279 / (db.ConcurrencyLimitV2.limit * OCCUPANCY_SAMPLES_MULTIPLIER) 

280 ), 

281 ) 

282 

283 result = await session.execute(query) 1a

284 return result.rowcount == len(concurrency_limit_ids) 1a

285 

286 

287@db_injector 1b

288async def bulk_update_denied_slots( 1b

289 db: PrefectDBInterface, 

290 session: AsyncSession, 

291 concurrency_limit_ids: List[UUID], 

292 slots: int, 

293) -> bool: 

294 query = ( 

295 sa.update(db.ConcurrencyLimitV2) 

296 .where( 

297 sa.and_( 

298 db.ConcurrencyLimitV2.id.in_(concurrency_limit_ids), 

299 db.ConcurrencyLimitV2.active == True, # noqa 

300 ) 

301 ) 

302 .values(denied_slots=denied_slots_after_decay(db) + slots) 

303 ) 

304 

305 result = await session.execute(query) 

306 return result.rowcount == len(concurrency_limit_ids)