Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/database/query_components.py: 46%

249 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 11:21 +0000

1import datetime 1a

2from abc import ABC, abstractmethod 1a

3from collections import defaultdict 1a

4from collections.abc import Hashable, Iterable, Sequence 1a

5from functools import cached_property 1a

6from typing import ( 1a

7 Any, 

8 ClassVar, 

9 Literal, 

10 NamedTuple, 

11 Optional, 

12 Union, 

13 cast, 

14) 

15from uuid import UUID 1a

16 

17import sqlalchemy as sa 1a

18from cachetools import Cache, TTLCache 1a

19from jinja2 import Environment, PackageLoader, select_autoescape 1a

20from sqlalchemy import orm 1a

21from sqlalchemy.dialects import postgresql, sqlite 1a

22from sqlalchemy.exc import NoResultFound 1a

23from sqlalchemy.ext.asyncio import AsyncSession 1a

24from sqlalchemy.sql.type_api import TypeEngine 1a

25from typing_extensions import TypeVar 1a

26 

27from prefect.server import models, schemas 1a

28from prefect.server.database import orm_models 1a

29from prefect.server.database.dependencies import db_injector 1a

30from prefect.server.database.interface import PrefectDBInterface 1a

31from prefect.server.exceptions import FlowRunGraphTooLarge, ObjectNotFoundError 1a

32from prefect.server.schemas.graph import Edge, Graph, GraphArtifact, GraphState, Node 1a

33from prefect.server.schemas.states import StateType 1a

34from prefect.server.utilities.database import UUID as UUIDTypeDecorator 1a

35from prefect.server.utilities.database import Timestamp, bindparams_from_clause 1a

36from prefect.types._datetime import DateTime 1a

37 

38T = TypeVar("T", infer_variance=True) 1a

39 

40 

41class FlowRunGraphV2Node(NamedTuple): 1a

42 kind: Literal["flow-run", "task-run"] 1a

43 id: UUID 1a

44 label: str 1a

45 state_type: StateType 1a

46 start_time: DateTime 1a

47 end_time: Optional[DateTime] 1a

48 parent_ids: Optional[list[UUID]] 1a

49 child_ids: Optional[list[UUID]] 1a

50 encapsulating_ids: Optional[list[UUID]] 1a

51 

52 

53ONE_HOUR = 60 * 60 1a

54 

55 

56jinja_env: Environment = Environment( 1a

57 loader=PackageLoader("prefect.server.database", package_path="sql"), 

58 autoescape=select_autoescape(), 

59 trim_blocks=True, 

60) 

61 

62 

63class BaseQueryComponents(ABC): 1a

64 """ 

65 Abstract base class used to inject dialect-specific SQL operations into Prefect. 

66 """ 

67 

68 _configuration_cache: ClassVar[Cache[str, dict[str, Any]]] = TTLCache( 1a

69 maxsize=100, ttl=ONE_HOUR 

70 ) 

71 

72 def unique_key(self) -> tuple[Hashable, ...]: 1a

73 """ 

74 Returns a key used to determine whether to instantiate a new DB interface. 

75 """ 

76 return (self.__class__,) 1adbce

77 

78 # --- dialect-specific SqlAlchemy bindings 

79 

80 @abstractmethod 1a

81 def insert( 1a

82 self, obj: type[orm_models.Base] 

83 ) -> Union[postgresql.Insert, sqlite.Insert]: 

84 """dialect-specific insert statement""" 

85 

86 # --- dialect-specific JSON handling 

87 

88 @property 1a

89 @abstractmethod 1a

90 def uses_json_strings(self) -> bool: 1a

91 """specifies whether the configured dialect returns JSON as strings""" 

92 

93 @abstractmethod 1a

94 def cast_to_json(self, json_obj: sa.ColumnElement[T]) -> sa.ColumnElement[T]: 1a

95 """casts to JSON object if necessary""" 

96 

97 @abstractmethod 1a

98 def build_json_object( 1a

99 self, *args: Union[str, sa.ColumnElement[Any]] 

100 ) -> sa.ColumnElement[Any]: 

101 """builds a JSON object from sequential key-value pairs""" 

102 

103 @abstractmethod 1a

104 def json_arr_agg(self, json_array: sa.ColumnElement[Any]) -> sa.ColumnElement[Any]: 1a

105 """aggregates a JSON array""" 

106 

107 # --- dialect-optimized subqueries 

108 

109 @abstractmethod 1a

110 def make_timestamp_intervals( 110 ↛ exitline 110 didn't return from function 'make_timestamp_intervals' because 1a

111 self, 

112 start_time: datetime.datetime, 

113 end_time: datetime.datetime, 

114 interval: datetime.timedelta, 

115 ) -> sa.Select[tuple[datetime.datetime, datetime.datetime]]: ... 

116 

117 @abstractmethod 1a

118 def set_state_id_on_inserted_flow_runs_statement( 118 ↛ exitline 118 didn't return from function 'set_state_id_on_inserted_flow_runs_statement' because 1a

119 self, 

120 inserted_flow_run_ids: Sequence[UUID], 

121 insert_flow_run_states: Iterable[dict[str, Any]], 

122 ) -> sa.Update: ... 

123 

124 @db_injector 1a

125 def get_scheduled_flow_runs_from_work_queues( 1a

126 self, 

127 db: PrefectDBInterface, 

128 limit_per_queue: Optional[int] = None, 

129 work_queue_ids: Optional[list[UUID]] = None, 

130 scheduled_before: Optional[DateTime] = None, 

131 ) -> sa.Select[tuple[orm_models.FlowRun, UUID]]: 

