Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/utilities/database.py: 82%

258 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1""" 

2Utilities for interacting with Prefect REST API database and ORM layer. 

3 

4Prefect supports both SQLite and Postgres. Many of these utilities 

5allow the Prefect REST API to seamlessly switch between the two. 

6""" 

7 

8from __future__ import annotations 1a

9 

10import datetime 1a

11import json 1a

12import operator 1a

13import re 1a

14import uuid 1a

15from functools import partial 1a

16from typing import ( 1a

17 TYPE_CHECKING, 

18 Any, 

19 Callable, 

20 Optional, 

21 Type, 

22 Union, 

23 overload, 

24) 

25from zoneinfo import ZoneInfo 1a

26 

27import pydantic 1a

28import sqlalchemy as sa 1a

29from sqlalchemy.dialects import postgresql, sqlite 1a

30from sqlalchemy.dialects.postgresql.operators import ( 1a

31 # these are all incompletely annotated 

32 ASTEXT, # type: ignore 

33 CONTAINS, # type: ignore 

34 HAS_ALL, # type: ignore 

35 HAS_ANY, # type: ignore 

36) 

37from sqlalchemy.ext.compiler import compiles 1a

38from sqlalchemy.orm import Session 1a

39from sqlalchemy.sql import functions, schema 1a

40from sqlalchemy.sql.compiler import SQLCompiler 1a

41from sqlalchemy.sql.operators import OperatorType 1a

42from sqlalchemy.sql.visitors import replacement_traverse 1a

43from sqlalchemy.types import CHAR, TypeDecorator, TypeEngine 1a

44from typing_extensions import ( 1a

45 Concatenate, 

46 ParamSpec, 

47 TypeAlias, 

48 TypeVar, 

49) 

50 

51from prefect.types._datetime import DateTime 1a

52 

53P = ParamSpec("P") 1a

54R = TypeVar("R", infer_variance=True) 1a

55T = TypeVar("T", infer_variance=True) 1a

56 

57_SQLExpressionOrLiteral: TypeAlias = Union[sa.SQLColumnExpression[T], T] 1a

58_Function = Callable[P, R] 1a

59_Method = Callable[Concatenate[T, P], R] 1a

60_DBFunction: TypeAlias = Callable[Concatenate["PrefectDBInterface", P], R] 1a

61_DBMethod: TypeAlias = Callable[Concatenate[T, "PrefectDBInterface", P], R] 1a

62 

63CAMEL_TO_SNAKE: re.Pattern[str] = re.compile(r"(?<!^)(?=[A-Z])") 1a

64 

65if TYPE_CHECKING: 65 ↛ 66line 65 didn't jump to line 66 because the condition on line 65 was never true1a

66 from prefect.server.database.interface import PrefectDBInterface 

67 

68 

69@overload 1a

70def db_injector(func: _DBMethod[T, P, R]) -> _Method[T, P, R]: ... 70 ↛ exitline 70 didn't return from function 'db_injector' because 1a

71 

72 

73@overload 1a

74def db_injector(func: _DBFunction[P, R]) -> _Function[P, R]: ... 74 ↛ exitline 74 didn't return from function 'db_injector' because 1a

75 

76 

77def db_injector( 1a

78 func: Union[_DBMethod[T, P, R], _DBFunction[P, R]], 

79) -> Union[_Method[T, P, R], _Function[P, R]]: 

80 from prefect.server.database import db_injector 1a

81 

82 return db_injector(func) 1a

83 

84 

85class GenerateUUID(functions.FunctionElement[uuid.UUID]): 1a

86 """ 

87 Platform-independent UUID default generator. 

88 Note the actual functionality for this class is specified in the 

89 `compiles`-decorated functions below 

90 """ 

91 

92 name = "uuid_default" 1a

93 

94 

95@compiles(GenerateUUID, "postgresql") 1a

96def generate_uuid_postgresql( 1a

97 element: GenerateUUID, compiler: SQLCompiler, **kwargs: Any 

98) -> str: 

99 """ 

100 Generates a random UUID in Postgres; requires the pgcrypto extension. 

101 """ 

102 

103 return "(GEN_RANDOM_UUID())" 

104 

105 

106@compiles(GenerateUUID, "sqlite") 1a

107def generate_uuid_sqlite( 1a

108 element: GenerateUUID, compiler: SQLCompiler, **kwargs: Any 

109) -> str: 

110 """ 

111 Generates a random UUID in other databases (SQLite) by concatenating 

112 bytes in a way that approximates a UUID hex representation. This is 

113 sufficient for our purposes of having a random client-generated ID 

114 that is compatible with a UUID spec. 

115 """ 

