Coverage for polar/eventstream/endpoints.py: 30%
61 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
1import asyncio 1a
2from collections.abc import AsyncGenerator 1a
3from typing import Any 1a
5import structlog 1a
6from fastapi import Depends, Request 1a
7from redis.exceptions import ConnectionError 1a
8from sse_starlette.sse import EventSourceResponse 1a
9from uvicorn import Server 1a
11from polar.auth.dependencies import WebUserRead 1a
12from polar.exceptions import ResourceNotFound 1a
13from polar.organization.schemas import OrganizationID 1a
14from polar.organization.service import organization as organization_service 1a
15from polar.postgres import AsyncSession, get_db_session 1a
16from polar.redis import Redis, get_redis 1a
17from polar.routing import APIRouter 1a
19from .service import Receivers 1a
21router = APIRouter(prefix="/stream", tags=["stream"], include_in_schema=False) 1a
23log = structlog.get_logger() 1a
26def _uvicorn_should_exit() -> bool: 1a
27 """
28 Hacky way to check if Uvicorn server is shutting down, by retrieving
29 it from the running asyncio tasks.
31 We do this because the exit signal handler monkey-patch made by sse_starlette
32 doesn't work when running Uvicorn from the CLI,
33 preventing a graceful shutdown when a SSE connection is open.
34 """
35 try:
36 for task in asyncio.all_tasks():
37 coroutine = task.get_coro()
38 if coroutine is not None:
39 frame = coroutine.cr_frame # type: ignore
40 if frame is not None:
41 args = frame.f_locals
42 if self := args.get("self"):
43 if isinstance(self, Server):
44 return self.should_exit
45 except RuntimeError:
46 pass
47 return False
50async def subscribe( 1a
51 redis: Redis,
52 channels: list[str],
53 request: Request,
54) -> AsyncGenerator[Any, Any]:
55 async with redis.pubsub() as pubsub:
56 await pubsub.subscribe(*channels)
58 while not _uvicorn_should_exit():
59 if await request.is_disconnected():
60 await pubsub.close()
61 break
63 try:
64 message = await pubsub.get_message(
65 ignore_subscribe_messages=True,
66 # Waits for up to 10s for a new message
67 timeout=10.0,
68 )
70 if message is not None:
71 log.info("redis.pubsub", message=message["data"])
72 yield message["data"]
73 except asyncio.CancelledError as e:
74 await pubsub.close()
75 raise e
76 except ConnectionError as e:
77 await pubsub.close()
78 raise e
81@router.get("/user") 1a
82async def user_stream( 1a
83 request: Request,
84 auth_subject: WebUserRead,
85 redis: Redis = Depends(get_redis),
86) -> EventSourceResponse:
87 receivers = Receivers(user_id=auth_subject.subject.id)
88 return EventSourceResponse(subscribe(redis, receivers.get_channels(), request))
91@router.get("/organizations/{id}") 1a
92async def org_stream( 1a
93 id: OrganizationID,
94 request: Request,
95 auth_subject: WebUserRead,
96 redis: Redis = Depends(get_redis),
97 session: AsyncSession = Depends(get_db_session),
98) -> EventSourceResponse:
99 organization = await organization_service.get(session, auth_subject, id)
100 if organization is None:
101 raise ResourceNotFound()
103 receivers = Receivers(
104 user_id=auth_subject.subject.id, organization_id=organization.id
105 )
106 return EventSourceResponse(subscribe(redis, receivers.get_channels(), request))