Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/models/concurrency_limits_v2.py: 31%

85 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 11:21 +0000

1from typing import List, Optional, Sequence, Union 1a

2from uuid import UUID 1a

3 

4import sqlalchemy as sa 1a

5from sqlalchemy.ext.asyncio import AsyncSession 1a

6from sqlalchemy.sql.elements import ColumnElement 1a

7 

8import prefect.server.schemas as schemas 1a

9from prefect.server.database import PrefectDBInterface, db_injector, orm_models 1a

10from prefect.server.utilities.database import greatest, least 1a

11from prefect.settings import get_current_settings 1a

12 

13 

14def active_slots_after_decay(db: PrefectDBInterface) -> ColumnElement[float]: 1a

15 # Active slots will decay at a rate of `slot_decay_per_second` per second. 

16 return greatest( 

17 0, 

18 db.ConcurrencyLimitV2.active_slots 

19 - sa.func.floor( 

20 db.ConcurrencyLimitV2.slot_decay_per_second 

21 * sa.func.date_diff_seconds(db.ConcurrencyLimitV2.updated) 

22 ), 

23 ) 

24 

25 

26def denied_slots_after_decay(db: PrefectDBInterface) -> ColumnElement[float]: 1a

27 """ 

28 Calculate denied_slots after applying decay. 

29 

30 Denied slots decay at a rate of `slot_decay_per_second` per second if it's 

31 greater than 0 (rate limits), otherwise for concurrency limits it decays at 

32 a rate based on clamped `avg_slot_occupancy_seconds`. 

33 

34 The clamping matches the retry-after calculation to prevent denied_slots from 

35 accumulating when clients retry faster than the unclamped decay rate. 

36 """ 

37 settings = get_current_settings() 

38 

39 # Determine max_wait based on limit name prefix 

40 max_wait_for_limit = sa.case( 

41 ( 

42 db.ConcurrencyLimitV2.name.like("tag:%"), 

43 sa.literal(settings.server.tasks.tag_concurrency_slot_wait_seconds), 

44 ), 

45 else_=sa.literal( 

46 settings.server.concurrency.maximum_concurrency_slot_wait_seconds 

47 ), 

48 ) 

49 

50 # Clamp avg_slot_occupancy_seconds with minimum bound to prevent division by zero 

51 clamped_occupancy = greatest( 

52 sa.literal(MINIMUM_OCCUPANCY_SECONDS_PER_SLOT), 

53 least( 

54 sa.cast(db.ConcurrencyLimitV2.avg_slot_occupancy_seconds, sa.Float), 

55 max_wait_for_limit, 

56 ), 

57 ) 

58 

59 # Calculate decay rate: use slot_decay_per_second for rate limits, 

60 # use 1/clamped_occupancy for concurrency limits 

61 decay_rate_per_second = sa.case( 

62 ( 

63 db.ConcurrencyLimitV2.slot_decay_per_second > 0.0, 

64 db.ConcurrencyLimitV2.slot_decay_per_second, # Rate limits - no clamping 

65 ), 

66 else_=(1.0 / clamped_occupancy), # Concurrency limits - use clamped value 

67 ) 

68 

69 return greatest( 

70 0, 

71 db.ConcurrencyLimitV2.denied_slots 

72 - sa.func.floor( 

73 decay_rate_per_second 

74 * sa.func.date_diff_seconds(db.ConcurrencyLimitV2.updated) 

75 ), 

76 ) 

77 

78 

79# OCCUPANCY_SAMPLES_MULTIPLIER is used to determine how many samples to use when 

80# calculating the average occupancy seconds per slot. 

81OCCUPANCY_SAMPLES_MULTIPLIER = 2 1a

82 

83# MINIMUM_OCCUPANCY_SECONDS_PER_SLOT is used to prevent the average occupancy 

84# from dropping too low and causing divide by zero errors. 

85MINIMUM_OCCUPANCY_SECONDS_PER_SLOT = 0.1 1a

86 

87 

88@db_injector 1a

