Coverage for polar/user_organization/service.py: 31%
69 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
1from collections.abc import Sequence 1a
2from uuid import UUID 1a
4from sqlalchemy import Select, func 1a
5from sqlalchemy.orm import joinedload 1a
7from polar.exceptions import PolarError 1a
8from polar.kit.utils import utc_now 1a
9from polar.models import UserOrganization 1a
10from polar.postgres import AsyncReadSession, AsyncSession, sql 1a
13class UserOrganizationError(PolarError): ... 1a
16class OrganizationNotFound(UserOrganizationError): 1a
17 def __init__(self, organization_id: UUID) -> None: 1a
18 self.organization_id = organization_id
19 message = f"Organization with id {organization_id} not found."
20 super().__init__(message, 404)
23class UserNotMemberOfOrganization(UserOrganizationError): 1a
24 def __init__(self, user_id: UUID, organization_id: UUID) -> None: 1a
25 self.user_id = user_id
26 self.organization_id = organization_id
27 message = (
28 f"User with id {user_id} is not a member of organization {organization_id}."
29 )
30 super().__init__(message, 404)
33class CannotRemoveOrganizationAdmin(UserOrganizationError): 1a
34 def __init__(self, user_id: UUID, organization_id: UUID) -> None: 1a
35 self.user_id = user_id
36 self.organization_id = organization_id
37 message = f"Cannot remove user {user_id} - they are the admin of organization {organization_id}."
38 super().__init__(message, 403)
41class UserOrganizationService: 1a
42 async def list_by_org( 1a
43 self, session: AsyncReadSession, org_id: UUID
44 ) -> Sequence[UserOrganization]:
45 stmt = (
46 sql.select(UserOrganization)
47 .where(
48 UserOrganization.organization_id == org_id,
49 UserOrganization.deleted_at.is_(None),
50 )
51 .options(
52 joinedload(UserOrganization.user),
53 joinedload(UserOrganization.organization),
54 )
55 )
57 res = await session.execute(stmt)
58 return res.scalars().unique().all()
60 async def list_by_user_id( 1a
61 self, session: AsyncSession, user_id: UUID
62 ) -> Sequence[UserOrganization]:
63 stmt = self._get_list_by_user_id_query(user_id)
64 res = await session.execute(stmt)
65 return res.scalars().unique().all()
67 async def get_user_organization_count( 1a
68 self, session: AsyncSession, user_id: UUID
69 ) -> int:
70 stmt = self._get_list_by_user_id_query(
71 user_id, ordered=False
72 ).with_only_columns(func.count(UserOrganization.organization_id))
73 res = await session.execute(stmt)
74 count = res.scalar()
75 if count:
76 return count
77 return 0
79 async def get_by_user_and_org( 1a
80 self,
81 session: AsyncSession,
82 user_id: UUID,
83 organization_id: UUID,
84 ) -> UserOrganization | None:
85 stmt = (
86 sql.select(UserOrganization)
87 .where(
88 UserOrganization.user_id == user_id,
89 UserOrganization.organization_id == organization_id,
90 UserOrganization.deleted_at.is_(None),
91 )
92 .options(
93 joinedload(UserOrganization.user),
94 joinedload(UserOrganization.organization),
95 )
96 )
98 res = await session.execute(stmt)
99 return res.scalars().unique().one_or_none()
101 async def remove_member( 1a
102 self,
103 session: AsyncSession,
104 user_id: UUID,
105 organization_id: UUID,
106 ) -> None:
107 stmt = (
108 sql.update(UserOrganization)
109 .where(
110 UserOrganization.user_id == user_id,
111 UserOrganization.organization_id == organization_id,
112 UserOrganization.deleted_at.is_(None),
113 )
114 .values(deleted_at=utc_now())
115 )
116 await session.execute(stmt)
118 async def remove_member_safe( 1a
119 self,
120 session: AsyncSession,
121 user_id: UUID,
122 organization_id: UUID,
123 ) -> None:
124 """
125 Safely remove a member from an organization.
127 Raises:
128 OrganizationNotFound: If the organization doesn't exist
129 UserNotMemberOfOrganization: If the user is not a member of the organization
130 CannotRemoveOrganizationAdmin: If the user is the organization admin
131 """
132 from polar.organization.repository import OrganizationRepository
134 org_repo = OrganizationRepository.from_session(session)
135 organization = await org_repo.get_by_id(organization_id)
137 if not organization:
138 raise OrganizationNotFound(organization_id)
140 # Check if user is actually a member
141 user_org = await self.get_by_user_and_org(session, user_id, organization_id)
142 if not user_org:
143 raise UserNotMemberOfOrganization(user_id, organization_id)
145 # Check if the user is the organization admin
146 if organization.account_id:
147 admin_user = await org_repo.get_admin_user(session, organization)
148 if admin_user and admin_user.id == user_id:
149 raise CannotRemoveOrganizationAdmin(user_id, organization_id)
151 # Remove the member
152 await self.remove_member(session, user_id, organization_id)
154 def _get_list_by_user_id_query( 1a
155 self, user_id: UUID, ordered: bool = True
156 ) -> Select[tuple[UserOrganization]]:
157 stmt = (
158 sql.select(UserOrganization)
159 .where(
160 UserOrganization.user_id == user_id,
161 UserOrganization.deleted_at.is_(None),
162 )
163 .options(
164 joinedload(UserOrganization.user),
165 joinedload(UserOrganization.organization),
166 )
167 )
168 if ordered:
169 stmt = stmt.order_by(UserOrganization.created_at.asc())
171 return stmt
174user_organization = UserOrganizationService() 1a