Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/database/query_components.py: 54%
249 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
1import datetime 1a
2from abc import ABC, abstractmethod 1a
3from collections import defaultdict 1a
4from collections.abc import Hashable, Iterable, Sequence 1a
5from functools import cached_property 1a
6from typing import ( 1a
7 Any,
8 ClassVar,
9 Literal,
10 NamedTuple,
11 Optional,
12 Union,
13 cast,
14)
15from uuid import UUID 1a
17import sqlalchemy as sa 1a
18from cachetools import Cache, TTLCache 1a
19from jinja2 import Environment, PackageLoader, select_autoescape 1a
20from sqlalchemy import orm 1a
21from sqlalchemy.dialects import postgresql, sqlite 1a
22from sqlalchemy.exc import NoResultFound 1a
23from sqlalchemy.ext.asyncio import AsyncSession 1a
24from sqlalchemy.sql.type_api import TypeEngine 1a
25from typing_extensions import TypeVar 1a
27from prefect.server import models, schemas 1a
28from prefect.server.database import orm_models 1a
29from prefect.server.database.dependencies import db_injector 1a
30from prefect.server.database.interface import PrefectDBInterface 1a
31from prefect.server.exceptions import FlowRunGraphTooLarge, ObjectNotFoundError 1a
32from prefect.server.schemas.graph import Edge, Graph, GraphArtifact, GraphState, Node 1a
33from prefect.server.schemas.states import StateType 1a
34from prefect.server.utilities.database import UUID as UUIDTypeDecorator 1a
35from prefect.server.utilities.database import Timestamp, bindparams_from_clause 1a
36from prefect.types._datetime import DateTime 1a
38T = TypeVar("T", infer_variance=True) 1a
41class FlowRunGraphV2Node(NamedTuple): 1a
42 kind: Literal["flow-run", "task-run"] 1a
43 id: UUID 1a
44 label: str 1a
45 state_type: StateType 1a
46 start_time: DateTime 1a
47 end_time: Optional[DateTime] 1a
48 parent_ids: Optional[list[UUID]] 1a
49 child_ids: Optional[list[UUID]] 1a
50 encapsulating_ids: Optional[list[UUID]] 1a
53ONE_HOUR = 60 * 60 1a
56jinja_env: Environment = Environment( 1a
57 loader=PackageLoader("prefect.server.database", package_path="sql"),
58 autoescape=select_autoescape(),
59 trim_blocks=True,
60)
63class BaseQueryComponents(ABC): 1a
64 """
65 Abstract base class used to inject dialect-specific SQL operations into Prefect.
66 """
68 _configuration_cache: ClassVar[Cache[str, dict[str, Any]]] = TTLCache( 1a
69 maxsize=100, ttl=ONE_HOUR
70 )
72 def unique_key(self) -> tuple[Hashable, ...]: 1a
73 """
74 Returns a key used to determine whether to instantiate a new DB interface.
75 """
76 return (self.__class__,) 1aicejbkflghd
78 # --- dialect-specific SqlAlchemy bindings
80 @abstractmethod 1a
81 def insert( 1a
82 self, obj: type[orm_models.Base]
83 ) -> Union[postgresql.Insert, sqlite.Insert]:
84 """dialect-specific insert statement"""
86 # --- dialect-specific JSON handling
88 @property 1a
89 @abstractmethod 1a
90 def uses_json_strings(self) -> bool: 1a
91 """specifies whether the configured dialect returns JSON as strings"""
93 @abstractmethod 1a
94 def cast_to_json(self, json_obj: sa.ColumnElement[T]) -> sa.ColumnElement[T]: 1a
95 """casts to JSON object if necessary"""
97 @abstractmethod 1a
98 def build_json_object( 1a
99 self, *args: Union[str, sa.ColumnElement[Any]]
100 ) -> sa.ColumnElement[Any]:
101 """builds a JSON object from sequential key-value pairs"""
103 @abstractmethod 1a
104 def json_arr_agg(self, json_array: sa.ColumnElement[Any]) -> sa.ColumnElement[Any]: 1a
105 """aggregates a JSON array"""
107 # --- dialect-optimized subqueries
109 @abstractmethod 1a
110 def make_timestamp_intervals( 110 ↛ exitline 110 didn't return from function 'make_timestamp_intervals' because 1a
111 self,
112 start_time: datetime.datetime,
113 end_time: datetime.datetime,
114 interval: datetime.timedelta,
115 ) -> sa.Select[tuple[datetime.datetime, datetime.datetime]]: ...
117 @abstractmethod 1a
118 def set_state_id_on_inserted_flow_runs_statement( 118 ↛ exitline 118 didn't return from function 'set_state_id_on_inserted_flow_runs_statement' because 1a
119 self,
120 inserted_flow_run_ids: Sequence[UUID],
121 insert_flow_run_states: Iterable[dict[str, Any]],
122 ) -> sa.Update: ...
124 @db_injector 1a
125 def get_scheduled_flow_runs_from_work_queues( 1a
126 self,
127 db: PrefectDBInterface,
128 limit_per_queue: Optional[int] = None,
129 work_queue_ids: Optional[list[UUID]] = None,
130 scheduled_before: Optional[DateTime] = None,
131 ) -> sa.Select[tuple[orm_models.FlowRun, UUID]]:
132 """
133 Returns all scheduled runs in work queues, subject to provided parameters.
135 This query returns a `(orm_models.FlowRun, orm_models.WorkQueue.id)` pair; calling
136 `result.all()` will return both; calling `result.scalars().unique().all()`
137 will return only the flow run because it grabs the first result.
138 """
140 FlowRun, WorkQueue = db.FlowRun, db.WorkQueue
142 # get any work queues that have a concurrency limit, and compute available
143 # slots as their limit less the number of running flows
144 concurrency_queues = (
145 sa.select(
146 WorkQueue.id,
147 sa.func.greatest(
148 0,
149 WorkQueue.concurrency_limit - sa.func.count(FlowRun.id),
150 ).label("available_slots"),
151 )
152 .select_from(WorkQueue)
153 .join(
154 FlowRun,
155 sa.and_(
156 FlowRun.work_queue_name == WorkQueue.name,
157 FlowRun.state_type.in_(
158 (StateType.RUNNING, StateType.PENDING, StateType.CANCELLING)
159 ),
160 ),
161 isouter=True,
162 )
163 .where(WorkQueue.concurrency_limit.is_not(None))
164 .group_by(WorkQueue.id)
165 .cte("concurrency_queues")
166 )
168 # use the available slots information to generate a join
169 # for all scheduled runs
170 scheduled_flow_runs, join_criteria = self._get_scheduled_flow_runs_join(
171 work_queue_query=concurrency_queues,
172 limit_per_queue=limit_per_queue,
173 scheduled_before=scheduled_before,
174 )
176 # starting with the work queue table, join the limited queues to get the
177 # concurrency information and the scheduled flow runs to load all applicable
178 # runs. this will return all the scheduled runs allowed by the parameters
179 query = (
180 # return a flow run and work queue id
181 sa.select(
182 orm.aliased(FlowRun, scheduled_flow_runs), WorkQueue.id.label("wq_id")
183 )
184 .select_from(WorkQueue)
185 .join(
186 concurrency_queues,
187 WorkQueue.id == concurrency_queues.c.id,
188 isouter=True,
189 )
190 .join(scheduled_flow_runs, join_criteria)
191 .where(
192 WorkQueue.is_paused.is_(False),
193 WorkQueue.id.in_(work_queue_ids) if work_queue_ids else sa.true(),
194 )
195 .order_by(
196 scheduled_flow_runs.c.next_scheduled_start_time,
197 scheduled_flow_runs.c.id,
198 )
199 )
201 return query
203 @db_injector 1a
204 def _get_scheduled_flow_runs_join( 1a
205 self,
206 db: PrefectDBInterface,
207 work_queue_query: sa.CTE,
208 limit_per_queue: Optional[int],
209 scheduled_before: Optional[DateTime],
210 ) -> tuple[sa.FromClause, sa.ColumnExpressionArgument[bool]]:
211 """Used by self.get_scheduled_flow_runs_from_work_queue, allowing just
212 this function to be changed on a per-dialect basis"""
214 FlowRun = db.FlowRun
216 # precompute for readability
217 scheduled_before_clause = (
218 FlowRun.next_scheduled_start_time <= scheduled_before
219 if scheduled_before is not None
220 else sa.true()
221 )
223 # get scheduled flow runs with lateral join where the limit is the
224 # available slots per queue
225 scheduled_flow_runs = (
226 sa.select(FlowRun)
227 .where(
228 FlowRun.work_queue_name == db.WorkQueue.name,
229 FlowRun.state_type == StateType.SCHEDULED,
230 scheduled_before_clause,
231 )
232 .with_for_update(skip_locked=True)
233 # priority given to runs with earlier next_scheduled_start_time
234 .order_by(FlowRun.next_scheduled_start_time)
235 # if null, no limit will be applied
236 .limit(sa.func.least(limit_per_queue, work_queue_query.c.available_slots))
237 .lateral("scheduled_flow_runs")
238 )
240 # Perform a cross-join
241 join_criteria = sa.true()
243 return scheduled_flow_runs, join_criteria
245 # -------------------------------------------------------
246 # Workers
247 # -------------------------------------------------------
249 @property 1a
250 @abstractmethod 1a
251 def _get_scheduled_flow_runs_from_work_pool_template_path(self) -> str: 1a
252 """
253 Template for the query to get scheduled flow runs from a work pool
254 """
256 @db_injector 1a
257 async def get_scheduled_flow_runs_from_work_pool( 1a
258 self,
259 db: PrefectDBInterface,
260 session: AsyncSession,
261 limit: Optional[int] = None,
262 worker_limit: Optional[int] = None,
263 queue_limit: Optional[int] = None,
264 work_pool_ids: Optional[list[UUID]] = None,
265 work_queue_ids: Optional[list[UUID]] = None,
266 scheduled_before: Optional[DateTime] = None,
267 scheduled_after: Optional[DateTime] = None,
268 respect_queue_priorities: bool = False,
269 ) -> list[schemas.responses.WorkerFlowRunResponse]:
270 template = jinja_env.get_template( 1b
271 self._get_scheduled_flow_runs_from_work_pool_template_path
272 )
274 raw_query = sa.text( 1b
275 template.render(
276 work_pool_ids=work_pool_ids,
277 work_queue_ids=work_queue_ids,
278 respect_queue_priorities=respect_queue_priorities,
279 scheduled_before=scheduled_before,
280 scheduled_after=scheduled_after,
281 )
282 )
284 bindparams: list[sa.BindParameter[Any]] = [] 1b
286 if scheduled_before: 286 ↛ 287line 286 didn't jump to line 287 because the condition on line 286 was never true1b
287 bindparams.append(
288 sa.bindparam("scheduled_before", scheduled_before, type_=Timestamp)
289 )
291 if scheduled_after: 291 ↛ 292line 291 didn't jump to line 292 because the condition on line 291 was never true1b
292 bindparams.append(
293 sa.bindparam("scheduled_after", scheduled_after, type_=Timestamp)
294 )
296 # if work pool IDs were provided, bind them
297 if work_pool_ids: 297 ↛ 309line 297 didn't jump to line 309 because the condition on line 297 was always true1b
298 assert all(isinstance(i, UUID) for i in work_pool_ids) 1b
299 bindparams.append( 1b
300 sa.bindparam(
301 "work_pool_ids",
302 work_pool_ids,
303 expanding=True,
304 type_=UUIDTypeDecorator,
305 )
306 )
308 # if work queue IDs were provided, bind them
309 if work_queue_ids: 309 ↛ 310line 309 didn't jump to line 310 because the condition on line 309 was never true1b
310 assert all(isinstance(i, UUID) for i in work_queue_ids)
311 bindparams.append(
312 sa.bindparam(
313 "work_queue_ids",
314 work_queue_ids,
315 expanding=True,
316 type_=UUIDTypeDecorator,
317 )
318 )
320 query = raw_query.bindparams( 1b
321 *bindparams,
322 limit=1000 if limit is None else limit,
323 worker_limit=1000 if worker_limit is None else worker_limit,
324 queue_limit=1000 if queue_limit is None else queue_limit,
325 )
327 FlowRun = db.FlowRun 1b
328 orm_query = ( 1b
329 sa.select(
330 sa.column("run_work_pool_id", UUIDTypeDecorator),
331 sa.column("run_work_queue_id", UUIDTypeDecorator),
332 FlowRun,
333 )
334 .from_statement(query)
335 # indicate that the state relationship isn't being loaded
336 .options(orm.noload(FlowRun.state))
337 )
339 result: sa.Result[
340 tuple[UUID, UUID, orm_models.FlowRun]
341 ] = await session.execute(orm_query)
343 return [
344 schemas.responses.WorkerFlowRunResponse(
345 work_pool_id=run_work_pool_id,
346 work_queue_id=run_work_queue_id,
347 flow_run=schemas.core.FlowRun.model_validate(
348 flow_run, from_attributes=True
349 ),
350 )
351 for (run_work_pool_id, run_work_queue_id, flow_run) in result.t
352 ]
354 @db_injector 1a
355 async def read_configuration_value( 1a
356 self, db: PrefectDBInterface, session: AsyncSession, key: str
357 ) -> Optional[dict[str, Any]]:
358 """
359 Read a configuration value by key.
361 Configuration values should not be changed at run time, so retrieved
362 values are cached in memory.
364 The main use of configurations is encrypting blocks, this speeds up nested
365 block document queries.
366 """
367 Configuration = db.Configuration 1c
368 value = None 1c
369 try: 1c
370 value = self._configuration_cache[key] 1c
371 except KeyError: 1c
372 query = sa.select(Configuration).where(Configuration.key == key) 1c
373 if (configuration := await session.scalar(query)) is not None: 373 ↛ anywhereline 373 didn't jump anywhere: it always raised an exception.1ce
374 value = self._configuration_cache[key] = configuration.value
375 return value
377 def clear_configuration_value_cache_for_key(self, key: str) -> None: 1a
378 """Removes a configuration key from the cache."""
379 self._configuration_cache.pop(key, None) 1e
381 @cached_property 1a
382 def _flow_run_graph_v2_query(self): 1a
383 query = self._build_flow_run_graph_v2_query()
384 param_names = set(bindparams_from_clause(query))
385 required = {"flow_run_id", "max_nodes", "since"}
386 assert param_names >= required, (
387 "_build_flow_run_graph_v2_query result is missing required bind params: "
388 f"{sorted(required - param_names)}"
389 )
390 return query
392 @abstractmethod 1a
393 def _build_flow_run_graph_v2_query(self) -> sa.Select[FlowRunGraphV2Node]: 1a
394 """The flow run graph query, per database flavour
396 The query must accept the following bind parameters:
398 flow_run_id: UUID
399 since: DateTime
400 max_nodes: int
402 """
404 @db_injector 1a
405 async def flow_run_graph_v2( 1a
406 self,
407 db: PrefectDBInterface,
408 session: AsyncSession,
409 flow_run_id: UUID,
410 since: DateTime,
411 max_nodes: int,
412 max_artifacts: int,
413 ) -> Graph:
414 """Returns the query that selects all of the nodes and edges for a flow run graph (version 2)."""
415 FlowRun = db.FlowRun 1bd
416 result = await session.execute( 1bd
417 sa.select(
418 sa.func.coalesce(
419 FlowRun.start_time, FlowRun.expected_start_time, type_=Timestamp
420 ),
421 FlowRun.end_time,
422 ).where(FlowRun.id == flow_run_id)
423 )
424 try:
425 start_time, end_time = result.t.one()
426 except NoResultFound:
427 raise ObjectNotFoundError(f"Flow run {flow_run_id} not found")
429 query = self._flow_run_graph_v2_query
430 results = await session.execute(
431 query,
432 params=dict(flow_run_id=flow_run_id, since=since, max_nodes=max_nodes + 1),
433 )
435 graph_artifacts = await self._get_flow_run_graph_artifacts(
436 db, session, flow_run_id, max_artifacts
437 )
438 graph_states = await self._get_flow_run_graph_states(session, flow_run_id)
440 nodes: list[tuple[UUID, Node]] = []
441 root_node_ids: list[UUID] = []
443 for row in results.t:
444 if not row.parent_ids:
445 root_node_ids.append(row.id)
447 nodes.append(
448 (
449 row.id,
450 Node(
451 kind=row.kind,
452 id=row.id,
453 label=row.label,
454 state_type=row.state_type,
455 start_time=row.start_time,
456 end_time=row.end_time,
457 parents=[Edge(id=id) for id in row.parent_ids or []],
458 children=[Edge(id=id) for id in row.child_ids or []],
459 encapsulating=[
460 Edge(id=id)
461 # ensure encapsulating_ids is deduplicated
462 for id in dict.fromkeys(row.encapsulating_ids or ())
463 ],
464 artifacts=graph_artifacts.get(row.id, []),
465 ),
466 )
467 )
469 if len(nodes) > max_nodes:
470 raise FlowRunGraphTooLarge(
471 f"The graph of flow run {flow_run_id} has more than "
472 f"{max_nodes} nodes."
473 )
475 return Graph(
476 start_time=start_time,
477 end_time=end_time,
478 root_node_ids=root_node_ids,
479 nodes=nodes,
480 artifacts=graph_artifacts.get(None, []),
481 states=graph_states,
482 )
484 async def _get_flow_run_graph_artifacts( 1ae
485 self,
486 db: PrefectDBInterface,
487 session: AsyncSession,
488 flow_run_id: UUID,
489 max_artifacts: int,
490 ) -> dict[Optional[UUID], list[GraphArtifact]]:
491 """Get the artifacts for a flow run grouped by task run id.
493 Does not recurse into subflows.
494 Artifacts for the flow run without a task run id are grouped under None.
495 """
496 Artifact, ArtifactCollection = db.Artifact, db.ArtifactCollection
498 query = (
499 sa.select(Artifact, ArtifactCollection.id.label("latest_in_collection_id"))
500 .where(Artifact.flow_run_id == flow_run_id, Artifact.type != "result")
501 .join(
502 ArtifactCollection,
503 onclause=sa.and_(
504 ArtifactCollection.key == Artifact.key,
505 ArtifactCollection.latest_id == Artifact.id,
506 ),
507 isouter=True,
508 )
509 .order_by(Artifact.created.asc())
510 .limit(max_artifacts)
511 )
513 results = await session.execute(query)
515 artifacts_by_task: dict[Optional[UUID], list[GraphArtifact]] = defaultdict(list)
516 for artifact, latest_in_collection_id in results.t:
517 artifacts_by_task[artifact.task_run_id].append(
518 GraphArtifact(
519 id=artifact.id,
520 created=artifact.created,
521 key=artifact.key,
522 type=artifact.type,
523 # We're only using the data field for progress artifacts for now
524 data=artifact.data if artifact.type == "progress" else None,
525 is_latest=artifact.key is None
526 or latest_in_collection_id is not None, # pyright: ignore[reportUnnecessaryComparison]
527 )
528 )
530 return dict(artifacts_by_task)
532 async def _get_flow_run_graph_states( 1a
533 self, session: AsyncSession, flow_run_id: UUID
534 ) -> list[GraphState]:
535 """Get the flow run states for a flow run graph."""
536 states = await models.flow_run_states.read_flow_run_states(session, flow_run_id)
537 return [
538 GraphState.model_validate(state, from_attributes=True) for state in states
539 ]
542class AsyncPostgresQueryComponents(BaseQueryComponents): 1a
543 # --- Postgres-specific SqlAlchemy bindings
545 def insert(self, obj: type[orm_models.Base]) -> postgresql.Insert: 1a
546 return postgresql.insert(obj)
548 # --- Postgres-specific JSON handling
550 @property 1a
551 def uses_json_strings(self) -> bool: 1a
552 return False
554 def cast_to_json(self, json_obj: sa.ColumnElement[T]) -> sa.ColumnElement[T]: 1a
555 return json_obj
557 def build_json_object( 1a
558 self, *args: Union[str, sa.ColumnElement[Any]]
559 ) -> sa.ColumnElement[Any]:
560 return sa.func.jsonb_build_object(*args)
562 def json_arr_agg(self, json_array: sa.ColumnElement[Any]) -> sa.ColumnElement[Any]: 1a
563 return sa.func.jsonb_agg(json_array)
565 # --- Postgres-optimized subqueries
567 def make_timestamp_intervals( 1a
568 self,
569 start_time: datetime.datetime,
570 end_time: datetime.datetime,
571 interval: datetime.timedelta,
572 ) -> sa.Select[tuple[datetime.datetime, datetime.datetime]]:
573 dt = sa.func.generate_series(
574 start_time, end_time, interval, type_=Timestamp()
575 ).column_valued("dt")
576 return (
577 sa.select(
578 dt.label("interval_start"),
579 sa.type_coerce(
580 dt + sa.bindparam("interval", interval, type_=sa.Interval()),
581 type_=Timestamp(),
582 ).label("interval_end"),
583 )
584 .where(dt < end_time)
585 .limit(500) # grab at most 500 intervals
586 )
588 @db_injector 1a
589 def set_state_id_on_inserted_flow_runs_statement( 1a
590 self,
591 db: PrefectDBInterface,
592 inserted_flow_run_ids: Sequence[UUID],
593 insert_flow_run_states: Iterable[dict[str, Any]],
594 ) -> sa.Update:
595 """Given a list of flow run ids and associated states, set the state_id
596 to the appropriate state for all flow runs"""
597 # postgres supports `UPDATE ... FROM` syntax
598 FlowRun, FlowRunState = db.FlowRun, db.FlowRunState
599 stmt = (
600 sa.update(FlowRun)
601 .where(
602 FlowRun.id.in_(inserted_flow_run_ids),
603 FlowRunState.flow_run_id == FlowRun.id,
604 FlowRunState.id.in_([r["id"] for r in insert_flow_run_states]),
605 )
606 .values(state_id=FlowRunState.id)
607 # no need to synchronize as these flow runs are entirely new
608 .execution_options(synchronize_session=False)
609 )
610 return stmt
612 @property 1a
613 def _get_scheduled_flow_runs_from_work_pool_template_path(self) -> str: 1a
614 """
615 Template for the query to get scheduled flow runs from a work pool
616 """
617 return "postgres/get-runs-from-worker-queues.sql.jinja"
619 @db_injector 1a
620 def _build_flow_run_graph_v2_query( 1a
621 self, db: PrefectDBInterface
622 ) -> sa.Select[FlowRunGraphV2Node]:
623 """Postgresql version of the V2 FlowRun graph data query
625 This SQLA query is built just once and then cached per DB interface
627 """
628 # the parameters this query takes as inputs
629 param_flow_run_id = sa.bindparam("flow_run_id", type_=UUIDTypeDecorator)
630 param_since = sa.bindparam("since", type_=Timestamp)
631 param_max_nodes = sa.bindparam("max_nodes", type_=sa.Integer)
633 Flow, FlowRun, TaskRun = db.Flow, db.FlowRun, db.TaskRun
634 input = sa.func.jsonb_each(TaskRun.task_inputs).table_valued(
635 "key", "value", name="input"
636 )
637 argument = (
638 sa.func.jsonb_array_elements(input.c.value, type_=postgresql.JSONB())
639 .table_valued(sa.column("value", postgresql.JSONB()))
640 .render_derived(name="argument")
641 )
642 edges = (
643 sa.select(
644 sa.case((FlowRun.id.is_not(None), "flow-run"), else_="task-run").label(
645 "kind"
646 ),
647 sa.func.coalesce(FlowRun.id, TaskRun.id).label("id"),
648 sa.func.coalesce(Flow.name + " / " + FlowRun.name, TaskRun.name).label(
649 "label"
650 ),
651 sa.func.coalesce(FlowRun.state_type, TaskRun.state_type).label(
652 "state_type"
653 ),
654 sa.func.coalesce(
655 FlowRun.start_time,
656 FlowRun.expected_start_time,
657 TaskRun.start_time,
658 TaskRun.expected_start_time,
659 ).label("start_time"),
660 sa.func.coalesce(
661 FlowRun.end_time,
662 TaskRun.end_time,
663 sa.case(
664 (
665 TaskRun.state_type == StateType.COMPLETED,
666 TaskRun.expected_start_time,
667 ),
668 else_=sa.null(),
669 ),
670 ).label("end_time"),
671 sa.cast(argument.c.value["id"].astext, type_=UUIDTypeDecorator).label(
672 "parent"
673 ),
674 (input.c.key == "__parents__").label("has_encapsulating_task"),
675 )
676 .join_from(TaskRun, input, onclause=sa.true(), isouter=True)
677 .join(argument, onclause=sa.true(), isouter=True)
678 .join(
679 FlowRun,
680 isouter=True,
681 onclause=FlowRun.parent_task_run_id == TaskRun.id,
682 )
683 .join(Flow, isouter=True, onclause=Flow.id == FlowRun.flow_id)
684 .where(
685 TaskRun.flow_run_id == param_flow_run_id,
686 TaskRun.state_type != StateType.PENDING,
687 sa.func.coalesce(
688 FlowRun.start_time,
689 FlowRun.expected_start_time,
690 TaskRun.start_time,
691 TaskRun.expected_start_time,
692 ).is_not(None),
693 )
694 # -- the order here is important to speed up building the two sets of
695 # -- edges in the with_parents and with_children CTEs below
696 .order_by(sa.func.coalesce(FlowRun.id, TaskRun.id))
697 ).cte("edges")
698 children, parents = edges.alias("children"), edges.alias("parents")
699 with_encapsulating = (
700 sa.select(
701 children.c.id,
702 sa.func.array_agg(
703 postgresql.aggregate_order_by(parents.c.id, parents.c.start_time)
704 ).label("encapsulating_ids"),
705 )
706 .join(parents, onclause=parents.c.id == children.c.parent)
707 .where(children.c.has_encapsulating_task.is_(True))
708 .group_by(children.c.id)
709 ).cte("with_encapsulating")
710 with_parents = (
711 sa.select(
712 children.c.id,
713 sa.func.array_agg(
714 postgresql.aggregate_order_by(parents.c.id, parents.c.start_time)
715 ).label("parent_ids"),
716 )
717 .join(parents, onclause=parents.c.id == children.c.parent)
718 .where(children.c.has_encapsulating_task.is_distinct_from(True))
719 .group_by(children.c.id)
720 .cte("with_parents")
721 )
722 with_children = (
723 sa.select(
724 parents.c.id,
725 sa.func.array_agg(
726 postgresql.aggregate_order_by(children.c.id, children.c.start_time)
727 ).label("child_ids"),
728 )
729 .join(children, onclause=children.c.parent == parents.c.id)
730 .where(children.c.has_encapsulating_task.is_distinct_from(True))
731 .group_by(parents.c.id)
732 .cte("with_children")
733 )
735 graph = (
736 sa.select(
737 edges.c.kind,
738 edges.c.id,
739 edges.c.label,
740 edges.c.state_type,
741 edges.c.start_time,
742 edges.c.end_time,
743 with_parents.c.parent_ids,
744 with_children.c.child_ids,
745 with_encapsulating.c.encapsulating_ids,
746 )
747 .distinct(edges.c.id)
748 .join(with_parents, isouter=True, onclause=with_parents.c.id == edges.c.id)
749 .join(
750 with_children, isouter=True, onclause=with_children.c.id == edges.c.id
751 )
752 .join(
753 with_encapsulating,
754 isouter=True,
755 onclause=with_encapsulating.c.id == edges.c.id,
756 )
757 .cte("nodes")
758 )
759 query = (
760 sa.select(
761 graph.c.kind,
762 graph.c.id,
763 graph.c.label,
764 graph.c.state_type,
765 graph.c.start_time,
766 graph.c.end_time,
767 graph.c.parent_ids,
768 graph.c.child_ids,
769 graph.c.encapsulating_ids,
770 )
771 .where(sa.or_(graph.c.end_time.is_(None), graph.c.end_time >= param_since))
772 .order_by(graph.c.start_time, graph.c.end_time)
773 .limit(param_max_nodes)
774 )
775 return cast(sa.Select[FlowRunGraphV2Node], query)
778class UUIDList(sa.TypeDecorator[list[UUID]]): 1a
779 """Map a JSON list of strings back to a list of UUIDs at the result loading stage"""
781 impl: Union[TypeEngine[Any], type[TypeEngine[Any]]] = sa.JSON() 1a
782 cache_ok: Optional[bool] = True 1a
784 def process_result_value( 1a
785 self, value: Optional[list[Union[str, UUID]]], dialect: sa.Dialect
786 ) -> Optional[list[UUID]]:
787 if value is None:
788 return value
789 return [v if isinstance(v, UUID) else UUID(v) for v in value]
792class AioSqliteQueryComponents(BaseQueryComponents): 1a
793 # --- Sqlite-specific SqlAlchemy bindings
795 def insert(self, obj: type[orm_models.Base]) -> sqlite.Insert: 1a
796 return sqlite.insert(obj) 1cbfghd
798 # --- Sqlite-specific JSON handling
800 @property 1a
801 def uses_json_strings(self) -> bool: 1a
802 return True 1bd
804 def cast_to_json(self, json_obj: sa.ColumnElement[T]) -> sa.ColumnElement[T]: 1a
805 return sa.func.json(json_obj) 1bd
807 def build_json_object( 1a
808 self, *args: Union[str, sa.ColumnElement[Any]]
809 ) -> sa.ColumnElement[Any]:
810 return sa.func.json_object(*args)
812 def json_arr_agg(self, json_array: sa.ColumnElement[Any]) -> sa.ColumnElement[Any]: 1a
813 return sa.func.json_group_array(json_array) 1bd
815 # --- Sqlite-optimized subqueries
817 def make_timestamp_intervals( 1a
818 self,
819 start_time: datetime.datetime,
820 end_time: datetime.datetime,
821 interval: datetime.timedelta,
822 ) -> sa.Select[tuple[datetime.datetime, datetime.datetime]]:
823 start = sa.bindparam("start_time", start_time, Timestamp)
824 # subtract interval because recursive where clauses are effectively evaluated on a t-1 lag
825 stop = sa.bindparam("end_time", end_time - interval, Timestamp)
826 step = sa.bindparam("interval", interval, sa.Interval)
828 one = sa.literal(1, literal_execute=True)
830 # recursive CTE to mimic the behavior of `generate_series`, which is
831 # only available as a compiled extension
832 base_case = sa.select(
833 start.label("interval_start"),
834 sa.func.date_add(start, step).label("interval_end"),
835 one.label("counter"),
836 ).cte(recursive=True)
837 recursive_case = sa.select(
838 base_case.c.interval_end,
839 sa.func.date_add(base_case.c.interval_end, step),
840 base_case.c.counter + one,
841 ).where(
842 base_case.c.interval_start < stop,
843 # don't compute more than 500 intervals
844 base_case.c.counter < 500,
845 )
846 cte = base_case.union_all(recursive_case)
848 return sa.select(cte.c.interval_start, cte.c.interval_end)
850 @db_injector 1a
851 def set_state_id_on_inserted_flow_runs_statement( 1a
852 self,
853 db: PrefectDBInterface,
854 inserted_flow_run_ids: Sequence[UUID],
855 insert_flow_run_states: Iterable[dict[str, Any]],
856 ) -> sa.Update:
857 """Given a list of flow run ids and associated states, set the state_id
858 to the appropriate state for all flow runs"""
859 fr_model, frs_model = db.FlowRun, db.FlowRunState
860 # sqlite requires a correlated subquery to update from another table
861 subquery = (
862 sa.select(frs_model.id)
863 .where(
864 frs_model.flow_run_id == fr_model.id,
865 frs_model.id.in_([r["id"] for r in insert_flow_run_states]),
866 )
867 .limit(1)
868 .scalar_subquery()
869 )
870 stmt = (
871 sa.update(fr_model)
872 .where(
873 fr_model.id.in_(inserted_flow_run_ids),
874 )
875 .values(state_id=subquery)
876 # no need to synchronize as these flow runs are entirely new
877 .execution_options(synchronize_session=False)
878 )
879 return stmt
881 @db_injector 1a
882 def _get_scheduled_flow_runs_join( 1a
883 self,
884 db: PrefectDBInterface,
885 work_queue_query: sa.CTE,
886 limit_per_queue: Optional[int],
887 scheduled_before: Optional[DateTime],
888 ) -> tuple[sa.FromClause, sa.ColumnExpressionArgument[bool]]:
889 # precompute for readability
890 FlowRun = db.FlowRun
892 scheduled_before_clause = (
893 FlowRun.next_scheduled_start_time <= scheduled_before
894 if scheduled_before is not None
895 else sa.true()
896 )
898 # select scheduled flow runs, ordered by scheduled start time per queue
899 scheduled_flow_runs = (
900 sa.select(
901 (
902 sa.func.row_number()
903 .over(
904 partition_by=[FlowRun.work_queue_name],
905 order_by=FlowRun.next_scheduled_start_time,
906 )
907 .label("rank")
908 ),
909 FlowRun,
910 )
911 .where(FlowRun.state_type == StateType.SCHEDULED, scheduled_before_clause)
912 .subquery("scheduled_flow_runs")
913 )
915 # sqlite short-circuits the `min` comparison on nulls, so we use `999999`
916 # as an "unlimited" limit.
917 limit = 999999 if limit_per_queue is None else limit_per_queue
919 # in the join, only keep flow runs whose rank is less than or equal to the
920 # available slots for each queue
921 join_criteria = sa.and_(
922 scheduled_flow_runs.c.work_queue_name == db.WorkQueue.name,
923 scheduled_flow_runs.c.rank
924 <= sa.func.min(
925 sa.func.coalesce(work_queue_query.c.available_slots, limit), limit
926 ),
927 )
928 return scheduled_flow_runs, join_criteria
930 # -------------------------------------------------------
931 # Workers
932 # -------------------------------------------------------
934 @property 1a
935 def _get_scheduled_flow_runs_from_work_pool_template_path(self) -> str: 1a
936 """
937 Template for the query to get scheduled flow runs from a work pool
938 """
939 return "sqlite/get-runs-from-worker-queues.sql.jinja" 1b
941 @db_injector 1a
942 def _build_flow_run_graph_v2_query( 1a
943 self, db: PrefectDBInterface
944 ) -> sa.Select[FlowRunGraphV2Node]:
945 """Postgresql version of the V2 FlowRun graph data query
947 This SQLA query is built just once and then cached per DB interface
949 """
950 # the parameters this query takes as inputs
951 param_flow_run_id = sa.bindparam("flow_run_id", type_=UUIDTypeDecorator)
952 param_since = sa.bindparam("since", type_=Timestamp)
953 param_max_nodes = sa.bindparam("max_nodes", type_=sa.Integer)
955 Flow, FlowRun, TaskRun = db.Flow, db.FlowRun, db.TaskRun
956 input = sa.func.json_each(TaskRun.task_inputs).table_valued(
957 "key", "value", name="input"
958 )
959 argument = sa.func.json_each(
960 input.c.value, type_=postgresql.JSON()
961 ).table_valued("key", sa.column("value", postgresql.JSON()), name="argument")
962 edges = (
963 sa.select(
964 sa.case((FlowRun.id.is_not(None), "flow-run"), else_="task-run").label(
965 "kind"
966 ),
967 sa.func.coalesce(FlowRun.id, TaskRun.id).label("id"),
968 sa.func.coalesce(Flow.name + " / " + FlowRun.name, TaskRun.name).label(
969 "label"
970 ),
971 sa.func.coalesce(FlowRun.state_type, TaskRun.state_type).label(
972 "state_type"
973 ),
974 sa.func.coalesce(
975 FlowRun.start_time,
976 FlowRun.expected_start_time,
977 TaskRun.start_time,
978 TaskRun.expected_start_time,
979 ).label("start_time"),
980 sa.func.coalesce(
981 FlowRun.end_time,
982 TaskRun.end_time,
983 sa.case(
984 (
985 TaskRun.state_type == StateType.COMPLETED,
986 TaskRun.expected_start_time,
987 ),
988 else_=sa.null(),
989 ),
990 ).label("end_time"),
991 argument.c.value["id"].astext.label("parent"),
992 (input.c.key == "__parents__").label("has_encapsulating_task"),
993 )
994 .join_from(TaskRun, input, onclause=sa.true(), isouter=True)
995 .join(argument, onclause=sa.true(), isouter=True)
996 .join(
997 FlowRun,
998 isouter=True,
999 onclause=FlowRun.parent_task_run_id == TaskRun.id,
1000 )
1001 .join(Flow, isouter=True, onclause=Flow.id == FlowRun.flow_id)
1002 .where(
1003 TaskRun.flow_run_id == param_flow_run_id,
1004 TaskRun.state_type != StateType.PENDING,
1005 sa.func.coalesce(
1006 FlowRun.start_time,
1007 FlowRun.expected_start_time,
1008 TaskRun.start_time,
1009 TaskRun.expected_start_time,
1010 ).is_not(None),
1011 )
1012 # -- the order here is important to speed up building the two sets of
1013 # -- edges in the with_parents and with_children CTEs below
1014 .order_by(sa.func.coalesce(FlowRun.id, TaskRun.id))
1015 ).cte("edges")
1016 children, parents = edges.alias("children"), edges.alias("parents")
1017 with_encapsulating = (
1018 sa.select(
1019 children.c.id,
1020 sa.func.json_group_array(parents.c.id).label("encapsulating_ids"),
1021 )
1022 .join(parents, onclause=parents.c.id == children.c.parent)
1023 .where(children.c.has_encapsulating_task.is_(True))
1024 .group_by(children.c.id)
1025 ).cte("with_encapsulating")
1026 with_parents = (
1027 sa.select(
1028 children.c.id,
1029 sa.func.json_group_array(parents.c.id).label("parent_ids"),
1030 )
1031 .join(parents, onclause=parents.c.id == children.c.parent)
1032 .where(children.c.has_encapsulating_task.is_distinct_from(True))
1033 .group_by(children.c.id)
1034 .cte("with_parents")
1035 )
1036 with_children = (
1037 sa.select(
1038 parents.c.id, sa.func.json_group_array(children.c.id).label("child_ids")
1039 )
1040 .join(children, onclause=children.c.parent == parents.c.id)
1041 .where(children.c.has_encapsulating_task.is_distinct_from(True))
1042 .group_by(parents.c.id)
1043 .cte("with_children")
1044 )
1046 graph = (
1047 sa.select(
1048 edges.c.kind,
1049 edges.c.id,
1050 edges.c.label,
1051 edges.c.state_type,
1052 edges.c.start_time,
1053 edges.c.end_time,
1054 with_parents.c.parent_ids,
1055 with_children.c.child_ids,
1056 with_encapsulating.c.encapsulating_ids,
1057 )
1058 .distinct()
1059 .join(with_parents, isouter=True, onclause=with_parents.c.id == edges.c.id)
1060 .join(
1061 with_children, isouter=True, onclause=with_children.c.id == edges.c.id
1062 )
1063 .join(
1064 with_encapsulating,
1065 isouter=True,
1066 onclause=with_encapsulating.c.id == edges.c.id,
1067 )
1068 .cte("nodes")
1069 )
1071 query = (
1072 sa.select(
1073 graph.c.kind,
1074 graph.c.id,
1075 graph.c.label,
1076 graph.c.state_type,
1077 graph.c.start_time,
1078 graph.c.end_time,
1079 sa.type_coerce(graph.c.parent_ids, UUIDList),
1080 sa.type_coerce(graph.c.child_ids, UUIDList),
1081 sa.type_coerce(graph.c.encapsulating_ids, UUIDList),
1082 )
1083 .where(sa.or_(graph.c.end_time.is_(None), graph.c.end_time >= param_since))
1084 .order_by(graph.c.start_time, graph.c.end_time)
1085 .limit(param_max_nodes)
1086 )
1087 return cast(sa.Select[FlowRunGraphV2Node], query)