Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/database/configurations.py: 52%

217 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1from __future__ import annotations 1a

2 

3import sqlite3 1a

4import ssl 1a

5import traceback 1a

6from abc import ABC, abstractmethod 1a

7from asyncio import AbstractEventLoop, get_running_loop 1a

8from collections.abc import AsyncGenerator, Hashable 1a

9from contextlib import AbstractAsyncContextManager, asynccontextmanager 1a

10from contextvars import ContextVar 1a

11from functools import partial 1a

12from typing import Any, Optional 1a

13 

14import sqlalchemy as sa 1a

15from sqlalchemy import AdaptedConnection, event 1a

16from sqlalchemy.dialects.sqlite import aiosqlite 1a

17from sqlalchemy.engine.interfaces import DBAPIConnection 1a

18from sqlalchemy.ext.asyncio import ( 1a

19 AsyncConnection, 

20 AsyncEngine, 

21 AsyncSession, 

22 AsyncSessionTransaction, 

23 create_async_engine, 

24) 

25from sqlalchemy.pool import ConnectionPoolEntry 1a

26from typing_extensions import TypeAlias 1a

27 

28from prefect._internal.observability import configure_logfire 1a

29from prefect.settings import ( 1a

30 PREFECT_API_DATABASE_CONNECTION_TIMEOUT, 

31 PREFECT_API_DATABASE_ECHO, 

32 PREFECT_API_DATABASE_TIMEOUT, 

33 PREFECT_TESTING_UNIT_TEST_MODE, 

34 get_current_settings, 

35) 

36from prefect.utilities.asyncutils import add_event_loop_shutdown_callback 1a

37 

38logfire: Any | None = configure_logfire() 1a

39 

40SQLITE_BEGIN_MODE: ContextVar[Optional[str]] = ContextVar( # novm 1a

41 "SQLITE_BEGIN_MODE", default=None 

42) 

43 

44_EngineCacheKey: TypeAlias = tuple[AbstractEventLoop, str, bool, Optional[float]] 1a

45ENGINES: dict[_EngineCacheKey, AsyncEngine] = {} 1a

46 

47 

48class ConnectionTracker: 1a

49 """A test utility which tracks the connections given out by a connection pool, to 

50 make it easy to see which connections are currently checked out and open.""" 

51 

52 all_connections: dict[AdaptedConnection, list[str]] 1acb

53 open_connections: dict[AdaptedConnection, list[str]] 1a

54 left_field_closes: dict[AdaptedConnection, list[str]] 1a

55 connects: int 1a

56 closes: int 1a

57 active: bool 1a

58 

59 def __init__(self) -> None: 1a

60 self.active = False 1a

61 self.all_connections = {} 1a

62 self.open_connections = {} 1acb

63 self.left_field_closes = {} 1a

64 self.connects = 0 1a

65 self.closes = 0 1acb

66 

67 def track_pool(self, pool: sa.pool.Pool) -> None: 1a

68 event.listen(pool, "connect", self.on_connect) 

69 event.listen(pool, "close", self.on_close) 1cb

70 event.listen(pool, "close_detached", self.on_close_detached) 1cb

71 

72 def on_connect( 1a

73 self, 

74 adapted_connection: AdaptedConnection, 

75 connection_record: ConnectionPoolEntry, 

76 ) -> None: 

77 self.all_connections[adapted_connection] = traceback.format_stack() 

78 self.open_connections[adapted_connection] = traceback.format_stack() 

79 self.connects += 1 

80 

81 def on_close( 1afb

82 self, 

83 adapted_connection: AdaptedConnection, 

84 connection_record: ConnectionPoolEntry, 

85 ) -> None: 

86 try: 

87 del self.open_connections[adapted_connection] 

88 except KeyError: 1fb

89 self.left_field_closes[adapted_connection] = traceback.format_stack() 1fb

90 self.closes += 1 

91 

92 def on_close_detached( 1a

93 self, 

94 adapted_connection: AdaptedConnection, 

95 ) -> None: 

96 try: 

97 del self.open_connections[adapted_connection] 