116 

117 return """ 

118 ( 

119 lower(hex(randomblob(4))) 

120 || '-' 

121 || lower(hex(randomblob(2))) 

122 || '-4' 

123 || substr(lower(hex(randomblob(2))),2) 

124 || '-' 

125 || substr('89ab',abs(random()) % 4 + 1, 1) 

126 || substr(lower(hex(randomblob(2))),2) 

127 || '-' 

128 || lower(hex(randomblob(6))) 

129 ) 

130 """ 

131 

132 

133class Timestamp(TypeDecorator[datetime.datetime]): 1a

134 """TypeDecorator that ensures that timestamps have a timezone. 

135 

136 For SQLite, all timestamps are converted to UTC (since they are stored 

137 as naive timestamps without timezones) and recovered as UTC. 

138 """ 

139 

140 impl: TypeEngine[Any] | type[TypeEngine[Any]] = sa.TIMESTAMP(timezone=True) 1a

141 cache_ok: bool | None = True 1a

142 

143 def load_dialect_impl(self, dialect: sa.Dialect) -> TypeEngine[Any]: 1a

144 if dialect.name == "postgresql": 144 ↛ 145line 144 didn't jump to line 145 because the condition on line 144 was never true1egbdc

145 return dialect.type_descriptor(postgresql.TIMESTAMP(timezone=True)) 

146 elif dialect.name == "sqlite": 1egbdc

147 # see the sqlite.DATETIME docstring on the particulars of the storage 

148 # format. Note that the sqlite implementations for timestamp and interval 

149 # arithmetic below would require updating if a different format was to 

150 # be configured here. 

151 return dialect.type_descriptor(sqlite.DATETIME()) 1egbc

152 else: 

153 return dialect.type_descriptor(sa.TIMESTAMP(timezone=True)) 1bdc

154 

155 def process_bind_param( 1a

156 self, 

157 value: Optional[datetime.datetime], 

158 dialect: sa.Dialect, 

159 ) -> Optional[datetime.datetime]: 

160 if value is None: 1egbkdhfc

161 return None 1bdc

162 else: 

163 if value.tzinfo is None: 163 ↛ 164line 163 didn't jump to line 164 because the condition on line 163 was never true1egbkdhfc

164 raise ValueError("Timestamps must have a timezone.") 

165 elif dialect.name == "sqlite": 165 ↛ 168line 165 didn't jump to line 168 because the condition on line 165 was always true1egbkdhfc

166 return value.astimezone(ZoneInfo("UTC")) 1egbkdhfc

167 else: 

168 return value 

169 

170 def process_result_value( 1a

171 self, 

172 value: Optional[datetime.datetime], 

173 dialect: sa.Dialect, 

174 ) -> Optional[datetime.datetime]: 

175 # retrieve timestamps in their native timezone (or UTC) 

176 if value is not None: 1ebidjhfc

177 if value.tzinfo is None: 177 ↛ 180line 177 didn't jump to line 180 because the condition on line 177 was always true1ebidjhfc

178 return value.replace(tzinfo=ZoneInfo("UTC")) 1ebidjhfc

179 else: 

180 return value.astimezone(ZoneInfo("UTC")) 

181 

182 

183class UUID(TypeDecorator[uuid.UUID]): 1a

184 """ 

185 Platform-independent UUID type. 

186 

187 Uses PostgreSQL's UUID type, otherwise uses 

188 CHAR(36), storing as stringified hex values with 

189 hyphens. 

190 """ 

191 

192 impl: type[TypeEngine[Any]] | TypeEngine[Any] = TypeEngine 1a

193 cache_ok: bool | None = True 1a

194 

195 def load_dialect_impl(self, dialect: sa.Dialect) -> TypeEngine[Any]: 1a

196 if dialect.name == "postgresql": 196 ↛ 197line 196 didn't jump to line 197 because the condition on line 196 was never true1egbkidhfc

197 return dialect.type_descriptor(postgresql.UUID()) 

198 else: 

199 return dialect.type_descriptor(CHAR(36)) 1egbkidhfc

200 

201 def process_bind_param( 1a

202 self, value: Optional[Union[str, uuid.UUID]], dialect: sa.Dialect 

203 ) -> Optional[str]: 

204 if value is None: 1egbkidjhfc

205 return None 1bdhc

206 elif dialect.name == "postgresql": 206 ↛ 207line 206 didn't jump to line 207 because the condition on line 206 was never true1egbkidjhfc

207 return str(value) 

