Coverage for polar/billing_entry/repository.py: 43%

38 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1from collections.abc import AsyncGenerator, Sequence 1a

2from datetime import datetime 1a

3from uuid import UUID 1a

4 

5from sqlalchemy import Select, func, update 1a

6from sqlalchemy.orm.strategy_options import contains_eager 1a

7 

8from polar.config import settings 1a

9from polar.kit.repository import ( 1a

10 Options, 

11 RepositoryBase, 

12 RepositorySoftDeletionIDMixin, 

13 RepositorySoftDeletionMixin, 

14) 

15from polar.models import BillingEntry 1a

16from polar.models.product_price import ProductPrice, ProductPriceMeteredUnit 1a

17 

18 

19class BillingEntryRepository( 1a

20 RepositorySoftDeletionIDMixin[BillingEntry, UUID], 

21 RepositorySoftDeletionMixin[BillingEntry], 

22 RepositoryBase[BillingEntry], 

23): 

24 model = BillingEntry 1a

25 

26 async def update_order_item_id( 1a

27 self, billing_entries: Sequence[UUID], order_item_id: UUID 

28 ) -> None: 

29 statement = ( 

30 update(self.model) 

31 .where( 

32 self.model.id.in_(billing_entries), 

33 self.model.order_item_id.is_(None), 

34 ) 

35 .values(order_item_id=order_item_id) 

36 ) 

37 await self.session.execute(statement) 

38 

39 async def get_pending_by_subscription( 1a

40 self, subscription_id: UUID, *, options: Options = () 

41 ) -> Sequence[BillingEntry]: 

42 statement = self.get_pending_by_subscription_statement( 

43 subscription_id, options=options 

44 ) 

45 return await self.get_all(statement) 

46 

47 async def get_static_pending_by_subscription( 1a

48 self, subscription_id: UUID 

49 ) -> AsyncGenerator[BillingEntry]: 

50 statement = ( 

51 self.get_pending_by_subscription_statement(subscription_id) 

52 .join(BillingEntry.product_price) 

53 .where(ProductPrice.is_static.is_(True)) 

54 .options(contains_eager(BillingEntry.product_price)) 

55 ) 

56 async for result in self.stream(statement): 

57 yield result 

58 

59 async def get_pending_metered_by_subscription_tuples( 1a

60 self, subscription_id: UUID 

61 ) -> AsyncGenerator[tuple[UUID, UUID, datetime, datetime]]: 

62 """ 

63 Get pending metered billing entries grouped by (product_price_id, meter_id). 

64 

65 Returns tuples of (product_price_id, meter_id, start_timestamp, end_timestamp). 

66 

67 For summable aggregations (count, sum): Each tuple represents entries to bill separately. 

68 For non-summable aggregations (max, min, avg, unique): Multiple tuples for the same 

69 meter_id will be returned (one per price), but only the first is processed by the 

70 service layer - the rest are skipped. The active price is determined from 

71 subscription.subscription_product_prices, not from these tuples. 

72 """ 

73 statement = ( 

74 self.get_pending_by_subscription_statement(subscription_id) 

75 .join( 

76 ProductPriceMeteredUnit, 

77 BillingEntry.product_price_id == ProductPriceMeteredUnit.id, 

78 ) 

79 .with_only_columns( 

80 BillingEntry.product_price_id, 

81 ProductPriceMeteredUnit.meter_id, 

82 func.min(BillingEntry.start_timestamp), 

83 func.max(BillingEntry.end_timestamp), 

84 ) 

85 .group_by(BillingEntry.product_price_id, ProductPriceMeteredUnit.meter_id) 

86 .order_by(None) # Clear existing ORDER BY from base statement 

87 .order_by(ProductPriceMeteredUnit.meter_id.asc()) 

88 ) 

89 results = await self.session.stream( 

90 statement, 

91 execution_options={"yield_per": settings.DATABASE_STREAM_YIELD_PER}, 

92 ) 

93 try: 

94 async for result in results: 

95 yield result._tuple() 

96 finally: 

97 await results.close() 

98 

99 async def get_pending_ids_by_subscription_and_price( 1a

100 self, subscription_id: UUID, product_price_id: UUID 

101 ) -> Sequence[UUID]: 

102 statement = ( 

103 self.get_pending_by_subscription_statement(subscription_id) 

104 .with_only_columns(BillingEntry.id) 

105 .where(BillingEntry.product_price_id == product_price_id) 

106 ) 

107 results = await self.session.execute(statement) 

108 return results.scalars().unique().all() 

109 

110 async def get_pending_ids_by_subscription_and_meter( 1a

111 self, subscription_id: UUID, meter_id: UUID 

112 ) -> Sequence[UUID]: 

113 """ 

114 Get all pending billing entry IDs for a subscription and meter across all prices. 

115 """ 

116 statement = ( 

117 self.get_pending_by_subscription_statement(subscription_id) 

118 .join( 

119 ProductPriceMeteredUnit, 

120 BillingEntry.product_price_id == ProductPriceMeteredUnit.id, 

121 ) 

122 .with_only_columns(BillingEntry.id) 

123 .where(ProductPriceMeteredUnit.meter_id == meter_id) 

124 ) 

125 results = await self.session.execute(statement) 

126 return results.scalars().unique().all() 

127 

128 def get_pending_by_subscription_statement( 1a

129 self, subscription_id: UUID, *, options: Options = () 

130 ) -> Select[tuple["BillingEntry"]]: 

131 return ( 

132 self.get_base_statement() 

133 .where( 

134 BillingEntry.order_item_id.is_(None), 

135 BillingEntry.subscription_id == subscription_id, 

136 ) 

137 .order_by(BillingEntry.product_price_id.asc()) 

138 .options(*options) 

139 )