98 except KeyError: 

99 self.left_field_closes[adapted_connection] = traceback.format_stack() 

100 self.closes += 1 

101 

102 def clear(self) -> None: 1ab

103 self.all_connections.clear() 

104 self.open_connections.clear() 

105 self.left_field_closes.clear() 

106 self.connects = 0 

107 self.closes = 0 

108 

109 

110TRACKER: ConnectionTracker = ConnectionTracker() 1a

111 

112 

113class BaseDatabaseConfiguration(ABC): 1a

114 """ 

115 Abstract base class used to inject database connection configuration into Prefect. 

116 

117 This configuration is responsible for defining how Prefect REST API creates and manages 

118 database connections and sessions. 

119 """ 

120 

121 def __init__( 1a

122 self, 

123 connection_url: str, 

124 echo: Optional[bool] = None, 

125 timeout: Optional[float] = None, 

126 connection_timeout: Optional[float] = None, 

127 sqlalchemy_pool_size: Optional[int] = None, 

128 sqlalchemy_max_overflow: Optional[int] = None, 

129 connection_app_name: Optional[str] = None, 

130 statement_cache_size: Optional[int] = None, 

131 prepared_statement_cache_size: Optional[int] = None, 

132 ) -> None: 

133 self.connection_url = connection_url 1a

134 self.echo: bool = echo or PREFECT_API_DATABASE_ECHO.value() 1a

135 self.timeout: Optional[float] = timeout or PREFECT_API_DATABASE_TIMEOUT.value() 1a

136 self.connection_timeout: Optional[float] = ( 1a

137 connection_timeout or PREFECT_API_DATABASE_CONNECTION_TIMEOUT.value() 

138 ) 

139 self.sqlalchemy_pool_size: Optional[int] = ( 1a

140 sqlalchemy_pool_size 

141 or get_current_settings().server.database.sqlalchemy.pool_size 

142 ) 

143 self.sqlalchemy_max_overflow: Optional[int] = ( 1a

144 sqlalchemy_max_overflow 

145 or get_current_settings().server.database.sqlalchemy.max_overflow 

146 ) 

147 self.connection_app_name: Optional[str] = ( 1a

148 connection_app_name 

149 or get_current_settings().server.database.sqlalchemy.connect_args.application_name 

150 ) 

151 self.statement_cache_size: Optional[int] = ( 1a

152 statement_cache_size 

153 or get_current_settings().server.database.sqlalchemy.connect_args.statement_cache_size 

154 ) 

155 self.prepared_statement_cache_size: Optional[int] = ( 1a

156 prepared_statement_cache_size 

157 or get_current_settings().server.database.sqlalchemy.connect_args.prepared_statement_cache_size 

158 ) 

159 

160 def unique_key(self) -> tuple[Hashable, ...]: 1a

161 """ 

162 Returns a key used to determine whether to instantiate a new DB interface. 

163 """ 

164 return (self.__class__, self.connection_url) 1aehidcjflkgb

165 

166 @abstractmethod 1a

167 async def engine(self) -> AsyncEngine: 1a

168 """Returns a SqlAlchemy engine""" 

169 

170 @abstractmethod 1a

171 async def session(self, engine: AsyncEngine) -> AsyncSession: 1a

172 """ 

173 Retrieves a SQLAlchemy session for an engine. 

174 """ 

175 

176 @abstractmethod 1a

177 async def create_db( 1a

178 self, connection: AsyncConnection, base_metadata: sa.MetaData 

179 ) -> None: 

180 """Create the database""" 

181 

182 @abstractmethod 1a

183 async def drop_db( 1a

184 self, connection: AsyncConnection, base_metadata: sa.MetaData 

185 ) -> None: 

186 """Drop the database""" 

187 

188 @abstractmethod 1a

189 def is_inmemory(self) -> bool: 1a

190 """Returns true if database is run in memory""" 

191 

192 @abstractmethod 1a

193 def begin_transaction( 1a

194 self, session: AsyncSession, with_for_update: bool = False 

195 ) -> AbstractAsyncContextManager[AsyncSessionTransaction]: 

