diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index b226203fb7..1848a84875 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -3628,6 +3628,54 @@ def body_fn(iter, x): sqnr = compute_error(output_ref, output) self.assertGreater(sqnr, 35, f"SQNR too low: {sqnr} dB") + def test_nested_while_loop_convert_preserves_subgraphs(self): + """Regression test for https://github.com/pytorch/ao/issues/4455. + + When a while_loop body contains a nested while_loop, the inner cond/body + subgraphs are referenced only from inside the outer body subgraph. The + cleanup in convert() (``delete_all_unused_submodules``) only inspects the + top-level graph, so it used to delete those inner subgraphs, leaving + dangling references that crashed graph linting in DuplicateDQPass. + """ + from torch._higher_order_ops.while_loop import while_loop + + class NestedWhileLoopModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + def inner_cond(j, y): + return j < 2 + + def inner_body(j, y): + return j + 1, self.linear(y) + + def outer_cond(i, x): + return i < 3 + + def outer_body(i, x): + _, y = while_loop(inner_cond, inner_body, (torch.tensor(0), x)) + return i + 1, y + + _, result = while_loop( + outer_cond, outer_body, (torch.tensor(0), x) + ) + return result + + quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) + m = NestedWhileLoopModel().eval() + example_inputs = (torch.randn(5),) + + m_export = torch.export.export(m, example_inputs).module() + m_prepared = prepare_pt2e(m_export, quantizer) + with torch.no_grad(): + m_prepared(*example_inputs) + # Previously raised in DuplicateDQPass ("references nonexistent attribute"). + m_converted = convert_pt2e(m_prepared) + with torch.no_grad(): + m_converted(*example_inputs) + @skipIfNoQNNPACK class TestQuantizePT2EAffineQuantization(PT2EQuantizationTestCase): diff --git a/torchao/quantization/pt2e/convert.py b/torchao/quantization/pt2e/convert.py index 16e64b5b08..ec1f8869bf 100644 --- a/torchao/quantization/pt2e/convert.py +++ b/torchao/quantization/pt2e/convert.py @@ -930,6 +930,54 @@ def convert_weighted_module( setattr(modules[parent_name], name, ref_qmodule) +def _referenced_submodule_paths(gm: GraphModule) -> set[str]: + """Return the fully-qualified paths of every submodule/attribute referenced + via ``get_attr`` or ``call_module``, scanning ``gm``'s own graph *and* the + graphs of its nested ``GraphModule`` children (recursively). + + ``GraphModule.delete_all_unused_submodules`` only inspects the immediate + graph, so a submodule referenced solely from inside a nested control-flow + subgraph (the ``while_loop`` / ``scan`` / ``cond`` body subgraphs that + ``torch.export`` stores as child modules) is wrongly treated as unused. + """ + used: set[str] = set() + + def visit(module: torch.nn.Module, prefix: str) -> None: + if not isinstance(module, GraphModule): + return + for node in module.graph.nodes: + if node.op in ("call_module", "get_attr") and isinstance(node.target, str): + acc = "" + for part in node.target.split("."): + acc = f"{acc}.{part}" if acc else part + used.add(f"{prefix}{acc}") + for name, child in module.named_children(): + visit(child, f"{prefix}{name}.") + + visit(gm, "") + return used + + +def _delete_all_unused_submodules(model: GraphModule) -> None: + """Recursion-aware replacement for ``model.delete_all_unused_submodules()``. + + The stock method only looks at the top-level graph, so it can delete + submodules that are still referenced from nested control-flow subgraphs, + leaving dangling references that later fail graph linting (e.g. in + ``DuplicateDQPass``). We snapshot those references and restore any submodule + that the stock cleanup incorrectly removed. + """ + referenced = _referenced_submodule_paths(model) + snapshot = {name: mod for name, mod in model.named_modules() if name in referenced} + model.delete_all_unused_submodules() + # Restore parents before children so intermediate paths exist on re-add. + for name in sorted(snapshot, key=lambda n: n.count(".")): + try: + model.get_submodule(name) + except AttributeError: + model.add_submodule(name, snapshot[name]) + + def convert( model: GraphModule, is_reference: bool = False, @@ -1190,7 +1238,7 @@ def convert( # removes qconfig and activation_post_process modules if _remove_qconfig_flag: _remove_qconfig(model) - model.delete_all_unused_submodules() + _delete_all_unused_submodules(model) model.meta.pop("_observed_graph_module_attrs", None) return model