From 809d2cb8e4069dc045583e4a714c0ef201a1f46d Mon Sep 17 00:00:00 2001 From: William G Hatch Date: Fri, 13 Mar 2026 10:59:37 -0600 Subject: [PATCH 1/2] Add schedule for 256x224x256 macro tile It is a no_unroll schedule to get under the register budget. This gets the macro tile functional with the waveasm backend. For the 7.1 example, it adds - `--wave_shape` flag -- Previously (1,4) was hard-coded, but the 256x224x256 tile needed (2, 2) because the N dimension was not divisible by 4 after pipelining... I think was the reason we chose that. - `--no_unroll` flag to access the new no_unroll schedule. The particular 7.1 example target for this work was `python examples/python/7.1_schedule.py --block 256,224,256 --shape 1024,896,8192 --wave_shape 2,2 --no-unroll --test test_dbuf_4wave_mxfp_preshuffle_b_gemm_cpp` This also adds an e2e waveasm test. At this stage no real effort has been made to make the schedule performant, just to get it working. Signed-off-by: William G Hatch --- examples/python/7.1_schedule.py | 22 +- examples/python/utils.py | 19 + tests/kernel/wave/asm/test_waveasm_e2e.py | 51 ++- wave_lang/kernel/wave/schedules/__init__.py | 2 + .../schedules/gemm_mxfp4_double_buffer.py | 364 ++++++++++++++++++ 5 files changed, 450 insertions(+), 8 deletions(-) diff --git a/examples/python/7.1_schedule.py b/examples/python/7.1_schedule.py index 0896395cc4..2aab8b8b86 100755 --- a/examples/python/7.1_schedule.py +++ b/examples/python/7.1_schedule.py @@ -27,6 +27,7 @@ get_mxfp4_dbuf_pingpong_schedule, get_mxfp4_dbuf_mixed_pingpong_schedule, get_mxfp4_asymmetric_schedule, + get_mxfp4_asymmetric_nounroll_schedule, get_mxfp4_dbuf_mixed_pingpong_shuffle_schedule, get_mxfp4_dbuf_pingpong_schedule_Bshuffled, get_mxfp4_dbuf_pingpong_schedule_Bshuffled_lds, @@ -372,19 +373,28 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm_cpp( is_debug=False, shape=(512, 1024, 8192), # 4*T0, 4*T1, 8192 block=(128, 256, 256), + wave_shape=(1, 4), eliminate_epilogue=True, + no_unroll=False, ): """Preshuffle-B MXFP4 GEMM using C++ WaveASM backend.""" - gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(1, 4)) + gemm, options = get_tagged_mxfp4_gemm_preshuffle_b( + shape, block, wave_shape=wave_shape + ) options.backend = "asm" options.use_buffer_ops = True - options.wave_runtime = True options.use_wave_asm_backend = True + options.wave_runtime = True options.dump_intermediates = "build/intermediates" options.eliminate_epilogue = eliminate_epilogue - schedule = get_mxfp4_asymmetric_schedule( - eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True - ) + if no_unroll: + schedule = get_mxfp4_asymmetric_nounroll_schedule( + eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True + ) + else: + schedule = get_mxfp4_asymmetric_schedule( + eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True + ) options.print_ir_after = "all" if is_debug else [] options = set_default_run_config(options) gemm = wave_compile(options, gemm, schedule) @@ -444,5 +454,7 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm( args.shape, args.block, args.eliminate_epilogue, + args.wave_shape, + no_unroll=args.no_unroll, ) exit(0 if success else 1) diff --git a/examples/python/utils.py b/examples/python/utils.py index 7c104dfe66..37238f35f4 100644 --- a/examples/python/utils.py +++ b/examples/python/utils.py @@ -36,6 +36,17 @@ def parse_args(): default=None, help="Enable epilogue elimination (true/false)", ) + parser.add_argument( + "--wave_shape", + type=str, + default=None, + help="Wave shape, e.g. 2,2", + ) + parser.add_argument( + "--no-unroll", + action="store_true", + help="Use nounroll (unroll_factor=1) schedule variant", + ) args = parser.parse_args() @@ -44,6 +55,8 @@ def parse_args(): args.shape = tuple(map(int, args.shape.split(","))) if isinstance(args.block, str): args.block = tuple(map(int, args.block.split(","))) + if isinstance(args.wave_shape, str): + args.wave_shape = tuple(map(int, args.wave_shape.split(","))) return args @@ -64,6 +77,8 @@ def run_test( shape=None, block=None, eliminate_epilogue=None, + wave_shape=None, + no_unroll=False, ): """Run a test function multiple times.""" if test_name not in module_globals: @@ -78,6 +93,10 @@ def run_test( kwargs["block"] = block if eliminate_epilogue is not None: kwargs["eliminate_epilogue"] = eliminate_epilogue + if wave_shape is not None: + kwargs["wave_shape"] = wave_shape + if no_unroll: + kwargs["no_unroll"] = True for i in range(repeat): try: diff --git a/tests/kernel/wave/asm/test_waveasm_e2e.py b/tests/kernel/wave/asm/test_waveasm_e2e.py index 4e2dec946e..98359a7610 100644 --- a/tests/kernel/wave/asm/test_waveasm_e2e.py +++ b/tests/kernel/wave/asm/test_waveasm_e2e.py @@ -1142,6 +1142,7 @@ def _dbuf_mxfp4_helper( wave_shape=None, reorder_workgroups=None, eliminate_epilogue=False, + no_unroll=False, ): """Shared helper for double-buffered MXFP4 scheduled GEMM tests. @@ -1168,6 +1169,7 @@ def _dbuf_mxfp4_helper( from wave_lang.kernel.wave.schedules import ( get_mxfp4_dbuf_schedule, get_mxfp4_asymmetric_schedule, + get_mxfp4_asymmetric_nounroll_schedule, ) from wave_lang.kernel.wave.scheduling.schedule_enums import SchedulingType from wave_lang.kernel.wave.utils.run_utils import set_default_run_config @@ -1200,9 +1202,14 @@ def _dbuf_mxfp4_helper( ) options.eliminate_epilogue = eliminate_epilogue if use_schedule: - schedule = get_mxfp4_asymmetric_schedule( - eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True - ) + if no_unroll: + schedule = get_mxfp4_asymmetric_nounroll_schedule( + eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True + ) + else: + schedule = get_mxfp4_asymmetric_schedule( + eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True + ) else: schedule = None options.schedule = SchedulingType.NONE @@ -1445,6 +1452,44 @@ def test_dbuf_4wave_mxfp4_gemm_cpp_backend( ) +@pytest.mark.parametrize("eliminate_epilogue", [True], ids=["ee"]) +@pytest.mark.parametrize( + "shape,block,wave_shape", + [ + pytest.param((1024, 896, 8192), (256, 224, 256), (2, 2), id="256x224x256"), + ], +) +def test_dbuf_4wave_mxfp4_nounroll_gemm_cpp_backend( + shape, + block, + wave_shape, + eliminate_epilogue, + compiler, + dump_asm, +): + """End-to-end test for asymmetric MXFP4 GEMM with no-unroll schedule. + + The no-unroll schedule (unroll_factor=1) reduces register pressure by + not unrolling the K-loop body, allowing larger block sizes like + 256x224x256 that would otherwise exceed the 256-VGPR hardware limit + with the standard asymmetric schedule. + """ + _dbuf_mxfp4_helper( + shape=shape, + block=block, + num_waves=4, + use_stagger=False, + compiler=compiler, + dump_asm=dump_asm, + use_buffer_ops=True, + use_schedule=True, + output_dtype="f32", + wave_shape=wave_shape, + eliminate_epilogue=eliminate_epilogue, + no_unroll=True, + ) + + @pytest.mark.parametrize( "shape,block,wave_shape", [ diff --git a/wave_lang/kernel/wave/schedules/__init__.py b/wave_lang/kernel/wave/schedules/__init__.py index 4ec56863ab..7640de303e 100644 --- a/wave_lang/kernel/wave/schedules/__init__.py +++ b/wave_lang/kernel/wave/schedules/__init__.py @@ -18,6 +18,7 @@ get_mxfp4_dbuf_mixed_pingpong_schedule, get_mxfp4_dbuf_mixed_pingpong_shuffle_schedule, get_mxfp4_asymmetric_schedule, + get_mxfp4_asymmetric_nounroll_schedule, get_mxfp4_dbuf_pingpong_schedule_Bshuffled, get_mxfp4_dbuf_pingpong_schedule_Bshuffled_lds, ) @@ -34,6 +35,7 @@ "get_mxfp4_dbuf_pingpong_schedule_Bshuffled", "get_mxfp4_dbuf_pingpong_schedule_Bshuffled_lds", "get_mxfp4_asymmetric_schedule", + "get_mxfp4_asymmetric_nounroll_schedule", "get_mxfp4_dbuf_mixed_pingpong_shuffle_schedule", "get_attention_prefetch_schedule", ] diff --git a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py index 6069941284..f31d768c58 100755 --- a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py +++ b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py @@ -2004,3 +2004,367 @@ def split_by_iteration(nodes, key="name"): ) return mxfp4_dbuf_schedule + + +def get_mxfp4_asymmetric_nounroll_schedule( + eliminate_epilogue: bool = False, is_bscale_shuffled: bool = False +): + """Asymmetric-prefetch MXFP4 schedule with unroll_factor=1. + + Same 3-stage pipeline as get_mxfp4_asymmetric_schedule but without + kernel body unrolling. This keeps VGPR pressure low enough for large + tiles (e.g. 256x224x256 with wave_shape 2x2) while still using a + standard epilogue for pipeline draining. + """ + M = tkl.sym.M + + @wave_schedule.wave_schedule() + def mxfp4_nounroll_schedule(): + k_loop = tkw.get_node_by_tag("k_loop") + + all_read_a = tkw.get_node_by_tag("read_a") + g2s_a = tkw.filter_nodes(all_read_a, node_type=tkw.GatherToLDS) + s2v_a = tkw.filter_nodes(all_read_a, node_type=tkw.Read) + + all_read_a_scale = tkw.get_node_by_tag("read_a_scale") + g2s_a_scale = tkw.filter_nodes(all_read_a_scale, node_type=tkw.GatherToLDS) + s2v_a_scale = tkw.filter_nodes(all_read_a_scale, node_type=tkw.Read) + + s2v_a_0, s2v_a_1 = tkw.partition_by_dim(s2v_a, dim=M, num_partitions=2) + s2v_a_scale_0, s2v_a_scale_1 = tkw.partition_by_dim( + s2v_a_scale, dim=M, num_partitions=2 + ) + + g2v_b = tkw.get_node_by_tag("read_b") + g2v_b_scale = tkw.get_node_by_tag("read_b_scale") + + bitcast_a = tkw.get_node_by_tag("bitcast_a") + bitcast_a_scale = tkw.get_node_by_tag("bitcast_a_scale") + bitcast_b = tkw.get_node_by_tag("bitcast_b") + bitcast_b_scale = tkw.get_node_by_tag("bitcast_b_scale") + + scaled_mma = tkw.get_node_by_tag("scaled_mma") + + pipeline_loop = tkw.pipeline(k_loop, eliminate_epilogue=eliminate_epilogue) + pipeline_loop.multi_buffer_count = 2 + pipeline_loop.unroll_factor = 1 + + with pipeline_loop as pl: + pl.set_stage( + [ + (g2s_a, g2s_a_scale), + (), + (), + ], + ) + pl.set_stage( + [ + (g2v_b, g2v_b_scale), + (s2v_a_0, s2v_a_scale_0), + (), + ], + ) + pl.set_stage( + [ + (s2v_a_1, s2v_a_scale_1), + (bitcast_a, bitcast_a_scale, bitcast_b, bitcast_b_scale), + (scaled_mma,), + ], + ) + + if is_bscale_shuffled: + b_scale_shuffling_factor = 4 + else: + b_scale_shuffling_factor = 1 + + num_pf_iters = 2 + + # ----------------------------------------------------------------- + # Prologue + # ----------------------------------------------------------------- + prologue_g2s_a = tkw.filter_nodes(g2s_a, subgraph=pipeline_loop.PROLOGUE) + prologue_g2s_a_scale = tkw.filter_nodes( + g2s_a_scale, subgraph=pipeline_loop.PROLOGUE + ) + prologue_g2v_b = tkw.filter_nodes(g2v_b, subgraph=pipeline_loop.PROLOGUE) + prologue_g2v_b_scale = tkw.filter_nodes( + g2v_b_scale, subgraph=pipeline_loop.PROLOGUE + ) + prologue_s2v_a_0 = tkw.filter_nodes(s2v_a_0, subgraph=pipeline_loop.PROLOGUE) + prologue_s2v_a_scale_0 = tkw.filter_nodes( + s2v_a_scale_0, subgraph=pipeline_loop.PROLOGUE + ) + + A_g2s_total = len(prologue_g2s_a) + len(prologue_g2s_a_scale) + A_g2s_per_iter = A_g2s_total // num_pf_iters + B_g2v_prologue = len(prologue_g2v_b) + ( + len(prologue_g2v_b_scale) // b_scale_shuffling_factor + ) + + prologue_clusters = [ + tkw.cluster( + [ + prologue_g2s_a, + prologue_g2s_a_scale, + prologue_g2v_b, + tkw.SchedulingBarrier([]), + prologue_g2v_b_scale, + tkw.SchedulingBarrier([]), + tkw.MemoryCounterWaitBarrier(load=0), + tkw.SchedulingBarrier([]), + prologue_s2v_a_0, + prologue_s2v_a_scale_0, + ], + ) + ] + + # ----------------------------------------------------------------- + # Kernel (main loop body) + # ----------------------------------------------------------------- + loop_g2s_a = tkw.filter_nodes(g2s_a, subgraph=pipeline_loop.KERNEL) + loop_g2s_a_scale = tkw.filter_nodes(g2s_a_scale, subgraph=pipeline_loop.KERNEL) + loop_g2v_b = tkw.filter_nodes(g2v_b, subgraph=pipeline_loop.KERNEL) + loop_g2v_b_scale = tkw.filter_nodes(g2v_b_scale, subgraph=pipeline_loop.KERNEL) + loop_shared_load_a_0 = tkw.filter_nodes(s2v_a_0, subgraph=pipeline_loop.KERNEL) + loop_shared_load_a_scale_0 = tkw.filter_nodes( + s2v_a_scale_0, subgraph=pipeline_loop.KERNEL + ) + loop_shared_load_a_1 = tkw.filter_nodes(s2v_a_1, subgraph=pipeline_loop.KERNEL) + loop_shared_load_a_scale_1 = tkw.filter_nodes( + s2v_a_scale_1, subgraph=pipeline_loop.KERNEL + ) + loop_bitcast_a = tkw.filter_nodes(bitcast_a, subgraph=pipeline_loop.KERNEL) + loop_bitcast_a_scale = tkw.filter_nodes( + bitcast_a_scale, subgraph=pipeline_loop.KERNEL + ) + loop_bitcast_b = tkw.filter_nodes(bitcast_b, subgraph=pipeline_loop.KERNEL) + loop_bitcast_b_scale = tkw.filter_nodes( + bitcast_b_scale, subgraph=pipeline_loop.KERNEL + ) + loop_scaled_mma = tkw.filter_nodes(scaled_mma, subgraph=pipeline_loop.KERNEL) + + loop_scaled_mma_0, loop_scaled_mma_1 = tkw.partition_by_dim( + loop_scaled_mma, dim=M, num_partitions=2 + ) + loop_bitcast_a_0, loop_bitcast_a_1 = tkw.partition_by_dim( + loop_bitcast_a, dim=M, num_partitions=2 + ) + loop_bitcast_a_scale_0, loop_bitcast_a_scale_1 = tkw.partition_by_dim( + loop_bitcast_a_scale, dim=M, num_partitions=2 + ) + + interleaved_mma_0 = tkw.interleave_operations( + base_ops=loop_scaled_mma_0, + interleaved_ops=[ + loop_shared_load_a_1, + loop_shared_load_a_scale_1, + ], + intervals=[4, 2], + start_offsets=[3, 2], + start_after_groups=[[], [0]], + ) + + interleaved_mma_1 = tkw.interleave_operations( + base_ops=loop_scaled_mma_1, + interleaved_ops=[ + loop_shared_load_a_0, + loop_shared_load_a_scale_0, + ], + intervals=[4, 2], + start_offsets=[3, 2], + start_after_groups=[[], [0]], + ) + + loop_A_s2v_bs = len(loop_g2s_a) + len(loop_g2s_a_scale) + kernel_clusters = [ + tkw.cluster( + [ + loop_bitcast_a_0, + loop_bitcast_a_scale_0, + loop_bitcast_b, + loop_bitcast_b_scale, + tkw.SchedulingBarrier([]), + interleaved_mma_0, + tkw.SchedulingBarrier([]), + ], + ), + tkw.cluster( + [ + loop_bitcast_a_1, + loop_bitcast_a_scale_1, + tkw.SchedulingBarrier([]), + interleaved_mma_1, + tkw.SchedulingBarrier([]), + ], + ), + tkw.cluster( + [ + loop_g2v_b, + loop_g2v_b_scale, + loop_g2s_a, + loop_g2s_a_scale, + tkw.SchedulingBarrier([]), + tkw.MemoryCounterWaitBarrier(load=loop_A_s2v_bs, ds=0), + tkw.SchedulingBarrier([]), + ], + ), + ] + + if eliminate_epilogue: + kernel_clusters += prologue_clusters + tkw.reorder_graph(pipeline_loop.KERNEL, kernel_clusters) + else: + # ----------------------------------------------------------------- + # Epilogue: two drain iterations for the 3-stage pipeline. + # Schedule drain 0's compute before drain 1's loads so both + # iterations' live registers don't overlap and exceed the + # 256 VGPR budget. + # ----------------------------------------------------------------- + epilogue_g2v_b = tkw.filter_nodes(g2v_b, subgraph=pipeline_loop.EPILOGUE) + epilogue_g2v_b_scale = tkw.filter_nodes( + g2v_b_scale, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_s2v_a_0 = tkw.filter_nodes( + s2v_a_0, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_s2v_a_scale_0 = tkw.filter_nodes( + s2v_a_scale_0, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_s2v_a_1 = tkw.filter_nodes( + s2v_a_1, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_s2v_a_scale_1 = tkw.filter_nodes( + s2v_a_scale_1, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_bitcast_a = tkw.filter_nodes( + bitcast_a, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_bitcast_a_scale = tkw.filter_nodes( + bitcast_a_scale, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_bitcast_b = tkw.filter_nodes( + bitcast_b, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_bitcast_b_scale = tkw.filter_nodes( + bitcast_b_scale, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_mma = tkw.filter_nodes(scaled_mma, subgraph=pipeline_loop.EPILOGUE) + + def split_by_iteration(nodes, key="name"): + itr0 = [] + itr1 = [] + for node in nodes: + value = getattr(node, key) + if "1_2" in value: + itr0.append(node) + elif "2_2" in value: + itr1.append(node) + else: + raise ValueError(f"Unknown {key} for node: {value}") + return itr0, itr1 + + epilogue_mma_itr0, epilogue_mma_itr1 = split_by_iteration(epilogue_mma) + epilogue_s2v_a_1_itr0, epilogue_s2v_a_1_itr1 = split_by_iteration( + epilogue_s2v_a_1 + ) + epilogue_s2v_a_scale_1_itr0, epilogue_s2v_a_scale_1_itr1 = ( + split_by_iteration(epilogue_s2v_a_scale_1) + ) + epilogue_bitcast_a_itr0, epilogue_bitcast_a_itr1 = split_by_iteration( + epilogue_bitcast_a + ) + epilogue_bitcast_a_scale_itr0, epilogue_bitcast_a_scale_itr1 = ( + split_by_iteration(epilogue_bitcast_a_scale) + ) + epilogue_bitcast_b_itr0, epilogue_bitcast_b_itr1 = split_by_iteration( + epilogue_bitcast_b + ) + epilogue_bitcast_b_scale_itr0, epilogue_bitcast_b_scale_itr1 = ( + split_by_iteration(epilogue_bitcast_b_scale) + ) + + epilogue_mma_itr0_0, epilogue_mma_itr0_1 = tkw.partition_by_dim( + epilogue_mma_itr0, dim=M, num_partitions=2 + ) + epilogue_bitcast_a_itr0_0, epilogue_bitcast_a_itr0_1 = tkw.partition_by_dim( + epilogue_bitcast_a_itr0, dim=M, num_partitions=2 + ) + epilogue_bitcast_a_scale_itr0_0, epilogue_bitcast_a_scale_itr0_1 = ( + tkw.partition_by_dim( + epilogue_bitcast_a_scale_itr0, dim=M, num_partitions=2 + ) + ) + + epilogue_mma_itr1_0, epilogue_mma_itr1_1 = tkw.partition_by_dim( + epilogue_mma_itr1, dim=M, num_partitions=2 + ) + epilogue_bitcast_a_itr1_0, epilogue_bitcast_a_itr1_1 = tkw.partition_by_dim( + epilogue_bitcast_a_itr1, dim=M, num_partitions=2 + ) + epilogue_bitcast_a_scale_itr1_0, epilogue_bitcast_a_scale_itr1_1 = ( + tkw.partition_by_dim( + epilogue_bitcast_a_scale_itr1, dim=M, num_partitions=2 + ) + ) + + epilogue_clusters = [ + # Drain iteration 0: complete compute for K-tile N-2 + tkw.cluster( + [ + epilogue_bitcast_a_itr0_0, + epilogue_bitcast_a_scale_itr0_0, + epilogue_bitcast_b_itr0, + epilogue_bitcast_b_scale_itr0, + tkw.SchedulingBarrier([]), + epilogue_mma_itr0_0, + epilogue_g2v_b, + epilogue_s2v_a_1_itr0, + epilogue_g2v_b_scale, + epilogue_s2v_a_scale_1_itr0, + epilogue_bitcast_a_itr0_1, + epilogue_bitcast_a_scale_itr0_1, + ], + ), + tkw.cluster( + [ + epilogue_mma_itr0_1, + tkw.SchedulingBarrier([]), + epilogue_s2v_a_0, + epilogue_s2v_a_scale_0, + ], + ), + # Drain iteration 1: final compute for K-tile N-1 + tkw.cluster( + [ + epilogue_bitcast_a_itr1_0, + epilogue_bitcast_a_scale_itr1_0, + epilogue_bitcast_b_itr1, + epilogue_bitcast_b_scale_itr1, + tkw.SchedulingBarrier([]), + epilogue_mma_itr1_0, + epilogue_s2v_a_1_itr1, + epilogue_s2v_a_scale_1_itr1, + ], + ), + tkw.cluster( + [ + epilogue_bitcast_a_itr1_1, + epilogue_bitcast_a_scale_itr1_1, + epilogue_mma_itr1_1, + ], + ), + ] + + tkw.reorder_graph(pipeline_loop.PROLOGUE, prologue_clusters) + tkw.reorder_graph(pipeline_loop.KERNEL, kernel_clusters) + tkw.reorder_graph(pipeline_loop.EPILOGUE, epilogue_clusters) + + tkw.insert_at_start( + pipeline_loop.KERNEL, + tkw.MemoryCounterWaitBarrier(load=A_g2s_per_iter, ds=0), + ) + tkw.insert_after( + pipeline_loop.KERNEL, tkw.MemoryCounterWaitBarrier(load=0, ds=0) + ) + + return mxfp4_nounroll_schedule From b782e12ddc6dbd2b0fe024df4e7de047216eb64a Mon Sep 17 00:00:00 2001 From: William G Hatch Date: Mon, 16 Mar 2026 17:01:00 -0600 Subject: [PATCH 2/2] Consolidate asymmetric nounroll schedule into parameterized asymmetric schedule The no-unroll path needs a different kernel interleaving strategy than the unrolled path: 2-group interleaving (shared A loads interleaved with MMA) with B loads and G2S prefetches in a separate third cluster, rather than 4-group interleaving that folds B loads and G2S directly into the two MMA clusters. The 4-group pattern was designed for the unrolled kernel where the larger loop body can absorb the extra live values; with unroll_factor=1 the tighter loop needs the third cluster to keep VGPR pressure in check. --- examples/python/7.1_schedule.py | 15 +- tests/kernel/wave/asm/test_waveasm_e2e.py | 15 +- wave_lang/kernel/wave/schedules/__init__.py | 2 - .../schedules/gemm_mxfp4_double_buffer.py | 526 ++++-------------- 4 files changed, 124 insertions(+), 434 deletions(-) diff --git a/examples/python/7.1_schedule.py b/examples/python/7.1_schedule.py index 2aab8b8b86..ac61b98f78 100755 --- a/examples/python/7.1_schedule.py +++ b/examples/python/7.1_schedule.py @@ -27,7 +27,6 @@ get_mxfp4_dbuf_pingpong_schedule, get_mxfp4_dbuf_mixed_pingpong_schedule, get_mxfp4_asymmetric_schedule, - get_mxfp4_asymmetric_nounroll_schedule, get_mxfp4_dbuf_mixed_pingpong_shuffle_schedule, get_mxfp4_dbuf_pingpong_schedule_Bshuffled, get_mxfp4_dbuf_pingpong_schedule_Bshuffled_lds, @@ -387,14 +386,12 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm_cpp( options.wave_runtime = True options.dump_intermediates = "build/intermediates" options.eliminate_epilogue = eliminate_epilogue - if no_unroll: - schedule = get_mxfp4_asymmetric_nounroll_schedule( - eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True - ) - else: - schedule = get_mxfp4_asymmetric_schedule( - eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True - ) + schedule = get_mxfp4_asymmetric_schedule( + eliminate_epilogue=eliminate_epilogue, + is_bscale_shuffled=True, + unroll_factor=1 if no_unroll else 2, + unroll_kernel=not no_unroll, + ) options.print_ir_after = "all" if is_debug else [] options = set_default_run_config(options) gemm = wave_compile(options, gemm, schedule) diff --git a/tests/kernel/wave/asm/test_waveasm_e2e.py b/tests/kernel/wave/asm/test_waveasm_e2e.py index 98359a7610..5d4eff8700 100644 --- a/tests/kernel/wave/asm/test_waveasm_e2e.py +++ b/tests/kernel/wave/asm/test_waveasm_e2e.py @@ -1169,7 +1169,6 @@ def _dbuf_mxfp4_helper( from wave_lang.kernel.wave.schedules import ( get_mxfp4_dbuf_schedule, get_mxfp4_asymmetric_schedule, - get_mxfp4_asymmetric_nounroll_schedule, ) from wave_lang.kernel.wave.scheduling.schedule_enums import SchedulingType from wave_lang.kernel.wave.utils.run_utils import set_default_run_config @@ -1202,14 +1201,12 @@ def _dbuf_mxfp4_helper( ) options.eliminate_epilogue = eliminate_epilogue if use_schedule: - if no_unroll: - schedule = get_mxfp4_asymmetric_nounroll_schedule( - eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True - ) - else: - schedule = get_mxfp4_asymmetric_schedule( - eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True - ) + schedule = get_mxfp4_asymmetric_schedule( + eliminate_epilogue=eliminate_epilogue, + is_bscale_shuffled=True, + unroll_factor=1 if no_unroll else 2, + unroll_kernel=not no_unroll, + ) else: schedule = None options.schedule = SchedulingType.NONE diff --git a/wave_lang/kernel/wave/schedules/__init__.py b/wave_lang/kernel/wave/schedules/__init__.py index 7640de303e..4ec56863ab 100644 --- a/wave_lang/kernel/wave/schedules/__init__.py +++ b/wave_lang/kernel/wave/schedules/__init__.py @@ -18,7 +18,6 @@ get_mxfp4_dbuf_mixed_pingpong_schedule, get_mxfp4_dbuf_mixed_pingpong_shuffle_schedule, get_mxfp4_asymmetric_schedule, - get_mxfp4_asymmetric_nounroll_schedule, get_mxfp4_dbuf_pingpong_schedule_Bshuffled, get_mxfp4_dbuf_pingpong_schedule_Bshuffled_lds, ) @@ -35,7 +34,6 @@ "get_mxfp4_dbuf_pingpong_schedule_Bshuffled", "get_mxfp4_dbuf_pingpong_schedule_Bshuffled_lds", "get_mxfp4_asymmetric_schedule", - "get_mxfp4_asymmetric_nounroll_schedule", "get_mxfp4_dbuf_mixed_pingpong_shuffle_schedule", "get_attention_prefetch_schedule", ] diff --git a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py index f31d768c58..0a42463441 100755 --- a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py +++ b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py @@ -1573,7 +1573,10 @@ def mxfp4_dbuf_schedule(): def get_mxfp4_asymmetric_schedule( - eliminate_epilogue: bool = False, is_bscale_shuffled: bool = False + eliminate_epilogue: bool = False, + is_bscale_shuffled: bool = False, + unroll_factor: int = 2, + unroll_kernel: bool = True, ): """Return an asymmetric-prefetch MXFP4 schedule for wave_compile(). @@ -1645,7 +1648,7 @@ def mxfp4_dbuf_schedule(): # This forces the pipeline to use double buffering pipeline_loop.multi_buffer_count = 2 - pipeline_loop.unroll_factor = 2 + pipeline_loop.unroll_factor = unroll_factor with pipeline_loop as pl: pl.set_stage( @@ -1783,443 +1786,131 @@ def mxfp4_dbuf_schedule(): # Interleave MFMAs with memory ops (matching aiter f4gemm pattern). # Clamp start_offsets so they fit within each partition when the M # tile count is odd (e.g. 7 tiles split into 4+3). - base_offsets = [0, 3, 2, 0] - base_intervals = [4, 4, 2, 4] - def _clamp_offsets(n, offsets): return [min(o, max(0, n - 1)) for o in offsets] - interleaved_mma_0 = tkw.interleave_operations( - base_ops=loop_scaled_mma_0, - interleaved_ops=[ - loop_g2v_b, - loop_shared_load_a_1, - loop_shared_load_a_scale_1, - loop_g2v_b_scale, - ], - intervals=base_intervals, - start_offsets=_clamp_offsets(len(loop_scaled_mma_0), base_offsets), - start_after_groups=[[], [], [1], [0]], - ) - - interleaved_mma_1 = tkw.interleave_operations( - base_ops=loop_scaled_mma_1, - interleaved_ops=[ - loop_g2s_a, - loop_shared_load_a_0, - loop_shared_load_a_scale_0, - loop_g2s_a_scale, - ], - intervals=base_intervals, - start_offsets=_clamp_offsets(len(loop_scaled_mma_1), base_offsets), - start_after_groups=[[], [], [1], [0]], - ) - - loop_B_g2v_bs = len(loop_g2v_b) + ( - len(loop_g2v_b_scale) // b_scale_shuffling_factor - ) loop_A_s2v_bs = len(loop_g2s_a) + len(loop_g2s_a_scale) - clusters = [ - tkw.cluster( - [ - loop_bitcast_a_0, - loop_bitcast_a_scale_0, - loop_bitcast_b, - loop_bitcast_b_scale, - tkw.SchedulingBarrier([]), - interleaved_mma_0, - tkw.SchedulingBarrier([]), - tkw.MemoryCounterWaitBarrier(load=loop_B_g2v_bs, ds=0), - tkw.SchedulingBarrier([]), - ], - ), - tkw.cluster( - [ - loop_bitcast_a_1, - loop_bitcast_a_scale_1, - tkw.SchedulingBarrier([]), - interleaved_mma_1, - tkw.SchedulingBarrier([]), - tkw.MemoryCounterWaitBarrier(load=loop_A_s2v_bs, ds=0), - tkw.SchedulingBarrier([]), - ] - ), - ] - - if eliminate_epilogue: - clusters += prologue_clusters - tkw.reorder_graph(pipeline_loop.KERNEL, clusters) - else: - epilogue_g2v_b = tkw.filter_nodes(g2v_b, subgraph=pipeline_loop.EPILOGUE) - epilogue_g2v_b_scale = tkw.filter_nodes( - g2v_b_scale, subgraph=pipeline_loop.EPILOGUE - ) - epilogue_s2v_a_0 = tkw.filter_nodes( - s2v_a_0, subgraph=pipeline_loop.EPILOGUE - ) - epilogue_s2v_a_scale_0 = tkw.filter_nodes( - s2v_a_scale_0, subgraph=pipeline_loop.EPILOGUE - ) - epilogue_s2v_a_1 = tkw.filter_nodes( - s2v_a_1, subgraph=pipeline_loop.EPILOGUE - ) - epilogue_s2v_a_scale_1 = tkw.filter_nodes( - s2v_a_scale_1, subgraph=pipeline_loop.EPILOGUE - ) - epilogue_bitcast_a = tkw.filter_nodes( - bitcast_a, subgraph=pipeline_loop.EPILOGUE - ) - epilogue_bitcast_a_scale = tkw.filter_nodes( - bitcast_a_scale, subgraph=pipeline_loop.EPILOGUE - ) - epilogue_bitcast_b = tkw.filter_nodes( - bitcast_b, subgraph=pipeline_loop.EPILOGUE - ) - epilogue_bitcast_b_scale = tkw.filter_nodes( - bitcast_b_scale, subgraph=pipeline_loop.EPILOGUE - ) - epilogue_mma = tkw.filter_nodes(scaled_mma, subgraph=pipeline_loop.EPILOGUE) + if unroll_kernel: + base_offsets = [0, 3, 2, 0] + base_intervals = [4, 4, 2, 4] - def split_by_iteration(nodes, key="name"): - # TODO: Replace name-based splitting with a - # pipeline_drain_iteration attribute (analogous to - # unroll_iteration). expanded_dims can't be used here because - # loop_reconstruction copies them verbatim for both drain - # iterations. - itr0 = [] - itr1 = [] - for node in nodes: - value = getattr(node, key) - if "1_2" in value: - itr0.append(node) - elif "2_2" in value: - itr1.append(node) - else: - raise ValueError(f"Unknown {key} for node: {value}") - return itr0, itr1 - - epilogue_mma_itr0, epilogue_mma_itr1 = split_by_iteration(epilogue_mma) - epilogue_s2v_a_1_itr0, epilogue_s2v_a_1_itr1 = split_by_iteration( - epilogue_s2v_a_1 - ) - ( - epilogue_s2v_a_scale_1_itr0, - epilogue_s2v_a_scale_1_itr1, - ) = split_by_iteration(epilogue_s2v_a_scale_1) - epilogue_bitcast_a_itr0, epilogue_bitcast_a_itr1 = split_by_iteration( - epilogue_bitcast_a - ) - epilogue_bitcast_a_scale_itr0, epilogue_bitcast_a_scale_itr1 = ( - split_by_iteration(epilogue_bitcast_a_scale) - ) - epilogue_bitcast_b_itr0, epilogue_bitcast_b_itr1 = split_by_iteration( - epilogue_bitcast_b - ) - epilogue_bitcast_b_scale_itr0, epilogue_bitcast_b_scale_itr1 = ( - split_by_iteration(epilogue_bitcast_b_scale) + interleaved_mma_0 = tkw.interleave_operations( + base_ops=loop_scaled_mma_0, + interleaved_ops=[ + loop_g2v_b, + loop_shared_load_a_1, + loop_shared_load_a_scale_1, + loop_g2v_b_scale, + ], + intervals=base_intervals, + start_offsets=_clamp_offsets(len(loop_scaled_mma_0), base_offsets), + start_after_groups=[[], [], [1], [0]], ) - epilogue_mma_itr0_0, epilogue_mma_itr0_1 = tkw.partition_by_dim( - epilogue_mma_itr0, dim=M, num_partitions=2 - ) - epilogue_bitcast_a_itr0_0, epilogue_bitcast_a_itr0_1 = tkw.partition_by_dim( - epilogue_bitcast_a_itr0, dim=M, num_partitions=2 - ) - epilogue_bitcast_a_scale_itr0_0, epilogue_bitcast_a_scale_itr0_1 = ( - tkw.partition_by_dim( - epilogue_bitcast_a_scale_itr0, dim=M, num_partitions=2 - ) + interleaved_mma_1 = tkw.interleave_operations( + base_ops=loop_scaled_mma_1, + interleaved_ops=[ + loop_g2s_a, + loop_shared_load_a_0, + loop_shared_load_a_scale_0, + loop_g2s_a_scale, + ], + intervals=base_intervals, + start_offsets=_clamp_offsets(len(loop_scaled_mma_1), base_offsets), + start_after_groups=[[], [], [1], [0]], ) - epilogue_mma_itr1_0, epilogue_mma_itr1_1 = tkw.partition_by_dim( - epilogue_mma_itr1, dim=M, num_partitions=2 - ) - epilogue_bitcast_a_itr1_0, epilogue_bitcast_a_itr1_1 = tkw.partition_by_dim( - epilogue_bitcast_a_itr1, dim=M, num_partitions=2 + loop_B_g2v_bs = len(loop_g2v_b) + ( + len(loop_g2v_b_scale) // b_scale_shuffling_factor ) - epilogue_bitcast_a_scale_itr1_0, epilogue_bitcast_a_scale_itr1_1 = ( - tkw.partition_by_dim( - epilogue_bitcast_a_scale_itr1, dim=M, num_partitions=2 - ) - ) - - epilogue_clusters_itr0 = [ + clusters = [ tkw.cluster( [ - epilogue_bitcast_a_itr0_0, - epilogue_bitcast_a_scale_itr0_0, - epilogue_bitcast_b_itr0, - epilogue_bitcast_b_scale_itr0, + loop_bitcast_a_0, + loop_bitcast_a_scale_0, + loop_bitcast_b, + loop_bitcast_b_scale, + tkw.SchedulingBarrier([]), + interleaved_mma_0, + tkw.SchedulingBarrier([]), + tkw.MemoryCounterWaitBarrier(load=loop_B_g2v_bs, ds=0), tkw.SchedulingBarrier([]), - epilogue_mma_itr0_0, - epilogue_g2v_b, - epilogue_s2v_a_1_itr0, - epilogue_g2v_b_scale, - epilogue_s2v_a_scale_1_itr0, - epilogue_bitcast_a_itr0_1, - epilogue_bitcast_a_scale_itr0_1, ], ), tkw.cluster( [ - epilogue_mma_itr0_1, + loop_bitcast_a_1, + loop_bitcast_a_scale_1, + tkw.SchedulingBarrier([]), + interleaved_mma_1, + tkw.SchedulingBarrier([]), + tkw.MemoryCounterWaitBarrier(load=loop_A_s2v_bs, ds=0), + tkw.SchedulingBarrier([]), + ] + ), + ] + else: + interleaved_mma_0 = tkw.interleave_operations( + base_ops=loop_scaled_mma_0, + interleaved_ops=[ + loop_shared_load_a_1, + loop_shared_load_a_scale_1, + ], + intervals=[4, 2], + start_offsets=_clamp_offsets(len(loop_scaled_mma_0), [3, 2]), + start_after_groups=[[], [0]], + ) + + interleaved_mma_1 = tkw.interleave_operations( + base_ops=loop_scaled_mma_1, + interleaved_ops=[ + loop_shared_load_a_0, + loop_shared_load_a_scale_0, + ], + intervals=[4, 2], + start_offsets=_clamp_offsets(len(loop_scaled_mma_1), [3, 2]), + start_after_groups=[[], [0]], + ) + + clusters = [ + tkw.cluster( + [ + loop_bitcast_a_0, + loop_bitcast_a_scale_0, + loop_bitcast_b, + loop_bitcast_b_scale, + tkw.SchedulingBarrier([]), + interleaved_mma_0, tkw.SchedulingBarrier([]), - epilogue_s2v_a_0, - epilogue_s2v_a_scale_0, ], ), tkw.cluster( [ - epilogue_bitcast_a_itr1_0, - epilogue_bitcast_a_scale_itr1_0, - epilogue_bitcast_b_itr1, - epilogue_bitcast_b_scale_itr1, + loop_bitcast_a_1, + loop_bitcast_a_scale_1, + tkw.SchedulingBarrier([]), + interleaved_mma_1, tkw.SchedulingBarrier([]), - epilogue_mma_itr1_0, - epilogue_s2v_a_1_itr1, - epilogue_s2v_a_scale_1_itr1, ], ), tkw.cluster( [ - epilogue_bitcast_a_itr1_1, - epilogue_bitcast_a_scale_itr1_1, - epilogue_mma_itr1_1, + loop_g2v_b, + loop_g2v_b_scale, + loop_g2s_a, + loop_g2s_a_scale, + tkw.SchedulingBarrier([]), + tkw.MemoryCounterWaitBarrier(load=loop_A_s2v_bs, ds=0), + tkw.SchedulingBarrier([]), ], ), ] - tkw.reorder_graph(pipeline_loop.PROLOGUE, prologue_clusters) - tkw.reorder_graph(pipeline_loop.KERNEL, clusters) - unroll_factor = 2 - tkw.unroll(pipeline_loop.KERNEL, unroll_factor) - - tkw.insert_at_start( - pipeline_loop.KERNEL, - tkw.MemoryCounterWaitBarrier(load=A_g2s_per_iter, ds=0), - ) - tkw.insert_after( - pipeline_loop.KERNEL, tkw.MemoryCounterWaitBarrier(load=0, ds=0) - ) - - return mxfp4_dbuf_schedule - - -def get_mxfp4_asymmetric_nounroll_schedule( - eliminate_epilogue: bool = False, is_bscale_shuffled: bool = False -): - """Asymmetric-prefetch MXFP4 schedule with unroll_factor=1. - - Same 3-stage pipeline as get_mxfp4_asymmetric_schedule but without - kernel body unrolling. This keeps VGPR pressure low enough for large - tiles (e.g. 256x224x256 with wave_shape 2x2) while still using a - standard epilogue for pipeline draining. - """ - M = tkl.sym.M - - @wave_schedule.wave_schedule() - def mxfp4_nounroll_schedule(): - k_loop = tkw.get_node_by_tag("k_loop") - - all_read_a = tkw.get_node_by_tag("read_a") - g2s_a = tkw.filter_nodes(all_read_a, node_type=tkw.GatherToLDS) - s2v_a = tkw.filter_nodes(all_read_a, node_type=tkw.Read) - - all_read_a_scale = tkw.get_node_by_tag("read_a_scale") - g2s_a_scale = tkw.filter_nodes(all_read_a_scale, node_type=tkw.GatherToLDS) - s2v_a_scale = tkw.filter_nodes(all_read_a_scale, node_type=tkw.Read) - - s2v_a_0, s2v_a_1 = tkw.partition_by_dim(s2v_a, dim=M, num_partitions=2) - s2v_a_scale_0, s2v_a_scale_1 = tkw.partition_by_dim( - s2v_a_scale, dim=M, num_partitions=2 - ) - - g2v_b = tkw.get_node_by_tag("read_b") - g2v_b_scale = tkw.get_node_by_tag("read_b_scale") - - bitcast_a = tkw.get_node_by_tag("bitcast_a") - bitcast_a_scale = tkw.get_node_by_tag("bitcast_a_scale") - bitcast_b = tkw.get_node_by_tag("bitcast_b") - bitcast_b_scale = tkw.get_node_by_tag("bitcast_b_scale") - - scaled_mma = tkw.get_node_by_tag("scaled_mma") - - pipeline_loop = tkw.pipeline(k_loop, eliminate_epilogue=eliminate_epilogue) - pipeline_loop.multi_buffer_count = 2 - pipeline_loop.unroll_factor = 1 - - with pipeline_loop as pl: - pl.set_stage( - [ - (g2s_a, g2s_a_scale), - (), - (), - ], - ) - pl.set_stage( - [ - (g2v_b, g2v_b_scale), - (s2v_a_0, s2v_a_scale_0), - (), - ], - ) - pl.set_stage( - [ - (s2v_a_1, s2v_a_scale_1), - (bitcast_a, bitcast_a_scale, bitcast_b, bitcast_b_scale), - (scaled_mma,), - ], - ) - - if is_bscale_shuffled: - b_scale_shuffling_factor = 4 - else: - b_scale_shuffling_factor = 1 - - num_pf_iters = 2 - - # ----------------------------------------------------------------- - # Prologue - # ----------------------------------------------------------------- - prologue_g2s_a = tkw.filter_nodes(g2s_a, subgraph=pipeline_loop.PROLOGUE) - prologue_g2s_a_scale = tkw.filter_nodes( - g2s_a_scale, subgraph=pipeline_loop.PROLOGUE - ) - prologue_g2v_b = tkw.filter_nodes(g2v_b, subgraph=pipeline_loop.PROLOGUE) - prologue_g2v_b_scale = tkw.filter_nodes( - g2v_b_scale, subgraph=pipeline_loop.PROLOGUE - ) - prologue_s2v_a_0 = tkw.filter_nodes(s2v_a_0, subgraph=pipeline_loop.PROLOGUE) - prologue_s2v_a_scale_0 = tkw.filter_nodes( - s2v_a_scale_0, subgraph=pipeline_loop.PROLOGUE - ) - - A_g2s_total = len(prologue_g2s_a) + len(prologue_g2s_a_scale) - A_g2s_per_iter = A_g2s_total // num_pf_iters - B_g2v_prologue = len(prologue_g2v_b) + ( - len(prologue_g2v_b_scale) // b_scale_shuffling_factor - ) - - prologue_clusters = [ - tkw.cluster( - [ - prologue_g2s_a, - prologue_g2s_a_scale, - prologue_g2v_b, - tkw.SchedulingBarrier([]), - prologue_g2v_b_scale, - tkw.SchedulingBarrier([]), - tkw.MemoryCounterWaitBarrier(load=0), - tkw.SchedulingBarrier([]), - prologue_s2v_a_0, - prologue_s2v_a_scale_0, - ], - ) - ] - - # ----------------------------------------------------------------- - # Kernel (main loop body) - # ----------------------------------------------------------------- - loop_g2s_a = tkw.filter_nodes(g2s_a, subgraph=pipeline_loop.KERNEL) - loop_g2s_a_scale = tkw.filter_nodes(g2s_a_scale, subgraph=pipeline_loop.KERNEL) - loop_g2v_b = tkw.filter_nodes(g2v_b, subgraph=pipeline_loop.KERNEL) - loop_g2v_b_scale = tkw.filter_nodes(g2v_b_scale, subgraph=pipeline_loop.KERNEL) - loop_shared_load_a_0 = tkw.filter_nodes(s2v_a_0, subgraph=pipeline_loop.KERNEL) - loop_shared_load_a_scale_0 = tkw.filter_nodes( - s2v_a_scale_0, subgraph=pipeline_loop.KERNEL - ) - loop_shared_load_a_1 = tkw.filter_nodes(s2v_a_1, subgraph=pipeline_loop.KERNEL) - loop_shared_load_a_scale_1 = tkw.filter_nodes( - s2v_a_scale_1, subgraph=pipeline_loop.KERNEL - ) - loop_bitcast_a = tkw.filter_nodes(bitcast_a, subgraph=pipeline_loop.KERNEL) - loop_bitcast_a_scale = tkw.filter_nodes( - bitcast_a_scale, subgraph=pipeline_loop.KERNEL - ) - loop_bitcast_b = tkw.filter_nodes(bitcast_b, subgraph=pipeline_loop.KERNEL) - loop_bitcast_b_scale = tkw.filter_nodes( - bitcast_b_scale, subgraph=pipeline_loop.KERNEL - ) - loop_scaled_mma = tkw.filter_nodes(scaled_mma, subgraph=pipeline_loop.KERNEL) - - loop_scaled_mma_0, loop_scaled_mma_1 = tkw.partition_by_dim( - loop_scaled_mma, dim=M, num_partitions=2 - ) - loop_bitcast_a_0, loop_bitcast_a_1 = tkw.partition_by_dim( - loop_bitcast_a, dim=M, num_partitions=2 - ) - loop_bitcast_a_scale_0, loop_bitcast_a_scale_1 = tkw.partition_by_dim( - loop_bitcast_a_scale, dim=M, num_partitions=2 - ) - - interleaved_mma_0 = tkw.interleave_operations( - base_ops=loop_scaled_mma_0, - interleaved_ops=[ - loop_shared_load_a_1, - loop_shared_load_a_scale_1, - ], - intervals=[4, 2], - start_offsets=[3, 2], - start_after_groups=[[], [0]], - ) - - interleaved_mma_1 = tkw.interleave_operations( - base_ops=loop_scaled_mma_1, - interleaved_ops=[ - loop_shared_load_a_0, - loop_shared_load_a_scale_0, - ], - intervals=[4, 2], - start_offsets=[3, 2], - start_after_groups=[[], [0]], - ) - - loop_A_s2v_bs = len(loop_g2s_a) + len(loop_g2s_a_scale) - kernel_clusters = [ - tkw.cluster( - [ - loop_bitcast_a_0, - loop_bitcast_a_scale_0, - loop_bitcast_b, - loop_bitcast_b_scale, - tkw.SchedulingBarrier([]), - interleaved_mma_0, - tkw.SchedulingBarrier([]), - ], - ), - tkw.cluster( - [ - loop_bitcast_a_1, - loop_bitcast_a_scale_1, - tkw.SchedulingBarrier([]), - interleaved_mma_1, - tkw.SchedulingBarrier([]), - ], - ), - tkw.cluster( - [ - loop_g2v_b, - loop_g2v_b_scale, - loop_g2s_a, - loop_g2s_a_scale, - tkw.SchedulingBarrier([]), - tkw.MemoryCounterWaitBarrier(load=loop_A_s2v_bs, ds=0), - tkw.SchedulingBarrier([]), - ], - ), - ] - if eliminate_epilogue: - kernel_clusters += prologue_clusters - tkw.reorder_graph(pipeline_loop.KERNEL, kernel_clusters) + clusters += prologue_clusters + tkw.reorder_graph(pipeline_loop.KERNEL, clusters) else: - # ----------------------------------------------------------------- - # Epilogue: two drain iterations for the 3-stage pipeline. - # Schedule drain 0's compute before drain 1's loads so both - # iterations' live registers don't overlap and exceed the - # 256 VGPR budget. - # ----------------------------------------------------------------- epilogue_g2v_b = tkw.filter_nodes(g2v_b, subgraph=pipeline_loop.EPILOGUE) epilogue_g2v_b_scale = tkw.filter_nodes( g2v_b_scale, subgraph=pipeline_loop.EPILOGUE @@ -2248,9 +1939,15 @@ def mxfp4_nounroll_schedule(): epilogue_bitcast_b_scale = tkw.filter_nodes( bitcast_b_scale, subgraph=pipeline_loop.EPILOGUE ) + epilogue_mma = tkw.filter_nodes(scaled_mma, subgraph=pipeline_loop.EPILOGUE) def split_by_iteration(nodes, key="name"): + # TODO: Replace name-based splitting with a + # pipeline_drain_iteration attribute (analogous to + # unroll_iteration). expanded_dims can't be used here because + # loop_reconstruction copies them verbatim for both drain + # iterations. itr0 = [] itr1 = [] for node in nodes: @@ -2267,9 +1964,10 @@ def split_by_iteration(nodes, key="name"): epilogue_s2v_a_1_itr0, epilogue_s2v_a_1_itr1 = split_by_iteration( epilogue_s2v_a_1 ) - epilogue_s2v_a_scale_1_itr0, epilogue_s2v_a_scale_1_itr1 = ( - split_by_iteration(epilogue_s2v_a_scale_1) - ) + ( + epilogue_s2v_a_scale_1_itr0, + epilogue_s2v_a_scale_1_itr1, + ) = split_by_iteration(epilogue_s2v_a_scale_1) epilogue_bitcast_a_itr0, epilogue_bitcast_a_itr1 = split_by_iteration( epilogue_bitcast_a ) @@ -2307,8 +2005,7 @@ def split_by_iteration(nodes, key="name"): ) ) - epilogue_clusters = [ - # Drain iteration 0: complete compute for K-tile N-2 + epilogue_clusters_itr0 = [ tkw.cluster( [ epilogue_bitcast_a_itr0_0, @@ -2333,7 +2030,6 @@ def split_by_iteration(nodes, key="name"): epilogue_s2v_a_scale_0, ], ), - # Drain iteration 1: final compute for K-tile N-1 tkw.cluster( [ epilogue_bitcast_a_itr1_0, @@ -2356,8 +2052,10 @@ def split_by_iteration(nodes, key="name"): ] tkw.reorder_graph(pipeline_loop.PROLOGUE, prologue_clusters) - tkw.reorder_graph(pipeline_loop.KERNEL, kernel_clusters) - tkw.reorder_graph(pipeline_loop.EPILOGUE, epilogue_clusters) + tkw.reorder_graph(pipeline_loop.KERNEL, clusters) + tkw.reorder_graph(pipeline_loop.EPILOGUE, epilogue_clusters_itr0) + if unroll_kernel: + tkw.unroll(pipeline_loop.KERNEL, unroll_factor) tkw.insert_at_start( pipeline_loop.KERNEL, @@ -2367,4 +2065,4 @@ def split_by_iteration(nodes, key="name"): pipeline_loop.KERNEL, tkw.MemoryCounterWaitBarrier(load=0, ds=0) ) - return mxfp4_nounroll_schedule + return mxfp4_dbuf_schedule