Coverage for polar/product/repository.py: 29%
61 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
1from uuid import UUID 1a
3from sqlalchemy import Select, case, func, select 1a
4from sqlalchemy.orm import contains_eager, joinedload, selectinload 1a
6from polar.auth.models import AuthSubject, Organization, User, is_organization, is_user 1a
7from polar.kit.repository import ( 1a
8 Options,
9 RepositoryBase,
10 RepositorySoftDeletionIDMixin,
11 RepositorySoftDeletionMixin,
12 RepositorySortingMixin,
13 SortingClause,
14)
15from polar.models import ( 1a
16 CheckoutProduct,
17 Product,
18 ProductPrice,
19 ProductPriceCustom,
20 ProductPriceFixed,
21 UserOrganization,
22)
23from polar.models.product_price import ProductPriceAmountType 1a
24from polar.postgres import sql 1a
26from .sorting import ProductSortProperty 1a
29class ProductRepository( 1a
30 RepositorySortingMixin[Product, ProductSortProperty],
31 RepositorySoftDeletionIDMixin[Product, UUID],
32 RepositorySoftDeletionMixin[Product],
33 RepositoryBase[Product],
34):
35 model = Product 1a
37 async def get_by_id_and_organization( 1a
38 self,
39 id: UUID,
40 organization_id: UUID,
41 *,
42 options: Options = (),
43 ) -> Product | None:
44 statement = (
45 self.get_base_statement()
46 .where(Product.id == id, Product.organization_id == organization_id)
47 .options(*options)
48 )
49 return await self.get_one_or_none(statement)
51 async def get_by_id_and_checkout( 1a
52 self,
53 id: UUID,
54 checkout_id: UUID,
55 *,
56 options: Options = (),
57 ) -> Product | None:
58 statement = (
59 self.get_base_statement()
60 .join(CheckoutProduct, onclause=Product.id == CheckoutProduct.product_id)
61 .where(
62 Product.id == id,
63 CheckoutProduct.checkout_id == checkout_id,
64 )
65 .options(*options)
66 )
67 return await self.get_one_or_none(statement)
69 def get_eager_options(self) -> Options: 1a
70 return (
71 joinedload(Product.organization),
72 selectinload(Product.product_medias),
73 selectinload(Product.attached_custom_fields),
74 selectinload(Product.all_prices),
75 )
77 def get_readable_statement( 1a
78 self, auth_subject: AuthSubject[User | Organization]
79 ) -> Select[tuple[Product]]:
80 statement = self.get_base_statement()
82 if is_user(auth_subject):
83 user = auth_subject.subject
84 statement = statement.where(
85 Product.organization_id.in_(
86 select(UserOrganization.organization_id).where(
87 UserOrganization.user_id == user.id,
88 UserOrganization.deleted_at.is_(None),
89 )
90 )
91 )
92 elif is_organization(auth_subject):
93 statement = statement.where(
94 Product.organization_id == auth_subject.subject.id
95 )
97 return statement
99 async def count_by_organization_id( 1a
100 self,
101 organization_id: UUID,
102 *,
103 is_archived: bool | None = None,
104 ) -> int:
105 """Count products for an organization with optional archived filter."""
106 statement = sql.select(sql.func.count(Product.id)).where(
107 Product.organization_id == organization_id,
108 Product.deleted_at.is_(None),
109 )
111 if is_archived is not None:
112 statement = statement.where(Product.is_archived.is_(is_archived))
114 count = await self.session.scalar(statement)
115 return count or 0
117 def get_sorting_clause(self, property: ProductSortProperty) -> SortingClause: 1a
118 match property:
119 case ProductSortProperty.created_at:
120 return Product.created_at
121 case ProductSortProperty.product_name:
122 return Product.name
123 case ProductSortProperty.price_amount_type:
124 return case(
125 (
126 ProductPrice.amount_type == ProductPriceAmountType.free,
127 1,
128 ),
129 (
130 ProductPrice.amount_type == ProductPriceAmountType.custom,
131 2,
132 ),
133 (
134 ProductPrice.amount_type == ProductPriceAmountType.fixed,
135 3,
136 ),
137 )
138 case ProductSortProperty.price_amount:
139 return case(
140 (
141 ProductPrice.amount_type == ProductPriceAmountType.free,
142 -2,
143 ),
144 (
145 ProductPrice.amount_type == ProductPriceAmountType.custom,
146 func.coalesce(ProductPriceCustom.minimum_amount, -1),
147 ),
148 (
149 ProductPrice.amount_type == ProductPriceAmountType.fixed,
150 ProductPriceFixed.price_amount,
151 ),
152 )
155class ProductPriceRepository( 1a
156 RepositorySoftDeletionIDMixin[ProductPrice, UUID],
157 RepositorySoftDeletionMixin[ProductPrice],
158 RepositoryBase[ProductPrice],
159):
160 model = ProductPrice 1a
162 async def get_readable_by_id( 1a
163 self,
164 id: UUID,
165 auth_subject: AuthSubject[User | Organization],
166 *,
167 options: Options = (),
168 ) -> ProductPrice | None:
169 statement = (
170 self.get_readable_statement(auth_subject)
171 .where(ProductPrice.id == id)
172 .options(*options)
173 )
174 return await self.get_one_or_none(statement)
176 async def get_by_stripe_price_id( 1a
177 self, stripe_price_id: str, *, options: Options = ()
178 ) -> ProductPrice | None:
179 statement = (
180 self.get_base_statement()
181 .where(ProductPrice.__table__.c["stripe_price_id"] == stripe_price_id)
182 .options(*options)
183 )
184 return await self.get_one_or_none(statement)
186 def get_eager_options(self) -> Options: 1a
187 return (joinedload(ProductPrice.product),)
189 def get_readable_statement( 1a
190 self, auth_subject: AuthSubject[User | Organization]
191 ) -> Select[tuple[ProductPrice]]:
192 statement = (
193 self.get_base_statement()
194 .join(Product, Product.id == ProductPrice.product_id)
195 .options(contains_eager(ProductPrice.product))
196 )
198 if is_user(auth_subject):
199 user = auth_subject.subject
200 statement = statement.where(
201 Product.organization_id.in_(
202 select(UserOrganization.organization_id).where(
203 UserOrganization.user_id == user.id,
204 UserOrganization.deleted_at.is_(None),
205 )
206 )
207 )
208 elif is_organization(auth_subject):
209 statement = statement.where(
210 Product.organization_id == auth_subject.subject.id,
211 )
213 return statement