208 elif isinstance(value, uuid.UUID): 208 ↛ 211line 208 didn't jump to line 211 because the condition on line 208 was always true1egbkidjhfc

209 return str(value) 1egbkidjhfc

210 else: 

211 return str(uuid.UUID(value)) 

212 

213 def process_result_value( 1a

214 self, value: Optional[Union[str, uuid.UUID]], dialect: sa.Dialect 

215 ) -> Optional[uuid.UUID]: 

216 if value is None: 1ebidjhfc

217 return value 1ebdc

218 else: 

219 if not isinstance(value, uuid.UUID): 219 ↛ 221line 219 didn't jump to line 221 because the condition on line 219 was always true1ebidjhfc

220 value = uuid.UUID(value) 1ebidjhfc

221 return value 1ebidjhfc

222 

223 

224class JSON(TypeDecorator[Any]): 1a

225 """ 

226 JSON type that returns SQLAlchemy's dialect-specific JSON types, where 

227 possible. Uses generic JSON otherwise. 

228 

229 The "base" type is postgresql.JSONB to expose useful methods prior 

230 to SQL compilation 

231 """ 

232 

233 impl: type[postgresql.JSONB] | type[TypeEngine[Any]] | TypeEngine[Any] = ( 1a

234 postgresql.JSONB 

235 ) 

236 cache_ok: bool | None = True 1a

237 

238 def load_dialect_impl(self, dialect: sa.Dialect) -> TypeEngine[Any]: 1a

239 if dialect.name == "postgresql": 239 ↛ 240line 239 didn't jump to line 240 because the condition on line 239 was never true1egbdc

240 return dialect.type_descriptor(postgresql.JSONB(none_as_null=True)) 

241 elif dialect.name == "sqlite": 1egbdc

242 return dialect.type_descriptor(sqlite.JSON(none_as_null=True)) 1egbc

243 else: 

244 return dialect.type_descriptor(sa.JSON(none_as_null=True)) 1bdc

245 

246 def process_bind_param( 1a

247 self, value: Optional[Any], dialect: sa.Dialect 

248 ) -> Optional[Any]: 

249 """Prepares the given value to be used as a JSON field in a parameter binding""" 

250 if not value: 1egbdhfc

251 return value 1ebdhfc

252 

253 # PostgreSQL does not support the floating point extrema values `NaN`, 

254 # `-Infinity`, or `Infinity` 

255 # https://www.postgresql.org/docs/current/datatype-json.html#JSON-TYPE-MAPPING-TABLE 

256 # 

257 # SQLite supports storing and retrieving full JSON values that include 

258 # `NaN`, `-Infinity`, or `Infinity`, but any query that requires SQLite to parse 

259 # the value (like `json_extract`) will fail. 

260 # 

261 # Replace any `NaN`, `-Infinity`, or `Infinity` values with `None` in the 

262 # returned value. See more about `parse_constant` at 

263 # https://docs.python.org/3/library/json.html#json.load. 

264 return json.loads(json.dumps(value), parse_constant=lambda c: None) 1egbdhfc

265 

266 

267class Pydantic(TypeDecorator[T]): 1a

268 """ 

269 A pydantic type that converts inserted parameters to 

270 json and converts read values to the pydantic type. 

271 """ 

272 

273 impl = JSON 1a

274 cache_ok: bool | None = True 1a

275 

276 @overload 1a

277 def __init__( 277 ↛ exitline 277 didn't return from function '__init__' because 1a

278 self, 

279 pydantic_type: type[T], 

280 sa_column_type: Optional[Union[type[TypeEngine[Any]], TypeEngine[Any]]] = None, 

281 ) -> None: ... 

282 

283 # This overload is needed to allow for typing special forms (e.g. 

284 # Union[...], etc.) as these can't be married with `type[...]`. Also see 

285 # https://github.com/pydantic/pydantic/pull/8923 

286 @overload 1a

287 def __init__( 287 ↛ exitline 287 didn't return from function '__init__' because 1a

288 self: "Pydantic[Any]", 

289 pydantic_type: Any, 

290 sa_column_type: Optional[Union[type[TypeEngine[Any]], TypeEngine[Any]]] = None, 

291 ) -> None: ... 

292 

293 def __init__( 1a

294 self, 

295 pydantic_type: type[T], 

296 sa_column_type: Optional[Union[type[TypeEngine[Any]], TypeEngine[Any]]] = None, 

297 ) -> None: 

298 super().__init__() 1ae

299 self._pydantic_type = pydantic_type 1ae

300 self._adapter = pydantic.TypeAdapter(self._pydantic_type) 1ae