196 """Enter a transaction for a session""" 

197 pass 

198 

199 

200class AsyncPostgresConfiguration(BaseDatabaseConfiguration): 1a

201 async def engine(self) -> AsyncEngine: 1a

202 """Retrieves an async SQLAlchemy engine. 

203 

204 Args: 

205 connection_url (str, optional): The database connection string. 

206 Defaults to self.connection_url 

207 echo (bool, optional): Whether to echo SQL sent 

208 to the database. Defaults to self.echo 

209 timeout (float, optional): The database statement timeout, in seconds. 

210 Defaults to self.timeout 

211 

212 Returns: 

213 AsyncEngine: a SQLAlchemy engine 

214 """ 

215 

216 loop = get_running_loop() 

217 

218 cache_key = ( 

219 loop, 

220 self.connection_url, 

221 self.echo, 

222 self.timeout, 

223 ) 

224 if cache_key not in ENGINES: 

225 kwargs: dict[str, Any] = ( 

226 get_current_settings().server.database.sqlalchemy.model_dump( 

227 mode="json", exclude={"connect_args"} 

228 ) 

229 ) 

230 connect_args: dict[str, Any] = {} 

231 

232 if self.timeout is not None: 

233 connect_args["command_timeout"] = self.timeout 

234 

235 if self.connection_timeout is not None: 

236 connect_args["timeout"] = self.connection_timeout 

237 

238 if self.statement_cache_size is not None: 

239 connect_args["statement_cache_size"] = self.statement_cache_size 

240 

241 if self.prepared_statement_cache_size is not None: 

242 connect_args["prepared_statement_cache_size"] = ( 

243 self.prepared_statement_cache_size 

244 ) 

245 

246 if self.connection_app_name is not None: 

247 connect_args["server_settings"] = dict( 

248 application_name=self.connection_app_name 

249 ) 

250 

251 if get_current_settings().server.database.sqlalchemy.connect_args.tls.enabled: 

252 tls_config = ( 

253 get_current_settings().server.database.sqlalchemy.connect_args.tls 

254 ) 

255 

256 pg_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH) 

257 

258 if tls_config.ca_file: 

259 pg_ctx = ssl.create_default_context( 

260 purpose=ssl.Purpose.SERVER_AUTH, cafile=tls_config.ca_file 

261 ) 

262 

263 pg_ctx.minimum_version = ssl.TLSVersion.TLSv1_2 

264 

265 if tls_config.cert_file and tls_config.key_file: 

266 pg_ctx.load_cert_chain( 

267 certfile=tls_config.cert_file, keyfile=tls_config.key_file 

268 ) 

269 

270 pg_ctx.check_hostname = tls_config.check_hostname 

271 pg_ctx.verify_mode = ssl.CERT_REQUIRED 

272 connect_args["ssl"] = pg_ctx 

273 

274 if connect_args: 

275 kwargs["connect_args"] = connect_args 

276 

277 if self.sqlalchemy_pool_size is not None: 

278 kwargs["pool_size"] = self.sqlalchemy_pool_size 

279 

280 if self.sqlalchemy_max_overflow is not None: 

281 kwargs["max_overflow"] = self.sqlalchemy_max_overflow 

282 

283 engine = create_async_engine( 

284 self.connection_url, 

285 echo=self.echo, 

286 # "pre-ping" connections upon checkout to ensure they have not been 

287 # closed on the server side 

288 pool_pre_ping=True, 

289 # Use connections in LIFO order to help reduce connections 

290 # after spiky load and in general increase the likelihood 

291 # that a given connection pulled from the pool will be 

292 # usable. 

293 pool_use_lifo=True, 

294 **kwargs, 

295 ) 

296 

297 if logfire: 

298 logfire.instrument_sqlalchemy(engine) # pyright: ignore 

299 

300 if TRACKER.active: 

301 TRACKER.track_pool(engine.pool) 

302 

303 ENGINES[cache_key] = engine 

304 await self.schedule_engine_disposal(cache_key) 

305 return ENGINES[cache_key] 

306 

