Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/models/concurrency_limits.py: 32%

67 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 13:38 +0000

1""" 

2Functions for interacting with concurrency limit ORM objects. 

3Intended for internal use by the Prefect REST API. 

4""" 

5 

6from datetime import timedelta 1a

7from typing import List, Optional, Sequence, Union 1a

8from uuid import UUID 1a

9 

10import sqlalchemy as sa 1a

11from sqlalchemy.ext.asyncio import AsyncSession 1a

12 

13import prefect.server.schemas as schemas 1a

14from prefect.server.database import PrefectDBInterface, db_injector, orm_models 1a

15from prefect.types._datetime import now 1a

16 

17# Clients creating V1 limits can't maintain leases, so we use a long TTL to maintain compatibility. 

18V1_LEASE_TTL = timedelta(days=100 * 365) # ~100 years 1a

19 

20 

21@db_injector 1a

22async def create_concurrency_limit( 1a

23 db: PrefectDBInterface, 

24 session: AsyncSession, 

25 concurrency_limit: schemas.core.ConcurrencyLimit, 

26) -> orm_models.ConcurrencyLimit: 

27 insert_values = concurrency_limit.model_dump_for_orm(exclude_unset=False) 

28 insert_values.pop("created") 

29 insert_values.pop("updated") 

30 concurrency_tag = insert_values["tag"] 

31 

32 # set `updated` manually 

33 # known limitation of `on_conflict_do_update`, will not use `Column.onupdate` 

34 # https://docs.sqlalchemy.org/en/14/dialects/sqlite.html#the-set-clause 

35 concurrency_limit.updated = now("UTC") # type: ignore[assignment] 

36 

37 insert_stmt = ( 

38 db.queries.insert(db.ConcurrencyLimit) 

39 .values(**insert_values) 

40 .on_conflict_do_update( 

41 index_elements=db.orm.concurrency_limit_unique_upsert_columns, 

42 set_=concurrency_limit.model_dump_for_orm( 

43 include={"concurrency_limit", "updated"} 

44 ), 

45 ) 

46 ) 

47 

48 await session.execute(insert_stmt) 

49 

50 query = ( 

51 sa.select(db.ConcurrencyLimit) 

52 .where(db.ConcurrencyLimit.tag == concurrency_tag) 

53 .execution_options(populate_existing=True) 

54 ) 

55 

56 result = await session.execute(query) 

57 return result.scalar_one() 

58 

59 

60@db_injector 1a

61async def read_concurrency_limit( 1a

62 db: PrefectDBInterface, 

63 session: AsyncSession, 

64 concurrency_limit_id: UUID, 

65) -> Union[orm_models.ConcurrencyLimit, None]: 

66 """ 

67 Reads a concurrency limit by id. If used for orchestration, simultaneous read race 

68 conditions might allow the concurrency limit to be temporarily exceeded. 

69 """ 

70 

71 query = sa.select(db.ConcurrencyLimit).where( 

72 db.ConcurrencyLimit.id == concurrency_limit_id 

73 ) 

74 

75 result = await session.execute(query) 

76 return result.scalar() 

77 

78 

79@db_injector 1a

80async def read_concurrency_limit_by_tag( 1a

81 db: PrefectDBInterface, 

82 session: AsyncSession, 

83 tag: str, 

84) -> Union[orm_models.ConcurrencyLimit, None]: 

85 """ 

86 Reads a concurrency limit by tag. If used for orchestration, simultaneous read race 

87 conditions might allow the concurrency limit to be temporarily exceeded. 

88 """ 

89 

90 query = sa.select(db.ConcurrencyLimit).where(db.ConcurrencyLimit.tag == tag) 

91 

92 result = await session.execute(query) 

93 return result.scalar() 

94 

95 

96@db_injector 1a

97async def reset_concurrency_limit_by_tag( 1a

98 db: PrefectDBInterface, 

99 session: AsyncSession, 

100 tag: str, 

101 slot_override: Optional[List[UUID]] = None, 

102) -> Union[orm_models.ConcurrencyLimit, None]: 

103 """ 

104 Resets a concurrency limit by tag. 

105 """ 

106 query = sa.select(db.ConcurrencyLimit).where(db.ConcurrencyLimit.tag == tag) 

107 result = await session.execute(query) 

108 concurrency_limit = result.scalar() 

109 if concurrency_limit: 

110 if slot_override is not None: 

111 concurrency_limit.active_slots = [str(slot) for slot in slot_override] 

112 else: 

113 concurrency_limit.active_slots = [] 

114 return concurrency_limit 

115 

116 

117@db_injector 1a

118async def filter_concurrency_limits_for_orchestration( 1a

119 db: PrefectDBInterface, 

120 session: AsyncSession, 

121 tags: List[str], 

122) -> Sequence[orm_models.ConcurrencyLimit]: 

123 """ 

124 Filters concurrency limits by tag. This will apply a "select for update" lock on 

125 these rows to prevent simultaneous read race conditions from enabling the 

126 the concurrency limit on these tags from being temporarily exceeded. 

127 """ 

128 

129 if not tags: 

130 return [] 

131 

132 query = ( 

133 sa.select(db.ConcurrencyLimit) 

134 .filter(db.ConcurrencyLimit.tag.in_(tags)) 

135 .order_by(db.ConcurrencyLimit.tag) 

136 .with_for_update() 

137 ) 

138 result = await session.execute(query) 

139 return result.scalars().all() 

140 

141 

142@db_injector 1a

143async def delete_concurrency_limit( 1a

144 db: PrefectDBInterface, 

145 session: AsyncSession, 

146 concurrency_limit_id: UUID, 

147) -> bool: 

148 query = sa.delete(db.ConcurrencyLimit).where( 

149 db.ConcurrencyLimit.id == concurrency_limit_id 

150 ) 

151 

152 result = await session.execute(query) 

153 return result.rowcount > 0 

154 

155 

156@db_injector 1a

157async def delete_concurrency_limit_by_tag( 1a

158 db: PrefectDBInterface, 

159 session: AsyncSession, 

160 tag: str, 

161) -> bool: 

162 query = sa.delete(db.ConcurrencyLimit).where(db.ConcurrencyLimit.tag == tag) 

163 

164 result = await session.execute(query) 

165 return result.rowcount > 0 

166 

167 

168@db_injector 1a

169async def read_concurrency_limits( 1a

170 db: PrefectDBInterface, 

171 session: AsyncSession, 

172 limit: Optional[int] = None, 

173 offset: Optional[int] = None, 

174) -> Sequence[orm_models.ConcurrencyLimit]: 

175 """ 

176 Reads a concurrency limits. If used for orchestration, simultaneous read race 

177 conditions might allow the concurrency limit to be temporarily exceeded. 

178 

179 Args: 

180 session: A database session 

181 offset: Query offset 

182 limit: Query limit 

183 

184 Returns: 

185 List[orm_models.ConcurrencyLimit]: concurrency limits 

186 """ 

187 

188 query = sa.select(db.ConcurrencyLimit).order_by(db.ConcurrencyLimit.tag) 

189 

190 if offset is not None: 

191 query = query.offset(offset) 

192 if limit is not None: 

193 query = query.limit(limit) 

194 

195 result = await session.execute(query) 

196 return result.scalars().unique().all()