Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions examples/python/7.1_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,18 +372,25 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm_cpp(
is_debug=False,
shape=(512, 1024, 8192), # 4*T0, 4*T1, 8192
block=(128, 256, 256),
wave_shape=(1, 4),
eliminate_epilogue=True,
no_unroll=False,
):
"""Preshuffle-B MXFP4 GEMM using C++ WaveASM backend."""
gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(1, 4))
gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(
shape, block, wave_shape=wave_shape
)
options.backend = "asm"
options.use_buffer_ops = True
options.wave_runtime = True
options.use_wave_asm_backend = True
options.wave_runtime = True
options.dump_intermediates = "build/intermediates"
options.eliminate_epilogue = eliminate_epilogue
schedule = get_mxfp4_asymmetric_schedule(
eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True
eliminate_epilogue=eliminate_epilogue,
is_bscale_shuffled=True,
unroll_factor=1 if no_unroll else 2,
unroll_kernel=not no_unroll,
)
options.print_ir_after = "all" if is_debug else []
options = set_default_run_config(options)
Expand Down Expand Up @@ -444,5 +451,7 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm(
args.shape,
args.block,
args.eliminate_epilogue,
args.wave_shape,
no_unroll=args.no_unroll,
)
exit(0 if success else 1)
19 changes: 19 additions & 0 deletions examples/python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ def parse_args():
default=None,
help="Enable epilogue elimination (true/false)",
)
parser.add_argument(
"--wave_shape",
type=str,
default=None,
help="Wave shape, e.g. 2,2",
)
parser.add_argument(
"--no-unroll",
action="store_true",
help="Use nounroll (unroll_factor=1) schedule variant",
)

args = parser.parse_args()

Expand All @@ -44,6 +55,8 @@ def parse_args():
args.shape = tuple(map(int, args.shape.split(",")))
if isinstance(args.block, str):
args.block = tuple(map(int, args.block.split(",")))
if isinstance(args.wave_shape, str):
args.wave_shape = tuple(map(int, args.wave_shape.split(",")))

return args

