Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions projects/fal/src/fal/toolkit/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from fal.compat import run_in_thread
from fal.ref import get_current_app
from fal.toolkit.file.providers.fal import (
FalCDNFileRepository,
FalFileRepository,
FalFileRepositoryV2,
FalFileRepositoryV3,
Expand All @@ -45,7 +44,6 @@
"in_memory": lambda: InMemoryRepository(),
"gcp_storage": lambda: GoogleStorageRepository(),
"r2": lambda: R2Repository(),
"cdn": lambda: FalCDNFileRepository(),
}


Expand All @@ -61,7 +59,7 @@ def get_builtin_repository(id: RepositoryId | FileRepository) -> FileRepository:
get_builtin_repository.__module__ = "__main__"

DEFAULT_REPOSITORY: FileRepository | RepositoryId = "fal_v3"
FALLBACK_REPOSITORY: list[FileRepository | RepositoryId] = ["cdn", "fal"]
FALLBACK_REPOSITORY: list[FileRepository | RepositoryId] = ["fal"]
OBJECT_LIFECYCLE_PREFERENCE_KEY = "x-fal-object-lifecycle-preference"


Expand Down
42 changes: 0 additions & 42 deletions projects/fal/src/fal/toolkit/file/providers/fal.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from fal.toolkit.file.types import FileData, FileRepository
from fal.toolkit.utils.retry import retry

_FAL_CDN = "https://fal.media"
_FAL_CDN_V3 = "https://v3.fal.media"


