Coverage for polar/organization/repository.py: 37%

64 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 15:52 +0000

1from collections.abc import Sequence 1a

2from uuid import UUID 1a

3 

4from sqlalchemy import Select, func, select 1a

5 

6from polar.auth.models import AuthSubject, is_organization, is_user 1a

7from polar.kit.repository import ( 1a

8 RepositoryBase, 

9 RepositorySoftDeletionIDMixin, 

10 RepositorySoftDeletionMixin, 

11 RepositorySortingMixin, 

12 SortingClause, 

13) 

14from polar.kit.repository.base import Options 1a

15from polar.models import Account, Customer, Organization, User, UserOrganization 1a

16from polar.models.organization_review import OrganizationReview 1a

17from polar.postgres import AsyncSession 1a

18 

19from .sorting import OrganizationSortProperty 1a

20 

21 

22class OrganizationRepository( 1a

23 RepositorySortingMixin[Organization, OrganizationSortProperty], 

24 RepositorySoftDeletionIDMixin[Organization, UUID], 

25 RepositorySoftDeletionMixin[Organization], 

26 RepositoryBase[Organization], 

27): 

28 model = Organization 1a

29 

30 async def get_by_id( 1a

31 self, 

32 id: UUID, 

33 *, 

34 options: Options = (), 

35 include_deleted: bool = False, 

36 include_blocked: bool = False, 

37 ) -> Organization | None: 

38 statement = ( 

39 self.get_base_statement(include_deleted=include_deleted) 

40 .where(self.model.id == id) 

41 .options(*options) 

42 ) 

43 

44 if not include_blocked: 

45 statement = statement.where(self.model.blocked_at.is_(None)) 

46 

47 return await self.get_one_or_none(statement) 

48 

49 async def get_by_slug(self, slug: str) -> Organization | None: 1a

50 statement = self.get_base_statement().where(Organization.slug == slug) 

51 return await self.get_one_or_none(statement) 

52 

53 async def slug_exists(self, slug: str) -> bool: 1a

54 """Check if slug exists, including soft-deleted organizations. 

55 

56 Soft-deleted organizations are included to prevent slug reuse, 

57 ensuring backoffice links continue to work. 

58 """ 

59 statement = self.get_base_statement(include_deleted=True).where( 

60 Organization.slug == slug 

61 ) 

62 result = await self.get_one_or_none(statement) 

63 return result is not None 

64 

65 async def get_by_customer(self, customer_id: UUID) -> Organization: 1a

66 statement = ( 

67 self.get_base_statement() 

68 .join(Customer, Customer.organization_id == Organization.id) 

69 .where(Customer.id == customer_id) 

70 ) 

71 return await self.get_one(statement) 

72 

73 async def get_all_by_user(self, user: UUID) -> Sequence[Organization]: 1a

74 statement = ( 

75 self.get_base_statement() 

76 .join(UserOrganization) 

77 .where( 

78 UserOrganization.user_id == user, 

79 UserOrganization.deleted_at.is_(None), 

80 Organization.blocked_at.is_(None), 

81 ) 

82 ) 

83 return await self.get_all(statement) 

84 

85 async def get_all_by_account( 1a

86 self, account: UUID, *, options: Options = () 

87 ) -> Sequence[Organization]: 

88 statement = ( 

89 self.get_base_statement() 

90 .where( 

91 Organization.account_id == account, 

92 Organization.blocked_at.is_(None), 

93 ) 

94 .options(*options) 

95 ) 

96 return await self.get_all(statement) 

97 

98 def get_sorting_clause(self, property: OrganizationSortProperty) -> SortingClause: 1a

99 match property: 

100 case OrganizationSortProperty.created_at: 

101 return self.model.created_at 

102 case OrganizationSortProperty.slug: 

103 return self.model.slug 

104 case OrganizationSortProperty.organization_name: 

105 return self.model.name 

106 case OrganizationSortProperty.next_review_threshold: 

107 return self.model.next_review_threshold 

108 case OrganizationSortProperty.days_in_status: 

109 # Calculate days since status was last updated 

110 return ( 

111 func.extract( 

112 "epoch", 

113 func.now() 

114 - func.coalesce( 

115 self.model.status_updated_at, self.model.modified_at 

116 ), 

117 ) 

118 / 86400 

119 ) 

120 

121 def get_readable_statement( 1a

122 self, auth_subject: AuthSubject[User | Organization] 

123 ) -> Select[tuple[Organization]]: 

124 statement = self.get_base_statement().where(Organization.blocked_at.is_(None)) 1b

125 

126 if is_user(auth_subject): 126 ↛ 136line 126 didn't jump to line 136 because the condition on line 126 was always true1b

127 user = auth_subject.subject 1b

128 statement = statement.where( 1b

129 Organization.id.in_( 

130 select(UserOrganization.organization_id).where( 

131 UserOrganization.user_id == user.id, 

132 UserOrganization.deleted_at.is_(None), 

133 ) 

134 ) 

135 ) 

136 elif is_organization(auth_subject): 

137 statement = statement.where( 

138 Organization.id == auth_subject.subject.id, 

139 ) 

140 

141 return statement 1b

142 

143 async def get_admin_user( 1a

144 self, session: AsyncSession, organization: Organization 

145 ) -> User | None: 

146 """Get the admin user of the organization from the associated account.""" 

147 if not organization.account_id: 

148 return None 

149 

150 statement = ( 

151 select(User) 

152 .join(Account, Account.admin_id == User.id) 

153 .where( 

154 Account.id == organization.account_id, 

155 User.deleted_at.is_(None), 

156 ) 

157 ) 

158 result = await session.execute(statement) 

159 return result.unique().scalar_one_or_none() 

160 

161 

162class OrganizationReviewRepository(RepositoryBase[OrganizationReview]): 1a

163 model = OrganizationReview 1a

164 

165 async def get_by_organization( 1a

166 self, organization_id: UUID 

167 ) -> OrganizationReview | None: 

168 statement = self.get_base_statement().where( 

169 OrganizationReview.organization_id == organization_id, 

170 OrganizationReview.deleted_at.is_(None), 

171 ) 

172 return await self.get_one_or_none(statement)