301 if sa_column_type is not None: 301 ↛ 302line 301 didn't jump to line 302 because the condition on line 301 was never true1ae

302 self.impl: type[JSON] | type[TypeEngine[Any]] | TypeEngine[Any] = ( 

303 sa_column_type 

304 ) 

305 

306 def process_bind_param( 1a

307 self, value: Optional[T], dialect: sa.Dialect 

308 ) -> Optional[str]: 

309 if value is None: 1bdfc

310 return None 1bdc

311 

312 value = self._adapter.validate_python(value) 1bdfc

313 

314 # sqlalchemy requires the bind parameter's value to be a python-native 

315 # collection of JSON-compatible objects. we achieve that by dumping the 

316 # value to a json string using the pydantic JSON encoder and re-parsing 

317 # it into a python-native form. 

318 return self._adapter.dump_python(value, mode="json") 1bdfc

319 

320 def process_result_value( 1a

321 self, value: Optional[Any], dialect: sa.Dialect 

322 ) -> Optional[T]: 

323 if value is not None: 1bidhfc

324 return self._adapter.validate_python(value) 1bdhfc

325 

326 

327def bindparams_from_clause( 1a

328 query: sa.ClauseElement, 

329) -> dict[str, sa.BindParameter[Any]]: 

330 """Retrieve all non-anonymous bind parameters defined in a SQL clause""" 

331 # we could use `traverse(query, {}, {"bindparam": some_list.append})` too, 

332 # but this private method builds on the SQLA query caching infrastructure 

333 # and so is more efficient. 

334 return { 

335 bp.key: bp 

336 for bp in query._get_embedded_bindparams() # pyright: ignore[reportPrivateUsage] 

337 # Anonymous keys are always a printf-style template that starts with '%([seed]' 

338 # the seed is the id() of the bind parameter itself. 

339 if not bp.key.startswith(f"%({id(bp)}") 

340 } 

341 

342 

343# Platform-independent datetime and timedelta arithmetic functions 

344 

345 

346class date_add(functions.GenericFunction[DateTime]): 1a

347 """Platform-independent way to add a timestamp and an interval""" 

348 

349 type: Timestamp = Timestamp() 1a

350 inherit_cache: bool = True 1a

351 

352 def __init__( 1a

353 self, 

354 dt: _SQLExpressionOrLiteral[datetime.datetime], 

355 interval: _SQLExpressionOrLiteral[datetime.timedelta], 

356 **kwargs: Any, 

357 ): 

358 super().__init__( 

359 sa.type_coerce(dt, Timestamp()), 

360 sa.type_coerce(interval, sa.Interval()), 

361 **kwargs, 

362 ) 

363 

364 

365class interval_add(functions.GenericFunction[datetime.timedelta]): 1a

366 """Platform-independent way to add two intervals.""" 

367 

368 type: sa.Interval = sa.Interval() 1a

369 inherit_cache: bool = True 1a

370 

371 def __init__( 1a

372 self, 

373 i1: _SQLExpressionOrLiteral[datetime.timedelta], 

374 i2: _SQLExpressionOrLiteral[datetime.timedelta], 

375 **kwargs: Any, 

376 ): 

377 super().__init__( 

378 sa.type_coerce(i1, sa.Interval()), 

379 sa.type_coerce(i2, sa.Interval()), 

380 **kwargs, 

381 ) 

382 

383 

384class date_diff(functions.GenericFunction[datetime.timedelta]): 1a

385 """Platform-independent difference of two timestamps. Computes d1 - d2.""" 

386 

387 type: sa.Interval = sa.Interval() 1a

388 inherit_cache: bool = True 1a

389 

390 def __init__( 1a

391 self, 

392 d1: _SQLExpressionOrLiteral[datetime.datetime], 

393 d2: _SQLExpressionOrLiteral[datetime.datetime], 

394 **kwargs: Any, 

395 ) -> None: 

396 super().__init__( 

397 sa.type_coerce(d1, Timestamp()), sa.type_coerce(d2, Timestamp()), **kwargs 

398 ) 

399 

400 

401class date_diff_seconds(functions.GenericFunction[float]): 1a

402 """Platform-independent calculation of the number of seconds between two timestamps or from 'now'""" 

403 

404 type: Type[sa.REAL[float]] = sa.REAL 1a

405 inherit_cache: bool = True 1a

406 

407 def __init__( 1a

408 self, 

409 dt1: _SQLExpressionOrLiteral[datetime.datetime], 

410 dt2: Optional[_SQLExpressionOrLiteral[datetime.datetime]] = None, 

411 **kwargs: Any, 

412 ) -> None: 

