Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/events/storage/database.py: 20%

132 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 13:38 +0000

1from typing import TYPE_CHECKING, Any, Generator, Optional, Sequence 1a

2 

3import pydantic 1a

4import sqlalchemy as sa 1a

5from sqlalchemy.ext.asyncio import AsyncSession 1a

6from sqlalchemy.orm import aliased 1a

7 

8from prefect.logging.loggers import get_logger 1a

9from prefect.server.database import ( 1a

10 PrefectDBInterface, 

11 db_injector, 

12 provide_database_interface, 

13) 

14from prefect.server.events.counting import Countable, TimeUnit 1a

15from prefect.server.events.filters import EventFilter, EventOrder 1a

16from prefect.server.events.schemas.events import EventCount, ReceivedEvent 1a

17from prefect.server.events.storage import ( 1a

18 INTERACTIVE_PAGE_SIZE, 

19 from_page_token, 

20 process_time_based_counts, 

21 to_page_token, 

22) 

23from prefect.server.utilities.database import get_dialect 1a

24from prefect.settings import PREFECT_API_DATABASE_CONNECTION_URL 1a

25 

26if TYPE_CHECKING: 26 ↛ 27line 26 didn't jump to line 27 because the condition on line 26 was never true1a

27 import logging 

28 

29 from prefect.server.database.orm_models import ORMEvent 

30 

31logger: "logging.Logger" = get_logger(__name__) 1a

32 

33 

34@db_injector 1a

35def build_distinct_queries( 1a

36 db: PrefectDBInterface, 

37 events_filter: EventFilter, 

38) -> list[sa.Column["ORMEvent"]]: 

39 distinct_fields: list[str] = [] 

40 if events_filter.resource and events_filter.resource.distinct: 

41 distinct_fields.append("resource_id") 

42 if distinct_fields: 

43 return [getattr(db.Event, field) for field in distinct_fields] 

44 return [] 

45 

46 

47async def query_events( 1a

48 session: AsyncSession, 

49 filter: EventFilter, 

50 page_size: int = INTERACTIVE_PAGE_SIZE, 

51) -> tuple[list[ReceivedEvent], int, Optional[str]]: 

52 assert isinstance(session, AsyncSession) 

53 count = await raw_count_events(session, filter) 

54 page = await read_events(session, filter, limit=page_size, offset=0) 

55 events = [ReceivedEvent.model_validate(e, from_attributes=True) for e in page] 

56 page_token = to_page_token(filter, count, page_size, 0) 

57 return events, count, page_token 

58 

59 

60async def query_next_page( 1a

61 session: AsyncSession, 

62 page_token: str, 

63) -> tuple[list[ReceivedEvent], int, Optional[str]]: 

64 assert isinstance(session, AsyncSession) 

65 filter, count, page_size, offset = from_page_token(page_token) 

66 page = await read_events(session, filter, limit=page_size, offset=offset) 

67 events = [ReceivedEvent.model_validate(e, from_attributes=True) for e in page] 

68 next_token = to_page_token(filter, count, page_size, offset) 

69 return events, count, next_token 

70 

71 

72async def count_events( 1a

73 session: AsyncSession, 

74 filter: EventFilter, 

75 countable: Countable, 

76 time_unit: TimeUnit, 

77 time_interval: float, 

78) -> list[EventCount]: 

79 time_unit.validate_buckets( 

80 filter.occurred.since, filter.occurred.until, time_interval 

81 ) 

82 results = await session.execute( 

83 countable.get_database_query(filter, time_unit, time_interval) 

84 ) 

85 

86 counts = pydantic.TypeAdapter(list[EventCount]).validate_python( 

87 results.mappings().all() 

88 ) 

89 

90 if countable in (Countable.day, Countable.time): 

91 counts = process_time_based_counts(filter, time_unit, time_interval, counts) 

92 

93 return counts 

94 

95 

96@db_injector 1a

97async def raw_count_events( 1a

98 db: PrefectDBInterface, 

99 session: AsyncSession, 

100 events_filter: EventFilter, 

101) -> int: 

