Coverage for polar/kit/trial.py: 50%

40 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 16:17 +0000

1from datetime import datetime 1ab

2from enum import StrEnum 1ab

3from typing import Self 1ab

4 

5from dateutil.relativedelta import relativedelta 1ab

6from pydantic import BaseModel, Field, model_validator 1ab

7from pydantic_core import PydanticCustomError 1ab

8from sqlalchemy import Integer 1ab

9from sqlalchemy.orm import Mapped, mapped_column 1ab

10 

11from polar.kit.extensions.sqlalchemy.types import StringEnum 1ab

12 

13 

14class TrialInterval(StrEnum): 1ab

15 day = "day" 1ab

16 week = "week" 1ab

17 month = "month" 1ab

18 year = "year" 1ab

19 

20 def get_end(self, d: datetime, count: int) -> datetime: 1ab

21 match self: 

22 case TrialInterval.day: 

23 return d + relativedelta(days=count) 

24 case TrialInterval.week: 

25 return d + relativedelta(weeks=count) 

26 case TrialInterval.month: 

27 return d + relativedelta(months=count) 

28 case TrialInterval.year: 

29 return d + relativedelta(years=count) 

30 

31 

32class TrialConfigurationMixin: 1ab

33 trial_interval: Mapped[TrialInterval | None] = mapped_column( 1ab

34 StringEnum(TrialInterval), nullable=True, default=None 

35 ) 

36 trial_interval_count: Mapped[int | None] = mapped_column( 1ab

37 Integer, nullable=True, default=None 

38 ) 

39 

40 

41class TrialConfigurationInputMixin(BaseModel): 1ab

42 trial_interval: TrialInterval | None = Field( 1ab

43 default=None, description="The interval unit for the trial period." 

44 ) 

45 trial_interval_count: int | None = Field( 1ab

46 default=None, 

47 description="The number of interval units for the trial period.", 

48 ge=1, 

49 le=1000, 

50 ) 

51 

52 @model_validator(mode="after") 1ab

53 def is_complete_configuration(self) -> Self: 1ab

54 if self.trial_interval is None and self.trial_interval_count is None: 

55 return self 

56 

57 if self.trial_interval is not None and self.trial_interval_count is not None: 

58 return self 

59 

60 raise PydanticCustomError( 

61 "missing", 

62 "Both trial_interval and trial_interval_count must be set together.", 

63 ) 

64 

65 

66class TrialConfigurationOutputMixin(BaseModel): 1ab

67 trial_interval: TrialInterval | None = Field( 1ab

68 description="The interval unit for the trial period." 

69 ) 

70 trial_interval_count: int | None = Field( 1ab

71 description="The number of interval units for the trial period." 

72 )