132 """ 

133 Returns all scheduled runs in work queues, subject to provided parameters. 

134 

135 This query returns a `(orm_models.FlowRun, orm_models.WorkQueue.id)` pair; calling 

136 `result.all()` will return both; calling `result.scalars().unique().all()` 

137 will return only the flow run because it grabs the first result. 

138 """ 

139 

140 FlowRun, WorkQueue = db.FlowRun, db.WorkQueue 

141 

142 # get any work queues that have a concurrency limit, and compute available 

143 # slots as their limit less the number of running flows 

144 concurrency_queues = ( 

145 sa.select( 

146 WorkQueue.id, 

147 sa.func.greatest( 

148 0, 

149 WorkQueue.concurrency_limit - sa.func.count(FlowRun.id), 

150 ).label("available_slots"), 

151 ) 

152 .select_from(WorkQueue) 

153 .join( 

154 FlowRun, 

155 sa.and_( 

156 FlowRun.work_queue_name == WorkQueue.name, 

157 FlowRun.state_type.in_( 

158 (StateType.RUNNING, StateType.PENDING, StateType.CANCELLING) 

159 ), 

160 ), 

161 isouter=True, 

162 ) 

163 .where(WorkQueue.concurrency_limit.is_not(None)) 

164 .group_by(WorkQueue.id) 

165 .cte("concurrency_queues") 

166 ) 

167 

168 # use the available slots information to generate a join 

169 # for all scheduled runs 

170 scheduled_flow_runs, join_criteria = self._get_scheduled_flow_runs_join( 

171 work_queue_query=concurrency_queues, 

172 limit_per_queue=limit_per_queue, 

173 scheduled_before=scheduled_before, 

174 ) 

175 

176 # starting with the work queue table, join the limited queues to get the 

177 # concurrency information and the scheduled flow runs to load all applicable 

178 # runs. this will return all the scheduled runs allowed by the parameters 

179 query = ( 

180 # return a flow run and work queue id 

181 sa.select( 

182 orm.aliased(FlowRun, scheduled_flow_runs), WorkQueue.id.label("wq_id") 

183 ) 

184 .select_from(WorkQueue) 

185 .join( 

186 concurrency_queues, 

187 WorkQueue.id == concurrency_queues.c.id, 

188 isouter=True, 

189 ) 

190 .join(scheduled_flow_runs, join_criteria) 

191 .where( 

192 WorkQueue.is_paused.is_(False), 

193 WorkQueue.id.in_(work_queue_ids) if work_queue_ids else sa.true(), 

194 ) 

195 .order_by( 

196 scheduled_flow_runs.c.next_scheduled_start_time, 

197 scheduled_flow_runs.c.id, 

198 ) 

199 ) 

200 

201 return query 

202 

203 @db_injector 1a

204 def _get_scheduled_flow_runs_join( 1a

205 self, 

206 db: PrefectDBInterface, 

207 work_queue_query: sa.CTE, 

208 limit_per_queue: Optional[int], 

209 scheduled_before: Optional[DateTime], 

210 ) -> tuple[sa.FromClause, sa.ColumnExpressionArgument[bool]]: 

211 """Used by self.get_scheduled_flow_runs_from_work_queue, allowing just 

212 this function to be changed on a per-dialect basis""" 

213 

214 FlowRun = db.FlowRun 

215 

216 # precompute for readability 

217 scheduled_before_clause = ( 

218 FlowRun.next_scheduled_start_time <= scheduled_before 

219 if scheduled_before is not None 

220 else sa.true() 

221 ) 

222 

223 # get scheduled flow runs with lateral join where the limit is the 

224 # available slots per queue 

225 scheduled_flow_runs = ( 

226 sa.select(FlowRun) 

227 .where( 

228 FlowRun.work_queue_name == db.WorkQueue.name, 

229 FlowRun.state_type == StateType.SCHEDULED, 

230 scheduled_before_clause, 

231 ) 

232 .with_for_update(skip_locked=True) 

233 # priority given to runs with earlier next_scheduled_start_time 

234 .order_by(FlowRun.next_scheduled_start_time) 

235 # if null, no limit will be applied 

236 .limit(sa.func.least(limit_per_queue, work_queue_query.c.available_slots)) 

237 .lateral("scheduled_flow_runs") 

238 ) 

239 

240 # Perform a cross-join 

241 join_criteria = sa.true() 

242 

243 return scheduled_flow_runs, join_criteria 

244 

245 # ------------------------------------------------------- 

246 # Workers 

247 # ------------------------------------------------------- 

248 

249 @property 1a

250 @abstractmethod 1a

251 def _get_scheduled_flow_runs_from_work_pool_template_path(self) -> str: 1a

252 """ 

253 Template for the query to get scheduled flow runs from a work pool 

254 """ 

255 

256 @db_injector 1a

257 async def get_scheduled_flow_runs_from_work_pool( 1a

258 self, 

259 db: PrefectDBInterface, 

260 session: AsyncSession, 

261 limit: Optional[int] = None, 

262 worker_limit: Optional[int] = None, 

263 queue_limit: Optional[int] = None, 

264 work_pool_ids: Optional[list[UUID]] = None, 

265 work_queue_ids: Optional[list[UUID]] = None, 

266 scheduled_before: Optional[DateTime] = None, 

267 scheduled_after: Optional[DateTime] = None, 

268 respect_queue_priorities: bool = False, 

269 ) -> list[schemas.responses.WorkerFlowRunResponse]: 