102 """ 

103 Count events from the database with the given filter. 

104 

105 Only returns the count and does not return any addition metadata. For additional 

106 metadata, use `count_events`. 

107 

108 Args: 

109 session: a database session 

110 events_filter: filter criteria for events 

111 

112 Returns: 

113 The count of events in the database that match the filter criteria. 

114 """ 

115 # start with sa.func.count(), don't sa.select 

116 select_events_query = sa.select(sa.func.count()).select_from(db.Event) 

117 

118 if distinct_fields := build_distinct_queries(events_filter): 

119 select_events_query = sa.select( 

120 sa.func.count(sa.distinct(*distinct_fields)) 

121 ).select_from(db.Event) 

122 

123 select_events_query_result = await session.execute( 

124 select_events_query.where(sa.and_(*events_filter.build_where_clauses())) 

125 ) 

126 return select_events_query_result.scalar() or 0 

127 

128 

129@db_injector 1a

130async def read_events( 1a

131 db: PrefectDBInterface, 

132 session: AsyncSession, 

133 events_filter: EventFilter, 

134 limit: Optional[int] = None, 

135 offset: Optional[int] = None, 

136) -> Sequence["ORMEvent"]: 

137 """ 

138 Read events from the Postgres database. 

139 

140 Args: 

141 session: a Postgres events session. 

142 filter: filter criteria for events. 

143 limit: limit for the query. 

144 offset: offset for the query. 

145 

146 Returns: 

147 A list of events ORM objects. 

148 """ 

149 # Always order by occurred timestamp, with placeholder for order direction 

150 order = sa.desc if events_filter.order == EventOrder.DESC else sa.asc 

151 

152 # Check if distinct fields are provided 

153 if distinct_fields := build_distinct_queries(events_filter): 

154 # Define window function 

155 window_function = ( 

156 sa.func.row_number() 

157 .over(partition_by=distinct_fields, order_by=order(db.Event.occurred)) 

158 .label("row_number") 

159 ) 

160 # Create a subquery with the window function 

161 subquery = ( 

162 sa.select(db.Event, window_function) 

163 .where( 

164 sa.and_( 

165 *events_filter.build_where_clauses() 

166 ) # Ensure the same filters are applied here 

167 ) 

168 .subquery() 

169 ) 

170 

171 # Alias the subquery for easier column references 

172 aliased_table = aliased(db.Event, subquery) 

173 

174 # Create the final query from the subquery, filtering to get only rows with row_number = 1 

175 select_events_query = sa.select(aliased_table).where(subquery.c.row_number == 1) 

176 

177 # Order by the occurred timestamp 

178 select_events_query = select_events_query.order_by(order(subquery.c.occurred)) 

179 

180 else: 

181 # If no distinct fields are provided, create a query for all events 

182 select_events_query = sa.select(db.Event).where( 

183 sa.and_(*events_filter.build_where_clauses()) 

184 ) 

185 # Order by the occurred timestamp 

186 select_events_query = select_events_query.order_by(order(db.Event.occurred)) 

187 

188 if limit is not None: 

189 limit = max(0, min(limit, events_filter.logical_limit)) 

190 select_events_query = select_events_query.limit(limit=limit) 

191 if offset is not None: 

192 select_events_query = select_events_query.offset(offset=offset) 

193 

194 logger.debug("Running PostgreSQL query: %s", select_events_query) 

195 

196 select_events_query_result = await session.execute(select_events_query) 

197 return select_events_query_result.scalars().unique().all() 

198 

199 

200async def write_events(session: AsyncSession, events: list[ReceivedEvent]) -> None: 1a

201 """ 

202 Write events to the database. 

203 

204 Args: 

205 session: a database session 

206 events: the events to insert 

207 """ 

208 if events: 

209 dialect = get_dialect(PREFECT_API_DATABASE_CONNECTION_URL.value()) 

210 if dialect.name == "postgresql": 

211 await _write_postgres_events(session, events) 

212 else: 

213 await _write_sqlite_events(session, events) 

