Coverage for polar/kit/routing.py: 94%

88 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 16:17 +0000

1import functools 1a

2import inspect 1a

3from collections.abc import Callable 1a

4from typing import Any 1a

5 

6from fastapi import APIRouter as _APIRouter 1a

7from fastapi.routing import APIRoute 1a

8from sqlalchemy.ext.asyncio import AsyncSession 1a

9 

10from polar.config import settings 1a

11from polar.kit.pagination import ListResource 1a

12from polar.openapi import APITag 1a

13 

14 

15class AutoCommitAPIRoute(APIRoute): 1a

16 """ 

17 A subclass of `APIRoute` that automatically 

18 commits the session after the endpoint is called. 

19 

20 It allows to directly return ORM objects from the endpoint 

21 without having to call `session.commit()` before returning. 

22 """ 

23 

24 def __init__(self, path: str, endpoint: Callable[..., Any], **kwargs: Any) -> None: 1a

25 endpoint = self.wrap_endpoint(endpoint) 1ab

26 super().__init__(path, endpoint, **kwargs) 1ab

27 

28 def wrap_endpoint(self, endpoint: Callable[..., Any]) -> Callable[..., Any]: 1a

29 @functools.wraps(endpoint) 1ab

30 async def wrapped_endpoint(*args: Any, **kwargs: Any) -> Any: 1ab

31 session: AsyncSession | None = None 1cb

32 for arg in (args, *kwargs.values()): 32 ↛ 37line 32 didn't jump to line 37 because the loop on line 32 didn't complete1cb

33 if isinstance(arg, AsyncSession): 1cb

34 session = arg 1cb

35 break 1cb

36 

37 response = await endpoint(*args, **kwargs) 1cb

38 

39 if session is not None: 39 ↛ 42line 39 didn't jump to line 42 because the condition on line 39 was always true

40 await session.commit() 1cb

41 

42 return response 

43 

44 return wrapped_endpoint 1ab

45 

46 

47class IncludedInSchemaAPIRoute(APIRoute): 1a

48 """ 

49 A subclass of `APIRoute` that automatically sets the `include_in_schema` property 

50 depending on the tags. 

51 """ 

52 

53 def __init__(self, path: str, endpoint: Callable[..., Any], **kwargs: Any) -> None: 1a

54 super().__init__(path, endpoint, **kwargs) 1ab

55 tags = self.tags 1ab

56 if self.include_in_schema: 1ab

57 if APITag.private in tags: 1ab

58 self.include_in_schema = settings.is_development() 1ab

59 elif APITag.public in tags: 59 ↛ 62line 59 didn't jump to line 62 because the condition on line 59 was always true1ab

60 self.include_in_schema = True 1ab

61 else: 

62 self.include_in_schema = False 

63 

64 

65class SpeakeasyNameOverrideAPIRoute(APIRoute): 1a

66 """ 

67 A subclass of `APIRoute` that automatically adds `x-speakeasy-name-override` property 

68 following the route function name. 

69 """ 

70 

71 def __init__(self, path: str, endpoint: Callable[..., Any], **kwargs: Any) -> None: 1a

72 super().__init__(path, endpoint, **kwargs) 1ab

73 endpoint_name = endpoint.__name__ 1ab

74 openapi_extra = self.openapi_extra or {} 1ab

75 if "x-speakeasy-name-override" not in openapi_extra: 1ab

76 self.openapi_extra = { 1a

77 **openapi_extra, 

78 "x-speakeasy-name-override": endpoint_name, 

79 } 

80 

81 

82class SpeakeasyIgnoreAPIRoute(APIRoute): 1a

83 """ 

84 A subclass of `APIRoute` that automatically adds `x-speakeasy-ignore` property 

85 to the OpenAPI schema if `APITag.documented` is missing. 

86 """ 

87 

88 def __init__(self, path: str, endpoint: Callable[..., Any], **kwargs: Any) -> None: 1a

89 super().__init__(path, endpoint, **kwargs) 1ab

90 tags = self.tags 1ab

91 if APITag.public not in tags: 1ab

92 openapi_extra = self.openapi_extra or {} 1ab

93 self.openapi_extra = {**openapi_extra, "x-speakeasy-ignore": True} 1ab

94 

95 

96class SpeakeasyGroupAPIRoute(APIRoute): 1a

97 """ 

98 A subclass of `APIRoute` that automatically adds `x-speakeasy-group` property 

99 to the OpenAPI schema by combining all the non-generic tags. 

100 """ 

101 

102 def __init__(self, path: str, endpoint: Callable[..., Any], **kwargs: Any) -> None: 1a