270 template = jinja_env.get_template( 

271 self._get_scheduled_flow_runs_from_work_pool_template_path 

272 ) 

273 

274 raw_query = sa.text( 

275 template.render( 

276 work_pool_ids=work_pool_ids, 

277 work_queue_ids=work_queue_ids, 

278 respect_queue_priorities=respect_queue_priorities, 

279 scheduled_before=scheduled_before, 

280 scheduled_after=scheduled_after, 

281 ) 

282 ) 

283 

284 bindparams: list[sa.BindParameter[Any]] = [] 

285 

286 if scheduled_before: 

287 bindparams.append( 

288 sa.bindparam("scheduled_before", scheduled_before, type_=Timestamp) 

289 ) 

290 

291 if scheduled_after: 

292 bindparams.append( 

293 sa.bindparam("scheduled_after", scheduled_after, type_=Timestamp) 

294 ) 

295 

296 # if work pool IDs were provided, bind them 

297 if work_pool_ids: 

298 assert all(isinstance(i, UUID) for i in work_pool_ids) 

299 bindparams.append( 

300 sa.bindparam( 

301 "work_pool_ids", 

302 work_pool_ids, 

303 expanding=True, 

304 type_=UUIDTypeDecorator, 

305 ) 

306 ) 

307 

308 # if work queue IDs were provided, bind them 

309 if work_queue_ids: 

310 assert all(isinstance(i, UUID) for i in work_queue_ids) 

311 bindparams.append( 

312 sa.bindparam( 

313 "work_queue_ids", 

314 work_queue_ids, 

315 expanding=True, 

316 type_=UUIDTypeDecorator, 

317 ) 

318 ) 

319 

320 query = raw_query.bindparams( 

321 *bindparams, 

322 limit=1000 if limit is None else limit, 

323 worker_limit=1000 if worker_limit is None else worker_limit, 

324 queue_limit=1000 if queue_limit is None else queue_limit, 

325 ) 

326 

327 FlowRun = db.FlowRun 

328 orm_query = ( 

329 sa.select( 

330 sa.column("run_work_pool_id", UUIDTypeDecorator), 

331 sa.column("run_work_queue_id", UUIDTypeDecorator), 

332 FlowRun, 

333 ) 

334 .from_statement(query) 

335 # indicate that the state relationship isn't being loaded 

336 .options(orm.noload(FlowRun.state)) 

337 ) 

338 

339 result: sa.Result[ 

340 tuple[UUID, UUID, orm_models.FlowRun] 

341 ] = await session.execute(orm_query) 

342 

343 return [ 

344 schemas.responses.WorkerFlowRunResponse( 

345 work_pool_id=run_work_pool_id, 

346 work_queue_id=run_work_queue_id, 

347 flow_run=schemas.core.FlowRun.model_validate( 

348 flow_run, from_attributes=True 

349 ), 

350 ) 

351 for (run_work_pool_id, run_work_queue_id, flow_run) in result.t 

352 ] 

353 

354 @db_injector 1a

355 async def read_configuration_value( 1a

356 self, db: PrefectDBInterface, session: AsyncSession, key: str 

357 ) -> Optional[dict[str, Any]]: 

358 """ 

359 Read a configuration value by key. 

360 

361 Configuration values should not be changed at run time, so retrieved 

362 values are cached in memory. 

363 

364 The main use of configurations is encrypting blocks, this speeds up nested 

365 block document queries. 

366 """ 

367 Configuration = db.Configuration 1b

368 value = None 1b

369 try: 1b

370 value = self._configuration_cache[key] 1b

371 except KeyError: 1b

372 query = sa.select(Configuration).where(Configuration.key == key) 1b

373 if (configuration := await session.scalar(query)) is not None: 373 ↛ anywhereline 373 didn't jump anywhere: it always raised an exception.1bc

374 value = self._configuration_cache[key] = configuration.value 

375 return value 

376 

377 def clear_configuration_value_cache_for_key(self, key: str) -> None: 1a

378 """Removes a configuration key from the cache.""" 

379 self._configuration_cache.pop(key, None) 1c

380 

381 @cached_property 1a

382 def _flow_run_graph_v2_query(self): 1a

383 query = self._build_flow_run_graph_v2_query() 

384 param_names = set(bindparams_from_clause(query)) 

385 required = {"flow_run_id", "max_nodes", "since"} 

386 assert param_names >= required, ( 

387 "_build_flow_run_graph_v2_query result is missing required bind params: " 

388 f"{sorted(required - param_names)}" 

389 ) 

390 return query 

391 

392 @abstractmethod 1a

393 def _build_flow_run_graph_v2_query(self) -> sa.Select[FlowRunGraphV2Node]: 1a

394 """The flow run graph query, per database flavour 

395 

396 The query must accept the following bind parameters: 

397 

398 flow_run_id: UUID 

399 since: DateTime 

400 max_nodes: int 

401 

402 """ 

403 

404 @db_injector 1a

405 async def flow_run_graph_v2( 1a

406 self, 

407 db: PrefectDBInterface, 

408 session: AsyncSession, 

409 flow_run_id: UUID, 

410 since: DateTime, 

411 max_nodes: int, 

412 max_artifacts: int, 

413 ) -> Graph: 

414 """Returns the query that selects all of the nodes and edges for a flow run graph (version 2).""" 

415 FlowRun = db.FlowRun 

