Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/models/concurrency_limits_v2.py: 74%
85 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
1from typing import List, Optional, Sequence, Union 1b
2from uuid import UUID 1b
4import sqlalchemy as sa 1b
5from sqlalchemy.ext.asyncio import AsyncSession 1b
6from sqlalchemy.sql.elements import ColumnElement 1b
8import prefect.server.schemas as schemas 1b
9from prefect.server.database import PrefectDBInterface, db_injector, orm_models 1b
10from prefect.server.utilities.database import greatest, least 1b
11from prefect.settings import get_current_settings 1b
14def active_slots_after_decay(db: PrefectDBInterface) -> ColumnElement[float]: 1b
15 # Active slots will decay at a rate of `slot_decay_per_second` per second.
16 return greatest( 1ac
17 0,
18 db.ConcurrencyLimitV2.active_slots
19 - sa.func.floor(
20 db.ConcurrencyLimitV2.slot_decay_per_second
21 * sa.func.date_diff_seconds(db.ConcurrencyLimitV2.updated)
22 ),
23 )
26def denied_slots_after_decay(db: PrefectDBInterface) -> ColumnElement[float]: 1b
27 """
28 Calculate denied_slots after applying decay.
30 Denied slots decay at a rate of `slot_decay_per_second` per second if it's
31 greater than 0 (rate limits), otherwise for concurrency limits it decays at
32 a rate based on clamped `avg_slot_occupancy_seconds`.
34 The clamping matches the retry-after calculation to prevent denied_slots from
35 accumulating when clients retry faster than the unclamped decay rate.
36 """
37 settings = get_current_settings() 1ac
39 # Determine max_wait based on limit name prefix
40 max_wait_for_limit = sa.case( 1ac
41 (
42 db.ConcurrencyLimitV2.name.like("tag:%"),
43 sa.literal(settings.server.tasks.tag_concurrency_slot_wait_seconds),
44 ),
45 else_=sa.literal(
46 settings.server.concurrency.maximum_concurrency_slot_wait_seconds
47 ),
48 )
50 # Clamp avg_slot_occupancy_seconds with minimum bound to prevent division by zero
51 clamped_occupancy = greatest( 1ac
52 sa.literal(MINIMUM_OCCUPANCY_SECONDS_PER_SLOT),
53 least(
54 sa.cast(db.ConcurrencyLimitV2.avg_slot_occupancy_seconds, sa.Float),
55 max_wait_for_limit,
56 ),
57 )
59 # Calculate decay rate: use slot_decay_per_second for rate limits,
60 # use 1/clamped_occupancy for concurrency limits
61 decay_rate_per_second = sa.case( 1ac
62 (
63 db.ConcurrencyLimitV2.slot_decay_per_second > 0.0,
64 db.ConcurrencyLimitV2.slot_decay_per_second, # Rate limits - no clamping
65 ),
66 else_=(1.0 / clamped_occupancy), # Concurrency limits - use clamped value
67 )
69 return greatest( 1ac
70 0,
71 db.ConcurrencyLimitV2.denied_slots
72 - sa.func.floor(
73 decay_rate_per_second
74 * sa.func.date_diff_seconds(db.ConcurrencyLimitV2.updated)
75 ),
76 )
79# OCCUPANCY_SAMPLES_MULTIPLIER is used to determine how many samples to use when
80# calculating the average occupancy seconds per slot.
81OCCUPANCY_SAMPLES_MULTIPLIER = 2 1b
83# MINIMUM_OCCUPANCY_SECONDS_PER_SLOT is used to prevent the average occupancy
84# from dropping too low and causing divide by zero errors.
85MINIMUM_OCCUPANCY_SECONDS_PER_SLOT = 0.1 1b
88@db_injector 1b
89async def create_concurrency_limit( 1b
90 db: PrefectDBInterface,
91 session: AsyncSession,
92 concurrency_limit: Union[
93 schemas.actions.ConcurrencyLimitV2Create, schemas.core.ConcurrencyLimitV2
94 ],
95) -> orm_models.ConcurrencyLimitV2:
96 model = db.ConcurrencyLimitV2(**concurrency_limit.model_dump()) 1ac
98 session.add(model) 1ac
99 await session.flush() 1ac
101 return model
104@db_injector 1b
105async def read_concurrency_limit( 1b
106 db: PrefectDBInterface,
107 session: AsyncSession,
108 concurrency_limit_id: Optional[UUID] = None,
109 name: Optional[str] = None,
110) -> Union[orm_models.ConcurrencyLimitV2, None]:
111 if not concurrency_limit_id and not name: 111 ↛ 112line 111 didn't jump to line 112 because the condition on line 111 was never true1adec
112 raise ValueError("Must provide either concurrency_limit_id or name")
114 where = ( 1adec
115 db.ConcurrencyLimitV2.id == concurrency_limit_id
116 if concurrency_limit_id
117 else db.ConcurrencyLimitV2.name == name
118 )
119 query = sa.select(db.ConcurrencyLimitV2).where(where) 1adec
120 result = await session.execute(query) 1adec
121 return result.scalar()
124@db_injector 1b
125async def read_all_concurrency_limits( 1b
126 db: PrefectDBInterface,
127 session: AsyncSession,
128 limit: int,
129 offset: int,
130) -> Sequence[orm_models.ConcurrencyLimitV2]:
131 query = sa.select(db.ConcurrencyLimitV2).order_by(db.ConcurrencyLimitV2.name) 1adc
133 if offset is not None: 133 ↛ 135line 133 didn't jump to line 135 because the condition on line 133 was always true1adc
134 query = query.offset(offset) 1adc
135 if limit is not None: 135 ↛ 138line 135 didn't jump to line 138 because the condition on line 135 was always true1adc
136 query = query.limit(limit) 1adc
138 result = await session.execute(query) 1adc
139 return result.scalars().unique().all()
142@db_injector 1b
143async def update_concurrency_limit( 1b
144 db: PrefectDBInterface,
145 session: AsyncSession,
146 concurrency_limit: schemas.actions.ConcurrencyLimitV2Update,
147 concurrency_limit_id: Optional[UUID] = None,
148 name: Optional[str] = None,
149) -> bool:
150 current_concurrency_limit = await read_concurrency_limit( 1aec
151 session, concurrency_limit_id=concurrency_limit_id, name=name
152 )
153 if not current_concurrency_limit:
154 return False
156 if not concurrency_limit_id and not name:
157 raise ValueError("Must provide either concurrency_limit_id or name")
159 where = (
160 db.ConcurrencyLimitV2.id == concurrency_limit_id
161 if concurrency_limit_id
162 else db.ConcurrencyLimitV2.name == name
163 )
165 result = await session.execute( 1aec
166 sa.update(db.ConcurrencyLimitV2)
167 .where(where)
168 .values(**concurrency_limit.model_dump(exclude_unset=True))
169 )
171 return result.rowcount > 0
174@db_injector 1b
175async def delete_concurrency_limit( 1b
176 db: PrefectDBInterface,
177 session: AsyncSession,
178 concurrency_limit_id: Optional[UUID] = None,
179 name: Optional[str] = None,
180) -> bool:
181 if not concurrency_limit_id and not name: 181 ↛ 182line 181 didn't jump to line 182 because the condition on line 181 was never true1ac
182 raise ValueError("Must provide either concurrency_limit_id or name")
184 where = ( 1ac
185 db.ConcurrencyLimitV2.id == concurrency_limit_id
186 if concurrency_limit_id
187 else db.ConcurrencyLimitV2.name == name
188 )
189 query = sa.delete(db.ConcurrencyLimitV2).where(where) 1ac
191 result = await session.execute(query) 1ac
192 return result.rowcount > 0
195@db_injector 1b
196async def bulk_read_concurrency_limits( 1b
197 db: PrefectDBInterface,
198 session: AsyncSession,
199 names: List[str],
200) -> List[orm_models.ConcurrencyLimitV2]:
201 # Get all existing concurrency limits in `names`.
202 existing_query = sa.select(db.ConcurrencyLimitV2).where( 1adc
203 db.ConcurrencyLimitV2.name.in_(names)
204 )
205 existing_limits = list((await session.execute(existing_query)).scalars().all()) 1adc
207 return existing_limits
210@db_injector 1b
211async def bulk_increment_active_slots( 1b
212 db: PrefectDBInterface,
213 session: AsyncSession,
214 concurrency_limit_ids: List[UUID],
215 slots: int,
216) -> bool:
217 active_slots = active_slots_after_decay(db) 1ac
218 denied_slots = denied_slots_after_decay(db) 1ac
220 query = ( 1ac
221 sa.update(db.ConcurrencyLimitV2)
222 .where(
223 sa.and_(
224 db.ConcurrencyLimitV2.id.in_(concurrency_limit_ids),
225 db.ConcurrencyLimitV2.active == True, # noqa
226 active_slots + slots <= db.ConcurrencyLimitV2.limit,
227 )
228 )
229 .values(
230 active_slots=active_slots + slots,
231 denied_slots=denied_slots,
232 )
233 ).execution_options(synchronize_session=False)
235 result = await session.execute(query) 1ac
236 return result.rowcount == len(concurrency_limit_ids)
239@db_injector 1b
240async def bulk_decrement_active_slots( 1b
241 db: PrefectDBInterface,
242 session: AsyncSession,
243 concurrency_limit_ids: List[UUID],
244 slots: int,
245 occupancy_seconds: Optional[float] = None,
246) -> bool:
247 query = ( 1a
248 sa.update(db.ConcurrencyLimitV2)
249 .where(
250 sa.and_(
251 db.ConcurrencyLimitV2.id.in_(concurrency_limit_ids),
252 db.ConcurrencyLimitV2.active == True, # noqa
253 )
254 )
255 .values(
256 active_slots=sa.case(
257 (active_slots_after_decay(db) - slots < 0, 0),
258 else_=active_slots_after_decay(db) - slots,
259 ),
260 denied_slots=denied_slots_after_decay(db),
261 )
262 )
264 if occupancy_seconds: 264 ↛ 283line 264 didn't jump to line 283 because the condition on line 264 was always true1a
265 occupancy_seconds_per_slot = max( 1a
266 occupancy_seconds / slots, MINIMUM_OCCUPANCY_SECONDS_PER_SLOT
267 )
269 query = query.values( 1a
270 # Update the average occupancy seconds per slot as a weighted
271 # average over the last `limit * OCCUPANCY_SAMPLE_MULTIPLIER` samples.
272 avg_slot_occupancy_seconds=db.ConcurrencyLimitV2.avg_slot_occupancy_seconds
273 + (
274 occupancy_seconds_per_slot
275 / (db.ConcurrencyLimitV2.limit * OCCUPANCY_SAMPLES_MULTIPLIER)
276 )
277 - (
278 db.ConcurrencyLimitV2.avg_slot_occupancy_seconds
279 / (db.ConcurrencyLimitV2.limit * OCCUPANCY_SAMPLES_MULTIPLIER)
280 ),
281 )
283 result = await session.execute(query) 1a
284 return result.rowcount == len(concurrency_limit_ids) 1a
287@db_injector 1b
288async def bulk_update_denied_slots( 1b
289 db: PrefectDBInterface,
290 session: AsyncSession,
291 concurrency_limit_ids: List[UUID],
292 slots: int,
293) -> bool:
294 query = (
295 sa.update(db.ConcurrencyLimitV2)
296 .where(
297 sa.and_(
298 db.ConcurrencyLimitV2.id.in_(concurrency_limit_ids),
299 db.ConcurrencyLimitV2.active == True, # noqa
300 )
301 )
302 .values(denied_slots=denied_slots_after_decay(db) + slots)
303 )
305 result = await session.execute(query)
306 return result.rowcount == len(concurrency_limit_ids)