diff --git a/examples/python/7.1_schedule.py b/examples/python/7.1_schedule.py index 0896395cc4..ac61b98f78 100755 --- a/examples/python/7.1_schedule.py +++ b/examples/python/7.1_schedule.py @@ -372,18 +372,25 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm_cpp( is_debug=False, shape=(512, 1024, 8192), # 4*T0, 4*T1, 8192 block=(128, 256, 256), + wave_shape=(1, 4), eliminate_epilogue=True, + no_unroll=False, ): """Preshuffle-B MXFP4 GEMM using C++ WaveASM backend.""" - gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(1, 4)) + gemm, options = get_tagged_mxfp4_gemm_preshuffle_b( + shape, block, wave_shape=wave_shape + ) options.backend = "asm" options.use_buffer_ops = True - options.wave_runtime = True options.use_wave_asm_backend = True + options.wave_runtime = True options.dump_intermediates = "build/intermediates" options.eliminate_epilogue = eliminate_epilogue schedule = get_mxfp4_asymmetric_schedule( - eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True + eliminate_epilogue=eliminate_epilogue, + is_bscale_shuffled=True, + unroll_factor=1 if no_unroll else 2, + unroll_kernel=not no_unroll, ) options.print_ir_after = "all" if is_debug else [] options = set_default_run_config(options) @@ -444,5 +451,7 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm( args.shape, args.block, args.eliminate_epilogue, + args.wave_shape, + no_unroll=args.no_unroll, ) exit(0 if success else 1) diff --git a/examples/python/utils.py b/examples/python/utils.py index 7c104dfe66..37238f35f4 100644 --- a/examples/python/utils.py +++ b/examples/python/utils.py @@ -36,6 +36,17 @@ def parse_args(): default=None, help="Enable epilogue elimination (true/false)", ) + parser.add_argument( + "--wave_shape", + type=str, + default=None, + help="Wave shape, e.g. 2,2", + ) + parser.add_argument( + "--no-unroll", + action="store_true", + help="Use nounroll (unroll_factor=1) schedule variant", + ) args = parser.parse_args() @@ -44,6 +55,8 @@ def parse_args(): args.shape = tuple(map(int, args.shape.split(","))) if isinstance(args.block, str): args.block = tuple(map(int, args.block.split(","))) + if isinstance(args.wave_shape, str): + args.wave_shape = tuple(map(int, args.wave_shape.split(","))) return args @@ -64,6 +77,8 @@ def run_test( shape=None, block=None, eliminate_epilogue=None, + wave_shape=None, + no_unroll=False, ): """Run a test function multiple times.""" if test_name not in module_globals: @@ -78,6 +93,10 @@ def run_test( kwargs["block"] = block if eliminate_epilogue is not None: kwargs["eliminate_epilogue"] = eliminate_epilogue + if wave_shape is not None: + kwargs["wave_shape"] = wave_shape + if no_unroll: + kwargs["no_unroll"] = True for i in range(repeat): try: diff --git a/tests/kernel/wave/asm/test_waveasm_e2e.py b/tests/kernel/wave/asm/test_waveasm_e2e.py index 4e2dec946e..5d4eff8700 100644 --- a/tests/kernel/wave/asm/test_waveasm_e2e.py +++ b/tests/kernel/wave/asm/test_waveasm_e2e.py @@ -1142,6 +1142,7 @@ def _dbuf_mxfp4_helper( wave_shape=None, reorder_workgroups=None, eliminate_epilogue=False, + no_unroll=False, ): """Shared helper for double-buffered MXFP4 scheduled GEMM tests. @@ -1201,7 +1202,10 @@ def _dbuf_mxfp4_helper( options.eliminate_epilogue = eliminate_epilogue if use_schedule: schedule = get_mxfp4_asymmetric_schedule( - eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True + eliminate_epilogue=eliminate_epilogue, + is_bscale_shuffled=True, + unroll_factor=1 if no_unroll else 2, + unroll_kernel=not no_unroll, ) else: schedule = None @@ -1445,6 +1449,44 @@ def test_dbuf_4wave_mxfp4_gemm_cpp_backend( ) +@pytest.mark.parametrize("eliminate_epilogue", [True], ids=["ee"]) +@pytest.mark.parametrize( + "shape,block,wave_shape", + [ + pytest.param((1024, 896, 8192), (256, 224, 256), (2, 2), id="256x224x256"), + ], +) +def test_dbuf_4wave_mxfp4_nounroll_gemm_cpp_backend( + shape, + block, + wave_shape, + eliminate_epilogue, + compiler, + dump_asm, +): + """End-to-end test for asymmetric MXFP4 GEMM with no-unroll schedule. + + The no-unroll schedule (unroll_factor=1) reduces register pressure by + not unrolling the K-loop body, allowing larger block sizes like + 256x224x256 that would otherwise exceed the 256-VGPR hardware limit + with the standard asymmetric schedule. + """ + _dbuf_mxfp4_helper( + shape=shape, + block=block, + num_waves=4, + use_stagger=False, + compiler=compiler, + dump_asm=dump_asm, + use_buffer_ops=True, + use_schedule=True, + output_dtype="f32", + wave_shape=wave_shape, + eliminate_epilogue=eliminate_epilogue, + no_unroll=True, + ) + + @pytest.mark.parametrize( "shape,block,wave_shape", [ diff --git a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py index 6069941284..0a42463441 100755 --- a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py +++ b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py @@ -1573,7 +1573,10 @@ def mxfp4_dbuf_schedule(): def get_mxfp4_asymmetric_schedule( - eliminate_epilogue: bool = False, is_bscale_shuffled: bool = False + eliminate_epilogue: bool = False, + is_bscale_shuffled: bool = False, + unroll_factor: int = 2, + unroll_kernel: bool = True, ): """Return an asymmetric-prefetch MXFP4 schedule for wave_compile(). @@ -1645,7 +1648,7 @@ def mxfp4_dbuf_schedule(): # This forces the pipeline to use double buffering pipeline_loop.multi_buffer_count = 2 - pipeline_loop.unroll_factor = 2 + pipeline_loop.unroll_factor = unroll_factor with pipeline_loop as pl: pl.set_stage( @@ -1783,68 +1786,126 @@ def mxfp4_dbuf_schedule(): # Interleave MFMAs with memory ops (matching aiter f4gemm pattern). # Clamp start_offsets so they fit within each partition when the M # tile count is odd (e.g. 7 tiles split into 4+3). - base_offsets = [0, 3, 2, 0] - base_intervals = [4, 4, 2, 4] - def _clamp_offsets(n, offsets): return [min(o, max(0, n - 1)) for o in offsets] - interleaved_mma_0 = tkw.interleave_operations( - base_ops=loop_scaled_mma_0, - interleaved_ops=[ - loop_g2v_b, - loop_shared_load_a_1, - loop_shared_load_a_scale_1, - loop_g2v_b_scale, - ], - intervals=base_intervals, - start_offsets=_clamp_offsets(len(loop_scaled_mma_0), base_offsets), - start_after_groups=[[], [], [1], [0]], - ) - - interleaved_mma_1 = tkw.interleave_operations( - base_ops=loop_scaled_mma_1, - interleaved_ops=[ - loop_g2s_a, - loop_shared_load_a_0, - loop_shared_load_a_scale_0, - loop_g2s_a_scale, - ], - intervals=base_intervals, - start_offsets=_clamp_offsets(len(loop_scaled_mma_1), base_offsets), - start_after_groups=[[], [], [1], [0]], - ) - - loop_B_g2v_bs = len(loop_g2v_b) + ( - len(loop_g2v_b_scale) // b_scale_shuffling_factor - ) loop_A_s2v_bs = len(loop_g2s_a) + len(loop_g2s_a_scale) - clusters = [ - tkw.cluster( - [ - loop_bitcast_a_0, - loop_bitcast_a_scale_0, - loop_bitcast_b, - loop_bitcast_b_scale, - tkw.SchedulingBarrier([]), - interleaved_mma_0, - tkw.SchedulingBarrier([]), - tkw.MemoryCounterWaitBarrier(load=loop_B_g2v_bs, ds=0), - tkw.SchedulingBarrier([]), + + if unroll_kernel: + base_offsets = [0, 3, 2, 0] + base_intervals = [4, 4, 2, 4] + + interleaved_mma_0 = tkw.interleave_operations( + base_ops=loop_scaled_mma_0, + interleaved_ops=[ + loop_g2v_b, + loop_shared_load_a_1, + loop_shared_load_a_scale_1, + loop_g2v_b_scale, ], - ), - tkw.cluster( - [ - loop_bitcast_a_1, - loop_bitcast_a_scale_1, - tkw.SchedulingBarrier([]), - interleaved_mma_1, - tkw.SchedulingBarrier([]), - tkw.MemoryCounterWaitBarrier(load=loop_A_s2v_bs, ds=0), - tkw.SchedulingBarrier([]), - ] - ), - ] + intervals=base_intervals, + start_offsets=_clamp_offsets(len(loop_scaled_mma_0), base_offsets), + start_after_groups=[[], [], [1], [0]], + ) + + interleaved_mma_1 = tkw.interleave_operations( + base_ops=loop_scaled_mma_1, + interleaved_ops=[ + loop_g2s_a, + loop_shared_load_a_0, + loop_shared_load_a_scale_0, + loop_g2s_a_scale, + ], + intervals=base_intervals, + start_offsets=_clamp_offsets(len(loop_scaled_mma_1), base_offsets), + start_after_groups=[[], [], [1], [0]], + ) + + loop_B_g2v_bs = len(loop_g2v_b) + ( + len(loop_g2v_b_scale) // b_scale_shuffling_factor + ) + clusters = [ + tkw.cluster( + [ + loop_bitcast_a_0, + loop_bitcast_a_scale_0, + loop_bitcast_b, + loop_bitcast_b_scale, + tkw.SchedulingBarrier([]), + interleaved_mma_0, + tkw.SchedulingBarrier([]), + tkw.MemoryCounterWaitBarrier(load=loop_B_g2v_bs, ds=0), + tkw.SchedulingBarrier([]), + ], + ), + tkw.cluster( + [ + loop_bitcast_a_1, + loop_bitcast_a_scale_1, + tkw.SchedulingBarrier([]), + interleaved_mma_1, + tkw.SchedulingBarrier([]), + tkw.MemoryCounterWaitBarrier(load=loop_A_s2v_bs, ds=0), + tkw.SchedulingBarrier([]), + ] + ), + ] + else: + interleaved_mma_0 = tkw.interleave_operations( + base_ops=loop_scaled_mma_0, + interleaved_ops=[ + loop_shared_load_a_1, + loop_shared_load_a_scale_1, + ], + intervals=[4, 2], + start_offsets=_clamp_offsets(len(loop_scaled_mma_0), [3, 2]), + start_after_groups=[[], [0]], + ) + + interleaved_mma_1 = tkw.interleave_operations( + base_ops=loop_scaled_mma_1, + interleaved_ops=[ + loop_shared_load_a_0, + loop_shared_load_a_scale_0, + ], + intervals=[4, 2], + start_offsets=_clamp_offsets(len(loop_scaled_mma_1), [3, 2]), + start_after_groups=[[], [0]], + ) + + clusters = [ + tkw.cluster( + [ + loop_bitcast_a_0, + loop_bitcast_a_scale_0, + loop_bitcast_b, + loop_bitcast_b_scale, + tkw.SchedulingBarrier([]), + interleaved_mma_0, + tkw.SchedulingBarrier([]), + ], + ), + tkw.cluster( + [ + loop_bitcast_a_1, + loop_bitcast_a_scale_1, + tkw.SchedulingBarrier([]), + interleaved_mma_1, + tkw.SchedulingBarrier([]), + ], + ), + tkw.cluster( + [ + loop_g2v_b, + loop_g2v_b_scale, + loop_g2s_a, + loop_g2s_a_scale, + tkw.SchedulingBarrier([]), + tkw.MemoryCounterWaitBarrier(load=loop_A_s2v_bs, ds=0), + tkw.SchedulingBarrier([]), + ], + ), + ] if eliminate_epilogue: clusters += prologue_clusters @@ -1992,8 +2053,9 @@ def split_by_iteration(nodes, key="name"): tkw.reorder_graph(pipeline_loop.PROLOGUE, prologue_clusters) tkw.reorder_graph(pipeline_loop.KERNEL, clusters) - unroll_factor = 2 - tkw.unroll(pipeline_loop.KERNEL, unroll_factor) + tkw.reorder_graph(pipeline_loop.EPILOGUE, epilogue_clusters_itr0) + if unroll_kernel: + tkw.unroll(pipeline_loop.KERNEL, unroll_factor) tkw.insert_at_start( pipeline_loop.KERNEL,