413 args = (sa.type_coerce(dt1, Timestamp()),) 1ebc

414 if dt2 is not None: 414 ↛ 415line 414 didn't jump to line 415 because the condition on line 414 was never true1ebc

415 args = (*args, sa.type_coerce(dt2, Timestamp())) 

416 super().__init__(*args, **kwargs) 1ebc

417 

418 

419# timestamp and interval arithmetic implementations for PostgreSQL 

420 

421 

422@compiles(date_add, "postgresql") 1a

423@compiles(interval_add, "postgresql") 1a

424@compiles(date_diff, "postgresql") 1a

425def datetime_or_interval_add_postgresql( 1a

426 element: Union[date_add, interval_add, date_diff], 

427 compiler: SQLCompiler, 

428 **kwargs: Any, 

429) -> str: 

430 operation = operator.sub if isinstance(element, date_diff) else operator.add 

431 return compiler.process(operation(*element.clauses), **kwargs) 

432 

433 

434@compiles(date_diff_seconds, "postgresql") 1a

435def date_diff_seconds_postgresql( 1a

436 element: date_diff_seconds, compiler: SQLCompiler, **kwargs: Any 

437) -> str: 

438 # either 1 or 2 timestamps; if 1, subtract from 'now' 

439 dts: list[sa.ColumnElement[datetime.datetime]] = list(element.clauses) 

440 if len(dts) == 1: 

441 dts = [sa.func.now(), *dts] 

442 as_utc = (sa.func.timezone("UTC", dt) for dt in dts) 

443 return compiler.process(sa.func.extract("epoch", operator.sub(*as_utc)), **kwargs) 

444 

445 

446# SQLite implementations for the Timestamp and Interval arithmetic functions. 

447# 

448# The following concepts are at play here: 

449# 

450# - By default, SQLAlchemy stores Timestamp values formatted as ISO8601 strings 

451# (with a space between the date and the time parts), with microsecond precision. 

452# - SQLAlchemy stores Interval values as a Timestamp, offset from the UNIX epoch. 

453# - SQLite processes timestamp values with _at most_ millisecond precision, and 

454# only if you use the `juliandate()` function or the 'subsec' modifier for 

455# the `unixepoch()` function (the latter requires SQLite 3.42.0, released 

456# 2023-05-16) 

457# 

458# In order for arthmetic to work well, you need to convert timestamps to 

459# fractional [Julian day numbers][JDN], and intervals to a real number 

460# by subtracting the UNIX epoch from their Julian day number representation. 

461# 

462# Once the result has been computed, the result needs to be converted back 

463# to an ISO8601 formatted string including any milliseconds. For an 

464# interval result, that means adding the UNIX epoch offset to it first. 

465# 

466# [JDN]: https://en.wikipedia.org/wiki/Julian_day 

467 

468# SQLite strftime() format to output ISO8601 date and time with milliseconds 

469# This format must be parseable by the `datetime.fromisodatetime()` function, 

470# or if the SQLite implementation for Timestamp below is configured with a 

471# regex, then that it must target that regex. 

472# 

473# SQLite only provides millisecond precision, but past versions of SQLAlchemy 

474# defaulted to parsing with a regex that would treat fractional as a value in 

475# microseconds. To ensure maximum compatibility the current format should 

476# continue to format the fractional seconds as microseconds, so 6 digits. 

477SQLITE_DATETIME_FORMAT = sa.literal("%Y-%m-%d %H:%M:%f000", literal_execute=True) 1a

478"""The SQLite timestamp output format as a SQL literal string constant""" 1a

479 

480 

481SQLITE_EPOCH_JULIANDAYNUMBER = sa.literal(2440587.5, literal_execute=True) 1a

482"""The UNIX epoch, 1970-01-01T00:00:00.000000Z, expressed as a fractional Julain day number""" 1a

483SECONDS_PER_DAY = sa.literal(24 * 60 * 60.0, literal_execute=True) 1a

484"""The number of seconds in a day as a SQL literal, to convert fractional Julian days to seconds""" 1a

485 

486_sqlite_now_constant = sa.literal("now", literal_execute=True) 1a

487"""The 'now' string constant, passed to SQLite datetime functions""" 1a

488_sqlite_strftime = partial(sa.func.strftime, SQLITE_DATETIME_FORMAT) 1a

489"""Format SQLite timestamp to a SQLAlchemy-compatible string""" 1a

490 

491 

492def _sqlite_strfinterval( 1a

493 offset: sa.ColumnElement[float], 

494) -> sa.ColumnElement[datetime.datetime]: 