416 result = await session.execute( 

417 sa.select( 

418 sa.func.coalesce( 

419 FlowRun.start_time, FlowRun.expected_start_time, type_=Timestamp 

420 ), 

421 FlowRun.end_time, 

422 ).where(FlowRun.id == flow_run_id) 

423 ) 

424 try: 

425 start_time, end_time = result.t.one() 

426 except NoResultFound: 

427 raise ObjectNotFoundError(f"Flow run {flow_run_id} not found") 

428 

429 query = self._flow_run_graph_v2_query 

430 results = await session.execute( 

431 query, 

432 params=dict(flow_run_id=flow_run_id, since=since, max_nodes=max_nodes + 1), 

433 ) 

434 

435 graph_artifacts = await self._get_flow_run_graph_artifacts( 

436 db, session, flow_run_id, max_artifacts 

437 ) 

438 graph_states = await self._get_flow_run_graph_states(session, flow_run_id) 

439 

440 nodes: list[tuple[UUID, Node]] = [] 

441 root_node_ids: list[UUID] = [] 

442 

443 for row in results.t: 

444 if not row.parent_ids: 

445 root_node_ids.append(row.id) 

446 

447 nodes.append( 

448 ( 

449 row.id, 

450 Node( 

451 kind=row.kind, 

452 id=row.id, 

453 label=row.label, 

454 state_type=row.state_type, 

455 start_time=row.start_time, 

456 end_time=row.end_time, 

457 parents=[Edge(id=id) for id in row.parent_ids or []], 

458 children=[Edge(id=id) for id in row.child_ids or []], 

459 encapsulating=[ 

460 Edge(id=id) 

461 # ensure encapsulating_ids is deduplicated 

462 for id in dict.fromkeys(row.encapsulating_ids or ()) 

463 ], 

464 artifacts=graph_artifacts.get(row.id, []), 

465 ), 

466 ) 

467 ) 

468 

469 if len(nodes) > max_nodes: 

470 raise FlowRunGraphTooLarge( 

471 f"The graph of flow run {flow_run_id} has more than " 

472 f"{max_nodes} nodes." 

473 ) 

474 

475 return Graph( 

476 start_time=start_time, 

477 end_time=end_time, 

478 root_node_ids=root_node_ids, 

479 nodes=nodes, 

480 artifacts=graph_artifacts.get(None, []), 

481 states=graph_states, 

482 ) 

483 

484 async def _get_flow_run_graph_artifacts( 1ac

485 self, 

486 db: PrefectDBInterface, 

487 session: AsyncSession, 

488 flow_run_id: UUID, 

489 max_artifacts: int, 

490 ) -> dict[Optional[UUID], list[GraphArtifact]]: 

491 """Get the artifacts for a flow run grouped by task run id. 

492 

493 Does not recurse into subflows. 

494 Artifacts for the flow run without a task run id are grouped under None. 

495 """ 

496 Artifact, ArtifactCollection = db.Artifact, db.ArtifactCollection 

497 

498 query = ( 

499 sa.select(Artifact, ArtifactCollection.id.label("latest_in_collection_id")) 

500 .where(Artifact.flow_run_id == flow_run_id, Artifact.type != "result") 

501 .join( 

502 ArtifactCollection, 

503 onclause=sa.and_( 

504 ArtifactCollection.key == Artifact.key, 

505 ArtifactCollection.latest_id == Artifact.id, 

506 ), 

507 isouter=True, 

508 ) 

509 .order_by(Artifact.created.asc()) 

510 .limit(max_artifacts) 

511 ) 

512 

513 results = await session.execute(query) 

514 

515 artifacts_by_task: dict[Optional[UUID], list[GraphArtifact]] = defaultdict(list) 

516 for artifact, latest_in_collection_id in results.t: 

517 artifacts_by_task[artifact.task_run_id].append( 

518 GraphArtifact( 

519 id=artifact.id, 

520 created=artifact.created, 

521 key=artifact.key, 

522 type=artifact.type, 

523 # We're only using the data field for progress artifacts for now 

524 data=artifact.data if artifact.type == "progress" else None, 

525 is_latest=artifact.key is None 

526 or latest_in_collection_id is not None, # pyright: ignore[reportUnnecessaryComparison] 

527 ) 

528 ) 

529 

530 return dict(artifacts_by_task) 

531 

532 async def _get_flow_run_graph_states( 1a

533 self, session: AsyncSession, flow_run_id: UUID 

534 ) -> list[GraphState]: 

535 """Get the flow run states for a flow run graph.""" 

536 states = await models.flow_run_states.read_flow_run_states(session, flow_run_id) 

537 return [ 

538 GraphState.model_validate(state, from_attributes=True) for state in states 

539 ] 

540 

541 

542class AsyncPostgresQueryComponents(BaseQueryComponents): 1a

543 # --- Postgres-specific SqlAlchemy bindings 

544 

545 def insert(self, obj: type[orm_models.Base]) -> postgresql.Insert: 1a

546 return postgresql.insert(obj) 

547 

548 # --- Postgres-specific JSON handling 

549 

550 @property 1a

551 def uses_json_strings(self) -> bool: 1a

552 return False 

553 

554 def cast_to_json(self, json_obj: sa.ColumnElement[T]) -> sa.ColumnElement[T]: 1a

555 return json_obj 

556 

557 def build_json_object( 1a

558 self, *args: Union[str, sa.ColumnElement[Any]] 

559 ) -> sa.ColumnElement[Any]: 

560 return sa.func.jsonb_build_object(*args) 

561 

