Coverage for polar/kit/trial.py: 50%
40 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 15:52 +0000
1from datetime import datetime 1ab
2from enum import StrEnum 1ab
3from typing import Self 1ab
5from dateutil.relativedelta import relativedelta 1ab
6from pydantic import BaseModel, Field, model_validator 1ab
7from pydantic_core import PydanticCustomError 1ab
8from sqlalchemy import Integer 1ab
9from sqlalchemy.orm import Mapped, mapped_column 1ab
11from polar.kit.extensions.sqlalchemy.types import StringEnum 1ab
14class TrialInterval(StrEnum): 1ab
15 day = "day" 1ab
16 week = "week" 1ab
17 month = "month" 1ab
18 year = "year" 1ab
20 def get_end(self, d: datetime, count: int) -> datetime: 1ab
21 match self:
22 case TrialInterval.day:
23 return d + relativedelta(days=count)
24 case TrialInterval.week:
25 return d + relativedelta(weeks=count)
26 case TrialInterval.month:
27 return d + relativedelta(months=count)
28 case TrialInterval.year:
29 return d + relativedelta(years=count)
32class TrialConfigurationMixin: 1ab
33 trial_interval: Mapped[TrialInterval | None] = mapped_column( 1ab
34 StringEnum(TrialInterval), nullable=True, default=None
35 )
36 trial_interval_count: Mapped[int | None] = mapped_column( 1ab
37 Integer, nullable=True, default=None
38 )
41class TrialConfigurationInputMixin(BaseModel): 1ab
42 trial_interval: TrialInterval | None = Field( 1ab
43 default=None, description="The interval unit for the trial period."
44 )
45 trial_interval_count: int | None = Field( 1ab
46 default=None,
47 description="The number of interval units for the trial period.",
48 ge=1,
49 le=1000,
50 )
52 @model_validator(mode="after") 1ab
53 def is_complete_configuration(self) -> Self: 1ab
54 if self.trial_interval is None and self.trial_interval_count is None:
55 return self
57 if self.trial_interval is not None and self.trial_interval_count is not None:
58 return self
60 raise PydanticCustomError(
61 "missing",
62 "Both trial_interval and trial_interval_count must be set together.",
63 )
66class TrialConfigurationOutputMixin(BaseModel): 1ab
67 trial_interval: TrialInterval | None = Field( 1ab
68 description="The interval unit for the trial period."
69 )
70 trial_interval_count: int | None = Field( 1ab
71 description="The number of interval units for the trial period."
72 )