Coverage for polar/integrations/aws/s3/service.py: 26%

88 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 15:52 +0000

1import base64 1a

2from datetime import datetime, timedelta 1a

3from typing import TYPE_CHECKING, Any, cast 1a

4 

5import botocore 1a

6import structlog 1a

7from botocore.client import ClientError 1a

8 

9from polar.kit.utils import generate_uuid, utc_now 1a

10 

11from .client import client, get_client 1a

12from .exceptions import S3FileError 1a

13from .schemas import ( 1a

14 S3File, 

15 S3FileCreate, 

16 S3FileCreatePart, 

17 S3FileUpload, 

18 S3FileUploadCompleted, 

19 S3FileUploadMultipart, 

20 S3FileUploadPart, 

21 get_downloadable_content_disposition, 

22) 

23 

24if TYPE_CHECKING: 24 ↛ 25line 24 didn't jump to line 25 because the condition on line 24 was never true1a

25 from mypy_boto3_s3.client import S3Client 

26 from mypy_boto3_s3.type_defs import PutObjectRequestTypeDef 

27 

28log = structlog.get_logger() 1a

29 

30 

31class S3Service: 1a

32 def __init__( 1a

33 self, 

34 bucket: str, 

35 presign_ttl: int = 600, 

36 client: "S3Client" = client, 

37 ): 

38 self.bucket = bucket 1a

39 self.presign_ttl = presign_ttl 1a

40 self.client = client 1a

41 

42 def upload( 1a

43 self, 

44 data: bytes, 

45 path: str, 

46 mime_type: str, 

47 checksum_sha256_base64: str | None = None, 

48 ) -> str: 

49 """ 

50 Uploads a file directly to S3. 

51 

52 Mostly useful for files we generate on the backend, like invoices or exports. 

53 """ 

54 request: PutObjectRequestTypeDef = { 

55 "Bucket": self.bucket, 

56 "Key": path, 

57 "Body": data, 

58 "ContentType": mime_type, 

59 } 

60 if checksum_sha256_base64 is not None: 

61 request["ChecksumAlgorithm"] = "SHA256" 

62 request["ChecksumSHA256"] = checksum_sha256_base64 

63 

64 if checksum_sha256_base64: 

65 request["ChecksumSHA256"] = checksum_sha256_base64 

66 

67 response = self.client.put_object(**request) 

68 return path 

69 

70 def create_multipart_upload( 1a

71 self, data: S3FileCreate, namespace: str = "" 

72 ) -> S3FileUpload: 

73 if not data.organization_id: 

74 raise S3FileError("Organization ID is required") 

75 

76 file_uuid = generate_uuid() 

77 # Each organization gets its own directory 

78 # Containing one directory per file: {file_uuid}/{data.name} 

79 # Allowing multiple files to be named the same. 

80 path = f"{namespace}/{data.organization_id}/{file_uuid}/{data.name}" 

81 

82 file = S3File( 

83 id=file_uuid, 

84 organization_id=data.organization_id, 

85 name=data.name, 

86 path=path, 

87 mime_type=data.mime_type, 

88 size=data.size, 

89 storage_version=None, 

90 checksum_etag=None, 

91 checksum_sha256_base64=None, 

92 checksum_sha256_hex=None, 

93 last_modified_at=None, 

94 ) 

95 

96 if data.checksum_sha256_base64: 

97 sha256_base64 = data.checksum_sha256_base64 

98 file.checksum_sha256_base64 = sha256_base64 

99 file.checksum_sha256_hex = base64.b64decode(sha256_base64).hex() 

100 

101 multipart_upload = self.client.create_multipart_upload( 

102 Bucket=self.bucket, 

103 Key=file.path, 

104 ContentType=file.mime_type, 

105 ChecksumAlgorithm="SHA256", 

106 Metadata=file.to_metadata(), 

107 ) 

108 multipart_upload_id = multipart_upload.get("UploadId") 

109 if not multipart_upload_id: 

110 log.error( 

111 "aws.s3", 

112 organization_id=file.organization_id, 

113 filename=file.name, 

114 mime_type=file.mime_type, 

115 size=file.size, 

116 error="No upload ID returned from S3", 

117 ) 

118 raise S3FileError("No upload ID returned from S3") 

119 

120 parts = self.generate_presigned_upload_parts( 

121 path=file.path, 

122 parts=data.upload.parts, 

123 upload_id=multipart_upload_id, 

124 ) 

125 

