diff --git a/projects/fal/src/fal/cli/_utils.py b/projects/fal/src/fal/cli/_utils.py index 0f5812642..74f021af8 100644 --- a/projects/fal/src/fal/cli/_utils.py +++ b/projects/fal/src/fal/cli/_utils.py @@ -4,6 +4,11 @@ from dataclasses import dataclass, field from typing import Any, Optional +from packaging.requirements import Requirement +from packaging.utils import canonicalize_name +from rich.panel import Panel +from rich.text import Text + from fal.api import Options from fal.project import find_project_root, find_pyproject_toml, parse_pyproject_toml from fal.sdk import AuthModeLiteral, DeploymentStrategyLiteral @@ -189,3 +194,85 @@ def _validate_requirements(requirements: Any) -> None: def _validate_str_list(field_name: str, value: Any) -> None: if not (isinstance(value, list) and all(isinstance(item, str) for item in value)): raise ValueError(f"{field_name} must be a list of strings.") + + +def warn_if_vulnerable_torch( + requirements: list[str] | list[list[str]], + console: Any, +) -> None: + """Warn if requirements include a torch version with known serialization + vulnerabilities. + + PyTorch <= 2.5 is vulnerable to arbitrary code execution via malicious model + files due to unsafe pickle deserialization. + See https://github.com/pytorch/pytorch/issues/31875 + """ + if not console.is_terminal: + return + + # Flatten potentially layered requirements (list[str] or list[list[str]]) + flat_reqs: list[str] = [] + for item in requirements: + if isinstance(item, list): + flat_reqs.extend(item) + else: + flat_reqs.append(item) + + torch_req = None + for req_str in flat_reqs: + try: + req = Requirement(req_str) + except Exception: + continue + if canonicalize_name(req.name) == "torch": + torch_req = req + break + + if torch_req is None: + return + + _MITIGATION = ( + "In a future release, fal will automatically set " + "[bold cyan]weights_only=True[/bold cyan]\n" + "on [bold cyan]torch.load()[/bold cyan] calls for torch versions prior to 2.6 " + "as a mitigation,\n" + "but upgrading to [bold cyan]torch>=2.6[/bold cyan] is strongly recommended." + ) + + specifier = torch_req.specifier + if not specifier: + return + + # A version is safe if major > 2, or major == 2 and minor >= 6. + def _is_safe_version(version_str: str) -> bool: + parts = version_str.split(".") + try: + major, minor = int(parts[0]), int(parts[1]) + except (IndexError, ValueError): + return False + return major > 2 or (major == 2 and minor >= 6) + + # Safe only if there's an explicit constraint guaranteeing a safe minimum: + # ==2.6.x, >=2.6, >2.5.x, etc. Upper-bound operators (<, <=) don't + # guarantee safety so we ignore them here and warn. + guaranteeing_ops = {"==", "===", ">=", ">"} + if any( + c.operator in guaranteeing_ops and _is_safe_version(c.version) + for c in specifier + ): + return + detail = ( + "Versions [bold]2.5 and earlier[/bold] of pytorch have a known " + "serialization vulnerability\n" + "that may allow malicious model files to execute arbitrary code.\n\n" + + _MITIGATION + ) + + panel = Panel( + Text.from_markup(detail), + title="[bold red]Security Warning[/bold red]", + border_style="red", + padding=(1, 2), + expand=False, + ) + console.print(panel) diff --git a/projects/fal/src/fal/cli/run.py b/projects/fal/src/fal/cli/run.py index fa9ec852a..cabe16d51 100644 --- a/projects/fal/src/fal/cli/run.py +++ b/projects/fal/src/fal/cli/run.py @@ -2,7 +2,12 @@ from dataclasses import replace from pathlib import Path -from ._utils import AppData, get_app_data_from_toml, is_app_name +from ._utils import ( + AppData, + get_app_data_from_toml, + is_app_name, + warn_if_vulnerable_torch, +) from .parser import FalClientParser, RefAction, add_env_argument @@ -61,6 +66,11 @@ def _run(args): limit_max_requests=args.limit_max_requests, ) + warn_if_vulnerable_torch( + loaded.function.options.environment.get("requirements", []), + args.console, + ) + isolated_function = loaded.function if args.machine_type is not None: isolated_function.options.host["machine_type"] = args.machine_type diff --git a/projects/fal/tests/unit/cli/test_run.py b/projects/fal/tests/unit/cli/test_run.py index 7757ffee5..1b74342e2 100644 --- a/projects/fal/tests/unit/cli/test_run.py +++ b/projects/fal/tests/unit/cli/test_run.py @@ -3,11 +3,90 @@ import pytest +from fal.cli._utils import warn_if_vulnerable_torch from fal.cli.main import parse_args from fal.cli.run import _run from fal.project import find_project_root +def make_console(is_terminal=True): + console = MagicMock() + console.is_terminal = is_terminal + return console + + +class TestWarnIfVulnerableTorch: + def test_no_torch_in_requirements(self): + console = make_console() + warn_if_vulnerable_torch(["numpy==1.26.4", "pillow"], console) + console.print.assert_not_called() + + def test_empty_requirements(self): + console = make_console() + warn_if_vulnerable_torch([], console) + console.print.assert_not_called() + + def test_torch_pinned_to_vulnerable_version(self): + console = make_console() + warn_if_vulnerable_torch(["torch==2.4.0"], console) + console.print.assert_called_once() + + def test_torch_pinned_to_latest_vulnerable_version(self): + console = make_console() + warn_if_vulnerable_torch(["torch==2.5.0"], console) + console.print.assert_called_once() + + def test_torch_pinned_to_vulnerable_patch_version(self): + console = make_console() + warn_if_vulnerable_torch(["torch==2.4.1"], console) + console.print.assert_called_once() + + def test_torch_pinned_to_safe_version(self): + console = make_console() + warn_if_vulnerable_torch(["torch==2.6.0"], console) + console.print.assert_not_called() + + def test_torch_with_safe_lower_bound(self): + console = make_console() + warn_if_vulnerable_torch(["torch>=2.6"], console) + console.print.assert_not_called() + + def test_torch_specifier_spanning_boundary(self): + console = make_console() + warn_if_vulnerable_torch(["torch>=2.4,<2.6"], console) + console.print.assert_called_once() + + def test_torch_with_vulnerable_upper_bound(self): + console = make_console() + warn_if_vulnerable_torch(["torch<=2.5"], console) + console.print.assert_called_once() + + def test_torch_with_exclusive_vulnerable_upper_bound(self): + console = make_console() + warn_if_vulnerable_torch(["torch<2.6"], console) + console.print.assert_called_once() + + def test_torch_without_version_pin(self): + console = make_console() + warn_if_vulnerable_torch(["torch"], console) + console.print.assert_not_called() + + def test_layered_requirements_with_vulnerable_torch(self): + console = make_console() + warn_if_vulnerable_torch([["torch==2.4.0"], ["flash-attn"]], console) + console.print.assert_called_once() + + def test_layered_requirements_with_safe_torch(self): + console = make_console() + warn_if_vulnerable_torch([["torch==2.6.0"], ["flash-attn"]], console) + console.print.assert_not_called() + + def test_non_terminal_console_suppresses_warning(self): + console = make_console(is_terminal=False) + warn_if_vulnerable_torch(["torch==2.4.0"], console) + console.print.assert_not_called() + + def test_run(): args = parse_args(["run", "/my/path.py::myfunc"]) assert args.func == _run