Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions invokeai/backend/patches/layer_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,13 @@ def apply_smart_model_patch(
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
# without searching, but some legacy code still uses flattened keys.
first_key = next(iter(patch.layers.keys()))
layer_keys_are_flattened = "." not in first_key
#
# We inspect *all* keys rather than just the first one: a flattened key never contains a dot, so the patch is
# only flattened if no key contains a dot. Checking only the first key misclassifies non-flattened patches whose
# first layer happens to target a top-level module with a single-token name (e.g. a Flux2 diffusers LoRA whose
# first converted layer is `lora_transformer-context_embedder`), causing a spurious assertion failure on
# subsequent dotted keys.
layer_keys_are_flattened = not any("." in key for key in patch.layers.keys())

prefix_len = len(prefix)

Expand Down
60 changes: 60 additions & 0 deletions tests/backend/patches/test_layer_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,66 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear_layer_2(self.linear_layer_1(x))


class DummyModuleWithNestedLayer(torch.nn.Module):
"""A model with both a top-level layer (dotless key) and a nested submodule (dotted key)."""

def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype):
super().__init__()
self.top_layer = torch.nn.Linear(in_features, in_features, device=device, dtype=dtype)
self.block = DummyModuleWithOneLayer(in_features, out_features, device=device, dtype=dtype)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.block(self.top_layer(x))


@torch.no_grad()
def test_apply_smart_model_patches_mixed_dotted_and_dotless_keys():
"""Regression test: a patch whose first layer key is a dotless top-level module (e.g. a Flux2 diffusers
LoRA whose first converted layer is `context_embedder`) followed by dotted nested keys must not be
misclassified as a flattened patch. Inspecting only the first key would set layer_keys_are_flattened=True
and trip `assert "." not in layer_key` in _get_submodule on the subsequent dotted keys.
"""
dtype = torch.float16
in_features = 4
out_features = 8
lora_rank = 2
model = DummyModuleWithNestedLayer(in_features, out_features, device="cpu", dtype=dtype)
apply_custom_layers_to_model(model)

# Insertion order matters: the dotless top-level key comes first, the dotted nested key second.
lora_layers = {
"top_layer": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, in_features), device="cpu", dtype=dtype),
"lora_up.weight": torch.ones((in_features, lora_rank), device="cpu", dtype=dtype),
},
),
"block.linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, in_features), device="cpu", dtype=dtype),
"lora_up.weight": torch.ones((out_features, lora_rank), device="cpu", dtype=dtype),
},
),
}
lora = ModelPatchRaw(lora_layers)
assert next(iter(lora.layers.keys())) == "top_layer" # the dotless key is first

input = torch.randn(1, in_features, device="cpu", dtype=dtype)
output_before_patch = model(input)

# This previously raised AssertionError on the "block.linear_layer_1" key.
with LayerPatcher.apply_smart_model_patches(
model=model, patches=[(lora, 0.5)], prefix="", dtype=dtype, force_direct_patching=True
):
output_during_patch = model(input)

output_after_patch = model(input)

# Both layers were actually patched (output changed), and unpatching restored the original output.
assert not torch.allclose(output_before_patch, output_during_patch)
assert torch.allclose(output_before_patch, output_after_patch)


@pytest.mark.parametrize(
"device",
[
Expand Down
Loading