562 def json_arr_agg(self, json_array: sa.ColumnElement[Any]) -> sa.ColumnElement[Any]: 1a

563 return sa.func.jsonb_agg(json_array) 

564 

565 # --- Postgres-optimized subqueries 

566 

567 def make_timestamp_intervals( 1a

568 self, 

569 start_time: datetime.datetime, 

570 end_time: datetime.datetime, 

571 interval: datetime.timedelta, 

572 ) -> sa.Select[tuple[datetime.datetime, datetime.datetime]]: 

573 dt = sa.func.generate_series( 

574 start_time, end_time, interval, type_=Timestamp() 

575 ).column_valued("dt") 

576 return ( 

577 sa.select( 

578 dt.label("interval_start"), 

579 sa.type_coerce( 

580 dt + sa.bindparam("interval", interval, type_=sa.Interval()), 

581 type_=Timestamp(), 

582 ).label("interval_end"), 

583 ) 

584 .where(dt < end_time) 

585 .limit(500) # grab at most 500 intervals 

586 ) 

587 

588 @db_injector 1a

589 def set_state_id_on_inserted_flow_runs_statement( 1a

590 self, 

591 db: PrefectDBInterface, 

592 inserted_flow_run_ids: Sequence[UUID], 

593 insert_flow_run_states: Iterable[dict[str, Any]], 

594 ) -> sa.Update: 

595 """Given a list of flow run ids and associated states, set the state_id 

596 to the appropriate state for all flow runs""" 

597 # postgres supports `UPDATE ... FROM` syntax 

598 FlowRun, FlowRunState = db.FlowRun, db.FlowRunState 

599 stmt = ( 

600 sa.update(FlowRun) 

601 .where( 

602 FlowRun.id.in_(inserted_flow_run_ids), 

603 FlowRunState.flow_run_id == FlowRun.id, 

604 FlowRunState.id.in_([r["id"] for r in insert_flow_run_states]), 

605 ) 

606 .values(state_id=FlowRunState.id) 

607 # no need to synchronize as these flow runs are entirely new 

608 .execution_options(synchronize_session=False) 

609 ) 

610 return stmt 

611 

612 @property 1a

613 def _get_scheduled_flow_runs_from_work_pool_template_path(self) -> str: 1a

614 """ 

615 Template for the query to get scheduled flow runs from a work pool 

616 """ 

617 return "postgres/get-runs-from-worker-queues.sql.jinja" 

618 

619 @db_injector 1a

620 def _build_flow_run_graph_v2_query( 1a

621 self, db: PrefectDBInterface 

622 ) -> sa.Select[FlowRunGraphV2Node]: 

623 """Postgresql version of the V2 FlowRun graph data query 

624 

625 This SQLA query is built just once and then cached per DB interface 

626 

627 """ 

628 # the parameters this query takes as inputs 

629 param_flow_run_id = sa.bindparam("flow_run_id", type_=UUIDTypeDecorator) 

630 param_since = sa.bindparam("since", type_=Timestamp) 

631 param_max_nodes = sa.bindparam("max_nodes", type_=sa.Integer) 

632 

633 Flow, FlowRun, TaskRun = db.Flow, db.FlowRun, db.TaskRun 

634 input = sa.func.jsonb_each(TaskRun.task_inputs).table_valued( 

635 "key", "value", name="input" 

636 ) 

637 argument = ( 

638 sa.func.jsonb_array_elements(input.c.value, type_=postgresql.JSONB()) 

639 .table_valued(sa.column("value", postgresql.JSONB())) 

640 .render_derived(name="argument") 

641 ) 

642 edges = ( 

643 sa.select( 

644 sa.case((FlowRun.id.is_not(None), "flow-run"), else_="task-run").label( 

645 "kind" 

646 ), 

647 sa.func.coalesce(FlowRun.id, TaskRun.id).label("id"), 

648 sa.func.coalesce(Flow.name + " / " + FlowRun.name, TaskRun.name).label( 

649 "label" 

650 ), 

651 sa.func.coalesce(FlowRun.state_type, TaskRun.state_type).label( 

652 "state_type" 

653 ), 

654 sa.func.coalesce( 

655 FlowRun.start_time, 

656 FlowRun.expected_start_time, 

657 TaskRun.start_time, 

658 TaskRun.expected_start_time, 

659 ).label("start_time"), 

660 sa.func.coalesce( 

661 FlowRun.end_time, 

662 TaskRun.end_time, 

663 sa.case( 

664 ( 

665 TaskRun.state_type == StateType.COMPLETED, 

666 TaskRun.expected_start_time, 

667 ), 

668 else_=sa.null(), 

669 ), 

670 ).label("end_time"), 

671 sa.cast(argument.c.value["id"].astext, type_=UUIDTypeDecorator).label( 

672 "parent" 

673 ), 

674 (input.c.key == "__parents__").label("has_encapsulating_task"), 

675 ) 

676 .join_from(TaskRun, input, onclause=sa.true(), isouter=True) 

677 .join(argument, onclause=sa.true(), isouter=True) 

678 .join( 

679 FlowRun, 

680 isouter=True, 

681 onclause=FlowRun.parent_task_run_id == TaskRun.id, 

682 ) 

683 .join(Flow, isouter=True, onclause=Flow.id == FlowRun.flow_id) 

684 .where( 

685 TaskRun.flow_run_id == param_flow_run_id, 

686 TaskRun.state_type != StateType.PENDING, 

687 sa.func.coalesce( 

688 FlowRun.start_time, 

689 FlowRun.expected_start_time, 

690 TaskRun.start_time, 

691 TaskRun.expected_start_time, 

692 ).is_not(None), 

693 ) 

694 # -- the order here is important to speed up building the two sets of 

695 # -- edges in the with_parents and with_children CTEs below 

696 .order_by(sa.func.coalesce(FlowRun.id, TaskRun.id)) 

697 ).cte("edges") 