126 upload = S3FileUpload( 

127 upload=S3FileUploadMultipart( 

128 id=multipart_upload_id, 

129 # Keep a shorthand for path here too for upload 

130 path=file.path, 

131 parts=parts, 

132 ), 

133 **file.model_dump(), 

134 ) 

135 return upload 

136 

137 def generate_presigned_upload_parts( 1a

138 self, 

139 *, 

140 path: str, 

141 parts: list[S3FileCreatePart], 

142 upload_id: str, 

143 ) -> list[S3FileUploadPart]: 

144 ret = [] 

145 expires_in = self.presign_ttl 

146 for part in parts: 

147 signed_post_url = self.client.generate_presigned_url( 

148 "upload_part", 

149 Params=dict( 

150 UploadId=upload_id, 

151 Bucket=self.bucket, 

152 Key=path, 

153 **part.get_boto3_arguments(), 

154 ), 

155 ExpiresIn=expires_in, 

156 ) 

157 presign_expires_at = utc_now() + timedelta(seconds=expires_in) 

158 headers = S3FileUploadPart.generate_headers(part.checksum_sha256_base64) 

159 ret.append( 

160 S3FileUploadPart( 

161 number=part.number, 

162 chunk_start=part.chunk_start, 

163 chunk_end=part.chunk_end, 

164 checksum_sha256_base64=part.checksum_sha256_base64, 

165 url=signed_post_url, 

166 expires_at=presign_expires_at, 

167 headers=headers, 

168 ) 

169 ) 

170 return ret 

171 

172 def get_object_or_raise(self, path: str, s3_version_id: str = "") -> dict[str, Any]: 1a

173 try: 

174 obj = self.client.get_object( 

175 Bucket=self.bucket, 

176 Key=path, 

177 VersionId=s3_version_id, 

178 ChecksumMode="ENABLED", 

179 ) 

180 except ClientError: 

181 raise S3FileError("No object on S3") 

182 

183 return cast(dict[str, Any], obj) 

184 

185 def get_head_or_raise(self, path: str, s3_version_id: str = "") -> dict[str, Any]: 1a

186 try: 

187 head = self.client.head_object( 

188 Bucket=self.bucket, Key=path, VersionId=s3_version_id 

189 ) 

190 except ClientError: 

191 raise S3FileError("No metadata from S3") 

192 

193 return cast(dict[str, Any], head) 

194 

195 def complete_multipart_upload(self, data: S3FileUploadCompleted) -> S3File: 1a

196 boto_arguments = data.get_boto3_arguments() 

197 response = self.client.complete_multipart_upload( 

198 Bucket=self.bucket, Key=data.path, **boto_arguments 

199 ) 

200 if not response: 

201 raise S3FileError("No response from S3") 

202 

203 version_id = response.get("VersionId", "") 

204 head = self.get_head_or_raise(data.path, s3_version_id=version_id) 

205 file = S3File.from_head(data.path, head) 

206 return file 

207 

208 def generate_presigned_download_url( 1a

209 self, 

210 *, 

211 path: str, 

212 filename: str, 

213 mime_type: str, 

214 ) -> tuple[str, datetime]: 

215 expires_in = self.presign_ttl 

216 presign_from = utc_now() 

217 signed_download_url = self.client.generate_presigned_url( 

218 "get_object", 

219 Params=dict( 

220 Bucket=self.bucket, 

221 Key=path, 

222 ResponseContentDisposition=get_downloadable_content_disposition( 

223 filename 

224 ), 

225 ResponseContentType=mime_type, 

226 ), 

227 ExpiresIn=expires_in, 

228 ) 

229 

230 presign_expires_at = presign_from + timedelta(seconds=expires_in) 

231 return (signed_download_url, presign_expires_at) 

232 

233 def get_public_url(self, path: str) -> str: 1a

234 # This is apparently the *only* way to get a public URL with boto3, 

235 # apart from building a URL manually 🙄 

236 # Ref: https://stackoverflow.com/a/48197923 

237 unsigned_client = get_client(signature_version=botocore.UNSIGNED) 

238 return unsigned_client.generate_presigned_url( 

239 "get_object", ExpiresIn=0, Params=dict(Bucket=self.bucket, Key=path) 

240 ) 

241 

242 def delete_file(self, path: str) -> bool: 1a

243 deleted = self.client.delete_object(Bucket=self.bucket, Key=path) 

244 return deleted.get("DeleteMarker", False)