From 1ba1086f28dfc33e61df0e7a4e990db3f804272a Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Sat, 20 Jun 2026 04:11:57 +0200 Subject: [PATCH] fix(lora): correctly detect flattened layer keys in LayerPatcher The patcher decided whether a patch's layer keys were flattened (legacy underscore-joined) or real dotted module paths by inspecting only the first key. For FLUX.2 Klein diffusers LoRAs whose first converted layer is a dotless top-level module (e.g. `context_embedder`), the whole patch was misclassified as flattened, causing `assert "." not in layer_key` to fail on subsequent dotted keys and crashing LoRA application. Inspect all keys instead: a flattened key never contains a dot, so the patch is flattened only if no key contains one. Add a regression test covering the mixed dotless/dotted key ordering. --- invokeai/backend/patches/layer_patcher.py | 9 +++- tests/backend/patches/test_layer_patcher.py | 60 +++++++++++++++++++++ 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/patches/layer_patcher.py b/invokeai/backend/patches/layer_patcher.py index fbfcd04de20..ce0573565b9 100644 --- a/invokeai/backend/patches/layer_patcher.py +++ b/invokeai/backend/patches/layer_patcher.py @@ -86,8 +86,13 @@ def apply_smart_model_patch( # submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been # replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly # without searching, but some legacy code still uses flattened keys. - first_key = next(iter(patch.layers.keys())) - layer_keys_are_flattened = "." not in first_key + # + # We inspect *all* keys rather than just the first one: a flattened key never contains a dot, so the patch is + # only flattened if no key contains a dot. Checking only the first key misclassifies non-flattened patches whose + # first layer happens to target a top-level module with a single-token name (e.g. a Flux2 diffusers LoRA whose + # first converted layer is `lora_transformer-context_embedder`), causing a spurious assertion failure on + # subsequent dotted keys. + layer_keys_are_flattened = not any("." in key for key in patch.layers.keys()) prefix_len = len(prefix) diff --git a/tests/backend/patches/test_layer_patcher.py b/tests/backend/patches/test_layer_patcher.py index 76712b92197..35b81f4f573 100644 --- a/tests/backend/patches/test_layer_patcher.py +++ b/tests/backend/patches/test_layer_patcher.py @@ -31,6 +31,66 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear_layer_2(self.linear_layer_1(x)) +class DummyModuleWithNestedLayer(torch.nn.Module): + """A model with both a top-level layer (dotless key) and a nested submodule (dotted key).""" + + def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype): + super().__init__() + self.top_layer = torch.nn.Linear(in_features, in_features, device=device, dtype=dtype) + self.block = DummyModuleWithOneLayer(in_features, out_features, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.block(self.top_layer(x)) + + +@torch.no_grad() +def test_apply_smart_model_patches_mixed_dotted_and_dotless_keys(): + """Regression test: a patch whose first layer key is a dotless top-level module (e.g. a Flux2 diffusers + LoRA whose first converted layer is `context_embedder`) followed by dotted nested keys must not be + misclassified as a flattened patch. Inspecting only the first key would set layer_keys_are_flattened=True + and trip `assert "." not in layer_key` in _get_submodule on the subsequent dotted keys. + """ + dtype = torch.float16 + in_features = 4 + out_features = 8 + lora_rank = 2 + model = DummyModuleWithNestedLayer(in_features, out_features, device="cpu", dtype=dtype) + apply_custom_layers_to_model(model) + + # Insertion order matters: the dotless top-level key comes first, the dotted nested key second. + lora_layers = { + "top_layer": LoRALayer.from_state_dict_values( + values={ + "lora_down.weight": torch.ones((lora_rank, in_features), device="cpu", dtype=dtype), + "lora_up.weight": torch.ones((in_features, lora_rank), device="cpu", dtype=dtype), + }, + ), + "block.linear_layer_1": LoRALayer.from_state_dict_values( + values={ + "lora_down.weight": torch.ones((lora_rank, in_features), device="cpu", dtype=dtype), + "lora_up.weight": torch.ones((out_features, lora_rank), device="cpu", dtype=dtype), + }, + ), + } + lora = ModelPatchRaw(lora_layers) + assert next(iter(lora.layers.keys())) == "top_layer" # the dotless key is first + + input = torch.randn(1, in_features, device="cpu", dtype=dtype) + output_before_patch = model(input) + + # This previously raised AssertionError on the "block.linear_layer_1" key. + with LayerPatcher.apply_smart_model_patches( + model=model, patches=[(lora, 0.5)], prefix="", dtype=dtype, force_direct_patching=True + ): + output_during_patch = model(input) + + output_after_patch = model(input) + + # Both layers were actually patched (output changed), and unpatching restored the original output. + assert not torch.allclose(output_before_patch, output_during_patch) + assert torch.allclose(output_before_patch, output_after_patch) + + @pytest.mark.parametrize( "device", [