698 children, parents = edges.alias("children"), edges.alias("parents") 

699 with_encapsulating = ( 

700 sa.select( 

701 children.c.id, 

702 sa.func.array_agg( 

703 postgresql.aggregate_order_by(parents.c.id, parents.c.start_time) 

704 ).label("encapsulating_ids"), 

705 ) 

706 .join(parents, onclause=parents.c.id == children.c.parent) 

707 .where(children.c.has_encapsulating_task.is_(True)) 

708 .group_by(children.c.id) 

709 ).cte("with_encapsulating") 

710 with_parents = ( 

711 sa.select( 

712 children.c.id, 

713 sa.func.array_agg( 

714 postgresql.aggregate_order_by(parents.c.id, parents.c.start_time) 

715 ).label("parent_ids"), 

716 ) 

717 .join(parents, onclause=parents.c.id == children.c.parent) 

718 .where(children.c.has_encapsulating_task.is_distinct_from(True)) 

719 .group_by(children.c.id) 

720 .cte("with_parents") 

721 ) 

722 with_children = ( 

723 sa.select( 

724 parents.c.id, 

725 sa.func.array_agg( 

726 postgresql.aggregate_order_by(children.c.id, children.c.start_time) 

727 ).label("child_ids"), 

728 ) 

729 .join(children, onclause=children.c.parent == parents.c.id) 

730 .where(children.c.has_encapsulating_task.is_distinct_from(True)) 

731 .group_by(parents.c.id) 

732 .cte("with_children") 

733 ) 

734 

735 graph = ( 

736 sa.select( 

737 edges.c.kind, 

738 edges.c.id, 

739 edges.c.label, 

740 edges.c.state_type, 

741 edges.c.start_time, 

742 edges.c.end_time, 

743 with_parents.c.parent_ids, 

744 with_children.c.child_ids, 

745 with_encapsulating.c.encapsulating_ids, 

746 ) 

747 .distinct(edges.c.id) 

748 .join(with_parents, isouter=True, onclause=with_parents.c.id == edges.c.id) 

749 .join( 

750 with_children, isouter=True, onclause=with_children.c.id == edges.c.id 

751 ) 

752 .join( 

753 with_encapsulating, 

754 isouter=True, 

755 onclause=with_encapsulating.c.id == edges.c.id, 

756 ) 

757 .cte("nodes") 

758 ) 

759 query = ( 

760 sa.select( 

761 graph.c.kind, 

762 graph.c.id, 

763 graph.c.label, 

764 graph.c.state_type, 

765 graph.c.start_time, 

766 graph.c.end_time, 

767 graph.c.parent_ids, 

768 graph.c.child_ids, 

769 graph.c.encapsulating_ids, 

770 ) 

771 .where(sa.or_(graph.c.end_time.is_(None), graph.c.end_time >= param_since)) 

772 .order_by(graph.c.start_time, graph.c.end_time) 

773 .limit(param_max_nodes) 

774 ) 

775 return cast(sa.Select[FlowRunGraphV2Node], query) 

776 

777 

778class UUIDList(sa.TypeDecorator[list[UUID]]): 1a

779 """Map a JSON list of strings back to a list of UUIDs at the result loading stage""" 

780 

781 impl: Union[TypeEngine[Any], type[TypeEngine[Any]]] = sa.JSON() 1a

782 cache_ok: Optional[bool] = True 1a

783 

784 def process_result_value( 1a

785 self, value: Optional[list[Union[str, UUID]]], dialect: sa.Dialect 

786 ) -> Optional[list[UUID]]: 

787 if value is None: 

788 return value 

789 return [v if isinstance(v, UUID) else UUID(v) for v in value] 

790 

791 

792class AioSqliteQueryComponents(BaseQueryComponents): 1a

793 # --- Sqlite-specific SqlAlchemy bindings 

794 

795 def insert(self, obj: type[orm_models.Base]) -> sqlite.Insert: 1a

796 return sqlite.insert(obj) 1b

797 

798 # --- Sqlite-specific JSON handling 

799 

800 @property 1a

801 def uses_json_strings(self) -> bool: 1a

802 return True 

803 

804 def cast_to_json(self, json_obj: sa.ColumnElement[T]) -> sa.ColumnElement[T]: 1a

805 return sa.func.json(json_obj) 

806 

807 def build_json_object( 1a

808 self, *args: Union[str, sa.ColumnElement[Any]] 

809 ) -> sa.ColumnElement[Any]: 

810 return sa.func.json_object(*args) 

811 

812 def json_arr_agg(self, json_array: sa.ColumnElement[Any]) -> sa.ColumnElement[Any]: 1a

813 return sa.func.json_group_array(json_array) 

814 

815 # --- Sqlite-optimized subqueries 

816 

817 def make_timestamp_intervals( 1a

818 self, 

819 start_time: datetime.datetime, 

820 end_time: datetime.datetime, 

821 interval: datetime.timedelta, 

822 ) -> sa.Select[tuple[datetime.datetime, datetime.datetime]]: 

823 start = sa.bindparam("start_time", start_time, Timestamp) 

824 # subtract interval because recursive where clauses are effectively evaluated on a t-1 lag 