89async def create_concurrency_limit( 1a

90 db: PrefectDBInterface, 

91 session: AsyncSession, 

92 concurrency_limit: Union[ 

93 schemas.actions.ConcurrencyLimitV2Create, schemas.core.ConcurrencyLimitV2 

94 ], 

95) -> orm_models.ConcurrencyLimitV2: 

96 model = db.ConcurrencyLimitV2(**concurrency_limit.model_dump()) 

97 

98 session.add(model) 

99 await session.flush() 

100 

101 return model 

102 

103 

104@db_injector 1a

105async def read_concurrency_limit( 1a

106 db: PrefectDBInterface, 

107 session: AsyncSession, 

108 concurrency_limit_id: Optional[UUID] = None, 

109 name: Optional[str] = None, 

110) -> Union[orm_models.ConcurrencyLimitV2, None]: 

111 if not concurrency_limit_id and not name: 

112 raise ValueError("Must provide either concurrency_limit_id or name") 

113 

114 where = ( 

115 db.ConcurrencyLimitV2.id == concurrency_limit_id 

116 if concurrency_limit_id 

117 else db.ConcurrencyLimitV2.name == name 

118 ) 

119 query = sa.select(db.ConcurrencyLimitV2).where(where) 

120 result = await session.execute(query) 

121 return result.scalar() 

122 

123 

124@db_injector 1a

125async def read_all_concurrency_limits( 1a

126 db: PrefectDBInterface, 

127 session: AsyncSession, 

128 limit: int, 

129 offset: int, 

130) -> Sequence[orm_models.ConcurrencyLimitV2]: 

131 query = sa.select(db.ConcurrencyLimitV2).order_by(db.ConcurrencyLimitV2.name) 

132 

133 if offset is not None: 

134 query = query.offset(offset) 

135 if limit is not None: 

136 query = query.limit(limit) 

137 

138 result = await session.execute(query) 

139 return result.scalars().unique().all() 

140 

141 

142@db_injector 1a

143async def update_concurrency_limit( 1a

144 db: PrefectDBInterface, 

145 session: AsyncSession, 

146 concurrency_limit: schemas.actions.ConcurrencyLimitV2Update, 

147 concurrency_limit_id: Optional[UUID] = None, 

148 name: Optional[str] = None, 

149) -> bool: 

150 current_concurrency_limit = await read_concurrency_limit( 

151 session, concurrency_limit_id=concurrency_limit_id, name=name 

152 ) 

153 if not current_concurrency_limit: 

154 return False 

155 

156 if not concurrency_limit_id and not name: 

157 raise ValueError("Must provide either concurrency_limit_id or name") 

158 

159 where = ( 

160 db.ConcurrencyLimitV2.id == concurrency_limit_id 

161 if concurrency_limit_id 

162 else db.ConcurrencyLimitV2.name == name 

163 ) 

164 

165 result = await session.execute( 

166 sa.update(db.ConcurrencyLimitV2) 

167 .where(where) 

168 .values(**concurrency_limit.model_dump(exclude_unset=True)) 

169 ) 

170 

171 return result.rowcount > 0 

172 

173 

174@db_injector 1a

175async def delete_concurrency_limit( 1a

176 db: PrefectDBInterface, 

177 session: AsyncSession, 

178 concurrency_limit_id: Optional[UUID] = None, 

179 name: Optional[str] = None, 

180) -> bool: 

181 if not concurrency_limit_id and not name: 

182 raise ValueError("Must provide either concurrency_limit_id or name") 

183 

184 where = ( 

185 db.ConcurrencyLimitV2.id == concurrency_limit_id 

186 if concurrency_limit_id 

187 else db.ConcurrencyLimitV2.name == name 

188 ) 

189 query = sa.delete(db.ConcurrencyLimitV2).where(where) 

190 

191 result = await session.execute(query) 

192 return result.rowcount > 0 

193 

194 

195@db_injector 1a

196async def bulk_read_concurrency_limits( 1a

197 db: PrefectDBInterface, 

198 session: AsyncSession, 

199 names: List[str], 

200) -> List[orm_models.ConcurrencyLimitV2]: 

201 # Get all existing concurrency limits in `names`. 