214 

215 

216@db_injector 1a

217async def _write_sqlite_events( 1a

218 db: PrefectDBInterface, session: AsyncSession, events: list[ReceivedEvent] 

219) -> None: 

220 """ 

221 Write events to the SQLite database. 

222 

223 SQLite does not support the `RETURNING` clause with SQLAlchemy < 2, so we need to 

224 check for existing events before inserting them. 

225 

226 Args: 

227 session: a SQLite events session 

228 events: the events to insert 

229 """ 

230 for batch in _in_safe_batches(events): 

231 event_ids = {event.id for event in batch} 

232 result = await session.scalars( 

233 sa.select(db.Event.id).where(db.Event.id.in_(event_ids)) 

234 ) 

235 existing_event_ids = list(result.all()) 

236 events_to_insert = [ 

237 event for event in batch if event.id not in existing_event_ids 

238 ] 

239 event_rows = [event.as_database_row() for event in events_to_insert] 

240 await session.execute(db.queries.insert(db.Event).values(event_rows)) 

241 

242 resource_rows: list[dict[str, Any]] = [] 

243 for event in events_to_insert: 

244 resource_rows.extend(event.as_database_resource_rows()) 

245 

246 if not resource_rows: 

247 continue 

248 

249 await session.execute(db.queries.insert(db.EventResource).values(resource_rows)) 

250 

251 

252@db_injector 1a

253async def _write_postgres_events( 1a

254 db: PrefectDBInterface, session: AsyncSession, events: list[ReceivedEvent] 

255) -> None: 

256 """ 

257 Write events to the Postgres database. 

258 

259 Args: 

260 session: a Postgres events session 

261 events: the events to insert 

262 """ 

263 for batch in _in_safe_batches(events): 

264 event_rows = [event.as_database_row() for event in batch] 

265 result = await session.scalars( 

266 db.queries.insert(db.Event) 

267 .on_conflict_do_nothing() 

268 .returning(db.Event.id) 

269 .values(event_rows) 

270 ) 

271 inserted_event_ids = set(result.all()) 

272 

273 resource_rows: list[dict[str, Any]] = [] 

274 for event in batch: 

275 if event.id not in inserted_event_ids: 

276 # if the event wasn't inserted, this means the event was a duplicate, so 

277 # we will skip adding its related resources, as they would have been 

278 # inserted already 

279 continue 

280 resource_rows.extend(event.as_database_resource_rows()) 

281 

282 if not resource_rows: 

283 continue 

284 

285 await session.execute(db.queries.insert(db.EventResource).values(resource_rows)) 

286 

287 

288def get_max_query_parameters() -> int: 1a

289 dialect = get_dialect(PREFECT_API_DATABASE_CONNECTION_URL.value()) 

290 if dialect.name == "postgresql": 

291 return 32_767 

292 else: 

293 return 999 

294 

295 

296# Events require a fixed number of parameters per event,... 

297def get_number_of_event_fields() -> int: 1a

298 return provide_database_interface().Event.__table__.columns.__len__() 

299 

300 

301# ...plus a variable number of parameters per resource... 

302def get_number_of_resource_fields() -> int: 1a

303 return provide_database_interface().EventResource.__table__.columns.__len__() 

304 

305 

306def _in_safe_batches( 1a

307 events: list[ReceivedEvent], 

308) -> Generator[list[ReceivedEvent], None, None]: 

309 batch = [] 

310 parameters_used = 0 

311 max_query_parameters = get_max_query_parameters() 

312 number_of_event_fields = get_number_of_event_fields() 

313 number_of_resource_fields = get_number_of_resource_fields() 

314 

315 for event in events: 

316 these_parameters = number_of_event_fields + ( 

317 len(event.involved_resources) * number_of_resource_fields 

318 ) 

319 if parameters_used + these_parameters < max_query_parameters: 

320 batch.append(event) 

321 parameters_used += these_parameters 

322 else: 

323 yield batch 

324 batch = [event] 

325 parameters_used = 0 

326 

327 if batch: 

328 yield batch