Coverage for polar/billing_entry/repository.py: 43%
38 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
1from collections.abc import AsyncGenerator, Sequence 1a
2from datetime import datetime 1a
3from uuid import UUID 1a
5from sqlalchemy import Select, func, update 1a
6from sqlalchemy.orm.strategy_options import contains_eager 1a
8from polar.config import settings 1a
9from polar.kit.repository import ( 1a
10 Options,
11 RepositoryBase,
12 RepositorySoftDeletionIDMixin,
13 RepositorySoftDeletionMixin,
14)
15from polar.models import BillingEntry 1a
16from polar.models.product_price import ProductPrice, ProductPriceMeteredUnit 1a
19class BillingEntryRepository( 1a
20 RepositorySoftDeletionIDMixin[BillingEntry, UUID],
21 RepositorySoftDeletionMixin[BillingEntry],
22 RepositoryBase[BillingEntry],
23):
24 model = BillingEntry 1a
26 async def update_order_item_id( 1a
27 self, billing_entries: Sequence[UUID], order_item_id: UUID
28 ) -> None:
29 statement = (
30 update(self.model)
31 .where(
32 self.model.id.in_(billing_entries),
33 self.model.order_item_id.is_(None),
34 )
35 .values(order_item_id=order_item_id)
36 )
37 await self.session.execute(statement)
39 async def get_pending_by_subscription( 1a
40 self, subscription_id: UUID, *, options: Options = ()
41 ) -> Sequence[BillingEntry]:
42 statement = self.get_pending_by_subscription_statement(
43 subscription_id, options=options
44 )
45 return await self.get_all(statement)
47 async def get_static_pending_by_subscription( 1a
48 self, subscription_id: UUID
49 ) -> AsyncGenerator[BillingEntry]:
50 statement = (
51 self.get_pending_by_subscription_statement(subscription_id)
52 .join(BillingEntry.product_price)
53 .where(ProductPrice.is_static.is_(True))
54 .options(contains_eager(BillingEntry.product_price))
55 )
56 async for result in self.stream(statement):
57 yield result
59 async def get_pending_metered_by_subscription_tuples( 1a
60 self, subscription_id: UUID
61 ) -> AsyncGenerator[tuple[UUID, UUID, datetime, datetime]]:
62 """
63 Get pending metered billing entries grouped by (product_price_id, meter_id).
65 Returns tuples of (product_price_id, meter_id, start_timestamp, end_timestamp).
67 For summable aggregations (count, sum): Each tuple represents entries to bill separately.
68 For non-summable aggregations (max, min, avg, unique): Multiple tuples for the same
69 meter_id will be returned (one per price), but only the first is processed by the
70 service layer - the rest are skipped. The active price is determined from
71 subscription.subscription_product_prices, not from these tuples.
72 """
73 statement = (
74 self.get_pending_by_subscription_statement(subscription_id)
75 .join(
76 ProductPriceMeteredUnit,
77 BillingEntry.product_price_id == ProductPriceMeteredUnit.id,
78 )
79 .with_only_columns(
80 BillingEntry.product_price_id,
81 ProductPriceMeteredUnit.meter_id,
82 func.min(BillingEntry.start_timestamp),
83 func.max(BillingEntry.end_timestamp),
84 )
85 .group_by(BillingEntry.product_price_id, ProductPriceMeteredUnit.meter_id)
86 .order_by(None) # Clear existing ORDER BY from base statement
87 .order_by(ProductPriceMeteredUnit.meter_id.asc())
88 )
89 results = await self.session.stream(
90 statement,
91 execution_options={"yield_per": settings.DATABASE_STREAM_YIELD_PER},
92 )
93 try:
94 async for result in results:
95 yield result._tuple()
96 finally:
97 await results.close()
99 async def get_pending_ids_by_subscription_and_price( 1a
100 self, subscription_id: UUID, product_price_id: UUID
101 ) -> Sequence[UUID]:
102 statement = (
103 self.get_pending_by_subscription_statement(subscription_id)
104 .with_only_columns(BillingEntry.id)
105 .where(BillingEntry.product_price_id == product_price_id)
106 )
107 results = await self.session.execute(statement)
108 return results.scalars().unique().all()
110 async def get_pending_ids_by_subscription_and_meter( 1a
111 self, subscription_id: UUID, meter_id: UUID
112 ) -> Sequence[UUID]:
113 """
114 Get all pending billing entry IDs for a subscription and meter across all prices.
115 """
116 statement = (
117 self.get_pending_by_subscription_statement(subscription_id)
118 .join(
119 ProductPriceMeteredUnit,
120 BillingEntry.product_price_id == ProductPriceMeteredUnit.id,
121 )
122 .with_only_columns(BillingEntry.id)
123 .where(ProductPriceMeteredUnit.meter_id == meter_id)
124 )
125 results = await self.session.execute(statement)
126 return results.scalars().unique().all()
128 def get_pending_by_subscription_statement( 1a
129 self, subscription_id: UUID, *, options: Options = ()
130 ) -> Select[tuple["BillingEntry"]]:
131 return (
132 self.get_base_statement()
133 .where(
134 BillingEntry.order_item_id.is_(None),
135 BillingEntry.subscription_id == subscription_id,
136 )
137 .order_by(BillingEntry.product_price_id.asc())
138 .options(*options)
139 )