202 existing_query = sa.select(db.ConcurrencyLimitV2).where( 

203 db.ConcurrencyLimitV2.name.in_(names) 

204 ) 

205 existing_limits = list((await session.execute(existing_query)).scalars().all()) 

206 

207 return existing_limits 

208 

209 

210@db_injector 1a

211async def bulk_increment_active_slots( 1a

212 db: PrefectDBInterface, 

213 session: AsyncSession, 

214 concurrency_limit_ids: List[UUID], 

215 slots: int, 

216) -> bool: 

217 active_slots = active_slots_after_decay(db) 

218 denied_slots = denied_slots_after_decay(db) 

219 

220 query = ( 

221 sa.update(db.ConcurrencyLimitV2) 

222 .where( 

223 sa.and_( 

224 db.ConcurrencyLimitV2.id.in_(concurrency_limit_ids), 

225 db.ConcurrencyLimitV2.active == True, # noqa 

226 active_slots + slots <= db.ConcurrencyLimitV2.limit, 

227 ) 

228 ) 

229 .values( 

230 active_slots=active_slots + slots, 

231 denied_slots=denied_slots, 

232 ) 

233 ).execution_options(synchronize_session=False) 

234 

235 result = await session.execute(query) 

236 return result.rowcount == len(concurrency_limit_ids) 

237 

238 

239@db_injector 1a

240async def bulk_decrement_active_slots( 1a

241 db: PrefectDBInterface, 

242 session: AsyncSession, 

243 concurrency_limit_ids: List[UUID], 

244 slots: int, 

245 occupancy_seconds: Optional[float] = None, 

246) -> bool: 

247 query = ( 

248 sa.update(db.ConcurrencyLimitV2) 

249 .where( 

250 sa.and_( 

251 db.ConcurrencyLimitV2.id.in_(concurrency_limit_ids), 

252 db.ConcurrencyLimitV2.active == True, # noqa 

253 ) 

254 ) 

255 .values( 

256 active_slots=sa.case( 

257 (active_slots_after_decay(db) - slots < 0, 0), 

258 else_=active_slots_after_decay(db) - slots, 

259 ), 

260 denied_slots=denied_slots_after_decay(db), 

261 ) 

262 ) 

263 

264 if occupancy_seconds: 

265 occupancy_seconds_per_slot = max( 

266 occupancy_seconds / slots, MINIMUM_OCCUPANCY_SECONDS_PER_SLOT 

267 ) 

268 

269 query = query.values( 

270 # Update the average occupancy seconds per slot as a weighted 

271 # average over the last `limit * OCCUPANCY_SAMPLE_MULTIPLIER` samples. 

272 avg_slot_occupancy_seconds=db.ConcurrencyLimitV2.avg_slot_occupancy_seconds 

273 + ( 

274 occupancy_seconds_per_slot 

275 / (db.ConcurrencyLimitV2.limit * OCCUPANCY_SAMPLES_MULTIPLIER) 

276 ) 

277 - ( 

278 db.ConcurrencyLimitV2.avg_slot_occupancy_seconds 

279 / (db.ConcurrencyLimitV2.limit * OCCUPANCY_SAMPLES_MULTIPLIER) 

280 ), 

281 ) 

282 

283 result = await session.execute(query) 

284 return result.rowcount == len(concurrency_limit_ids) 

285 

286 

287@db_injector 1a

288async def bulk_update_denied_slots( 1a

289 db: PrefectDBInterface, 

290 session: AsyncSession, 

291 concurrency_limit_ids: List[UUID], 

292 slots: int, 

293) -> bool: 

294 query = ( 

295 sa.update(db.ConcurrencyLimitV2) 

296 .where( 

297 sa.and_( 

298 db.ConcurrencyLimitV2.id.in_(concurrency_limit_ids), 

299 db.ConcurrencyLimitV2.active == True, # noqa 

300 ) 

301 ) 

302 .values(denied_slots=denied_slots_after_decay(db) + slots) 

303 ) 

304 

305 result = await session.execute(query) 

306 return result.rowcount == len(concurrency_limit_ids)