495 """Format interval offset to a SQLAlchemy-compatible string""" 

496 return _sqlite_strftime(SQLITE_EPOCH_JULIANDAYNUMBER + offset) 

497 

498 

499def _sqlite_interval_offset( 1a

500 interval: _SQLExpressionOrLiteral[datetime.timedelta], 

501) -> sa.ColumnElement[float]: 

502 """Convert interval value to a fraction Julian day number REAL offset from UNIX epoch""" 

503 return sa.func.julianday(interval) - SQLITE_EPOCH_JULIANDAYNUMBER 

504 

505 

506@compiles(functions.now, "sqlite") 1a

507def current_timestamp_sqlite( 1a

508 element: functions.now, compiler: SQLCompiler, **kwargs: Any 

509) -> str: 

510 """Generates the current timestamp for SQLite""" 

511 return compiler.process(_sqlite_strftime(_sqlite_now_constant), **kwargs) 1egbc

512 

513 

514@compiles(date_add, "sqlite") 1a

515def date_add_sqlite(element: date_add, compiler: SQLCompiler, **kwargs: Any) -> str: 1a

516 dt, interval = element.clauses 

517 jdn, offset = sa.func.julianday(dt), _sqlite_interval_offset(interval) 

518 # dt + interval, as fractional Julian day number values 

519 return compiler.process(_sqlite_strftime(jdn + offset), **kwargs) 

520 

521 

522@compiles(interval_add, "sqlite") 1a

523def interval_add_sqlite( 1a

524 element: interval_add, compiler: SQLCompiler, **kwargs: Any 

525) -> str: 

526 offsets = map(_sqlite_interval_offset, element.clauses) 

527 # interval + interval, as fractional Julian day number values 

528 return compiler.process(_sqlite_strfinterval(operator.add(*offsets)), **kwargs) 

529 

530 

531@compiles(date_diff, "sqlite") 1a

532def date_diff_sqlite(element: date_diff, compiler: SQLCompiler, **kwargs: Any) -> str: 1a

533 jdns = map(sa.func.julianday, element.clauses) 

534 # timestamp - timestamp, as fractional Julian day number values 

535 return compiler.process(_sqlite_strfinterval(operator.sub(*jdns)), **kwargs) 

536 

537 

538@compiles(date_diff_seconds, "sqlite") 1a

539def date_diff_seconds_sqlite( 1a

540 element: date_diff_seconds, compiler: SQLCompiler, **kwargs: Any 

541) -> str: 

542 # either 1 or 2 timestamps; if 1, subtract from 'now' 

543 dts: list[sa.ColumnElement[Any]] = list(element.clauses) 1gb

544 if len(dts) == 1: 544 ↛ 546line 544 didn't jump to line 546 because the condition on line 544 was always true1gb

545 dts = [_sqlite_now_constant, *dts] 1gb

546 as_jdn = (sa.func.julianday(dt) for dt in dts) 1gb

547 # timestamp - timestamp, as a fractional Julian day number, times the number of seconds in a day 

548 return compiler.process(operator.sub(*as_jdn) * SECONDS_PER_DAY, **kwargs) 1gb

549 

550 

551# PostgreSQL JSON(B) Comparator operators ported to SQLite 

552 

553 

554def _is_literal(elem: Any) -> bool: 1a

555 """Element is not a SQLAlchemy SQL construct""" 

556 # Copied from sqlalchemy.sql.coercions._is_literal 

557 return not ( 1bc

558 isinstance(elem, (sa.Visitable, schema.SchemaEventTarget)) 

559 or hasattr(elem, "__clause_element__") 

560 ) 

561 

562 

563def _postgresql_array_to_json_array( 1a

564 elem: sa.ColumnElement[Any], 

565) -> sa.ColumnElement[Any]: 

566 """Replace any postgresql array() literals with a json_array() function call 

567 

568 Because an _empty_ array leads to a PostgreSQL error, array() is often 

569 coupled with a cast(); this function replaces arrays with or without 

570 such a cast. 

571 

572 This allows us to map the postgres JSONB.has_any / JSONB.has_all operand to 

573 SQLite. 

574 

575 Returns the updated expression. 

576 

577 """ 

578 

579 def _replacer(element: Any, **kwargs: Any) -> Optional[Any]: 1bdc

580 # either array(...), or cast(array(...), ...) 

581 if isinstance(element, sa.Cast): 581 ↛ 583line 581 didn't jump to line 583 because the condition on line 581 was always true1bdc

582 element = element.clause 1bdc

