Coverage for polar/organization/repository.py: 37%
64 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
1from collections.abc import Sequence 1a
2from uuid import UUID 1a
4from sqlalchemy import Select, func, select 1a
6from polar.auth.models import AuthSubject, is_organization, is_user 1a
7from polar.kit.repository import ( 1a
8 RepositoryBase,
9 RepositorySoftDeletionIDMixin,
10 RepositorySoftDeletionMixin,
11 RepositorySortingMixin,
12 SortingClause,
13)
14from polar.kit.repository.base import Options 1a
15from polar.models import Account, Customer, Organization, User, UserOrganization 1a
16from polar.models.organization_review import OrganizationReview 1a
17from polar.postgres import AsyncSession 1a
19from .sorting import OrganizationSortProperty 1a
22class OrganizationRepository( 1a
23 RepositorySortingMixin[Organization, OrganizationSortProperty],
24 RepositorySoftDeletionIDMixin[Organization, UUID],
25 RepositorySoftDeletionMixin[Organization],
26 RepositoryBase[Organization],
27):
28 model = Organization 1a
30 async def get_by_id( 1a
31 self,
32 id: UUID,
33 *,
34 options: Options = (),
35 include_deleted: bool = False,
36 include_blocked: bool = False,
37 ) -> Organization | None:
38 statement = (
39 self.get_base_statement(include_deleted=include_deleted)
40 .where(self.model.id == id)
41 .options(*options)
42 )
44 if not include_blocked:
45 statement = statement.where(self.model.blocked_at.is_(None))
47 return await self.get_one_or_none(statement)
49 async def get_by_slug(self, slug: str) -> Organization | None: 1a
50 statement = self.get_base_statement().where(Organization.slug == slug)
51 return await self.get_one_or_none(statement)
53 async def slug_exists(self, slug: str) -> bool: 1a
54 """Check if slug exists, including soft-deleted organizations.
56 Soft-deleted organizations are included to prevent slug reuse,
57 ensuring backoffice links continue to work.
58 """
59 statement = self.get_base_statement(include_deleted=True).where(
60 Organization.slug == slug
61 )
62 result = await self.get_one_or_none(statement)
63 return result is not None
65 async def get_by_customer(self, customer_id: UUID) -> Organization: 1a
66 statement = (
67 self.get_base_statement()
68 .join(Customer, Customer.organization_id == Organization.id)
69 .where(Customer.id == customer_id)
70 )
71 return await self.get_one(statement)
73 async def get_all_by_user(self, user: UUID) -> Sequence[Organization]: 1a
74 statement = (
75 self.get_base_statement()
76 .join(UserOrganization)
77 .where(
78 UserOrganization.user_id == user,
79 UserOrganization.deleted_at.is_(None),
80 Organization.blocked_at.is_(None),
81 )
82 )
83 return await self.get_all(statement)
85 async def get_all_by_account( 1a
86 self, account: UUID, *, options: Options = ()
87 ) -> Sequence[Organization]:
88 statement = (
89 self.get_base_statement()
90 .where(
91 Organization.account_id == account,
92 Organization.blocked_at.is_(None),
93 )
94 .options(*options)
95 )
96 return await self.get_all(statement)
98 def get_sorting_clause(self, property: OrganizationSortProperty) -> SortingClause: 1a
99 match property:
100 case OrganizationSortProperty.created_at:
101 return self.model.created_at
102 case OrganizationSortProperty.slug:
103 return self.model.slug
104 case OrganizationSortProperty.organization_name:
105 return self.model.name
106 case OrganizationSortProperty.next_review_threshold:
107 return self.model.next_review_threshold
108 case OrganizationSortProperty.days_in_status:
109 # Calculate days since status was last updated
110 return (
111 func.extract(
112 "epoch",
113 func.now()
114 - func.coalesce(
115 self.model.status_updated_at, self.model.modified_at
116 ),
117 )
118 / 86400
119 )
121 def get_readable_statement( 1a
122 self, auth_subject: AuthSubject[User | Organization]
123 ) -> Select[tuple[Organization]]:
124 statement = self.get_base_statement().where(Organization.blocked_at.is_(None)) 1b
126 if is_user(auth_subject): 126 ↛ 136line 126 didn't jump to line 136 because the condition on line 126 was always true1b
127 user = auth_subject.subject 1b
128 statement = statement.where( 1b
129 Organization.id.in_(
130 select(UserOrganization.organization_id).where(
131 UserOrganization.user_id == user.id,
132 UserOrganization.deleted_at.is_(None),
133 )
134 )
135 )
136 elif is_organization(auth_subject):
137 statement = statement.where(
138 Organization.id == auth_subject.subject.id,
139 )
141 return statement 1b
143 async def get_admin_user( 1a
144 self, session: AsyncSession, organization: Organization
145 ) -> User | None:
146 """Get the admin user of the organization from the associated account."""
147 if not organization.account_id:
148 return None
150 statement = (
151 select(User)
152 .join(Account, Account.admin_id == User.id)
153 .where(
154 Account.id == organization.account_id,
155 User.deleted_at.is_(None),
156 )
157 )
158 result = await session.execute(statement)
159 return result.unique().scalar_one_or_none()
162class OrganizationReviewRepository(RepositoryBase[OrganizationReview]): 1a
163 model = OrganizationReview 1a
165 async def get_by_organization( 1a
166 self, organization_id: UUID
167 ) -> OrganizationReview | None:
168 statement = self.get_base_statement().where(
169 OrganizationReview.organization_id == organization_id,
170 OrganizationReview.deleted_at.is_(None),
171 )
172 return await self.get_one_or_none(statement)