diff --git a/src/bentoml/_internal/frameworks/pytorch.py b/src/bentoml/_internal/frameworks/pytorch.py index 6f442edcc38..e167c8d4ad9 100644 --- a/src/bentoml/_internal/frameworks/pytorch.py +++ b/src/bentoml/_internal/frameworks/pytorch.py @@ -77,6 +77,13 @@ def load_model( ) weight_file = bentoml_model.path_of(MODEL_FILENAME) + # `save_model` serializes the whole model object (not just a state dict) via + # `torch.load`'s pickle path, so it must be loaded with `weights_only=False`. + # PyTorch >= 2.6 flipped the default to `weights_only=True`, which cannot + # unpickle arbitrary classes and breaks loading. The model store is a trusted, + # BentoML-produced artifact, so default to `weights_only=False` while still + # allowing the caller to override it through `torch_load_args`. + torch_load_args.setdefault("weights_only", False) with Path(weight_file).open("rb") as file: model: "torch.nn.Module" = torch.load( file, map_location=device_id, **torch_load_args diff --git a/tests/integration/frameworks/test_pytorch_unit.py b/tests/integration/frameworks/test_pytorch_unit.py index 26284dc03c6..1ff4c5f88d2 100644 --- a/tests/integration/frameworks/test_pytorch_unit.py +++ b/tests/integration/frameworks/test_pytorch_unit.py @@ -3,10 +3,22 @@ import pytest import torch +import bentoml +from bentoml._internal.configuration.containers import BentoMLContainer from bentoml._internal.frameworks.pytorch import PyTorchTensorContainer +from bentoml._internal.models import ModelStore from bentoml._internal.runner.container import AutoContainer +class _Net(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = torch.nn.Linear(4, 2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.fc(x) + + @pytest.mark.parametrize("batch_axis", [0, 1]) def test_pytorch_container(batch_axis: int): one_batch = torch.arange(6).reshape(2, 3) @@ -39,3 +51,22 @@ def test_pytorch_container(batch_axis: int): AutoContainer.from_payload(AutoContainer.to_payload(one_batch, batch_dim=0)) == one_batch ).all() + + +def test_load_model_defaults_to_weights_only_false(tmp_path): + # Regression test for #5365: PyTorch >= 2.6 defaults `torch.load` to + # `weights_only=True`, which cannot unpickle the whole-model artifact that + # `save_model` writes via cloudpickle. `load_model` must default to + # `weights_only=False` so a trusted, BentoML-produced model loads correctly, + # while still honoring an explicit override passed by the caller. + BentoMLContainer.model_store.set(ModelStore(str(tmp_path))) + try: + saved = bentoml.pytorch.save_model("weights_only_model", _Net()) + + loaded = bentoml.pytorch.load_model(saved) + assert isinstance(loaded, torch.nn.Module) + + with pytest.raises(Exception): + bentoml.pytorch.load_model(saved, weights_only=True) + finally: + BentoMLContainer.model_store.reset()