Coverage for polar/worker/_sqlalchemy.py: 46%
42 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
1import contextlib 1ab
2from collections.abc import AsyncIterator 1ab
4import dramatiq 1ab
5import structlog 1ab
6from dramatiq.asyncio import get_event_loop_thread 1ab
8from polar.kit.db.postgres import AsyncSessionMaker as AsyncSessionMakerType 1ab
9from polar.kit.db.postgres import create_async_sessionmaker 1ab
10from polar.logfire import instrument_sqlalchemy 1ab
11from polar.logging import Logger 1ab
12from polar.postgres import AsyncEngine, AsyncSession, create_async_engine 1ab
14log: Logger = structlog.get_logger() 1ab
16_sqlalchemy_engine: AsyncEngine | None = None 1ab
17_sqlalchemy_async_sessionmaker: AsyncSessionMakerType | None = None 1ab
20async def dispose_sqlalchemy_engine() -> None: 1ab
21 global _sqlalchemy_engine
22 if _sqlalchemy_engine is not None:
23 await _sqlalchemy_engine.dispose()
24 log.info("Disposed SQLAlchemy engine")
25 _sqlalchemy_engine = None
28class SQLAlchemyMiddleware(dramatiq.Middleware): 1ab
29 """
30 Middleware managing the lifecycle of the database engine and sessionmaker.
31 """
33 @classmethod 1ab
34 def get_async_session(cls) -> contextlib.AbstractAsyncContextManager[AsyncSession]: 1ab
35 global _sqlalchemy_async_sessionmaker
36 if _sqlalchemy_async_sessionmaker is None:
37 raise RuntimeError("SQLAlchemy not initialized")
38 return _sqlalchemy_async_sessionmaker()
40 def before_worker_boot( 1ab
41 self, broker: dramatiq.Broker, worker: dramatiq.Worker
42 ) -> None:
43 global _sqlalchemy_engine, _sqlalchemy_async_sessionmaker
44 _sqlalchemy_engine = create_async_engine("worker")
45 _sqlalchemy_async_sessionmaker = create_async_sessionmaker(_sqlalchemy_engine)
46 instrument_sqlalchemy([_sqlalchemy_engine.sync_engine])
47 log.info("Created database engine")
49 def after_worker_shutdown( 1ab
50 self, broker: dramatiq.Broker, worker: dramatiq.Worker
51 ) -> None:
52 event_loop_thread = get_event_loop_thread()
53 assert event_loop_thread is not None
54 event_loop_thread.run_coroutine(dispose_sqlalchemy_engine())
57@contextlib.asynccontextmanager 1ab
58async def AsyncSessionMaker() -> AsyncIterator[AsyncSession]: 1ab
59 """
60 Context manager to handle a database session taken from the middleware context.
61 """
62 async with SQLAlchemyMiddleware.get_async_session() as session:
63 try:
64 yield session
65 except:
66 await session.rollback()
67 raise
68 else:
69 await session.commit()