Coverage for polar/eventstream/endpoints.py: 30%

61 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1import asyncio 1a

2from collections.abc import AsyncGenerator 1a

3from typing import Any 1a

4 

5import structlog 1a

6from fastapi import Depends, Request 1a

7from redis.exceptions import ConnectionError 1a

8from sse_starlette.sse import EventSourceResponse 1a

9from uvicorn import Server 1a

10 

11from polar.auth.dependencies import WebUserRead 1a

12from polar.exceptions import ResourceNotFound 1a

13from polar.organization.schemas import OrganizationID 1a

14from polar.organization.service import organization as organization_service 1a

15from polar.postgres import AsyncSession, get_db_session 1a

16from polar.redis import Redis, get_redis 1a

17from polar.routing import APIRouter 1a

18 

19from .service import Receivers 1a

20 

21router = APIRouter(prefix="/stream", tags=["stream"], include_in_schema=False) 1a

22 

23log = structlog.get_logger() 1a

24 

25 

26def _uvicorn_should_exit() -> bool: 1a

27 """ 

28 Hacky way to check if Uvicorn server is shutting down, by retrieving 

29 it from the running asyncio tasks. 

30 

31 We do this because the exit signal handler monkey-patch made by sse_starlette 

32 doesn't work when running Uvicorn from the CLI, 

33 preventing a graceful shutdown when a SSE connection is open. 

34 """ 

35 try: 

36 for task in asyncio.all_tasks(): 

37 coroutine = task.get_coro() 

38 if coroutine is not None: 

39 frame = coroutine.cr_frame # type: ignore 

40 if frame is not None: 

41 args = frame.f_locals 

42 if self := args.get("self"): 

43 if isinstance(self, Server): 

44 return self.should_exit 

45 except RuntimeError: 

46 pass 

47 return False 

48 

49 

50async def subscribe( 1a

51 redis: Redis, 

52 channels: list[str], 

53 request: Request, 

54) -> AsyncGenerator[Any, Any]: 

55 async with redis.pubsub() as pubsub: 

56 await pubsub.subscribe(*channels) 

57 

58 while not _uvicorn_should_exit(): 

59 if await request.is_disconnected(): 

60 await pubsub.close() 

61 break 

62 

63 try: 

64 message = await pubsub.get_message( 

65 ignore_subscribe_messages=True, 

66 # Waits for up to 10s for a new message 

67 timeout=10.0, 

68 ) 

69 

70 if message is not None: 

71 log.info("redis.pubsub", message=message["data"]) 

72 yield message["data"] 

73 except asyncio.CancelledError as e: 

74 await pubsub.close() 

75 raise e 

76 except ConnectionError as e: 

77 await pubsub.close() 

78 raise e 

79 

80 

81@router.get("/user") 1a

82async def user_stream( 1a

83 request: Request, 

84 auth_subject: WebUserRead, 

85 redis: Redis = Depends(get_redis), 

86) -> EventSourceResponse: 

87 receivers = Receivers(user_id=auth_subject.subject.id) 

88 return EventSourceResponse(subscribe(redis, receivers.get_channels(), request)) 

89 

90 

91@router.get("/organizations/{id}") 1a

92async def org_stream( 1a

93 id: OrganizationID, 

94 request: Request, 

95 auth_subject: WebUserRead, 

96 redis: Redis = Depends(get_redis), 

97 session: AsyncSession = Depends(get_db_session), 

98) -> EventSourceResponse: 

99 organization = await organization_service.get(session, auth_subject, id) 

100 if organization is None: 

101 raise ResourceNotFound() 

102 

103 receivers = Receivers( 

104 user_id=auth_subject.subject.id, organization_id=organization.id 

105 ) 

106 return EventSourceResponse(subscribe(redis, receivers.get_channels(), request))