583 if isinstance(element, postgresql.array): 583 ↛ 585line 583 didn't jump to line 585 because the condition on line 583 was always true1bdc

584 return sa.func.json_array(*element.clauses) 1bdc

585 return None 

586 

587 opts: dict[str, Any] = {} 1bdc

588 return replacement_traverse(elem, opts, _replacer) 1bdc

589 

590 

591def _json_each(elem: sa.ColumnElement[Any]) -> sa.TableValuedAlias: 1a

592 """SQLite json_each() table-valued construct 

593 

594 Configures a SQLAlchemy table-valued object with the minimum 

595 column definitions and correct configuration. 

596 

597 """ 

598 return sa.func.json_each(elem).table_valued("key", "value", joins_implicitly=True) 1bdc

599 

600 

601# sqlite JSON operator implementations. 

602 

603 

604def _sqlite_json_astext( 1a

605 element: sa.BinaryExpression[Any], 

606) -> sa.BinaryExpression[Any]: 

607 """Map postgres JSON.astext / JSONB.astext (`->>`) to sqlite json_extract() 

608 

609 Without the `as_string()` call, SQLAlchemy outputs json_quote(json_extract(...)) 

610 

611 """ 

612 return element.left[element.right].as_string() 1b

613 

614 

615def _sqlite_json_contains( 1a

616 element: sa.BinaryExpression[bool], 

617) -> sa.ColumnElement[bool]: 

618 """Map JSONB.contains() and JSONB.has_all() to a SQLite expression""" 

619 # left can be a JSON value as a (Python) literal, or a SQL expression for a JSON value 

620 # right can be a SQLA postgresql.array() literal or a SQL expression for a 

621 # JSON array (for .has_all()) or it can be a JSON value as a (Python) 

622 # literal or a SQL expression for a JSON object (for .contains()) 

623 left, right = element.left, element.right 1bc

624 

625 # if either top-level operand is literal, convert to a JSON bindparam 

626 if _is_literal(left): 626 ↛ 627line 626 didn't jump to line 627 because the condition on line 626 was never true1bc

627 left = sa.bindparam("haystack", left, expanding=True, type_=JSON) 

628 if _is_literal(right): 628 ↛ 629line 628 didn't jump to line 629 because the condition on line 628 was never true1bc

629 right = sa.bindparam("needles", right, expanding=True, type_=JSON) 

630 else: 

631 # convert the array() literal used in JSONB.has_all() to a JSON array. 

632 right = _postgresql_array_to_json_array(right) 1bc

633 

634 jleft, jright = _json_each(left), _json_each(right) 1bc

635 

636 # compute equality by counting the number of distinct matches between the 

637 # left items and the right items (e.g. the number of rows resulting from a 

638 # join) and seeing if it exceeds the number of distinct keys in the right 

639 # operand. 

640 # 

641 # note that using distinct emulates postgres behavior to disregard duplicates 

642 distinct_matches = ( 1bc

643 sa.select(sa.func.count(sa.distinct(jleft.c.value))) 

644 .join(jright, onclause=jleft.c.value == jright.c.value) 

645 .scalar_subquery() 

646 ) 

647 distinct_keys = sa.select( 1bc

648 sa.func.count(sa.distinct(jright.c.value)) 

649 ).scalar_subquery() 

650 

651 return distinct_matches >= distinct_keys 1bc

652 

653 

654def _sqlite_json_has_any(element: sa.BinaryExpression[bool]) -> sa.ColumnElement[bool]: 1a

655 """Map JSONB.has_any() to a SQLite expression""" 

656 # left can be a JSON value as a (Python) literal, or a SQL expression for a JSON value 

657 # right can be a SQLA postgresql.array() literal or a SQL expression for a JSON array 

658 left, right = element.left, element.right 1bdc

659 

660 # convert the array() literal used in JSONB.has_all() to a JSON array. 

661 right = _postgresql_array_to_json_array(right) 1bdc

662 

663 jleft, jright = _json_each(left), _json_each(right) 1bdc

664 

665 # deal with "json array ?| [value, ...]"" vs "json object ?| [key, ...]" tests 

666 # if left is a JSON object, match keys, else match values; the latter works 

667 # for arrays and all JSON scalar types 

668 json_object = sa.literal("object", literal_execute=True) 1bdc

669 left_elem = sa.case( 1bdc

670 (sa.func.json_type(element.left) == json_object, jleft.c.key), 

671 else_=jleft.c.value, 

672 ) 

673 

674 return sa.exists().where(left_elem == jright.c.value) 1bdc

675 

676 

