Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -3628,6 +3628,54 @@ def body_fn(iter, x):
sqnr = compute_error(output_ref, output)
self.assertGreater(sqnr, 35, f"SQNR too low: {sqnr} dB")

def test_nested_while_loop_convert_preserves_subgraphs(self):
"""Regression test for https://github.com/pytorch/ao/issues/4455.

When a while_loop body contains a nested while_loop, the inner cond/body
subgraphs are referenced only from inside the outer body subgraph. The
cleanup in convert() (``delete_all_unused_submodules``) only inspects the
top-level graph, so it used to delete those inner subgraphs, leaving
dangling references that crashed graph linting in DuplicateDQPass.
"""
from torch._higher_order_ops.while_loop import while_loop

class NestedWhileLoopModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 5)

def forward(self, x):
def inner_cond(j, y):
return j < 2

def inner_body(j, y):
return j + 1, self.linear(y)

def outer_cond(i, x):
return i < 3

def outer_body(i, x):
_, y = while_loop(inner_cond, inner_body, (torch.tensor(0), x))
return i + 1, y

_, result = while_loop(
outer_cond, outer_body, (torch.tensor(0), x)
)
return result

quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
m = NestedWhileLoopModel().eval()
example_inputs = (torch.randn(5),)

m_export = torch.export.export(m, example_inputs).module()
m_prepared = prepare_pt2e(m_export, quantizer)
with torch.no_grad():
m_prepared(*example_inputs)
# Previously raised in DuplicateDQPass ("references nonexistent attribute").
m_converted = convert_pt2e(m_prepared)
with torch.no_grad():
m_converted(*example_inputs)


@skipIfNoQNNPACK
class TestQuantizePT2EAffineQuantization(PT2EQuantizationTestCase):
Expand Down
50 changes: 49 additions & 1 deletion torchao/quantization/pt2e/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,54 @@ def convert_weighted_module(
setattr(modules[parent_name], name, ref_qmodule)


def _referenced_submodule_paths(gm: GraphModule) -> set[str]:
"""Return the fully-qualified paths of every submodule/attribute referenced
via ``get_attr`` or ``call_module``, scanning ``gm``'s own graph *and* the
graphs of its nested ``GraphModule`` children (recursively).

``GraphModule.delete_all_unused_submodules`` only inspects the immediate
graph, so a submodule referenced solely from inside a nested control-flow
subgraph (the ``while_loop`` / ``scan`` / ``cond`` body subgraphs that
``torch.export`` stores as child modules) is wrongly treated as unused.
"""
used: set[str] = set()

def visit(module: torch.nn.Module, prefix: str) -> None:
if not isinstance(module, GraphModule):
return
for node in module.graph.nodes:
if node.op in ("call_module", "get_attr") and isinstance(node.target, str):
acc = ""
for part in node.target.split("."):
acc = f"{acc}.{part}" if acc else part
used.add(f"{prefix}{acc}")
for name, child in module.named_children():
visit(child, f"{prefix}{name}.")

visit(gm, "")
return used


def _delete_all_unused_submodules(model: GraphModule) -> None:
"""Recursion-aware replacement for ``model.delete_all_unused_submodules()``.

The stock method only looks at the top-level graph, so it can delete
submodules that are still referenced from nested control-flow subgraphs,
leaving dangling references that later fail graph linting (e.g. in
``DuplicateDQPass``). We snapshot those references and restore any submodule
that the stock cleanup incorrectly removed.
"""
referenced = _referenced_submodule_paths(model)
snapshot = {name: mod for name, mod in model.named_modules() if name in referenced}
model.delete_all_unused_submodules()
# Restore parents before children so intermediate paths exist on re-add.
for name in sorted(snapshot, key=lambda n: n.count(".")):
try:
model.get_submodule(name)
except AttributeError:
model.add_submodule(name, snapshot[name])


def convert(
model: GraphModule,
is_reference: bool = False,
Expand Down Expand Up @@ -1190,7 +1238,7 @@ def convert(
# removes qconfig and activation_post_process modules
if _remove_qconfig_flag:
_remove_qconfig(model)
model.delete_all_unused_submodules()
_delete_all_unused_submodules(model)
model.meta.pop("_observed_graph_module_attrs", None)
return model

Expand Down