From f994de8e406b3ba835ca50fd0d4089ac0f648052 Mon Sep 17 00:00:00 2001 From: Tobias Wennergren Date: Tue, 7 Apr 2026 14:25:49 -0700 Subject: [PATCH 1/2] feat: warn on vulnerable torch versions in fal run Adds a security warning in `fal run` when the app's requirements include a torch version <= 2.5, which has known serialization vulnerabilities allowing malicious model files to execute arbitrary code. The warning also notes that a future release will automatically set weights_only=True on torch.load() calls for torch < 2.6 as a mitigation. Co-Authored-By: Claude Sonnet 4.6 --- projects/fal/src/fal/cli/_utils.py | 83 +++++++++++++++++++++++++ projects/fal/src/fal/cli/run.py | 7 ++- projects/fal/tests/unit/cli/test_run.py | 79 +++++++++++++++++++++++ 3 files changed, 168 insertions(+), 1 deletion(-) diff --git a/projects/fal/src/fal/cli/_utils.py b/projects/fal/src/fal/cli/_utils.py index 0f5812642..44d5b84eb 100644 --- a/projects/fal/src/fal/cli/_utils.py +++ b/projects/fal/src/fal/cli/_utils.py @@ -189,3 +189,86 @@ def _validate_requirements(requirements: Any) -> None: def _validate_str_list(field_name: str, value: Any) -> None: if not (isinstance(value, list) and all(isinstance(item, str) for item in value)): raise ValueError(f"{field_name} must be a list of strings.") + + +def warn_if_vulnerable_torch( + requirements: list[str] | list[list[str]], + console: Any, +) -> None: + """Warn if requirements include a torch version with known serialization vulnerabilities. + + PyTorch <= 2.5 is vulnerable to arbitrary code execution via malicious model files + due to unsafe pickle deserialization. See https://github.com/pytorch/pytorch/issues/31875 + """ + from packaging.requirements import Requirement + from packaging.utils import canonicalize_name + from rich.panel import Panel + from rich.text import Text + + if not console.is_terminal: + return + + # Flatten potentially layered requirements (list[str] or list[list[str]]) + flat_reqs: list[str] = [] + for item in requirements: + if isinstance(item, list): + flat_reqs.extend(item) + else: + flat_reqs.append(item) + + torch_req = None + for req_str in flat_reqs: + try: + req = Requirement(req_str) + except Exception: + continue + if canonicalize_name(req.name) == "torch": + torch_req = req + break + + if torch_req is None: + return + + _MITIGATION = ( + "In a future release, fal will automatically set [bold cyan]weights_only=True[/bold cyan]\n" + "on [bold cyan]torch.load()[/bold cyan] calls for torch versions prior to 2.6 as a mitigation,\n" + "but upgrading to [bold cyan]torch>=2.6[/bold cyan] is strongly recommended." + ) + + specifier = torch_req.specifier + if not specifier: + return + + # A version is safe if major > 2, or major == 2 and minor >= 6. + def _is_safe_version(version_str: str) -> bool: + parts = version_str.split(".") + try: + major, minor = int(parts[0]), int(parts[1]) + except (IndexError, ValueError): + return False + return major > 2 or (major == 2 and minor >= 6) + + # Safe only if there's an explicit constraint guaranteeing a safe minimum: + # ==2.6.x, >=2.6, >2.5.x, etc. Upper-bound operators (<, <=) don't + # guarantee safety so we ignore them here and warn. + guaranteeing_ops = {"==", "===", ">=", ">"} + if any( + c.operator in guaranteeing_ops and _is_safe_version(c.version) + for c in specifier + ): + return + detail = ( + f"[bold]torch{specifier}[/bold] can resolve to a vulnerable version.\n" + "Versions [bold]2.5 and earlier[/bold] have a known serialization vulnerability\n" + "that may allow malicious model files to execute arbitrary code.\n\n" + + _MITIGATION + ) + + panel = Panel( + Text.from_markup(detail), + title="[bold red]Security Warning[/bold red]", + border_style="red", + padding=(1, 2), + expand=False, + ) + console.print(panel) diff --git a/projects/fal/src/fal/cli/run.py b/projects/fal/src/fal/cli/run.py index fa9ec852a..a8c70fda4 100644 --- a/projects/fal/src/fal/cli/run.py +++ b/projects/fal/src/fal/cli/run.py @@ -2,7 +2,7 @@ from dataclasses import replace from pathlib import Path -from ._utils import AppData, get_app_data_from_toml, is_app_name +from ._utils import AppData, get_app_data_from_toml, is_app_name, warn_if_vulnerable_torch from .parser import FalClientParser, RefAction, add_env_argument @@ -61,6 +61,11 @@ def _run(args): limit_max_requests=args.limit_max_requests, ) + warn_if_vulnerable_torch( + loaded.function.options.environment.get("requirements", []), + args.console, + ) + isolated_function = loaded.function if args.machine_type is not None: isolated_function.options.host["machine_type"] = args.machine_type diff --git a/projects/fal/tests/unit/cli/test_run.py b/projects/fal/tests/unit/cli/test_run.py index 7757ffee5..1b74342e2 100644 --- a/projects/fal/tests/unit/cli/test_run.py +++ b/projects/fal/tests/unit/cli/test_run.py @@ -3,11 +3,90 @@ import pytest +from fal.cli._utils import warn_if_vulnerable_torch from fal.cli.main import parse_args from fal.cli.run import _run from fal.project import find_project_root +def make_console(is_terminal=True): + console = MagicMock() + console.is_terminal = is_terminal + return console + + +class TestWarnIfVulnerableTorch: + def test_no_torch_in_requirements(self): + console = make_console() + warn_if_vulnerable_torch(["numpy==1.26.4", "pillow"], console) + console.print.assert_not_called() + + def test_empty_requirements(self): + console = make_console() + warn_if_vulnerable_torch([], console) + console.print.assert_not_called() + + def test_torch_pinned_to_vulnerable_version(self): + console = make_console() + warn_if_vulnerable_torch(["torch==2.4.0"], console) + console.print.assert_called_once() + + def test_torch_pinned_to_latest_vulnerable_version(self): + console = make_console() + warn_if_vulnerable_torch(["torch==2.5.0"], console) + console.print.assert_called_once() + + def test_torch_pinned_to_vulnerable_patch_version(self): + console = make_console() + warn_if_vulnerable_torch(["torch==2.4.1"], console) + console.print.assert_called_once() + + def test_torch_pinned_to_safe_version(self): + console = make_console() + warn_if_vulnerable_torch(["torch==2.6.0"], console) + console.print.assert_not_called() + + def test_torch_with_safe_lower_bound(self): + console = make_console() + warn_if_vulnerable_torch(["torch>=2.6"], console) + console.print.assert_not_called() + + def test_torch_specifier_spanning_boundary(self): + console = make_console() + warn_if_vulnerable_torch(["torch>=2.4,<2.6"], console) + console.print.assert_called_once() + + def test_torch_with_vulnerable_upper_bound(self): + console = make_console() + warn_if_vulnerable_torch(["torch<=2.5"], console) + console.print.assert_called_once() + + def test_torch_with_exclusive_vulnerable_upper_bound(self): + console = make_console() + warn_if_vulnerable_torch(["torch<2.6"], console) + console.print.assert_called_once() + + def test_torch_without_version_pin(self): + console = make_console() + warn_if_vulnerable_torch(["torch"], console) + console.print.assert_not_called() + + def test_layered_requirements_with_vulnerable_torch(self): + console = make_console() + warn_if_vulnerable_torch([["torch==2.4.0"], ["flash-attn"]], console) + console.print.assert_called_once() + + def test_layered_requirements_with_safe_torch(self): + console = make_console() + warn_if_vulnerable_torch([["torch==2.6.0"], ["flash-attn"]], console) + console.print.assert_not_called() + + def test_non_terminal_console_suppresses_warning(self): + console = make_console(is_terminal=False) + warn_if_vulnerable_torch(["torch==2.4.0"], console) + console.print.assert_not_called() + + def test_run(): args = parse_args(["run", "/my/path.py::myfunc"]) assert args.func == _run From 51c7e608f1644794927d9249797558af12c03450 Mon Sep 17 00:00:00 2001 From: Tobias Wennergren Date: Tue, 7 Apr 2026 16:59:10 -0700 Subject: [PATCH 2/2] Linting --- projects/fal/src/fal/cli/_utils.py | 28 ++++++++++++++++------------ projects/fal/src/fal/cli/run.py | 7 ++++++- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/projects/fal/src/fal/cli/_utils.py b/projects/fal/src/fal/cli/_utils.py index 44d5b84eb..74f021af8 100644 --- a/projects/fal/src/fal/cli/_utils.py +++ b/projects/fal/src/fal/cli/_utils.py @@ -4,6 +4,11 @@ from dataclasses import dataclass, field from typing import Any, Optional +from packaging.requirements import Requirement +from packaging.utils import canonicalize_name +from rich.panel import Panel +from rich.text import Text + from fal.api import Options from fal.project import find_project_root, find_pyproject_toml, parse_pyproject_toml from fal.sdk import AuthModeLiteral, DeploymentStrategyLiteral @@ -195,16 +200,13 @@ def warn_if_vulnerable_torch( requirements: list[str] | list[list[str]], console: Any, ) -> None: - """Warn if requirements include a torch version with known serialization vulnerabilities. + """Warn if requirements include a torch version with known serialization + vulnerabilities. - PyTorch <= 2.5 is vulnerable to arbitrary code execution via malicious model files - due to unsafe pickle deserialization. See https://github.com/pytorch/pytorch/issues/31875 + PyTorch <= 2.5 is vulnerable to arbitrary code execution via malicious model + files due to unsafe pickle deserialization. + See https://github.com/pytorch/pytorch/issues/31875 """ - from packaging.requirements import Requirement - from packaging.utils import canonicalize_name - from rich.panel import Panel - from rich.text import Text - if not console.is_terminal: return @@ -230,8 +232,10 @@ def warn_if_vulnerable_torch( return _MITIGATION = ( - "In a future release, fal will automatically set [bold cyan]weights_only=True[/bold cyan]\n" - "on [bold cyan]torch.load()[/bold cyan] calls for torch versions prior to 2.6 as a mitigation,\n" + "In a future release, fal will automatically set " + "[bold cyan]weights_only=True[/bold cyan]\n" + "on [bold cyan]torch.load()[/bold cyan] calls for torch versions prior to 2.6 " + "as a mitigation,\n" "but upgrading to [bold cyan]torch>=2.6[/bold cyan] is strongly recommended." ) @@ -258,8 +262,8 @@ def _is_safe_version(version_str: str) -> bool: ): return detail = ( - f"[bold]torch{specifier}[/bold] can resolve to a vulnerable version.\n" - "Versions [bold]2.5 and earlier[/bold] have a known serialization vulnerability\n" + "Versions [bold]2.5 and earlier[/bold] of pytorch have a known " + "serialization vulnerability\n" "that may allow malicious model files to execute arbitrary code.\n\n" + _MITIGATION ) diff --git a/projects/fal/src/fal/cli/run.py b/projects/fal/src/fal/cli/run.py index a8c70fda4..cabe16d51 100644 --- a/projects/fal/src/fal/cli/run.py +++ b/projects/fal/src/fal/cli/run.py @@ -2,7 +2,12 @@ from dataclasses import replace from pathlib import Path -from ._utils import AppData, get_app_data_from_toml, is_app_name, warn_if_vulnerable_torch +from ._utils import ( + AppData, + get_app_data_from_toml, + is_app_name, + warn_if_vulnerable_torch, +) from .parser import FalClientParser, RefAction, add_env_argument