Expand Down Expand Up @@ -1292,47 +1291,6 @@ def save(
return f"data:{file.content_type};base64,{b64encode(file.data).decode('utf-8')}"


@dataclass
class FalCDNFileRepository(FileRepository):
def save(
self,
file: FileData,
multipart: bool | None = None,
multipart_threshold: int | None = None,
multipart_chunk_size: int | None = None,
multipart_max_concurrency: int | None = None,
object_lifecycle_preference: dict[str, str] | None = None,
) -> str:
headers = {
**self.auth_headers,
"Accept": "application/json",
"Content-Type": file.content_type,
"X-Fal-File-Name": file.file_name,
}

_object_lifecycle_headers(headers, object_lifecycle_preference)

url = os.getenv("FAL_CDN_HOST", _FAL_CDN) + "/files/upload"
request = Request(url, headers=headers, method="POST", data=file.data)
try:
with _maybe_retry_request(request) as response:
result = json.load(response)
except HTTPError as e:
raise FileUploadException(
f"Error initiating upload. Status {e.status}: {e.reason}"
)

access_url = result["access_url"]
return access_url

@property
def auth_headers(self) -> dict[str, str]:
return {
"Authorization": f"Bearer {_require_auth_credentials().token}",
"User-Agent": USER_AGENT,
}


@dataclass
class FalFileRepositoryV3(FileRepository):
@property
Expand Down
4 changes: 1 addition & 3 deletions projects/fal/src/fal/toolkit/file/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ def __init__(
self.file_name = file_name


RepositoryId = Literal[
"fal", "fal_v2", "fal_v3", "in_memory", "gcp_storage", "r2", "cdn"
]
RepositoryId = Literal["fal", "fal_v2", "fal_v3", "in_memory", "gcp_storage", "r2"]


@dataclass
Expand Down
14 changes: 0 additions & 14 deletions projects/fal/tests/integration/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
download_file,
download_model_weights,
)
from fal.toolkit import (
Image as FalImage,
)
from fal.toolkit.file.file import CompressedFile
from fal.toolkit.utils.download_utils import _git_rev_parse, _hash_url

Expand Down Expand Up @@ -572,17 +569,6 @@ def init_compressed_file_on_fal(input: TestInput) -> int:
assert len(extracted_file_paths) == 3


@pytest.mark.flaky(max_runs=3)
def test_fal_cdn(isolated_client):
@isolated_client(requirements=[f"pydantic=={pydantic_version}", "tomli"])
def upload_to_fal_cdn() -> FalImage:
return FalImage.from_bytes(b"0", "jpeg", repository="cdn")

uploaded_image = upload_to_fal_cdn()

assert uploaded_image


def test_download_file_with_slash_in_filename():
from fal.toolkit.utils.download_utils import _filename_from_response
from fal.toolkit.utils.ssrf import SafeResponse
Expand Down
46 changes: 0 additions & 46 deletions projects/fal/tests/unit/toolkit/file/providers/test_cdn_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,52 +278,6 @@ def test_fal_file_repository_v3_auth_headers_with_bearer_credentials():
assert headers["Authorization"] == "Bearer jwt-token"


def test_fal_cdn_file_repository_auth_headers_with_key_credentials():
"""FalCDNFileRepository emits `Bearer <id>:<secret>` even for Key creds."""
with patch.object(
providers,
"fetch_auth_credentials",
return_value=AuthCredentials("Key", "key_id:key_secret"),
):
repo = providers.FalCDNFileRepository()
headers = repo.auth_headers

# CDN endpoint expects Bearer scheme regardless of how creds were obtained.
assert headers["Authorization"] == "Bearer key_id:key_secret"
assert headers["User-Agent"] == providers.USER_AGENT


def test_fal_cdn_file_repository_auth_headers_with_bearer_credentials():
"""FalCDNFileRepository emits `Bearer <jwt>` for auth0 bearer creds."""
with patch.object(
providers,
"fetch_auth_credentials",
return_value=AuthCredentials("Bearer", "jwt-token"),
):
repo = providers.FalCDNFileRepository()
headers = repo.auth_headers

assert headers["Authorization"] == "Bearer jwt-token"
assert headers["User-Agent"] == providers.USER_AGENT


def test_fal_cdn_file_repository_raises_file_upload_exception_when_missing():
"""FalCDNFileRepository surfaces missing creds as FileUploadException."""
from fal.exceptions.auth import UnauthenticatedException
from fal.toolkit.exceptions import FileUploadException

with patch.object(
providers,
"fetch_auth_credentials",
side_effect=UnauthenticatedException(),
):
repo = providers.FalCDNFileRepository()
with pytest.raises(FileUploadException) as excinfo:
_ = repo.auth_headers

assert isinstance(excinfo.value.__cause__, UnauthenticatedException)


def test_internal_multipart_upload_v3_auth_headers_include_cdn_token():
"""InternalMultipartUploadV3 auth_headers should include CDN token."""
cdn_token = "test-cdn-token"
Expand Down
102 changes: 3 additions & 99 deletions projects/fal_client/src/fal_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
from PIL import Image

AnyJSON = Dict[str, Any]
UploadRepositoryId = Literal["fal_v3", "cdn", "fal"]
UploadRepositoryId = Literal["fal_v3", "fal"]
LifecyclePreferencePayload = Dict[str, Any]
ObjectExpiration = Union[
Literal["never", "immediate", "1h", "1d", "7d", "30d", "1y"],
Expand All @@ -81,9 +81,8 @@
REALTIME_URL_FORMAT = f"wss://{FAL_RUN_HOST}/"
REST_URL = "https://rest.fal.ai"
CDN_URL = "https://v3.fal.media"
FAL_CDN_FALLBACK_URL = os.environ.get("FAL_CDN_HOST", "https://fal.media")
DEFAULT_UPLOAD_REPOSITORY: UploadRepositoryId = "fal_v3"
DEFAULT_UPLOAD_FALLBACK_REPOSITORY: list[UploadRepositoryId] = ["cdn", "fal"]
DEFAULT_UPLOAD_FALLBACK_REPOSITORY: list[UploadRepositoryId] = ["fal"]
USER_AGENT = f"fal-client/{__version__} (python)"

MIN_REQUEST_TIMEOUT_SECONDS = 1
Expand Down Expand Up @@ -1170,12 +1169,6 @@ async def _async_maybe_retry_request(
raise RuntimeError("Failed to perform request")


def _cdn_auth_header(auth: AuthCredentials) -> str:
if auth.scheme.lower() == "key":
return f"Bearer {auth.token}"
return auth.header_value


def _object_lifecycle_headers(
headers: dict[str, str],
object_lifecycle_preference: LifecyclePreferencePayload | None,
Expand Down Expand Up @@ -1234,19 +1227,6 @@ def _normalize_upload_lifecycle(
return normalized or None


def _cdn_upload_headers(
auth: AuthCredentials,
content_type: str,
file_name: str | None,
object_lifecycle_preference: LifecyclePreferencePayload | None,
) -> dict[str, str]:
headers = {"Content-Type": content_type, "Authorization": _cdn_auth_header(auth)}
if file_name is not None:
headers["X-Fal-File-Name"] = file_name
_object_lifecycle_headers(headers, object_lifecycle_preference)
return headers


def _storage_upload_headers(
auth: AuthCredentials,
object_lifecycle_preference: LifecyclePreferencePayload | None,
Expand All @@ -1264,7 +1244,7 @@ def _normalize_upload_repositories(
repository: UploadRepositoryId | None,
fallback_repository: UploadRepositoryId | list[UploadRepositoryId] | None,
) -> list[UploadRepositoryId]:
allowed = {"fal_v3", "cdn", "fal"}
allowed = {"fal_v3", "fal"}
if repository is None:
repository = DEFAULT_UPLOAD_REPOSITORY

Expand Down Expand Up @@ -1366,27 +1346,6 @@ def _upload_v3(
return response.json()["access_url"]


def _upload_cdn(
client: httpx.Client,
auth: AuthCredentials,
*,
data: bytes,
content_type: str,
file_name: str | None,
object_lifecycle_preference: LifecyclePreferencePayload | None = None,
) -> str:
response = _maybe_retry_request(
client,
"POST",
FAL_CDN_FALLBACK_URL + "/files/upload",
content=data,
headers=_cdn_upload_headers(
auth, content_type, file_name, object_lifecycle_preference
),
)
return response.json()["access_url"]


async def _async_upload_v3(
client: httpx.AsyncClient,
*,
Expand All @@ -1403,27 +1362,6 @@ async def _async_upload_v3(
return response.json()["access_url"]


async def _async_upload_cdn(
client: httpx.AsyncClient,
auth: AuthCredentials,
*,
data: bytes,
content_type: str,
file_name: str | None,
object_lifecycle_preference: LifecyclePreferencePayload | None = None,
) -> str:
response = await _async_maybe_retry_request(
client,
"POST",
FAL_CDN_FALLBACK_URL + "/files/upload",
content=data,
headers=_cdn_upload_headers(
auth, content_type, file_name, object_lifecycle_preference
),
)
return response.json()["access_url"]


def _try_upload_with_fallback(
attempts: list[tuple[str, Callable[[], str]]],
) -> str:
Expand Down Expand Up @@ -1996,25 +1934,6 @@ async def _v3_attempt() -> str:
for repo in repository_chain:
if repo == "fal_v3":
attempts.append(("fal_v3", _v3_attempt))
elif repo == "cdn":
if auth is None:
auth = await self._auth
if client is None:
client = await self._client
attempts.append(
(
"cdn",
partial(
_async_upload_cdn,
client,
auth,
data=data,
content_type=content_type,
file_name=file_name,
object_lifecycle_preference=resolved_lifecycle,
),
)
)
elif repo == "fal":
if auth is None:
auth = await self._auth
Expand Down Expand Up @@ -2513,21 +2432,6 @@ def upload(
partial(_upload_v3, client, data=data, headers=headers),
)
)
elif repo == "cdn":
attempts.append(
(
"cdn",
partial(
_upload_cdn,
self._client,
auth,
data=data,
content_type=content_type,
file_name=file_name,
object_lifecycle_preference=resolved_lifecycle,
),
)
)
elif repo == "fal":
attempts.append(
(
Expand Down
Loading
Loading