Coverage for polar/kit/routing.py: 94%
88 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
1import functools 1a
2import inspect 1a
3from collections.abc import Callable 1a
4from typing import Any 1a
6from fastapi import APIRouter as _APIRouter 1a
7from fastapi.routing import APIRoute 1a
8from sqlalchemy.ext.asyncio import AsyncSession 1a
10from polar.config import settings 1a
11from polar.kit.pagination import ListResource 1a
12from polar.openapi import APITag 1a
15class AutoCommitAPIRoute(APIRoute): 1a
16 """
17 A subclass of `APIRoute` that automatically
18 commits the session after the endpoint is called.
20 It allows to directly return ORM objects from the endpoint
21 without having to call `session.commit()` before returning.
22 """
24 def __init__(self, path: str, endpoint: Callable[..., Any], **kwargs: Any) -> None: 1a
25 endpoint = self.wrap_endpoint(endpoint) 1ab
26 super().__init__(path, endpoint, **kwargs) 1ab
28 def wrap_endpoint(self, endpoint: Callable[..., Any]) -> Callable[..., Any]: 1a
29 @functools.wraps(endpoint) 1ab
30 async def wrapped_endpoint(*args: Any, **kwargs: Any) -> Any: 1ab
31 session: AsyncSession | None = None 1cb
32 for arg in (args, *kwargs.values()): 32 ↛ 37line 32 didn't jump to line 37 because the loop on line 32 didn't complete1cb
33 if isinstance(arg, AsyncSession): 1cb
34 session = arg 1cb
35 break 1cb
37 response = await endpoint(*args, **kwargs) 1cb
39 if session is not None: 39 ↛ 42line 39 didn't jump to line 42 because the condition on line 39 was always true
40 await session.commit() 1cb
42 return response
44 return wrapped_endpoint 1ab
47class IncludedInSchemaAPIRoute(APIRoute): 1a
48 """
49 A subclass of `APIRoute` that automatically sets the `include_in_schema` property
50 depending on the tags.
51 """
53 def __init__(self, path: str, endpoint: Callable[..., Any], **kwargs: Any) -> None: 1a
54 super().__init__(path, endpoint, **kwargs) 1ab
55 tags = self.tags 1ab
56 if self.include_in_schema: 1ab
57 if APITag.private in tags: 1ab
58 self.include_in_schema = settings.is_development() 1ab
59 elif APITag.public in tags: 59 ↛ 62line 59 didn't jump to line 62 because the condition on line 59 was always true1ab
60 self.include_in_schema = True 1ab
61 else:
62 self.include_in_schema = False
65class SpeakeasyNameOverrideAPIRoute(APIRoute): 1a
66 """
67 A subclass of `APIRoute` that automatically adds `x-speakeasy-name-override` property
68 following the route function name.
69 """
71 def __init__(self, path: str, endpoint: Callable[..., Any], **kwargs: Any) -> None: 1a
72 super().__init__(path, endpoint, **kwargs) 1ab
73 endpoint_name = endpoint.__name__ 1ab
74 openapi_extra = self.openapi_extra or {} 1ab
75 if "x-speakeasy-name-override" not in openapi_extra: 1ab
76 self.openapi_extra = { 1a
77 **openapi_extra,
78 "x-speakeasy-name-override": endpoint_name,
79 }
82class SpeakeasyIgnoreAPIRoute(APIRoute): 1a
83 """
84 A subclass of `APIRoute` that automatically adds `x-speakeasy-ignore` property
85 to the OpenAPI schema if `APITag.documented` is missing.
86 """
88 def __init__(self, path: str, endpoint: Callable[..., Any], **kwargs: Any) -> None: 1a
89 super().__init__(path, endpoint, **kwargs) 1ab
90 tags = self.tags 1ab
91 if APITag.public not in tags: 1ab
92 openapi_extra = self.openapi_extra or {} 1ab
93 self.openapi_extra = {**openapi_extra, "x-speakeasy-ignore": True} 1ab
96class SpeakeasyGroupAPIRoute(APIRoute): 1a
97 """
98 A subclass of `APIRoute` that automatically adds `x-speakeasy-group` property
99 to the OpenAPI schema by combining all the non-generic tags.
100 """
102 def __init__(self, path: str, endpoint: Callable[..., Any], **kwargs: Any) -> None: 1a
103 super().__init__(path, endpoint, **kwargs) 1ab
104 non_generic_tags = [str(tag) for tag in self.tags if tag not in APITag] 1ab
105 if len(non_generic_tags) > 0: 105 ↛ exitline 105 didn't return from function '__init__' because the condition on line 105 was always true1ab
106 openapi_extra = self.openapi_extra or {} 1ab
107 self.openapi_extra = { 1ab
108 **openapi_extra,
109 "x-speakeasy-group": ".".join(non_generic_tags),
110 }
113class SpeakeasyPaginationAPIRoute(APIRoute): 1a
114 """
115 A subclass of `APIRoute` that automatically adds `x-speakeasy-pagination` property
116 to the OpenAPI schema if the endpoint response model is a `ListResource`.
117 """
119 def __init__(self, path: str, endpoint: Callable[..., Any], **kwargs: Any) -> None: 1a
120 super().__init__(path, endpoint, **kwargs) 1ab
121 response_model = self.response_model 1ab
122 if (
123 response_model is not None
124 and inspect.isclass(response_model)
125 and ListResource in response_model.mro()
126 ):
127 openapi_extra = self.openapi_extra or {} 1ab
128 self.openapi_extra = { 1ab
129 **openapi_extra,
130 "x-speakeasy-pagination": {
131 "type": "offsetLimit",
132 "inputs": [
133 {
134 "name": "page",
135 "in": "parameters",
136 "type": "page",
137 },
138 {
139 "name": "limit",
140 "in": "parameters",
141 "type": "limit",
142 },
143 ],
144 "outputs": {
145 "results": "$.items",
146 "numPages": "$.pagination.max_page",
147 },
148 },
149 }
152class SpeakeasyMCPAPIRoute(APIRoute): 1a
153 """
154 A subclass of `APIRoute` that automatically adds `x-speakeasy-mcp` property
155 to the OpenAPI schema.
156 """
158 def __init__(self, path: str, endpoint: Callable[..., Any], **kwargs: Any) -> None: 1a
159 super().__init__(path, endpoint, **kwargs) 1ab
160 openapi_extra = self.openapi_extra or {} 1ab
161 if APITag.mcp in self.tags: 1ab
162 safe_method = all( 1ab
163 method in {"GET", "HEAD", "OPTIONS"} for method in self.methods
164 )
165 scopes = [ 1ab
166 "read" if safe_method else "write",
167 ]
168 non_generic_tags = [str(tag) for tag in self.tags if tag not in APITag] 1ab
169 if len(non_generic_tags) > 0: 169 ↛ 172line 169 didn't jump to line 172 because the condition on line 169 was always true1ab
170 scopes.append(".".join(non_generic_tags)) 1ab
172 openapi_extra = { 1ab
173 **openapi_extra,
174 "x-speakeasy-mcp": {"disabled": False, "scopes": scopes},
175 }
176 else:
177 openapi_extra = {**openapi_extra, "x-speakeasy-mcp": {"disabled": True}} 1ab
178 self.openapi_extra = openapi_extra 1ab
181def _inherit_signature_from[**P, T]( 1a
182 _to: Callable[P, T],
183) -> Callable[[Callable[..., T]], Callable[P, T]]:
184 return lambda x: x # pyright: ignore 1a
187def get_api_router_class(route_class: type[APIRoute]) -> type[_APIRouter]: 1a
188 """
189 Returns a subclass of `APIRouter` that uses the given `route_class`.
190 """
192 class _CustomAPIRouter(_APIRouter): 1a
193 @_inherit_signature_from(_APIRouter.__init__) 1a
194 def __init__(self, *args: Any, **kwargs: Any) -> None: 1a
195 kwargs["route_class"] = route_class 1a
196 super().__init__(*args, **kwargs) 1a
198 return _CustomAPIRouter 1a
201__all__ = [ 1a
202 "get_api_router_class",
203 "AutoCommitAPIRoute",
204 "IncludedInSchemaAPIRoute",
205 "SpeakeasyGroupAPIRoute",
206 "SpeakeasyIgnoreAPIRoute",
207 "SpeakeasyNameOverrideAPIRoute",
208 "SpeakeasyPaginationAPIRoute",
209 "SpeakeasyMCPAPIRoute",
210]