Coverage for polar/file/service.py: 27%
85 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
1import uuid 1a
2from collections.abc import Sequence 1a
3from datetime import datetime 1a
5import structlog 1a
7from polar.auth.models import AuthSubject 1a
8from polar.integrations.aws.s3 import S3FileError 1a
9from polar.kit.pagination import PaginationParams 1a
10from polar.models import Organization, ProductMedia, User 1a
11from polar.models.file import File, ProductMediaFile 1a
12from polar.postgres import AsyncReadSession, AsyncSession, sql 1a
14from .repository import FileRepository 1a
15from .s3 import S3_SERVICES 1a
16from .schemas import ( 1a
17 FileCreate,
18 FileDownload,
19 FilePatch,
20 FileUpload,
21 FileUploadCompleted,
22)
24log = structlog.get_logger() 1a
27class FileError(S3FileError): ... 1a
30class FileService: 1a
31 async def list( 1a
32 self,
33 session: AsyncReadSession,
34 auth_subject: AuthSubject[User | Organization],
35 *,
36 organization_id: Sequence[uuid.UUID] | None = None,
37 ids: Sequence[uuid.UUID] | None = None,
38 pagination: PaginationParams,
39 ) -> tuple[Sequence[File], int]:
40 repository = FileRepository.from_session(session)
42 statement = repository.get_readable_statement(auth_subject).where(
43 File.is_uploaded.is_(True)
44 )
46 if organization_id is not None:
47 statement = statement.where(File.organization_id.in_(organization_id))
49 if ids is not None:
50 statement = statement.where(File.id.in_(ids))
52 return await repository.paginate(
53 statement, limit=pagination.limit, page=pagination.page
54 )
56 async def get( 1a
57 self,
58 session: AsyncSession,
59 auth_subject: AuthSubject[User | Organization],
60 id: uuid.UUID,
61 ) -> File | None:
62 repository = FileRepository.from_session(session)
63 statement = repository.get_readable_statement(auth_subject).where(File.id == id)
64 return await repository.get_one_or_none(statement)
66 async def patch( 1a
67 self,
68 session: AsyncSession,
69 *,
70 file: File,
71 patches: FilePatch,
72 ) -> File:
73 changes = False
74 if patches.name:
75 file.name = patches.name
76 changes = True
78 if patches.version:
79 file.version = patches.version
80 changes = True
82 if not changes:
83 return file
85 session.add(file)
86 await session.flush()
87 return file
89 async def generate_presigned_upload( 1a
90 self,
91 session: AsyncSession,
92 *,
93 organization: Organization,
94 create_schema: FileCreate,
95 ) -> FileUpload:
96 s3_service = S3_SERVICES[create_schema.service]
97 upload = s3_service.create_multipart_upload(
98 create_schema, namespace=create_schema.service.value
99 )
101 instance = File(
102 organization=organization,
103 service=create_schema.service,
104 is_enabled=True,
105 is_uploaded=False,
106 **upload.model_dump(exclude={"upload", "organization_id", "size_readable"}),
107 )
108 session.add(instance)
109 await session.flush()
110 assert instance.id is not None
112 return FileUpload(
113 is_uploaded=instance.is_uploaded,
114 version=instance.version,
115 service=create_schema.service,
116 **upload.model_dump(),
117 )
119 async def complete_upload( 1a
120 self,
121 session: AsyncSession,
122 *,
123 file: File,
124 completed_schema: FileUploadCompleted,
125 ) -> File:
126 s3_service = S3_SERVICES[file.service]
127 s3file = s3_service.complete_multipart_upload(completed_schema)
129 file.is_uploaded = True
131 if s3file.checksum_etag:
132 file.checksum_etag = s3file.checksum_etag
134 if s3file.last_modified_at:
135 file.last_modified_at = s3file.last_modified_at
137 if s3file.storage_version:
138 file.storage_version = s3file.storage_version
140 session.add(file)
141 await session.flush()
142 assert file.checksum_etag is not None
143 assert file.last_modified_at is not None
145 return file
147 def generate_download_url(self, file: File) -> tuple[str, datetime]: 1a
148 """Generate a presigned download URL for a file."""
149 s3_service = S3_SERVICES[file.service]
150 return s3_service.generate_presigned_download_url(
151 path=file.path,
152 filename=file.name,
153 mime_type=file.mime_type,
154 )
156 def generate_downloadable_schema(self, file: File) -> FileDownload: 1a
157 url, expires_at = self.generate_download_url(file)
158 return FileDownload.from_presigned(file, url=url, expires_at=expires_at)
160 async def delete(self, session: AsyncSession, *, file: File) -> bool: 1a
161 file.set_deleted_at()
162 session.add(file)
163 assert file.deleted_at is not None
165 # Delete ProductMedia association table records
166 statement = sql.delete(ProductMedia).where(ProductMedia.file_id == file.id)
167 await session.execute(statement)
169 s3_service = S3_SERVICES[file.service]
170 deleted = s3_service.delete_file(file.path)
171 log.info("file.delete", file_id=file.id, s3_deleted=deleted)
172 return True
174 async def get_selectable_product_media_file( 1a
175 self,
176 session: AsyncSession,
177 id: uuid.UUID,
178 *,
179 organization_id: uuid.UUID,
180 ) -> ProductMediaFile | None:
181 statement = sql.select(ProductMediaFile).where(
182 File.id == id,
183 File.organization_id == organization_id,
184 File.is_uploaded.is_(True),
185 File.is_enabled.is_(True),
186 File.deleted_at.is_(None),
187 )
188 result = await session.execute(statement)
189 return result.scalar_one_or_none()
192file = FileService() 1a