Coverage for polar/worker/_sqlalchemy.py: 46%

42 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 15:52 +0000

1import contextlib 1ab

2from collections.abc import AsyncIterator 1ab

3 

4import dramatiq 1ab

5import structlog 1ab

6from dramatiq.asyncio import get_event_loop_thread 1ab

7 

8from polar.kit.db.postgres import AsyncSessionMaker as AsyncSessionMakerType 1ab

9from polar.kit.db.postgres import create_async_sessionmaker 1ab

10from polar.logfire import instrument_sqlalchemy 1ab

11from polar.logging import Logger 1ab

12from polar.postgres import AsyncEngine, AsyncSession, create_async_engine 1ab

13 

14log: Logger = structlog.get_logger() 1ab

15 

16_sqlalchemy_engine: AsyncEngine | None = None 1ab

17_sqlalchemy_async_sessionmaker: AsyncSessionMakerType | None = None 1ab

18 

19 

20async def dispose_sqlalchemy_engine() -> None: 1ab

21 global _sqlalchemy_engine 

22 if _sqlalchemy_engine is not None: 

23 await _sqlalchemy_engine.dispose() 

24 log.info("Disposed SQLAlchemy engine") 

25 _sqlalchemy_engine = None 

26 

27 

28class SQLAlchemyMiddleware(dramatiq.Middleware): 1ab

29 """ 

30 Middleware managing the lifecycle of the database engine and sessionmaker. 

31 """ 

32 

33 @classmethod 1ab

34 def get_async_session(cls) -> contextlib.AbstractAsyncContextManager[AsyncSession]: 1ab

35 global _sqlalchemy_async_sessionmaker 

36 if _sqlalchemy_async_sessionmaker is None: 

37 raise RuntimeError("SQLAlchemy not initialized") 

38 return _sqlalchemy_async_sessionmaker() 

39 

40 def before_worker_boot( 1ab

41 self, broker: dramatiq.Broker, worker: dramatiq.Worker 

42 ) -> None: 

43 global _sqlalchemy_engine, _sqlalchemy_async_sessionmaker 

44 _sqlalchemy_engine = create_async_engine("worker") 

45 _sqlalchemy_async_sessionmaker = create_async_sessionmaker(_sqlalchemy_engine) 

46 instrument_sqlalchemy([_sqlalchemy_engine.sync_engine]) 

47 log.info("Created database engine") 

48 

49 def after_worker_shutdown( 1ab

50 self, broker: dramatiq.Broker, worker: dramatiq.Worker 

51 ) -> None: 

52 event_loop_thread = get_event_loop_thread() 

53 assert event_loop_thread is not None 

54 event_loop_thread.run_coroutine(dispose_sqlalchemy_engine()) 

55 

56 

57@contextlib.asynccontextmanager 1ab

58async def AsyncSessionMaker() -> AsyncIterator[AsyncSession]: 1ab

59 """ 

60 Context manager to handle a database session taken from the middleware context. 

61 """ 

62 async with SQLAlchemyMiddleware.get_async_session() as session: 

63 try: 

64 yield session 

65 except: 

66 await session.rollback() 

67 raise 

68 else: 

69 await session.commit()