Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 76 additions & 6 deletions projects/fal_client/src/fal_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
CDN_URL = "https://v3.fal.media"
USER_AGENT = "fal-client/0.2.2 (python)"

SIGNED_URL_DURATION = 600


@dataclass
class CDNToken:
Expand Down Expand Up @@ -153,6 +155,7 @@ def __init__(
chunk_size: int | None = None,
content_type: str | None = None,
max_concurrency: int | None = None,
private: bool = False,
) -> None:
self.file_name = file_name
self._client = client
Expand All @@ -163,6 +166,7 @@ def __init__(
self._access_url: str | None = None
self._upload_id: str | None = None
self._parts: list[dict] = []
self._private = private

@property
def access_url(self) -> str:
Expand Down Expand Up @@ -191,6 +195,7 @@ def create(self):
self._client,
"POST",
url,
params={"private": True} if self._private else {},
headers={
**self.auth_headers,
"Accept": "application/json",
Expand Down Expand Up @@ -248,6 +253,7 @@ def save(
content_type: str | None = None,
chunk_size: int | None = None,
max_concurrency: int | None = None,
private: bool = False,
):
import concurrent.futures

Expand All @@ -258,6 +264,7 @@ def save(
chunk_size=chunk_size,
content_type=content_type,
max_concurrency=max_concurrency,
private=private,
)
multipart.create()
parts = math.ceil(len(data) / multipart.chunk_size)
Expand Down Expand Up @@ -285,6 +292,7 @@ def save_file(
chunk_size: int | None = None,
content_type: str | None = None,
max_concurrency: int | None = None,
private: bool = False,
) -> str:
import concurrent.futures

Expand All @@ -297,6 +305,7 @@ def save_file(
chunk_size=chunk_size,
content_type=content_type,
max_concurrency=max_concurrency,
private=private,
)
multipart.create()
parts = math.ceil(size / multipart.chunk_size)
Expand Down Expand Up @@ -329,6 +338,7 @@ def __init__(
chunk_size: int | None = None,
content_type: str | None = None,
max_concurrency: int | None = None,
private: bool = False,
) -> None:
self.file_name = file_name
self._client = client
Expand All @@ -339,6 +349,7 @@ def __init__(
self._access_url: str | None = None
self._upload_id: str | None = None
self._parts: list[dict] = []
self._private = private

@property
def access_url(self) -> str:
Expand Down Expand Up @@ -368,6 +379,7 @@ async def create(self):
self._client,
"POST",
url,
params={"private": True} if self._private else {},
headers={
**headers,
"Accept": "application/json",
Expand Down Expand Up @@ -427,6 +439,7 @@ async def save(
content_type: str | None = None,
chunk_size: int | None = None,
max_concurrency: int | None = None,
private: bool = False,
) -> str:
multipart = cls(
file_name=file_name,
Expand All @@ -435,6 +448,7 @@ async def save(
chunk_size=chunk_size,
content_type=content_type,
max_concurrency=max_concurrency,
private=private,
)
await multipart.create()
parts = math.ceil(len(data) / multipart.chunk_size)
Expand Down Expand Up @@ -469,6 +483,7 @@ async def save_file(
chunk_size: int | None = None,
content_type: str | None = None,
max_concurrency: int | None = None,
private: bool = False,
) -> str:
file_name = os.path.basename(file_path)
size = os.path.getsize(file_path)
Expand All @@ -479,6 +494,7 @@ async def save_file(
chunk_size=chunk_size,
content_type=content_type,
max_concurrency=max_concurrency,
private=private,
)
await multipart.create()
parts = math.ceil(size / multipart.chunk_size)
Expand Down Expand Up @@ -1114,6 +1130,7 @@ async def cancel(self) -> None:
class AsyncClient:
key: str | None = field(default=None, repr=False)
default_timeout: float = 120.0
acl_enabled: bool = False

def _get_key(self) -> str:
if self.key is None:
Expand Down Expand Up @@ -1163,6 +1180,22 @@ async def _get_realtime_token(
)
return _parse_token_response(response.json())

async def _get_signed_url(self, access_url: str) -> str:
if not access_url.startswith(CDN_URL + "/files/b/"):
raise ValueError(f"Invalid access URL: {access_url}")

client = await self._get_cdn_client()
response = await _async_maybe_retry_request(
client,
"POST",
f"{access_url}/sign",
json={"duration": SIGNED_URL_DURATION, "scope": ["read"]},
)
_raise_for_status(response)
signed_url = response.text

return signed_url

async def run(
self,
application: str,
Expand Down Expand Up @@ -1335,26 +1368,34 @@ async def upload(
if len(data) > MULTIPART_THRESHOLD:
if file_name is None:
file_name = "upload.bin"
return await AsyncMultipartUpload.save(
access_url = await AsyncMultipartUpload.save(
client=client,
token_manager=self._token_manager,
file_name=file_name,
data=data,
content_type=content_type,
private=self.acl_enabled,
)
if self.acl_enabled:
access_url = await self._get_signed_url(access_url)
return access_url

headers = {"Content-Type": content_type}
if file_name is not None:
headers["X-Fal-File-Name"] = file_name

response = await client.post(
CDN_URL + "/files/upload",
params={"private": True} if self.acl_enabled else {},
content=data,
headers=headers,
)
_raise_for_status(response)

return response.json()["access_url"]
access_url = response.json()["access_url"]
if self.acl_enabled:
access_url = await self._get_signed_url(access_url)
return access_url

async def upload_file(self, path: os.PathLike) -> str:
"""Upload a file from the local filesystem to the CDN and return the access URL."""
Expand All @@ -1365,12 +1406,16 @@ async def upload_file(self, path: os.PathLike) -> str:

if os.path.getsize(path) > MULTIPART_THRESHOLD:
client = await self._get_cdn_client()
return await AsyncMultipartUpload.save_file(
access_url = await AsyncMultipartUpload.save_file(
file_path=str(path),
client=client,
token_manager=self._token_manager,
content_type=mime_type,
private=self.acl_enabled,
)
if self.acl_enabled:
access_url = await self._get_signed_url(access_url)
return access_url

with open(path, "rb") as file:
return await self.upload(
Expand Down Expand Up @@ -1425,6 +1470,7 @@ async def ws_connect(
class SyncClient:
key: str | None = field(default=None, repr=False)
default_timeout: float = 120.0
acl_enabled: bool = False

def _get_key(self) -> str:
if self.key is None:
Expand Down Expand Up @@ -1475,6 +1521,20 @@ def _get_realtime_token(
)
return _parse_token_response(response.json())

def _get_signed_url(self, access_url: str) -> str:
if not access_url.startswith(CDN_URL + "/files/b/"):
raise ValueError(f"Invalid access URL: {access_url}")

client = self._get_cdn_client()
response = client.post(
f"{access_url}/sign",
json={"duration": SIGNED_URL_DURATION, "scope": ["read"]},
)
_raise_for_status(response)
signed_url = response.text

return signed_url

def run(
self,
application: str,
Expand Down Expand Up @@ -1643,13 +1703,17 @@ def upload(
if len(data) > MULTIPART_THRESHOLD:
if file_name is None:
file_name = "upload.bin"
return MultipartUpload.save(
access_url = MultipartUpload.save(
client=client,
token_manager=self._token_manager,
file_name=file_name,
data=data,
content_type=content_type,
private=self.acl_enabled,
)
if self.acl_enabled:
access_url = self._get_signed_url(access_url)
return access_url

headers = {"Content-Type": content_type}
if file_name is not None:
Expand All @@ -1662,7 +1726,10 @@ def upload(
)
_raise_for_status(response)

return response.json()["access_url"]
access_url = response.json()["access_url"]
if self.acl_enabled:
access_url = self._get_signed_url(access_url)
return access_url

def upload_file(self, path: os.PathLike) -> str:
"""Upload a file from the local filesystem to the CDN and return the access URL."""
Expand All @@ -1673,12 +1740,15 @@ def upload_file(self, path: os.PathLike) -> str:

if os.path.getsize(path) > MULTIPART_THRESHOLD:
client = self._get_cdn_client()
return MultipartUpload.save_file(
access_url = MultipartUpload.save_file(
file_path=str(path),
client=client,
token_manager=self._token_manager,
content_type=mime_type,
)
if self.acl_enabled:
access_url = self._get_signed_url(access_url)
return access_url

with open(path, "rb") as file:
return self.upload(file.read(), mime_type, file_name=os.path.basename(path))
Expand Down
Loading