Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions projects/fal/src/fal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fal.api import function as isolated # noqa: F401
from fal.app import App, endpoint, realtime, wrap_app # noqa: F401
from fal.container import ContainerImage
from fal.helpers import warm_dir, warm_file
from fal.sdk import FalServerlessKeyCredentials, HealthCheck
from fal.sync import sync_dir

Expand All @@ -26,6 +27,8 @@
"HealthCheck",
"FalServerlessKeyCredentials",
"sync_dir",
"warm_file",
"warm_dir",
"__version__",
"version_tuple",
"ContainerImage",
Expand Down
39 changes: 35 additions & 4 deletions projects/fal/src/fal/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import time
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, ClassVar, Optional

import fastapi
Expand All @@ -27,9 +28,11 @@
from fal.api import (
function as fal_function,
)
from fal.app_files import get_app_files_relative_path, include_app_files_path
from fal.auth import key_credentials
from fal.container import ContainerImage
from fal.exceptions import FalServerlessException, RequestCancelledException
from fal.helpers import warm_dir
from fal.logging import get_logger
from fal.realtime import realtime # noqa: F401
from fal.ref import get_current_app, set_current_app
Expand Down Expand Up @@ -59,10 +62,10 @@


async def _call_any_fn(fn, *args, **kwargs):
if inspect.iscoroutinefunction(fn):
return await fn(*args, **kwargs)
else:
return fn(*args, **kwargs)
result = fn(*args, **kwargs)
if inspect.isawaitable(result):
return await result
return result


async def open_isolate_channel(address: str) -> async_grpc.Channel | None:
Expand Down Expand Up @@ -512,6 +515,8 @@ class App(BaseServable):
Default excludes `.pyc`, `__pycache__`, `.git`, `.DS_Store`.
app_files_context_dir: Base directory for resolving app_files paths.
Defaults to the directory containing the app file.
data_dirs: Directories to pre-read in parallel during setup. This is
primarily useful for warming model weights stored under `/data`.
request_timeout: Maximum seconds for a single request. None for default.
startup_timeout: Maximum seconds for app startup/setup. None for default.
min_concurrency: Minimum warm instances to keep running. Set to 1+ to
Expand Down Expand Up @@ -555,6 +560,7 @@ class App(BaseServable):
app_files: ClassVar[list[str]] = []
app_files_ignore: ClassVar[list[str]] = DEFAULT_APP_FILES_IGNORE
app_files_context_dir: ClassVar[Optional[str]] = None
data_dirs: ClassVar[list[str]] = []
request_timeout: ClassVar[Optional[int]] = None
startup_timeout: ClassVar[Optional[int]] = None
min_concurrency: ClassVar[Optional[int]] = None
Expand Down Expand Up @@ -613,6 +619,9 @@ def __init_subclass__(cls, **kwargs):
"app_files_context_dir is only supported when app_files is provided"
)

if cls.data_dirs:
cls.host_kwargs["data_dirs"] = cls.data_dirs

if cls.min_concurrency is not None:
cls.host_kwargs["min_concurrency"] = cls.min_concurrency

Expand Down Expand Up @@ -790,6 +799,28 @@ async def lifespan(self, app: fastapi.FastAPI):
except (ImportError, AttributeError):
pass

# We want to not do any directory changes for container apps,
# since we don't have explicit checks to see the kind of app
# We check for app_files here and check kind and app_files earlier
# to ensure that container apps don't have app_files
if self.app_files:
# For app_files deployments (always use /app)
app_files_relative_path = get_app_files_relative_path(
self.local_file_path or os.getcwd(), self.app_files_context_dir
)
include_app_files_path(app_files_relative_path)
elif self.image is not None:
# For containers, add the working directory to sys.path
# isolate's runpy.run_path() overrides sys.path[0],
# so the working directory is never added to sys.path
sys.path.insert(0, "")

for directory in self.host_kwargs.get("data_dirs", []):
warm_path = Path(directory).expanduser()
if not warm_path.is_absolute():
warm_path = Path("/data") / warm_path
warm_dir(os.fspath(warm_path))

