Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/api/dependencies.py: 35%

80 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 11:21 +0000

1""" 

2Utilities for injecting FastAPI dependencies. 

3""" 

4 

5from __future__ import annotations 1a

6 

7import logging 1a

8import re 1a

9from base64 import b64decode 1a

10from typing import Annotated, Any, Optional 1a

11from uuid import UUID 1a

12 

13from docket import Docket as Docket_ 1a

14from fastapi import Body, Depends, Header, HTTPException 1a

15from packaging.version import Version 1a

16from starlette.requests import Request 1a

17 

18from prefect._internal.compatibility.starlette import status 1a

19from prefect.server import schemas 1a

20from prefect.settings import PREFECT_API_DEFAULT_LIMIT 1a

21 

22 

23def provide_request_api_version( 1a

24 x_prefect_api_version: str = Header(None), 

25) -> Version | None: 

26 if not x_prefect_api_version: 

27 return 

28 

29 # parse version 

30 try: 

31 _, _, _ = [int(v) for v in x_prefect_api_version.split(".")] 

32 except ValueError: 

33 raise HTTPException( 

34 status_code=status.HTTP_400_BAD_REQUEST, 

35 detail=( 

36 "Invalid X-PREFECT-API-VERSION header format.Expected header in format" 

37 f" 'x.y.z' but received {x_prefect_api_version}" 

38 ), 

39 ) 

40 return Version(x_prefect_api_version) 

41 

42 

43class EnforceMinimumAPIVersion: 1a

44 """ 

45 FastAPI Dependency used to check compatibility between the version of the api 

46 and a given request. 

47 

48 Looks for the header 'X-PREFECT-API-VERSION' in the request and compares it 

49 to the api's version. Rejects requests that are lower than the minimum version. 

50 """ 

51 

52 def __init__(self, minimum_api_version: str, logger: logging.Logger): 1a

53 self.minimum_api_version = minimum_api_version 1a

54 versions = [int(v) for v in minimum_api_version.split(".")] 1a

55 self.api_major: int = versions[0] 1a

56 self.api_minor: int = versions[1] 1a

57 self.api_patch: int = versions[2] 1a

58 self.logger = logger 1a

59 

60 async def __call__( 1a

61 self, 

62 x_prefect_api_version: str = Header(None), 

63 ) -> None: 

64 request_version = x_prefect_api_version 

65 

66 # if no version header, assume latest and continue 

67 if not request_version: 

68 return 

69 

70 # parse version 

71 try: 

72 major, minor, patch = [int(v) for v in request_version.split(".")] 

73 except ValueError: 

74 await self._notify_of_invalid_value(request_version) 

75 raise HTTPException( 

76 status_code=status.HTTP_400_BAD_REQUEST, 

77 detail=( 

78 "Invalid X-PREFECT-API-VERSION header format." 

79 f"Expected header in format 'x.y.z' but received {request_version}" 

80 ), 

81 ) 

82 

83 if (major, minor, patch) < (self.api_major, self.api_minor, self.api_patch): 

84 await self._notify_of_outdated_version(request_version) 

85 raise HTTPException( 

86 status_code=status.HTTP_400_BAD_REQUEST, 

87 detail=( 

88 f"The request specified API version {request_version} but this " 

89 f"server requires version {self.minimum_api_version} or higher." 

90 ), 

91 ) 

92 

93 async def _notify_of_invalid_value(self, request_version: str): 1a

94 self.logger.error( 

95 f"Invalid X-PREFECT-API-VERSION header format: '{request_version}'" 

96 ) 

97 

98 async def _notify_of_outdated_version(self, request_version: str): 1a

99 self.logger.error( 

100 f"X-PREFECT-API-VERSION header specifies version '{request_version}' " 

101 f"but minimum allowed version is '{self.minimum_api_version}'" 

102 ) 

103 

104 

105def LimitBody() -> Any: 1a

106 """ 

107 A `fastapi.Depends` factory for pulling a `limit: int` parameter from the 

108 request body while determining the default from the current settings. 

109 """ 

110 

111 def get_limit( 1a

112 limit: int = Body( 

113 None, 

114 description="Defaults to PREFECT_API_DEFAULT_LIMIT if not provided.", 

115 ), 

116 ): 

117 default_limit = PREFECT_API_DEFAULT_LIMIT.value() 

118 limit = limit if limit is not None else default_limit 

119 if not limit >= 0: 

120 raise HTTPException( 

121 status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, 

122 detail="Invalid limit: must be greater than or equal to 0.", 

123 ) 

124 if limit > default_limit: 

125 raise HTTPException( 

126 status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, 

127 detail=f"Invalid limit: must be less than or equal to {default_limit}.", 

128 ) 

129 return limit 

130 

131 return Depends(get_limit) 1a

132 

133 

134def get_created_by( 1a

135 prefect_automation_id: Optional[UUID] = Header(None, include_in_schema=False), 

136 prefect_automation_name: Optional[str] = Header(None, include_in_schema=False), 

137) -> Optional[schemas.core.CreatedBy]: 

138 """A dependency that returns the provenance information to use when creating objects 

139 during this API call.""" 

140 if prefect_automation_id and prefect_automation_name: 

141 try: 

142 display_value = b64decode(prefect_automation_name.encode()).decode() 

143 except Exception: 

144 display_value = None 

145 

146 if display_value: 

147 return schemas.core.CreatedBy( 

148 id=prefect_automation_id, 

149 type="AUTOMATION", 

150 display_value=display_value, 

151 ) 

152 

153 return None 

154 

155 

156def get_updated_by( 1a

157 prefect_automation_id: Optional[UUID] = Header(None, include_in_schema=False), 

158 prefect_automation_name: Optional[str] = Header(None, include_in_schema=False), 

159) -> Optional[schemas.core.UpdatedBy]: 

160 """A dependency that returns the provenance information to use when updating objects 

161 during this API call.""" 

162 if prefect_automation_id and prefect_automation_name: 

163 return schemas.core.UpdatedBy( 

164 id=prefect_automation_id, 

165 type="AUTOMATION", 

166 display_value=prefect_automation_name, 

167 ) 

168 

169 return None 

170 

171 

172def is_ephemeral_request(request: Request) -> bool: 1a

173 """ 

174 A dependency that returns whether the request is to an ephemeral server. 

175 """ 

176 return "ephemeral-prefect" in str(request.base_url) 

177 

178 

179PREFECT_CLIENT_USER_AGENT_PATTERN = re.compile( 1a

180 r"^prefect/(\d+\.\d+\.\d+(?:[a-z.+0-9]+)?) \(API \S+\)$" 

181) 

182 

183 

184def get_prefect_client_version( 1a

185 user_agent: Annotated[Optional[str], Header(include_in_schema=False)] = None, 

186) -> Optional[str]: 

187 """ 

188 Attempts to parse out the Prefect client version from the User-Agent header. 

189 

190 The Prefect client sets the User-Agent header like so: 

191 f"prefect/{prefect.__version__} (API {constants.SERVER_API_VERSION})" 

192 """ 

193 if not user_agent: 

194 return None 

195 

196 if client_version := PREFECT_CLIENT_USER_AGENT_PATTERN.match(user_agent): 

197 return client_version.group(1) 

198 return None 

199 

200 

201def docket(request: Request) -> Docket_: 1a

202 return request.app.state.docket 

203 

204 

205Docket = Annotated[Docket_, Depends(docket)] 1a