Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions projects/fal/src/fal/cli/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from dataclasses import dataclass, field
from typing import Any, Optional

from packaging.requirements import Requirement
from packaging.utils import canonicalize_name
from rich.panel import Panel
from rich.text import Text

from fal.api import Options
from fal.project import find_project_root, find_pyproject_toml, parse_pyproject_toml
from fal.sdk import AuthModeLiteral, DeploymentStrategyLiteral
Expand Down Expand Up @@ -189,3 +194,85 @@ def _validate_requirements(requirements: Any) -> None:
def _validate_str_list(field_name: str, value: Any) -> None:
if not (isinstance(value, list) and all(isinstance(item, str) for item in value)):
raise ValueError(f"{field_name} must be a list of strings.")


def warn_if_vulnerable_torch(
requirements: list[str] | list[list[str]],
console: Any,
) -> None:
"""Warn if requirements include a torch version with known serialization
vulnerabilities.

PyTorch <= 2.5 is vulnerable to arbitrary code execution via malicious model
files due to unsafe pickle deserialization.
See https://github.com/pytorch/pytorch/issues/31875
"""
if not console.is_terminal:
return

# Flatten potentially layered requirements (list[str] or list[list[str]])
flat_reqs: list[str] = []
for item in requirements:
if isinstance(item, list):
flat_reqs.extend(item)
else:
flat_reqs.append(item)

torch_req = None
for req_str in flat_reqs:
try:
req = Requirement(req_str)
except Exception:
continue
if canonicalize_name(req.name) == "torch":
torch_req = req
break

if torch_req is None:
return

_MITIGATION = (
"In a future release, fal will automatically set "
"[bold cyan]weights_only=True[/bold cyan]\n"
"on [bold cyan]torch.load()[/bold cyan] calls for torch versions prior to 2.6 "
"as a mitigation,\n"
"but upgrading to [bold cyan]torch>=2.6[/bold cyan] is strongly recommended."
)

specifier = torch_req.specifier
if not specifier:
return

# A version is safe if major > 2, or major == 2 and minor >= 6.
def _is_safe_version(version_str: str) -> bool:
parts = version_str.split(".")
try:
major, minor = int(parts[0]), int(parts[1])
except (IndexError, ValueError):
return False
return major > 2 or (major == 2 and minor >= 6)
Comment on lines +247 to +253

@badayvedat badayvedat Apr 13, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fails on something like this, while it should not

from fal.cli._utils import warn_if_vulnerable_torch

from rich.console import Console
from rich.theme import Theme

console = Console(theme=Theme(), soft_wrap=True)

warn_if_vulnerable_torch(["torch>3"], console=console)


# Safe only if there's an explicit constraint guaranteeing a safe minimum:
# ==2.6.x, >=2.6, >2.5.x, etc. Upper-bound operators (<, <=) don't
# guarantee safety so we ignore them here and warn.
guaranteeing_ops = {"==", "===", ">=", ">"}
if any(
c.operator in guaranteeing_ops and _is_safe_version(c.version)
for c in specifier
):
return
detail = (
"Versions [bold]2.5 and earlier[/bold] of pytorch have a known "
"serialization vulnerability\n"
"that may allow malicious model files to execute arbitrary code.\n\n"
+ _MITIGATION
)

panel = Panel(
Text.from_markup(detail),
title="[bold red]Security Warning[/bold red]",
border_style="red",
padding=(1, 2),
expand=False,
)
console.print(panel)
12 changes: 11 additions & 1 deletion projects/fal/src/fal/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
from dataclasses import replace
from pathlib import Path

from ._utils import AppData, get_app_data_from_toml, is_app_name
from ._utils import (
AppData,
get_app_data_from_toml,
is_app_name,
warn_if_vulnerable_torch,
)
from .parser import FalClientParser, RefAction, add_env_argument


Expand Down Expand Up @@ -61,6 +66,11 @@ def _run(args):
limit_max_requests=args.limit_max_requests,
)

warn_if_vulnerable_torch(
loaded.function.options.environment.get("requirements", []),
args.console,
)

isolated_function = loaded.function
if args.machine_type is not None:
isolated_function.options.host["machine_type"] = args.machine_type
Expand Down
79 changes: 79 additions & 0 deletions projects/fal/tests/unit/cli/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,90 @@

import pytest

from fal.cli._utils import warn_if_vulnerable_torch
from fal.cli.main import parse_args
from fal.cli.run import _run
from fal.project import find_project_root


def make_console(is_terminal=True):
console = MagicMock()
console.is_terminal = is_terminal
return console


class TestWarnIfVulnerableTorch:
def test_no_torch_in_requirements(self):
console = make_console()
warn_if_vulnerable_torch(["numpy==1.26.4", "pillow"], console)
console.print.assert_not_called()

def test_empty_requirements(self):
console = make_console()
warn_if_vulnerable_torch([], console)
console.print.assert_not_called()

def test_torch_pinned_to_vulnerable_version(self):
console = make_console()
warn_if_vulnerable_torch(["torch==2.4.0"], console)
console.print.assert_called_once()

def test_torch_pinned_to_latest_vulnerable_version(self):
console = make_console()
warn_if_vulnerable_torch(["torch==2.5.0"], console)
console.print.assert_called_once()

def test_torch_pinned_to_vulnerable_patch_version(self):
console = make_console()
warn_if_vulnerable_torch(["torch==2.4.1"], console)
console.print.assert_called_once()

def test_torch_pinned_to_safe_version(self):
console = make_console()
warn_if_vulnerable_torch(["torch==2.6.0"], console)
console.print.assert_not_called()

def test_torch_with_safe_lower_bound(self):
console = make_console()
warn_if_vulnerable_torch(["torch>=2.6"], console)
console.print.assert_not_called()

def test_torch_specifier_spanning_boundary(self):
console = make_console()
warn_if_vulnerable_torch(["torch>=2.4,<2.6"], console)
console.print.assert_called_once()

def test_torch_with_vulnerable_upper_bound(self):
console = make_console()
warn_if_vulnerable_torch(["torch<=2.5"], console)
console.print.assert_called_once()

def test_torch_with_exclusive_vulnerable_upper_bound(self):
console = make_console()
warn_if_vulnerable_torch(["torch<2.6"], console)
console.print.assert_called_once()

def test_torch_without_version_pin(self):
console = make_console()
warn_if_vulnerable_torch(["torch"], console)
console.print.assert_not_called()

def test_layered_requirements_with_vulnerable_torch(self):
console = make_console()
warn_if_vulnerable_torch([["torch==2.4.0"], ["flash-attn"]], console)
console.print.assert_called_once()

def test_layered_requirements_with_safe_torch(self):
console = make_console()
warn_if_vulnerable_torch([["torch==2.6.0"], ["flash-attn"]], console)
console.print.assert_not_called()

def test_non_terminal_console_suppresses_warning(self):
console = make_console(is_terminal=False)
warn_if_vulnerable_torch(["torch==2.4.0"], console)
console.print.assert_not_called()


def test_run():
args = parse_args(["run", "/my/path.py::myfunc"])
assert args.func == _run
Expand Down
Loading