Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/models/concurrency_limits_v2.py: 31%
85 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 13:38 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 13:38 +0000
1from typing import List, Optional, Sequence, Union 1a
2from uuid import UUID 1a
4import sqlalchemy as sa 1a
5from sqlalchemy.ext.asyncio import AsyncSession 1a
6from sqlalchemy.sql.elements import ColumnElement 1a
8import prefect.server.schemas as schemas 1a
9from prefect.server.database import PrefectDBInterface, db_injector, orm_models 1a
10from prefect.server.utilities.database import greatest, least 1a
11from prefect.settings import get_current_settings 1a
14def active_slots_after_decay(db: PrefectDBInterface) -> ColumnElement[float]: 1a
15 # Active slots will decay at a rate of `slot_decay_per_second` per second.
16 return greatest(
17 0,
18 db.ConcurrencyLimitV2.active_slots
19 - sa.func.floor(
20 db.ConcurrencyLimitV2.slot_decay_per_second
21 * sa.func.date_diff_seconds(db.ConcurrencyLimitV2.updated)
22 ),
23 )
26def denied_slots_after_decay(db: PrefectDBInterface) -> ColumnElement[float]: 1a
27 """
28 Calculate denied_slots after applying decay.
30 Denied slots decay at a rate of `slot_decay_per_second` per second if it's
31 greater than 0 (rate limits), otherwise for concurrency limits it decays at
32 a rate based on clamped `avg_slot_occupancy_seconds`.
34 The clamping matches the retry-after calculation to prevent denied_slots from
35 accumulating when clients retry faster than the unclamped decay rate.
36 """
37 settings = get_current_settings()
39 # Determine max_wait based on limit name prefix
40 max_wait_for_limit = sa.case(
41 (
42 db.ConcurrencyLimitV2.name.like("tag:%"),
43 sa.literal(settings.server.tasks.tag_concurrency_slot_wait_seconds),
44 ),
45 else_=sa.literal(
46 settings.server.concurrency.maximum_concurrency_slot_wait_seconds
47 ),
48 )
50 # Clamp avg_slot_occupancy_seconds with minimum bound to prevent division by zero
51 clamped_occupancy = greatest(
52 sa.literal(MINIMUM_OCCUPANCY_SECONDS_PER_SLOT),
53 least(
54 sa.cast(db.ConcurrencyLimitV2.avg_slot_occupancy_seconds, sa.Float),
55 max_wait_for_limit,
56 ),
57 )
59 # Calculate decay rate: use slot_decay_per_second for rate limits,
60 # use 1/clamped_occupancy for concurrency limits
61 decay_rate_per_second = sa.case(
62 (
63 db.ConcurrencyLimitV2.slot_decay_per_second > 0.0,
64 db.ConcurrencyLimitV2.slot_decay_per_second, # Rate limits - no clamping
65 ),
66 else_=(1.0 / clamped_occupancy), # Concurrency limits - use clamped value
67 )
69 return greatest(
70 0,
71 db.ConcurrencyLimitV2.denied_slots
72 - sa.func.floor(
73 decay_rate_per_second
74 * sa.func.date_diff_seconds(db.ConcurrencyLimitV2.updated)
75 ),
76 )
79# OCCUPANCY_SAMPLES_MULTIPLIER is used to determine how many samples to use when
80# calculating the average occupancy seconds per slot.
81OCCUPANCY_SAMPLES_MULTIPLIER = 2 1a
83# MINIMUM_OCCUPANCY_SECONDS_PER_SLOT is used to prevent the average occupancy
84# from dropping too low and causing divide by zero errors.
85MINIMUM_OCCUPANCY_SECONDS_PER_SLOT = 0.1 1a
88@db_injector 1a
89async def create_concurrency_limit( 1a
90 db: PrefectDBInterface,
91 session: AsyncSession,
92 concurrency_limit: Union[
93 schemas.actions.ConcurrencyLimitV2Create, schemas.core.ConcurrencyLimitV2
94 ],
95) -> orm_models.ConcurrencyLimitV2:
96 model = db.ConcurrencyLimitV2(**concurrency_limit.model_dump())
98 session.add(model)
99 await session.flush()
101 return model
104@db_injector 1a
105async def read_concurrency_limit( 1a
106 db: PrefectDBInterface,
107 session: AsyncSession,
108 concurrency_limit_id: Optional[UUID] = None,
109 name: Optional[str] = None,
110) -> Union[orm_models.ConcurrencyLimitV2, None]:
111 if not concurrency_limit_id and not name:
112 raise ValueError("Must provide either concurrency_limit_id or name")
114 where = (
115 db.ConcurrencyLimitV2.id == concurrency_limit_id
116 if concurrency_limit_id
117 else db.ConcurrencyLimitV2.name == name
118 )
119 query = sa.select(db.ConcurrencyLimitV2).where(where)
120 result = await session.execute(query)
121 return result.scalar()
124@db_injector 1a
125async def read_all_concurrency_limits( 1a
126 db: PrefectDBInterface,
127 session: AsyncSession,
128 limit: int,
129 offset: int,
130) -> Sequence[orm_models.ConcurrencyLimitV2]:
131 query = sa.select(db.ConcurrencyLimitV2).order_by(db.ConcurrencyLimitV2.name)
133 if offset is not None:
134 query = query.offset(offset)
135 if limit is not None:
136 query = query.limit(limit)
138 result = await session.execute(query)
139 return result.scalars().unique().all()
142@db_injector 1a
143async def update_concurrency_limit( 1a
144 db: PrefectDBInterface,
145 session: AsyncSession,
146 concurrency_limit: schemas.actions.ConcurrencyLimitV2Update,
147 concurrency_limit_id: Optional[UUID] = None,
148 name: Optional[str] = None,
149) -> bool:
150 current_concurrency_limit = await read_concurrency_limit(
151 session, concurrency_limit_id=concurrency_limit_id, name=name
152 )
153 if not current_concurrency_limit:
154 return False
156 if not concurrency_limit_id and not name:
157 raise ValueError("Must provide either concurrency_limit_id or name")
159 where = (
160 db.ConcurrencyLimitV2.id == concurrency_limit_id
161 if concurrency_limit_id
162 else db.ConcurrencyLimitV2.name == name
163 )
165 result = await session.execute(
166 sa.update(db.ConcurrencyLimitV2)
167 .where(where)
168 .values(**concurrency_limit.model_dump(exclude_unset=True))
169 )
171 return result.rowcount > 0
174@db_injector 1a
175async def delete_concurrency_limit( 1a
176 db: PrefectDBInterface,
177 session: AsyncSession,
178 concurrency_limit_id: Optional[UUID] = None,
179 name: Optional[str] = None,
180) -> bool:
181 if not concurrency_limit_id and not name:
182 raise ValueError("Must provide either concurrency_limit_id or name")
184 where = (
185 db.ConcurrencyLimitV2.id == concurrency_limit_id
186 if concurrency_limit_id
187 else db.ConcurrencyLimitV2.name == name
188 )
189 query = sa.delete(db.ConcurrencyLimitV2).where(where)
191 result = await session.execute(query)
192 return result.rowcount > 0
195@db_injector 1a
196async def bulk_read_concurrency_limits( 1a
197 db: PrefectDBInterface,
198 session: AsyncSession,
199 names: List[str],
200) -> List[orm_models.ConcurrencyLimitV2]:
201 # Get all existing concurrency limits in `names`.
202 existing_query = sa.select(db.ConcurrencyLimitV2).where(
203 db.ConcurrencyLimitV2.name.in_(names)
204 )
205 existing_limits = list((await session.execute(existing_query)).scalars().all())
207 return existing_limits
210@db_injector 1a
211async def bulk_increment_active_slots( 1a
212 db: PrefectDBInterface,
213 session: AsyncSession,
214 concurrency_limit_ids: List[UUID],
215 slots: int,
216) -> bool:
217 active_slots = active_slots_after_decay(db)
218 denied_slots = denied_slots_after_decay(db)
220 query = (
221 sa.update(db.ConcurrencyLimitV2)
222 .where(
223 sa.and_(
224 db.ConcurrencyLimitV2.id.in_(concurrency_limit_ids),
225 db.ConcurrencyLimitV2.active == True, # noqa
226 active_slots + slots <= db.ConcurrencyLimitV2.limit,
227 )
228 )
229 .values(
230 active_slots=active_slots + slots,
231 denied_slots=denied_slots,
232 )
233 ).execution_options(synchronize_session=False)
235 result = await session.execute(query)
236 return result.rowcount == len(concurrency_limit_ids)
239@db_injector 1a
240async def bulk_decrement_active_slots( 1a
241 db: PrefectDBInterface,
242 session: AsyncSession,
243 concurrency_limit_ids: List[UUID],
244 slots: int,
245 occupancy_seconds: Optional[float] = None,
246) -> bool:
247 query = (
248 sa.update(db.ConcurrencyLimitV2)
249 .where(
250 sa.and_(
251 db.ConcurrencyLimitV2.id.in_(concurrency_limit_ids),
252 db.ConcurrencyLimitV2.active == True, # noqa
253 )
254 )
255 .values(
256 active_slots=sa.case(
257 (active_slots_after_decay(db) - slots < 0, 0),
258 else_=active_slots_after_decay(db) - slots,
259 ),
260 denied_slots=denied_slots_after_decay(db),
261 )
262 )
264 if occupancy_seconds:
265 occupancy_seconds_per_slot = max(
266 occupancy_seconds / slots, MINIMUM_OCCUPANCY_SECONDS_PER_SLOT
267 )
269 query = query.values(
270 # Update the average occupancy seconds per slot as a weighted
271 # average over the last `limit * OCCUPANCY_SAMPLE_MULTIPLIER` samples.
272 avg_slot_occupancy_seconds=db.ConcurrencyLimitV2.avg_slot_occupancy_seconds
273 + (
274 occupancy_seconds_per_slot
275 / (db.ConcurrencyLimitV2.limit * OCCUPANCY_SAMPLES_MULTIPLIER)
276 )
277 - (
278 db.ConcurrencyLimitV2.avg_slot_occupancy_seconds
279 / (db.ConcurrencyLimitV2.limit * OCCUPANCY_SAMPLES_MULTIPLIER)
280 ),
281 )
283 result = await session.execute(query)
284 return result.rowcount == len(concurrency_limit_ids)
287@db_injector 1a
288async def bulk_update_denied_slots( 1a
289 db: PrefectDBInterface,
290 session: AsyncSession,
291 concurrency_limit_ids: List[UUID],
292 slots: int,
293) -> bool:
294 query = (
295 sa.update(db.ConcurrencyLimitV2)
296 .where(
297 sa.and_(
298 db.ConcurrencyLimitV2.id.in_(concurrency_limit_ids),
299 db.ConcurrencyLimitV2.active == True, # noqa
300 )
301 )
302 .values(denied_slots=denied_slots_after_decay(db) + slots)
303 )
305 result = await session.execute(query)
306 return result.rowcount == len(concurrency_limit_ids)