825 stop = sa.bindparam("end_time", end_time - interval, Timestamp) 

826 step = sa.bindparam("interval", interval, sa.Interval) 

827 

828 one = sa.literal(1, literal_execute=True) 

829 

830 # recursive CTE to mimic the behavior of `generate_series`, which is 

831 # only available as a compiled extension 

832 base_case = sa.select( 

833 start.label("interval_start"), 

834 sa.func.date_add(start, step).label("interval_end"), 

835 one.label("counter"), 

836 ).cte(recursive=True) 

837 recursive_case = sa.select( 

838 base_case.c.interval_end, 

839 sa.func.date_add(base_case.c.interval_end, step), 

840 base_case.c.counter + one, 

841 ).where( 

842 base_case.c.interval_start < stop, 

843 # don't compute more than 500 intervals 

844 base_case.c.counter < 500, 

845 ) 

846 cte = base_case.union_all(recursive_case) 

847 

848 return sa.select(cte.c.interval_start, cte.c.interval_end) 

849 

850 @db_injector 1a

851 def set_state_id_on_inserted_flow_runs_statement( 1a

852 self, 

853 db: PrefectDBInterface, 

854 inserted_flow_run_ids: Sequence[UUID], 

855 insert_flow_run_states: Iterable[dict[str, Any]], 

856 ) -> sa.Update: 

857 """Given a list of flow run ids and associated states, set the state_id 

858 to the appropriate state for all flow runs""" 

859 fr_model, frs_model = db.FlowRun, db.FlowRunState 

860 # sqlite requires a correlated subquery to update from another table 

861 subquery = ( 

862 sa.select(frs_model.id) 

863 .where( 

864 frs_model.flow_run_id == fr_model.id, 

865 frs_model.id.in_([r["id"] for r in insert_flow_run_states]), 

866 ) 

867 .limit(1) 

868 .scalar_subquery() 

869 ) 

870 stmt = ( 

871 sa.update(fr_model) 

872 .where( 

873 fr_model.id.in_(inserted_flow_run_ids), 

874 ) 

875 .values(state_id=subquery) 

876 # no need to synchronize as these flow runs are entirely new 

877 .execution_options(synchronize_session=False) 

878 ) 

879 return stmt 

880 

881 @db_injector 1a

882 def _get_scheduled_flow_runs_join( 1a

883 self, 

884 db: PrefectDBInterface, 

885 work_queue_query: sa.CTE, 

886 limit_per_queue: Optional[int], 

887 scheduled_before: Optional[DateTime], 

888 ) -> tuple[sa.FromClause, sa.ColumnExpressionArgument[bool]]: 

889 # precompute for readability 

890 FlowRun = db.FlowRun 

891 

892 scheduled_before_clause = ( 

893 FlowRun.next_scheduled_start_time <= scheduled_before 

894 if scheduled_before is not None 

895 else sa.true() 

896 ) 

897 

898 # select scheduled flow runs, ordered by scheduled start time per queue 

899 scheduled_flow_runs = ( 

900 sa.select( 

901 ( 

902 sa.func.row_number() 

903 .over( 

904 partition_by=[FlowRun.work_queue_name], 

905 order_by=FlowRun.next_scheduled_start_time, 

906 ) 

907 .label("rank") 

908 ), 

909 FlowRun, 

910 ) 

911 .where(FlowRun.state_type == StateType.SCHEDULED, scheduled_before_clause) 

912 .subquery("scheduled_flow_runs") 

913 ) 

914 

915 # sqlite short-circuits the `min` comparison on nulls, so we use `999999` 

916 # as an "unlimited" limit. 

917 limit = 999999 if limit_per_queue is None else limit_per_queue 

918 

919 # in the join, only keep flow runs whose rank is less than or equal to the 

920 # available slots for each queue 

921 join_criteria = sa.and_( 

922 scheduled_flow_runs.c.work_queue_name == db.WorkQueue.name, 

923 scheduled_flow_runs.c.rank 

924 <= sa.func.min( 

925 sa.func.coalesce(work_queue_query.c.available_slots, limit), limit 

926 ), 

927 ) 

928 return scheduled_flow_runs, join_criteria 

929 

930 # ------------------------------------------------------- 

931 # Workers 

932 # ------------------------------------------------------- 

933 

934 @property 1a

935 def _get_scheduled_flow_runs_from_work_pool_template_path(self) -> str: 1a

936 """ 

937 Template for the query to get scheduled flow runs from a work pool 

938 """ 

939 return "sqlite/get-runs-from-worker-queues.sql.jinja" 

940 

941 @db_injector 1a

942 def _build_flow_run_graph_v2_query( 1a

943 self, db: PrefectDBInterface 

944 ) -> sa.Select[FlowRunGraphV2Node]: 

945 """Postgresql version of the V2 FlowRun graph data query 

946 

947 This SQLA query is built just once and then cached per DB interface 

948 

949 """ 

950 # the parameters this query takes as inputs 

951 param_flow_run_id = sa.bindparam("flow_run_id", type_=UUIDTypeDecorator) 

952 param_since = sa.bindparam("since", type_=Timestamp) 

953 param_max_nodes = sa.bindparam("max_nodes", type_=sa.Integer) 

954 

955 Flow, FlowRun, TaskRun = db.Flow, db.FlowRun, db.TaskRun 

956 input = sa.func.json_each(TaskRun.task_inputs).table_valued( 

957 "key", "value", name="input" 

958 ) 

