Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/utilities/database.py: 82%
258 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
1"""
2Utilities for interacting with Prefect REST API database and ORM layer.
4Prefect supports both SQLite and Postgres. Many of these utilities
5allow the Prefect REST API to seamlessly switch between the two.
6"""
8from __future__ import annotations 1a
10import datetime 1a
11import json 1a
12import operator 1a
13import re 1a
14import uuid 1a
15from functools import partial 1a
16from typing import ( 1a
17 TYPE_CHECKING,
18 Any,
19 Callable,
20 Optional,
21 Type,
22 Union,
23 overload,
24)
25from zoneinfo import ZoneInfo 1a
27import pydantic 1a
28import sqlalchemy as sa 1a
29from sqlalchemy.dialects import postgresql, sqlite 1a
30from sqlalchemy.dialects.postgresql.operators import ( 1a
31 # these are all incompletely annotated
32 ASTEXT, # type: ignore
33 CONTAINS, # type: ignore
34 HAS_ALL, # type: ignore
35 HAS_ANY, # type: ignore
36)
37from sqlalchemy.ext.compiler import compiles 1a
38from sqlalchemy.orm import Session 1a
39from sqlalchemy.sql import functions, schema 1a
40from sqlalchemy.sql.compiler import SQLCompiler 1a
41from sqlalchemy.sql.operators import OperatorType 1a
42from sqlalchemy.sql.visitors import replacement_traverse 1a
43from sqlalchemy.types import CHAR, TypeDecorator, TypeEngine 1a
44from typing_extensions import ( 1a
45 Concatenate,
46 ParamSpec,
47 TypeAlias,
48 TypeVar,
49)
51from prefect.types._datetime import DateTime 1a
53P = ParamSpec("P") 1a
54R = TypeVar("R", infer_variance=True) 1a
55T = TypeVar("T", infer_variance=True) 1a
57_SQLExpressionOrLiteral: TypeAlias = Union[sa.SQLColumnExpression[T], T] 1a
58_Function = Callable[P, R] 1a
59_Method = Callable[Concatenate[T, P], R] 1a
60_DBFunction: TypeAlias = Callable[Concatenate["PrefectDBInterface", P], R] 1a
61_DBMethod: TypeAlias = Callable[Concatenate[T, "PrefectDBInterface", P], R] 1a
63CAMEL_TO_SNAKE: re.Pattern[str] = re.compile(r"(?<!^)(?=[A-Z])") 1a
65if TYPE_CHECKING: 65 ↛ 66line 65 didn't jump to line 66 because the condition on line 65 was never true1a
66 from prefect.server.database.interface import PrefectDBInterface
69@overload 1a
70def db_injector(func: _DBMethod[T, P, R]) -> _Method[T, P, R]: ... 70 ↛ exitline 70 didn't return from function 'db_injector' because 1a
73@overload 1a
74def db_injector(func: _DBFunction[P, R]) -> _Function[P, R]: ... 74 ↛ exitline 74 didn't return from function 'db_injector' because 1a
77def db_injector( 1a
78 func: Union[_DBMethod[T, P, R], _DBFunction[P, R]],
79) -> Union[_Method[T, P, R], _Function[P, R]]:
80 from prefect.server.database import db_injector 1a
82 return db_injector(func) 1a
85class GenerateUUID(functions.FunctionElement[uuid.UUID]): 1a
86 """
87 Platform-independent UUID default generator.
88 Note the actual functionality for this class is specified in the
89 `compiles`-decorated functions below
90 """
92 name = "uuid_default" 1a
95@compiles(GenerateUUID, "postgresql") 1a
96def generate_uuid_postgresql( 1a
97 element: GenerateUUID, compiler: SQLCompiler, **kwargs: Any
98) -> str:
99 """
100 Generates a random UUID in Postgres; requires the pgcrypto extension.
101 """
103 return "(GEN_RANDOM_UUID())"
106@compiles(GenerateUUID, "sqlite") 1a
107def generate_uuid_sqlite( 1a
108 element: GenerateUUID, compiler: SQLCompiler, **kwargs: Any
109) -> str:
110 """
111 Generates a random UUID in other databases (SQLite) by concatenating
112 bytes in a way that approximates a UUID hex representation. This is
113 sufficient for our purposes of having a random client-generated ID
114 that is compatible with a UUID spec.
115 """
117 return """
118 (
119 lower(hex(randomblob(4)))
120 || '-'
121 || lower(hex(randomblob(2)))
122 || '-4'
123 || substr(lower(hex(randomblob(2))),2)
124 || '-'
125 || substr('89ab',abs(random()) % 4 + 1, 1)
126 || substr(lower(hex(randomblob(2))),2)
127 || '-'
128 || lower(hex(randomblob(6)))
129 )
130 """
133class Timestamp(TypeDecorator[datetime.datetime]): 1a
134 """TypeDecorator that ensures that timestamps have a timezone.
136 For SQLite, all timestamps are converted to UTC (since they are stored
137 as naive timestamps without timezones) and recovered as UTC.
138 """
140 impl: TypeEngine[Any] | type[TypeEngine[Any]] = sa.TIMESTAMP(timezone=True) 1a
141 cache_ok: bool | None = True 1a
143 def load_dialect_impl(self, dialect: sa.Dialect) -> TypeEngine[Any]: 1a
144 if dialect.name == "postgresql": 144 ↛ 145line 144 didn't jump to line 145 because the condition on line 144 was never true1egbdc
145 return dialect.type_descriptor(postgresql.TIMESTAMP(timezone=True))
146 elif dialect.name == "sqlite": 1egbdc
147 # see the sqlite.DATETIME docstring on the particulars of the storage
148 # format. Note that the sqlite implementations for timestamp and interval
149 # arithmetic below would require updating if a different format was to
150 # be configured here.
151 return dialect.type_descriptor(sqlite.DATETIME()) 1egbc
152 else:
153 return dialect.type_descriptor(sa.TIMESTAMP(timezone=True)) 1bdc
155 def process_bind_param( 1a
156 self,
157 value: Optional[datetime.datetime],
158 dialect: sa.Dialect,
159 ) -> Optional[datetime.datetime]:
160 if value is None: 1egbkdhfc
161 return None 1bdc
162 else:
163 if value.tzinfo is None: 163 ↛ 164line 163 didn't jump to line 164 because the condition on line 163 was never true1egbkdhfc
164 raise ValueError("Timestamps must have a timezone.")
165 elif dialect.name == "sqlite": 165 ↛ 168line 165 didn't jump to line 168 because the condition on line 165 was always true1egbkdhfc
166 return value.astimezone(ZoneInfo("UTC")) 1egbkdhfc
167 else:
168 return value
170 def process_result_value( 1a
171 self,
172 value: Optional[datetime.datetime],
173 dialect: sa.Dialect,
174 ) -> Optional[datetime.datetime]:
175 # retrieve timestamps in their native timezone (or UTC)
176 if value is not None: 1ebidjhfc
177 if value.tzinfo is None: 177 ↛ 180line 177 didn't jump to line 180 because the condition on line 177 was always true1ebidjhfc
178 return value.replace(tzinfo=ZoneInfo("UTC")) 1ebidjhfc
179 else:
180 return value.astimezone(ZoneInfo("UTC"))
183class UUID(TypeDecorator[uuid.UUID]): 1a
184 """
185 Platform-independent UUID type.
187 Uses PostgreSQL's UUID type, otherwise uses
188 CHAR(36), storing as stringified hex values with
189 hyphens.
190 """
192 impl: type[TypeEngine[Any]] | TypeEngine[Any] = TypeEngine 1a
193 cache_ok: bool | None = True 1a
195 def load_dialect_impl(self, dialect: sa.Dialect) -> TypeEngine[Any]: 1a
196 if dialect.name == "postgresql": 196 ↛ 197line 196 didn't jump to line 197 because the condition on line 196 was never true1egbkidhfc
197 return dialect.type_descriptor(postgresql.UUID())
198 else:
199 return dialect.type_descriptor(CHAR(36)) 1egbkidhfc
201 def process_bind_param( 1a
202 self, value: Optional[Union[str, uuid.UUID]], dialect: sa.Dialect
203 ) -> Optional[str]:
204 if value is None: 1egbkidjhfc
205 return None 1bdhc
206 elif dialect.name == "postgresql": 206 ↛ 207line 206 didn't jump to line 207 because the condition on line 206 was never true1egbkidjhfc
207 return str(value)
208 elif isinstance(value, uuid.UUID): 208 ↛ 211line 208 didn't jump to line 211 because the condition on line 208 was always true1egbkidjhfc
209 return str(value) 1egbkidjhfc
210 else:
211 return str(uuid.UUID(value))
213 def process_result_value( 1a
214 self, value: Optional[Union[str, uuid.UUID]], dialect: sa.Dialect
215 ) -> Optional[uuid.UUID]:
216 if value is None: 1ebidjhfc
217 return value 1ebdc
218 else:
219 if not isinstance(value, uuid.UUID): 219 ↛ 221line 219 didn't jump to line 221 because the condition on line 219 was always true1ebidjhfc
220 value = uuid.UUID(value) 1ebidjhfc
221 return value 1ebidjhfc
224class JSON(TypeDecorator[Any]): 1a
225 """
226 JSON type that returns SQLAlchemy's dialect-specific JSON types, where
227 possible. Uses generic JSON otherwise.
229 The "base" type is postgresql.JSONB to expose useful methods prior
230 to SQL compilation
231 """
233 impl: type[postgresql.JSONB] | type[TypeEngine[Any]] | TypeEngine[Any] = ( 1a
234 postgresql.JSONB
235 )
236 cache_ok: bool | None = True 1a
238 def load_dialect_impl(self, dialect: sa.Dialect) -> TypeEngine[Any]: 1a
239 if dialect.name == "postgresql": 239 ↛ 240line 239 didn't jump to line 240 because the condition on line 239 was never true1egbdc
240 return dialect.type_descriptor(postgresql.JSONB(none_as_null=True))
241 elif dialect.name == "sqlite": 1egbdc
242 return dialect.type_descriptor(sqlite.JSON(none_as_null=True)) 1egbc
243 else:
244 return dialect.type_descriptor(sa.JSON(none_as_null=True)) 1bdc
246 def process_bind_param( 1a
247 self, value: Optional[Any], dialect: sa.Dialect
248 ) -> Optional[Any]:
249 """Prepares the given value to be used as a JSON field in a parameter binding"""
250 if not value: 1egbdhfc
251 return value 1ebdhfc
253 # PostgreSQL does not support the floating point extrema values `NaN`,
254 # `-Infinity`, or `Infinity`
255 # https://www.postgresql.org/docs/current/datatype-json.html#JSON-TYPE-MAPPING-TABLE
256 #
257 # SQLite supports storing and retrieving full JSON values that include
258 # `NaN`, `-Infinity`, or `Infinity`, but any query that requires SQLite to parse
259 # the value (like `json_extract`) will fail.
260 #
261 # Replace any `NaN`, `-Infinity`, or `Infinity` values with `None` in the
262 # returned value. See more about `parse_constant` at
263 # https://docs.python.org/3/library/json.html#json.load.
264 return json.loads(json.dumps(value), parse_constant=lambda c: None) 1egbdhfc
267class Pydantic(TypeDecorator[T]): 1a
268 """
269 A pydantic type that converts inserted parameters to
270 json and converts read values to the pydantic type.
271 """
273 impl = JSON 1a
274 cache_ok: bool | None = True 1a
276 @overload 1a
277 def __init__( 277 ↛ exitline 277 didn't return from function '__init__' because 1a
278 self,
279 pydantic_type: type[T],
280 sa_column_type: Optional[Union[type[TypeEngine[Any]], TypeEngine[Any]]] = None,
281 ) -> None: ...
283 # This overload is needed to allow for typing special forms (e.g.
284 # Union[...], etc.) as these can't be married with `type[...]`. Also see
285 # https://github.com/pydantic/pydantic/pull/8923
286 @overload 1a
287 def __init__( 287 ↛ exitline 287 didn't return from function '__init__' because 1a
288 self: "Pydantic[Any]",
289 pydantic_type: Any,
290 sa_column_type: Optional[Union[type[TypeEngine[Any]], TypeEngine[Any]]] = None,
291 ) -> None: ...
293 def __init__( 1a
294 self,
295 pydantic_type: type[T],
296 sa_column_type: Optional[Union[type[TypeEngine[Any]], TypeEngine[Any]]] = None,
297 ) -> None:
298 super().__init__() 1ae
299 self._pydantic_type = pydantic_type 1ae
300 self._adapter = pydantic.TypeAdapter(self._pydantic_type) 1ae
301 if sa_column_type is not None: 301 ↛ 302line 301 didn't jump to line 302 because the condition on line 301 was never true1ae
302 self.impl: type[JSON] | type[TypeEngine[Any]] | TypeEngine[Any] = (
303 sa_column_type
304 )
306 def process_bind_param( 1a
307 self, value: Optional[T], dialect: sa.Dialect
308 ) -> Optional[str]:
309 if value is None: 1bdfc
310 return None 1bdc
312 value = self._adapter.validate_python(value) 1bdfc
314 # sqlalchemy requires the bind parameter's value to be a python-native
315 # collection of JSON-compatible objects. we achieve that by dumping the
316 # value to a json string using the pydantic JSON encoder and re-parsing
317 # it into a python-native form.
318 return self._adapter.dump_python(value, mode="json") 1bdfc
320 def process_result_value( 1a
321 self, value: Optional[Any], dialect: sa.Dialect
322 ) -> Optional[T]:
323 if value is not None: 1bidhfc
324 return self._adapter.validate_python(value) 1bdhfc
327def bindparams_from_clause( 1a
328 query: sa.ClauseElement,
329) -> dict[str, sa.BindParameter[Any]]:
330 """Retrieve all non-anonymous bind parameters defined in a SQL clause"""
331 # we could use `traverse(query, {}, {"bindparam": some_list.append})` too,
332 # but this private method builds on the SQLA query caching infrastructure
333 # and so is more efficient.
334 return {
335 bp.key: bp
336 for bp in query._get_embedded_bindparams() # pyright: ignore[reportPrivateUsage]
337 # Anonymous keys are always a printf-style template that starts with '%([seed]'
338 # the seed is the id() of the bind parameter itself.
339 if not bp.key.startswith(f"%({id(bp)}")
340 }
343# Platform-independent datetime and timedelta arithmetic functions
346class date_add(functions.GenericFunction[DateTime]): 1a
347 """Platform-independent way to add a timestamp and an interval"""
349 type: Timestamp = Timestamp() 1a
350 inherit_cache: bool = True 1a
352 def __init__( 1a
353 self,
354 dt: _SQLExpressionOrLiteral[datetime.datetime],
355 interval: _SQLExpressionOrLiteral[datetime.timedelta],
356 **kwargs: Any,
357 ):
358 super().__init__(
359 sa.type_coerce(dt, Timestamp()),
360 sa.type_coerce(interval, sa.Interval()),
361 **kwargs,
362 )
365class interval_add(functions.GenericFunction[datetime.timedelta]): 1a
366 """Platform-independent way to add two intervals."""
368 type: sa.Interval = sa.Interval() 1a
369 inherit_cache: bool = True 1a
371 def __init__( 1a
372 self,
373 i1: _SQLExpressionOrLiteral[datetime.timedelta],
374 i2: _SQLExpressionOrLiteral[datetime.timedelta],
375 **kwargs: Any,
376 ):
377 super().__init__(
378 sa.type_coerce(i1, sa.Interval()),
379 sa.type_coerce(i2, sa.Interval()),
380 **kwargs,
381 )
384class date_diff(functions.GenericFunction[datetime.timedelta]): 1a
385 """Platform-independent difference of two timestamps. Computes d1 - d2."""
387 type: sa.Interval = sa.Interval() 1a
388 inherit_cache: bool = True 1a
390 def __init__( 1a
391 self,
392 d1: _SQLExpressionOrLiteral[datetime.datetime],
393 d2: _SQLExpressionOrLiteral[datetime.datetime],
394 **kwargs: Any,
395 ) -> None:
396 super().__init__(
397 sa.type_coerce(d1, Timestamp()), sa.type_coerce(d2, Timestamp()), **kwargs
398 )
401class date_diff_seconds(functions.GenericFunction[float]): 1a
402 """Platform-independent calculation of the number of seconds between two timestamps or from 'now'"""
404 type: Type[sa.REAL[float]] = sa.REAL 1a
405 inherit_cache: bool = True 1a
407 def __init__( 1a
408 self,
409 dt1: _SQLExpressionOrLiteral[datetime.datetime],
410 dt2: Optional[_SQLExpressionOrLiteral[datetime.datetime]] = None,
411 **kwargs: Any,
412 ) -> None:
413 args = (sa.type_coerce(dt1, Timestamp()),) 1ebc
414 if dt2 is not None: 414 ↛ 415line 414 didn't jump to line 415 because the condition on line 414 was never true1ebc
415 args = (*args, sa.type_coerce(dt2, Timestamp()))
416 super().__init__(*args, **kwargs) 1ebc
419# timestamp and interval arithmetic implementations for PostgreSQL
422@compiles(date_add, "postgresql") 1a
423@compiles(interval_add, "postgresql") 1a
424@compiles(date_diff, "postgresql") 1a
425def datetime_or_interval_add_postgresql( 1a
426 element: Union[date_add, interval_add, date_diff],
427 compiler: SQLCompiler,
428 **kwargs: Any,
429) -> str:
430 operation = operator.sub if isinstance(element, date_diff) else operator.add
431 return compiler.process(operation(*element.clauses), **kwargs)
434@compiles(date_diff_seconds, "postgresql") 1a
435def date_diff_seconds_postgresql( 1a
436 element: date_diff_seconds, compiler: SQLCompiler, **kwargs: Any
437) -> str:
438 # either 1 or 2 timestamps; if 1, subtract from 'now'
439 dts: list[sa.ColumnElement[datetime.datetime]] = list(element.clauses)
440 if len(dts) == 1:
441 dts = [sa.func.now(), *dts]
442 as_utc = (sa.func.timezone("UTC", dt) for dt in dts)
443 return compiler.process(sa.func.extract("epoch", operator.sub(*as_utc)), **kwargs)
446# SQLite implementations for the Timestamp and Interval arithmetic functions.
447#
448# The following concepts are at play here:
449#
450# - By default, SQLAlchemy stores Timestamp values formatted as ISO8601 strings
451# (with a space between the date and the time parts), with microsecond precision.
452# - SQLAlchemy stores Interval values as a Timestamp, offset from the UNIX epoch.
453# - SQLite processes timestamp values with _at most_ millisecond precision, and
454# only if you use the `juliandate()` function or the 'subsec' modifier for
455# the `unixepoch()` function (the latter requires SQLite 3.42.0, released
456# 2023-05-16)
457#
458# In order for arthmetic to work well, you need to convert timestamps to
459# fractional [Julian day numbers][JDN], and intervals to a real number
460# by subtracting the UNIX epoch from their Julian day number representation.
461#
462# Once the result has been computed, the result needs to be converted back
463# to an ISO8601 formatted string including any milliseconds. For an
464# interval result, that means adding the UNIX epoch offset to it first.
465#
466# [JDN]: https://en.wikipedia.org/wiki/Julian_day
468# SQLite strftime() format to output ISO8601 date and time with milliseconds
469# This format must be parseable by the `datetime.fromisodatetime()` function,
470# or if the SQLite implementation for Timestamp below is configured with a
471# regex, then that it must target that regex.
472#
473# SQLite only provides millisecond precision, but past versions of SQLAlchemy
474# defaulted to parsing with a regex that would treat fractional as a value in
475# microseconds. To ensure maximum compatibility the current format should
476# continue to format the fractional seconds as microseconds, so 6 digits.
477SQLITE_DATETIME_FORMAT = sa.literal("%Y-%m-%d %H:%M:%f000", literal_execute=True) 1a
478"""The SQLite timestamp output format as a SQL literal string constant""" 1a
481SQLITE_EPOCH_JULIANDAYNUMBER = sa.literal(2440587.5, literal_execute=True) 1a
482"""The UNIX epoch, 1970-01-01T00:00:00.000000Z, expressed as a fractional Julain day number""" 1a
483SECONDS_PER_DAY = sa.literal(24 * 60 * 60.0, literal_execute=True) 1a
484"""The number of seconds in a day as a SQL literal, to convert fractional Julian days to seconds""" 1a
486_sqlite_now_constant = sa.literal("now", literal_execute=True) 1a
487"""The 'now' string constant, passed to SQLite datetime functions""" 1a
488_sqlite_strftime = partial(sa.func.strftime, SQLITE_DATETIME_FORMAT) 1a
489"""Format SQLite timestamp to a SQLAlchemy-compatible string""" 1a
492def _sqlite_strfinterval( 1a
493 offset: sa.ColumnElement[float],
494) -> sa.ColumnElement[datetime.datetime]:
495 """Format interval offset to a SQLAlchemy-compatible string"""
496 return _sqlite_strftime(SQLITE_EPOCH_JULIANDAYNUMBER + offset)
499def _sqlite_interval_offset( 1a
500 interval: _SQLExpressionOrLiteral[datetime.timedelta],
501) -> sa.ColumnElement[float]:
502 """Convert interval value to a fraction Julian day number REAL offset from UNIX epoch"""
503 return sa.func.julianday(interval) - SQLITE_EPOCH_JULIANDAYNUMBER
506@compiles(functions.now, "sqlite") 1a
507def current_timestamp_sqlite( 1a
508 element: functions.now, compiler: SQLCompiler, **kwargs: Any
509) -> str:
510 """Generates the current timestamp for SQLite"""
511 return compiler.process(_sqlite_strftime(_sqlite_now_constant), **kwargs) 1egbc
514@compiles(date_add, "sqlite") 1a
515def date_add_sqlite(element: date_add, compiler: SQLCompiler, **kwargs: Any) -> str: 1a
516 dt, interval = element.clauses
517 jdn, offset = sa.func.julianday(dt), _sqlite_interval_offset(interval)
518 # dt + interval, as fractional Julian day number values
519 return compiler.process(_sqlite_strftime(jdn + offset), **kwargs)
522@compiles(interval_add, "sqlite") 1a
523def interval_add_sqlite( 1a
524 element: interval_add, compiler: SQLCompiler, **kwargs: Any
525) -> str:
526 offsets = map(_sqlite_interval_offset, element.clauses)
527 # interval + interval, as fractional Julian day number values
528 return compiler.process(_sqlite_strfinterval(operator.add(*offsets)), **kwargs)
531@compiles(date_diff, "sqlite") 1a
532def date_diff_sqlite(element: date_diff, compiler: SQLCompiler, **kwargs: Any) -> str: 1a
533 jdns = map(sa.func.julianday, element.clauses)
534 # timestamp - timestamp, as fractional Julian day number values
535 return compiler.process(_sqlite_strfinterval(operator.sub(*jdns)), **kwargs)
538@compiles(date_diff_seconds, "sqlite") 1a
539def date_diff_seconds_sqlite( 1a
540 element: date_diff_seconds, compiler: SQLCompiler, **kwargs: Any
541) -> str:
542 # either 1 or 2 timestamps; if 1, subtract from 'now'
543 dts: list[sa.ColumnElement[Any]] = list(element.clauses) 1gb
544 if len(dts) == 1: 544 ↛ 546line 544 didn't jump to line 546 because the condition on line 544 was always true1gb
545 dts = [_sqlite_now_constant, *dts] 1gb
546 as_jdn = (sa.func.julianday(dt) for dt in dts) 1gb
547 # timestamp - timestamp, as a fractional Julian day number, times the number of seconds in a day
548 return compiler.process(operator.sub(*as_jdn) * SECONDS_PER_DAY, **kwargs) 1gb
551# PostgreSQL JSON(B) Comparator operators ported to SQLite
554def _is_literal(elem: Any) -> bool: 1a
555 """Element is not a SQLAlchemy SQL construct"""
556 # Copied from sqlalchemy.sql.coercions._is_literal
557 return not ( 1bc
558 isinstance(elem, (sa.Visitable, schema.SchemaEventTarget))
559 or hasattr(elem, "__clause_element__")
560 )
563def _postgresql_array_to_json_array( 1a
564 elem: sa.ColumnElement[Any],
565) -> sa.ColumnElement[Any]:
566 """Replace any postgresql array() literals with a json_array() function call
568 Because an _empty_ array leads to a PostgreSQL error, array() is often
569 coupled with a cast(); this function replaces arrays with or without
570 such a cast.
572 This allows us to map the postgres JSONB.has_any / JSONB.has_all operand to
573 SQLite.
575 Returns the updated expression.
577 """
579 def _replacer(element: Any, **kwargs: Any) -> Optional[Any]: 1bdc
580 # either array(...), or cast(array(...), ...)
581 if isinstance(element, sa.Cast): 581 ↛ 583line 581 didn't jump to line 583 because the condition on line 581 was always true1bdc
582 element = element.clause 1bdc
583 if isinstance(element, postgresql.array): 583 ↛ 585line 583 didn't jump to line 585 because the condition on line 583 was always true1bdc
584 return sa.func.json_array(*element.clauses) 1bdc
585 return None
587 opts: dict[str, Any] = {} 1bdc
588 return replacement_traverse(elem, opts, _replacer) 1bdc
591def _json_each(elem: sa.ColumnElement[Any]) -> sa.TableValuedAlias: 1a
592 """SQLite json_each() table-valued construct
594 Configures a SQLAlchemy table-valued object with the minimum
595 column definitions and correct configuration.
597 """
598 return sa.func.json_each(elem).table_valued("key", "value", joins_implicitly=True) 1bdc
601# sqlite JSON operator implementations.
604def _sqlite_json_astext( 1a
605 element: sa.BinaryExpression[Any],
606) -> sa.BinaryExpression[Any]:
607 """Map postgres JSON.astext / JSONB.astext (`->>`) to sqlite json_extract()
609 Without the `as_string()` call, SQLAlchemy outputs json_quote(json_extract(...))
611 """
612 return element.left[element.right].as_string() 1b
615def _sqlite_json_contains( 1a
616 element: sa.BinaryExpression[bool],
617) -> sa.ColumnElement[bool]:
618 """Map JSONB.contains() and JSONB.has_all() to a SQLite expression"""
619 # left can be a JSON value as a (Python) literal, or a SQL expression for a JSON value
620 # right can be a SQLA postgresql.array() literal or a SQL expression for a
621 # JSON array (for .has_all()) or it can be a JSON value as a (Python)
622 # literal or a SQL expression for a JSON object (for .contains())
623 left, right = element.left, element.right 1bc
625 # if either top-level operand is literal, convert to a JSON bindparam
626 if _is_literal(left): 626 ↛ 627line 626 didn't jump to line 627 because the condition on line 626 was never true1bc
627 left = sa.bindparam("haystack", left, expanding=True, type_=JSON)
628 if _is_literal(right): 628 ↛ 629line 628 didn't jump to line 629 because the condition on line 628 was never true1bc
629 right = sa.bindparam("needles", right, expanding=True, type_=JSON)
630 else:
631 # convert the array() literal used in JSONB.has_all() to a JSON array.
632 right = _postgresql_array_to_json_array(right) 1bc
634 jleft, jright = _json_each(left), _json_each(right) 1bc
636 # compute equality by counting the number of distinct matches between the
637 # left items and the right items (e.g. the number of rows resulting from a
638 # join) and seeing if it exceeds the number of distinct keys in the right
639 # operand.
640 #
641 # note that using distinct emulates postgres behavior to disregard duplicates
642 distinct_matches = ( 1bc
643 sa.select(sa.func.count(sa.distinct(jleft.c.value)))
644 .join(jright, onclause=jleft.c.value == jright.c.value)
645 .scalar_subquery()
646 )
647 distinct_keys = sa.select( 1bc
648 sa.func.count(sa.distinct(jright.c.value))
649 ).scalar_subquery()
651 return distinct_matches >= distinct_keys 1bc
654def _sqlite_json_has_any(element: sa.BinaryExpression[bool]) -> sa.ColumnElement[bool]: 1a
655 """Map JSONB.has_any() to a SQLite expression"""
656 # left can be a JSON value as a (Python) literal, or a SQL expression for a JSON value
657 # right can be a SQLA postgresql.array() literal or a SQL expression for a JSON array
658 left, right = element.left, element.right 1bdc
660 # convert the array() literal used in JSONB.has_all() to a JSON array.
661 right = _postgresql_array_to_json_array(right) 1bdc
663 jleft, jright = _json_each(left), _json_each(right) 1bdc
665 # deal with "json array ?| [value, ...]"" vs "json object ?| [key, ...]" tests
666 # if left is a JSON object, match keys, else match values; the latter works
667 # for arrays and all JSON scalar types
668 json_object = sa.literal("object", literal_execute=True) 1bdc
669 left_elem = sa.case( 1bdc
670 (sa.func.json_type(element.left) == json_object, jleft.c.key),
671 else_=jleft.c.value,
672 )
674 return sa.exists().where(left_elem == jright.c.value) 1bdc
677# Map of SQLA postgresql JSON/JSONB operators and a function to rewrite
678# a BinaryExpression with such an operator to their SQLite equivalent.
679_sqlite_json_operator_map: dict[ 1a
680 OperatorType, Callable[[sa.BinaryExpression[Any]], sa.ColumnElement[Any]]
681] = {
682 ASTEXT: _sqlite_json_astext,
683 CONTAINS: _sqlite_json_contains,
684 HAS_ALL: _sqlite_json_contains, # "has all" is equivalent to "contains"
685 HAS_ANY: _sqlite_json_has_any,
686}
689@compiles(sa.BinaryExpression, "sqlite") 1a
690def sqlite_json_operators( 1a
691 element: sa.BinaryExpression[Any],
692 compiler: SQLCompiler,
693 override_operator: Optional[OperatorType] = None,
694 **kwargs: Any,
695) -> str:
696 """Intercept the PostgreSQL-only JSON / JSONB operators and translate them to SQLite"""
697 operator = override_operator or element.operator 1egbdfc
698 if (handler := _sqlite_json_operator_map.get(operator)) is not None: 1egbdfc
699 return compiler.process(handler(element), **kwargs) 1bdc
700 # ignore reason: SQLA compilation hooks are not as well covered with type annotations
701 return compiler.visit_binary(element, override_operator=operator, **kwargs) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] 1egbdfc
704class greatest(functions.ReturnTypeFromArgs[T]): 1a
705 inherit_cache: bool = True 1a
708@compiles(greatest, "sqlite") 1a
709def sqlite_greatest_as_max( 1a
710 element: greatest[Any], compiler: SQLCompiler, **kwargs: Any
711) -> str:
712 # TODO: SQLite MAX() is very close to PostgreSQL GREATEST(), *except* when
713 # it comes to nulls: SQLite MAX() returns NULL if _any_ clause is NULL,
714 # whereas PostgreSQL GREATEST() only returns NULL if _all_ clauses are NULL.
715 #
716 # A work-around is to use MAX() as an aggregate function instead, in a
717 # subquery. This, however, would probably require a VALUES-like construct
718 # that SQLA doesn't currently support for SQLite. You can [provide
719 # compilation hooks for
720 # this](https://github.com/sqlalchemy/sqlalchemy/issues/7228#issuecomment-1746837960)
721 # but this would only be worth it if sa.func.greatest() starts being used on
722 # values that include NULLs. Up until the time of this comment this hasn't
723 # been an issue.
724 return compiler.process(sa.func.max(*element.clauses), **kwargs) 1b
727class least(functions.ReturnTypeFromArgs[T]): 1a
728 inherit_cache: bool = True 1a
731@compiles(least, "sqlite") 1a
732def sqlite_least_as_min( 1a
733 element: least[Any], compiler: SQLCompiler, **kwargs: Any
734) -> str:
735 # SQLite doesn't have LEAST(), use MIN() instead.
736 # Note: Like MAX(), SQLite MIN() returns NULL if _any_ clause is NULL,
737 # whereas PostgreSQL LEAST() only returns NULL if _all_ clauses are NULL.
738 return compiler.process(sa.func.min(*element.clauses), **kwargs) 1b
741def get_dialect(obj: Union[str, Session, sa.Engine]) -> type[sa.Dialect]: 1a
742 """
743 Get the dialect of a session, engine, or connection url.
745 Primary use case is figuring out whether the Prefect REST API is communicating with
746 SQLite or Postgres.
748 Example:
749 ```python
750 import prefect.settings
751 from prefect.server.utilities.database import get_dialect
753 dialect = get_dialect(PREFECT_API_DATABASE_CONNECTION_URL.value())
754 if dialect.name == "sqlite":
755 print("Using SQLite!")
756 else:
757 print("Using Postgres!")
758 ```
759 """
760 if isinstance(obj, Session): 1alegmbidjhfc
761 assert obj.bind is not None 1bdc
762 obj = obj.bind.engine if isinstance(obj.bind, sa.Connection) else obj.bind 1bdc
764 if isinstance(obj, sa.engine.Engine): 1alegmbidjhfc
765 url = obj.url 1bdc
766 else:
767 url = sa.engine.url.make_url(obj) 1alegmbidjhfc
769 return url.get_dialect() 1alegmbidjhfc