Coverage for polar/file/service.py: 27%

85 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1import uuid 1a

2from collections.abc import Sequence 1a

3from datetime import datetime 1a

4 

5import structlog 1a

6 

7from polar.auth.models import AuthSubject 1a

8from polar.integrations.aws.s3 import S3FileError 1a

9from polar.kit.pagination import PaginationParams 1a

10from polar.models import Organization, ProductMedia, User 1a

11from polar.models.file import File, ProductMediaFile 1a

12from polar.postgres import AsyncReadSession, AsyncSession, sql 1a

13 

14from .repository import FileRepository 1a

15from .s3 import S3_SERVICES 1a

16from .schemas import ( 1a

17 FileCreate, 

18 FileDownload, 

19 FilePatch, 

20 FileUpload, 

21 FileUploadCompleted, 

22) 

23 

24log = structlog.get_logger() 1a

25 

26 

27class FileError(S3FileError): ... 1a

28 

29 

30class FileService: 1a

31 async def list( 1a

32 self, 

33 session: AsyncReadSession, 

34 auth_subject: AuthSubject[User | Organization], 

35 *, 

36 organization_id: Sequence[uuid.UUID] | None = None, 

37 ids: Sequence[uuid.UUID] | None = None, 

38 pagination: PaginationParams, 

39 ) -> tuple[Sequence[File], int]: 

40 repository = FileRepository.from_session(session) 

41 

42 statement = repository.get_readable_statement(auth_subject).where( 

43 File.is_uploaded.is_(True) 

44 ) 

45 

46 if organization_id is not None: 

47 statement = statement.where(File.organization_id.in_(organization_id)) 

48 

49 if ids is not None: 

50 statement = statement.where(File.id.in_(ids)) 

51 

52 return await repository.paginate( 

53 statement, limit=pagination.limit, page=pagination.page 

54 ) 

55 

56 async def get( 1a

57 self, 

58 session: AsyncSession, 

59 auth_subject: AuthSubject[User | Organization], 

60 id: uuid.UUID, 

61 ) -> File | None: 

62 repository = FileRepository.from_session(session) 

63 statement = repository.get_readable_statement(auth_subject).where(File.id == id) 

64 return await repository.get_one_or_none(statement) 

65 

66 async def patch( 1a

67 self, 

68 session: AsyncSession, 

69 *, 

70 file: File, 

71 patches: FilePatch, 

72 ) -> File: 

73 changes = False 

74 if patches.name: 

75 file.name = patches.name 

76 changes = True 

77 

78 if patches.version: 

79 file.version = patches.version 

80 changes = True 

81 

82 if not changes: 

83 return file 

84 

85 session.add(file) 

86 await session.flush() 

87 return file 

88 

89 async def generate_presigned_upload( 1a

90 self, 

91 session: AsyncSession, 

92 *, 

93 organization: Organization, 

94 create_schema: FileCreate, 

95 ) -> FileUpload: 

96 s3_service = S3_SERVICES[create_schema.service] 

97 upload = s3_service.create_multipart_upload( 

98 create_schema, namespace=create_schema.service.value 

99 ) 

100 

101 instance = File( 

102 organization=organization, 

103 service=create_schema.service, 

104 is_enabled=True, 

105 is_uploaded=False, 

106 **upload.model_dump(exclude={"upload", "organization_id", "size_readable"}), 

107 ) 

108 session.add(instance) 

109 await session.flush() 

110 assert instance.id is not None 

111 

112 return FileUpload( 

113 is_uploaded=instance.is_uploaded, 

114 version=instance.version, 

115 service=create_schema.service, 

116 **upload.model_dump(), 

117 ) 

118 

119 async def complete_upload( 1a

120 self, 

121 session: AsyncSession, 

122 *, 

123 file: File, 

124 completed_schema: FileUploadCompleted, 

125 ) -> File: 

126 s3_service = S3_SERVICES[file.service] 

127 s3file = s3_service.complete_multipart_upload(completed_schema) 

128 

129 file.is_uploaded = True 

130 

131 if s3file.checksum_etag: 

132 file.checksum_etag = s3file.checksum_etag 

133 

134 if s3file.last_modified_at: 

135 file.last_modified_at = s3file.last_modified_at 

136 

137 if s3file.storage_version: 

138 file.storage_version = s3file.storage_version 

139 

140 session.add(file) 

141 await session.flush() 

142 assert file.checksum_etag is not None 

143 assert file.last_modified_at is not None 

144 

145 return file 

146 

147 def generate_download_url(self, file: File) -> tuple[str, datetime]: 1a

148 """Generate a presigned download URL for a file.""" 

149 s3_service = S3_SERVICES[file.service] 

150 return s3_service.generate_presigned_download_url( 

151 path=file.path, 

152 filename=file.name, 

153 mime_type=file.mime_type, 

154 ) 

155 

156 def generate_downloadable_schema(self, file: File) -> FileDownload: 1a

157 url, expires_at = self.generate_download_url(file) 

158 return FileDownload.from_presigned(file, url=url, expires_at=expires_at) 

159 

160 async def delete(self, session: AsyncSession, *, file: File) -> bool: 1a

161 file.set_deleted_at() 

162 session.add(file) 

163 assert file.deleted_at is not None 

164 

165 # Delete ProductMedia association table records 

166 statement = sql.delete(ProductMedia).where(ProductMedia.file_id == file.id) 

167 await session.execute(statement) 

168 

169 s3_service = S3_SERVICES[file.service] 

170 deleted = s3_service.delete_file(file.path) 

171 log.info("file.delete", file_id=file.id, s3_deleted=deleted) 

172 return True 

173 

174 async def get_selectable_product_media_file( 1a

175 self, 

176 session: AsyncSession, 

177 id: uuid.UUID, 

178 *, 

179 organization_id: uuid.UUID, 

180 ) -> ProductMediaFile | None: 

181 statement = sql.select(ProductMediaFile).where( 

182 File.id == id, 

183 File.organization_id == organization_id, 

184 File.is_uploaded.is_(True), 

185 File.is_enabled.is_(True), 

186 File.deleted_at.is_(None), 

187 ) 

188 result = await session.execute(statement) 

189 return result.scalar_one_or_none() 

190 

191 

192file = FileService() 1a