959 argument = sa.func.json_each( 

960 input.c.value, type_=postgresql.JSON() 

961 ).table_valued("key", sa.column("value", postgresql.JSON()), name="argument") 

962 edges = ( 

963 sa.select( 

964 sa.case((FlowRun.id.is_not(None), "flow-run"), else_="task-run").label( 

965 "kind" 

966 ), 

967 sa.func.coalesce(FlowRun.id, TaskRun.id).label("id"), 

968 sa.func.coalesce(Flow.name + " / " + FlowRun.name, TaskRun.name).label( 

969 "label" 

970 ), 

971 sa.func.coalesce(FlowRun.state_type, TaskRun.state_type).label( 

972 "state_type" 

973 ), 

974 sa.func.coalesce( 

975 FlowRun.start_time, 

976 FlowRun.expected_start_time, 

977 TaskRun.start_time, 

978 TaskRun.expected_start_time, 

979 ).label("start_time"), 

980 sa.func.coalesce( 

981 FlowRun.end_time, 

982 TaskRun.end_time, 

983 sa.case( 

984 ( 

985 TaskRun.state_type == StateType.COMPLETED, 

986 TaskRun.expected_start_time, 

987 ), 

988 else_=sa.null(), 

989 ), 

990 ).label("end_time"), 

991 argument.c.value["id"].astext.label("parent"), 

992 (input.c.key == "__parents__").label("has_encapsulating_task"), 

993 ) 

994 .join_from(TaskRun, input, onclause=sa.true(), isouter=True) 

995 .join(argument, onclause=sa.true(), isouter=True) 

996 .join( 

997 FlowRun, 

998 isouter=True, 

999 onclause=FlowRun.parent_task_run_id == TaskRun.id, 

1000 ) 

1001 .join(Flow, isouter=True, onclause=Flow.id == FlowRun.flow_id) 

1002 .where( 

1003 TaskRun.flow_run_id == param_flow_run_id, 

1004 TaskRun.state_type != StateType.PENDING, 

1005 sa.func.coalesce( 

1006 FlowRun.start_time, 

1007 FlowRun.expected_start_time, 

1008 TaskRun.start_time, 

1009 TaskRun.expected_start_time, 

1010 ).is_not(None), 

1011 ) 

1012 # -- the order here is important to speed up building the two sets of 

1013 # -- edges in the with_parents and with_children CTEs below 

1014 .order_by(sa.func.coalesce(FlowRun.id, TaskRun.id)) 

1015 ).cte("edges") 

1016 children, parents = edges.alias("children"), edges.alias("parents") 

1017 with_encapsulating = ( 

1018 sa.select( 

1019 children.c.id, 

1020 sa.func.json_group_array(parents.c.id).label("encapsulating_ids"), 

1021 ) 

1022 .join(parents, onclause=parents.c.id == children.c.parent) 

1023 .where(children.c.has_encapsulating_task.is_(True)) 

1024 .group_by(children.c.id) 

1025 ).cte("with_encapsulating") 

1026 with_parents = ( 

1027 sa.select( 

1028 children.c.id, 

1029 sa.func.json_group_array(parents.c.id).label("parent_ids"), 

1030 ) 

1031 .join(parents, onclause=parents.c.id == children.c.parent) 

1032 .where(children.c.has_encapsulating_task.is_distinct_from(True)) 

1033 .group_by(children.c.id) 

1034 .cte("with_parents") 

1035 ) 

1036 with_children = ( 

1037 sa.select( 

1038 parents.c.id, sa.func.json_group_array(children.c.id).label("child_ids") 

1039 ) 

1040 .join(children, onclause=children.c.parent == parents.c.id) 

1041 .where(children.c.has_encapsulating_task.is_distinct_from(True)) 

1042 .group_by(parents.c.id) 

1043 .cte("with_children") 

1044 ) 

1045 

1046 graph = ( 

1047 sa.select( 

1048 edges.c.kind, 

1049 edges.c.id, 

1050 edges.c.label, 

1051 edges.c.state_type, 

1052 edges.c.start_time, 

1053 edges.c.end_time, 

1054 with_parents.c.parent_ids, 

1055 with_children.c.child_ids, 

1056 with_encapsulating.c.encapsulating_ids, 

1057 ) 

1058 .distinct() 

1059 .join(with_parents, isouter=True, onclause=with_parents.c.id == edges.c.id) 

1060 .join( 

1061 with_children, isouter=True, onclause=with_children.c.id == edges.c.id 

1062 ) 

1063 .join( 

1064 with_encapsulating, 

1065 isouter=True, 

1066 onclause=with_encapsulating.c.id == edges.c.id, 

1067 ) 

1068 .cte("nodes") 

1069 ) 

1070 

1071 query = ( 

1072 sa.select( 

1073 graph.c.kind, 

1074 graph.c.id, 

1075 graph.c.label, 

1076 graph.c.state_type, 

1077 graph.c.start_time, 

1078 graph.c.end_time, 

1079 sa.type_coerce(graph.c.parent_ids, UUIDList), 

1080 sa.type_coerce(graph.c.child_ids, UUIDList), 

1081 sa.type_coerce(graph.c.encapsulating_ids, UUIDList), 

1082 ) 

1083 .where(sa.or_(graph.c.end_time.is_(None), graph.c.end_time >= param_since)) 

1084 .order_by(graph.c.start_time, graph.c.end_time) 

1085 .limit(param_max_nodes) 

1086 ) 

1087 return cast(sa.Select[FlowRunGraphV2Node], query)