677# Map of SQLA postgresql JSON/JSONB operators and a function to rewrite 

678# a BinaryExpression with such an operator to their SQLite equivalent. 

679_sqlite_json_operator_map: dict[ 1a

680 OperatorType, Callable[[sa.BinaryExpression[Any]], sa.ColumnElement[Any]] 

681] = { 

682 ASTEXT: _sqlite_json_astext, 

683 CONTAINS: _sqlite_json_contains, 

684 HAS_ALL: _sqlite_json_contains, # "has all" is equivalent to "contains" 

685 HAS_ANY: _sqlite_json_has_any, 

686} 

687 

688 

689@compiles(sa.BinaryExpression, "sqlite") 1a

690def sqlite_json_operators( 1a

691 element: sa.BinaryExpression[Any], 

692 compiler: SQLCompiler, 

693 override_operator: Optional[OperatorType] = None, 

694 **kwargs: Any, 

695) -> str: 

696 """Intercept the PostgreSQL-only JSON / JSONB operators and translate them to SQLite""" 

697 operator = override_operator or element.operator 1egbdfc

698 if (handler := _sqlite_json_operator_map.get(operator)) is not None: 1egbdfc

699 return compiler.process(handler(element), **kwargs) 1bdc

700 # ignore reason: SQLA compilation hooks are not as well covered with type annotations 

701 return compiler.visit_binary(element, override_operator=operator, **kwargs) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] 1egbdfc

702 

703 

704class greatest(functions.ReturnTypeFromArgs[T]): 1a

705 inherit_cache: bool = True 1a

706 

707 

708@compiles(greatest, "sqlite") 1a

709def sqlite_greatest_as_max( 1a

710 element: greatest[Any], compiler: SQLCompiler, **kwargs: Any 

711) -> str: 

712 # TODO: SQLite MAX() is very close to PostgreSQL GREATEST(), *except* when 

713 # it comes to nulls: SQLite MAX() returns NULL if _any_ clause is NULL, 

714 # whereas PostgreSQL GREATEST() only returns NULL if _all_ clauses are NULL. 

715 # 

716 # A work-around is to use MAX() as an aggregate function instead, in a 

717 # subquery. This, however, would probably require a VALUES-like construct 

718 # that SQLA doesn't currently support for SQLite. You can [provide 

719 # compilation hooks for 

720 # this](https://github.com/sqlalchemy/sqlalchemy/issues/7228#issuecomment-1746837960) 

721 # but this would only be worth it if sa.func.greatest() starts being used on 

722 # values that include NULLs. Up until the time of this comment this hasn't 

723 # been an issue. 

724 return compiler.process(sa.func.max(*element.clauses), **kwargs) 1b

725 

726 

727class least(functions.ReturnTypeFromArgs[T]): 1a

728 inherit_cache: bool = True 1a

729 

730 

731@compiles(least, "sqlite") 1a

732def sqlite_least_as_min( 1a

733 element: least[Any], compiler: SQLCompiler, **kwargs: Any 

734) -> str: 

735 # SQLite doesn't have LEAST(), use MIN() instead. 

736 # Note: Like MAX(), SQLite MIN() returns NULL if _any_ clause is NULL, 

737 # whereas PostgreSQL LEAST() only returns NULL if _all_ clauses are NULL. 

738 return compiler.process(sa.func.min(*element.clauses), **kwargs) 1b

739 

740 

741def get_dialect(obj: Union[str, Session, sa.Engine]) -> type[sa.Dialect]: 1a

742 """ 

743 Get the dialect of a session, engine, or connection url. 

744 

745 Primary use case is figuring out whether the Prefect REST API is communicating with 

746 SQLite or Postgres. 

747 

748 Example: 

749 ```python 

750 import prefect.settings 

751 from prefect.server.utilities.database import get_dialect 

752 

753 dialect = get_dialect(PREFECT_API_DATABASE_CONNECTION_URL.value()) 

754 if dialect.name == "sqlite": 

755 print("Using SQLite!") 

756 else: 

757 print("Using Postgres!") 

758 ``` 

759 """ 

760 if isinstance(obj, Session): 1alegmbidjhfc

761 assert obj.bind is not None 1bdc

762 obj = obj.bind.engine if isinstance(obj.bind, sa.Connection) else obj.bind 1bdc

763 

764 if isinstance(obj, sa.engine.Engine): 1alegmbidjhfc

765 url = obj.url 1bdc

766 else: 

767 url = sa.engine.url.make_url(obj) 1alegmbidjhfc

768 

769 return url.get_dialect() 1alegmbidjhfc