103 super().__init__(path, endpoint, **kwargs) 1ab

104 non_generic_tags = [str(tag) for tag in self.tags if tag not in APITag] 1ab

105 if len(non_generic_tags) > 0: 105 ↛ exitline 105 didn't return from function '__init__' because the condition on line 105 was always true1ab

106 openapi_extra = self.openapi_extra or {} 1ab

107 self.openapi_extra = { 1ab

108 **openapi_extra, 

109 "x-speakeasy-group": ".".join(non_generic_tags), 

110 } 

111 

112 

113class SpeakeasyPaginationAPIRoute(APIRoute): 1a

114 """ 

115 A subclass of `APIRoute` that automatically adds `x-speakeasy-pagination` property 

116 to the OpenAPI schema if the endpoint response model is a `ListResource`. 

117 """ 

118 

119 def __init__(self, path: str, endpoint: Callable[..., Any], **kwargs: Any) -> None: 1a

120 super().__init__(path, endpoint, **kwargs) 1ab

121 response_model = self.response_model 1ab

122 if ( 

123 response_model is not None 

124 and inspect.isclass(response_model) 

125 and ListResource in response_model.mro() 

126 ): 

127 openapi_extra = self.openapi_extra or {} 1ab

128 self.openapi_extra = { 1ab

129 **openapi_extra, 

130 "x-speakeasy-pagination": { 

131 "type": "offsetLimit", 

132 "inputs": [ 

133 { 

134 "name": "page", 

135 "in": "parameters", 

136 "type": "page", 

137 }, 

138 { 

139 "name": "limit", 

140 "in": "parameters", 

141 "type": "limit", 

142 }, 

143 ], 

144 "outputs": { 

145 "results": "$.items", 

146 "numPages": "$.pagination.max_page", 

147 }, 

148 }, 

149 } 

150 

151 

152class SpeakeasyMCPAPIRoute(APIRoute): 1a

153 """ 

154 A subclass of `APIRoute` that automatically adds `x-speakeasy-mcp` property 

155 to the OpenAPI schema. 

156 """ 

157 

158 def __init__(self, path: str, endpoint: Callable[..., Any], **kwargs: Any) -> None: 1a

159 super().__init__(path, endpoint, **kwargs) 1ab

160 openapi_extra = self.openapi_extra or {} 1ab

161 if APITag.mcp in self.tags: 1ab

162 safe_method = all( 1ab

163 method in {"GET", "HEAD", "OPTIONS"} for method in self.methods 

164 ) 

165 scopes = [ 1ab

166 "read" if safe_method else "write", 

167 ] 

168 non_generic_tags = [str(tag) for tag in self.tags if tag not in APITag] 1ab

169 if len(non_generic_tags) > 0: 169 ↛ 172line 169 didn't jump to line 172 because the condition on line 169 was always true1ab

170 scopes.append(".".join(non_generic_tags)) 1ab

171 

172 openapi_extra = { 1ab

173 **openapi_extra, 

174 "x-speakeasy-mcp": {"disabled": False, "scopes": scopes}, 

175 } 

176 else: 

177 openapi_extra = {**openapi_extra, "x-speakeasy-mcp": {"disabled": True}} 1ab

178 self.openapi_extra = openapi_extra 1ab

179 

180 

181def _inherit_signature_from[**P, T]( 1a

182 _to: Callable[P, T], 

183) -> Callable[[Callable[..., T]], Callable[P, T]]: 

184 return lambda x: x # pyright: ignore 1a

185 

186 

187def get_api_router_class(route_class: type[APIRoute]) -> type[_APIRouter]: 1a

188 """ 

189 Returns a subclass of `APIRouter` that uses the given `route_class`. 

190 """ 

191 

192 class _CustomAPIRouter(_APIRouter): 1a

193 @_inherit_signature_from(_APIRouter.__init__) 1a

194 def __init__(self, *args: Any, **kwargs: Any) -> None: 1a

195 kwargs["route_class"] = route_class 1a

196 super().__init__(*args, **kwargs) 1a

197 

198 return _CustomAPIRouter 1a

199 

200 

201__all__ = [ 1a

202 "get_api_router_class", 

203 "AutoCommitAPIRoute", 

204 "IncludedInSchemaAPIRoute", 

205 "SpeakeasyGroupAPIRoute", 

206 "SpeakeasyIgnoreAPIRoute", 

207 "SpeakeasyNameOverrideAPIRoute", 

208 "SpeakeasyPaginationAPIRoute", 

209 "SpeakeasyMCPAPIRoute", 

210]