307 async def schedule_engine_disposal(self, cache_key: _EngineCacheKey) -> None: 1a

308 """ 

309 Dispose of an engine once the event loop is closing. 

310 

311 See caveats at `add_event_loop_shutdown_callback`. 

312 

313 We attempted to lazily clean up old engines when new engines are created, but 

314 if the loop the engine is attached to is already closed then the connections 

315 cannot be cleaned up properly and warnings are displayed. 

316 

317 Engine disposal should only be important when running the application 

318 ephemerally. Notably, this is an issue in our tests where many short-lived event 

319 loops and engines are created which can consume all of the available database 

320 connection slots. Users operating at a scale where connection limits are 

321 encountered should be encouraged to use a standalone server. 

322 """ 

323 

324 async def dispose_engine(cache_key: _EngineCacheKey) -> None: 

325 engine = ENGINES.pop(cache_key, None) 

326 if engine: 

327 await engine.dispose() 

328 

329 await add_event_loop_shutdown_callback(partial(dispose_engine, cache_key)) 

330 

331 async def session(self, engine: AsyncEngine) -> AsyncSession: 1a

332 """ 

333 Retrieves a SQLAlchemy session for an engine. 

334 

335 Args: 

336 engine: a sqlalchemy engine 

337 """ 

338 return AsyncSession(engine, expire_on_commit=False) 

339 

340 @asynccontextmanager 1a

341 async def begin_transaction( 1a

342 self, session: AsyncSession, with_for_update: bool = False 

343 ) -> AsyncGenerator[AsyncSessionTransaction, None]: 

344 # `with_for_update` is for SQLite only. For Postgres, lock the row on read 

345 # for update instead. 

346 async with session.begin() as transaction: 

347 yield transaction 

348 

349 async def create_db( 1a

350 self, connection: AsyncConnection, base_metadata: sa.MetaData 

351 ) -> None: 

352 """Create the database""" 

353 

354 await connection.run_sync(base_metadata.create_all) 

355 

356 async def drop_db( 1a

357 self, connection: AsyncConnection, base_metadata: sa.MetaData 

358 ) -> None: 

359 """Drop the database""" 

360 

361 await connection.run_sync(base_metadata.drop_all) 

362 

363 def is_inmemory(self) -> bool: 1a

364 """Returns true if database is run in memory""" 

365 

366 return False 

367 

368 

369class AioSqliteConfiguration(BaseDatabaseConfiguration): 1a

370 MIN_SQLITE_VERSION = (3, 24, 0) 1a

371 

372 async def engine(self) -> AsyncEngine: 1a

373 """Retrieves an async SQLAlchemy engine. 

374 

375 Args: 

376 connection_url (str, optional): The database connection string. 

377 Defaults to self.connection_url 

378 echo (bool, optional): Whether to echo SQL sent 

379 to the database. Defaults to self.echo 

380 timeout (float, optional): The database statement timeout, in seconds. 

381 Defaults to self.timeout 

382 

383 Returns: 

384 AsyncEngine: a SQLAlchemy engine 

385 """ 

386 

387 if sqlite3.sqlite_version_info < self.MIN_SQLITE_VERSION: 387 ↛ 388line 387 didn't jump to line 388 because the condition on line 387 was never true1ehidcjfkgb

388 required = ".".join(str(v) for v in self.MIN_SQLITE_VERSION) 

389 raise RuntimeError( 

390 f"Prefect requires sqlite >= {required} but we found version " 

391 f"{sqlite3.sqlite_version}" 

392 ) 

393 

394 kwargs: dict[str, Any] = dict() 1ehidcjfkgb

395 

396 loop = get_running_loop() 1ehidcjfkgb

397 

398 cache_key = (loop, self.connection_url, self.echo, self.timeout) 1ehidcjfkgb

399 if cache_key not in ENGINES: 1ehidcjfkgb

400 # apply database timeout 

401 if self.timeout is not None: 401 ↛ 408line 401 didn't jump to line 408 because the condition on line 401 was always true1ed

402 kwargs["connect_args"] = dict(timeout=self.timeout) 1ed