Expand All @@ -64,6 +77,8 @@ def run_test(
shape=None,
block=None,
eliminate_epilogue=None,
wave_shape=None,
no_unroll=False,
):
"""Run a test function multiple times."""
if test_name not in module_globals:
Expand All @@ -78,6 +93,10 @@ def run_test(
kwargs["block"] = block
if eliminate_epilogue is not None:
kwargs["eliminate_epilogue"] = eliminate_epilogue
if wave_shape is not None:
kwargs["wave_shape"] = wave_shape
if no_unroll:
kwargs["no_unroll"] = True

for i in range(repeat):
try:
Expand Down
44 changes: 43 additions & 1 deletion tests/kernel/wave/asm/test_waveasm_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,7 @@ def _dbuf_mxfp4_helper(
wave_shape=None,
reorder_workgroups=None,
eliminate_epilogue=False,
no_unroll=False,
):
"""Shared helper for double-buffered MXFP4 scheduled GEMM tests.

Expand Down Expand Up @@ -1201,7 +1202,10 @@ def _dbuf_mxfp4_helper(
options.eliminate_epilogue = eliminate_epilogue
if use_schedule:
schedule = get_mxfp4_asymmetric_schedule(
eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True
eliminate_epilogue=eliminate_epilogue,
is_bscale_shuffled=True,
unroll_factor=1 if no_unroll else 2,
unroll_kernel=not no_unroll,
)
else:
schedule = None
Expand Down Expand Up @@ -1445,6 +1449,44 @@ def test_dbuf_4wave_mxfp4_gemm_cpp_backend(
)


@pytest.mark.parametrize("eliminate_epilogue", [True], ids=["ee"])
@pytest.mark.parametrize(
"shape,block,wave_shape",
[
pytest.param((1024, 896, 8192), (256, 224, 256), (2, 2), id="256x224x256"),
],
)
def test_dbuf_4wave_mxfp4_nounroll_gemm_cpp_backend(
shape,
block,
wave_shape,
eliminate_epilogue,
compiler,
dump_asm,
):
"""End-to-end test for asymmetric MXFP4 GEMM with no-unroll schedule.

The no-unroll schedule (unroll_factor=1) reduces register pressure by
not unrolling the K-loop body, allowing larger block sizes like
256x224x256 that would otherwise exceed the 256-VGPR hardware limit
with the standard asymmetric schedule.
"""
_dbuf_mxfp4_helper(
shape=shape,
block=block,
num_waves=4,
use_stagger=False,
compiler=compiler,
dump_asm=dump_asm,
use_buffer_ops=True,
use_schedule=True,
output_dtype="f32",
wave_shape=wave_shape,
eliminate_epilogue=eliminate_epilogue,
no_unroll=True,
)


@pytest.mark.parametrize(
"shape,block,wave_shape",
[
Expand Down
184 changes: 123 additions & 61 deletions wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1573,7 +1573,10 @@ def mxfp4_dbuf_schedule():


def get_mxfp4_asymmetric_schedule(
eliminate_epilogue: bool = False, is_bscale_shuffled: bool = False
eliminate_epilogue: bool = False,
is_bscale_shuffled: bool = False,
unroll_factor: int = 2,
unroll_kernel: bool = True,
):
"""Return an asymmetric-prefetch MXFP4 schedule for wave_compile().

Expand Down Expand Up @@ -1645,7 +1648,7 @@ def mxfp4_dbuf_schedule():

# This forces the pipeline to use double buffering
pipeline_loop.multi_buffer_count = 2
pipeline_loop.unroll_factor = 2
pipeline_loop.unroll_factor = unroll_factor

with pipeline_loop as pl:
pl.set_stage(
Expand Down Expand Up @@ -1783,68 +1786,126 @@ def mxfp4_dbuf_schedule():
# Interleave MFMAs with memory ops (matching aiter f4gemm pattern).
# Clamp start_offsets so they fit within each partition when the M
# tile count is odd (e.g. 7 tiles split into 4+3).
base_offsets = [0, 3, 2, 0]
base_intervals = [4, 4, 2, 4]

def _clamp_offsets(n, offsets):
return [min(o, max(0, n - 1)) for o in offsets]

interleaved_mma_0 = tkw.interleave_operations(
base_ops=loop_scaled_mma_0,
interleaved_ops=[
loop_g2v_b,
loop_shared_load_a_1,
loop_shared_load_a_scale_1,
loop_g2v_b_scale,
],
intervals=base_intervals,
start_offsets=_clamp_offsets(len(loop_scaled_mma_0), base_offsets),
start_after_groups=[[], [], [1], [0]],
)

interleaved_mma_1 = tkw.interleave_operations(
base_ops=loop_scaled_mma_1,
interleaved_ops=[
loop_g2s_a,
loop_shared_load_a_0,
loop_shared_load_a_scale_0,
loop_g2s_a_scale,
],
intervals=base_intervals,
start_offsets=_clamp_offsets(len(loop_scaled_mma_1), base_offsets),
start_after_groups=[[], [], [1], [0]],
)

loop_B_g2v_bs = len(loop_g2v_b) + (
len(loop_g2v_b_scale) // b_scale_shuffling_factor
)
loop_A_s2v_bs = len(loop_g2s_a) + len(loop_g2s_a_scale)
clusters = [
tkw.cluster(
[
loop_bitcast_a_0,
loop_bitcast_a_scale_0,
loop_bitcast_b,
loop_bitcast_b_scale,
tkw.SchedulingBarrier([]),
interleaved_mma_0,
tkw.SchedulingBarrier([]),
tkw.MemoryCounterWaitBarrier(load=loop_B_g2v_bs, ds=0),
tkw.SchedulingBarrier([]),

if unroll_kernel:
base_offsets = [0, 3, 2, 0]
base_intervals = [4, 4, 2, 4]

interleaved_mma_0 = tkw.interleave_operations(
base_ops=loop_scaled_mma_0,
interleaved_ops=[
loop_g2v_b,
loop_shared_load_a_1,
loop_shared_load_a_scale_1,
loop_g2v_b_scale,
],
),
tkw.cluster(
[
loop_bitcast_a_1,
loop_bitcast_a_scale_1,
tkw.SchedulingBarrier([]),
interleaved_mma_1,
tkw.SchedulingBarrier([]),
tkw.MemoryCounterWaitBarrier(load=loop_A_s2v_bs, ds=0),
tkw.SchedulingBarrier([]),
]
),
]
intervals=base_intervals,
start_offsets=_clamp_offsets(len(loop_scaled_mma_0), base_offsets),
start_after_groups=[[], [], [1], [0]],
)

interleaved_mma_1 = tkw.interleave_operations(
base_ops=loop_scaled_mma_1,
interleaved_ops=[
loop_g2s_a,
loop_shared_load_a_0,
loop_shared_load_a_scale_0,
loop_g2s_a_scale,
],
intervals=base_intervals,
start_offsets=_clamp_offsets(len(loop_scaled_mma_1), base_offsets),
start_after_groups=[[], [], [1], [0]],
)

loop_B_g2v_bs = len(loop_g2v_b) + (
len(loop_g2v_b_scale) // b_scale_shuffling_factor
)
clusters = [
tkw.cluster(
[
loop_bitcast_a_0,
loop_bitcast_a_scale_0,
loop_bitcast_b,
loop_bitcast_b_scale,
tkw.SchedulingBarrier([]),
interleaved_mma_0,
tkw.SchedulingBarrier([]),
tkw.MemoryCounterWaitBarrier(load=loop_B_g2v_bs, ds=0),
tkw.SchedulingBarrier([]),
],
),
tkw.cluster(
[
loop_bitcast_a_1,
loop_bitcast_a_scale_1,
tkw.SchedulingBarrier([]),
interleaved_mma_1,
tkw.SchedulingBarrier([]),
tkw.MemoryCounterWaitBarrier(load=loop_A_s2v_bs, ds=0),
tkw.SchedulingBarrier([]),
]
),
]
else:
interleaved_mma_0 = tkw.interleave_operations(
base_ops=loop_scaled_mma_0,
interleaved_ops=[
loop_shared_load_a_1,
loop_shared_load_a_scale_1,
],
intervals=[4, 2],
start_offsets=_clamp_offsets(len(loop_scaled_mma_0), [3, 2]),
start_after_groups=[[], [0]],
)

interleaved_mma_1 = tkw.interleave_operations(
base_ops=loop_scaled_mma_1,
interleaved_ops=[
loop_shared_load_a_0,
loop_shared_load_a_scale_0,
],
intervals=[4, 2],
start_offsets=_clamp_offsets(len(loop_scaled_mma_1), [3, 2]),
start_after_groups=[[], [0]],
)

clusters = [
tkw.cluster(
[
loop_bitcast_a_0,
loop_bitcast_a_scale_0,
loop_bitcast_b,
loop_bitcast_b_scale,
tkw.SchedulingBarrier([]),
interleaved_mma_0,
tkw.SchedulingBarrier([]),
],
),
tkw.cluster(
[
loop_bitcast_a_1,
loop_bitcast_a_scale_1,
tkw.SchedulingBarrier([]),
interleaved_mma_1,
tkw.SchedulingBarrier([]),
],
),
tkw.cluster(
[
loop_g2v_b,
loop_g2v_b_scale,
loop_g2s_a,
loop_g2s_a_scale,
tkw.SchedulingBarrier([]),
tkw.MemoryCounterWaitBarrier(load=loop_A_s2v_bs, ds=0),
tkw.SchedulingBarrier([]),
],
),
]

if eliminate_epilogue:
clusters += prologue_clusters
Expand Down Expand Up @@ -1992,8 +2053,9 @@ def split_by_iteration(nodes, key="name"):

tkw.reorder_graph(pipeline_loop.PROLOGUE, prologue_clusters)
tkw.reorder_graph(pipeline_loop.KERNEL, clusters)
unroll_factor = 2
tkw.unroll(pipeline_loop.KERNEL, unroll_factor)
tkw.reorder_graph(pipeline_loop.EPILOGUE, epilogue_clusters_itr0)
if unroll_kernel:
tkw.unroll(pipeline_loop.KERNEL, unroll_factor)

tkw.insert_at_start(
pipeline_loop.KERNEL,
Expand Down
Loading