_print_python_packages()
setup_started_at = time.perf_counter()
await _call_any_fn(self.setup)
Expand Down
89 changes: 89 additions & 0 deletions projects/fal/src/fal/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from __future__ import annotations

import concurrent.futures
import os
from os import PathLike
from pathlib import Path
from typing import Iterator

from fal.logging import get_logger

DEFAULT_WARM_DIR_PARALLELISM = 32
DEFAULT_WARM_FILE_CHUNK_SIZE = 1024 * 1024

logger = get_logger(__name__)


def _iter_regular_files(directory: Path) -> Iterator[Path]:
with os.scandir(directory) as entries:
for entry in entries:
if entry.is_file(follow_symlinks=False):
yield Path(entry.path)
elif entry.is_dir(follow_symlinks=False):
yield from _iter_regular_files(Path(entry.path))


def warm_file(
file_path: str | PathLike[str],
chunk_size: int = DEFAULT_WARM_FILE_CHUNK_SIZE,
) -> None:
"""Pre-read a file into the OS page cache."""

if chunk_size < 1:
raise ValueError("chunk_size must be greater than or equal to 1")

path = Path(file_path).expanduser()
if path.is_symlink():
logger.info("Skipping warm_file for symlink", path=os.fspath(path))
return
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
if path.is_dir():
raise IsADirectoryError(f"Expected a file: {path}")
if not path.is_file():
raise ValueError(f"Expected a regular file: {path}")

with path.open("rb") as file:
while file.read(chunk_size):
pass


def warm_dir(
directory: str | PathLike[str],
parallelism: int = DEFAULT_WARM_DIR_PARALLELISM,
) -> None:
"""Pre-read all files in a directory into the OS page cache."""

if parallelism < 1:
raise ValueError("parallelism must be greater than or equal to 1")

path = Path(directory).expanduser()
if not path.exists():
raise FileNotFoundError(f"Directory not found: {path}")
if not path.is_dir():
raise NotADirectoryError(f"Expected a directory: {path}")

if parallelism == 1:
for file_path in _iter_regular_files(path):
warm_file(file_path)
else:
with concurrent.futures.ThreadPoolExecutor(max_workers=parallelism) as executor:
pending = set()

for file_path in _iter_regular_files(path):
pending.add(executor.submit(warm_file, file_path))

# Keep a small bounded queue of work instead of materializing the
# full directory tree before starting any reads.
if len(pending) >= parallelism * 2:
done, pending = concurrent.futures.wait(
pending,
return_when=concurrent.futures.FIRST_COMPLETED,
)
for future in done:
future.result()

for future in concurrent.futures.as_completed(pending):
future.result()

logger.info("Warmed directory into OS page cache", path=os.fspath(path))
66 changes: 65 additions & 1 deletion projects/fal/tests/unit/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from unittest.mock import MagicMock, PropertyMock, patch

import pytest
from fastapi import WebSocket
from fastapi import FastAPI, WebSocket
from pydantic import BaseModel

import fal
Expand Down Expand Up @@ -370,6 +370,70 @@ def test_load_function_from_allows_container_without_ref():
assert loaded.function.options is options


@pytest.mark.asyncio
async def test_app_lifespan_warms_data_dirs_before_setup(
isolate_agent_env, monkeypatch: pytest.MonkeyPatch
):
import fal.app as fal_app_module

calls: list[tuple[str, str | None]] = []

monkeypatch.setattr(
fal_app_module,
"warm_dir",
lambda directory: calls.append(("warm", directory)),
)

class WarmedApp(App):
data_dirs = ["/data/models", "/data/tokenizer"]

def setup(self):
calls.append(("setup", None))

app = WarmedApp()

async with app.lifespan(FastAPI()):
pass

assert calls == [
("warm", "/data/models"),
("warm", "/data/tokenizer"),
("setup", None),
]


@pytest.mark.asyncio
async def test_app_lifespan_resolves_relative_data_dirs_from_data(
isolate_agent_env, monkeypatch: pytest.MonkeyPatch
):
import fal.app as fal_app_module

