Coverage for polar/notifications/service.py: 49%

49 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 16:17 +0000

1from collections.abc import Sequence 1a

2from uuid import UUID 1a

3 

4from pydantic import BaseModel, TypeAdapter 1a

5from sqlalchemy import desc 1a

6from sqlalchemy.orm import joinedload 1a

7 

8from polar.kit.extensions.sqlalchemy import sql 1a

9from polar.models.notification import Notification 1a

10from polar.models.user_notification import UserNotification 1a

11from polar.postgres import AsyncSession 1a

12from polar.user_organization.service import ( 1a

13 user_organization as user_organization_service, 

14) 

15from polar.worker import enqueue_job 1a

16 

17from .notification import Notification as NotificationSchema 1a

18from .notification import NotificationPayload, NotificationType 1a

19 

20 

21class PartialNotification(BaseModel): 1a

22 type: NotificationType 1a

23 payload: NotificationPayload 1a

24 

25 

26class NotificationsService: 1a

27 async def get(self, session: AsyncSession, id: UUID) -> Notification | None: 1a

28 stmt = ( 

29 sql.select(Notification) 

30 .where(Notification.id == id) 

31 .options(joinedload(Notification.user)) 

32 ) 

33 

34 res = await session.execute(stmt) 

35 return res.scalars().unique().one_or_none() 

36 

37 async def get_for_user( 1a

38 self, session: AsyncSession, user_id: UUID 

39 ) -> Sequence[Notification]: 

40 stmt = ( 

41 sql.select(Notification) 

42 .where(Notification.user_id == user_id) 

43 .order_by(desc(Notification.created_at)) 

44 .limit(100) 

45 ) 

46 

47 res = await session.execute(stmt) 

48 return res.scalars().unique().all() 

49 

50 async def send_to_user( 1a

51 self, 

52 session: AsyncSession, 

53 user_id: UUID, 

54 notif: PartialNotification, 

55 ) -> bool: 

56 notification = Notification( 

57 user_id=user_id, 

58 type=notif.type, 

59 payload=notif.payload.model_dump(mode="json"), 

60 ) 

61 

62 session.add(notification) 

63 await session.flush() 

64 enqueue_job("notifications.send", notification_id=notification.id) 

65 enqueue_job("notifications.push", notification_id=notification.id) 

66 return True 

67 

68 async def send_to_org_members( 1a

69 self, 

70 session: AsyncSession, 

71 org_id: UUID, 

72 notif: PartialNotification, 

73 ) -> None: 

74 members = await user_organization_service.list_by_org(session, org_id) 

75 for member in members: 

76 await self.send_to_user( 

77 session=session, 

78 user_id=member.user_id, 

79 notif=notif, 

80 ) 

81 

82 def parse_payload(self, n: Notification) -> NotificationPayload: 1a

83 NotificationTypeAdapter: TypeAdapter[NotificationSchema] = TypeAdapter( 

84 NotificationSchema 

85 ) 

86 notification = NotificationTypeAdapter.validate_python(n) 

87 return notification.payload 

88 

89 async def get_user_last_read( 1a

90 self, session: AsyncSession, user_id: UUID 

91 ) -> UUID | None: 

92 stmt = sql.select(UserNotification).where(UserNotification.user_id == user_id) 

93 res = await session.execute(stmt) 

94 user_notif = res.scalar_one_or_none() 

95 return user_notif.last_read_notification_id if user_notif else None 

96 

97 async def set_user_last_read( 1a

98 self, session: AsyncSession, user_id: UUID, notification_id: UUID 

99 ) -> None: 

100 stmt = ( 

101 sql.insert(UserNotification) 

102 .values(user_id=user_id, last_read_notification_id=notification_id) 

103 .on_conflict_do_update( 

104 index_elements=[UserNotification.user_id], 

105 set_={"last_read_notification_id": notification_id}, 

106 ) 

107 ) 

108 await session.execute(stmt) 

109 

110 

111notifications = NotificationsService() 1a