403 

404 # use `named` paramstyle for sqlite instead of `qmark` in very rare 

405 # circumstances, we've seen aiosqlite pass parameters in the wrong 

406 # order; by using named parameters we avoid this issue 

407 # see https://github.com/PrefectHQ/prefect/pull/6702 

408 kwargs["paramstyle"] = "named" 1ed

409 

410 # ensure a long-lasting pool is used with in-memory databases 

411 # because they disappear when the last connection closes 

412 if ":memory:" in self.connection_url: 412 ↛ 413line 412 didn't jump to line 413 because the condition on line 412 was never true1ed

413 kwargs.update( 

414 poolclass=sa.pool.AsyncAdaptedQueuePool, 

415 pool_size=1, 

416 max_overflow=0, 

417 pool_recycle=-1, 

418 ) 

419 

420 engine = create_async_engine(self.connection_url, echo=self.echo, **kwargs) 1ed

421 event.listen(engine.sync_engine, "connect", self.setup_sqlite) 1ed

422 event.listen(engine.sync_engine, "begin", self.begin_sqlite_stmt) 1ed

423 

424 if logfire: 424 ↛ 425line 424 didn't jump to line 425 because the condition on line 424 was never true1ed

425 logfire.instrument_sqlalchemy(engine) # pyright: ignore 

426 

427 if TRACKER.active: 427 ↛ 428line 427 didn't jump to line 428 because the condition on line 427 was never true1ed

428 TRACKER.track_pool(engine.pool) 

429 

430 ENGINES[cache_key] = engine 1ed

431 await self.schedule_engine_disposal(cache_key) 1ed

432 return ENGINES[cache_key] 1ehidcjfkgb

433 

434 async def schedule_engine_disposal(self, cache_key: _EngineCacheKey) -> None: 1a

435 """ 

436 Dispose of an engine once the event loop is closing. 

437 

438 See caveats at `add_event_loop_shutdown_callback`. 

439 

440 We attempted to lazily clean up old engines when new engines are created, but 

441 if the loop the engine is attached to is already closed then the connections 

442 cannot be cleaned up properly and warnings are displayed. 

443 

444 Engine disposal should only be important when running the application 

445 ephemerally. Notably, this is an issue in our tests where many short-lived event 

446 loops and engines are created which can consume all of the available database 

447 connection slots. Users operating at a scale where connection limits are 

448 encountered should be encouraged to use a standalone server. 

449 """ 

450 

451 async def dispose_engine(cache_key: _EngineCacheKey) -> None: 1ed

452 engine = ENGINES.pop(cache_key, None) 

453 if engine: 

454 await engine.dispose() 

455 

456 await add_event_loop_shutdown_callback(partial(dispose_engine, cache_key)) 1ed

457 

458 def setup_sqlite(self, conn: DBAPIConnection, record: ConnectionPoolEntry) -> None: 1a

459 """Issue PRAGMA statements to SQLITE on connect. PRAGMAs only last for the 

460 duration of the connection. See https://www.sqlite.org/pragma.html for more info. 

461 """ 

462 # workaround sqlite transaction behavior 

463 if isinstance(conn, aiosqlite.AsyncAdapt_aiosqlite_connection): 463 ↛ 466line 463 didn't jump to line 466 because the condition on line 463 was always true1hicmjlg

464 self.begin_sqlite_conn(conn) 1hicmjlg

465 

466 cursor = conn.cursor() 

467 

468 # write to a write-ahead-log instead and regularly commit the changes 

469 # this allows multiple concurrent readers even during a write transaction 

470 # even with the WAL we can get busy errors if we have transactions that: 

471 # - t1 reads from a database 

472 # - t2 inserts to the database 

473 # - t1 tries to insert to the database 

474 # this can be resolved by using the IMMEDIATE transaction mode in t1 

475 cursor.execute("PRAGMA journal_mode = WAL;") 

476 

477 # enable foreign keys 

478 cursor.execute("PRAGMA foreign_keys = ON;") 

479 

480 # disable legacy alter table behavior as it will cause problems during 