calls: list[tuple[str, str | None]] = []

monkeypatch.setattr(
fal_app_module,
"warm_dir",
lambda directory: calls.append(("warm", str(directory))),
)

class WarmedApp(App):
data_dirs = ["models", "nested/tokenizer"]

def setup(self):
calls.append(("setup", None))

app = WarmedApp()

async with app.lifespan(FastAPI()):
pass

assert calls == [
("warm", "/data/models"),
("warm", "/data/nested/tokenizer"),
("setup", None),
]


def test_wrap_app_allows_resolver_with_container_kind():
from fal.app import wrap_app

Expand Down
117 changes: 117 additions & 0 deletions projects/fal/tests/unit/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from __future__ import annotations

from pathlib import Path

import pytest

import fal.helpers as helpers


def test_warm_file_reads_file_and_validates_inputs(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
):
logs: list[tuple[str, dict[str, str]]] = []

class FakeLogger:
def info(self, event: str, **kwargs: str) -> None:
logs.append((event, kwargs))

monkeypatch.setattr(helpers, "logger", FakeLogger())

file_path = tmp_path / "model.bin"
file_path.write_bytes(b"abcdef")

helpers.warm_file(file_path, chunk_size=2)

with pytest.raises(ValueError, match="chunk_size"):
helpers.warm_file(file_path, chunk_size=0)

with pytest.raises(FileNotFoundError, match="File not found"):
helpers.warm_file(tmp_path / "missing.bin")

with pytest.raises(IsADirectoryError, match="Expected a file"):
helpers.warm_file(tmp_path)

symlink_path = tmp_path / "model-link.bin"
try:
symlink_path.symlink_to(file_path)
except OSError:
pytest.skip("symlinks are not supported in this environment")

helpers.warm_file(symlink_path)
assert logs == [
("Skipping warm_file for symlink", {"path": str(symlink_path)}),
]


def test_warm_dir_warms_nested_files_with_parallelism(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
):
logs: list[tuple[str, dict[str, str]]] = []

class FakeLogger:
def info(self, event: str, **kwargs: str) -> None:
logs.append((event, kwargs))

monkeypatch.setattr(helpers, "logger", FakeLogger())

top_level_file = tmp_path / "model.bin"
top_level_file.write_bytes(b"top-level")

nested_dir = tmp_path / "nested"
nested_dir.mkdir()
nested_file = nested_dir / "tokenizer.json"
nested_file.write_text("{}", encoding="utf-8")

symlink_path = tmp_path / "model-link.bin"
try:
symlink_path.symlink_to(top_level_file)
except OSError:
symlink_path = None

warmed_paths: list[Path] = []

def fake_warm_file(
file_path: str | Path,
chunk_size: int = helpers.DEFAULT_WARM_FILE_CHUNK_SIZE,
) -> None:
del chunk_size
warmed_paths.append(Path(file_path))

monkeypatch.setattr(helpers, "warm_file", fake_warm_file)

helpers.warm_dir(tmp_path, parallelism=2)

assert set(warmed_paths) == {top_level_file, nested_file}
if symlink_path is not None:
assert symlink_path not in warmed_paths
assert logs == [
("Warmed directory into OS page cache", {"path": str(tmp_path)}),
]


def test_warm_dir_skips_empty_directories(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
):
(tmp_path / "empty").mkdir()
monkeypatch.setattr(
helpers,
"warm_file",
lambda *args, **kwargs: pytest.fail("warm_file should not be called"),
)

helpers.warm_dir(tmp_path)


def test_warm_dir_validates_inputs(tmp_path: Path):
file_path = tmp_path / "model.bin"
file_path.write_bytes(b"abcdef")

with pytest.raises(ValueError, match="parallelism"):
helpers.warm_dir(tmp_path, parallelism=0)

with pytest.raises(FileNotFoundError, match="Directory not found"):
helpers.warm_dir(tmp_path / "missing")

with pytest.raises(NotADirectoryError, match="Expected a directory"):
helpers.warm_dir(file_path)
Loading