Coverage for polar/product/repository.py: 29%

61 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 15:52 +0000

1from uuid import UUID 1a

2 

3from sqlalchemy import Select, case, func, select 1a

4from sqlalchemy.orm import contains_eager, joinedload, selectinload 1a

5 

6from polar.auth.models import AuthSubject, Organization, User, is_organization, is_user 1a

7from polar.kit.repository import ( 1a

8 Options, 

9 RepositoryBase, 

10 RepositorySoftDeletionIDMixin, 

11 RepositorySoftDeletionMixin, 

12 RepositorySortingMixin, 

13 SortingClause, 

14) 

15from polar.models import ( 1a

16 CheckoutProduct, 

17 Product, 

18 ProductPrice, 

19 ProductPriceCustom, 

20 ProductPriceFixed, 

21 UserOrganization, 

22) 

23from polar.models.product_price import ProductPriceAmountType 1a

24from polar.postgres import sql 1a

25 

26from .sorting import ProductSortProperty 1a

27 

28 

29class ProductRepository( 1a

30 RepositorySortingMixin[Product, ProductSortProperty], 

31 RepositorySoftDeletionIDMixin[Product, UUID], 

32 RepositorySoftDeletionMixin[Product], 

33 RepositoryBase[Product], 

34): 

35 model = Product 1a

36 

37 async def get_by_id_and_organization( 1a

38 self, 

39 id: UUID, 

40 organization_id: UUID, 

41 *, 

42 options: Options = (), 

43 ) -> Product | None: 

44 statement = ( 

45 self.get_base_statement() 

46 .where(Product.id == id, Product.organization_id == organization_id) 

47 .options(*options) 

48 ) 

49 return await self.get_one_or_none(statement) 

50 

51 async def get_by_id_and_checkout( 1a

52 self, 

53 id: UUID, 

54 checkout_id: UUID, 

55 *, 

56 options: Options = (), 

57 ) -> Product | None: 

58 statement = ( 

59 self.get_base_statement() 

60 .join(CheckoutProduct, onclause=Product.id == CheckoutProduct.product_id) 

61 .where( 

62 Product.id == id, 

63 CheckoutProduct.checkout_id == checkout_id, 

64 ) 

65 .options(*options) 

66 ) 

67 return await self.get_one_or_none(statement) 

68 

69 def get_eager_options(self) -> Options: 1a

70 return ( 

71 joinedload(Product.organization), 

72 selectinload(Product.product_medias), 

73 selectinload(Product.attached_custom_fields), 

74 selectinload(Product.all_prices), 

75 ) 

76 

77 def get_readable_statement( 1a

78 self, auth_subject: AuthSubject[User | Organization] 

79 ) -> Select[tuple[Product]]: 

80 statement = self.get_base_statement() 

81 

82 if is_user(auth_subject): 

83 user = auth_subject.subject 

84 statement = statement.where( 

85 Product.organization_id.in_( 

86 select(UserOrganization.organization_id).where( 

87 UserOrganization.user_id == user.id, 

88 UserOrganization.deleted_at.is_(None), 

89 ) 

90 ) 

91 ) 

92 elif is_organization(auth_subject): 

93 statement = statement.where( 

94 Product.organization_id == auth_subject.subject.id 

95 ) 

96 

97 return statement 

98 

99 async def count_by_organization_id( 1a

100 self, 

101 organization_id: UUID, 

102 *, 

103 is_archived: bool | None = None, 

104 ) -> int: 

105 """Count products for an organization with optional archived filter.""" 

106 statement = sql.select(sql.func.count(Product.id)).where( 

107 Product.organization_id == organization_id, 

108 Product.deleted_at.is_(None), 

109 ) 

110 

111 if is_archived is not None: 

112 statement = statement.where(Product.is_archived.is_(is_archived)) 

113 

114 count = await self.session.scalar(statement) 

115 return count or 0 

116 

117 def get_sorting_clause(self, property: ProductSortProperty) -> SortingClause: 1a

118 match property: 

119 case ProductSortProperty.created_at: 

120 return Product.created_at 

121 case ProductSortProperty.product_name: 

122 return Product.name 

123 case ProductSortProperty.price_amount_type: 

124 return case( 

125 ( 

126 ProductPrice.amount_type == ProductPriceAmountType.free, 

127 1, 

128 ), 

129 ( 

130 ProductPrice.amount_type == ProductPriceAmountType.custom, 

131 2, 

132 ), 

133 ( 

134 ProductPrice.amount_type == ProductPriceAmountType.fixed, 

135 3, 

136 ), 

137 ) 

138 case ProductSortProperty.price_amount: 

139 return case( 

140 ( 

141 ProductPrice.amount_type == ProductPriceAmountType.free, 

142 -2, 

143 ), 

144 ( 

145 ProductPrice.amount_type == ProductPriceAmountType.custom, 

146 func.coalesce(ProductPriceCustom.minimum_amount, -1), 

147 ), 

148 ( 

149 ProductPrice.amount_type == ProductPriceAmountType.fixed, 

150 ProductPriceFixed.price_amount, 

151 ), 

152 ) 

153 

154 

155class ProductPriceRepository( 1a

156 RepositorySoftDeletionIDMixin[ProductPrice, UUID], 

157 RepositorySoftDeletionMixin[ProductPrice], 

158 RepositoryBase[ProductPrice], 

159): 

160 model = ProductPrice 1a

161 

162 async def get_readable_by_id( 1a

163 self, 

164 id: UUID, 

165 auth_subject: AuthSubject[User | Organization], 

166 *, 

167 options: Options = (), 

168 ) -> ProductPrice | None: 

169 statement = ( 

170 self.get_readable_statement(auth_subject) 

171 .where(ProductPrice.id == id) 

172 .options(*options) 

173 ) 

174 return await self.get_one_or_none(statement) 

175 

176 async def get_by_stripe_price_id( 1a

177 self, stripe_price_id: str, *, options: Options = () 

178 ) -> ProductPrice | None: 

179 statement = ( 

180 self.get_base_statement() 

181 .where(ProductPrice.__table__.c["stripe_price_id"] == stripe_price_id) 

182 .options(*options) 

183 ) 

184 return await self.get_one_or_none(statement) 

185 

186 def get_eager_options(self) -> Options: 1a

187 return (joinedload(ProductPrice.product),) 

188 

189 def get_readable_statement( 1a

190 self, auth_subject: AuthSubject[User | Organization] 

191 ) -> Select[tuple[ProductPrice]]: 

192 statement = ( 

193 self.get_base_statement() 

194 .join(Product, Product.id == ProductPrice.product_id) 

195 .options(contains_eager(ProductPrice.product)) 

196 ) 

197 

198 if is_user(auth_subject): 

199 user = auth_subject.subject 

200 statement = statement.where( 

201 Product.organization_id.in_( 

202 select(UserOrganization.organization_id).where( 

203 UserOrganization.user_id == user.id, 

204 UserOrganization.deleted_at.is_(None), 

205 ) 

206 ) 

207 ) 

208 elif is_organization(auth_subject): 

209 statement = statement.where( 

210 Product.organization_id == auth_subject.subject.id, 

211 ) 

212 

213 return statement