481 # migrations when tables are renamed as references would otherwise be retained 

482 # in some locations 

483 # https://www.sqlite.org/pragma.html#pragma_legacy_alter_table 

484 cursor.execute("PRAGMA legacy_alter_table=OFF") 

485 

486 # when using the WAL, we do need to sync changes on every write. sqlite 

487 # recommends using 'normal' mode which is much faster 

488 cursor.execute("PRAGMA synchronous = NORMAL;") 

489 

490 # a higher cache size (default of 2000) for more aggressive performance 

491 cursor.execute("PRAGMA cache_size = 20000;") 

492 

493 # wait for this amount of time while a table is locked 

494 # before returning and raising an error 

495 # setting the value very high allows for more 'concurrency' 

496 # without running into errors, but may result in slow api calls 

497 if PREFECT_TESTING_UNIT_TEST_MODE.value() is True: 

498 cursor.execute("PRAGMA busy_timeout = 5000;") # 5s 

499 else: 

500 cursor.execute("PRAGMA busy_timeout = 60000;") # 60s 

501 

502 # `PRAGMA temp_store = memory;` moves temporary tables from disk into RAM 

503 # this supposedly speeds up reads, but it seems to actually 

504 # decrease overall performance, see https://github.com/PrefectHQ/prefect/pull/14812 

505 # cursor.execute("PRAGMA temp_store = memory;") 

506 

507 cursor.close() 

508 

509 def begin_sqlite_conn( 1a

510 self, conn: aiosqlite.AsyncAdapt_aiosqlite_connection 

511 ) -> None: 

512 # disable pysqlite's emitting of the BEGIN statement entirely. 

513 # also stops it from emitting COMMIT before any DDL. 

514 # requires `begin_sqlite_stmt` 

515 # see https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl 

516 conn.isolation_level = None 1hicmjlg

517 

518 def begin_sqlite_stmt(self, conn: sa.Connection) -> None: 1a

519 # emit our own BEGIN 

520 # requires `begin_sqlite_conn` 

521 # see https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl 

522 mode = SQLITE_BEGIN_MODE.get() 1hicmjflkgb

523 if mode is not None: 1hicmjflkgb

524 conn.exec_driver_sql(f"BEGIN {mode}") 1hicmjflkgb

525 

526 # Note this is intentionally a no-op if there is no BEGIN MODE set 

527 # This allows us to use SQLite's default behavior for reads which do not need 

528 # to be wrapped in a long-running transaction 

529 

530 @asynccontextmanager 1a

531 async def begin_transaction( 1a

532 self, session: AsyncSession, with_for_update: bool = False 

533 ) -> AsyncGenerator[AsyncSessionTransaction, None]: 

534 token = SQLITE_BEGIN_MODE.set("IMMEDIATE" if with_for_update else "DEFERRED") 1hicfkgb

535 

536 try: 1hicfkgb

537 async with session.begin() as transaction: 1hicmjflkgb

538 yield transaction 1hicmjflkgb

539 finally: 

540 SQLITE_BEGIN_MODE.reset(token) 1cmfgb

541 

542 async def session(self, engine: AsyncEngine) -> AsyncSession: 1a

543 """ 

544 Retrieves a SQLAlchemy session for an engine. 

545 

546 Args: 

547 engine: a sqlalchemy engine 

548 """ 

549 return AsyncSession(engine, expire_on_commit=False) 1hidcjfkgb

550 

551 async def create_db( 1a

552 self, connection: AsyncConnection, base_metadata: sa.MetaData 

553 ) -> None: 

554 """Create the database""" 

555 

556 await connection.run_sync(base_metadata.create_all) 

557 

558 async def drop_db( 1a

559 self, connection: AsyncConnection, base_metadata: sa.MetaData 

560 ) -> None: 

561 """Drop the database""" 

562 

563 await connection.run_sync(base_metadata.drop_all) 

564 

565 def is_inmemory(self) -> bool: 1a

566 """Returns true if database is run in memory""" 

567 

568 return ":memory:" in self.connection_url or "mode=memory" in self.connection_url