From 133a93c0a29d3bdffc040c66f61b6770c4454f01 Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Fri, 20 Mar 2026 10:23:40 -0700 Subject: [PATCH 1/5] Waveasm + frontend fixes for functional 256x192 MXFP4 GEMM (#1159) This branch contains various fixes needed to get a functional MXFP4 GEMM through waveasm. --------- Signed-off-by: Sanket Pandit --- docs/scc-sgpr-promotion-investigation.md | 180 +++++++ examples/python/7.1_schedule.py | 15 +- examples/python/test_sympy_diff.py | 155 ++++++ .../wave/dynamic_shapes_preshuffle_mxfp4.py | 27 +- .../unittests/index_mapping_simplify_test.py | 50 +- tests/unittests/simplify_floordiv_test.py | 20 +- wave_lang/kernel/_support/indexing.py | 57 +- .../compiler/wave_codegen/read_write.py | 444 ++++++++++----- .../wave/analysis/annotate_iv_strides.py | 70 +++ .../analysis/partition_strided_operators.py | 13 +- wave_lang/kernel/wave/compile.py | 22 +- .../kernel/wave/index_mapping_simplify.py | 147 +++-- wave_lang/kernel/wave/opsel_scaled_mfma.py | 291 +++++++++- .../kernel/wave/preshuffle_scale_to_shared.py | 5 - .../schedules/gemm_mxfp4_double_buffer.py | 4 +- .../wave/scheduling/loop_reconstruction.py | 12 +- wave_lang/kernel/wave/scheduling/schedule.py | 2 - .../wave/templates/tagged_mxfp4_gemm.py | 4 + wave_lang/kernel/wave/utils/mapping_utils.py | 403 ++++++++++++++ wave_lang/kernel/wave/utils/symbol_utils.py | 179 ++++--- .../waveasm/Dialect/WaveASMInterfaces.h | 14 + .../waveasm/Dialect/WaveASMInterfaces.td | 6 + waveasm/include/waveasm/Dialect/WaveASMOps.td | 146 +++-- waveasm/include/waveasm/Transforms/Passes.td | 34 +- waveasm/include/waveasm/Transforms/RegAlloc.h | 145 ++--- .../waveasm/Transforms/TranslateFromMLIR.h | 8 +- waveasm/lib/Transforms/AssemblyEmitter.cpp | 50 +- .../BufferLoadStrengthReduction.cpp | 48 ++ waveasm/lib/Transforms/CMakeLists.txt | 9 +- waveasm/lib/Transforms/LinearScanPass.cpp | 131 +++++ waveasm/lib/Transforms/LinearScanRegAlloc.cpp | 45 +- .../lib/Transforms/LiteralMaterialization.cpp | 9 +- waveasm/lib/Transforms/Liveness.cpp | 67 +-- waveasm/lib/Transforms/SCCVerifier.cpp | 138 +++++ waveasm/lib/Transforms/ScopedCSE.cpp | 5 + waveasm/lib/Transforms/TranslateFromMLIR.cpp | 61 ++- waveasm/lib/Transforms/VGPRCompaction.cpp | 505 ++++++++++++++++++ .../Transforms/handlers/AMDGPUHandlers.cpp | 40 +- .../Transforms/handlers/AffineHandlers.cpp | 105 ++-- .../lib/Transforms/handlers/ArithHandlers.cpp | 309 ++++++++--- waveasm/lib/Transforms/handlers/Handlers.h | 258 +++++++++ .../Transforms/handlers/MemRefHandlers.cpp | 9 +- .../waveasm-translate/waveasm-translate.cpp | 8 + waveasm/waveasm_e2e.py | 2 + 44 files changed, 3599 insertions(+), 653 deletions(-) create mode 100644 docs/scc-sgpr-promotion-investigation.md create mode 100644 examples/python/test_sympy_diff.py mode change 100644 => 100755 wave_lang/kernel/compiler/wave_codegen/read_write.py create mode 100644 wave_lang/kernel/wave/analysis/annotate_iv_strides.py create mode 100644 waveasm/lib/Transforms/SCCVerifier.cpp create mode 100644 waveasm/lib/Transforms/VGPRCompaction.cpp diff --git a/docs/scc-sgpr-promotion-investigation.md b/docs/scc-sgpr-promotion-investigation.md new file mode 100644 index 0000000000..b868032975 --- /dev/null +++ b/docs/scc-sgpr-promotion-investigation.md @@ -0,0 +1,180 @@ +# SCC & SGPR Promotion Investigation + +## Summary + +This document captures the findings from investigating SCC (Scalar Condition Code) tracking and VGPR→SGPR promotion in the WaveASM backend. The work spans SCC infrastructure, handler-level SALU promotions, and the SRD scalar select issue (now resolved — root cause was a register allocator bug where initial SRD precolored registers were DCE'd). + +## What Was Built + +### SCC Infrastructure (committed, working) + +- **SCCDef / SCCUse traits** on all SALU ops — every op now declares whether it writes or reads SCC +- **SCC verifier pass** (`--waveasm-scc-verifier`) — walks IR in emission order, catches SCC clobbers between producer and consumer +- Key insight: `Pure` in MLIR ODS = `NoMemoryEffect + AlwaysSpeculatableImplTrait`. Composing these with `SCCDef` gives identical MLIR pass behavior while enabling SCC verification. Adding bare `NativeOpTrait` to existing ODS classes changes tablegen output and alters MLIR pass behavior — must use the explicit composition. +- `S_CSELECT_B32` carries `SCCUse` trait (not Pure, not CSE-eligible — result depends on implicit SCC) + +### Handler SALU Promotions (committed, working) + +| Handler | Change | Impact | +|---------|--------|--------| +| `handleArithOrI` | `emitOr` helper (S_OR_B32 when both scalar) | No trigger in GEMM kernel | +| `handleArithXorI` | `emitXor` helper (S_XOR_B32) | No trigger | +| `handleArithDivUI` | Uses `emitLshr` (has scalar path) | No trigger | +| `handleArithRemUI` | Uses `emitAnd` (has scalar path) | No trigger | +| `handleArithMaxSI/UI` | `emitMaxI32/U32` (S_MAX_I32 when scalar) | **1 v_max → s_max** | +| `handleArithMinSI/UI` | `emitMinI32/U32` (S_MIN_I32 when scalar) | No trigger | +| `handleArithCmpI` | SGPR accepted in V_CMP (constant bus) | **2 v_mov_b32 eliminated** | +| `handleArithSelect` | SGPR cond accepted in V_CMP_NE_U32 | No trigger | +| AffineHandlers OR | `emitOr` instead of `ensureBothVGPR + V_OR_B32` | No trigger | + +### `emitScalarCmp` Helper (committed) + +Shared inline function in `Handlers.h` that emits `S_CMP_*` for any `arith::CmpIPredicate`. Used by `handleArithCmpI` and available for `AMDGPUHandlers.cpp`. + +## The VGPR Pressure Problem + +### Numbers (256x192 GEMM, our kernel vs aiter reference) + +| Category | Ours | Reference | Delta | +|----------|------|-----------|-------| +| VALU address/control in loop | 21 VGPRs | 0 | **+21** | +| MFMA scale operands | 24 | 10 | +14 | +| buffer_load voffset addresses | 24 | 10 | +14 | +| Peak arch VGPRs | ~276 | ~227 | **+49** | + +The biggest single win is eliminating VALU from the loop body (+21 VGPRs). + +### Root Cause: cmpi→select→fat_raw_buffer_cast Chain + +In the MLIR, the loop body has: + +```mlir +%next_K = affine.apply (s0 + 2)[%loop_iv] // scalar +%K_bound = affine.apply (s0 ceildiv 256)[%K] // scalar +%cond = arith.cmpi slt, %next_K, %K_bound // uniform comparison +%validBytes = arith.select %cond, %bufSize, 0 // branchless guard +%srd = amdgpu.fat_raw_buffer_cast ... validBytes(%validBytes) +``` + +All values are provably uniform (scalar). The aiter reference emits this as 2 SALU ops: + +```asm +s_cmp_lt_u32 0x200, s51 ; scalar comparison +s_cselect_b32 s61, s61, 0 ; scalar select → SRD soffset +``` + +Our code emits 7 VALU ops + v_readfirstlane because: +1. `handleArithCmpI` emits `V_CMP + materializeVCCToBoolVGPR` (VGPR boolean) +2. `handleArithSelect` emits `V_CMP_NE_U32 + V_CNDMASK_B32` (VGPR result) +3. `emitSrdNumRecords` does `v_readfirstlane_b32` to extract back to SGPR + +## What Was Tried (and Failed) + +### Approach 1: CmpI Fusion in handleArithCmpI + +**Idea:** When both cmpi operands are scalar, emit `s_cmp + s_cselect(1, 0)` to produce an SGPR boolean instead of VGPR. + +**Result:** Memory fault. The SGPR boolean propagates through `arith.extui` and `arith.addi` chains, changing the type of downstream values from VGPR to SGPR. This cascading type change alters register allocation for buffer descriptors, corrupting SRD addresses. + +**Why:** The SGPR result enters the mapper and changes every downstream consumer's type decision. Values that were VGPR (with v_readfirstlane extraction) become SGPR, shifting the entire allocation picture. + +### Approach 2: CmpI Fusion in handleArithSelect + +**Idea:** Add a fusion path that detects `arith.select(arith.cmpi(scalar, scalar), scalar, scalar)` and emits `s_cmp + s_cselect` directly. + +**Result:** Memory fault when the select result feeds non-SRD consumers (cascading SGPR changes). With user-safety guards (only fire for specific downstream patterns), the fusion never triggers for the loop-body pattern. + +### Approach 3: Targeted SRD Scalar Select in emitSrdNumRecords + +**Idea:** Only in `emitSrdNumRecords`, detect the `arith.select(arith.cmpi)` pattern feeding `fat_raw_buffer_cast`'s `validBytes` and emit `s_cmp + s_cselect` directly into the precolored SRD word 2 (`PSRegType`). + +**Result:** Memory fault ("Write access to a read-only page"). Tried four emission variants: +- Direct `S_CSELECT_B32` into `PSRegType(srdBase+2)` with `DCEProtectOp` +- `S_CSELECT_B32` into virtual SGPR → `S_MOV_B32` copy to PSReg +- `S_CSELECT_B32` into virtual SGPR → `S_ADD_U32` copy to PSReg (clobbers SCC — eliminated as cause) +- Fresh `S_MOV_B32` copy of trueOp inside loop body before `S_CSELECT_B32` + +All produce correct-looking assembly. The SCC verifier reports no hazards. Register allocation appears correct (validBytes SGPR kept alive from prologue through loop body, not clobbered by intermediates). + +### Root Cause (RESOLVED) + +The assembly emitted by all three approaches was semantically correct. +The real bug was in the register allocator's interaction with `RawOp`-based +SRD setup: + +**Bug chain:** + +1. `emitSRDPrologue` creates a `PrecoloredSRegOp` for each initial SRD + (e.g., arg4 output buffer at s[36:39]) and fills it via `RawOp`s + (`s_mov_b64`, `s_mov_b32`). + +2. The SSA result of the `PrecoloredSRegOp` is mapped to the memref via + `mapper.mapValue()`. For some SRDs (data/scale buffers used inside the + loop), downstream SSA uses keep the value alive. For others (e.g., the + output buffer SRD), the only downstream references are `RawOp`s that + copy the physical registers directly (e.g., `s_mov_b64 s[80:81], s[36:37]` + in the epilogue). These `RawOp`s create no SSA uses. + +3. `CanonicalizerPass` (DCE) removes the unused `PrecoloredSRegOp`. + +4. `LinearScanPass` walks `PrecoloredSRegOp` ops to build `reservedSGPRs`. + Since the op was DCE'd, s[36:39] is NOT reserved. + +5. The linear scan allocator freely assigns s36/s37 to loop body temp + values (soffset, next-K computation), **clobbering the SRD base address**. + +6. After the loop, the epilogue `RawOp` copies garbage from s[36:37] to + s[80:81], forming a corrupted buffer descriptor → GPU fault. + +**Why it only manifested with s_cselect:** The s_cselect change referenced +the SGPR validBytes value inside the loop, extending its liveness from the +prologue through the loop body. This increased SGPR pressure forced the +linear scan allocator to reclaim s[36:39] (which it thought was free). +In the baseline, the allocator had enough free SGPRs and happened to never +assign anything to s[36:39]. + +**Fix:** Add `DCEProtectOp` after each initial SRD's `PrecoloredSRegOp` +in `emitSRDPrologue`. This prevents DCE from removing the op, so +`LinearScanPass` always reserves the physical SRD registers. + +```cpp +auto srdReg = PrecoloredSRegOp::create(builder, loc, srdType, srdBase, 4); +// ... RawOps for s_mov_b64, s_mov_b32 ... +DCEProtectOp::create(builder, loc, srdReg); // prevent DCE +``` + +**Impact:** The fix is correct for all kernels, not just the s_cselect case. +Any kernel where SGPR pressure pushes the allocator into the SRD register +range would have hit this bug. The SCC/s_cselect work merely exposed it. + +## V_CNDMASK_B32 Constant Bus Issue + +Attempted to allow one SGPR in `V_CNDMASK_B32` (VOP3 constant bus allows one SGPR + VCC). This caused test failures — `v_cndmask_b32` VOP3 encoding with VCC counts as one constant bus slot, and adding an SGPR as src0/src1 requires a SECOND slot, exceeding the limit on some instructions. + +## Files Modified + +| File | Changes | +|------|---------| +| `WaveASMInterfaces.h/.td` | SCCDef, SCCUse trait definitions | +| `WaveASMOps.td` | SCC-aware op classes, trait assignments | +| `SCCVerifier.cpp` | SCC verification pass | +| `ScopedCSE.cpp` | SCCUse exclusion from CSE | +| `Handlers.h` | emitOr, emitXor, emitMinI32/U32, emitMaxI32/U32, emitScalarCmp | +| `ArithHandlers.cpp` | SALU paths for OR/XOR/Min/Max/DivUI/RemUI, V_CMP constant bus | +| `AffineHandlers.cpp` | emitOr for non-overlapping Add | +| `AMDGPUHandlers.cpp` | s_cmp + s_cselect path in emitSrdNumRecords | +| `TranslateFromMLIR.cpp` | DCEProtect for initial SRD PrecoloredSRegOps | +| `Passes.td` | SCC verifier pass registration | +| `CMakeLists.txt` | SCCVerifier.cpp | +| `compile.py` | `--waveasm-scc-verifier` in pipeline | +| `waveasm_e2e.py` | Same | + +## Commits + +1. `d54817dc` — SCC verifier pass and trait infrastructure +2. `bc9726f8` — Migrate all SALU ops to SCC-aware trait classes +3. `d3e69752` — SALU paths for OR, XOR, DivUI, RemUI handlers +4. `3bcf0b85` — Eliminate unnecessary SGPR-to-VGPR coercions before V_CMP +5. `c22ce840` — emitScalarCmp helper and emitMin/emitMax SALU promotion +6. `8c65e6e0` — Move emitScalarCmp to Handlers.h +7. `963444f3` — SRD scalar select TODO with failure analysis diff --git a/examples/python/7.1_schedule.py b/examples/python/7.1_schedule.py index be229ee4da..d97102576c 100755 --- a/examples/python/7.1_schedule.py +++ b/examples/python/7.1_schedule.py @@ -375,7 +375,7 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm_cpp( eliminate_epilogue=True, ): """Preshuffle-B MXFP4 GEMM using C++ WaveASM backend.""" - gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(1, 4)) + gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(2, 2), reorder_workgroups=True) options.backend = "asm" options.use_buffer_ops = True options.wave_runtime = True @@ -394,7 +394,6 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm_cpp( f"MXFP GEMM preshuffle-B 4-wave (WaveASM) epilogue elimination={eliminate_epilogue} PASSED" ) - def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm( is_debug=False, shape=(1024, 1024, 8192), @@ -428,12 +427,10 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm( is_debug=False, shape=(1024, 1024, 8192), block=(128, 256, 256), - eliminate_epilogue=False, + eliminate_epilogue=True, ): """Preshuffle-B MXFP4 GEMM with dynamic M, N, K.""" - gemm, options = get_tagged_mxfp4_gemm_preshuffle_b( - shape, block, wave_shape=(1, 4), reorder_workgroups=False - ) + gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(2, 2), reorder_workgroups=True) # Make M, N, K dynamic so the compiler does not specialize on problem size. dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] for sym in dynamic_symbols: @@ -445,6 +442,8 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm( options.wave_runtime = True options.eliminate_epilogue = eliminate_epilogue options.dump_intermediates = "build/intermediates/" + options.print_mlir_file = "gemm_mxfp4_dbuf_4wave_asymmetric.mlir" + options.print_mlir = True schedule = get_mxfp4_asymmetric_schedule( eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True ) @@ -453,9 +452,7 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm( gemm = wave_compile(options, gemm, schedule) _run_mxfp_gemm_preshuffle(gemm, shape, all=True) - print( - "MXFP GEMM preshuffle-B 4-wave dynamic M, N, K (WaveASM backend) test passed!" - ) + print("MXFP GEMM preshuffle-B 4-wave dynamic M, N, K (WaveASM backend) test passed!") if __name__ == "__main__": diff --git a/examples/python/test_sympy_diff.py b/examples/python/test_sympy_diff.py new file mode 100644 index 0000000000..a9efd6a6b2 --- /dev/null +++ b/examples/python/test_sympy_diff.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +""" +Test suite for mem_simplify — the domain-specific memory access simplifier. + +Verifies that mem_simplify handles all the algebraic identities needed for +tensor memory indexing without relying on sympy.simplify(). + +Usage inside docker: + source /workspace/wave/.venv/bin/activate + python test_sympy_diff.py +""" + +import sympy +from sympy import Symbol, Mod, floor, simplify, expand, Integer + +from wave_lang.kernel.wave.utils.mapping_utils import ( + mem_simplify, + linearize_dims, + _eval_concrete_floor_mod, + _expand_mod, +) + +print(f"SymPy version: {sympy.__version__}") +print("=" * 70) + +K = Symbol("_K_div_256", integer=True, positive=True) +idx0 = Symbol("$index0", integer=True, nonnegative=True) +idx1 = Symbol("$index1", integer=True, nonnegative=True) + +PASS = 0 +FAIL = 0 + + +def check(name, got, expected): + global PASS, FAIL + ok = sympy.expand(got - expected) == 0 if not isinstance(got, bool) else got == expected + if ok: + PASS += 1 + else: + FAIL += 1 + print(f" [FAIL] {name}") + print(f" got: {got}") + print(f" expected: {expected}") + print() + + +# ── 1. The fundamental identity: a*floor(b/a) + Mod(b,a) == b ────────── +print("\n── 1. mem_simplify on floor/Mod identity ──") + +for b_val in [32, 64, 96, 128, 256]: + expr = K * floor(b_val / K) + Mod(b_val, K) + check(f"K*floor({b_val}/K)+Mod({b_val},K)", mem_simplify(expr), b_val) + +for b_val in [32, 64, 96, 128]: + expr = 8 * K * floor(b_val / K) + 8 * Mod(b_val, K) + check(f"8*K*floor({b_val}/K)+8*Mod({b_val},K)", mem_simplify(expr), 8 * b_val) + +print(f" Section 1: {PASS} passed") +p1 = PASS + +# ── 2. linearize_dims: the 2D-to-flat round-trip ─────────────────────── +print("\n── 2. linearize_dims ──") + +Ks = 8 * K +for flat_val in [0, 256, 512, 768, 1024]: + dim0 = floor(flat_val / Ks) + dim1 = Mod(flat_val, Ks) + result = linearize_dims([dim0, dim1], [Ks, Integer(1)]) + check(f"linearize_dims(flat={flat_val})", result, flat_val) + +E = 64 * Mod(idx0, 4) + 4 * Mod(idx1, 16) + 256 * floor(idx0 / 8) +dim0_sym = floor(E / Ks) +dim1_sym = Mod(E, Ks) +result_sym = linearize_dims([dim0_sym, dim1_sym], [Ks, Integer(1)]) +check("linearize_dims(symbolic flat)", mem_simplify(result_sym - E), 0) + +print(f" Section 2: {PASS - p1} passed") +p2 = PASS + +# ── 3. _eval_concrete_floor_mod ──────────────────────────────────────── +print("\n── 3. Concrete floor/Mod evaluation ──") + +check("floor(32/8)", _eval_concrete_floor_mod(floor(Integer(32) / 8)), Integer(4)) +check("floor(7/3)", _eval_concrete_floor_mod(floor(Integer(7) / 3)), Integer(2)) +check("Mod(256, 64)", _eval_concrete_floor_mod(Mod(256, 64)), Integer(0)) +check("Mod(10, 3)", _eval_concrete_floor_mod(Mod(10, 3)), Integer(1)) +check("Mod(0, K)", _eval_concrete_floor_mod(Mod(0, K)), Integer(0)) +check("floor(sym) unchanged", _eval_concrete_floor_mod(floor(K / 8)), floor(K / 8)) + +print(f" Section 3: {PASS - p2} passed") +p3 = PASS + +# ── 4. mem_simplify preserves already-simple expressions ─────────────── +print("\n── 4. Passthrough cases ──") + +check("integer", mem_simplify(Integer(42)), Integer(42)) +check("symbol", mem_simplify(K), K) +check("linear expr", mem_simplify(3 * K + 7), 3 * K + 7) +check("floor(sym/n) unchanged", mem_simplify(floor(K / 8)), floor(K / 8)) + +print(f" Section 4: {PASS - p3} passed") +p4 = PASS + +# ── 5. _expand_mod ───────────────────────────────────────────────────── +print("\n── 5. _expand_mod ──") + +check("expand_mod(Mod(32,K))", _expand_mod(Mod(32, K)), 32 - K * floor(32 / K)) +expr5 = 8 * K * floor(32 / K) + 8 * Mod(32, K) +check("expand_mod + expand cancels", expand(_expand_mod(expr5)), Integer(256)) + +print(f" Section 5: {PASS - p4} passed") +p5 = PASS + +# ── 6. The full scale-mapping simulation ─────────────────────────────── +print("\n── 6. Scale mapping addr simulation ──") + +b_scale_flat = ( + 256 * K * floor(idx1 / 32) + + 64 * Mod(idx0, 4) + + 4 * Mod(idx1, 16) + + 256 * floor(idx0 / 8) + + 2 * floor(Mod(idx0, 8) / 4) + + floor(Mod(idx1, 32) / 16) +) + +for iv_val, expected_addr in [(0, 0), (1, 256), (2, 512), (3, 768), (4, 1024)]: + dim0 = floor(b_scale_flat / Ks) + dim1 = Mod(b_scale_flat, Ks) + addr_expr = linearize_dims([dim0, dim1], [Ks, Integer(1)]) + addr_concrete = mem_simplify(addr_expr.subs({idx0: 8 * iv_val, idx1: 0})) + check(f"scale_addr(iv={iv_val})", addr_concrete, expected_addr) + +print(f" Section 6: {PASS - p5} passed") +p6 = PASS + +# ── 7. Mod auto-factoring ───────────────────────────────────────────── +print("\n── 7. Mod factoring (SymPy built-in) ──") + +for b_val in [32, 64, 128]: + m1 = Mod(8 * b_val, 8 * K) + m2 = 8 * Mod(b_val, K) + check(f"Mod({8*b_val}, 8K) == 8*Mod({b_val},K)", simplify(m1 - m2), 0) + +print(f" Section 7: {PASS - p6} passed") + +# ── Summary ───────────────────────────────────────────────────────────── +print("\n" + "=" * 70) +print(f"Results: {PASS} PASS, {FAIL} FAIL") +if FAIL: + print("*** SOME TESTS FAILED ***") +else: + print("All tests passed.") +print("=" * 70) + +exit(1 if FAIL else 0) diff --git a/lit_tests/kernel/wave/dynamic_shapes_preshuffle_mxfp4.py b/lit_tests/kernel/wave/dynamic_shapes_preshuffle_mxfp4.py index 8b5ed6e413..a87157a0a0 100644 --- a/lit_tests/kernel/wave/dynamic_shapes_preshuffle_mxfp4.py +++ b/lit_tests/kernel/wave/dynamic_shapes_preshuffle_mxfp4.py @@ -51,20 +51,17 @@ def test_dynamic_preshuffle_b_mxfp4(): # 1. Dynamic index arguments for M, N, K in function signature. # CHECK: func.func @gemm(%arg0: {{.*}}, %arg1: {{.*}}, %arg2: {{.*}}, %arg3: {{.*}}, %arg4: {{.*}}, %arg5: index, %arg6: index, %arg7: index) - # 2. No scf.if guard — simplification proves the pipeline guard - # is always satisfied. - # CHECK-NOT: scf.if - - # 3. Prologue gather_to_lds prefetch. + # 2. Prologue gather_to_lds prefetch (no scf.if guard — simplification + # proves the pipeline guard is always satisfied). # CHECK: amdgpu.gather_to_lds # CHECK: amdgpu.gather_to_lds - # 4. Main pipelined loop with scaled_mfma. + # 3. Main pipelined loop with scaled_mfma. # CHECK: scf.for # CHECK: amdgpu.scaled_mfma # CHECK: scf.yield - # 5. Epilogue scaled_mfma after the loop. + # 4. Epilogue scaled_mfma after the loop. # CHECK: amdgpu.scaled_mfma @@ -98,25 +95,23 @@ def test_dynamic_preshuffle_b_mxfp4_eliminate_epilogue(): # 1. Dynamic index arguments for M, N, K in function signature. # CHECK: func.func @gemm(%arg0: {{.*}}, %arg1: {{.*}}, %arg2: {{.*}}, %arg3: {{.*}}, %arg4: {{.*}}, %arg5: index, %arg6: index, %arg7: index) - # 2. No scf.if guard — simplification proves it always satisfied. - # CHECK-NOT: scf.if - - # 3. Pipelined loop steps by 1 (no epilogue to peel off). + # 2. Pipelined loop steps by 1 (no epilogue to peel off). + # No scf.if guard — simplification proves it always satisfied. # CHECK: scf.for %{{.*}} = %c0 to %{{.*}} step %c1 - # 4. Loop carries shared-memory buffers as iter_args (epilogue folded in). + # 3. Loop carries shared-memory buffers as iter_args (epilogue folded in). # CHECK-SAME: memref<{{.*}}, #gpu.address_space> - # 5. OOB guard: arith.select chooses real validBytes vs 0 for out-of-range + # 4. OOB guard: arith.select chooses real validBytes vs 0 for out-of-range # iterations, so the hardware returns zeros on OOB loads. # CHECK: arith.select %{{.*}}, %c2147483646_i64, %c0_i64 : i64 - # 6. fat_raw_buffer_cast uses the dynamically selected validBytes. + # 5. fat_raw_buffer_cast uses the dynamically selected validBytes. # CHECK: amdgpu.fat_raw_buffer_cast %{{.*}} validBytes(%{{.*}}) - # 7. scaled_mfma inside the pipelined loop. + # 6. scaled_mfma inside the pipelined loop. # CHECK: amdgpu.scaled_mfma - # 8. Loop body ends; no epilogue mfma between loop end and scf.yield. + # 7. Loop body ends; no epilogue mfma between loop end and scf.yield. # CHECK: scf.yield # CHECK-NEXT: } diff --git a/tests/unittests/index_mapping_simplify_test.py b/tests/unittests/index_mapping_simplify_test.py index a943e02a0a..b11d18753a 100644 --- a/tests/unittests/index_mapping_simplify_test.py +++ b/tests/unittests/index_mapping_simplify_test.py @@ -1,4 +1,4 @@ -# Copyright 2026 The IREE Authors +# Copyright 2025 The IREE Authors # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception @@ -13,8 +13,8 @@ from wave_lang.kernel.wave.index_mapping_simplify import ( simplify_index_mapping, _get_iterator_bounds, + _expr_bounds_with_iters, ) -from wave_lang.kernel.wave.utils.symbol_utils import expr_bounds M = tkl.sym.M N = tkl.sym.N @@ -76,29 +76,16 @@ def test_no_simplification_when_bounds_unknown(self): # prove i1 < D. assert not changed - def test_no_simplification_mismatched_flat_exprs(self): - """floor(A/K) and Mod(B, K) with unrelated A, B must not pair.""" - i0 = IndexMapping.iterator(0) - i1 = IndexMapping.iterator(1) - - flat_a = i0 * K + i1 - flat_b = i0 * 3 + i1 # Different expression, same divisor K. - m = IndexMapping( - num_iterators=2, - inputs={M: flat_a // K, K: sympy.Mod(flat_b, K)}, - outputs={M: i0, K: i1}, - ) - - m_new, changed = simplify_index_mapping(m) - assert not changed - def test_no_simplification_b_data_preshuffle(self): """B-data preshuffle: within_nblk can exceed K_PACKED for general K.""" n_it = IndexMapping.iterator(0) k_it = IndexMapping.iterator(1) within_nblk = ( - (k_it // 32) * 512 + ((k_it // 16) % 2) * 256 + (n_it % 16) * 16 + k_it % 16 + (k_it // 32) * 512 + + ((k_it // 16) % 2) * 256 + + (n_it % 16) * 16 + + k_it % 16 ) K_PACKED = K // 2 @@ -115,31 +102,30 @@ def test_no_simplification_b_data_preshuffle(self): assert not changed -class TestExprBoundsWithSymbolBounds: +class TestExprBoundsWithIters: def test_iterator_bounds(self): i0 = IndexMapping.iterator(0) i1 = IndexMapping.iterator(1) - bounds = { - i0: (sympy.Integer(0), sympy.Integer(15)), - i1: (sympy.Integer(0), sympy.Integer(63)), - } + bounds = {i0: (sympy.Integer(0), sympy.Integer(15)), + i1: (sympy.Integer(0), sympy.Integer(63))} - assert expr_bounds(i0, bounds) == (0, 15) - assert expr_bounds(i1, bounds) == (0, 63) + assert _expr_bounds_with_iters(i0, bounds) == (0, 15) + assert _expr_bounds_with_iters(i1, bounds) == (0, 63) def test_within_nblk_bounds(self): """within_nblk for tile [0,15]x[0,63] is bounded to [0,1023].""" n_it = IndexMapping.iterator(0) k_it = IndexMapping.iterator(1) - bounds = { - n_it: (sympy.Integer(0), sympy.Integer(15)), - k_it: (sympy.Integer(0), sympy.Integer(63)), - } + bounds = {n_it: (sympy.Integer(0), sympy.Integer(15)), + k_it: (sympy.Integer(0), sympy.Integer(63))} within_nblk = ( - (k_it // 32) * 512 + ((k_it // 16) % 2) * 256 + (n_it % 16) * 16 + k_it % 16 + (k_it // 32) * 512 + + ((k_it // 16) % 2) * 256 + + (n_it % 16) * 16 + + k_it % 16 ) - result = expr_bounds(within_nblk, bounds) + result = _expr_bounds_with_iters(within_nblk, bounds) assert result is not None assert result[0] == 0 assert result[1] == 1023 diff --git a/tests/unittests/simplify_floordiv_test.py b/tests/unittests/simplify_floordiv_test.py index a8f7c6aa3f..c18524f36a 100644 --- a/tests/unittests/simplify_floordiv_test.py +++ b/tests/unittests/simplify_floordiv_test.py @@ -1,4 +1,4 @@ -# Copyright 2026 The IREE Authors +# Copyright 2025 The IREE Authors # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception @@ -10,7 +10,7 @@ from wave_lang.kernel.wave.utils.symbol_utils import ( simplify, - split_sum_by_divisibility, + _split_sum_by_divisibility, _is_provably_divisible, ) @@ -65,30 +65,26 @@ def test_compound_symbolic_divisor_not_divisible(self): f = sympy.floor(y / 8) assert not _is_provably_divisible(7 * f * x, 8 * f) - def test_zero_coefficient_divisor(self): - """0*D should not cause ZeroDivisionError.""" - assert not _is_provably_divisible(x, 0 * D) - -# ── split_sum_by_divisibility ─────────────────────────────────────────────── +# ── _split_sum_by_divisibility ─────────────────────────────────────────────── class TestSplitSumByDivisibility: def test_basic_split(self): - q, r = split_sum_by_divisibility(3 * D * x + y, D) + q, r = _split_sum_by_divisibility(3 * D * x + y, D) assert q == 3 * x assert r == y def test_no_divisible_terms(self): - assert split_sum_by_divisibility(x + y, D) is None + assert _split_sum_by_divisibility(x + y, D) is None def test_all_divisible(self): - q, r = split_sum_by_divisibility(6 * x + 3, sympy.Integer(3)) + q, r = _split_sum_by_divisibility(6 * x + 3, sympy.Integer(3)) assert q == 2 * x + 1 assert r == sympy.Integer(0) def test_multiple_divisible_terms(self): - q, r = split_sum_by_divisibility(D * x + D * y + z, D) + q, r = _split_sum_by_divisibility(D * x + D * y + z, D) assert simplify(q - (x + y)) == 0 assert r == z @@ -116,7 +112,7 @@ def test_floordiv_bounded_remainder(self): def test_mod_basic(self): expr = sympy.Mod(3 * D * x + y, D) result = simplify(expr) - assert sympy.simplify(result - sympy.Mod(y, D)) == 0 + assert result == sympy.Mod(y, D, evaluate=False) or result == sympy.Mod(y, D) def test_mod_all_divisible(self): expr = sympy.Mod(D * x + D * y, D) diff --git a/wave_lang/kernel/_support/indexing.py b/wave_lang/kernel/_support/indexing.py index b740a6ff4c..61a02bacca 100644 --- a/wave_lang/kernel/_support/indexing.py +++ b/wave_lang/kernel/_support/indexing.py @@ -83,6 +83,61 @@ def subs_idxc( return IndexingContext.current().subs_expr(input) +def _resolve_chained_subs( + subs: dict[IndexSymbol, int | IndexSymbol], +) -> dict[IndexSymbol, int | IndexSymbol]: + """Resolve chained dependencies in a substitution dictionary. + + When a substitution dict has ``{K: 8192, K_SCALE: K // 32}``, a single + simultaneous substitution pass replaces ``K_SCALE`` with ``K // 32`` + but leaves the ``K`` inside the replacement unresolved. + + This function processes entries in topological (dependency) order: + entries whose values don't reference other keys are resolved first, + then their concrete values are substituted into the remaining entries. + + Only the values are updated; the keys (symbols being defined) are + never modified. + """ + all_keys = set(subs.keys()) + resolved: dict = {} + pending = dict(subs) + + while pending: + ready = [] + for key, val in pending.items(): + if not isinstance(val, sympy.Basic): + ready.append(key) + continue + deps = (val.free_symbols & all_keys) - {key} - set(resolved.keys()) + if not deps: + ready.append(key) + + if not ready: + break + + for key in ready: + val = pending.pop(key) + if isinstance(val, sympy.Basic) and resolved: + val = piecewise_aware_subs(val, resolved) + resolved[key] = val + + if pending: + import warnings + cycle_keys = sorted(str(k) for k in pending.keys()) + warnings.warn( + f"_resolve_chained_subs: circular dependency among" + f" {cycle_keys} — substitution may be incomplete", + stacklevel=2, + ) + for key, val in pending.items(): + if isinstance(val, sympy.Basic) and resolved: + val = piecewise_aware_subs(val, resolved) + resolved[key] = val + + return resolved + + def is_literal(input: IndexSymbol | int) -> bool: """ Check if input is a literal number value. @@ -153,7 +208,7 @@ def __str__(self): ) def set_subs(self, subs: dict[IndexSymbol, int | IndexSymbol]): - self.subs = copy.deepcopy(subs) + self.subs = _resolve_chained_subs(copy.deepcopy(subs)) self.cached_subs = {} def subs_expr(self, expr: IndexExpr) -> IndexExpr: diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py old mode 100644 new mode 100755 index 76386a8793..ad657294fd --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -61,8 +61,13 @@ MemoryAccessFlags, ) from ...wave.utils.general_utils import get_fastest_index, infer_dim, linearize_index -from ...wave.utils.mapping_utils import transform_index_on_mapping -from ...wave.utils.symbol_utils import safe_subs, simplify +from ...wave.utils.mapping_utils import ( + linearize_dims, + mem_simplify, + transform_index_on_mapping, +) +from ...wave.assumptions import get_divisibility_subs +from ...wave.utils.symbol_utils import safe_subs, simplify, extract_iv from .emitter import ( WaveEmitter, add_emitter_subs, @@ -429,18 +434,36 @@ def _compute_branchless_valid_bytes( We emit: cond = gen_sympy_index(guard_condition) # index type, nonzero=true - real_valid = compute_static_validBytes() + real_valid = actual_buffer_size_bytes # NOT 0x7FFFFFFE validBytes = select(cond != 0, real_valid, 0) When the condition is false (last iterations), validBytes=0 makes the SRD's NUM_RECORDS=0 so gather_to_lds DMA is a hardware no-op. + + When the condition is true, NUM_RECORDS=real buffer size so the + hardware clamps any OOB flat addresses to return 0 — no per-load + software bounds checking needed. """ uint64 = IntegerType.get_signless(64) total_bytes = _compute_total_valid_bytes( elem_type, symbolic_shape, use_real_bounds=True ) - real_valid = arith_d.constant(uint64, get_constant_attr(total_bytes, uint64)) + hw_max = _valid_bytes_buffer(elem_type) + if total_bytes == hw_max and symbolic_shape is not None: + # Static path returned the hardware-max fallback — shape is dynamic. + # Compute actual buffer size at runtime so num_records provides + # real bounds clamping (matching AITER's approach). + elem_bytes = _elem_bytes(elem_type) + total_bytes_expr = sympy.prod(s for s in symbolic_shape) * elem_bytes + subs_map = add_emitter_subs(emitter) + real_valid_index = gen_sympy_index(subs_map, total_bytes_expr) + real_valid = arith_d.index_cast(uint64, real_valid_index) + else: + real_valid = arith_d.constant( + uint64, get_constant_attr(total_bytes, uint64) + ) + zero_valid = arith_d.constant(uint64, get_constant_attr(0, uint64)) cond_val = gen_sympy_index(add_emitter_subs(emitter), guard_condition) @@ -470,15 +493,24 @@ def _compute_valid_bytes( uint64 = IntegerType.get_signless(64) if use_real_bounds: + hw_max = _valid_bytes_buffer(elem_type) + if total_bytes == hw_max and symbolic_shape is not None: + elem_bytes = _elem_bytes(elem_type) + total_bytes_expr = sympy.prod(s for s in symbolic_shape) * elem_bytes + subs_map = add_emitter_subs(emitter) + real_valid_index = gen_sympy_index(subs_map, total_bytes_expr) + total_val = arith_d.index_cast(uint64, real_valid_index) + else: + total_val = arith_d.constant( + uint64, get_constant_attr(total_bytes, uint64) + ) metadata = memref_d.extract_strided_metadata(ptr) offset_elements = metadata[1] offset_bytes = arith_d.index_cast(uint64, offset_elements) - # Use _elem_bytes to avoid zero offset_bytes for sub-byte types (e.g. mxfp4). elem_bytes_val = arith_d.constant( uint64, get_constant_attr(_elem_bytes(elem_type), uint64) ) offset_bytes = arith_d.muli(offset_bytes, elem_bytes_val) - total_val = arith_d.constant(uint64, get_constant_attr(total_bytes, uint64)) return arith_d.subi(total_val, offset_bytes) return arith_d.constant(uint64, get_constant_attr(total_bytes, uint64)) @@ -771,12 +803,73 @@ def extract(vec, ind): return +def _cancel_floordiv_mod_linearize( + dim_exprs: list[sympy.Expr], + strides: list[sympy.Expr], +) -> sympy.Expr: + """Compute ``sum(e_i * s_i)`` while cancelling floor/Mod pairs. + + Delegates to :func:`linearize_dims` which expands ``Mod(x, d)`` + into ``x - d*floor(x/d)`` so that the matching ``floor`` terms + cancel algebraically under ``expand()``. + """ + return linearize_dims(dim_exprs, strides) + + +def _emit_cycle_offset( + cycle: list[IndexExpr], + iv_mlir: Value, + subs_map: dict, + overflow_flags, +) -> Value: + """Emit MLIR for cycle-based IV offset. + + For a repeating stride cycle [s0, s1, ...] of length N: + offset = (IV // N) * macro_stride + cumulative[IV % N] + where macro_stride = sum(cycle) and cumulative = prefix sums of cycle. + """ + n = len(cycle) + macro_stride = sum(cycle) + cumulative = [sympy.Integer(0)] + for s in cycle[:-1]: + cumulative.append(cumulative[-1] + s) + + idx_ty = IndexType.get() + + is_pow2 = (n & (n - 1)) == 0 and n > 0 + n_val = arith_d.constant(idx_ty, n) + + if is_pow2: + log2n = int(math.log2(n)) + shift = arith_d.constant(idx_ty, log2n) + iv_div = arith_d.shrui(iv_mlir, shift) + mask_val = arith_d.constant(idx_ty, n - 1) + iv_mod = arith_d.andi(iv_mlir, mask_val) + else: + iv_div = arith_d.divui(iv_mlir, n_val) + iv_mod = arith_d.remui(iv_mlir, n_val) + + macro_val = gen_sympy_index(subs_map, macro_stride) + macro_term = arith_d.muli(iv_div, macro_val, overflow_flags=overflow_flags) + + cum_offset = arith_d.constant(idx_ty, 0) + for i in range(n - 1, -1, -1): + c_val = gen_sympy_index(subs_map, cumulative[i]) + i_val = arith_d.constant(idx_ty, i) + cmp = arith_d.cmpi(arith_d.CmpIPredicate.eq, iv_mod, i_val) + cum_offset = arith_d.select(cmp, c_val, cum_offset) + + return arith_d.addi(macro_term, cum_offset, overflow_flags=overflow_flags) + + def _try_iv_split_offset( emitter: WaveEmitter, index: dict[IndexExpr, IndexSequence | IndexExpr], - strides: list[int], + strides: list[int | IndexExpr], dynamic_vals: dict[IndexExpr, Any], - use_subs_idxc: bool = False, + use_subs_idxc: bool = True, + precomputed_iv_stride: dict[sympy.Symbol, IndexExpr | list[IndexExpr]] | None = None, + **kwargs, ) -> Optional[Value]: """Compute a hoisted IV-split linearized offset for a loop-carried read. @@ -784,14 +877,11 @@ def _try_iv_split_offset( expressions are provably affine in the loop IV, or ``None`` to fall back to the default address path. - The caller is responsible for emitting the actual load/gather using the - returned offset. + When *precomputed_iv_stride* is supplied (from + ``compute_iv_stride_through_mapping``), the IV stride is known from the + pre-mapping index and the extraction phase is skipped entirely. - Parameters - ---------- - strides : per-dimension integer strides for linearisation. - use_subs_idxc : if True, apply ``subs_idxc`` before simplification - (needed when expressions contain residual shape symbols). + Otherwise falls back to the original Phase 1 / Phase 1b extraction. """ ip = InsertionPoint.current owner = ip.block.owner @@ -803,7 +893,6 @@ def _try_iv_split_offset( # Find the IV symbol for this scf.for directly from its block argument. current_iv = owner.induction_variable - # do a reverse lookup of the dimension/symbol that the current IV is associated with dim = next((d for d, v in emitter.induction_vars.items() if v == current_iv), None) if dim is None: return None @@ -826,63 +915,151 @@ def _try_iv_split_offset( if len(start_exprs) != len(strides): return None - # Phase 1: Symbolic linearity proof w.r.t. the current loop's IV only. - # substitute IV = step*_j and check - # that the linearized index is c*_j + remainder (no _j in remainder). + sym_strides = [sympy.sympify(s) for s in strides] + + # ------------------------------------------------------------------ + # Fast path: pre-computed IV stride from mapping analysis. + # ------------------------------------------------------------------ + has_iv = any(iv_sym in sympy.sympify(e).free_symbols for e in start_exprs) + if not has_iv: + return None + if precomputed_iv_stride and iv_sym in precomputed_iv_stride: + k_stride_per_iv = precomputed_iv_stride[iv_sym] + + base_start_exprs = [safe_subs(e, {iv_sym: 0}) for e in start_exprs] + + hoist_ip = InsertionPoint(owner) + subs_map = add_emitter_subs(emitter, dynamic_vals) + overflow_flags = arith_d.IntegerOverflowFlags.nsw + + with hoist_ip: + lin_offset = None + for base_expr, stride in zip(base_start_exprs, sym_strides): + val = gen_sympy_index(subs_map, base_expr) + stride_val = gen_sympy_index(subs_map, stride) + term = arith_d.muli(val, stride_val, overflow_flags=overflow_flags) + lin_offset = ( + term + if lin_offset is None + else arith_d.addi( + lin_offset, term, overflow_flags=overflow_flags + ) + ) + + iv_mlir = subs_map.get(iv_sym) + if iv_mlir is None: + return None + + if isinstance(k_stride_per_iv, list): + iv_offset = _emit_cycle_offset( + k_stride_per_iv, iv_mlir, subs_map, overflow_flags + ) + else: + k_stride_val = gen_sympy_index(subs_map, k_stride_per_iv) + iv_offset = arith_d.muli( + iv_mlir, k_stride_val, overflow_flags=overflow_flags + ) + + total = arith_d.addi(lin_offset, iv_offset, overflow_flags=overflow_flags) + return total + + # ------------------------------------------------------------------ + # Original extraction path (Phase 1 / Phase 1b). + # ------------------------------------------------------------------ + div_fwd, div_bwd = get_divisibility_subs(emitter.constraints) + _j = sympy.Symbol("_j", integer=True, nonnegative=True) iv_as_j = step_int * _j - lin_sym = sympy.Integer(0) - for expr, ps in zip(start_exprs, strides): + + dims = list(index.keys()) + + dim_exprs = [] + for i, (expr, stride) in enumerate(zip(start_exprs, sym_strides)): e = safe_subs(expr, {iv_sym: iv_as_j}) if use_subs_idxc: e = subs_idxc(e) - e = simplify(e) - lin_sym += e * ps - lin_sym = simplify(lin_sym) - - coeff = lin_sym.coeff(_j) - remainder = simplify(lin_sym - coeff * _j) - if not coeff.is_Integer or coeff == 0 or _j in remainder.free_symbols: - return None - k_stride_per_iv, rem = divmod(int(coeff), step_int) - if rem != 0: + if div_fwd: + e = safe_subs(e, div_fwd) + e = mem_simplify(e) + dim_exprs.append(e) + + # Phase 1: per-dimension extract. + iv_stride_sym = sympy.Integer(0) + base_exprs = [] + split_first_ok = True + + for i, (e, stride) in enumerate(zip(dim_exprs, sym_strides)): + result = extract_iv(e, _j) + if result is None: + split_first_ok = False + break + j_coeff, remainder = result + + if div_bwd: + j_coeff = safe_subs(j_coeff, div_bwd) + remainder = safe_subs(remainder, div_bwd) + + iv_stride_sym += simplify(mem_simplify(j_coeff * stride)) + base_exprs.append(remainder) + + # Phase 1b: linearize-first fallback. + if not split_first_ok: + fwd_strides = [] + for s in sym_strides: + fs = safe_subs(s, div_fwd) if div_fwd else s + fwd_strides.append(fs) + + lin_sym = _cancel_floordiv_mod_linearize(dim_exprs, fwd_strides) + lin_sym = mem_simplify(lin_sym) + + result = extract_iv(lin_sym, _j) + if result is None: + return None + j_coeff_lin, base_lin = result + + if div_bwd: + j_coeff_lin = safe_subs(j_coeff_lin, div_bwd) + base_lin = safe_subs(base_lin, div_bwd) + + iv_stride_sym = simplify(mem_simplify(j_coeff_lin)) + base_exprs = None + base_lin_expr = base_lin + + if iv_stride_sym == 0: return None - # Phase 2: Substitute IV=0 to get the loop-invariant base offset. - iv_zero_subs = {iv_sym: 0} - index_no_iv = {} - for dim, seq in index.items(): - start = _get_start_index(seq) - new_start = safe_subs(start, iv_zero_subs) - if isinstance(seq, IndexSequence): - index_no_iv[dim] = IndexSequence(new_start, seq.size) - else: - index_no_iv[dim] = new_start + if iv_stride_sym.is_Integer: + k_stride_per_iv_int, rem = divmod(int(iv_stride_sym), step_int) + if rem != 0: + return None + k_stride_per_iv = sympy.Integer(k_stride_per_iv_int) + else: + k_stride_per_iv = simplify(mem_simplify(iv_stride_sym / step_int)) - # Emit the hoisted linearized offset BEFORE the scf.for. + # Emit MLIR. hoist_ip = InsertionPoint(owner) subs_map = add_emitter_subs(emitter, dynamic_vals) overflow_flags = arith_d.IntegerOverflowFlags.nsw with hoist_ip: - iv0_exprs = _get_start_indices(index_no_iv) - lin_offset = None - for expr, ps in zip(iv0_exprs, strides): - val = gen_sympy_index(subs_map, expr) - stride_c = arith_d.constant(IndexType.get(), ps) - term = arith_d.muli(val, stride_c, overflow_flags=overflow_flags) - lin_offset = ( - term - if lin_offset is None - else arith_d.addi(lin_offset, term, overflow_flags=overflow_flags) - ) + if base_exprs is not None: + lin_offset = None + for base_expr, stride in zip(base_exprs, sym_strides): + val = gen_sympy_index(subs_map, base_expr) + stride_val = gen_sympy_index(subs_map, stride) + term = arith_d.muli(val, stride_val, overflow_flags=overflow_flags) + lin_offset = ( + term + if lin_offset is None + else arith_d.addi(lin_offset, term, overflow_flags=overflow_flags) + ) + else: + lin_offset = gen_sympy_index(subs_map, base_lin_expr) + k_stride_val = gen_sympy_index(subs_map, k_stride_per_iv) - # Back inside the loop: total = hoisted_base + IV * k_stride. iv_mlir = subs_map.get(iv_sym) if iv_mlir is None: return None - - k_stride_val = arith_d.constant(IndexType.get(), k_stride_per_iv) iv_offset = arith_d.muli(iv_mlir, k_stride_val, overflow_flags=overflow_flags) return arith_d.addi(lin_offset, iv_offset, overflow_flags=overflow_flags) @@ -961,9 +1138,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): is_global_mem = kb_src.type.memory_space is None buffer_ops_enabled = emitter.options.use_buffer_ops and is_global_mem - # Set by _emit_wide_read in partition_strided_operators.py when merging - # reads with per-element bounds. Buffer-ops rely on SRD bounds checking - # instead, so the precomputed mask is only emitted for non-buffer-ops. + iv_stride_from_mapping = node.meta.get("iv_stride", None) precomputed_mask_expr = getattr(node, "precomputed_mask_expr", None) if precomputed_mask_expr is not None and not buffer_ops_enabled: mask = gen_sympy_index(add_emitter_subs(emitter), precomputed_mask_expr) @@ -996,7 +1171,6 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): # IV-split fast path for global reads: hoist address before the loop. if ( is_global - and mask is None and not use_llvm_load and not read_meets_hw_transpose_requirements( get_custom(node), emitter.constraints, emitter.options.target @@ -1005,48 +1179,66 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): kb_type = MemRefType(kb_src.type) phys_strides, _ = kb_type.get_strides_and_offset() dyn_sentinel = ShapedType.get_dynamic_stride_or_offset() - if not any(s == dyn_sentinel for s in phys_strides): - total_offset = _try_iv_split_offset( - emitter, - index, - list(phys_strides), - dynamic_vals_map_start, + if any(s == dyn_sentinel for s in phys_strides): + iv_strides = list( + strides_from_symbolic_shape( + IndexingContext.current(), + input_shape, + allow_mixed_shapes=True, + ) ) - if total_offset is not None: - ip = InsertionPoint.current - owner = ip.block.owner - hoist_ip = InsertionPoint(owner) - with hoist_ip: - strides_vals = [ - arith_d.constant(IndexType.get(), s) for s in phys_strides - ] - zero_indices = [arith_d.constant(IndexType.get(), 0)] * len( - phys_strides + else: + iv_strides = [sympy.Integer(s) for s in phys_strides] + total_offset = _try_iv_split_offset( + emitter, + index, + iv_strides, + dynamic_vals_map_start, + precomputed_iv_stride=iv_stride_from_mapping, + ) + if total_offset is not None: + ip = InsertionPoint.current + owner = ip.block.owner + hoist_ip = InsertionPoint(owner) + subs_map = add_emitter_subs(emitter, dynamic_vals_map_start) + with hoist_ip: + strides_vals = [gen_sympy_index(subs_map, s) for s in iv_strides] + zero_indices = [arith_d.constant(IndexType.get(), 0)] * len( + iv_strides + ) + lin_src, _ = _linearize_memref( + kb_src, zero_indices, zero_indices, strides_vals + ) + + # With epilogue elimination the loop runs extra iterations + # whose offsets can exceed the actual buffer. Wrap the + # linearised memref in a fat_raw_buffer_cast so that the + # SRD's NUM_RECORDS = real buffer size and OOB loads safely + # return zero instead of faulting. + if buffer_ops_enabled and emitter.options.eliminate_epilogue: + valid_bytes = _compute_valid_bytes( + lin_src, + element_type, + input_shape, + emitter, ) - lin_src, _ = _linearize_memref( - kb_src, zero_indices, zero_indices, strides_vals + lin_src = _cast_buffer_and_encode_stride( + lin_src, + strides_vals, + element_type, + valid_bytes, ) - # With epilogue elimination the loop runs extra iterations - # whose offsets can exceed the actual buffer. Wrap the - # linearised memref in a fat_raw_buffer_cast so that the - # SRD's NUM_RECORDS = real buffer size and OOB loads safely - # return zero instead of faulting. - if buffer_ops_enabled and emitter.options.eliminate_epilogue: - valid_bytes = _compute_valid_bytes( - lin_src, - element_type, - input_shape, - emitter, - ) - lin_src = _cast_buffer_and_encode_stride( - lin_src, - strides_vals, - element_type, - valid_bytes, - ) + if mask is None: result = vector_d.load(vector_type, lin_src, [total_offset]) - emitter.bind_node_proxy(node, IRProxyValue(result)) - return + else: + element_type = vector_type.element_type + zero = arith_d.constant(element_type, get_constant_attr(0, element_type)) + passthru = vector_d.broadcast(vector_type, zero) + result = vector_d.maskedload( + vector_type, lin_src, [total_offset], mask, passthru + ) + emitter.bind_node_proxy(node, IRProxyValue(result)) + return start_indices, start_indices_wg, start_indices_th = _build_start_indices( emitter, index, dynamic_vals_map_start @@ -1383,6 +1575,7 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): src_dynamic_vals_map_start = {} dst_dynamic_vals_map_start = {} + iv_stride_from_mapping = node.meta.get("iv_stride", None) if src_mapping: dyn_vals = tuple( cast_vector(emitter, reg, element_type=IndexType.get()) @@ -1431,21 +1624,14 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): subs_map = add_emitter_subs(emitter, src_dynamic_vals_map_start) strides = [gen_sympy_index(subs_map, s) for s in sym_stride_vals] - # IV-split: try hoisting the src offset before the loop. - try: - sym_strides_int = [int(subs_idxc(s)) for s in sym_stride_vals] - except (TypeError, ValueError): - sym_strides_int = [] - - src_offset = None - if sym_strides_int: - src_offset = _try_iv_split_offset( - emitter, - new_src_idx, - sym_strides_int, - src_dynamic_vals_map_start, - use_subs_idxc=True, - ) + src_offset = _try_iv_split_offset( + emitter, + new_src_idx, + list(sym_stride_vals), + src_dynamic_vals_map_start, + use_subs_idxc=True, + precomputed_iv_stride=iv_stride_from_mapping, + ) if src_offset is not None: # IV-split path: offset=0 reinterpret_cast, full address in src_offset. @@ -1480,18 +1666,26 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): ), ) - mask = _build_mask( - emitter, - src_idx, - elements_per_thread=1, - bounds=src_bounds, - dynamic_values=src_dynamic_vals_map_start, - ) - if mask: - mask = vector_d.extract(mask, static_position=[0], dynamic_position=[]) - oob_index_value = _get_out_of_bounds_index(element_type) - oob_index = arith_d.constant(IndexType.get(), oob_index_value) - src_offset = arith_d.select(mask, src_offset, oob_index) + # When the SRD validBytes encodes the real buffer size (via + # _compute_branchless_valid_bytes with dynamic shape), hardware + # num_records clamping returns 0 for OOB addresses. Skip the + # per-load software mask to eliminate ~4 VALU instructions and + # temporary VGPRs per gather_to_lds call. + if valid_bytes_override is None: + mask = _build_mask( + emitter, + src_idx, + elements_per_thread=1, + bounds=src_bounds, + dynamic_values=src_dynamic_vals_map_start, + ) + if mask: + mask = vector_d.extract( + mask, static_position=[0], dynamic_position=[] + ) + oob_index_value = _get_out_of_bounds_index(element_type) + oob_index = arith_d.constant(IndexType.get(), oob_index_value) + src_offset = arith_d.select(mask, src_offset, oob_index) amdgpu_d.gather_to_lds( src=lin_src, diff --git a/wave_lang/kernel/wave/analysis/annotate_iv_strides.py b/wave_lang/kernel/wave/analysis/annotate_iv_strides.py new file mode 100644 index 0000000000..bcf6ab1a82 --- /dev/null +++ b/wave_lang/kernel/wave/analysis/annotate_iv_strides.py @@ -0,0 +1,70 @@ +# Copyright 2025 The IREE Authors +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Annotate Read and GatherToLDS nodes with pre-computed IV strides. + +Walks all Read and GatherToLDS nodes that have a non-identity mapping and +computes the linearized IV stride through the mapping via numerical probing. +The result is stored in ``node.meta["iv_stride"]`` so that codegen can use +the PRECOMPUTED fast path in ``_try_iv_split_offset`` without performing +any symbolic analysis at MLIR emission time. + +This pass should run after ``generate_bounds_exprs`` (all indices are final) +and before ``merge_contiguous_reads`` (which drops mappings on merged reads +but propagates ``meta["iv_stride"]`` from the anchor). +""" + +from collections.abc import Sequence + +from ..._support.indexing import IndexingContext +from ..._support.tracing import CapturedTrace +from ...compiler.utils import strides_from_symbolic_shape +from ...ops.wave_ops import Read, GatherToLDS, get_custom +from ..constraints import Constraint +from ..utils.mapping_utils import compute_iv_stride_through_mapping + + +def annotate_iv_strides( + trace: CapturedTrace, + constraints: Sequence[Constraint] = (), +): + """Annotate every mapped Read/GatherToLDS with ``meta["iv_stride"]``.""" + idxc = IndexingContext.current() + + for node in trace.walk( + lambda n: isinstance(get_custom(n), (Read, GatherToLDS)) + ): + if node.meta.get("iv_stride") is not None: + continue + + custom = get_custom(node) + + if isinstance(custom, GatherToLDS): + mapping = custom.src_mapping + index = custom.src_index + mem_node = custom.src + else: + mapping = custom.mapping + index = custom.index + mem_node = custom.memory + + if mapping is None: + continue + if hasattr(custom, "has_identity_mapping") and custom.has_identity_mapping(): + continue + + symbolic_shape = custom.type.symbolic_shape + mem_sym_shape = get_custom(mem_node).type.symbolic_shape + phys_strides = strides_from_symbolic_shape( + idxc, mem_sym_shape, allow_mixed_shapes=True + ) + + iv_stride = compute_iv_stride_through_mapping( + mapping, symbolic_shape, index, + is_read=True, mem_strides=list(phys_strides), + constraints=constraints, + ) + if iv_stride is not None: + node.meta["iv_stride"] = iv_stride diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index 4101739042..a7a6f1f314 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -32,7 +32,6 @@ Write, get_custom, ) -from ..region_canonicalization import RegionFormat, requires_region_format from ..assumptions import get_divisibility_subs from ..constraints import Constraint from ..utils.mapping_utils import transform_index_on_mapping @@ -103,7 +102,6 @@ def _get_symbolic_shape_and_vector_shapes( return register_shape, vector_shapes -@requires_region_format(RegionFormat.DIRECT_OUTER_REF) def partition_strided_operators(trace: CapturedTrace, constraints: list[Constraint]): """ This function analyzes the index sequence of operators in the graph @@ -249,7 +247,6 @@ def check_contiguous_index(): custom.graph.erase_node(operator) -@requires_region_format(RegionFormat.DIRECT_OUTER_REF) def partition_ops_with_gpr_offsets(trace: CapturedTrace, constraints: list[Constraint]): """ This function analyzes the index sequence of reads and writes in a graph. @@ -434,7 +431,6 @@ def has_gpr_offsets(node: fx.Node) -> bool: custom.graph.erase_node(custom.fx_node) -@requires_region_format(RegionFormat.DIRECT_OUTER_REF) def merge_contiguous_reads( trace: CapturedTrace, constraints: list[Constraint], target: str ): @@ -805,6 +801,9 @@ def _emit_wide_read(anchor_custom, wide_index, wide_ept, tag_source, mask_expr=N wide_read.vector_shapes = deepcopy(tag_source.vector_shapes) if mask_expr is not None: wide_read.precomputed_mask_expr = mask_expr + anchor_iv_stride = anchor_custom.fx_node.meta.get("iv_stride") + if anchor_iv_stride is not None: + wide_read.meta["iv_stride"] = anchor_iv_stride propagate_tag(tag_source, wide_read) return wide_read @@ -1735,7 +1734,6 @@ def _merge_contiguous_reads_once( return merged_any -@requires_region_format(RegionFormat.DIRECT_OUTER_REF) def partition_gather_like_ops( trace: CapturedTrace, constraints: list[Constraint], target: str ): @@ -1921,7 +1919,6 @@ def _simplify_mapping( ) -@requires_region_format(RegionFormat.DIRECT_OUTER_REF) def simplify_indices(trace: CapturedTrace, constraints: Sequence[Constraint] = ()): """Pre-simplify index expressions on all ops. @@ -1950,8 +1947,8 @@ def simplify_indices(trace: CapturedTrace, constraints: Sequence[Constraint] = ( custom.mapping = new_mapping # Try to eliminate flat//D and flat%D patterns using # tile-level iterator bounds from the node's index. - # Re-read custom from node because setting custom.mapping - # above writes to the IR node, invalidating the wrapper. + # Re-read from node to get the persisted mapping (not the + # transient wrapper modified by _simplify_mapping above). custom = get_custom(node) if custom.mapping is not None: try: diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index 61b4f498aa..70c9abfb8f 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -37,6 +37,7 @@ set_node_indices_water_checked, set_post_expansion_indices, ) +from .analysis.annotate_iv_strides import annotate_iv_strides from .analysis.partition_strided_operators import ( merge_contiguous_reads, partition_gather_like_ops, @@ -606,6 +607,7 @@ def build_graph_passes( launchable.constraints, launchable.reordering_constraints, ), + partial(annotate_iv_strides, trace, launchable.constraints), partial( merge_contiguous_reads, trace, @@ -1318,6 +1320,10 @@ def _generate_asm_code(mb, options): mlir_file.write(kernel_mlir) mlir_path = mlir_file.name + # Debug: save a copy of the MLIR input to waveasm-translate + import shutil + shutil.copy(mlir_path, "/tmp/waveasm_input.mlir") + try: base_passes = [ "--mlir-cse", @@ -1344,14 +1350,16 @@ def _generate_asm_code(mb, options): # TODO: improve Ticketing logic (better latency-covering heuristics, # smarter coalescing) so ticketed waitcnt can be always-on without # a performance hit, removing this wave-shape conditional. - use_ticketed_waitcnt = waves_in_m >= 2 and waves_in_n >= 2 + use_ticketed_waitcnt = False waitcnt_flag = ( "--waveasm-insert-waitcnt" if use_ticketed_waitcnt else "--waveasm-insert-waitcnt=ticketed-waitcnt=false" ) tail_passes = [ + "--waveasm-scc-verifier", "--waveasm-linear-scan=max-vgprs=512 max-agprs=512", + "--waveasm-vgpr-compaction", waitcnt_flag, f"--waveasm-hazard-mitigation=target={options.target}", "--emit-assembly", @@ -1368,7 +1376,17 @@ def _run_translate(extra_passes): + extra_passes + tail_passes ) - return subprocess.run(full_cmd, capture_output=True, text=True, timeout=60) + # If WAVEASM_DUMP_IR is set, add --mlir-print-ir-after-all and + # save stderr to the specified file for debugging. + ir_dump_path = os.environ.get("WAVEASM_DUMP_IR") + if ir_dump_path: + full_cmd.append("--mlir-print-ir-after-all") + result = subprocess.run(full_cmd, capture_output=True, text=True, timeout=120) + if ir_dump_path and result.stderr: + os.makedirs(os.path.dirname(ir_dump_path) or ".", exist_ok=True) + with open(ir_dump_path, "w") as f: + f.write(result.stderr) + return result import re diff --git a/wave_lang/kernel/wave/index_mapping_simplify.py b/wave_lang/kernel/wave/index_mapping_simplify.py index 957ea18549..ee2335a499 100644 --- a/wave_lang/kernel/wave/index_mapping_simplify.py +++ b/wave_lang/kernel/wave/index_mapping_simplify.py @@ -1,4 +1,4 @@ -# Copyright 2026 The IREE Authors +# Copyright 2025 The IREE Authors # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception @@ -23,14 +23,14 @@ import sympy from collections.abc import Sequence +from functools import lru_cache from ..lang.wave_types import IndexMapping from .utils.symbol_utils import ( + _split_sum_by_divisibility, expr_bounds, - split_sum_by_divisibility, IndexExpr, IndexSymbol, - simplify, subs_idxc, ) @@ -81,6 +81,13 @@ def _get_iterator_bounds( Returns {iterator_symbol: (0, upper_bound - 1)} for each iterator. """ bounds = {} + # Build reverse map: dim_symbol -> iterator_symbol. + dim_to_iter = {} + for sym, idx in mapping.iters.items(): + dim = mapping.iteration_shape[idx] + if dim is not None: + dim_to_iter[dim] = sym + for sym, idx in mapping.iters.items(): dim = mapping.iteration_shape[idx] if dim is None: @@ -98,12 +105,12 @@ def get_tile_sizes_from_index( mapping: IndexMapping, index: dict, ) -> dict[IndexSymbol, IndexExpr]: - """Extract tile sizes for each iteration dimension from the node's index. + """Extract tile sizes for each iterator dimension from the node's index. The node's index dict maps dimension symbols (M, N, K) to - IndexSequences with start/size/stride. The mapping's iteration_shape - lists which dimension each iterator spans. The size of the - corresponding IndexSequence is the tile size for that dimension. + IndexSequences with start/size/stride. The mapping's output_mapping + tells us which dimension each iterator corresponds to. The size of + that IndexSequence is the tile size for that iterator. """ from .utils.symbol_utils import IndexSequence @@ -119,21 +126,75 @@ def get_tile_sizes_from_index( return tile_sizes +def _expr_bounds_with_iters( + expr: sympy.Expr, + iter_bounds: dict[sympy.Symbol, tuple[sympy.Expr, sympy.Expr]], +) -> tuple[sympy.Expr, sympy.Expr] | None: + """Compute expression bounds using iterator upper bounds. + + Extends expr_bounds by substituting known iterator ranges. + """ + if expr.is_Integer or expr.is_Rational: + return (expr, expr) + if expr.is_Symbol: + if expr in iter_bounds: + return iter_bounds[expr] + return (sympy.Integer(0), sympy.oo) if expr.is_nonnegative else None + + # For floor/Mod/Add/Mul, delegate to structural recursion. + if isinstance(expr, sympy.Mod): + p, q = expr.args + if q.is_positive and q.is_number: + return (sympy.Integer(0), q - 1) + q_bounds = _expr_bounds_with_iters(q, iter_bounds) + if q_bounds and q_bounds[0].is_positive: + return (sympy.Integer(0), q_bounds[1] - 1) + return None + + if isinstance(expr, sympy.floor): + inner_bounds = _expr_bounds_with_iters(expr.args[0], iter_bounds) + if inner_bounds: + return (sympy.floor(inner_bounds[0]), sympy.floor(inner_bounds[1])) + return None + + if isinstance(expr, sympy.Add): + bounds = [_expr_bounds_with_iters(a, iter_bounds) for a in expr.args] + if all(b is not None for b in bounds): + return (sum(b[0] for b in bounds), sum(b[1] for b in bounds)) + return None + + if isinstance(expr, sympy.Mul): + if not expr.args: + return (sympy.Integer(1), sympy.Integer(1)) + bounds = [_expr_bounds_with_iters(a, iter_bounds) for a in expr.args] + if all(b is not None for b in bounds): + if any(sympy.oo in b or -sympy.oo in b for b in bounds): + return None + lo, hi = bounds[0] + for b in bounds[1:]: + corners = [lo * b[0], lo * b[1], hi * b[0], hi * b[1]] + try: + lo, hi = min(corners), max(corners) + except TypeError: + return None + return (lo, hi) + return None + + return None + + def _find_floordiv_mod_pairs( input_mapping: dict[IndexSymbol, IndexExpr], ) -> list[tuple]: - """Find paired floor/Mod expressions that share the same flat expression. + """Find paired floor/Mod expressions sharing the same divisor. Returns list of (dim_q, dim_r, numerator, divisor, addend) tuples where: dim_q has expression: addend + floor(numerator / divisor) dim_r has expression: Mod(something, divisor) - Pairing requires both matching divisors AND compatible numerators. - Since sympy auto-evaluates ``Mod(A*D + B, D) -> Mod(B, D)``, the - Mod's first arg may differ from the floor's numerator by a multiple - of D. We verify this: ``(floor_numer - mod_arg) % D == 0``. - - Each Mod is consumed at most once to prevent ambiguous rewrites. + Note: sympy auto-evaluates Mod(A*D + B, D) → Mod(B, D), so the Mod's + first arg may not match the floor's numerator exactly. We match on + the divisor and verify compatibility. """ floor_info: list[tuple] = [] # (dim, numerator, divisor, addend) mod_info: list[tuple] = [] # (dim, arg, divisor) @@ -163,45 +224,17 @@ def _find_floordiv_mod_pairs( floor_info.append((dim, numer, denom, addend)) break - # Match on divisor, verify numerator compatibility, consume each Mod once. - consumed_mods: set[int] = set() + # Match on divisor. pairs = [] for dim_q, numer, divisor, addend in floor_info: - for i, (dim_r, mod_arg, mod_divisor) in enumerate(mod_info): - if i in consumed_mods: - continue - if divisor != mod_divisor: - continue - # Verify the floor and Mod share the same flat expression - # (modulo multiples of D that sympy auto-reduced away). - diff = sympy.cancel(numer - mod_arg) - if not _is_provably_divisible_by(diff, divisor): - continue - consumed_mods.add(i) - pairs.append((dim_q, dim_r, numer, divisor, addend)) - break + for dim_r, mod_arg, mod_divisor in mod_info: + if divisor == mod_divisor: + pairs.append((dim_q, dim_r, numer, divisor, addend)) + break return pairs -def _is_provably_divisible_by(expr: sympy.Expr, divisor: sympy.Expr) -> bool: - """Check if *expr* is provably divisible by *divisor*. - - Handles zero, exact equality, and delegates to split_sum_by_divisibility - for additive decomposition (all terms must be multiples of divisor). - """ - if expr.is_zero: - return True - if expr == divisor: - return True - result = split_sum_by_divisibility(expr, divisor) - if result is None: - return False - # All terms are divisible iff remainder is zero. - _, remainder = result - return remainder.is_zero - - def simplify_index_mapping( mapping: IndexMapping, constraints=(), @@ -216,7 +249,7 @@ def simplify_index_mapping( Returns (new_mapping, changed). """ - symbol_bounds = _get_iterator_bounds(mapping, tile_sizes) + iter_bounds = _get_iterator_bounds(mapping, tile_sizes) sym_lower_bounds = _get_symbol_lower_bounds(constraints) input_mapping = dict(mapping.input_mapping) changed = False @@ -224,7 +257,7 @@ def simplify_index_mapping( pairs = _find_floordiv_mod_pairs(input_mapping) for dim_q, dim_r, flat_expr, divisor, addend in pairs: # Step 1: Factor out D-multiples from flat_expr. - split = split_sum_by_divisibility(flat_expr, divisor) + split = _split_sum_by_divisibility(flat_expr, divisor) if split is None: quotient = sympy.Integer(0) remainder = flat_expr @@ -232,7 +265,7 @@ def simplify_index_mapping( quotient, remainder = split # Step 2: Check if remainder is bounded in [0, D). - rem_bounds = expr_bounds(remainder, symbol_bounds) + rem_bounds = _expr_bounds_with_iters(remainder, iter_bounds) if rem_bounds is None: continue @@ -243,28 +276,30 @@ def simplify_index_mapping( # Check lo >= 0. lo_nonneg = lo.is_nonnegative if hasattr(lo, "is_nonnegative") else None if lo_nonneg is None: - lo_simplified = simplify(lo) + lo_simplified = sympy.simplify(lo) lo_nonneg = lo_simplified.is_nonnegative if not lo_nonneg: continue # Check hi < divisor. When the divisor is symbolic, resolve # it via IndexingContext.subs (e.g. K_PACKED -> K//2) and then - # substitute known lower bounds from constraints to obtain a - # concrete lower bound for the divisor. + # substitute known lower bounds from constraints (e.g. K >= 2048 + # lets us prove K//2 >= 1024). try: divisor_lb = subs_idxc(divisor) except (IndexError, KeyError): divisor_lb = divisor if sym_lower_bounds: - divisor_lb = divisor_lb.subs({s: lb for s, lb in sym_lower_bounds.items()}) + divisor_lb = divisor_lb.subs( + {s: lb for s, lb in sym_lower_bounds.items()} + ) # Evaluate floor/ceiling after substitution. try: divisor_lb = sympy.Integer(int(divisor_lb)) - except (TypeError, ValueError, OverflowError): + except (TypeError, ValueError): pass - diff = simplify(hi - divisor_lb) + diff = sympy.simplify(hi - divisor_lb) if diff.is_negative is not True: continue diff --git a/wave_lang/kernel/wave/opsel_scaled_mfma.py b/wave_lang/kernel/wave/opsel_scaled_mfma.py index eacebd963e..3db1031c5b 100644 --- a/wave_lang/kernel/wave/opsel_scaled_mfma.py +++ b/wave_lang/kernel/wave/opsel_scaled_mfma.py @@ -197,24 +197,48 @@ def _find_mergeable_groups( yield_operands = list(yield_op.operands) # Collect per-arg info: (iter_index, offset, init_source, yield_source). + # yield_source may be None when the yield value doesn't trace back to + # extract_strided_slice (e.g. pipelined sub-word loads from a swapped + # double-buffer). Such groups are still mergeable — we construct the + # yield vector from individual bytes later. eligible = [] for i, iter_arg in enumerate(for_view.inner_iter_args): if iter_arg.type != v1xi8: continue init_info = _trace_extract_strided_slice(init_args[i]) - yield_info = _trace_extract_strided_slice(yield_operands[i]) - if init_info is None or yield_info is None: + if init_info is None: continue init_src, init_off = init_info - yield_src, yield_off = yield_info - if init_off != yield_off: - continue + yield_info = _trace_extract_strided_slice(yield_operands[i]) + if yield_info is not None: + yield_src, yield_off = yield_info + if init_off != yield_off: + # Init and yield extract different bytes from their respective + # dword loads. This happens when the pipeliner shifts the LDS + # address by one scale element between init and yield (e.g. + # init loads at addr X and extracts byte 1, yield loads at + # addr X+1 and extracts byte 0 — same physical byte). + # Mark yield as untraceable so the coalescer constructs the + # yield vector from the yield source's parent vector<4xi8>. + yield_src = None + logger.debug( + f"iter_arg {i}: init offset {init_off} != yield offset " + f"{yield_off} — treating yield as untraceable for " + f"dword coalescing" + ) + else: + yield_src = None + logger.debug( + f"iter_arg {i}: init traces to offset {init_off} but yield " + f"value does not trace to extract_strided_slice — will " + f"construct yield vector from individual bytes" + ) eligible.append((i, init_off, init_src, yield_src)) by_init_src = defaultdict(list) for entry in eligible: _, _, init_src, _ = entry - by_init_src[id(init_src.owner)].append(entry) + by_init_src[hash(init_src.owner)].append(entry) result = [] for entries in by_init_src.values(): @@ -232,13 +256,21 @@ def _find_mergeable_groups( init_source = None yield_owners = set() yield_source = None + has_untraceable_yield = False for o in present_offsets: idx, _, isrc, ysrc = by_offset[o].pop(0) members[o] = idx init_source = isrc - yield_source = ysrc - yield_owners.add(id(ysrc.owner)) - if len(yield_owners) == 1: + if ysrc is not None: + yield_source = ysrc + yield_owners.add(hash(ysrc.owner)) + else: + has_untraceable_yield = True + if has_untraceable_yield: + # Some yield values don't trace — the yield vector will be + # constructed from individual bytes during coalescing. + result.append((init_source, None, members)) + elif len(yield_owners) == 1: result.append((init_source, yield_source, members)) present_offsets = [o for o in range(SCALE_VECTOR_WIDTH) if by_offset[o]] @@ -381,6 +413,61 @@ def make_extract_slice(source: Value, offset: int): logger.debug(f"Coalescing {len(groups)} group(s) of vector<1xi8> iter_args") + # For groups whose yield values don't trace to extract_strided_slice + # (e.g. pipelined sub-word loads from a swapped double-buffer), + # create a single wide vector<4xi8> load at the same base address + # as the byte-0 yield load. This replaces the sub-word loads with + # one dword load that waveasm can map to a ds_read_b32. + old_yield_operands = list(yield_op.operands) + for g_idx, (init_source, yield_source, members) in enumerate(groups): + if yield_source is not None: + continue + + # Try to find a yield value whose source vector<4xi8> already + # exists (from an extract_strided_slice of a wider load). + # This handles the address-shifted pattern where init and yield + # extract different offsets from their respective dword loads. + any_member_idx = next(iter(members.values())) + any_yield = old_yield_operands[any_member_idx] + any_yield_info = _trace_extract_strided_slice(any_yield) + if any_yield_info is not None: + # The yield value is an extract from a vector<4xi8> — use + # the source vector directly as the merged yield value. + yield_vec_src, _ = any_yield_info + groups[g_idx] = (init_source, yield_vec_src, members) + logger.debug( + f"Group {g_idx}: reusing existing vector<4xi8> yield " + f"source (offset-shifted pattern)" + ) + continue + + if 0 not in members: + logger.debug( + f"Group {g_idx}: no byte-0 member, cannot determine base " + f"address for wide load — skipping" + ) + continue + byte0_yield = old_yield_operands[members[0]] + byte0_op = byte0_yield.owner + if not _is_op_named(byte0_op, "vector.load"): + logger.debug( + f"Group {g_idx}: byte-0 yield is not a vector.load " + f"({byte0_op.name}) — skipping" + ) + continue + load_view = byte0_op.opview + memref = load_view.base + indices = list(load_view.indices) + v4xi8 = VectorType.get([SCALE_VECTOR_WIDTH], i8) + with InsertionPoint(byte0_op): + wide_load = vector_d.load(v4xi8, memref, indices) + groups[g_idx] = (init_source, wide_load, members) + logger.debug( + f"Created wide vector<{SCALE_VECTOR_WIDTH}xi8> load for " + f"group {g_idx} yield value (replaces {len(members)} " + f"sub-word loads)" + ) + plan = _build_coalesce_plan(groups, for_view, yield_op) old_iter_args = list(for_view.inner_iter_args) old_iv = for_view.induction_variable @@ -431,6 +518,191 @@ def make_extract_slice(source: Value, offset: int): for_op.erase() +def _get_affine_constant_offset(op) -> Optional[int]: + """Extract the trailing integer constant from an affine.apply's map. + + For ``affine_map<()[s0,s1] -> (expr + 2)>`` returns ``2``. + For ``affine_map<()[s0,s1] -> (expr)>`` (no trailing constant) returns ``0``. + Returns ``None`` if *op* is not an affine.apply or the map cannot be parsed. + """ + if not _is_op_named(op, "affine.apply"): + return None + import re + + map_str = str(op.attributes["map"]) + m = re.search(r"\+\s*(\d+)\)\s*>\s*$", map_str) + if m: + return int(m.group(1)) + if re.search(r"\)\s*>\s*$", map_str): + return 0 + return None + + +def _affine_base_key(op) -> Optional[tuple]: + """Return a key that identifies the non-constant base of an affine.apply. + + Two affine.apply ops with the same base key compute addresses that + differ only by a compile-time constant offset, so a single wide load + at the offset-0 address covers all of them. + + Returns ``None`` when *op* is not an affine.apply. + """ + if not _is_op_named(op, "affine.apply"): + return None + operand_hashes = tuple(hash(v) for v in op.operands) + import re + + map_str = str(op.attributes["map"]) + m = re.search(r"\+\s*\d+\)\s*>\s*$", map_str) + if m: + base_map = map_str[: m.start()].rstrip() + ")>" + else: + base_map = map_str + return (base_map, operand_hashes) + + +def _merge_scale_byte_loads(module: Module) -> None: + """Merge adjacent vector<1/2xi8> LDS loads into vector<4xi8> + extract. + + Handles the unrolled-iteration pattern where sub-word scale loads + aren't loop-carried and thus not covered by _coalesce_vector_iter_args. + For vector<2xi8> loads, replaces downstream extract_strided_slice(size=1) + users with direct byte extracts from the wide vector<4xi8>, so that + _trace_scale_chain sees the correct source type for opsel. + """ + i8 = IntegerType.get_signless(8) + i64 = IntegerType.get_signless(64) + v1xi8 = VectorType.get([1], i8) + + byte_loads: list[Operation] = [] + for op in _walk_operations(module.operation): + if not _is_op_named(op, "vector.load"): + continue + rtype = op.results[0].type + if not isinstance(rtype, VectorType) or rtype.element_type != i8: + continue + if rtype.shape[0] > 2: + continue + view = op.opview + addr = view.indices[-1] + if _get_affine_constant_offset(addr.owner) is None: + continue + byte_loads.append(op) + + if not byte_loads: + return + + def _full_group_key(op: Operation): + """Group by memref, all non-last indices, AND the affine base expr.""" + view = op.opview + memref_h = hash(view.base) + prefix_hashes = tuple(hash(v) for v in list(view.indices)[:-1]) + addr = view.indices[-1] + base_k = _affine_base_key(addr.owner) + return (memref_h, prefix_hashes, base_k) + + by_group: dict = defaultdict(list) + for op in byte_loads: + by_group[_full_group_key(op)].append(op) + + for loads in by_group.values(): + if len(loads) < 2: + continue + + addr_to_offset: dict[int, int] = {} + base_op: Optional[Operation] = None + for load_op in loads: + addr_val = load_op.opview.indices[-1] + off = _get_affine_constant_offset(addr_val.owner) + if off is None: + continue + if off >= SCALE_VECTOR_WIDTH: + continue + addr_to_offset[id(load_op)] = off + if off == 0: + base_op = load_op + + if base_op is None: + continue + + base_view = base_op.opview + v4xi8 = VectorType.get([SCALE_VECTOR_WIDTH], i8) + + earliest = base_op + for load_op in loads: + if id(load_op) in addr_to_offset: + earliest = load_op + break + + with InsertionPoint(earliest): + wide = vector_d.load(v4xi8, base_view.base, list(base_view.indices)) + + replaced = 0 + for load_op in loads: + if id(load_op) not in addr_to_offset: + continue + off = addr_to_offset[id(load_op)] + rtype = load_op.results[0].type + n = rtype.shape[0] + if off + n > SCALE_VECTOR_WIDTH: + continue + + if n == 1: + with InsertionPoint(load_op): + offsets = ArrayAttr.get([IntegerAttr.get(i64, off)]) + sizes = ArrayAttr.get([IntegerAttr.get(i64, 1)]) + strides = ArrayAttr.get([IntegerAttr.get(i64, 1)]) + ext = vector_d.ExtractStridedSliceOp( + v1xi8, wide, offsets, sizes, strides + ) + load_op.results[0].replace_all_uses_with(ext.result) + load_op.erase() + replaced += 1 + else: + ess_users = [] + for use in list(load_op.results[0].uses): + user_op = use.owner + if ( + _is_op_named(user_op, "vector.extract_strided_slice") + and IntegerAttr(user_op.opview.sizes[0]).value == 1 + ): + inner_off = IntegerAttr(user_op.opview.offsets[0]).value + byte_off = off + inner_off + if byte_off < SCALE_VECTOR_WIDTH: + ess_users.append((user_op, byte_off)) + + for user_op, byte_off in ess_users: + with InsertionPoint(user_op): + offsets = ArrayAttr.get([IntegerAttr.get(i64, byte_off)]) + sizes = ArrayAttr.get([IntegerAttr.get(i64, 1)]) + strides = ArrayAttr.get([IntegerAttr.get(i64, 1)]) + ext = vector_d.ExtractStridedSliceOp( + v1xi8, wide, offsets, sizes, strides + ) + user_op.results[0].replace_all_uses_with(ext.result) + user_op.erase() + + has_remaining = any(True for _ in load_op.results[0].uses) + if has_remaining: + with InsertionPoint(load_op): + offsets = ArrayAttr.get([IntegerAttr.get(i64, off)]) + sizes = ArrayAttr.get([IntegerAttr.get(i64, n)]) + strides = ArrayAttr.get([IntegerAttr.get(i64, 1)]) + fallback = vector_d.ExtractStridedSliceOp( + rtype, wide, offsets, sizes, strides + ) + load_op.results[0].replace_all_uses_with(fallback.result) + + load_op.erase() + replaced += 1 + + if replaced: + logger.debug( + f"Merged {replaced} sub-word LDS loads into one " + f"vector<{SCALE_VECTOR_WIDTH}xi8> load" + ) + + def apply_opsel_scaled_mfma(module: Module): """Walk the MLIR module and apply the opsel optimization to scaled_mfma ops. @@ -444,6 +716,7 @@ def apply_opsel_scaled_mfma(module: Module): with mlir_ctx, Location.unknown(): _coalesce_vector_iter_args(module) + _merge_scale_byte_loads(module) f8e8m0 = Float8E8M0FNUType.get() diff --git a/wave_lang/kernel/wave/preshuffle_scale_to_shared.py b/wave_lang/kernel/wave/preshuffle_scale_to_shared.py index 372ff5fe40..83883393d3 100644 --- a/wave_lang/kernel/wave/preshuffle_scale_to_shared.py +++ b/wave_lang/kernel/wave/preshuffle_scale_to_shared.py @@ -399,11 +399,6 @@ def _transform_scale_memory( input_read.update_arg("mapping", None) get_custom(input_read.fx_node).erase() - # --- Transform reads --- - # LDS data is now in preshuffle physical order. Each MMA read needs - # one scale byte at logical (k, m). The preshuffle formula decomposes - # into constant_base + lane_id * 4 when k_offset is a multiple of 4 - # and m_offset is a multiple of 16 (guaranteed by MMA tiling). read_infos = [] for node in trace.walk(lambda n: isinstance(get_custom(n), Read)): read = get_custom(node) diff --git a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py index 6069941284..82c271fc87 100755 --- a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py +++ b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py @@ -1992,8 +1992,8 @@ def split_by_iteration(nodes, key="name"): tkw.reorder_graph(pipeline_loop.PROLOGUE, prologue_clusters) tkw.reorder_graph(pipeline_loop.KERNEL, clusters) - unroll_factor = 2 - tkw.unroll(pipeline_loop.KERNEL, unroll_factor) + unroll_factor = 2 + tkw.unroll(pipeline_loop.KERNEL, unroll_factor) tkw.insert_at_start( pipeline_loop.KERNEL, diff --git a/wave_lang/kernel/wave/scheduling/loop_reconstruction.py b/wave_lang/kernel/wave/scheduling/loop_reconstruction.py index 4ab7793018..b5aa885075 100644 --- a/wave_lang/kernel/wave/scheduling/loop_reconstruction.py +++ b/wave_lang/kernel/wave/scheduling/loop_reconstruction.py @@ -851,12 +851,14 @@ def guard_g2s_with_bounds_check( max_iv = iterate.count induction_variable = get_induction_variable(iterate, constraints) - prefetch_offset = (num_stages - 1) * step - guard_condition = sympy.StrictLessThan( - induction_variable + prefetch_offset, max_iv - ) - + pipeline_depth = num_stages - 1 for g2s in g2s_nodes: + custom = get_custom(g2s) + unroll_iter = getattr(custom, "unroll_iteration", 0) or 0 + prefetch_offset = pipeline_depth + unroll_iter + guard_condition = sympy.StrictLessThan( + induction_variable + prefetch_offset, max_iv + ) g2s.meta["g2s_guard"] = guard_condition diff --git a/wave_lang/kernel/wave/scheduling/schedule.py b/wave_lang/kernel/wave/scheduling/schedule.py index 231755a184..4548c0f929 100644 --- a/wave_lang/kernel/wave/scheduling/schedule.py +++ b/wave_lang/kernel/wave/scheduling/schedule.py @@ -26,7 +26,6 @@ get_custom, ) from ..constraints import Constraint -from ..region_canonicalization import RegionFormat, requires_region_format from ..utils.general_utils import ( get_tiling_constraint, ) @@ -669,7 +668,6 @@ def propagate_scheduling_parameters_to_iter_args( } -@requires_region_format(RegionFormat.SCHEDULE_SIGNATURE_PLACEHOLDERS) def schedule_graph( trace: CapturedTrace, constraints: list[Constraint], diff --git a/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py b/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py index b7b013460b..b595bbbb8d 100755 --- a/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py +++ b/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py @@ -436,6 +436,10 @@ def get_tagged_mxfp4_gemm_preshuffle_b( # K is always large enough for software pipelining. constraints += [tkw.Assumption(K > BLOCK_K * 6)] + # # K >= 2048 ensures the B-data preshuffle within_nblk (max 1023) + # # is always < K_PACKED (= K/2 >= 1024), eliminating dynamic floordiv. + # constraints += [tkw.Assumption(K >= 2048)] + if reorder_workgroups: new_wg0, new_wg1 = _reorder_mxfp4_workgroups( M, N, BLOCK_M, BLOCK_N, GROUP_SIZE_N diff --git a/wave_lang/kernel/wave/utils/mapping_utils.py b/wave_lang/kernel/wave/utils/mapping_utils.py index 2248da418d..870e3d2401 100644 --- a/wave_lang/kernel/wave/utils/mapping_utils.py +++ b/wave_lang/kernel/wave/utils/mapping_utils.py @@ -5,11 +5,16 @@ from typing import TypeVar from copy import deepcopy +from math import gcd, lcm + import sympy import torch.fx as fx +from collections.abc import Sequence + from ..._support.indexing import IndexingContext from ...lang.wave_types import IndexMapping +from ..assumptions import get_divisibility_subs from .general_utils import infer_dim, get_fastest_index from .symbol_utils import IndexExpr, IndexSequence, IndexSymbol, simplify, subs_idxc from ....support.indexing import piecewise_aware_subs @@ -253,6 +258,404 @@ def _make_aligned_index(offset_val: int) -> dict[IndexExpr, IndexSequence]: return True +_INDUCTION_PREFIX = "$ARG" + + +def _expand_mod(expr: sympy.Expr) -> sympy.Expr: + """Rewrite ``Mod(x, d)`` as ``x - d*floor(x/d)``. + + This lets ``floor(x/d)*d + Mod(x, d)`` cancel to ``x`` via normal + algebraic simplification, which SymPy cannot do when ``Mod`` is + kept as an opaque node. + """ + if not expr.has(sympy.Mod): + return expr + return expr.replace( + lambda e: isinstance(e, sympy.Mod), + lambda e: e.args[0] - e.args[1] * sympy.floor(e.args[0] / e.args[1]), + ) + + +def _eval_concrete_floor_mod(expr: sympy.Expr) -> sympy.Expr: + """Evaluate ``floor`` and ``Mod`` nodes whose arguments are concrete. + + SymPy sometimes leaves ``floor(4)`` or ``Mod(0, K)`` un-evaluated + when they were built symbolically and then substituted. This pass + collapses them bottom-up so the result contains no unnecessary + wrappers around known integers. + """ + def _try_eval(e): + if isinstance(e, sympy.floor): + inner = e.args[0] + if inner.is_number: + return sympy.Integer(int(sympy.floor(inner))) + if isinstance(e, sympy.Mod): + a, m = e.args + if a.is_Integer and m.is_Integer and int(m) != 0: + return sympy.Integer(int(a) % int(m)) + if a.is_Integer and int(a) == 0: + return sympy.Integer(0) + return e + + if not expr.has(sympy.floor) and not expr.has(sympy.Mod): + return expr + return expr.replace( + lambda e: isinstance(e, (sympy.floor, sympy.Mod)), + _try_eval, + ) + + +def mem_simplify(expr: sympy.Expr) -> sympy.Expr: + """Simplify an expression representing a tensor memory index. + + Domain-specific alternative to ``sympy.simplify`` that applies only + the algebraic identities relevant to integer memory addressing: + + 1. Rewrite ``Mod(x, d)`` as ``x - d*floor(x/d)`` so that paired + floor terms cancel algebraically. + 2. ``expand()`` to distribute products and cancel matching terms — + this collapses the fundamental round-trip + ``floor(E/D)*D + Mod(E, D) → E``. + 3. Evaluate any ``floor`` or ``Mod`` of concrete integers left over. + """ + expr = sympy.sympify(expr) + if expr.is_number: + return expr + if expr.has(sympy.Mod): + expr = _expand_mod(expr) + expr = sympy.expand(expr) + expr = _eval_concrete_floor_mod(expr) + return expr + + +def linearize_dims( + dim_exprs: list[sympy.Expr], + strides: list[sympy.Expr], +) -> sympy.Expr: + """Compute ``sum(dim_i * stride_i)`` with floor/Mod cancellation. + + Preshuffle mappings produce ``floor(flat/D)`` for one dimension and + ``Mod(flat, D)`` for another, with the first dimension's stride + equal to ``D``. The naive sum keeps SymPy's opaque ``Mod`` node + and ``simplify`` cannot prove the cancellation. + + By expanding ``Mod`` to its floor definition first and then calling + ``expand``, the matching ``floor`` terms cancel algebraically:: + + floor(E/D)*D + Mod(E,D) + → floor(E/D)*D + E - D*floor(E/D) + → E + """ + total = sum(d * s for d, s in zip(dim_exprs, strides)) + return mem_simplify(total) + + +def _infer_floor_to_exact(mem_strides: list[IndexExpr]) -> dict: + """Infer ``floor(sym/n) → sym/n`` substitutions from memory strides. + + When a memory stride has the form ``sym/n`` (symbolic numerator over + a positive integer denominator), the stride must be a non-negative + integer for the layout to be physically valid. That forces the + numerator to be a multiple of the denominator, so + ``floor(sym/n) == sym/n``. + + Substituting this into dimension expressions collapses + ``floor(E / floor(sym/n)) * (sym/n) + Mod(E, floor(sym/n))`` back + into ``E`` via the standard floor/Mod identity, eliminating the + symbolic divisor that the probing cannot handle. + """ + subs_map: dict[sympy.Expr, sympy.Expr] = {} + for stride in mem_strides: + stride = sympy.sympify(stride) + numer, denom = stride.as_numer_denom() + if denom.is_Integer and int(denom) > 1 and numer.free_symbols: + exact_quot = numer / denom + subs_map[sympy.floor(exact_quot)] = exact_quot + return subs_map + + + +def compute_iv_stride_through_mapping( + mapping: IndexMapping, + symbolic_shape: tuple[IndexExpr, ...], + index: dict[IndexExpr, IndexSequence], + is_read: bool = True, + mem_strides: list[IndexExpr] | None = None, + constraints: Sequence = (), +) -> dict[sympy.Symbol, IndexExpr | list[IndexExpr]] | None: + """Compute each IV symbol's linearized stride through a mapping. + + Uses numerical probing: evaluates the linearized address at iv=0,1,...,P, + takes consecutive differences, and detects constant or cyclic stride + patterns. This handles symbolic divisors (e.g. ``floor(K/32)``) that + defeat the old ``sympy.coeff(_iv)`` approach. + + Parameters + ---------- + mem_strides : optional physical memory strides. When provided these are + used for linearization instead of strides derived from + ``symbolic_shape``. Callers with access to the physical memory + layout should pass these to ensure floor/Mod cancellation. + constraints : constraint sequence (may include ``Assumption`` objects). + Divisibility assumptions (e.g. ``Assumption(Eq(K % 256, 0))``) + are used to simplify floor/Mod expressions before probing. + + Returns ``{iv_sym: stride}`` (constant) or ``{iv_sym: [s0, s1, ...]}`` + (repeating cycle), or ``None`` when no IV is found. + """ + iters = mapping.iters + + iv_info: dict[sympy.Symbol, tuple[sympy.Symbol, int]] = {} + for dim_sym, iter_sym in mapping.output_mapping.items(): + seq = index.get(dim_sym) + if seq is None: + continue + start = sympy.sympify(seq.start if isinstance(seq, IndexSequence) else seq) + for sym in start.free_symbols: + if not str(sym).startswith(_INDUCTION_PREFIX): + continue + coeff = sympy.expand(start).coeff(sym) + concrete = subs_idxc(coeff) + if not isinstance(concrete, (int, sympy.Integer)): + print( + f"IV coeff for {sym} is non-concrete: coeff={coeff}" + f" resolved={concrete} (type={type(concrete).__name__})" + f" — skipping this IV" + ) + continue + iv_info[sym] = (iter_sym, int(concrete)) + + if not iv_info: + return None + + print(f"=== compute_iv_stride_through_mapping is_read={is_read} ===") + print(f" iters: {dict(iters)}") + for iv_sym, (iv_iter, cc) in iv_info.items(): + print(f" IV {iv_sym} -> iter={iv_iter} coeff={cc}") + + map_dims = ( + mapping.input_shape if is_read else mapping.output_shape + ) + raw_exprs = ( + mapping.map_input_indices(map_dims) if is_read + else mapping.map_output_indices(map_dims) + ) + + idxc = IndexingContext.current() + dim_exprs = [subs_idxc(e) for e in raw_exprs] + + for i, (raw, resolved) in enumerate(zip(raw_exprs, dim_exprs)): + changed = str(raw) != str(resolved) + print( + f" dim[{i}] raw={raw} -> resolved={resolved}" + f"{' (CHANGED by subs_idxc)' if changed else ''}" + ) + + if mem_strides is None: + symbolic_shape_resolved = tuple(infer_dim(d) for d in symbolic_shape) + mem_strides = strides_from_symbolic_shape( + idxc, symbolic_shape_resolved, allow_mixed_shapes=True + ) + + stride_free = set() + for s in mem_strides: + stride_free |= sympy.sympify(s).free_symbols + print( + f" mem_strides={mem_strides}" + f" (symbolic={sorted(str(s) for s in stride_free) if stride_free else 'none'})" + ) + + div_fwd, div_bwd = get_divisibility_subs(constraints) + if div_fwd: + fwd_dict = dict(div_fwd) + dim_exprs = [sympy.sympify(e).subs(fwd_dict) for e in dim_exprs] + mem_strides = [sympy.sympify(s).subs(fwd_dict) for s in mem_strides] + print(f" divisibility fwd subs: {fwd_dict}") + for i, e in enumerate(dim_exprs): + print(f" dim_after_div_subs[{i}] = {e}") + print(f" mem_strides_after_div_subs={mem_strides}") + else: + floor_subs = _infer_floor_to_exact(mem_strides) + if floor_subs: + dim_exprs = [sympy.sympify(e).subs(floor_subs) for e in dim_exprs] + print(f" floor_to_exact subs (fallback): {floor_subs}") + for i, e in enumerate(dim_exprs): + print(f" dim_after_subs[{i}] = {e}") + + result: dict[sympy.Symbol, IndexExpr | list[IndexExpr]] = {} + + for iv_sym, (iv_iter, concrete_coeff) in iv_info.items(): + stride_or_cycle = _probe_iv_stride( + dim_exprs, mem_strides, iters, iv_iter, concrete_coeff + ) + if stride_or_cycle is None: + print( + f" _probe_iv_stride returned None for IV {iv_sym}" + f" — no pattern detected, returning None for entire mapping" + ) + return None + result[iv_sym] = stride_or_cycle + + if div_bwd: + bwd_dict = dict(div_bwd) + def _bwd(v): + if isinstance(v, list): + return [mem_simplify(sympy.sympify(x).subs(bwd_dict)) for x in v] + return mem_simplify(sympy.sympify(v).subs(bwd_dict)) + result = {k: _bwd(v) for k, v in result.items()} + + for iv_sym, val in result.items(): + print(f" RESULT {iv_sym} -> {val}") + + return result + + +def _extract_integer_divisors(expr: sympy.Expr) -> set[int]: + """Collect positive integer divisors from ``floor`` and ``Mod`` nodes.""" + divisors: set[int] = set() + for sub in sympy.preorder_traversal(expr): + if isinstance(sub, sympy.Mod): + d = sub.args[1] + if d.is_Integer and int(d) > 0: + divisors.add(int(d)) + elif isinstance(sub, sympy.floor): + _, denom = sub.args[0].as_numer_denom() + if denom.is_Integer and int(denom) > 1: + divisors.add(int(denom)) + return divisors + + +_MAX_PROBE_DEPTH = 1024 + + +def _compute_probe_depth(flat_expr: sympy.Expr, concrete_coeff: int) -> int: + """Compute the exact probe depth from divisors in *flat_expr*. + + For each ``floor(expr/D)`` or ``Mod(expr, D)`` with integer D, the IV + contribution has period ``D / gcd(C, D)`` where ``C`` is the IV step. + The overall period ``P = LCM(all periods)`` is the exact number of + diffs needed to capture the full stride pattern. + + Raises ``ValueError`` if the computed depth exceeds ``_MAX_PROBE_DEPTH``, + which indicates a pathological mapping with too many coprime divisors. + """ + if concrete_coeff == 0: + return 1 + divisors = _extract_integer_divisors(flat_expr) + if not divisors: + return 1 + C = abs(concrete_coeff) + periods = [d // gcd(C, d) for d in divisors] + depth = lcm(*periods) + if depth > _MAX_PROBE_DEPTH: + raise ValueError( + f"Probe depth {depth} exceeds maximum {_MAX_PROBE_DEPTH}" + f" (divisors={sorted(divisors)}, C={C})." + f" The mapping has too many coprime floor/Mod divisors" + f" for exact stride analysis." + ) + return depth + + +def _probe_iv_stride( + dim_exprs: list[IndexExpr], + mem_strides: list[IndexExpr], + iters: dict, + iv_iter: sympy.Symbol, + concrete_coeff: int, +) -> IndexExpr | list[IndexExpr] | None: + """Compute the IV stride through the linearized address. + + 1. Linearize ``dim_exprs * mem_strides`` symbolically via + ``linearize_dims`` + ``mem_simplify`` to obtain the flat address + expression (cancelling floor/Mod round-trip pairs). + 2. Compute the exact probe depth ``P`` from the remaining integer + divisors in the flat expression: ``P = LCM(D_i / gcd(C, D_i))``. + 3. Evaluate ``P + 1`` concrete addresses and detect constant or + cyclic stride patterns from the integer diffs. + + If the addresses contain free symbols, this is an error — the + substitution context has unresolved chained dependencies. + + Returns a single IndexExpr (constant stride), a list of IndexExpr + (repeating cycle), or ``None`` on failure. + """ + + print( + f"_probe_iv_stride iv_iter={iv_iter} coeff={concrete_coeff}" + ) + + # Step 1: linearize symbolically, then compute probe depth from the + # flat expression's divisors. Apply subs_idxc to iv_flat so the + # probe-depth computation sees the same symbol resolution as the + # concrete address evaluations (prevents under-probing when a + # divisor is symbolic pre-subs but integer post-subs). + flat_expr = mem_simplify(linearize_dims(dim_exprs, mem_strides)) + iv_flat = flat_expr.subs({ + it: (concrete_coeff * sympy.Symbol("_iv") if it == iv_iter else 0) + for it in iters.keys() + }) + iv_flat = subs_idxc(iv_flat) + probe_depth = _compute_probe_depth(iv_flat, concrete_coeff) + + print(f" probe_depth={probe_depth}") + + # Step 2: evaluate P+1 concrete addresses. + def _linearized_addr(iv_val: int) -> IndexExpr: + subs = { + it: (concrete_coeff * iv_val if it == iv_iter else 0) + for it in iters.keys() + } + resolved = [subs_idxc(dim_expr.subs(subs)) for dim_expr in dim_exprs] + return mem_simplify(subs_idxc(linearize_dims(resolved, mem_strides))) + + addrs: list[int] = [] + for iv in range(probe_depth + 1): + a = _linearized_addr(iv) + if getattr(a, 'free_symbols', set()): + print( + f" addr[iv={iv}] = {a} (free={a.free_symbols})" + f"\n *** ERROR: address contains unresolved free symbols." + f" Fix chained symbolic dependencies upstream." + ) + return None + addrs.append(int(a)) + + diffs = [addrs[i + 1] - addrs[i] for i in range(probe_depth)] + + for i, a in enumerate(addrs): + print(f" addr[iv={i}] = {a}") + for i, d in enumerate(diffs): + print(f" diff[{i}] = {d}") + + if not diffs: + return None + + # Step 3: detect constant or shortest repeating cycle. + # The probe depth is the analytically-proven period, so one full + # repetition is sufficient — use range(1, probe_depth + 1). + for cycle_len in range(1, probe_depth + 1): + if all(diffs[i] == diffs[i % cycle_len] for i in range(probe_depth)): + cycle = [sympy.Integer(diffs[i]) for i in range(cycle_len)] + if cycle_len == 1: + print( + f" -> CONSTANT stride = {cycle[0]}" + f" (concrete=True, probe_depth={probe_depth})" + ) + return cycle[0] + print( + f" -> CYCLIC stride (len={cycle_len}): {diffs[:cycle_len]}" + f" (concrete=True, probe_depth={probe_depth})" + ) + return cycle + + print( + f" -> FAILED: no constant or cyclic pattern in {probe_depth}" + f" diffs. diffs={diffs}" + ) + return None + + def transform_index_on_mapping( mapping: IndexMapping, symbolic_shape: tuple[IndexExpr, ...], diff --git a/wave_lang/kernel/wave/utils/symbol_utils.py b/wave_lang/kernel/wave/utils/symbol_utils.py index 969e288c70..13a5e993c0 100644 --- a/wave_lang/kernel/wave/utils/symbol_utils.py +++ b/wave_lang/kernel/wave/utils/symbol_utils.py @@ -63,76 +63,39 @@ #################################################################### -def expr_bounds( - expr: sympy.Expr, - symbol_bounds: dict[sympy.Symbol, tuple[sympy.Expr, sympy.Expr]] | None = None, -) -> tuple[sympy.Expr, sympy.Expr] | None: - """Compute (lo, hi) bounds for a sympy expression via interval arithmetic. - - When *symbol_bounds* is provided, symbols found in the dict use those - concrete ranges instead of the default [0, oo) assumption. This - enables tighter results for tile-level expressions with known iterator - ranges. - - Free symbols not in *symbol_bounds* are assumed to be non-negative - integers (hardware indices). Returns (lo, hi) or None if bounds - cannot be determined. - """ - hashable = ( - tuple(sorted(symbol_bounds.items(), key=lambda kv: str(kv[0]))) - if symbol_bounds - else None - ) - return _expr_bounds_cached(expr, hashable) - - @lru_cache(maxsize=1024) -def _expr_bounds_cached( - expr: sympy.Expr, - symbol_bounds_tuple: tuple | None = None, -) -> tuple[sympy.Expr, sympy.Expr] | None: - """Cached implementation of expr_bounds. +def expr_bounds(expr: sympy.Expr) -> tuple[sympy.Expr, sympy.Expr] | None: + """Compute (lo, hi) bounds for a sympy expression via interval arithmetic. - Takes *symbol_bounds* as a hashable tuple of (symbol, (lo, hi)) pairs. + Free symbols are assumed to be non-negative integers (hardware indices). + Returns (lo, hi) or None if bounds cannot be determined. """ - sb = dict(symbol_bounds_tuple) if symbol_bounds_tuple else None - if expr.is_Integer or expr.is_Rational: return (expr, expr) if expr.is_Symbol: - if sb and expr in sb: - return sb[expr] return (sympy.Integer(0), sympy.oo) if expr.is_nonnegative else None if isinstance(expr, sympy.Mod): p, q = expr.args if q.is_positive and q.is_number: - p_b = _expr_bounds_cached(p, symbol_bounds_tuple) - try: - if p_b and p_b[0] >= 0 and p_b[1] < q: - return p_b - except TypeError: - pass # Symbolic comparison -- fall through to default. + p_bounds = expr_bounds(p) + if p_bounds and p_bounds[0] >= 0 and p_bounds[1] < q: + return p_bounds return (sympy.Integer(0), q - 1) - # Symbolic modulus: Mod(p, q) is in [0, q-1] when q > 0. - # Use the upper bound of q as a conservative ceiling. - q_b = _expr_bounds_cached(q, symbol_bounds_tuple) - if q_b and q_b[0].is_positive: - return (sympy.Integer(0), q_b[1] - 1) return None if isinstance(expr, sympy.floor): - inner_b = _expr_bounds_cached(expr.args[0], symbol_bounds_tuple) - if inner_b: - return (sympy.floor(inner_b[0]), sympy.floor(inner_b[1])) + inner_bounds = expr_bounds(expr.args[0]) + if inner_bounds: + return (sympy.floor(inner_bounds[0]), sympy.floor(inner_bounds[1])) return None if isinstance(expr, sympy.Add): - bounds = [_expr_bounds_cached(a, symbol_bounds_tuple) for a in expr.args] + bounds = [expr_bounds(a) for a in expr.args] if all(b is not None for b in bounds): return (sum(b[0] for b in bounds), sum(b[1] for b in bounds)) return None if isinstance(expr, sympy.Mul): if not expr.args: return (sympy.Integer(1), sympy.Integer(1)) - bounds = [_expr_bounds_cached(a, symbol_bounds_tuple) for a in expr.args] + bounds = [expr_bounds(a) for a in expr.args] if all(b is not None for b in bounds): # Bail out if any bound is infinite (0 * oo = NaN). if any(sympy.oo in b or -sympy.oo in b for b in bounds): @@ -140,10 +103,7 @@ def _expr_bounds_cached( lo, hi = bounds[0] for b in bounds[1:]: corners = [lo * b[0], lo * b[1], hi * b[0], hi * b[1]] - try: - lo, hi = min(corners), max(corners) - except TypeError: - return None + lo, hi = min(corners), max(corners) return (lo, hi) return None return None @@ -172,8 +132,6 @@ def _is_provably_divisible(term: sympy.Expr, divisor: sympy.Expr) -> bool: # Decompose the divisor into numeric and symbolic parts. # E.g. 8*floor(...) -> (8, floor(...)) div_coeff, div_sym = _split_coeff(divisor) - if div_coeff.is_zero: - return False if isinstance(term, sympy.Mul): # Check if term contains div_sym as a multiplicative factor @@ -209,7 +167,9 @@ def _split_coeff(expr: sympy.Expr) -> tuple[sympy.Integer, sympy.Expr]: return (sympy.Integer(1), expr) -def _contains_factor(factors: list[sympy.Expr], target: sympy.Expr) -> bool: +def _contains_factor( + factors: list[sympy.Expr], target: sympy.Expr +) -> bool: """Check if *target* appears as a factor in *factors* (possibly nested).""" for f in factors: if f == target: @@ -220,7 +180,7 @@ def _contains_factor(factors: list[sympy.Expr], target: sympy.Expr) -> bool: return False -def split_sum_by_divisibility( +def _split_sum_by_divisibility( expr: sympy.Expr, divisor: sympy.Expr ) -> tuple[sympy.Expr, sympy.Expr] | None: """Split *expr* into ``(quotient, remainder)`` such that @@ -368,7 +328,7 @@ def transform_floor_div(expr): numer, denom = inner.as_numer_denom() if denom == 1: return None - result = split_sum_by_divisibility(numer, denom) + result = _split_sum_by_divisibility(numer, denom) if result is None: return None quotient, remainder = result @@ -389,7 +349,7 @@ def transform_mod_div(expr): if not isinstance(expr, sympy.Mod): return None p, q = expr.args - result = split_sum_by_divisibility(p, q) + result = _split_sum_by_divisibility(p, q) if result is None: return None _quotient, remainder = result @@ -405,9 +365,7 @@ def transform_mod_div(expr): pass # Symbolic comparison — can't determine. return sympy.Mod(remainder, q, evaluate=False) - expr = expr.replace( - lambda e: transform_floor_div(e) is not None, transform_floor_div - ) + expr = expr.replace(lambda e: transform_floor_div(e) is not None, transform_floor_div) expr = expr.replace(lambda e: transform_mod_div(e) is not None, transform_mod_div) expr = expr.replace(lambda e: transform_mod(e) is not None, transform_mod) expr = expr.replace(lambda e: transform_floor(e) is not None, transform_floor) @@ -417,6 +375,103 @@ def transform_mod_div(expr): _simplify_cache: dict[sympy.Basic, sympy.Expr] = {} +def extract_iv( + expr: sympy.Expr, + iv: sympy.Symbol, +) -> tuple[sympy.Expr, sympy.Expr] | None: + """Split *expr* into ``(iv_coeff, base)`` such that + ``expr == iv_coeff * iv + base`` and ``iv`` does not appear in ``base``. + + Uses ``sympy.Expr.coeff`` for the linear case. When *iv* survives in the + remainder (e.g. because ``simplify`` folded it into floor/Mod), applies the + general integer-division identity to pull *iv* out: + + * ``floor((c*iv + r) / d) = floor(c/d)*iv + floor((Mod(c,d)*iv + r)/d)`` + * ``Mod(c*iv + r, m) = Mod(Mod(c,m)*iv + r, m)`` + + Returns ``None`` if *iv* cannot be fully separated. + """ + expanded = sympy.expand(expr) + coeff = expanded.coeff(iv) + base = simplify(expanded - coeff * iv) + if iv not in base.free_symbols: + return (coeff, base) + + # iv is stuck inside floor/Mod — apply decomposition identities. + return _extract_iv_from_floor_mod(expr, iv) + + +def _extract_iv_from_floor_mod( + expr: sympy.Expr, + iv: sympy.Symbol, +) -> tuple[sympy.Expr, sympy.Expr] | None: + """Apply floor/Mod integer-division identities to separate *iv*. + + Walks the expression tree. For each ``floor(numer/denom)`` or + ``Mod(value, modulus)`` that contains *iv*, decomposes the *iv*-coefficient + using the identities: + + * ``floor((c*iv + r) / d) = floor(c/d)*iv + floor((Mod(c,d)*iv + r)/d)`` + * ``Mod(c*iv + r, m) = Mod(Mod(c,m)*iv + r, m)`` + + After rewriting, if *iv* remains in the base (e.g. inside + ``floor(Mod(c,d)*iv/d + ...)``), uses numeric probing to check whether + the residual coefficient ``Mod(c, d)`` is zero for all practical values. + When ``d | c`` (common for power-of-2 tile sizes), the probe confirms + zero and the *iv* term vanishes from the residual floor. + """ + + def _rewrite_floor(arg): + numer, denom = arg.as_numer_denom() + numer = sympy.expand(numer) + iv_coeff = numer.coeff(iv) + if iv_coeff == 0: + return sympy.floor(arg) + rest = numer - iv_coeff * iv + return ( + sympy.floor(iv_coeff / denom) * iv + + sympy.floor( + (sympy.Mod(iv_coeff, denom, evaluate=False) * iv + rest) / denom + ) + ) + + def _rewrite_mod(*args): + value, modulus = args + value_expanded = sympy.expand(value) + iv_coeff = value_expanded.coeff(iv) + if iv_coeff == 0: + return sympy.Mod(value, modulus, evaluate=False) + rest = value_expanded - iv_coeff * iv + return sympy.Mod( + sympy.Mod(iv_coeff, modulus, evaluate=False) * iv + rest, + modulus, + evaluate=False, + ) + + rewritten = expr.replace(sympy.floor, _rewrite_floor) + rewritten = rewritten.replace(sympy.Mod, _rewrite_mod) + + expanded = sympy.expand(rewritten) + coeff = expanded.coeff(iv) + base = expanded - coeff * iv + if iv not in base.free_symbols: + return (simplify(coeff), simplify(base)) + + return None + +def simplify_divisor_multiples(expr: sympy.Expr) -> sympy.Expr: + """Factor out divisor-multiples from floor/Mod without expand/cancel. + + Applies only the ``transform_floor_div`` and ``transform_mod_div`` + rewrites from ``_custom_simplify_once``. This is safe for complex + post-substitution expressions where ``sympy.expand`` / ``sympy.cancel`` + would destroy the expression structure. + """ + if not isinstance(expr, sympy.Basic) or expr.is_Atom: + return expr + return _custom_simplify_once(expr) + + def simplify(expr: sympy.Expr) -> sympy.Expr: """Simplify a sympy expression using interval arithmetic and cancel. diff --git a/waveasm/include/waveasm/Dialect/WaveASMInterfaces.h b/waveasm/include/waveasm/Dialect/WaveASMInterfaces.h index b401f5ee8c..0167f3e175 100644 --- a/waveasm/include/waveasm/Dialect/WaveASMInterfaces.h +++ b/waveasm/include/waveasm/Dialect/WaveASMInterfaces.h @@ -47,6 +47,20 @@ class ControlFlowOp : public TraitBase {}; template class SpecialRegOp : public TraitBase {}; +/// Trait for operations that implicitly write the SCC (Scalar Condition Code) +/// flag as a hardware side effect. Ops with this trait must not be placed +/// between an SCC-producing op and its consumer (ConditionOp, s_cselect, +/// s_addc_u32). The SCC verifier pass checks this invariant. +template +class SCCDef : public TraitBase {}; + +/// Trait for operations that implicitly read the SCC flag. +/// The result depends on the current SCC value, so these ops are NOT +/// eligible for CSE (two identical operands can produce different results +/// if SCC differs). +template +class SCCUse : public TraitBase {}; + } // namespace OpTrait } // namespace mlir diff --git a/waveasm/include/waveasm/Dialect/WaveASMInterfaces.td b/waveasm/include/waveasm/Dialect/WaveASMInterfaces.td index c8c143b84d..6a68ea10eb 100644 --- a/waveasm/include/waveasm/Dialect/WaveASMInterfaces.td +++ b/waveasm/include/waveasm/Dialect/WaveASMInterfaces.td @@ -264,4 +264,10 @@ def WaveASM_ControlFlowOp : NativeOpTrait<"ControlFlowOp">; // Trait for special register operations (M0, VCC, EXEC) def WaveASM_SpecialRegOp : NativeOpTrait<"SpecialRegOp">; +// Trait for operations that implicitly write the SCC flag. +def WaveASM_SCCDef : NativeOpTrait<"SCCDef">; + +// Trait for operations that implicitly read the SCC flag. +def WaveASM_SCCUse : NativeOpTrait<"SCCUse">; + #endif // WaveASM_DIALECT_WAVEASMINTERFACES diff --git a/waveasm/include/waveasm/Dialect/WaveASMOps.td b/waveasm/include/waveasm/Dialect/WaveASMOps.td index a37ff3391d..e36602bf46 100644 --- a/waveasm/include/waveasm/Dialect/WaveASMOps.td +++ b/waveasm/include/waveasm/Dialect/WaveASMOps.td @@ -83,6 +83,10 @@ class VALUCmpOp traits = []> } // SALU Unary: dst = op(src) +// Pure — does NOT clobber SCC on hardware. +// Used for: s_mov_b32, s_mov_b64, s_sext_i32_i8, s_sext_i32_i16. +// NOTE: Also used for SCC-clobbering ops (s_not, s_brev, etc.) until +// SALUUnaryWithSCCOp migration is validated. class SALUUnaryOp traits = []> : WAVEASMOp { let arguments = (ins WaveASM_SRegOrImm:$src); @@ -90,7 +94,25 @@ class SALUUnaryOp traits = []> let assemblyFormat = "$src attr-dict `:` type($src) `->` type($dst)"; } +// SALU Unary with SCC clobber: dst = op(src), SCC = f(result) +// Writes SCC on hardware but result depends only on operands. +// Uses NoMemoryEffect + AlwaysSpeculatableImplTrait (equivalent to Pure) +// so MLIR's DCE and LICM still work, plus SCCDef for the SCC verifier. +class SALUUnaryWithSCCOp traits = []> + : WAVEASMOp { + let arguments = (ins WaveASM_SRegOrImm:$src); + let results = (outs WaveASM_AnySGPR:$dst); + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($dst)"; +} + // SALU Binary: dst = op(src0, src1) +// Pure — does NOT clobber SCC on hardware. +// Used for: s_mul_i32, s_mul_hi_*, s_bfm_*, s_pack_*, s_cselect_b32. +// NOTE: Also used for SCC-clobbering ops (s_and, s_or, s_lshl, etc.) +// until SALUBinaryWithSCCOp migration is validated. class SALUBinaryOp traits = []> : WAVEASMOp { let arguments = (ins WaveASM_AnySGPR:$src0, WaveASM_SRegOrImm:$src1); @@ -98,6 +120,21 @@ class SALUBinaryOp traits = []> let assemblyFormat = "$src0 `,` $src1 attr-dict `:` type($src0) `,` type($src1) `->` type($dst)"; } +// SALU Binary with SCC clobber: dst = op(src0, src1), SCC = f(result) +// Writes SCC on hardware but result depends only on operands. +// Uses NoMemoryEffect + AlwaysSpeculatableImplTrait (equivalent to Pure) +// so MLIR's DCE and LICM still work, plus SCCDef so the SCC verifier +// can identify these as SCC-clobbering. +class SALUBinaryWithSCCOp traits = []> + : WAVEASMOp { + let arguments = (ins WaveASM_AnySGPR:$src0, WaveASM_SRegOrImm:$src1); + let results = (outs WaveASM_AnySGPR:$dst); + let assemblyFormat = "$src0 `,` $src1 attr-dict `:` type($src0) `,` type($src1) `->` type($dst)"; +} + // SALU Binary with carry/SCC: dst, scc = op(src0, src1) // s_add_u32 and s_addc_u32 set SCC (carry flag) as a side effect. // We model SCC as an explicit sreg result so the verifier can prevent @@ -107,7 +144,7 @@ class SALUBinaryOp traits = []> // a true SSA dependency. Currently the chain is only enforced by emission // ordering, not by the SSA graph. class SALUBinaryWithCarryOp traits = []> - : WAVEASMOp { + : WAVEASMOp { let arguments = (ins WaveASM_AnySGPR:$src0, WaveASM_SRegOrImm:$src1); let results = (outs WaveASM_AnySGPR:$dst, WaveASM_AnySGPR:$scc); let assemblyFormat = "$src0 `,` $src1 attr-dict `:` type($src0) `,` type($src1) `->` type($dst) `,` type($scc)"; @@ -123,7 +160,7 @@ class SALUBinaryWithCarryOp traits = []> // since the live range is very short (s_cmp -> condition terminator). // TODO: Consider a dedicated SCCType to avoid the wasted register slot. class SALUCmpOp traits = []> - : WAVEASMOp { + : WAVEASMOp { let arguments = (ins WaveASM_AnySGPR:$src0, WaveASM_SRegOrImm:$src1); let results = (outs WaveASM_AnySGPR:$result); // SCC result (see note above) let assemblyFormat = "$src0 `,` $src1 attr-dict `:` type($src0) `,` type($src1) `->` type($result)"; @@ -698,26 +735,29 @@ def WaveASM_V_CMP_GE_U64 : VALUCmpOp<"v_cmp_ge_u64">; // SALU Unary Instructions //===----------------------------------------------------------------------===// +// No SCC clobber def WaveASM_S_MOV_B32 : SALUUnaryOp<"s_mov_b32">; def WaveASM_S_MOV_B64 : SALUUnaryOp<"s_mov_b64">; -def WaveASM_S_NOT_B32 : SALUUnaryOp<"s_not_b32">; -def WaveASM_S_NOT_B64 : SALUUnaryOp<"s_not_b64">; -def WaveASM_S_BREV_B32 : SALUUnaryOp<"s_brev_b32">; -def WaveASM_S_BREV_B64 : SALUUnaryOp<"s_brev_b64">; -def WaveASM_S_BCNT0_I32_B32 : SALUUnaryOp<"s_bcnt0_i32_b32">; -def WaveASM_S_BCNT0_I32_B64 : SALUUnaryOp<"s_bcnt0_i32_b64">; -def WaveASM_S_BCNT1_I32_B32 : SALUUnaryOp<"s_bcnt1_i32_b32">; -def WaveASM_S_BCNT1_I32_B64 : SALUUnaryOp<"s_bcnt1_i32_b64">; -def WaveASM_S_FF0_I32_B32 : SALUUnaryOp<"s_ff0_i32_b32">; -def WaveASM_S_FF0_I32_B64 : SALUUnaryOp<"s_ff0_i32_b64">; -def WaveASM_S_FF1_I32_B32 : SALUUnaryOp<"s_ff1_i32_b32">; -def WaveASM_S_FF1_I32_B64 : SALUUnaryOp<"s_ff1_i32_b64">; -def WaveASM_S_FLBIT_I32_B32 : SALUUnaryOp<"s_flbit_i32_b32">; -def WaveASM_S_FLBIT_I32_B64 : SALUUnaryOp<"s_flbit_i32_b64">; -def WaveASM_S_ABS_I32 : SALUUnaryOp<"s_abs_i32">; def WaveASM_S_SEXT_I32_I8 : SALUUnaryOp<"s_sext_i32_i8">; def WaveASM_S_SEXT_I32_I16 : SALUUnaryOp<"s_sext_i32_i16">; +// Clobber SCC (SCC = f(result)) +def WaveASM_S_NOT_B32 : SALUUnaryWithSCCOp<"s_not_b32">; +def WaveASM_S_NOT_B64 : SALUUnaryWithSCCOp<"s_not_b64">; +def WaveASM_S_BREV_B32 : SALUUnaryWithSCCOp<"s_brev_b32">; +def WaveASM_S_BREV_B64 : SALUUnaryWithSCCOp<"s_brev_b64">; +def WaveASM_S_BCNT0_I32_B32 : SALUUnaryWithSCCOp<"s_bcnt0_i32_b32">; +def WaveASM_S_BCNT0_I32_B64 : SALUUnaryWithSCCOp<"s_bcnt0_i32_b64">; +def WaveASM_S_BCNT1_I32_B32 : SALUUnaryWithSCCOp<"s_bcnt1_i32_b32">; +def WaveASM_S_BCNT1_I32_B64 : SALUUnaryWithSCCOp<"s_bcnt1_i32_b64">; +def WaveASM_S_FF0_I32_B32 : SALUUnaryWithSCCOp<"s_ff0_i32_b32">; +def WaveASM_S_FF0_I32_B64 : SALUUnaryWithSCCOp<"s_ff0_i32_b64">; +def WaveASM_S_FF1_I32_B32 : SALUUnaryWithSCCOp<"s_ff1_i32_b32">; +def WaveASM_S_FF1_I32_B64 : SALUUnaryWithSCCOp<"s_ff1_i32_b64">; +def WaveASM_S_FLBIT_I32_B32 : SALUUnaryWithSCCOp<"s_flbit_i32_b32">; +def WaveASM_S_FLBIT_I32_B64 : SALUUnaryWithSCCOp<"s_flbit_i32_b64">; +def WaveASM_S_ABS_I32 : SALUUnaryWithSCCOp<"s_abs_i32">; + //===----------------------------------------------------------------------===// // SALU Binary Instructions //===----------------------------------------------------------------------===// @@ -728,48 +768,58 @@ def WaveASM_S_ADDC_U32 : SALUBinaryWithCarryOp<"s_addc_u32">; def WaveASM_S_ADD_I32 : SALUBinaryWithCarryOp<"s_add_i32">; def WaveASM_S_SUB_U32 : SALUBinaryWithCarryOp<"s_sub_u32">; def WaveASM_S_SUB_I32 : SALUBinaryWithCarryOp<"s_sub_i32">; +// Conditional select: dst = SCC ? src0 : src1 (reads SCC, no SCC write). +// NOT Pure/ArithmeticOp — result depends on implicit SCC state, so NOT +// CSE-eligible (two s_cselect_b32 with same operands can give different +// results if SCC differs). NoMemoryEffect allows DCE of unused results. +def WaveASM_S_CSELECT_B32 : WAVEASMOp<"s_cselect_b32", + [WaveASM_SCCUse, NoMemoryEffect]> { + let arguments = (ins WaveASM_AnySGPR:$src0, WaveASM_SRegOrImm:$src1); + let results = (outs WaveASM_AnySGPR:$dst); + let assemblyFormat = "$src0 `,` $src1 attr-dict `:` type($src0) `,` type($src1) `->` type($dst)"; +} // Multiplication does not set SCC def WaveASM_S_MUL_I32 : SALUBinaryOp<"s_mul_i32">; def WaveASM_S_MUL_HI_U32 : SALUBinaryOp<"s_mul_hi_u32">; def WaveASM_S_MUL_HI_I32 : SALUBinaryOp<"s_mul_hi_i32">; -// Bitwise -def WaveASM_S_AND_B32 : SALUBinaryOp<"s_and_b32">; -def WaveASM_S_AND_B64 : SALUBinaryOp<"s_and_b64">; -def WaveASM_S_OR_B32 : SALUBinaryOp<"s_or_b32">; -def WaveASM_S_OR_B64 : SALUBinaryOp<"s_or_b64">; -def WaveASM_S_XOR_B32 : SALUBinaryOp<"s_xor_b32">; -def WaveASM_S_XOR_B64 : SALUBinaryOp<"s_xor_b64">; -def WaveASM_S_ANDN2_B32 : SALUBinaryOp<"s_andn2_b32">; -def WaveASM_S_ANDN2_B64 : SALUBinaryOp<"s_andn2_b64">; -def WaveASM_S_ORN2_B32 : SALUBinaryOp<"s_orn2_b32">; -def WaveASM_S_ORN2_B64 : SALUBinaryOp<"s_orn2_b64">; -def WaveASM_S_NAND_B32 : SALUBinaryOp<"s_nand_b32">; -def WaveASM_S_NAND_B64 : SALUBinaryOp<"s_nand_b64">; -def WaveASM_S_NOR_B32 : SALUBinaryOp<"s_nor_b32">; -def WaveASM_S_NOR_B64 : SALUBinaryOp<"s_nor_b64">; -def WaveASM_S_XNOR_B32 : SALUBinaryOp<"s_xnor_b32">; -def WaveASM_S_XNOR_B64 : SALUBinaryOp<"s_xnor_b64">; +// Bitwise — all clobber SCC (SCC = result != 0) +def WaveASM_S_AND_B32 : SALUBinaryWithSCCOp<"s_and_b32">; +def WaveASM_S_AND_B64 : SALUBinaryWithSCCOp<"s_and_b64">; +def WaveASM_S_OR_B32 : SALUBinaryWithSCCOp<"s_or_b32">; +def WaveASM_S_OR_B64 : SALUBinaryWithSCCOp<"s_or_b64">; +def WaveASM_S_XOR_B32 : SALUBinaryWithSCCOp<"s_xor_b32">; +def WaveASM_S_XOR_B64 : SALUBinaryWithSCCOp<"s_xor_b64">; +def WaveASM_S_ANDN2_B32 : SALUBinaryWithSCCOp<"s_andn2_b32">; +def WaveASM_S_ANDN2_B64 : SALUBinaryWithSCCOp<"s_andn2_b64">; +def WaveASM_S_ORN2_B32 : SALUBinaryWithSCCOp<"s_orn2_b32">; +def WaveASM_S_ORN2_B64 : SALUBinaryWithSCCOp<"s_orn2_b64">; +def WaveASM_S_NAND_B32 : SALUBinaryWithSCCOp<"s_nand_b32">; +def WaveASM_S_NAND_B64 : SALUBinaryWithSCCOp<"s_nand_b64">; +def WaveASM_S_NOR_B32 : SALUBinaryWithSCCOp<"s_nor_b32">; +def WaveASM_S_NOR_B64 : SALUBinaryWithSCCOp<"s_nor_b64">; +def WaveASM_S_XNOR_B32 : SALUBinaryWithSCCOp<"s_xnor_b32">; +def WaveASM_S_XNOR_B64 : SALUBinaryWithSCCOp<"s_xnor_b64">; // Shifts -def WaveASM_S_LSHL_B32 : SALUBinaryOp<"s_lshl_b32">; -def WaveASM_S_LSHL_B64 : SALUBinaryOp<"s_lshl_b64">; -def WaveASM_S_LSHR_B32 : SALUBinaryOp<"s_lshr_b32">; -def WaveASM_S_LSHR_B64 : SALUBinaryOp<"s_lshr_b64">; -def WaveASM_S_ASHR_I32 : SALUBinaryOp<"s_ashr_i32">; -def WaveASM_S_ASHR_I64 : SALUBinaryOp<"s_ashr_i64">; +def WaveASM_S_LSHL_B32 : SALUBinaryWithSCCOp<"s_lshl_b32">; +def WaveASM_S_LSHL_B64 : SALUBinaryWithSCCOp<"s_lshl_b64">; +def WaveASM_S_LSHR_B32 : SALUBinaryWithSCCOp<"s_lshr_b32">; +def WaveASM_S_LSHR_B64 : SALUBinaryWithSCCOp<"s_lshr_b64">; +def WaveASM_S_ASHR_I32 : SALUBinaryWithSCCOp<"s_ashr_i32">; +def WaveASM_S_ASHR_I64 : SALUBinaryWithSCCOp<"s_ashr_i64">; // Min/Max -def WaveASM_S_MIN_I32 : SALUBinaryOp<"s_min_i32">; -def WaveASM_S_MIN_U32 : SALUBinaryOp<"s_min_u32">; -def WaveASM_S_MAX_I32 : SALUBinaryOp<"s_max_i32">; -def WaveASM_S_MAX_U32 : SALUBinaryOp<"s_max_u32">; +def WaveASM_S_MIN_I32 : SALUBinaryWithSCCOp<"s_min_i32">; +def WaveASM_S_MIN_U32 : SALUBinaryWithSCCOp<"s_min_u32">; +def WaveASM_S_MAX_I32 : SALUBinaryWithSCCOp<"s_max_i32">; +def WaveASM_S_MAX_U32 : SALUBinaryWithSCCOp<"s_max_u32">; // Misc -def WaveASM_S_BFE_U32 : SALUBinaryOp<"s_bfe_u32">; -def WaveASM_S_BFE_I32 : SALUBinaryOp<"s_bfe_i32">; -def WaveASM_S_BFE_U64 : SALUBinaryOp<"s_bfe_u64">; -def WaveASM_S_BFE_I64 : SALUBinaryOp<"s_bfe_i64">; +def WaveASM_S_BFE_U32 : SALUBinaryWithSCCOp<"s_bfe_u32">; +def WaveASM_S_BFE_I32 : SALUBinaryWithSCCOp<"s_bfe_i32">; +def WaveASM_S_BFE_U64 : SALUBinaryWithSCCOp<"s_bfe_u64">; +def WaveASM_S_BFE_I64 : SALUBinaryWithSCCOp<"s_bfe_i64">; def WaveASM_S_BFM_B32 : SALUBinaryOp<"s_bfm_b32">; def WaveASM_S_BFM_B64 : SALUBinaryOp<"s_bfm_b64">; def WaveASM_S_PACK_LL_B32_B16 : SALUBinaryOp<"s_pack_ll_b32_b16">; diff --git a/waveasm/include/waveasm/Transforms/Passes.td b/waveasm/include/waveasm/Transforms/Passes.td index 0bea664cec..674bd21193 100644 --- a/waveasm/include/waveasm/Transforms/Passes.td +++ b/waveasm/include/waveasm/Transforms/Passes.td @@ -38,8 +38,8 @@ def WAVEASMLinearScan : Pass<"waveasm-linear-scan"> { let options = [ Option<"maxVGPRs", "max-vgprs", "int64_t", "256", "Maximum VGPRs available">, - Option<"maxSGPRs", "max-sgprs", "int64_t", "104", - "Maximum SGPRs available">, + Option<"maxSGPRs", "max-sgprs", "int64_t", "102", + "Maximum user-addressable SGPRs (s0-s101 on GFX9+)">, Option<"maxAGPRs", "max-agprs", "int64_t", "256", "Maximum AGPRs available"> ]; @@ -246,6 +246,36 @@ def WAVEASMExtractScalarization ]; } +//===----------------------------------------------------------------------===// +// SCC Verifier Pass +//===----------------------------------------------------------------------===// + +def WAVEASMSCCVerifier : Pass<"waveasm-scc-verifier"> { + let summary = "Verify SCC (Scalar Condition Code) liveness invariants"; + let description = [{ + Checks that no SCC-clobbering SALU instruction sits between an + SCC-producing op and its consumer. Uses isa<> checks (not traits) + to avoid changing ODS-generated code for existing op classes. + }]; + let dependentDialects = ["::waveasm::WaveASMDialect"]; +} + +//===----------------------------------------------------------------------===// +// Memory Offset Optimization Pass +//===----------------------------------------------------------------------===// + +def WAVEASMVGPRCompaction : Pass<"waveasm-vgpr-compaction"> { + let summary = "Compact physical VGPR assignments to reduce peak register count"; + let description = [{ + Re-assigns physical VGPRs after linear scan allocation to eliminate + fragmentation caused by interleaved buffer_load and ds_read destinations. + Uses a shortest-first greedy strategy: short-lived values get low + registers, long-lived values get high registers, eliminating gaps + from interleaving. + }]; + let dependentDialects = ["::waveasm::WaveASMDialect"]; +} + //===----------------------------------------------------------------------===// // Memory Offset Optimization Pass //===----------------------------------------------------------------------===// diff --git a/waveasm/include/waveasm/Transforms/RegAlloc.h b/waveasm/include/waveasm/Transforms/RegAlloc.h index 523fec57ae..d3d5dee371 100644 --- a/waveasm/include/waveasm/Transforms/RegAlloc.h +++ b/waveasm/include/waveasm/Transforms/RegAlloc.h @@ -11,6 +11,7 @@ #include "waveasm/Dialect/WaveASMOps.h" #include "waveasm/Dialect/WaveASMTypes.h" #include "waveasm/Transforms/Liveness.h" +#include "llvm/ADT/BitVector.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SmallVector.h" @@ -66,126 +67,148 @@ class PhysicalMapping { // Register Pool //===----------------------------------------------------------------------===// -/// Pool of available physical registers +/// Pool of available physical registers. +/// Uses a BitVector for O(1) per-register operations and cache-friendly +/// scanning. Typical GPU register files (256 VGPRs, 102 SGPRs, 256 AGPRs) +/// fit in a few 64-bit words. class RegPool { public: RegPool(RegClass regClass, int64_t maxRegs, const llvm::DenseSet &reserved) - : regClass(regClass), maxRegs(maxRegs) { - // Initialize free list with all non-reserved registers - for (int64_t i = 0; i < maxRegs; ++i) { - if (!reserved.contains(i)) { - freeList.push_back(i); - } + : regClass(regClass), maxRegs(maxRegs), free(maxRegs, true) { + for (int64_t r : reserved) { + if (r >= 0 && r < maxRegs) + free.reset(r); } } - /// Check if a register is currently in the free list + /// O(1) free-check via bit test. bool isFree(int64_t reg) const { - return std::find(freeList.begin(), freeList.end(), reg) != freeList.end(); + return reg >= 0 && reg < maxRegs && free.test(reg); } - /// Reserve a specific register (for precoloring) + /// Reserve a specific register range (for precoloring / re-reservation). void reserve(int64_t reg, int64_t size) { for (int64_t i = 0; i < size; ++i) { int64_t r = reg + i; - auto it = std::find(freeList.begin(), freeList.end(), r); - if (it != freeList.end()) { - freeList.erase(it); - allocated.insert(r); + if (r < maxRegs && free.test(r)) { + free.reset(r); + ++currentUsage; } } updatePeak(); } - /// Allocate a single register - /// Returns -1 if allocation fails + /// Allocate the lowest-numbered free register. + /// Returns -1 if no register is available. int64_t allocSingle() { - if (freeList.empty()) + int idx = free.find_first(); + if (idx < 0) return -1; - - int64_t reg = freeList.front(); - freeList.erase(freeList.begin()); - allocated.insert(reg); + free.reset(idx); + ++currentUsage; updatePeak(); - return reg; + return idx; } - /// Allocate a contiguous range of registers with alignment - /// Returns base register index, or -1 if allocation fails + /// Allocate a contiguous range of registers with the given alignment. + /// Scans by alignment stride for O(maxRegs/alignment) candidate checks. + /// Returns base register index, or -1 if allocation fails. int64_t allocRange(int64_t size, int64_t alignment) { if (size <= 0) return -1; - llvm::DenseSet freeSet(freeList.begin(), freeList.end()); + for (int64_t c = 0; c + size <= maxRegs; c += alignment) { + bool allFree = true; + for (int64_t o = 0; o < size; ++o) { + if (!free.test(c + o)) { + allFree = false; + break; + } + } + if (allFree) { + for (int64_t o = 0; o < size; ++o) + free.reset(c + o); + currentUsage += size; + updatePeak(); + return c; + } + } + return -1; + } - for (int64_t candidate : freeList) { - // Check alignment - if (candidate % alignment != 0) - continue; + /// Allocate the highest-numbered free register below ceiling. + int64_t allocSingleFromTop(int64_t ceiling = -1) { + int64_t cap = (ceiling > 0 && ceiling <= maxRegs) ? ceiling : maxRegs; + // Scan from cap-1 downward for the first free register. + for (int64_t i = cap - 1; i >= 0; --i) { + if (free.test(i)) { + free.reset(i); + ++currentUsage; + updatePeak(); + return i; + } + } + return -1; + } + + /// Allocate a contiguous range from the top of a capped region. + /// Scans from `ceiling` downward to pack long-lived values at the top + /// of the USED range, not the top of the entire register file. + int64_t allocRangeFromTop(int64_t size, int64_t alignment, + int64_t ceiling = -1) { + if (size <= 0) + return -1; - // Check if all registers in range are free + int64_t cap = (ceiling > 0 && ceiling <= maxRegs) ? ceiling : maxRegs; + int64_t highestBase = ((cap - size) / alignment) * alignment; + for (int64_t c = highestBase; c >= 0; c -= alignment) { bool allFree = true; - for (int64_t offset = 0; offset < size; ++offset) { - int64_t reg = candidate + offset; - if (reg >= maxRegs || !freeSet.contains(reg)) { + for (int64_t o = 0; o < size; ++o) { + if (!free.test(c + o)) { allFree = false; break; } } - if (allFree) { - // Allocate the range - for (int64_t offset = 0; offset < size; ++offset) { - int64_t reg = candidate + offset; - freeList.erase(std::find(freeList.begin(), freeList.end(), reg)); - allocated.insert(reg); - } + for (int64_t o = 0; o < size; ++o) + free.reset(c + o); + currentUsage += size; updatePeak(); - return candidate; + return c; } } - - return -1; // Allocation failed + return -1; } - /// Free a single register + /// Free a single register back to the pool. void freeSingle(int64_t reg) { - if (!allocated.contains(reg)) + if (reg < 0 || reg >= maxRegs || free.test(reg)) return; - - allocated.erase(reg); - - // Insert in sorted order - auto it = std::lower_bound(freeList.begin(), freeList.end(), reg); - freeList.insert(it, reg); + free.set(reg); + --currentUsage; } - /// Free a range of registers + /// Free a contiguous range of registers. void freeRange(int64_t base, int64_t size) { for (int64_t offset = 0; offset < size; ++offset) { freeSingle(base + offset); } } - /// Get peak usage int64_t getPeakUsage() const { return peak; } - /// Get current usage - int64_t getCurrentUsage() const { return allocated.size(); } + int64_t getCurrentUsage() const { return currentUsage; } - /// Get register class RegClass getRegClass() const { return regClass; } private: - void updatePeak() { - peak = std::max(peak, static_cast(allocated.size())); - } + void updatePeak() { peak = std::max(peak, currentUsage); } RegClass regClass; int64_t maxRegs; - llvm::SmallVector freeList; // Sorted - llvm::DenseSet allocated; + llvm::BitVector free; + int64_t currentUsage = 0; int64_t peak = 0; }; diff --git a/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h b/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h index 8f3902fca1..2fc46f39cd 100644 --- a/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h +++ b/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h @@ -389,8 +389,7 @@ class TranslationContext { /// Get next available swizzle SRD index (for cache swizzle SRDs and /// per-workgroup SRD base adjustments). /// These are allocated after all regular SRDs, computed in emitSRDPrologue(). - /// Each allocation reserves 5 SGPRs: s[N..N+3] for the SRD quad plus - /// s[N+4] as a temporary for the multiply low-half result. + /// Each allocation reserves 4 SGPRs: s[N..N+3] for the SRD quad. int64_t getNextSwizzleSRDIndex() { if (nextSwizzleSRDIndex < 0) { int64_t maxSrdEnd = 24; @@ -401,9 +400,8 @@ class TranslationContext { nextSwizzleSRDIndex = (maxSrdEnd + 3) & ~3; // Align to 4 } int64_t idx = nextSwizzleSRDIndex; - // 4 SRD SGPRs + 1 temp for byteOffLo, padded to next 4-aligned index - // (SRD buffer descriptors require 4-SGPR alignment on AMDGCN). - nextSwizzleSRDIndex = (idx + 5 + 3) & ~3; + // SRD buffer descriptors require 4-SGPR alignment on AMDGCN. + nextSwizzleSRDIndex = idx + 4; return idx; } diff --git a/waveasm/lib/Transforms/AssemblyEmitter.cpp b/waveasm/lib/Transforms/AssemblyEmitter.cpp index 860f4c1273..f6ce8afb8b 100644 --- a/waveasm/lib/Transforms/AssemblyEmitter.cpp +++ b/waveasm/lib/Transforms/AssemblyEmitter.cpp @@ -166,9 +166,16 @@ std::string KernelGenerator::emitBufferStore(Operation *op, // Use VMEMStoreOpInterface to access operands by name if (auto storeOp = dyn_cast(op)) { std::string vdata = resolveValue(storeOp.getData()); - std::string voffset = resolveValue(storeOp.getVoffset()); + Value voffsetVal = storeOp.getVoffset(); std::string srd = resolveValue(storeOp.getSaddr()); - result += " " + vdata + ", " + voffset + ", " + srd + ", 0 offen"; + + if (isa(voffsetVal.getType())) { + result += " " + vdata + ", off, " + srd + ", 0"; + } else { + std::string voffset = resolveValue(voffsetVal); + result += " " + vdata + ", " + voffset + ", " + srd + ", 0 offen"; + } + if (auto instOffsetAttr = op->getAttrOfType("instOffset")) { int64_t offset = instOffsetAttr.getInt(); if (offset > 0) { @@ -533,16 +540,18 @@ std::optional KernelGenerator::generateOp(Operation *op) { std::string src = resolveValue(srcVal); std::string lines; if (isAGPR) { - // v_accvgpr_write_b32 requires a VGPR source in this backend. - // Materialize immediate sources into the reserved scratch VGPR. + auto [isLit, litVal] = getLiteralValue(srcVal); std::string writeSrc = src; - if (srcIsImm) { + if (srcIsImm && !(isLit && isInlineConstant(litVal))) { + // Non-inline literal: materialize into scratch VGPR first. lines += " v_mov_b32 " + formatVGPRRange(kScratchVGPR, 1) + ", " + src; writeSrc = formatVGPRRange(kScratchVGPR, 1); peakVGPRs = std::max(peakVGPRs, kScratchVGPR + 1); invalidateScratchCache(); } + // Inline constants (e.g. 0) go directly into + // v_accvgpr_write_b32 without a scratch VGPR. for (int64_t i = 0; i < size; ++i) { if (!lines.empty()) lines += "\n"; @@ -563,6 +572,12 @@ std::optional KernelGenerator::generateOp(Operation *op) { if (isAGPR) { if (srcIsImm) { std::string src = resolveValue(srcVal); + auto [isLit, litVal] = getLiteralValue(srcVal); + if (isLit && isInlineConstant(litVal)) { + // Inline constant: emit directly without scratch VGPR. + return " v_accvgpr_write_b32 " + resolveValue(result) + ", " + + src; + } std::string scratch = formatVGPRRange(kScratchVGPR, 1); peakVGPRs = std::max(peakVGPRs, kScratchVGPR + 1); invalidateScratchCache(); @@ -695,10 +710,7 @@ std::optional KernelGenerator::generateOp(Operation *op) { SmallVector handled(pendingCopies.size(), false); - // Allocate swap temps once and reuse across all swaps. - // Swaps are emitted sequentially so the temp is dead after - // each 3-instruction sequence and can be reused. - int64_t vSwapTemp = -1; + // SGPR swaps still need a temp; VGPR swaps use v_swap_b32. int64_t sSwapTemp = -1; for (size_t i = 0; i < pendingCopies.size(); ++i) { @@ -728,13 +740,7 @@ std::optional KernelGenerator::generateOp(Operation *op) { } int64_t regA = pendingCopies[i].dst; int64_t regB = pendingCopies[j].dst; - if (vSwapTemp < 0) { - vSwapTemp = peakVGPRs; - peakVGPRs = std::max(peakVGPRs, vSwapTemp + 1); - } - os << " v_mov_b32 v" << vSwapTemp << ", v" << regA << "\n"; - os << " v_mov_b32 v" << regA << ", v" << regB << "\n"; - os << " v_mov_b32 v" << regB << ", v" << vSwapTemp << "\n"; + os << " v_swap_b32 v" << regA << ", v" << regB << "\n"; handled[i] = true; handled[j] = true; break; @@ -842,14 +848,22 @@ std::optional KernelGenerator::generateOp(Operation *op) { .Case( [&](YieldOp) -> std::optional { return std::nullopt; }) .Case( + S_CMP_GE_U32, S_CMP_NE_U32, S_CMP_LT_I32, S_CMP_EQ_I32, + S_CMP_LE_I32, S_CMP_GT_I32, S_CMP_GE_I32, S_CMP_NE_I32>( [&](auto cmpOp) -> std::optional { llvm::StringRef opName = cmpOp->getName().getStringRef(); llvm::StringRef mnemonic = opName; if (opName.starts_with("waveasm.")) { mnemonic = opName.drop_front(8); } + // s_cmp_ne_* → s_cmp_lg_* (ISA mnemonic) + std::string mnemStr; + if (mnemonic.contains("_ne_")) { + mnemStr = mnemonic.str(); + size_t pos = mnemStr.find("_ne_"); + mnemStr.replace(pos, 4, "_lg_"); + mnemonic = mnemStr; + } llvm::SmallVector operands; for (Value operand : cmpOp->getOperands()) { operands.push_back(resolveValue(operand)); diff --git a/waveasm/lib/Transforms/BufferLoadStrengthReduction.cpp b/waveasm/lib/Transforms/BufferLoadStrengthReduction.cpp index e2672df944..185fc93b76 100644 --- a/waveasm/lib/Transforms/BufferLoadStrengthReduction.cpp +++ b/waveasm/lib/Transforms/BufferLoadStrengthReduction.cpp @@ -616,6 +616,53 @@ static void applyStrengthReduction(LoopOp loopOp) { loopOp.erase(); } +// Peephole: when a buffer_load has voffset = V_ADD_U32(vgpr, sgpr) and +// soffset = 0, fold the SGPR addend into soffset. This avoids a VALU +// instruction per load by using the hardware scalar offset field. +static void peepholeSoffsetFold(Operation *root) { + root->walk([&](Operation *op) { + if (!isBufferLoad(op) && !isBufferLoadLDS(op)) + return; + if (op->getNumOperands() < 3) + return; + + unsigned soffsetIdx = 2; + auto soffsetConst = getConstantValue(op->getOperand(soffsetIdx)); + if (!soffsetConst || *soffsetConst != 0) + return; + + unsigned voffsetIdx = getVoffsetIdx(op); + Value voffset = op->getOperand(voffsetIdx); + auto addOp = voffset.getDefiningOp(); + if (!addOp) + return; + + Value src0 = addOp.getSrc0(); + Value src1 = addOp.getSrc1(); + Value vgprPart = nullptr; + Value sgprPart = nullptr; + + if (isVGPRType(src0.getType()) && isSGPRType(src1.getType())) { + vgprPart = src0; + sgprPart = src1; + } else if (isSGPRType(src0.getType()) && isVGPRType(src1.getType())) { + vgprPart = src1; + sgprPart = src0; + } + if (!vgprPart || !sgprPart) + return; + + // Only fold if the V_ADD_U32 is used exclusively by buffer_loads. + for (Operation *user : addOp->getUsers()) { + if (!isBufferLoad(user) && !isBufferLoadLDS(user)) + return; + } + + op->setOperand(voffsetIdx, vgprPart); + op->setOperand(soffsetIdx, sgprPart); + }); +} + struct BufferLoadStrengthReductionPass : public waveasm::impl::WAVEASMBufferLoadStrengthReductionBase< BufferLoadStrengthReductionPass> { @@ -628,6 +675,7 @@ struct BufferLoadStrengthReductionPass module->walk([&](LoopOp loopOp) { loops.push_back(loopOp); }); for (auto loopOp : loops) applyStrengthReduction(loopOp); + peepholeSoffsetFold(module); } }; diff --git a/waveasm/lib/Transforms/CMakeLists.txt b/waveasm/lib/Transforms/CMakeLists.txt index 20919ac70b..f1a1bdae5f 100644 --- a/waveasm/lib/Transforms/CMakeLists.txt +++ b/waveasm/lib/Transforms/CMakeLists.txt @@ -14,11 +14,9 @@ foreach(src ${HANDLERS_SRCS}) endforeach() add_mlir_dialect_library(MLIRWaveASMTransforms - ArithLegalization.cpp AssemblyEmitter.cpp BufferLoadStrengthReduction.cpp ExtractScalarization.cpp - GPUModuleToBinary.cpp HazardMitigation.cpp LinearScanPass.cpp LinearScanRegAlloc.cpp @@ -31,9 +29,10 @@ add_mlir_dialect_library(MLIRWaveASMTransforms Peephole.cpp RegionBuilder.cpp ScalePackElimination.cpp + SCCVerifier.cpp ScopedCSE.cpp Ticketing.cpp - TranslateFromLLVMDialect.cpp + VGPRCompaction.cpp TranslateFromMLIR.cpp ${HANDLERS_FULL_PATHS} @@ -50,12 +49,8 @@ add_mlir_dialect_library(MLIRWaveASMTransforms MLIRFuncDialect MLIRGPUDialect MLIRIR - MLIRLLVMDialect MLIRMathDialect MLIRMemRefDialect - MLIRPass - MLIRROCDLDialect - MLIRROCDLTarget MLIRSCFDialect MLIRSupport MLIRVectorDialect diff --git a/waveasm/lib/Transforms/LinearScanPass.cpp b/waveasm/lib/Transforms/LinearScanPass.cpp index fef599358d..2012a96d3e 100644 --- a/waveasm/lib/Transforms/LinearScanPass.cpp +++ b/waveasm/lib/Transforms/LinearScanPass.cpp @@ -56,6 +56,131 @@ static Type makePhysicalType(MLIRContext *ctx, Type virtualType, return virtualType; } +//===----------------------------------------------------------------------===// +// Rematerialization +//===----------------------------------------------------------------------===// + +/// Return true when \p op sits inside a LoopOp body at any nesting depth. +static bool isInsideLoopBody(Operation *op) { + for (Operation *p = op->getParentOp(); p; p = p->getParentOp()) { + if (isa(p)) + return true; + } + return false; +} + +/// Check if an operation can be cheaply rematerialized (cloned near each +/// use to shorten its live range). +/// +/// Accepted patterns: +/// - V_MOV_B32 with immediate operands (accumulator zero-init) +/// - V_MOV_B32 with SGPR source (scalar-to-vector address copy) +/// - S_MOV_B32 with immediate operands (scalar constant materialisation) +static bool isRematerializableOp(Operation *op) { + if (op->getNumResults() != 1) + return false; + + if (isa(op)) { + for (Value operand : op->getOperands()) { + auto *defOp = operand.getDefiningOp(); + if (!defOp) + return false; + if (isa(defOp)) + continue; + // SGPR source: the scalar value dominates the original def, so it + // also dominates every use site where we might place a clone. + if (isSGPRType(operand.getType())) + continue; + return false; + } + return true; + } + + if (isa(op)) { + for (Value operand : op->getOperands()) { + auto *defOp = operand.getDefiningOp(); + if (!defOp || !isa(defOp)) + return false; + } + return true; + } + + return false; +} + +/// Rematerialize cheap-to-compute VGPR values by cloning their defining ops +/// near each use site, shortening live ranges and reducing peak register +/// pressure at the cost of slightly increased code size. +/// +/// A `v_mov_b32 %v, 0` defined at instruction 5 and used at instruction 100 +/// holds a VGPR for 95 instructions. After rematerialization the clone +/// appears at instruction 99 with a 1-instruction live range, freeing the +/// VGPR(s) for the other 94 instructions. For multi-register results +/// (e.g. 4-wide accumulators) this frees 4 VGPRs across that span. +static void rematerializeCheapOps(ProgramOp program) { + constexpr int64_t kMinRematRangeLength = 10; + + LivenessInfo liveness = computeLiveness(program); + + llvm::SmallVector candidates; + program.walk([&](Operation *op) { + if (!isRematerializableOp(op)) + return; + Value result = op->getResult(0); + // Accept both VGPR and SGPR results (V_MOV_B32 -> VGPR, S_MOV_B32 -> SGPR). + if (!isVGPRType(result.getType()) && !isSGPRType(result.getType())) + return; + const LiveRange *range = liveness.getRange(result); + if (!range || range->length() <= kMinRematRangeLength) + return; + candidates.push_back(op); + }); + + for (Operation *op : candidates) { + Value result = op->getResult(0); + bool defOutsideLoop = !isInsideLoopBody(op); + bool isVGPR = isVGPRType(result.getType()); + + llvm::SmallVector uses; + for (OpOperand &use : result.getUses()) + uses.push_back(&use); + + if (uses.empty()) + continue; + + // Clone once per unique user operation to avoid redundant copies when + // the same op references the value in multiple operand slots. + llvm::DenseMap cloneCache; + + for (OpOperand *use : uses) { + Operation *user = use->getOwner(); + + // Preserve VALU-free loop bodies: when a VGPR-producing op is defined + // outside a loop but a use is inside, skip that use. The value's live + // range already spans the entire loop body (Pass 2b extends it), so + // cloning inside wouldn't reduce in-loop pressure — it would only add + // a VALU instruction to the critical loop path. + // SALU ops (S_MOV_B32) don't use the VALU pipeline, so they're fine. + if (defOutsideLoop && isVGPR && isInsideLoopBody(user)) + continue; + + auto it = cloneCache.find(user); + if (it != cloneCache.end()) { + use->set(it->second); + } else { + OpBuilder builder(user); + Operation *clone = builder.clone(*op); + Value cloned = clone->getResult(0); + use->set(cloned); + cloneCache[user] = cloned; + } + } + + if (result.use_empty()) + op->erase(); + } +} + namespace { //===----------------------------------------------------------------------===// @@ -215,6 +340,12 @@ struct LinearScanPass } }); + // Rematerialize cheap VGPR ops (v_mov_b32 from immediates) near their + // use sites to shorten live ranges and reduce peak register pressure. + // This must run after duplicate-init-arg handling (which creates new + // V_MOV_B32 ops) and before allocation (which consumes the IR). + rematerializeCheapOps(program); + // Create allocator with precolored values and tied operands. // MFMA ties come from the local tiedPairs map; loop ties come from // the TiedValueClasses built during liveness analysis (see below). diff --git a/waveasm/lib/Transforms/LinearScanRegAlloc.cpp b/waveasm/lib/Transforms/LinearScanRegAlloc.cpp index 88e38a98bd..429c6cf3b6 100644 --- a/waveasm/lib/Transforms/LinearScanRegAlloc.cpp +++ b/waveasm/lib/Transforms/LinearScanRegAlloc.cpp @@ -70,15 +70,25 @@ static void insertActiveRange(llvm::SmallVectorImpl &active, active.insert(insertPos, newRange); } -/// Try to allocate a physical register from the pool. -/// Returns the allocated register index, or std::nullopt on failure. +/// Try to allocate a physical register from the pool (lowest-first). static std::optional tryAllocate(RegPool &pool, int64_t size, int64_t alignment) { int64_t physReg = (size == 1) ? pool.allocSingle() : pool.allocRange(size, alignment); - if (physReg < 0) { + if (physReg < 0) + return std::nullopt; + return physReg; +} + +/// Try to allocate from the top of a capped region (highest-first). +static std::optional tryAllocateFromTop(RegPool &pool, int64_t size, + int64_t alignment, + int64_t ceiling) { + int64_t physReg = (size == 1) + ? pool.allocSingleFromTop(ceiling) + : pool.allocRangeFromTop(size, alignment, ceiling); + if (physReg < 0) return std::nullopt; - } return physReg; } @@ -167,8 +177,31 @@ allocateRegClass(ArrayRef ranges, RegPool &pool, } } - // Allocate physical register(s) from the pool - physReg = tryAllocate(pool, range.size, range.alignment); + // For VGPRs with size > 1 (dwordx2/x4), allocate long-lived ranges + // from the top of the expected register usage and short-lived ranges + // from the bottom. This separates interleaved buffer_load (long-lived + // prefetch) and ds_read (short-lived consumed) destinations into + // contiguous regions, reducing fragmentation and peak VGPR count. + // The ceiling is maxPressure (peak simultaneous VGPRs from liveness), + // not maxRegs (512), to avoid allocating into the AGPR region. + bool useBidirectional = + pool.getRegClass() == RegClass::VGPR && range.size > 1; + if (useBidirectional) { + int64_t rangeLength = range.end - range.start; + // Ranges in the top 10% by length get allocated from the top. + // This targets buffer_load prefetch values (which span almost the + // entire loop body) while leaving ds_read values (consumed within + // one half) at the bottom. + int64_t maxEnd = ranges.back().end; + int64_t threshold = (maxEnd * 3) / 4; + if (rangeLength > threshold) + physReg = tryAllocateFromTop(pool, range.size, range.alignment, + maxPressure); + else + physReg = tryAllocate(pool, range.size, range.alignment); + } + if (!physReg) + physReg = tryAllocate(pool, range.size, range.alignment); if (!physReg) { return program.emitOpError() diff --git a/waveasm/lib/Transforms/LiteralMaterialization.cpp b/waveasm/lib/Transforms/LiteralMaterialization.cpp index 8adea968cf..c090272a23 100644 --- a/waveasm/lib/Transforms/LiteralMaterialization.cpp +++ b/waveasm/lib/Transforms/LiteralMaterialization.cpp @@ -33,11 +33,12 @@ static bool isInlineConstant(int64_t val) { } static const llvm::StringSet<> &getVOP2Instructions() { + // v_add_u32, v_sub_u32, v_subrev_u32 are VOP3-only on GFX9+ (the VOP2 + // carry-producing variants are v_add_co_u32 / v_sub_co_u32 / + // v_subrev_co_u32). VOP3 does not support literal operands. static const llvm::StringSet<> kVOP2 = { - "v_add_u32", "v_sub_u32", "v_subrev_u32", "v_and_b32", - "v_or_b32", "v_xor_b32", "v_lshlrev_b32", "v_lshrrev_b32", - "v_ashrrev_i32", "v_max_u32", "v_min_u32", "v_add_i32", - "v_sub_i32", + "v_and_b32", "v_or_b32", "v_xor_b32", "v_lshlrev_b32", + "v_lshrrev_b32", "v_ashrrev_i32", "v_max_u32", "v_min_u32", }; return kVOP2; } diff --git a/waveasm/lib/Transforms/Liveness.cpp b/waveasm/lib/Transforms/Liveness.cpp index b4c3e987f8..f404c34693 100644 --- a/waveasm/lib/Transforms/Liveness.cpp +++ b/waveasm/lib/Transforms/Liveness.cpp @@ -31,44 +31,36 @@ static bool isSwapPatternIterArg(Value iterArg, Block &bodyBlock, unsigned i) { return false; } -/// Detect a write-after-read (WAR) hazard between a buffer_load iter_arg -/// and the block_arg it feeds back into. +/// Detect a write-after-read (WAR) hazard between an iter_arg and the +/// block_arg it feeds back into. /// -/// In pipelined schedules, next-iteration loads can be interleaved with -/// MFMAs that still consume the current iteration's block_arg. If the -/// allocator ties them to the same register, the MFMA silently reads the -/// new load value instead of the old one. +/// If the iter_arg is defined at a point where the block_arg still has +/// subsequent uses, tying them to the same physical register would let +/// the in-place update clobber a value that later instructions still need. /// -/// For single-element loads (buffer_load_ubyte/sbyte/ushort/sshort), the -/// block_arg is consumed indirectly through vector.bitcast / vector.extract -/// that share the same physical register. The direct use-point check -/// misses these transitive uses, so we unconditionally flag them. -static bool hasBufferLoadWARHazard(Value iterArg, Value blockArg, - const LivenessInfo &info) { +/// This is critical for unrolled loops where CSE can merge an +/// affine.apply result (e.g. arg8+2 for a bounds check) with the IV +/// increment (also arg8+2 for step=2). After merging, the single +/// S_ADD_U32 is placed in the middle of the body; if it shares a +/// register with the IV block_arg, the second unrolled copy's use of +/// the original IV reads the wrong value. +static bool hasWARHazard(Value iterArg, Value blockArg, + const LivenessInfo &info) { if (isa(iterArg)) return false; auto *defOp = iterArg.getDefiningOp(); if (!defOp) return false; - auto opName = defOp->getName().getStringRef(); - if (!opName.contains("buffer_load") || opName.contains("_lds")) - return false; - - // Single-element loads: unconditionally untie (transitive uses hidden). - if (opName.contains("_ubyte") || opName.contains("_sbyte") || - opName.contains("_ushort") || opName.contains("_sshort")) { - LLVM_DEBUG(llvm::dbgs() - << " WAR hazard (single-element load): " << opName << "\n"); - return true; - } - // Multi-register loads: check for def/use overlap. + // Check for def/use overlap: if the iter_arg is defined at a point + // where block_arg still has subsequent uses, tying them creates a + // WAR hazard. auto iterDefIt = info.defPoints.find(iterArg); auto baUseIt = info.usePoints.find(blockArg); if (iterDefIt != info.defPoints.end() && baUseIt != info.usePoints.end()) { - int64_t loadDef = iterDefIt->second; + int64_t iterDef = iterDefIt->second; for (int64_t usePoint : baUseIt->second) { - if (usePoint >= loadDef) + if (usePoint > iterDef) return true; } } @@ -554,11 +546,11 @@ LivenessInfo computeLiveness(ProgramOp program) { // Condition iter_arg -> block arg. // Skip swap patterns and WAR hazards so the allocator keeps them - // in separate registers (see hasBufferLoadWARHazard). + // in separate registers (see hasWARHazard). if (i < condOp.getIterArgs().size()) { Value iterArg = condOp.getIterArgs()[i]; bool skip = isSwapPatternIterArg(iterArg, bodyBlock, i) || - hasBufferLoadWARHazard(iterArg, blockArg, info); + hasWARHazard(iterArg, blockArg, info); if (!skip && info.ranges.contains(iterArg)) members.push_back(iterArg); } @@ -610,7 +602,7 @@ LivenessInfo computeLiveness(ProgramOp program) { if (i < condOp.getIterArgs().size()) { Value iterArg = condOp.getIterArgs()[i]; bool skip = isSwapPatternIterArg(iterArg, bodyBlock, i) || - hasBufferLoadWARHazard(iterArg, blockArg, info); + hasWARHazard(iterArg, blockArg, info); if (!skip && info.ranges.contains(iterArg) && !tc.tiedPairs.contains(iterArg)) tc.tiedPairs[iterArg] = blockArg; @@ -634,11 +626,19 @@ LivenessInfo computeLiveness(ProgramOp program) { } } - // Sort by (start, end) for linear scan + // Sort by (start, -size, -alignment, -end) for linear scan. + // At the same start point, allocate larger and more constrained ranges + // first to reduce fragmentation — a 16-wide aligned range has very few + // valid slots, so giving it priority prevents smaller ranges from + // blocking its only valid position. auto sortByStart = [](const LiveRange &a, const LiveRange &b) { if (a.start != b.start) return a.start < b.start; - return a.end < b.end; + if (a.size != b.size) + return a.size > b.size; + if (a.alignment != b.alignment) + return a.alignment > b.alignment; + return a.end > b.end; }; llvm::sort(info.vregRanges, sortByStart); @@ -650,6 +650,11 @@ LivenessInfo computeLiveness(ProgramOp program) { info.maxSRegPressure = computeMaxPressure(info.sregRanges, info.tiedClasses); info.maxARegPressure = computeMaxPressure(info.aregRanges, info.tiedClasses); + // Dump detailed pressure breakdown for debugging. + dumpPeakPressureInfo(info, ops, RegClass::VGPR); + dumpPeakPressureInfo(info, ops, RegClass::SGPR); + dumpPeakPressureInfo(info, ops, RegClass::AGPR); + return info; } diff --git a/waveasm/lib/Transforms/SCCVerifier.cpp b/waveasm/lib/Transforms/SCCVerifier.cpp new file mode 100644 index 0000000000..b65fc45a42 --- /dev/null +++ b/waveasm/lib/Transforms/SCCVerifier.cpp @@ -0,0 +1,138 @@ +// Copyright 2026 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// SCC Verifier Pass +// +// Verifies that no SCC-clobbering SALU instruction is placed between an +// SCC-producing op and its consumer. Uses isa<> checks instead of +// hasTrait() because adding traits to existing op classes changes +// ODS-generated C++ and causes MLIR passes to produce different IR. +//===----------------------------------------------------------------------===// + +#include "waveasm/Dialect/WaveASMDialect.h" +#include "waveasm/Dialect/WaveASMInterfaces.h" +#include "waveasm/Dialect/WaveASMOps.h" +#include "waveasm/Transforms/Passes.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "waveasm-scc-verifier" + +using namespace mlir; +using namespace waveasm; + +namespace waveasm { +#define GEN_PASS_DEF_WAVEASMSCCVERIFIER +#include "waveasm/Transforms/Passes.h.inc" +} // namespace waveasm + +namespace { + +/// Returns true if the operation writes the SCC flag on hardware. +/// Uses hasTrait for ops that carry the trait (carry ops, cmp ops), +/// and isa<> for ops still on SALUBinaryOp/SALUUnaryOp (pending migration). +static bool writesSCC(Operation *op) { + return op->hasTrait(); +} + +struct SCCVerifierPass + : public waveasm::impl::WAVEASMSCCVerifierBase { + using WAVEASMSCCVerifierBase::WAVEASMSCCVerifierBase; + + void runOnOperation() override { + Operation *module = getOperation(); + unsigned errorCount = 0; + module->walk([&](ProgramOp program) { + for (Block &block : program.getBody()) + errorCount += verifyBlock(block); + }); + if (errorCount > 0) { + LLVM_DEBUG(llvm::dbgs() << "SCC verifier: found " << errorCount + << " SCC hazard(s)\n"); + signalPassFailure(); + } + } + +private: + static SmallVector findSCCClobbersBetween(Operation *producer, + Operation *consumer) { + SmallVector clobbers; + if (!producer || !consumer || producer->getBlock() != consumer->getBlock()) + return clobbers; + bool inRange = false; + for (Operation &op : *producer->getBlock()) { + if (&op == producer) { inRange = true; continue; } + if (&op == consumer) break; + if (inRange && writesSCC(&op)) + clobbers.push_back(&op); + } + return clobbers; + } + + static void emitSCCClobberError(Operation *consumer, Operation *producer, + ArrayRef clobbers) { + auto diag = consumer->emitError() + << "SCC hazard: " << clobbers.size() + << " SCC-clobbering op(s) between SCC producer '" + << producer->getName() << "' and consumer '" + << consumer->getName() << "'"; + for (Operation *c : clobbers) + diag.attachNote(c->getLoc()) + << "SCC clobbered here by '" << c->getName() << "'"; + diag.attachNote(producer->getLoc()) << "SCC defined here"; + } + + unsigned verifyBlock(Block &block) { + unsigned errors = 0; + Operation *lastSCCWriter = nullptr; + for (Operation &op : block) { + if (auto condOp = dyn_cast(&op)) { + Value cond = condOp.getCondition(); + Operation *condDef = cond.getDefiningOp(); + if (condDef && lastSCCWriter && lastSCCWriter != condDef) { + auto clobbers = findSCCClobbersBetween(condDef, &op); + if (!clobbers.empty()) { + emitSCCClobberError(&op, condDef, clobbers); + ++errors; + } + } + } + if (auto ifOp = dyn_cast(&op)) { + Value cond = ifOp.getCondition(); + Operation *condDef = cond.getDefiningOp(); + if (condDef && lastSCCWriter && lastSCCWriter != condDef) { + auto clobbers = findSCCClobbersBetween(condDef, &op); + if (!clobbers.empty()) { + emitSCCClobberError(&op, condDef, clobbers); + ++errors; + } + } + } + if (isa(&op) && !lastSCCWriter) { + op.emitError() + << "SCC hazard: s_cselect_b32 has no preceding SCC writer"; + ++errors; + } + if (isa(&op) && !lastSCCWriter) { + op.emitError() + << "SCC hazard: s_addc_u32 has no preceding SCC writer"; + ++errors; + } + if (writesSCC(&op)) + lastSCCWriter = &op; + for (Region ®ion : op.getRegions()) + for (Block &nestedBlock : region) + errors += verifyBlock(nestedBlock); + } + return errors; + } +}; + +} // namespace diff --git a/waveasm/lib/Transforms/ScopedCSE.cpp b/waveasm/lib/Transforms/ScopedCSE.cpp index 4c589060fa..0db4f14860 100644 --- a/waveasm/lib/Transforms/ScopedCSE.cpp +++ b/waveasm/lib/Transforms/ScopedCSE.cpp @@ -155,6 +155,11 @@ bool isCSEEligible(Operation *op) { if (op->hasTrait()) return false; + // SCC-reading ops are NOT CSE-eligible: their result depends on implicit + // SCC state, so two ops with identical operands can produce different results. + if (op->hasTrait()) + return false; + // Precolored registers are not eligible (they're fixed) if (isa(op)) return false; diff --git a/waveasm/lib/Transforms/TranslateFromMLIR.cpp b/waveasm/lib/Transforms/TranslateFromMLIR.cpp index 9b7cad9608..0d98072715 100644 --- a/waveasm/lib/Transforms/TranslateFromMLIR.cpp +++ b/waveasm/lib/Transforms/TranslateFromMLIR.cpp @@ -191,6 +191,12 @@ void TranslationContext::emitSRDPrologue() { int64_t afterOverflow = afterLastSrd + numOverflowScalars * 2; nextSwizzleSRDIndex = (afterOverflow + 3) & ~3; // Align to 4 + // Dedicated SGPR slots for scalar kernel args, placed after overflow-scalar + // slots so they don't collide with SRDs or preloads. + int64_t scalarArgSgprBase = nextSwizzleSRDIndex; + nextSwizzleSRDIndex = + (nextSwizzleSRDIndex + (int64_t)pendingScalarArgs.size() + 3) & ~3; + // Emit comment for prologue CommentOp::create(builder, loc, "SRD setup prologue"); @@ -322,14 +328,22 @@ void TranslationContext::emitSRDPrologue() { ", 0x" + llvm::utohexstr(kSRDStrideSwizzle); RawOp::create(builder, loc, movStrideStr); + // Prevent DCE from removing this PrecoloredSRegOp. The SRD registers + // are later referenced by RawOps (e.g., s_mov_b64 for epilogue SRD + // copies) that don't create SSA uses. Without DCEProtect, canonicalize + // removes the PrecoloredSRegOp, and the register allocator doesn't + // reserve s[srdBase:srdBase+3], allowing it to allocate temps there. + DCEProtectOp::create(builder, loc, srdReg); + mapper.mapValue(pending.memref, srdReg); } - // Move scalar args from preload SGPRs to VGPRs. + // Copy scalar args from preload/overflow SGPRs to dedicated SGPRs. // Lower 32 bits of the preload pair hold the value (little-endian). // Overflow args were loaded into overflowSgprBase positions above. int64_t ovfIdx = 0; - for (const auto &pending : pendingScalarArgs) { + for (size_t i = 0; i < pendingScalarArgs.size(); ++i) { + const auto &pending = pendingScalarArgs[i]; int64_t preloadBase = 2 + pending.argIndex * 2; int64_t sgprSrc; if (preloadBase >= 16) { @@ -338,13 +352,17 @@ void TranslationContext::emitSRDPrologue() { } else { sgprSrc = preloadBase; } - auto vregType = createVRegType(); - auto vreg = - PrecoloredVRegOp::create(builder, loc, vregType, pending.argIndex, 1); + int64_t sgprDst = scalarArgSgprBase + i; + + auto dstType = createSRegType(1, 1); + auto dstSreg = + PrecoloredSRegOp::create(builder, loc, dstType, sgprDst, 1); + RawOp::create(builder, loc, - "v_mov_b32 v" + std::to_string(pending.argIndex) + ", s" + + "s_mov_b32 s" + std::to_string(sgprDst) + ", s" + std::to_string(sgprSrc)); - mapper.mapValue(pending.blockArg, vreg); + + mapper.mapValue(pending.blockArg, dstSreg); } } else { // Non-GFX95* path (e.g., gfx942): Load directly into SRD positions @@ -419,18 +437,15 @@ void TranslationContext::emitSRDPrologue() { mapper.mapValue(pending.memref, srdReg); } - // Move scalar args from SGPRs to VGPRs + // Map scalar args to their dedicated SGPRs (already loaded above). for (size_t i = 0; i < pendingScalarArgs.size(); ++i) { const auto &pending = pendingScalarArgs[i]; int64_t sgprIdx = scalarSgprBase + (int64_t)i; - auto vregType = createVRegType(); - auto vreg = - PrecoloredVRegOp::create(builder, loc, vregType, pending.argIndex, 1); - RawOp::create(builder, loc, - "v_mov_b32 v" + std::to_string(pending.argIndex) + ", s" + - std::to_string(sgprIdx)); - mapper.mapValue(pending.blockArg, vreg); + auto dstType = createSRegType(1, 1); + auto dstSreg = + PrecoloredSRegOp::create(builder, loc, dstType, sgprIdx, 1); + mapper.mapValue(pending.blockArg, dstSreg); } } @@ -587,7 +602,7 @@ Value emitSRDBaseAdjustment(const TranslationContext::PendingSRDBaseAdjust &adj, auto *mlirCtx = builder.getContext(); int64_t N = ctx.getNextSwizzleSRDIndex(); - assert(N + 4 < 108 && "SRD allocation exceeds SGPR limit"); + assert(N + 4 <= 102 && "SRD allocation exceeds SGPR limit (s0-s101)"); // Copy source SRD base to new SRD. // Must use RawOp: S_MOV_B64 is Pure (SALUUnaryOp) and writes to a @@ -1004,12 +1019,20 @@ LogicalResult handleVectorTransferWrite(Operation *op, voffset = ConstantOp::create(builder, loc, immType, 0); } + Value storeData = *data; + if (isAGPRType(storeData.getType())) { + auto vregType = + ctx.createVRegType(numDwords, numDwords > 1 ? numDwords : 1); + storeData = + V_ACCVGPR_READ_B32::create(builder, loc, vregType, storeData); + } + if (numDwords == 1) { - BUFFER_STORE_DWORD::create(builder, loc, *data, srd, voffset); + BUFFER_STORE_DWORD::create(builder, loc, storeData, srd, voffset); } else if (numDwords == 2) { - BUFFER_STORE_DWORDX2::create(builder, loc, *data, srd, voffset); + BUFFER_STORE_DWORDX2::create(builder, loc, storeData, srd, voffset); } else { - BUFFER_STORE_DWORDX4::create(builder, loc, *data, srd, voffset); + BUFFER_STORE_DWORDX4::create(builder, loc, storeData, srd, voffset); } } diff --git a/waveasm/lib/Transforms/VGPRCompaction.cpp b/waveasm/lib/Transforms/VGPRCompaction.cpp new file mode 100644 index 0000000000..43f77ca1b2 --- /dev/null +++ b/waveasm/lib/Transforms/VGPRCompaction.cpp @@ -0,0 +1,505 @@ +// Copyright 2026 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// VGPR Compaction Pass +// +// Re-assigns physical VGPRs after register allocation to minimize the peak +// register number. The linear scan allocator assigns VGPRs in instruction +// order (lowest-first), which causes fragmentation when interleaved +// buffer_load (long-lived) and ds_read (short-lived) instructions get +// interleaved register numbers. This pass reassigns them using a +// shortest-first greedy strategy that packs short-lived values into low +// registers, pushing long-lived values to a contiguous high range. +//===----------------------------------------------------------------------===// + +#include "waveasm/Dialect/WaveASMDialect.h" +#include "waveasm/Dialect/WaveASMOps.h" +#include "waveasm/Dialect/WaveASMTypes.h" +#include "waveasm/Transforms/Passes.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "waveasm-vgpr-compaction" + +using namespace mlir; +using namespace waveasm; + +namespace waveasm { +#define GEN_PASS_DEF_WAVEASMVGPRCOMPACTION +#include "waveasm/Transforms/Passes.h.inc" +} // namespace waveasm + +namespace { + +struct PhysVGPRRange { + int64_t physIdx; + int64_t size; + int64_t alignment; + int64_t defPoint; + int64_t lastUsePoint; + bool pinned = false; // Precolored — must keep original position + + int64_t length() const { return lastUsePoint - defPoint; } +}; + +static void collectOps(Block &block, + llvm::SmallVectorImpl &ops) { + for (Operation &op : block) { + ops.push_back(&op); + for (Region ®ion : op.getRegions()) + for (Block &nested : region) + collectOps(nested, ops); + } +} + +/// Record a PVRegType occurrence. Merges into existing entry if the base +/// index matches, taking max size/alignment and extending the time range. +static void recordPVReg(int64_t baseIdx, int64_t size, int64_t point, + bool isDef, + llvm::DenseMap &defPoints, + llvm::DenseMap &usePoints, + llvm::DenseMap &sizes, + llvm::DenseMap &alignments) { + int64_t align = (size >= 4) ? 4 : (size >= 2 ? 2 : 1); + + auto it = sizes.find(baseIdx); + if (it != sizes.end()) { + it->second = std::max(it->second, size); + alignments[baseIdx] = std::max(alignments[baseIdx], align); + if (isDef) + defPoints[baseIdx] = std::min(defPoints[baseIdx], point); + else { + auto uit = usePoints.find(baseIdx); + if (uit != usePoints.end()) + uit->second = std::max(uit->second, point); + else + usePoints[baseIdx] = point; + } + } else { + sizes[baseIdx] = size; + alignments[baseIdx] = align; + if (isDef) + defPoints[baseIdx] = point; + else + usePoints[baseIdx] = point; + } +} + +/// For a PVRegType with index X and size S, find the "allocation base": +/// the base index of the multi-register range that contains X. +/// E.g., if the allocator assigned v[92:95] (base=92, size=4), then +/// v93 (index=93, size=1) has allocation base 92. +/// Returns (allocBase, allocSize) or (idx, size) if no containing range. +static std::pair +findAllocBase(int64_t idx, int64_t size, + const llvm::DenseMap &knownSizes) { + // Check if this index IS a known base + auto it = knownSizes.find(idx); + if (it != knownSizes.end() && it->second >= size) + return {idx, it->second}; + + // Check if this index falls within a known larger range. + // Multi-reg ranges are always aligned, so check aligned bases below idx. + for (int64_t align : {4, 2}) { + int64_t base = (idx / align) * align; + if (base == idx) + continue; + auto baseIt = knownSizes.find(base); + if (baseIt != knownSizes.end() && base + baseIt->second > idx) + return {base, baseIt->second}; + } + + return {idx, size}; +} + +static void buildPhysRanges( + ProgramOp program, + llvm::SmallVectorImpl &ranges, + llvm::DenseMap &rangeBaseToIdx) { + + llvm::SmallVector ops; + collectOps(program.getBodyBlock(), ops); + + llvm::DenseMap opToIdx; + for (int64_t i = 0; i < static_cast(ops.size()); ++i) + opToIdx[ops[i]] = i; + + llvm::DenseMap defPoints, usePoints, sizes, alignments; + + // Track precolored VGPR indices — these must not be remapped. + llvm::DenseSet pinnedIndices; + + // First pass: collect all multi-register definitions to build the + // "known allocations" map. This lets us identify sub-element accesses. + // Also identify precolored VGPRs. + for (int64_t i = 0; i < static_cast(ops.size()); ++i) { + Operation *op = ops[i]; + + if (isa(op)) { + for (Value result : op->getResults()) { + if (auto pvreg = dyn_cast(result.getType())) { + pinnedIndices.insert(pvreg.getIndex()); + for (int64_t k = 0; k < pvreg.getSize(); ++k) + pinnedIndices.insert(pvreg.getIndex() + k); + } + } + } + + for (Value result : op->getResults()) { + if (auto pvreg = dyn_cast(result.getType())) { + if (pvreg.getSize() > 1) { + auto &sz = sizes[pvreg.getIndex()]; + sz = std::max(sz, pvreg.getSize()); + int64_t align = (pvreg.getSize() >= 4) ? 4 : 2; + auto &al = alignments[pvreg.getIndex()]; + al = std::max(al, align); + } + } + } + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (BlockArgument arg : block.getArguments()) { + if (auto pvreg = dyn_cast(arg.getType())) { + if (pvreg.getSize() > 1) { + auto &sz = sizes[pvreg.getIndex()]; + sz = std::max(sz, pvreg.getSize()); + int64_t align = (pvreg.getSize() >= 4) ? 4 : 2; + auto &al = alignments[pvreg.getIndex()]; + al = std::max(al, align); + } + } + } + } + } + } + + // Second pass: collect def/use points, mapping sub-elements to their base. + llvm::DenseMap knownSizes = sizes; + defPoints.clear(); + usePoints.clear(); + sizes.clear(); + alignments.clear(); + + auto processPVReg = [&](int64_t idx, int64_t size, int64_t point, + bool isDef) { + auto [base, allocSize] = findAllocBase(idx, size, knownSizes); + recordPVReg(base, allocSize, point, isDef, defPoints, usePoints, sizes, + alignments); + }; + + for (int64_t i = 0; i < static_cast(ops.size()); ++i) { + Operation *op = ops[i]; + + for (Value result : op->getResults()) { + if (auto pvreg = dyn_cast(result.getType())) + processPVReg(pvreg.getIndex(), pvreg.getSize(), i, true); + } + + for (Value operand : op->getOperands()) { + if (auto pvreg = dyn_cast(operand.getType())) + processPVReg(pvreg.getIndex(), pvreg.getSize(), i, false); + } + + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (BlockArgument arg : block.getArguments()) { + if (auto pvreg = dyn_cast(arg.getType())) + processPVReg(pvreg.getIndex(), pvreg.getSize(), i, true); + } + } + } + + if (isa(op)) { + auto loopOp = cast(op); + Block &body = loopOp.getBodyBlock(); + Operation *terminator = body.getTerminator(); + int64_t termIdx = terminator ? opToIdx.lookup(terminator) : i; + for (BlockArgument arg : body.getArguments()) { + if (auto pvreg = dyn_cast(arg.getType())) + processPVReg(pvreg.getIndex(), pvreg.getSize(), termIdx, false); + } + } + } + + // Extend live ranges for values used inside loop bodies. + // Any value used inside a loop at some program point P is actually + // live from its definition until the end of the loop body (not just P), + // because the loop re-executes and the value must be available on + // every iteration. + for (int64_t i = 0; i < static_cast(ops.size()); ++i) { + Operation *op = ops[i]; + if (!isa(op)) + continue; + auto loopOp = cast(op); + Block &body = loopOp.getBodyBlock(); + Operation *terminator = body.getTerminator(); + int64_t loopStart = i; + int64_t loopEnd = terminator ? opToIdx.lookup(terminator) : i; + + // For each operand used inside the loop body, extend its range + // to cover the entire loop. + body.walk([&](Operation *innerOp) { + for (Value operand : innerOp->getOperands()) { + if (auto pvreg = dyn_cast(operand.getType())) { + auto [base, allocSize] = findAllocBase(pvreg.getIndex(), + pvreg.getSize(), knownSizes); + auto defIt = defPoints.find(base); + if (defIt != defPoints.end() && defIt->second < loopStart) { + // Value defined before the loop but used inside — extend to + // end of loop body. + auto &usePt = usePoints[base]; + usePt = std::max(usePt, loopEnd); + } + } + } + }); + } + + for (const auto &[base, defPt] : defPoints) { + auto useIt = usePoints.find(base); + int64_t usePt = (useIt != usePoints.end()) ? useIt->second : defPt; + PhysVGPRRange r; + r.physIdx = base; + r.size = sizes.lookup(base); + r.alignment = alignments.lookup(base); + if (r.size == 0) + r.size = 1; + if (r.alignment == 0) + r.alignment = 1; + r.defPoint = defPt; + r.lastUsePoint = usePt; + r.pinned = pinnedIndices.contains(base); + rangeBaseToIdx[base] = ranges.size(); + ranges.push_back(r); + } +} + +static bool overlaps(const PhysVGPRRange &a, const PhysVGPRRange &b) { + return a.defPoint <= b.lastUsePoint && b.defPoint <= a.lastUsePoint; +} + +static llvm::DenseMap +computeCompaction(llvm::SmallVectorImpl &ranges, + int64_t maxRegs) { + llvm::SmallVector order(ranges.size()); + std::iota(order.begin(), order.end(), 0); + llvm::sort(order, [&](int64_t a, int64_t b) { + if (ranges[a].length() != ranges[b].length()) + return ranges[a].length() < ranges[b].length(); + return ranges[a].defPoint < ranges[b].defPoint; + }); + + llvm::DenseMap oldToNew; + llvm::SmallVector newAssignment(ranges.size(), -1); + + // Pin precolored ranges to their original positions first. + for (size_t i = 0; i < ranges.size(); ++i) { + if (ranges[i].pinned) { + newAssignment[i] = ranges[i].physIdx; + oldToNew[ranges[i].physIdx] = ranges[i].physIdx; + } + } + + for (int64_t orderIdx : order) { + if (ranges[orderIdx].pinned) + continue; + + PhysVGPRRange &r = ranges[orderIdx]; + int64_t sz = r.size; + int64_t align = r.alignment; + + // v15 is the scratch VGPR for literal materialization in the assembly + // emitter (AssemblyEmitter.h KernelGenerator::kScratchVGPR). Must stay + // excluded so compaction never places a live value there. + constexpr int64_t kScratchVGPR = 15; + llvm::BitVector occupied(maxRegs, false); + if (kScratchVGPR < maxRegs) + occupied.set(kScratchVGPR); + for (size_t j = 0; j < ranges.size(); ++j) { + if (newAssignment[j] < 0) + continue; + if (overlaps(r, ranges[j])) { + int64_t base = newAssignment[j]; + for (int64_t k = 0; k < ranges[j].size; ++k) + if (base + k < maxRegs) + occupied.set(base + k); + } + } + + int64_t chosen = -1; + for (int64_t c = 0; c + sz <= maxRegs; c += align) { + bool free = true; + for (int64_t k = 0; k < sz; ++k) { + if (occupied.test(c + k)) { + free = false; + break; + } + } + if (free) { + chosen = c; + break; + } + } + + if (chosen >= 0) { + newAssignment[orderIdx] = chosen; + oldToNew[r.physIdx] = chosen; + } else { + newAssignment[orderIdx] = r.physIdx; + oldToNew[r.physIdx] = r.physIdx; + } + } + + return oldToNew; +} + +/// Remap a PVRegType index using the base-to-base mapping. +/// For sub-elements (e.g., v93 within v[92:95]), computes the offset +/// from the old base and applies it to the new base. +static int64_t remapIndex(int64_t oldIdx, int64_t size, + const llvm::DenseMap &oldToNew, + const llvm::SmallVectorImpl &ranges) { + // Direct lookup: this index IS a base + auto it = oldToNew.find(oldIdx); + if (it != oldToNew.end()) + return it->second; + + // Sub-element: find the containing base range + for (const auto &r : ranges) { + if (oldIdx >= r.physIdx && oldIdx < r.physIdx + r.size) { + auto baseIt = oldToNew.find(r.physIdx); + if (baseIt != oldToNew.end()) + return baseIt->second + (oldIdx - r.physIdx); + } + } + + return oldIdx; +} + +static void applyRemapping(ProgramOp program, + const llvm::DenseMap &oldToNew, + const llvm::SmallVectorImpl &ranges) { + auto remap = [&](Type ty) -> Type { + auto pvreg = dyn_cast(ty); + if (!pvreg) + return ty; + int64_t newIdx = + remapIndex(pvreg.getIndex(), pvreg.getSize(), oldToNew, ranges); + if (newIdx == pvreg.getIndex()) + return ty; + return PVRegType::get(ty.getContext(), newIdx, pvreg.getSize()); + }; + + program.walk([&](Operation *op) { + if (isa(op)) + return; + + for (Value result : op->getResults()) { + Type newTy = remap(result.getType()); + if (newTy != result.getType()) + result.setType(newTy); + } + + if (auto condOp = dyn_cast(op)) { + if (auto attr = condOp->getAttrOfType( + "_iterArgPhysRegs")) { + auto vals = attr.asArrayRef(); + llvm::SmallVector newVals(vals.begin(), vals.end()); + bool anyChanged = false; + for (size_t i = 0; i < newVals.size(); ++i) { + if (newVals[i] < 0) + continue; + if (i < condOp.getIterArgs().size()) { + Type ty = condOp.getIterArgs()[i].getType(); + if (isa(ty)) { + int64_t newIdx = + remapIndex(newVals[i], 1, oldToNew, ranges); + if (newIdx != newVals[i]) { + newVals[i] = newIdx; + anyChanged = true; + } + } + } + } + if (anyChanged) { + condOp->setAttr("_iterArgPhysRegs", + DenseI64ArrayAttr::get(op->getContext(), newVals)); + } + } + } + }); + + program.walk([&](LoopOp loopOp) { + Block &body = loopOp.getBodyBlock(); + for (BlockArgument arg : body.getArguments()) { + Type newTy = remap(arg.getType()); + if (newTy != arg.getType()) + arg.setType(newTy); + } + for (Value result : loopOp->getResults()) { + Type newTy = remap(result.getType()); + if (newTy != result.getType()) + result.setType(newTy); + } + }); +} + +struct WAVEASMVGPRCompaction + : waveasm::impl::WAVEASMVGPRCompactionBase { + + void runOnOperation() override { + auto moduleOp = getOperation(); + + moduleOp->walk([&](ProgramOp program) { + llvm::SmallVector ranges; + llvm::DenseMap rangeBaseToIdx; + + buildPhysRanges(program, ranges, rangeBaseToIdx); + + if (ranges.empty()) + return; + + int64_t maxBefore = 0; + for (const auto &r : ranges) + maxBefore = std::max(maxBefore, r.physIdx + r.size); + + auto oldToNew = computeCompaction(ranges, /*maxRegs=*/512); + + int64_t maxAfter = 0; + for (const auto &r : ranges) { + auto it = oldToNew.find(r.physIdx); + int64_t newIdx = (it != oldToNew.end()) ? it->second : r.physIdx; + maxAfter = std::max(maxAfter, newIdx + r.size); + } + + bool anyChange = false; + for (const auto &[old, newIdx] : oldToNew) { + if (old != newIdx) { + anyChange = true; + break; + } + } + + if (!anyChange) + return; + + LLVM_DEBUG(llvm::dbgs() << "VGPR compaction: " << maxBefore << " -> " + << maxAfter << " (saved " + << (maxBefore - maxAfter) << ")\n"); + + applyRemapping(program, oldToNew, ranges); + }); + } +}; + +} // namespace diff --git a/waveasm/lib/Transforms/handlers/AMDGPUHandlers.cpp b/waveasm/lib/Transforms/handlers/AMDGPUHandlers.cpp index 336db2c46a..5ccc8da9eb 100644 --- a/waveasm/lib/Transforms/handlers/AMDGPUHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/AMDGPUHandlers.cpp @@ -351,8 +351,9 @@ LogicalResult handleAMDGPUScaledMfma(Operation *op, TranslationContext &ctx) { /// Emit the SRD NUM_RECORDS field (word 2) from the validBytes operand. /// Static constants emit s_mov_b32; dynamic values (e.g. from arith.select -/// for the branchless g2s guard) go through v_readfirstlane_b32 + s_add_u32 -/// (non-Pure to avoid DCE). Falls back to hardware maximum if absent. +/// for the branchless g2s guard) use s_cmp + s_cselect when all operands +/// are scalar, or v_readfirstlane_b32 for VGPR values. +/// Falls back to hardware maximum if absent. static void emitSrdNumRecords(OpBuilder &builder, Location loc, int64_t srdBase, Operation *op, TranslationContext &ctx) { auto castOp = cast(op); @@ -367,6 +368,41 @@ static void emitSrdNumRecords(OpBuilder &builder, Location loc, int64_t srdBase, return; } + // Detect arith.select(arith.cmpi(scalar, scalar), scalar, scalar) + // and emit s_cmp + s_cselect directly into SRD word 2. + if (auto selectOp = validBytesVal.getDefiningOp()) { + Value condVal = selectOp.getCondition(); + if (auto cmpOp = condVal.getDefiningOp()) { + Value cmpLhs = cmpOp.getLhs(); + Value cmpRhs = cmpOp.getRhs(); + Value trueVal = selectOp.getTrueValue(); + Value falseVal = selectOp.getFalseValue(); + + auto cmpLhsMapped = ctx.getMapper().getMapped(cmpLhs); + auto cmpRhsMapped = ctx.getMapper().getMapped(cmpRhs); + auto trueMapped = ctx.getMapper().getMapped(trueVal); + auto falseMapped = ctx.getMapper().getMapped(falseVal); + + if (cmpLhsMapped && cmpRhsMapped && trueMapped && falseMapped && + isScalarOrImm(*cmpLhsMapped) && isScalarOrImm(*cmpRhsMapped) && + isScalarOrImm(*trueMapped) && isScalarOrImm(*falseMapped)) { + emitScalarCmp(builder, loc, cmpOp.getPredicate(), + *cmpLhsMapped, *cmpRhsMapped, ctx); + + auto dstType = PSRegType::get(builder.getContext(), srdBase + 2, 1); + Value trueV = *trueMapped; + Value falseV = *falseMapped; + auto sregType = ctx.createSRegType(); + if (isImmType(trueV.getType())) + trueV = S_MOV_B32::create(builder, loc, sregType, trueV); + auto result = + S_CSELECT_B32::create(builder, loc, dstType, trueV, falseV); + DCEProtectOp::create(builder, loc, result); + return; + } + } + } + auto mapped = ctx.getMapper().getMapped(validBytesVal); if (mapped) { Value src = *mapped; diff --git a/waveasm/lib/Transforms/handlers/AffineHandlers.cpp b/waveasm/lib/Transforms/handlers/AffineHandlers.cpp index 83caf692fe..1a56cc0aea 100644 --- a/waveasm/lib/Transforms/handlers/AffineHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/AffineHandlers.cpp @@ -259,9 +259,14 @@ static AffineExpr normalizeExpr(AffineExpr e, MLIRContext *ctx) { // 1. Float rcp seed → Newton-Raphson refinement → exact reciprocal // 2. q = mulhi(x, recip) // 3. Two remainder-based correction steps +// When both inputs are scalar, extracts the result back to SGPR via +// v_readfirstlane_b32 so downstream arithmetic stays in SALU. Value emitUnsignedFloordiv(Value x, Value d, OpBuilder &builder, Location loc, TranslationContext &ctx) { + bool bothScalar = isScalarOrImm(x) && isScalarOrImm(d); auto vregType = ctx.createVRegType(); + x = ensureVGPR(builder, x.getLoc(), ctx, x); + d = ensureVGPR(builder, d.getLoc(), ctx, d); Value zeroConst = createImmConst(0, builder, loc, ctx); // --- Step 1: float reciprocal seed --- @@ -305,6 +310,11 @@ Value emitUnsignedFloordiv(Value x, Value d, OpBuilder &builder, Location loc, zeroConst); q = V_ADD_U32::create(builder, loc, vregType, q, inc2); + if (bothScalar) { + auto sregType = ctx.createSRegType(); + q = V_READFIRSTLANE_B32::create(builder, loc, sregType, q); + } + return q; } @@ -330,7 +340,6 @@ Value emitConstantUnsignedFloordiv(Value x, int64_t divisor, OpBuilder &builder, assert(static_cast(divisor) <= 0xFFFFFFFFULL && "divisor exceeds 32 bits -- should have been normalized"); - auto vregType = ctx.createVRegType(); llvm::APInt divisorAPInt(32, static_cast(divisor)); auto mag = llvm::UnsignedDivisionByConstantInfo::get( divisorAPInt, /*LeadingZeros=*/0, @@ -345,23 +354,21 @@ Value emitConstantUnsignedFloordiv(Value x, int64_t divisor, OpBuilder &builder, int64_t magicVal = static_cast(mag.Magic.getZExtValue()); Value magicConst = createImmConst(magicVal, builder, loc, ctx); - Value q = V_MUL_HI_U32::create(builder, loc, vregType, x, magicConst); + Value q = emitMulHi(x, magicConst, builder, loc, ctx); auto emitShiftRight = [&](Value val, unsigned amount) -> Value { if (amount == 0) return val; Value shiftConst = createImmConst(static_cast(amount), builder, loc, ctx); - return V_LSHRREV_B32::create(builder, loc, vregType, shiftConst, val); + return emitLshr(val, shiftConst, builder, loc, ctx); }; if (mag.IsAdd) { - // add form: (mulhi(x,m) + ((x - mulhi(x,m)) >> 1)) >> PostShift - Value xSubQ = V_SUB_U32::create(builder, loc, vregType, x, q); + Value xSubQ = emitSub(x, q, builder, loc, ctx); Value oneConst = createImmConst(1, builder, loc, ctx); - Value halfDiff = - V_LSHRREV_B32::create(builder, loc, vregType, oneConst, xSubQ); - Value sum = V_ADD_U32::create(builder, loc, vregType, q, halfDiff); + Value halfDiff = emitLshr(xSubQ, oneConst, builder, loc, ctx); + Value sum = emitAdd(q, halfDiff, builder, loc, ctx); return emitShiftRight(sum, mag.PostShift); } @@ -382,7 +389,23 @@ Value emitConstantUnsignedFloordiv(Value x, int64_t divisor, OpBuilder &builder, static Value emitCeilFromFloorQuotient(Value q, Value x, Value d, OpBuilder &builder, Location loc, TranslationContext &ctx) { + bool allScalar = isScalarOrImm(q) && isScalarOrImm(x) && isScalarOrImm(d); + if (allScalar) { + // Fully scalar path: rem = x - q*d; SCC = (rem != 0); q += SCC + Value qd = emitMul(q, d, builder, loc, ctx); + Value rem = emitSub(x, qd, builder, loc, ctx); + Value zeroConst = createImmConst(0, builder, loc, ctx); + // s_cmp_lg_u32 sets SCC = (rem != 0) + S_CMP_NE_U32::create(builder, loc, ctx.createSRegType(), rem, zeroConst); + // s_addc_u32: dst = q + 0 + SCC (carry-in from SCC) + auto sregType = ctx.createSRegType(); + return S_ADDC_U32::create(builder, loc, sregType, sregType, q, zeroConst) + .getDst(); + } auto vregType = ctx.createVRegType(); + q = ensureVGPR(builder, loc, ctx, q); + x = ensureVGPR(builder, loc, ctx, x); + d = ensureVGPR(builder, loc, ctx, d); Value qd = V_MUL_LO_U32::create(builder, loc, vregType, q, d); Value rem = V_SUB_U32::create(builder, loc, vregType, x, qd); Value zeroConst = createImmConst(0, builder, loc, ctx); @@ -553,6 +576,12 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { BitRange lhsRange = lhsResult.range; BitRange rhsRange = rhsResult.range; + // Helper: coerce both operands to VGPR for VALU binary ops. + auto ensureBothVGPR = [&]() { + lhs = ensureVGPR(builder, loc, ctx, lhs); + rhs = ensureVGPR(builder, loc, ctx, rhs); + }; + switch (binExpr.getKind()) { case AffineExprKind::Add: { if (!lhsRange.overlaps(rhsRange)) { @@ -577,8 +606,11 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { auto shiftConst = ConstantOp::create(builder, loc, shiftImm, shiftAmount); // v_lshl_or_b32: dst = (src << shift) | orend + Value base = ensureVGPR(builder, loc, ctx, + baseResult.value); + orend = ensureVGPR(builder, loc, ctx, orend); Value fusedResult = V_LSHL_OR_B32::create( - builder, loc, vregType, baseResult.value, shiftConst, + builder, loc, vregType, base, shiftConst, orend); BitRange shiftedRange = baseResult.range.shiftLeft(shiftAmount); @@ -597,8 +629,11 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { auto shiftImm = ctx.createImmType(shiftAmount); auto shiftConst = ConstantOp::create(builder, loc, shiftImm, shiftAmount); + Value base2 = ensureVGPR(builder, loc, ctx, + baseResult.value); + orend = ensureVGPR(builder, loc, ctx, orend); Value fusedResult = V_LSHL_OR_B32::create( - builder, loc, vregType, baseResult.value, shiftConst, + builder, loc, vregType, base2, shiftConst, orend); BitRange shiftedRange = baseResult.range.shiftLeft(shiftAmount); @@ -621,14 +656,14 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { return *result; } - // No fusion possible, emit regular v_or_b32 - Value orResult = V_OR_B32::create(builder, loc, vregType, lhs, rhs); + // No fusion possible — emitOr picks S_OR_B32 when both scalar. + Value orResult = emitOr(lhs, rhs, builder, loc, ctx); BitRange resultRange = lhsRange.merge(rhsRange); ctx.setBitRange(orResult, resultRange); return ExprResult(orResult, resultRange); } // Overlapping ranges - must use ADD - Value addResult = V_ADD_U32::create(builder, loc, vregType, lhs, rhs); + Value addResult = emitAdd(lhs, rhs, builder, loc, ctx); BitRange resultRange = lhsRange.extendForAdd(rhsRange); ctx.setBitRange(addResult, resultRange); return ExprResult(addResult, resultRange); @@ -658,10 +693,8 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { int64_t mask = ~(N - 1) & 0xFFFFFFFF; auto maskImm = ctx.createImmType(mask); auto maskConst = ConstantOp::create(builder, loc, maskImm, mask); - // NOTE: constant must be src0 (first operand) for VOP2 encoding. - // src1 must be a VGPR on AMDGCN. - Value andResult = V_AND_B32::create(builder, loc, vregType, maskConst, - innerResult.value); + Value andResult = + emitAnd(maskConst, innerResult.value, builder, loc, ctx); // Result has same bit range as inner, but low bits cleared BitRange resultRange = innerResult.range; // Clear bits below log2(N) -- conservative: use inner range @@ -696,14 +729,12 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { auto shiftConst = ConstantOp::create(builder, loc, shiftAmt, shiftAmount); Value shiftResult = - V_LSHLREV_B32::create(builder, loc, vregType, shiftConst, lhs); - // Shift the bit range left by shiftAmount + emitLshl(lhs, shiftConst, builder, loc, ctx); BitRange resultRange = lhsRange.shiftLeft(shiftAmount); ctx.setBitRange(shiftResult, resultRange); return ExprResult(shiftResult, resultRange); } } - // Also check LHS for power of 2 multiply if (auto constLhs = dyn_cast(binExpr.getLHS())) { int64_t val = constLhs.getValue(); if (isPowerOf2(val)) { @@ -712,14 +743,13 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { auto shiftConst = ConstantOp::create(builder, loc, shiftAmt, shiftAmount); Value shiftResult = - V_LSHLREV_B32::create(builder, loc, vregType, shiftConst, rhs); + emitLshl(rhs, shiftConst, builder, loc, ctx); BitRange resultRange = rhsRange.shiftLeft(shiftAmount); ctx.setBitRange(shiftResult, resultRange); return ExprResult(shiftResult, resultRange); } } - Value mulResult = - V_MUL_LO_U32::create(builder, loc, vregType, lhs, rhs); + Value mulResult = emitMul(lhs, rhs, builder, loc, ctx); return ExprResult(mulResult, BitRange()); // Conservative: full range } @@ -747,7 +777,7 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { auto shiftConst = ConstantOp::create(builder, loc, shiftAmt, shiftAmount); Value shiftResult = - V_LSHRREV_B32::create(builder, loc, vregType, shiftConst, lhs); + emitLshr(lhs, shiftConst, builder, loc, ctx); BitRange resultRange = lhsRange.shiftRight(shiftAmount); ctx.setBitRange(shiftResult, resultRange); return ExprResult(shiftResult, resultRange); @@ -776,13 +806,21 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { if (isPowerOf2(divisor)) { int64_t shiftAmount = log2(divisor); Value shiftConst = createImmConst(shiftAmount, builder, loc, ctx); - Value q = - V_LSHRREV_B32::create(builder, loc, vregType, shiftConst, lhs); + Value q = emitLshr(lhs, shiftConst, builder, loc, ctx); int64_t mask = divisor - 1; Value maskConst = createImmConst(mask, builder, loc, ctx); - Value rem = - V_AND_B32::create(builder, loc, vregType, maskConst, lhs); + Value rem = emitAnd(lhs, maskConst, builder, loc, ctx); Value zeroConst = createImmConst(0, builder, loc, ctx); + if (isScalarOrImm(rem)) { + S_CMP_NE_U32::create(builder, loc, ctx.createSRegType(), rem, + zeroConst); + auto sregType = ctx.createSRegType(); + Value result = + S_ADDC_U32::create(builder, loc, sregType, sregType, q, + zeroConst) + .getDst(); + return ExprResult(result, BitRange()); + } V_CMP_NE_U32::create(builder, loc, rem, zeroConst); Value oneConst = createImmConst(1, builder, loc, ctx); Value oneVgpr = V_MOV_B32::create(builder, loc, vregType, oneConst); @@ -816,8 +854,7 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { int64_t val = constRhs.getValue(); if (isPowerOf2(val)) { Value maskConst = createImmConst(val - 1, builder, loc, ctx); - Value andResult = - V_AND_B32::create(builder, loc, vregType, lhs, maskConst); + Value andResult = emitAnd(lhs, maskConst, builder, loc, ctx); BitRange resultRange = BitRange(0, log2(val) - 1); ctx.setBitRange(andResult, resultRange); return ExprResult(andResult, resultRange); @@ -827,16 +864,16 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { if (val >= 2) { Value q = emitConstantUnsignedFloordiv(lhs, val, builder, loc, ctx); Value dConst = createImmConst(val, builder, loc, ctx); - Value qd = V_MUL_LO_U32::create(builder, loc, vregType, q, dConst); - Value rem = V_SUB_U32::create(builder, loc, vregType, lhs, qd); + Value qd = emitMul(q, dConst, builder, loc, ctx); + Value rem = emitSub(lhs, qd, builder, loc, ctx); return ExprResult(rem, BitRange()); } } // Symbolic divisor fallback: x mod d = x - floordiv(x, d) * d { Value q = emitUnsignedFloordiv(lhs, rhs, builder, loc, ctx); - Value qd = V_MUL_LO_U32::create(builder, loc, vregType, q, rhs); - Value rem = V_SUB_U32::create(builder, loc, vregType, lhs, qd); + Value qd = emitMul(q, rhs, builder, loc, ctx); + Value rem = emitSub(lhs, qd, builder, loc, ctx); return ExprResult(rem, BitRange()); } } diff --git a/waveasm/lib/Transforms/handlers/ArithHandlers.cpp b/waveasm/lib/Transforms/handlers/ArithHandlers.cpp index fced3fda29..27a016d1ae 100644 --- a/waveasm/lib/Transforms/handlers/ArithHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/ArithHandlers.cpp @@ -108,43 +108,135 @@ LogicalResult handleArithConstant(Operation *op, TranslationContext &ctx) { //===----------------------------------------------------------------------===// LogicalResult handleArithAddI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALU(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + + auto result = emitAdd(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } LogicalResult handleArithSubI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALU(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + + auto result = emitSub(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } LogicalResult handleArithMulI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALU(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + + auto result = emitMul(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } LogicalResult handleArithMinSI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALU(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + auto result = emitMinI32(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } LogicalResult handleArithMaxSI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALU(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + auto result = emitMaxI32(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } LogicalResult handleArithMinUI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALU(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + auto result = emitMinU32(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } LogicalResult handleArithMaxUI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALU(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + auto result = emitMaxU32(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } LogicalResult handleArithAndI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALU(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + + auto result = emitAnd(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } LogicalResult handleArithOrI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALU(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + + auto result = emitOr(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } LogicalResult handleArithXorI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALU(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + + auto result = emitXor(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } //===----------------------------------------------------------------------===// @@ -152,11 +244,31 @@ LogicalResult handleArithXorI(Operation *op, TranslationContext &ctx) { //===----------------------------------------------------------------------===// LogicalResult handleArithShLI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALURev(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + + auto result = emitLshl(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } LogicalResult handleArithShRUI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALURev(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + + auto result = emitLshr(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } LogicalResult handleArithShRSI(Operation *op, TranslationContext &ctx) { @@ -197,11 +309,9 @@ LogicalResult handleArithDivUI(Operation *op, TranslationContext &ctx) { if (auto constOp = rhs->getDefiningOp()) { int64_t divisor = constOp.getValue(); if (isPowerOf2(divisor)) { - auto vregType = ctx.createVRegType(); int64_t shiftAmt = log2(divisor); Value shiftConst = createImmConst(shiftAmt, builder, loc, ctx); - auto result = - V_LSHRREV_B32::create(builder, loc, vregType, shiftConst, *lhs); + auto result = emitLshr(*lhs, shiftConst, builder, loc, ctx); ctx.getMapper().mapValue(divOp.getResult(), result); return success(); } @@ -235,7 +345,7 @@ LogicalResult handleArithRemUI(Operation *op, TranslationContext &ctx) { int64_t modulus = constOp.getValue(); if (isPowerOf2(modulus)) { Value maskConst = createImmConst(modulus - 1, builder, loc, ctx); - auto result = V_AND_B32::create(builder, loc, vregType, *lhs, maskConst); + auto result = emitAnd(*lhs, maskConst, builder, loc, ctx); ctx.getMapper().mapValue(remOp.getResult(), result); return success(); } @@ -341,90 +451,95 @@ LogicalResult handleArithCmpI(Operation *op, TranslationContext &ctx) { return failure(); } - // When both operands are scalar (SGPR or immediate), use S_CMP which - // produces an SGPR result directly. This is required for scf.if/scf.for - // conditions that feed waveasm.if/waveasm.condition (which require SGPRs). + // When both operands are scalar AND the result feeds a ConditionOp (loop + // back-edge), use S_CMP which writes SCC directly. ConditionOp reads SCC + // via s_cbranch_scc1. + // For other scalar consumers (arith.select), the fusion happens in + // handleArithSelect which emits s_cmp + s_cselect as a pair. bool lhsScalar = isSGPRType(lhs->getType()) || isImmType(lhs->getType()); bool rhsScalar = isSGPRType(rhs->getType()) || isImmType(rhs->getType()); + bool usedByCondition = cmpOp.getResult().hasOneUse() && + isa(*cmpOp.getResult().getUsers().begin()); - if (lhsScalar && rhsScalar) { + if (lhsScalar && rhsScalar && usedByCondition) { auto sregType = ctx.createSRegType(); - // S_CMP requires SGPR operands; move immediates to SGPRs first. Value lhsOp = *lhs; Value rhsOp = *rhs; if (isImmType(lhsOp.getType())) lhsOp = S_MOV_B32::create(builder, loc, sregType, lhsOp); if (isImmType(rhsOp.getType())) rhsOp = S_MOV_B32::create(builder, loc, sregType, rhsOp); - Value result; - switch (cmpOp.getPredicate()) { - case arith::CmpIPredicate::eq: - result = S_CMP_EQ_U32::create(builder, loc, sregType, lhsOp, rhsOp); - break; - case arith::CmpIPredicate::ne: - result = S_CMP_NE_U32::create(builder, loc, sregType, lhsOp, rhsOp); - break; - case arith::CmpIPredicate::slt: - result = S_CMP_LT_I32::create(builder, loc, sregType, lhsOp, rhsOp); - break; - case arith::CmpIPredicate::sle: - result = S_CMP_LE_I32::create(builder, loc, sregType, lhsOp, rhsOp); - break; - case arith::CmpIPredicate::sgt: - result = S_CMP_GT_I32::create(builder, loc, sregType, lhsOp, rhsOp); - break; - case arith::CmpIPredicate::sge: - result = S_CMP_GE_I32::create(builder, loc, sregType, lhsOp, rhsOp); - break; - case arith::CmpIPredicate::ult: - result = S_CMP_LT_U32::create(builder, loc, sregType, lhsOp, rhsOp); - break; - case arith::CmpIPredicate::ule: - result = S_CMP_LE_U32::create(builder, loc, sregType, lhsOp, rhsOp); - break; - case arith::CmpIPredicate::ugt: - result = S_CMP_GT_U32::create(builder, loc, sregType, lhsOp, rhsOp); - break; - case arith::CmpIPredicate::uge: - result = S_CMP_GE_U32::create(builder, loc, sregType, lhsOp, rhsOp); - break; - } + Value result = + emitScalarCmp(builder, loc, cmpOp.getPredicate(), lhsOp, rhsOp, ctx); ctx.getMapper().mapValue(cmpOp.getResult(), result); return success(); } - // Vector path: at least one operand is a VGPR, so use V_CMP which writes - // to VCC implicitly. + // When both operands are scalar and ALL users are arith.select ops with + // scalar true/false values, the fusion in handleArithSelect will emit + // s_cmp + s_cselect directly. Skip the VALU path entirely — map the + // result to a dummy so the mapper doesn't complain. + if (lhsScalar && rhsScalar) { + bool allUsersFused = true; + for (auto &use : cmpOp.getResult().getUses()) { + auto *user = use.getOwner(); + if (auto selectUser = dyn_cast(user)) { + auto trueMap = ctx.getMapper().getMapped(selectUser.getTrueValue()); + auto falseMap = ctx.getMapper().getMapped(selectUser.getFalseValue()); + if (!trueMap || !falseMap || + !isScalarOrImm(*trueMap) || !isScalarOrImm(*falseMap)) { + allUsersFused = false; + break; + } + } else { + allUsersFused = false; + break; + } + } + if (allUsersFused) { + Value dummy = createImmConst(0, builder, loc, ctx); + ctx.getMapper().mapValue(cmpOp.getResult(), dummy); + return success(); + } + } + + // Vector path: use V_CMP which writes to VCC implicitly. + // V_CMP VOP3 can read one SGPR from the constant bus — only coerce + // to VGPR when BOTH operands are SGPR (two constant bus slots). + Value lhsV = *lhs; + Value rhsV = *rhs; + if (isSGPRType(lhsV.getType()) && isSGPRType(rhsV.getType())) + lhsV = ensureVGPR(builder, loc, ctx, lhsV); switch (cmpOp.getPredicate()) { case arith::CmpIPredicate::eq: - V_CMP_EQ_U32::create(builder, loc, *lhs, *rhs); + V_CMP_EQ_U32::create(builder, loc, lhsV, rhsV); break; case arith::CmpIPredicate::ne: - V_CMP_NE_U32::create(builder, loc, *lhs, *rhs); + V_CMP_NE_U32::create(builder, loc, lhsV, rhsV); break; case arith::CmpIPredicate::slt: - V_CMP_LT_I32::create(builder, loc, *lhs, *rhs); + V_CMP_LT_I32::create(builder, loc, lhsV, rhsV); break; case arith::CmpIPredicate::sle: - V_CMP_LE_I32::create(builder, loc, *lhs, *rhs); + V_CMP_LE_I32::create(builder, loc, lhsV, rhsV); break; case arith::CmpIPredicate::sgt: - V_CMP_GT_I32::create(builder, loc, *lhs, *rhs); + V_CMP_GT_I32::create(builder, loc, lhsV, rhsV); break; case arith::CmpIPredicate::sge: - V_CMP_GE_I32::create(builder, loc, *lhs, *rhs); + V_CMP_GE_I32::create(builder, loc, lhsV, rhsV); break; case arith::CmpIPredicate::ult: - V_CMP_LT_U32::create(builder, loc, *lhs, *rhs); + V_CMP_LT_U32::create(builder, loc, lhsV, rhsV); break; case arith::CmpIPredicate::ule: - V_CMP_LE_U32::create(builder, loc, *lhs, *rhs); + V_CMP_LE_U32::create(builder, loc, lhsV, rhsV); break; case arith::CmpIPredicate::ugt: - V_CMP_GT_U32::create(builder, loc, *lhs, *rhs); + V_CMP_GT_U32::create(builder, loc, lhsV, rhsV); break; case arith::CmpIPredicate::uge: - V_CMP_GE_U32::create(builder, loc, *lhs, *rhs); + V_CMP_GE_U32::create(builder, loc, lhsV, rhsV); break; } @@ -447,12 +562,68 @@ LogicalResult handleArithSelect(Operation *op, TranslationContext &ctx) { return op->emitError("operands not mapped"); } - // Restore the materialized boolean VGPR (0/1) back into VCC + // CmpI fusion: when condition is arith.cmpi with scalar operands and + // both select values are scalar, emit s_cmp + s_cselect directly. + // This bypasses the VGPR boolean from handleArithCmpI entirely. + if (isScalarOrImm(*trueVal) && isScalarOrImm(*falseVal)) { + Value condMLIR = selectOp.getCondition(); + if (auto cmpOp = condMLIR.getDefiningOp()) { + auto cmpLhs = ctx.getMapper().getMapped(cmpOp.getLhs()); + auto cmpRhs = ctx.getMapper().getMapped(cmpOp.getRhs()); + if (cmpLhs && cmpRhs && + isScalarOrImm(*cmpLhs) && isScalarOrImm(*cmpRhs)) { + auto sregType = ctx.createSRegType(); + Value lhsOp = *cmpLhs; + Value rhsOp = *cmpRhs; + if (isImmType(lhsOp.getType())) + lhsOp = S_MOV_B32::create(builder, loc, sregType, lhsOp); + if (isImmType(rhsOp.getType())) + rhsOp = S_MOV_B32::create(builder, loc, sregType, rhsOp); + emitScalarCmp(builder, loc, cmpOp.getPredicate(), lhsOp, rhsOp, ctx); + Value trueV = *trueVal; + Value falseV = *falseVal; + if (isImmType(trueV.getType())) + trueV = S_MOV_B32::create(builder, loc, sregType, trueV); + auto result = + S_CSELECT_B32::create(builder, loc, sregType, trueV, falseV); + ctx.getMapper().mapValue(selectOp.getResult(), result); + return success(); + } + } + } + + // Scalar path: when condition and both values are scalar, use + // s_cmp_lg_u32 + s_cselect_b32 (no VGPR needed). + if (isScalarOrImm(*cond) && isScalarOrImm(*trueVal) && + isScalarOrImm(*falseVal)) { + Value zeroConst = createImmConst(0, builder, loc, ctx); + Value condV = *cond; + if (isImmType(condV.getType())) + condV = S_MOV_B32::create(builder, loc, ctx.createSRegType(), condV); + S_CMP_NE_U32::create(builder, loc, ctx.createSRegType(), condV, zeroConst); + auto sregType = ctx.createSRegType(); + Value trueV = *trueVal; + Value falseV = *falseVal; + if (isImmType(trueV.getType())) + trueV = S_MOV_B32::create(builder, loc, sregType, trueV); + auto result = S_CSELECT_B32::create(builder, loc, sregType, trueV, falseV); + ctx.getMapper().mapValue(selectOp.getResult(), result); + return success(); + } + + // Vector path: restore the materialized boolean VGPR (0/1) back into VCC. + // V_CMP can take one SGPR via the constant bus; zeroConst is an immediate + // (no bus slot), so an SGPR cond is fine without coercion. Value zeroConst = createImmConst(0, builder, loc, ctx); - V_CMP_NE_U32::create(builder, loc, *cond, zeroConst); + Value condV = *cond; + V_CMP_NE_U32::create(builder, loc, condV, zeroConst); + // v_cndmask_b32: coerce both to VGPR. VOP3 constant bus allows one SGPR + // but interacts with vcc slot; keeping both as VGPR is safest. + Value trueVgpr = ensureVGPR(builder, loc, ctx, *trueVal); + Value falseVgpr = ensureVGPR(builder, loc, ctx, *falseVal); auto result = - V_CNDMASK_B32::create(builder, loc, vregType, *falseVal, *trueVal, *cond); + V_CNDMASK_B32::create(builder, loc, vregType, falseVgpr, trueVgpr, *cond); ctx.getMapper().mapValue(selectOp.getResult(), result); return success(); } diff --git a/waveasm/lib/Transforms/handlers/Handlers.h b/waveasm/lib/Transforms/handlers/Handlers.h index c3d2265f9c..bd6b8862fa 100644 --- a/waveasm/lib/Transforms/handlers/Handlers.h +++ b/waveasm/lib/Transforms/handlers/Handlers.h @@ -26,6 +26,7 @@ #include "waveasm/Transforms/TranslateFromMLIR.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" #include "mlir/Support/LogicalResult.h" @@ -278,6 +279,263 @@ template inline To bitCast(From value) { return result; } +//===----------------------------------------------------------------------===// +// Scalar / VGPR Helpers +//===----------------------------------------------------------------------===// + +/// Check whether a value is scalar (SGPR) or an inline constant. +inline bool isScalarOrImm(mlir::Value v) { + return isSGPRType(v.getType()) || isImmType(v.getType()); +} + +/// If \p v is an SGPR, emit a v_mov_b32 to coerce it into a VGPR so it can +/// be used by VALU-only instructions (v_cvt_*, v_rcp_*, v_mul_f32, etc.). +/// Returns \p v unchanged when it is already a VGPR or immediate. +inline mlir::Value ensureVGPR(mlir::OpBuilder &builder, mlir::Location loc, + TranslationContext &ctx, mlir::Value v) { + if (isSGPRType(v.getType())) { + auto vregType = ctx.createVRegType(); + return V_MOV_B32::create(builder, loc, vregType, v); + } + return v; +} + +//===----------------------------------------------------------------------===// +// Auto-select SALU/VALU Emit Helpers +//===----------------------------------------------------------------------===// + +/// Emit add: S_ADD_U32 when both operands are scalar, V_ADD_U32 otherwise. +/// Commutative: swaps to put immediate in src1 (SALU src0 must be SGPR). +inline mlir::Value emitAdd(mlir::Value a, mlir::Value b, mlir::OpBuilder &builder, + mlir::Location loc, TranslationContext &ctx) { + if (isScalarOrImm(a) && isScalarOrImm(b) && + !(isImmType(a.getType()) && isImmType(b.getType()))) { + if (isImmType(a.getType())) + std::swap(a, b); + auto sregType = ctx.createSRegType(); + return S_ADD_U32::create(builder, loc, sregType, sregType, a, b).getDst(); + } + auto vregType = ctx.createVRegType(); + return V_ADD_U32::create(builder, loc, vregType, a, b); +} + +/// Emit sub: S_SUB_U32 when both operands are scalar, V_SUB_U32 otherwise. +/// Not commutative: src0 (minuend) must be SGPR. +inline mlir::Value emitSub(mlir::Value a, mlir::Value b, mlir::OpBuilder &builder, + mlir::Location loc, TranslationContext &ctx) { + if (isScalarOrImm(a) && isScalarOrImm(b) && isSGPRType(a.getType())) { + auto sregType = ctx.createSRegType(); + return S_SUB_U32::create(builder, loc, sregType, sregType, a, b).getDst(); + } + auto vregType = ctx.createVRegType(); + return V_SUB_U32::create(builder, loc, vregType, a, b); +} + +/// Emit mul: S_MUL_I32 when both operands are scalar, V_MUL_LO_U32 otherwise. +/// Commutative: swaps to put immediate in src1. +inline mlir::Value emitMul(mlir::Value a, mlir::Value b, mlir::OpBuilder &builder, + mlir::Location loc, TranslationContext &ctx) { + if (isScalarOrImm(a) && isScalarOrImm(b) && + !(isImmType(a.getType()) && isImmType(b.getType()))) { + if (isImmType(a.getType())) + std::swap(a, b); + auto sregType = ctx.createSRegType(); + return S_MUL_I32::create(builder, loc, sregType, a, b); + } + auto vregType = ctx.createVRegType(); + return V_MUL_LO_U32::create(builder, loc, vregType, a, b); +} + +/// Emit logical shift right: S_LSHR_B32 when scalar, V_LSHRREV_B32 otherwise. +/// Not commutative: src0 (value) must be SGPR. +inline mlir::Value emitLshr(mlir::Value value, mlir::Value shiftAmt, + mlir::OpBuilder &builder, mlir::Location loc, + TranslationContext &ctx) { + if (isScalarOrImm(value) && isScalarOrImm(shiftAmt) && + isSGPRType(value.getType())) { + auto sregType = ctx.createSRegType(); + return S_LSHR_B32::create(builder, loc, sregType, value, shiftAmt); + } + auto vregType = ctx.createVRegType(); + value = ensureVGPR(builder, loc, ctx, value); + return V_LSHRREV_B32::create(builder, loc, vregType, shiftAmt, value); +} + +/// Emit logical shift left: S_LSHL_B32 when scalar, V_LSHLREV_B32 otherwise. +/// Not commutative: src0 (value) must be SGPR. +inline mlir::Value emitLshl(mlir::Value value, mlir::Value shiftAmt, + mlir::OpBuilder &builder, mlir::Location loc, + TranslationContext &ctx) { + if (isScalarOrImm(value) && isScalarOrImm(shiftAmt) && + isSGPRType(value.getType())) { + auto sregType = ctx.createSRegType(); + return S_LSHL_B32::create(builder, loc, sregType, value, shiftAmt); + } + auto vregType = ctx.createVRegType(); + value = ensureVGPR(builder, loc, ctx, value); + return V_LSHLREV_B32::create(builder, loc, vregType, shiftAmt, value); +} + +/// Emit bitwise AND: S_AND_B32 when both scalar, V_AND_B32 otherwise. +/// Commutative: swaps to put immediate in src1. +inline mlir::Value emitAnd(mlir::Value a, mlir::Value b, mlir::OpBuilder &builder, + mlir::Location loc, TranslationContext &ctx) { + if (isScalarOrImm(a) && isScalarOrImm(b) && + !(isImmType(a.getType()) && isImmType(b.getType()))) { + if (isImmType(a.getType())) + std::swap(a, b); + auto sregType = ctx.createSRegType(); + return S_AND_B32::create(builder, loc, sregType, a, b); + } + auto vregType = ctx.createVRegType(); + return V_AND_B32::create(builder, loc, vregType, a, b); +} + +/// Emit bitwise OR: S_OR_B32 when both scalar, V_OR_B32 otherwise. +/// Commutative: swaps to put immediate in src1. +inline mlir::Value emitOr(mlir::Value a, mlir::Value b, mlir::OpBuilder &builder, + mlir::Location loc, TranslationContext &ctx) { + if (isScalarOrImm(a) && isScalarOrImm(b) && + !(isImmType(a.getType()) && isImmType(b.getType()))) { + if (isImmType(a.getType())) + std::swap(a, b); + auto sregType = ctx.createSRegType(); + return S_OR_B32::create(builder, loc, sregType, a, b); + } + auto vregType = ctx.createVRegType(); + return V_OR_B32::create(builder, loc, vregType, a, b); +} + +/// Emit bitwise XOR: S_XOR_B32 when both scalar, V_XOR_B32 otherwise. +/// Commutative: swaps to put immediate in src1. +inline mlir::Value emitXor(mlir::Value a, mlir::Value b, mlir::OpBuilder &builder, + mlir::Location loc, TranslationContext &ctx) { + if (isScalarOrImm(a) && isScalarOrImm(b) && + !(isImmType(a.getType()) && isImmType(b.getType()))) { + if (isImmType(a.getType())) + std::swap(a, b); + auto sregType = ctx.createSRegType(); + return S_XOR_B32::create(builder, loc, sregType, a, b); + } + auto vregType = ctx.createVRegType(); + return V_XOR_B32::create(builder, loc, vregType, a, b); +} + +/// Emit mulhi: S_MUL_HI_U32 when both scalar, V_MUL_HI_U32 otherwise. +/// Commutative: swaps to put immediate in src1. +inline mlir::Value emitMulHi(mlir::Value a, mlir::Value b, + mlir::OpBuilder &builder, mlir::Location loc, + TranslationContext &ctx) { + if (isScalarOrImm(a) && isScalarOrImm(b) && + !(isImmType(a.getType()) && isImmType(b.getType()))) { + if (isImmType(a.getType())) + std::swap(a, b); + auto sregType = ctx.createSRegType(); + return S_MUL_HI_U32::create(builder, loc, sregType, a, b); + } + auto vregType = ctx.createVRegType(); + return V_MUL_HI_U32::create(builder, loc, vregType, a, b); +} + +/// Emit signed max: S_MAX_I32 when both scalar, V_MAX_I32 otherwise. +/// Not commutative for SCC semantics (SCC = src0 was selected) but +/// the result value is commutative. +inline mlir::Value emitMaxI32(mlir::Value a, mlir::Value b, + mlir::OpBuilder &builder, mlir::Location loc, + TranslationContext &ctx) { + if (isScalarOrImm(a) && isScalarOrImm(b) && + !(isImmType(a.getType()) && isImmType(b.getType()))) { + if (isImmType(a.getType())) + std::swap(a, b); + auto sregType = ctx.createSRegType(); + return S_MAX_I32::create(builder, loc, sregType, a, b); + } + auto vregType = ctx.createVRegType(); + return V_MAX_I32::create(builder, loc, vregType, a, b); +} + +/// Emit signed min: S_MIN_I32 when both scalar, V_MIN_I32 otherwise. +inline mlir::Value emitMinI32(mlir::Value a, mlir::Value b, + mlir::OpBuilder &builder, mlir::Location loc, + TranslationContext &ctx) { + if (isScalarOrImm(a) && isScalarOrImm(b) && + !(isImmType(a.getType()) && isImmType(b.getType()))) { + if (isImmType(a.getType())) + std::swap(a, b); + auto sregType = ctx.createSRegType(); + return S_MIN_I32::create(builder, loc, sregType, a, b); + } + auto vregType = ctx.createVRegType(); + return V_MIN_I32::create(builder, loc, vregType, a, b); +} + +/// Emit unsigned max: S_MAX_U32 when both scalar, V_MAX_U32 otherwise. +inline mlir::Value emitMaxU32(mlir::Value a, mlir::Value b, + mlir::OpBuilder &builder, mlir::Location loc, + TranslationContext &ctx) { + if (isScalarOrImm(a) && isScalarOrImm(b) && + !(isImmType(a.getType()) && isImmType(b.getType()))) { + if (isImmType(a.getType())) + std::swap(a, b); + auto sregType = ctx.createSRegType(); + return S_MAX_U32::create(builder, loc, sregType, a, b); + } + auto vregType = ctx.createVRegType(); + return V_MAX_U32::create(builder, loc, vregType, a, b); +} + +/// Emit unsigned min: S_MIN_U32 when both scalar, V_MIN_U32 otherwise. +inline mlir::Value emitMinU32(mlir::Value a, mlir::Value b, + mlir::OpBuilder &builder, mlir::Location loc, + TranslationContext &ctx) { + if (isScalarOrImm(a) && isScalarOrImm(b) && + !(isImmType(a.getType()) && isImmType(b.getType()))) { + if (isImmType(a.getType())) + std::swap(a, b); + auto sregType = ctx.createSRegType(); + return S_MIN_U32::create(builder, loc, sregType, a, b); + } + auto vregType = ctx.createVRegType(); + return V_MIN_U32::create(builder, loc, vregType, a, b); +} + +//===----------------------------------------------------------------------===// +// Scalar Comparison Helper +//===----------------------------------------------------------------------===// + +/// Emit S_CMP_* for the given predicate with SGPR operands (sets SCC). +/// Both lhs and rhs must be SGPRs (not immediates) before calling. +/// Returns the S_CMP result value (phantom SCC). +inline mlir::Value emitScalarCmp(mlir::OpBuilder &builder, mlir::Location loc, + mlir::arith::CmpIPredicate pred, + mlir::Value lhs, mlir::Value rhs, + TranslationContext &ctx) { + auto sregType = ctx.createSRegType(); + switch (pred) { + case mlir::arith::CmpIPredicate::eq: + return S_CMP_EQ_U32::create(builder, loc, sregType, lhs, rhs); + case mlir::arith::CmpIPredicate::ne: + return S_CMP_NE_U32::create(builder, loc, sregType, lhs, rhs); + case mlir::arith::CmpIPredicate::slt: + return S_CMP_LT_I32::create(builder, loc, sregType, lhs, rhs); + case mlir::arith::CmpIPredicate::sle: + return S_CMP_LE_I32::create(builder, loc, sregType, lhs, rhs); + case mlir::arith::CmpIPredicate::sgt: + return S_CMP_GT_I32::create(builder, loc, sregType, lhs, rhs); + case mlir::arith::CmpIPredicate::sge: + return S_CMP_GE_I32::create(builder, loc, sregType, lhs, rhs); + case mlir::arith::CmpIPredicate::ult: + return S_CMP_LT_U32::create(builder, loc, sregType, lhs, rhs); + case mlir::arith::CmpIPredicate::ule: + return S_CMP_LE_U32::create(builder, loc, sregType, lhs, rhs); + case mlir::arith::CmpIPredicate::ugt: + return S_CMP_GT_U32::create(builder, loc, sregType, lhs, rhs); + case mlir::arith::CmpIPredicate::uge: + return S_CMP_GE_U32::create(builder, loc, sregType, lhs, rhs); + } + llvm_unreachable("unhandled CmpIPredicate"); +} + //===----------------------------------------------------------------------===// // Error Handling Helpers //===----------------------------------------------------------------------===// diff --git a/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp b/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp index 785808bdba..70cb50b248 100644 --- a/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp @@ -301,7 +301,14 @@ LogicalResult handleMemRefStore(Operation *op, TranslationContext &ctx) { voffset = ConstantOp::create(builder, loc, immType, 0); } - BUFFER_STORE_DWORD::create(builder, loc, *data, srd, voffset); + Value storeData = *data; + if (isAGPRType(storeData.getType())) { + auto vregType = ctx.createVRegType(); + storeData = + V_ACCVGPR_READ_B32::create(builder, loc, vregType, storeData); + } + + BUFFER_STORE_DWORD::create(builder, loc, storeData, srd, voffset); } return success(); diff --git a/waveasm/tools/waveasm-translate/waveasm-translate.cpp b/waveasm/tools/waveasm-translate/waveasm-translate.cpp index 0daa6a64d5..f9836f3959 100644 --- a/waveasm/tools/waveasm-translate/waveasm-translate.cpp +++ b/waveasm/tools/waveasm-translate/waveasm-translate.cpp @@ -113,6 +113,12 @@ int main(int argc, char **argv) { waveasm::registerWaveASMPasses(); mlir::registerTransformsPasses(); + // Enable standard MLIR pass manager CLI options: + // --mlir-print-ir-after-all, --mlir-print-ir-before-all, + // --mlir-print-ir-after-change, --mlir-print-ir-after-failure, etc. + // IR is printed to stderr; pipe with 2>file to capture. + mlir::registerPassManagerCLOptions(); + // Construct AFTER pass registration — PassNameParser::initialize() snapshots // the registry, so passes must already be registered at this point. static mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run"); @@ -206,6 +212,8 @@ int main(int argc, char **argv) { // Build pass pipeline from CLI flags. PassManager pm(&context); + // Apply CLI options for IR printing (--mlir-print-ir-after-all, etc.) + mlir::applyPassManagerCLOptions(pm); if (passPipeline.hasAnyOccurrences()) { auto errorHandler = [](const Twine &msg) { llvm::errs() << msg << "\n"; diff --git a/waveasm/waveasm_e2e.py b/waveasm/waveasm_e2e.py index 1a33e798f8..db7273dfa5 100644 --- a/waveasm/waveasm_e2e.py +++ b/waveasm/waveasm_e2e.py @@ -211,7 +211,9 @@ def compile_mlir_to_asm( "--canonicalize", # Clean up dead instructions from offset opt. "--waveasm-scoped-cse", # Re-deduplicate after offset folding. "--waveasm-loop-address-promotion", + "--waveasm-scc-verifier", # Verify no SCC hazards before regalloc. "--waveasm-linear-scan=max-vgprs=512 max-agprs=512", # Register allocation. + "--waveasm-vgpr-compaction", # Compact VGPR assignments. f"--waveasm-insert-waitcnt=ticketed-waitcnt={ticketed}", # Insert waits. f"--waveasm-hazard-mitigation=target={self.target}", # Handle hazards. "--emit-assembly", From e17b4019db52b417d0f6c5051e1dfe83c0ea0b71 Mon Sep 17 00:00:00 2001 From: xintin Date: Sat, 21 Mar 2026 01:49:00 +0000 Subject: [PATCH 2/5] Enable 64x64x256 dynamic MXFP4 GEMM with ee=True on 3-stage pipeline on waveasm Signed-off-by: xintin --- .../schedules/gemm_mxfp4_double_buffer.py | 9 +- wave_lang/kernel/wave/utils/mapping_utils.py | 88 +++---------------- waveasm/lib/Transforms/CMakeLists.txt | 9 +- 3 files changed, 29 insertions(+), 77 deletions(-) diff --git a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py index 82c271fc87..2428b47f8f 100755 --- a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py +++ b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py @@ -1783,8 +1783,13 @@ def mxfp4_dbuf_schedule(): # Interleave MFMAs with memory ops (matching aiter f4gemm pattern). # Clamp start_offsets so they fit within each partition when the M # tile count is odd (e.g. 7 tiles split into 4+3). - base_offsets = [0, 3, 2, 0] - base_intervals = [4, 4, 2, 4] + n_mma = len(loop_scaled_mma_0) + if n_mma <= 4: + base_offsets = [0, 1, 1, 0] + base_intervals = [2, 2, 1, 2] + else: + base_offsets = [0, 3, 2, 0] + base_intervals = [4, 4, 2, 4] def _clamp_offsets(n, offsets): return [min(o, max(0, n - 1)) for o in offsets] diff --git a/wave_lang/kernel/wave/utils/mapping_utils.py b/wave_lang/kernel/wave/utils/mapping_utils.py index 870e3d2401..a6ce300437 100644 --- a/wave_lang/kernel/wave/utils/mapping_utils.py +++ b/wave_lang/kernel/wave/utils/mapping_utils.py @@ -284,6 +284,7 @@ def _eval_concrete_floor_mod(expr: sympy.Expr) -> sympy.Expr: collapses them bottom-up so the result contains no unnecessary wrappers around known integers. """ + def _try_eval(e): if isinstance(e, sympy.floor): inner = e.args[0] @@ -374,7 +375,6 @@ def _infer_floor_to_exact(mem_strides: list[IndexExpr]) -> dict: return subs_map - def compute_iv_stride_through_mapping( mapping: IndexMapping, symbolic_shape: tuple[IndexExpr, ...], @@ -428,59 +428,31 @@ def compute_iv_stride_through_mapping( if not iv_info: return None - print(f"=== compute_iv_stride_through_mapping is_read={is_read} ===") - print(f" iters: {dict(iters)}") - for iv_sym, (iv_iter, cc) in iv_info.items(): - print(f" IV {iv_sym} -> iter={iv_iter} coeff={cc}") - - map_dims = ( - mapping.input_shape if is_read else mapping.output_shape - ) + map_dims = mapping.input_shape if is_read else mapping.output_shape raw_exprs = ( - mapping.map_input_indices(map_dims) if is_read + mapping.map_input_indices(map_dims) + if is_read else mapping.map_output_indices(map_dims) ) idxc = IndexingContext.current() dim_exprs = [subs_idxc(e) for e in raw_exprs] - for i, (raw, resolved) in enumerate(zip(raw_exprs, dim_exprs)): - changed = str(raw) != str(resolved) - print( - f" dim[{i}] raw={raw} -> resolved={resolved}" - f"{' (CHANGED by subs_idxc)' if changed else ''}" - ) - if mem_strides is None: symbolic_shape_resolved = tuple(infer_dim(d) for d in symbolic_shape) mem_strides = strides_from_symbolic_shape( idxc, symbolic_shape_resolved, allow_mixed_shapes=True ) - stride_free = set() - for s in mem_strides: - stride_free |= sympy.sympify(s).free_symbols - print( - f" mem_strides={mem_strides}" - f" (symbolic={sorted(str(s) for s in stride_free) if stride_free else 'none'})" - ) - div_fwd, div_bwd = get_divisibility_subs(constraints) if div_fwd: fwd_dict = dict(div_fwd) dim_exprs = [sympy.sympify(e).subs(fwd_dict) for e in dim_exprs] mem_strides = [sympy.sympify(s).subs(fwd_dict) for s in mem_strides] - print(f" divisibility fwd subs: {fwd_dict}") - for i, e in enumerate(dim_exprs): - print(f" dim_after_div_subs[{i}] = {e}") - print(f" mem_strides_after_div_subs={mem_strides}") else: floor_subs = _infer_floor_to_exact(mem_strides) if floor_subs: dim_exprs = [sympy.sympify(e).subs(floor_subs) for e in dim_exprs] - print(f" floor_to_exact subs (fallback): {floor_subs}") - for i, e in enumerate(dim_exprs): - print(f" dim_after_subs[{i}] = {e}") result: dict[sympy.Symbol, IndexExpr | list[IndexExpr]] = {} @@ -489,23 +461,18 @@ def compute_iv_stride_through_mapping( dim_exprs, mem_strides, iters, iv_iter, concrete_coeff ) if stride_or_cycle is None: - print( - f" _probe_iv_stride returned None for IV {iv_sym}" - f" — no pattern detected, returning None for entire mapping" - ) return None result[iv_sym] = stride_or_cycle if div_bwd: bwd_dict = dict(div_bwd) + def _bwd(v): if isinstance(v, list): return [mem_simplify(sympy.sympify(x).subs(bwd_dict)) for x in v] return mem_simplify(sympy.sympify(v).subs(bwd_dict)) - result = {k: _bwd(v) for k, v in result.items()} - for iv_sym, val in result.items(): - print(f" RESULT {iv_sym} -> {val}") + result = {k: _bwd(v) for k, v in result.items()} return result @@ -581,30 +548,25 @@ def _probe_iv_stride( (repeating cycle), or ``None`` on failure. """ - print( - f"_probe_iv_stride iv_iter={iv_iter} coeff={concrete_coeff}" - ) - # Step 1: linearize symbolically, then compute probe depth from the # flat expression's divisors. Apply subs_idxc to iv_flat so the # probe-depth computation sees the same symbol resolution as the # concrete address evaluations (prevents under-probing when a # divisor is symbolic pre-subs but integer post-subs). flat_expr = mem_simplify(linearize_dims(dim_exprs, mem_strides)) - iv_flat = flat_expr.subs({ - it: (concrete_coeff * sympy.Symbol("_iv") if it == iv_iter else 0) - for it in iters.keys() - }) + iv_flat = flat_expr.subs( + { + it: (concrete_coeff * sympy.Symbol("_iv") if it == iv_iter else 0) + for it in iters.keys() + } + ) iv_flat = subs_idxc(iv_flat) probe_depth = _compute_probe_depth(iv_flat, concrete_coeff) - print(f" probe_depth={probe_depth}") - # Step 2: evaluate P+1 concrete addresses. def _linearized_addr(iv_val: int) -> IndexExpr: subs = { - it: (concrete_coeff * iv_val if it == iv_iter else 0) - for it in iters.keys() + it: (concrete_coeff * iv_val if it == iv_iter else 0) for it in iters.keys() } resolved = [subs_idxc(dim_expr.subs(subs)) for dim_expr in dim_exprs] return mem_simplify(subs_idxc(linearize_dims(resolved, mem_strides))) @@ -612,22 +574,12 @@ def _linearized_addr(iv_val: int) -> IndexExpr: addrs: list[int] = [] for iv in range(probe_depth + 1): a = _linearized_addr(iv) - if getattr(a, 'free_symbols', set()): - print( - f" addr[iv={iv}] = {a} (free={a.free_symbols})" - f"\n *** ERROR: address contains unresolved free symbols." - f" Fix chained symbolic dependencies upstream." - ) + if getattr(a, "free_symbols", set()): return None addrs.append(int(a)) diffs = [addrs[i + 1] - addrs[i] for i in range(probe_depth)] - for i, a in enumerate(addrs): - print(f" addr[iv={i}] = {a}") - for i, d in enumerate(diffs): - print(f" diff[{i}] = {d}") - if not diffs: return None @@ -638,21 +590,9 @@ def _linearized_addr(iv_val: int) -> IndexExpr: if all(diffs[i] == diffs[i % cycle_len] for i in range(probe_depth)): cycle = [sympy.Integer(diffs[i]) for i in range(cycle_len)] if cycle_len == 1: - print( - f" -> CONSTANT stride = {cycle[0]}" - f" (concrete=True, probe_depth={probe_depth})" - ) return cycle[0] - print( - f" -> CYCLIC stride (len={cycle_len}): {diffs[:cycle_len]}" - f" (concrete=True, probe_depth={probe_depth})" - ) return cycle - print( - f" -> FAILED: no constant or cyclic pattern in {probe_depth}" - f" diffs. diffs={diffs}" - ) return None diff --git a/waveasm/lib/Transforms/CMakeLists.txt b/waveasm/lib/Transforms/CMakeLists.txt index f1a1bdae5f..9770b31f0c 100644 --- a/waveasm/lib/Transforms/CMakeLists.txt +++ b/waveasm/lib/Transforms/CMakeLists.txt @@ -14,9 +14,11 @@ foreach(src ${HANDLERS_SRCS}) endforeach() add_mlir_dialect_library(MLIRWaveASMTransforms + ArithLegalization.cpp AssemblyEmitter.cpp BufferLoadStrengthReduction.cpp ExtractScalarization.cpp + GPUModuleToBinary.cpp HazardMitigation.cpp LinearScanPass.cpp LinearScanRegAlloc.cpp @@ -32,8 +34,9 @@ add_mlir_dialect_library(MLIRWaveASMTransforms SCCVerifier.cpp ScopedCSE.cpp Ticketing.cpp - VGPRCompaction.cpp + TranslateFromLLVMDialect.cpp TranslateFromMLIR.cpp + VGPRCompaction.cpp ${HANDLERS_FULL_PATHS} ADDITIONAL_HEADER_DIRS @@ -49,8 +52,12 @@ add_mlir_dialect_library(MLIRWaveASMTransforms MLIRFuncDialect MLIRGPUDialect MLIRIR + MLIRLLVMDialect MLIRMathDialect MLIRMemRefDialect + MLIRPass + MLIRROCDLDialect + MLIRROCDLTarget MLIRSCFDialect MLIRSupport MLIRVectorDialect From 11d3c8d3f1929671634fe6d06e6162a32eacb9bc Mon Sep 17 00:00:00 2001 From: xintin Date: Thu, 26 Mar 2026 22:25:32 +0000 Subject: [PATCH 3/5] initial draft: store works Signed-off-by: xintin --- examples/python/7.1_schedule.py | 215 ++++++++++++++++-- examples/python/test_sympy_diff.py | 6 +- .../unittests/index_mapping_simplify_test.py | 22 +- wave_lang/kernel/_support/indexing.py | 1 + .../compiler/wave_codegen/read_write.py | 173 ++++++++++++-- .../wave/analysis/annotate_iv_strides.py | 11 +- wave_lang/kernel/wave/compile.py | 12 +- wave_lang/kernel/wave/compile_options.py | 2 + .../kernel/wave/index_mapping_simplify.py | 6 +- .../wave/templates/tagged_mxfp4_gemm.py | 65 +++++- wave_lang/kernel/wave/utils/symbol_utils.py | 16 +- waveasm/include/waveasm/Dialect/WaveASMOps.td | 19 +- .../waveasm/Transforms/TranslateFromMLIR.h | 14 ++ waveasm/lib/Transforms/AssemblyEmitter.cpp | 16 ++ waveasm/lib/Transforms/LinearScanPass.cpp | 7 + waveasm/lib/Transforms/LinearScanRegAlloc.cpp | 4 +- waveasm/lib/Transforms/SCCVerifier.cpp | 23 +- waveasm/lib/Transforms/ScopedCSE.cpp | 3 +- waveasm/lib/Transforms/TranslateFromMLIR.cpp | 63 +++-- waveasm/lib/Transforms/VGPRCompaction.cpp | 30 ++- .../Transforms/handlers/AMDGPUHandlers.cpp | 89 +++++++- .../Transforms/handlers/AffineHandlers.cpp | 30 +-- .../lib/Transforms/handlers/ArithHandlers.cpp | 26 ++- waveasm/lib/Transforms/handlers/Handlers.h | 53 ++++- .../Transforms/handlers/MemRefHandlers.cpp | 3 +- 25 files changed, 750 insertions(+), 159 deletions(-) diff --git a/examples/python/7.1_schedule.py b/examples/python/7.1_schedule.py index d97102576c..aff372d35f 100755 --- a/examples/python/7.1_schedule.py +++ b/examples/python/7.1_schedule.py @@ -11,6 +11,7 @@ python 7.1_schedule.py --list_tests """ +import os import torch import wave_lang.kernel.lang as tkl @@ -60,7 +61,13 @@ def _run_mxfp_gemm(gemm, shape): def _run_mxfp_gemm_preshuffle( - gemm, shape, all=False, only_scale=False, only_b=False, output_dtype=torch.float32 + gemm, + shape, + all=False, + only_scale=False, + only_b=False, + output_dtype=torch.float32, + transpose_output=False, ): """Run compiled GEMM kernel with preshuffled B and B_scale, verify against reference. @@ -68,30 +75,82 @@ def _run_mxfp_gemm_preshuffle( all - shuffle a_scale (x_scales), b_scale (w_scales), and b (w_t) only_scale - shuffle a_scale (x_scales) and b_scale (w_scales) only only_b - shuffle b_scale (w_scales) only + + When transpose_output is True, the kernel writes C^T [N, M] instead of C [M, N]. """ x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape) torch_out = torchScaledGemmMXFP4(x, w, x_scales, w_scales) w_t = w.T.contiguous() - # Apply b (w_t) preshuffle only when all=True w_t_ps = b_preshuffle(w_t) if all else w_t - # Apply a_scale shuffle when all=True or only_scale=True x_scales_ps = e8m0_shuffle(x_scales) if (all or only_scale) else x_scales - # Apply b_scale shuffle when all=True, only_scale=True, or only_b=True w_scales_ps = e8m0_shuffle(w_scales) if (all or only_scale or only_b) else w_scales x, w_t_ps = x.cuda(), w_t_ps.cuda() x_scales_ps, w_scales_ps = x_scales_ps.cuda(), w_scales_ps.cuda() - out = torch.zeros(x.shape[0], w_t_ps.shape[0], dtype=output_dtype).cuda() + if transpose_output: + out = torch.zeros(w_t_ps.shape[0], x.shape[0], dtype=output_dtype).cuda() + else: + out = torch.zeros(x.shape[0], w_t_ps.shape[0], dtype=output_dtype).cuda() gemm(x, x_scales_ps, w_t_ps, w_scales_ps, out) - torch.testing.assert_close( - torch_out, out.cpu(), check_dtype=False, check_device=False - ) + result = out.T.contiguous().cpu() if transpose_output else out.cpu() + + if os.environ.get("WAVE_DEBUG_COMPARE"): + ref = torch_out.to(torch.float32).cpu() + got = result.to(torch.float32).cpu() + mismatch = ~torch.isclose(ref, got, atol=1e-1, rtol=0.1) + M, N = ref.shape + wrong_idx = torch.nonzero(mismatch) + if wrong_idx.numel() > 0: + print(f" Mismatch count: {mismatch.sum().item()} / {mismatch.numel()}") + m_wrong = wrong_idx[:, 0] + m_mod8 = m_wrong % 8 + print( + f" Mismatch M%8 distribution: {torch.bincount(m_mod8, minlength=8).tolist()}" + ) + # Check if got values at wrong M match ref at shifted M + print(" === Shift analysis (checking if got[m,n] == ref[m+shift,n]) ===") + for shift in [-4, -3, -2, -1, 1, 2, 3, 4]: + match_count = 0 + total = 0 + for i in range(min(1000, wrong_idx.shape[0])): + mi, ni = wrong_idx[i].tolist() + ms = mi + shift + if 0 <= ms < M: + total += 1 + if torch.isclose( + got[mi, ni : ni + 1], + ref[ms, ni : ni + 1], + atol=1e-1, + rtol=0.1, + ).item(): + match_count += 1 + if total > 0: + print( + f" shift={shift:+d}: {match_count}/{total} ({100*match_count/total:.0f}%)" + ) + # Print samples with neighboring values + print(" === Sample values (M=0..7 at N=0) ===") + for m in range(min(8, M)): + r = ref[m, 0].item() + g = got[m, 0].item() + ok = ( + "OK" + if torch.isclose( + ref[m, 0:1], got[m, 0:1], atol=1e-1, rtol=0.1 + ).item() + else "WRONG" + ) + print(f" M={m} ref={r:.4f} got={g:.4f} {ok}") + else: + print(" No mismatches (within tolerance)") + + torch.testing.assert_close(torch_out, result, check_dtype=False, check_device=False) def _get_8wave_shape_from_block(block): @@ -375,7 +434,9 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm_cpp( eliminate_epilogue=True, ): """Preshuffle-B MXFP4 GEMM using C++ WaveASM backend.""" - gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(2, 2), reorder_workgroups=True) + gemm, options = get_tagged_mxfp4_gemm_preshuffle_b( + shape, block, wave_shape=(2, 2), reorder_workgroups=True + ) options.backend = "asm" options.use_buffer_ops = True options.wave_runtime = True @@ -394,14 +455,22 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm_cpp( f"MXFP GEMM preshuffle-B 4-wave (WaveASM) epilogue elimination={eliminate_epilogue} PASSED" ) + def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm( is_debug=False, - shape=(1024, 1024, 8192), - block=(128, 256, 256), + shape=(1024, 6144, 8192), + block=(256, 192, 256), eliminate_epilogue=True, ): """Preshuffle-B MXFP4 GEMM with dynamic M, N, K.""" - gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(1, 4)) + gemm, options = get_tagged_mxfp4_gemm_preshuffle_b( + shape, + block, + wave_shape=(1, 4), + reorder_workgroups=True, + output_dtype=tkl.bf16, + transpose_output=True, + ) # Make M, N, K dynamic so the compiler does not specialize on problem size. dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] for sym in dynamic_symbols: @@ -419,18 +488,22 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm( options = set_default_run_config(options) gemm = wave_compile(options, gemm, schedule) - _run_mxfp_gemm_preshuffle(gemm, shape, all=True) + _run_mxfp_gemm_preshuffle( + gemm, shape, all=True, output_dtype=torch.bfloat16, transpose_output=True + ) print("MXFP GEMM preshuffle-B 4-wave dynamic M, N, K (LLVM backend) test passed!") def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm( is_debug=False, - shape=(1024, 1024, 8192), - block=(128, 256, 256), + shape=(1024, 6144, 8192), + block=(256, 192, 256), eliminate_epilogue=True, ): """Preshuffle-B MXFP4 GEMM with dynamic M, N, K.""" - gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(2, 2), reorder_workgroups=True) + gemm, options = get_tagged_mxfp4_gemm_preshuffle_b( + shape, block, wave_shape=(2, 2), reorder_workgroups=True + ) # Make M, N, K dynamic so the compiler does not specialize on problem size. dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] for sym in dynamic_symbols: @@ -441,7 +514,7 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm( options.use_wave_asm_backend = True options.wave_runtime = True options.eliminate_epilogue = eliminate_epilogue - options.dump_intermediates = "build/intermediates/" + options.dump_intermediates = "build/intermediates/waveasm_256x192x256_baseline/" options.print_mlir_file = "gemm_mxfp4_dbuf_4wave_asymmetric.mlir" options.print_mlir = True schedule = get_mxfp4_asymmetric_schedule( @@ -451,8 +524,114 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm( options = set_default_run_config(options) gemm = wave_compile(options, gemm, schedule) + with open( + "build/intermediates/waveasm_256x192x256_baseline/gemm_mxfp4_dbuf_4wave_asymmetric.mlir", + "w", + ) as f: + f.write(gemm.asm) + _run_mxfp_gemm_preshuffle(gemm, shape, all=True) - print("MXFP GEMM preshuffle-B 4-wave dynamic M, N, K (WaveASM backend) test passed!") + print( + "MXFP GEMM preshuffle-B 4-wave dynamic M, N, K (WaveASM backend) test passed!" + ) + + +def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm_bf16( + is_debug=False, + shape=(6400, 3072, 7168), + block=(256, 192, 256), + eliminate_epilogue=True, +): + """Preshuffle-B MXFP4 GEMM with dynamic M, N, K and bf16 output (WaveASM).""" + gemm, options = get_tagged_mxfp4_gemm_preshuffle_b( + shape, + block, + wave_shape=(2, 2), + reorder_workgroups=True, + output_dtype=tkl.bf16, + ) + dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] + for sym in dynamic_symbols: + del options.subs[sym] + options.dynamic_symbols = dynamic_symbols + options.use_buffer_ops = True + options.backend = "asm" + options.use_wave_asm_backend = True + options.wave_runtime = True + options.eliminate_epilogue = eliminate_epilogue + options.dump_intermediates = ( + "build/intermediates/waveasm_256x192x256_bf16_baseline/" + ) + options.print_mlir_file = "gemm_mxfp4_dbuf_4wave_asymmetric_bf16.mlir" + options.print_mlir = True + schedule = get_mxfp4_asymmetric_schedule( + eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True + ) + options.print_ir_after = "all" if is_debug else [] + options = set_default_run_config(options) + gemm = wave_compile(options, gemm, schedule) + + with open( + "build/intermediates/waveasm_256x192x256_bf16_baseline/gemm_mxfp4_dbuf_4wave_asymmetric_bf16.mlir", + "w", + ) as f: + f.write(gemm.asm) + + _run_mxfp_gemm_preshuffle(gemm, shape, all=True, output_dtype=torch.bfloat16) + print( + "MXFP GEMM preshuffle-B 4-wave dynamic M, N, K bf16 (WaveASM backend) test passed!" + ) + + +def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm_bf16_coalesced( + is_debug=False, + shape=(6400, 3072, 7168), + block=(256, 192, 256), + eliminate_epilogue=True, +): + """Preshuffle-B MXFP4 GEMM bf16 with coalesced epilogue stores via permlane swap (WaveASM).""" + gemm, options = get_tagged_mxfp4_gemm_preshuffle_b( + shape, + block, + wave_shape=(2, 2), + reorder_workgroups=True, + output_dtype=tkl.bf16, + transpose_output=True, + ) + dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] + for sym in dynamic_symbols: + del options.subs[sym] + options.dynamic_symbols = dynamic_symbols + options.use_buffer_ops = True + options.backend = "asm" + options.use_wave_asm_backend = True + options.wave_runtime = True + options.eliminate_epilogue = eliminate_epilogue + options.coalesce_epilogue_stores = True + options.dump_intermediates = ( + "build/intermediates/waveasm_256x192x256_bf16_coalesced/" + ) + options.print_mlir_file = "gemm_mxfp4_dbuf_4wave_asymmetric_bf16_coalesced.mlir" + options.print_mlir = True + schedule = get_mxfp4_asymmetric_schedule( + eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True + ) + options.print_ir_after = "all" if is_debug else [] + options = set_default_run_config(options) + gemm = wave_compile(options, gemm, schedule) + + with open( + "build/intermediates/waveasm_256x192x256_bf16_coalesced/gemm_mxfp4_dbuf_4wave_asymmetric_bf16_coalesced.mlir", + "w", + ) as f: + f.write(gemm.asm) + + _run_mxfp_gemm_preshuffle( + gemm, shape, all=True, output_dtype=torch.bfloat16, transpose_output=True + ) + print( + "MXFP GEMM preshuffle-B 4-wave bf16 coalesced epilogue (WaveASM backend) test passed!" + ) if __name__ == "__main__": diff --git a/examples/python/test_sympy_diff.py b/examples/python/test_sympy_diff.py index a9efd6a6b2..294d7dc7ec 100644 --- a/examples/python/test_sympy_diff.py +++ b/examples/python/test_sympy_diff.py @@ -33,7 +33,11 @@ def check(name, got, expected): global PASS, FAIL - ok = sympy.expand(got - expected) == 0 if not isinstance(got, bool) else got == expected + ok = ( + sympy.expand(got - expected) == 0 + if not isinstance(got, bool) + else got == expected + ) if ok: PASS += 1 else: diff --git a/tests/unittests/index_mapping_simplify_test.py b/tests/unittests/index_mapping_simplify_test.py index b11d18753a..03db6cfa49 100644 --- a/tests/unittests/index_mapping_simplify_test.py +++ b/tests/unittests/index_mapping_simplify_test.py @@ -82,10 +82,7 @@ def test_no_simplification_b_data_preshuffle(self): k_it = IndexMapping.iterator(1) within_nblk = ( - (k_it // 32) * 512 - + ((k_it // 16) % 2) * 256 - + (n_it % 16) * 16 - + k_it % 16 + (k_it // 32) * 512 + ((k_it // 16) % 2) * 256 + (n_it % 16) * 16 + k_it % 16 ) K_PACKED = K // 2 @@ -106,8 +103,10 @@ class TestExprBoundsWithIters: def test_iterator_bounds(self): i0 = IndexMapping.iterator(0) i1 = IndexMapping.iterator(1) - bounds = {i0: (sympy.Integer(0), sympy.Integer(15)), - i1: (sympy.Integer(0), sympy.Integer(63))} + bounds = { + i0: (sympy.Integer(0), sympy.Integer(15)), + i1: (sympy.Integer(0), sympy.Integer(63)), + } assert _expr_bounds_with_iters(i0, bounds) == (0, 15) assert _expr_bounds_with_iters(i1, bounds) == (0, 63) @@ -116,14 +115,13 @@ def test_within_nblk_bounds(self): """within_nblk for tile [0,15]x[0,63] is bounded to [0,1023].""" n_it = IndexMapping.iterator(0) k_it = IndexMapping.iterator(1) - bounds = {n_it: (sympy.Integer(0), sympy.Integer(15)), - k_it: (sympy.Integer(0), sympy.Integer(63))} + bounds = { + n_it: (sympy.Integer(0), sympy.Integer(15)), + k_it: (sympy.Integer(0), sympy.Integer(63)), + } within_nblk = ( - (k_it // 32) * 512 - + ((k_it // 16) % 2) * 256 - + (n_it % 16) * 16 - + k_it % 16 + (k_it // 32) * 512 + ((k_it // 16) % 2) * 256 + (n_it % 16) * 16 + k_it % 16 ) result = _expr_bounds_with_iters(within_nblk, bounds) assert result is not None diff --git a/wave_lang/kernel/_support/indexing.py b/wave_lang/kernel/_support/indexing.py index 61a02bacca..0966b7928f 100644 --- a/wave_lang/kernel/_support/indexing.py +++ b/wave_lang/kernel/_support/indexing.py @@ -124,6 +124,7 @@ def _resolve_chained_subs( if pending: import warnings + cycle_keys = sorted(str(k) for k in pending.keys()) warnings.warn( f"_resolve_chained_subs: circular dependency among" diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index ad657294fd..e01ddc8c20 100755 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -14,7 +14,9 @@ from wave_lang.kernel.wave.utils.graph_utils import propagate_loop_carried_vars from wave_lang.support.ir_imports import ( Attribute, + BF16Type, DenseElementsAttr, + F32Type, IndexType, InsertionPoint, IntegerAttr, @@ -31,6 +33,7 @@ gpu_d, llvm_d, memref_d, + rocdl_d, vector_d, ) from .ir_utils import ( @@ -460,9 +463,7 @@ def _compute_branchless_valid_bytes( real_valid_index = gen_sympy_index(subs_map, total_bytes_expr) real_valid = arith_d.index_cast(uint64, real_valid_index) else: - real_valid = arith_d.constant( - uint64, get_constant_attr(total_bytes, uint64) - ) + real_valid = arith_d.constant(uint64, get_constant_attr(total_bytes, uint64)) zero_valid = arith_d.constant(uint64, get_constant_attr(0, uint64)) @@ -501,9 +502,7 @@ def _compute_valid_bytes( real_valid_index = gen_sympy_index(subs_map, total_bytes_expr) total_val = arith_d.index_cast(uint64, real_valid_index) else: - total_val = arith_d.constant( - uint64, get_constant_attr(total_bytes, uint64) - ) + total_val = arith_d.constant(uint64, get_constant_attr(total_bytes, uint64)) metadata = memref_d.extract_strided_metadata(ptr) offset_elements = metadata[1] offset_bytes = arith_d.index_cast(uint64, offset_elements) @@ -868,7 +867,9 @@ def _try_iv_split_offset( strides: list[int | IndexExpr], dynamic_vals: dict[IndexExpr, Any], use_subs_idxc: bool = True, - precomputed_iv_stride: dict[sympy.Symbol, IndexExpr | list[IndexExpr]] | None = None, + precomputed_iv_stride: ( + dict[sympy.Symbol, IndexExpr | list[IndexExpr]] | None + ) = None, **kwargs, ) -> Optional[Value]: """Compute a hoisted IV-split linearized offset for a loop-carried read. @@ -941,9 +942,7 @@ def _try_iv_split_offset( lin_offset = ( term if lin_offset is None - else arith_d.addi( - lin_offset, term, overflow_flags=overflow_flags - ) + else arith_d.addi(lin_offset, term, overflow_flags=overflow_flags) ) iv_mlir = subs_map.get(iv_sym) @@ -1203,9 +1202,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): subs_map = add_emitter_subs(emitter, dynamic_vals_map_start) with hoist_ip: strides_vals = [gen_sympy_index(subs_map, s) for s in iv_strides] - zero_indices = [arith_d.constant(IndexType.get(), 0)] * len( - iv_strides - ) + zero_indices = [arith_d.constant(IndexType.get(), 0)] * len(iv_strides) lin_src, _ = _linearize_memref( kb_src, zero_indices, zero_indices, strides_vals ) @@ -1232,7 +1229,9 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): result = vector_d.load(vector_type, lin_src, [total_offset]) else: element_type = vector_type.element_type - zero = arith_d.constant(element_type, get_constant_attr(0, element_type)) + zero = arith_d.constant( + element_type, get_constant_attr(0, element_type) + ) passthru = vector_d.broadcast(vector_type, zero) result = vector_d.maskedload( vector_type, lin_src, [total_offset], mask, passthru @@ -1335,6 +1334,24 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): ) use_llvm_store = flags != MemoryAccessFlags.NONE + + is_shared = get_custom(memory).type.address_space == SHARED_ADDRESS_SPACE + is_bf16 = isinstance(element_type, BF16Type) + + if not is_shared and is_bf16 and getattr(node, "_permlane_pack_global", False): + _write_permlane_pack_to_global( + emitter, + insert_vector, + kb_dest, + output_shape, + start_indices, + start_indices_wg, + start_indices_th, + get_custom(memory), + index, + ) + return + if use_llvm_store: _create_llvm_read_write( kb_dest, kb_ir_type, start_indices, insert_type, flags, insert_vector @@ -1356,6 +1373,130 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): ) +def _write_permlane_pack_to_global( + emitter: WaveEmitter, + insert_vector: Value, + kb_dest: Value, + output_shape: tuple, + start_indices: tuple, + start_indices_wg: tuple, + start_indices_th: tuple, + memory_custom, + index: dict, +): + """Pack two lanes' bf16 values via permlane16_swap for wide global stores. + + MMA accumulator layout (F32_16x16x128_F8F6F4) gives each thread 4 + consecutive M values. Lanes are grouped by 16: lanes 0-15 own M=0-3, + lanes 16-31 own M=4-7, etc. ``v_permlane16_swap_b32`` exchanges data + between paired groups, giving each lane 8 consecutive M values that + can be written as a single ``buffer_store_dwordx4`` (128 bits). + + Both lane halves produce identical data at the same address (benign + duplicate store): + + - Lower half (lanes 0-15 in each 32-lane group): + data = [own, partner], address = thread's original M index. + - Upper half (lanes 16-31): + data = [partner, own], address = original M index - 4. + + This dual-write avoids divergent control flow (no scf.if / exec + masking needed). The buffer descriptor's ``valid_bytes`` handles + out-of-bounds suppression for dynamic shapes. + + Precondition: M must be the innermost (last) memory dimension with + stride 1 (i.e. transpose_output=True, shape [N, M]). + """ + f32_type = F32Type.get() + i32_type = IntegerType.get_signless(32) + idx_type = IndexType.get() + bf16_type = BF16Type.get() + + # waveasm defers vector arith.truncf (f32->bf16), so insert_vector + # is nominally vector<4xbf16> but the underlying data is still f32 + # in AGPRs. Get the f32 source directly from the defining truncf op. + truncf_op = insert_vector.owner + if truncf_op.name == "arith.truncf": + f32_vec = truncf_op.operands[0] + else: + f32_vec = insert_vector + + # Extract 4 f32 accumulator values. + e = [ + vector_d.extract(f32_vec, static_position=[i], dynamic_position=[]) + for i in range(4) + ] + + # Swap each f32 value with the partner lane (16 positions apart). + swap_type = llvm_d.StructType.get_literal([i32_type, i32_type]) + + p = [] + for ei in e: + ei_i32 = arith_d.bitcast(i32_type, ei) + swapped_i32 = llvm_d.extractvalue( + i32_type, + rocdl_d.permlane16_swap(swap_type, ei_i32, ei_i32, False, False), + [0], + ) + p.append(arith_d.bitcast(f32_type, swapped_i32)) + + # Determine lane position within each 32-lane half-wavefront. + lane_in_wave = arith_d.remui(emitter.thread_ids[0], arith_d.constant(idx_type, 64)) + half_pos = arith_d.remui(lane_in_wave, arith_d.constant(idx_type, 32)) + is_lower = arith_d.cmpi( + arith_d.CmpIPredicate.ult, half_pos, arith_d.constant(idx_type, 16) + ) + + four = arith_d.constant(idx_type, 4) + v2f32_type = VectorType.get([2], f32_type) + v2bf16_type = VectorType.get([2], bf16_type) + + adj_th = list(start_indices_th) + adj_th[-1] = arith_d.select(is_lower, adj_th[-1], arith_d.subi(adj_th[-1], four)) + + adj_full = list(start_indices) + adj_full[-1] = arith_d.select( + is_lower, adj_full[-1], arith_d.subi(adj_full[-1], four) + ) + + # Select values: own for lower lanes, partner for upper (and vice versa). + s_lo = [arith_d.select(is_lower, e[i], p[i]) for i in range(4)] + s_hi = [arith_d.select(is_lower, p[i], e[i]) for i in range(4)] + + # Emit 4 stores of vector<2xbf16> (= buffer_store_dword each). + # Each pair of f32 values is packed into one bf16 dword by + # v_cvt_pk_bf16_f32. Using 2-element stores avoids the multi-dword + # PackOp, which the register allocator cannot handle (it does not + # insert copies for PackOp operands). + all_vals = s_lo + s_hi + for pair_idx in range(4): + pair_f32 = vector_d.from_elements( + v2f32_type, [all_vals[pair_idx * 2], all_vals[pair_idx * 2 + 1]] + ) + pair_bf16 = arith_d.truncf(v2bf16_type, pair_f32) + + elem_offset = arith_d.constant(idx_type, pair_idx * 2) + cur_th = list(adj_th) + cur_th[-1] = arith_d.addi(adj_th[-1], elem_offset) + cur_full = list(adj_full) + cur_full[-1] = arith_d.addi(adj_full[-1], elem_offset) + + _create_vec_read_write( + emitter, + output_shape, + kb_dest, + pair_bf16, + None, + tuple(cur_full), + start_indices_wg, + tuple(cur_th), + 2, + memory_custom, + None, + node_index=index, + ) + + def assume_index_subgroup_uniform(value: Value, element_type: IrType) -> Value: res = gpu_d.subgroup_broadcast(value, gpu_d.BroadcastType.first_active_lane) return res @@ -1680,9 +1821,7 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): dynamic_values=src_dynamic_vals_map_start, ) if mask: - mask = vector_d.extract( - mask, static_position=[0], dynamic_position=[] - ) + mask = vector_d.extract(mask, static_position=[0], dynamic_position=[]) oob_index_value = _get_out_of_bounds_index(element_type) oob_index = arith_d.constant(IndexType.get(), oob_index_value) src_offset = arith_d.select(mask, src_offset, oob_index) diff --git a/wave_lang/kernel/wave/analysis/annotate_iv_strides.py b/wave_lang/kernel/wave/analysis/annotate_iv_strides.py index bcf6ab1a82..c23ac3e02e 100644 --- a/wave_lang/kernel/wave/analysis/annotate_iv_strides.py +++ b/wave_lang/kernel/wave/analysis/annotate_iv_strides.py @@ -33,9 +33,7 @@ def annotate_iv_strides( """Annotate every mapped Read/GatherToLDS with ``meta["iv_stride"]``.""" idxc = IndexingContext.current() - for node in trace.walk( - lambda n: isinstance(get_custom(n), (Read, GatherToLDS)) - ): + for node in trace.walk(lambda n: isinstance(get_custom(n), (Read, GatherToLDS))): if node.meta.get("iv_stride") is not None: continue @@ -62,8 +60,11 @@ def annotate_iv_strides( ) iv_stride = compute_iv_stride_through_mapping( - mapping, symbolic_shape, index, - is_read=True, mem_strides=list(phys_strides), + mapping, + symbolic_shape, + index, + is_read=True, + mem_strides=list(phys_strides), constraints=constraints, ) if iv_stride is not None: diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index 70c9abfb8f..d9da8257b4 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -573,6 +573,11 @@ def build_graph_passes( partial(guard_g2s_with_bounds_check, trace, launchable.constraints) ) + if options.coalesce_epilogue_stores: + from .coalesce_epilogue_stores import coalesce_epilogue_stores + + graph_passes.append(partial(coalesce_epilogue_stores, trace)) + if options.optimization_level: graph_passes += [ partial( @@ -1322,6 +1327,7 @@ def _generate_asm_code(mb, options): # Debug: save a copy of the MLIR input to waveasm-translate import shutil + shutil.copy(mlir_path, "/tmp/waveasm_input.mlir") try: @@ -1350,7 +1356,7 @@ def _generate_asm_code(mb, options): # TODO: improve Ticketing logic (better latency-covering heuristics, # smarter coalescing) so ticketed waitcnt can be always-on without # a performance hit, removing this wave-shape conditional. - use_ticketed_waitcnt = False + use_ticketed_waitcnt = False waitcnt_flag = ( "--waveasm-insert-waitcnt" if use_ticketed_waitcnt @@ -1381,7 +1387,9 @@ def _run_translate(extra_passes): ir_dump_path = os.environ.get("WAVEASM_DUMP_IR") if ir_dump_path: full_cmd.append("--mlir-print-ir-after-all") - result = subprocess.run(full_cmd, capture_output=True, text=True, timeout=120) + result = subprocess.run( + full_cmd, capture_output=True, text=True, timeout=120 + ) if ir_dump_path and result.stderr: os.makedirs(os.path.dirname(ir_dump_path) or ".", exist_ok=True) with open(ir_dump_path, "w") as f: diff --git a/wave_lang/kernel/wave/compile_options.py b/wave_lang/kernel/wave/compile_options.py index fafb01453b..4019e14494 100644 --- a/wave_lang/kernel/wave/compile_options.py +++ b/wave_lang/kernel/wave/compile_options.py @@ -104,6 +104,8 @@ class WaveCompileOptions: specialize: bool = False eliminate_epilogue: bool = False + coalesce_epilogue_stores: bool = False + # Cluster barrier signal/wait delay in number of loop iterations # None - no barriers inside the loop # 0 - signal and wait on same iteration diff --git a/wave_lang/kernel/wave/index_mapping_simplify.py b/wave_lang/kernel/wave/index_mapping_simplify.py index ee2335a499..7f568f3a7b 100644 --- a/wave_lang/kernel/wave/index_mapping_simplify.py +++ b/wave_lang/kernel/wave/index_mapping_simplify.py @@ -23,12 +23,10 @@ import sympy from collections.abc import Sequence -from functools import lru_cache from ..lang.wave_types import IndexMapping from .utils.symbol_utils import ( _split_sum_by_divisibility, - expr_bounds, IndexExpr, IndexSymbol, subs_idxc, @@ -290,9 +288,7 @@ def simplify_index_mapping( except (IndexError, KeyError): divisor_lb = divisor if sym_lower_bounds: - divisor_lb = divisor_lb.subs( - {s: lb for s, lb in sym_lower_bounds.items()} - ) + divisor_lb = divisor_lb.subs({s: lb for s, lb in sym_lower_bounds.items()}) # Evaluate floor/ceiling after substitution. try: divisor_lb = sympy.Integer(int(divisor_lb)) diff --git a/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py b/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py index b595bbbb8d..14381bf84c 100755 --- a/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py +++ b/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py @@ -149,6 +149,8 @@ def _get_tagged_mxfp4_gemm_preshuffle_scales_impl( b_preshuffled: bool = False, reorder_workgroups: bool = False, group_size_n=32, + output_dtype=tkl.f32, + transpose_output: bool = False, ): """Shared implementation: preshuffle scales only, or scales + B data. @@ -159,6 +161,9 @@ def _get_tagged_mxfp4_gemm_preshuffle_scales_impl( is controlled by the selected address spaces (`a_address_space` and `b_address_space`). + When transpose_output is True, the output memory is [N, M] instead of [M, N], + producing C^T in row-major layout. This makes per-lane MMA accumulator + elements contiguous in the M (fast) dimension of the output. """ M = tkl.sym.M N = tkl.sym.N @@ -179,6 +184,14 @@ def _get_tagged_mxfp4_gemm_preshuffle_scales_impl( constraints += [tkw.WaveConstraint(N, BLOCK_N / wave_shape[1])] constraints += [tkw.HardwareConstraint(threads_per_wave=64, mma_type=mfma_variant)] + constraints += [tkw.Assumption(Eq(M % 32, 0))] + constraints += [tkw.Assumption(Eq(N % 32, 0))] + constraints += [tkw.Assumption(Eq(K % 256, 0))] + constraints += [tkw.Assumption(Eq(K % BLOCK_K, 0))] + constraints += [tkw.Assumption(Eq(M % BLOCK_M, 0))] + constraints += [tkw.Assumption(Eq(N % BLOCK_N, 0))] + constraints += [tkw.Assumption(K > BLOCK_K * 6)] + if reorder_workgroups: new_wg0, new_wg1 = _reorder_mxfp4_workgroups( M, N, BLOCK_M, BLOCK_N, GROUP_SIZE_N @@ -243,13 +256,26 @@ def _get_tagged_mxfp4_gemm_preshuffle_scales_impl( outputs={K: k_s, N: n_s}, ) + c_dim_0, c_dim_1 = (N, M) if transpose_output else (M, N) + + if transpose_output: + c_it_m = tkw.IndexMapping.iterator(0) + c_it_n = tkw.IndexMapping.iterator(1) + c_write_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={M: c_it_m, N: c_it_n}, + outputs={N: c_it_n, M: c_it_m}, + ) + else: + c_write_mapping = None + @tkw.wave(constraints) def gemm( a: tkl.Memory[M, K / 2, A_ADDRESS_SPACE, tkl.i8], a_scale: tkl.Memory[M, K / 32, GLOBAL_ADDRESS_SPACE, tkl.i8], b: tkl.Memory[N, K / 2, B_ADDRESS_SPACE, tkl.i8], b_scale: tkl.Memory[N, K / 32, GLOBAL_ADDRESS_SPACE, tkl.i8], - c: tkl.Memory[M, N, C_ADDRESS_SPACE, tkl.f32], + c: tkl.Memory[c_dim_0, c_dim_1, C_ADDRESS_SPACE, output_dtype], ): c_reg = tkl.Register[M, N, tkl.f32](0.0) @@ -273,7 +299,13 @@ def repeat( ) return acc - tkw.write(repeat, c) + if output_dtype == tkl.bf16: + repeat = tkw.cast(repeat, tkl.bf16) + + if c_write_mapping is not None: + tkw.write(repeat, c, mapping=c_write_mapping, elements_per_thread=4) + else: + tkw.write(repeat, c) hyperparams = { A_ADDRESS_SPACE: a_address_space, @@ -290,7 +322,7 @@ def repeat( M: shape[0], N: shape[1], K: shape[2], - K_SCALE_SHUFFLED: (((shape[2] // 32) + 7) // 8) * 8, + K_SCALE_SHUFFLED: (((K // 32) + 7) // 8) * 8, } if b_preshuffled: hyperparams[K_PACKED] = K // 2 @@ -348,6 +380,8 @@ def get_tagged_mxfp4_gemm_preshuffle_scales_and_B( mfma_variant: ScaledMMAType = ScaledMMAType.F32_16x16x128_F8F6F4, a_address_space: tkl.AddressSpace = SHARED_ADDRESS_SPACE, b_address_space: tkl.AddressSpace | None = None, + output_dtype=tkl.f32, + transpose_output: bool = False, ): """Return a tagged MXFP4 scaled GEMM kernel with preshuffled B and B_scale. @@ -363,6 +397,7 @@ def get_tagged_mxfp4_gemm_preshuffle_scales_and_B( mfma_variant: Scaled MMA instruction type. a_address_space: Address space for A. b_address_space: Address space for B. + transpose_output: If True, output memory is [N, M] instead of [M, N]. Returns: (kernel_function, WaveCompileOptions) """ @@ -374,6 +409,8 @@ def get_tagged_mxfp4_gemm_preshuffle_scales_and_B( a_address_space, b_address_space, b_preshuffled=True, + output_dtype=output_dtype, + transpose_output=transpose_output, ) @@ -387,6 +424,7 @@ def get_tagged_mxfp4_gemm_preshuffle_b( reorder_workgroups=True, group_size_n=32, output_dtype=tkl.f32, + transpose_output: bool = False, ): """Return a tagged MXFP4 scaled GEMM kernel with preshuffled B and B_scale. @@ -520,13 +558,26 @@ def get_tagged_mxfp4_gemm_preshuffle_b( outputs={K: k_s, N: n_s}, ) + c_dim_0, c_dim_1 = (N, M) if transpose_output else (M, N) + + if transpose_output: + c_it_m = tkw.IndexMapping.iterator(0) + c_it_n = tkw.IndexMapping.iterator(1) + c_write_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={M: c_it_m, N: c_it_n}, + outputs={N: c_it_n, M: c_it_m}, + ) + else: + c_write_mapping = None + @tkw.wave(constraints) def gemm( a: tkl.Memory[M, K / 2, A_ADDRESS_SPACE, tkl.i8], a_scale: tkl.Memory[M, K / 32, A_ADDRESS_SPACE, tkl.i8], b: tkl.Memory[N, K / 2, GLOBAL_ADDRESS_SPACE, tkl.i8], b_scale: tkl.Memory[N, K / 32, GLOBAL_ADDRESS_SPACE, tkl.i8], - c: tkl.Memory[M, N, C_ADDRESS_SPACE, output_dtype], + c: tkl.Memory[c_dim_0, c_dim_1, C_ADDRESS_SPACE, output_dtype], ): c_reg = tkl.Register[M, N, tkl.f32](0.0) @@ -549,7 +600,11 @@ def repeat( if output_dtype == tkl.bf16: repeat = tkw.cast(repeat, tkl.bf16) - tkw.write(repeat, c) + + if c_write_mapping is not None: + tkw.write(repeat, c, mapping=c_write_mapping, elements_per_thread=4) + else: + tkw.write(repeat, c) hyperparams = { A_ADDRESS_SPACE: a_address_space, diff --git a/wave_lang/kernel/wave/utils/symbol_utils.py b/wave_lang/kernel/wave/utils/symbol_utils.py index 13a5e993c0..93dbeb85a7 100644 --- a/wave_lang/kernel/wave/utils/symbol_utils.py +++ b/wave_lang/kernel/wave/utils/symbol_utils.py @@ -167,9 +167,7 @@ def _split_coeff(expr: sympy.Expr) -> tuple[sympy.Integer, sympy.Expr]: return (sympy.Integer(1), expr) -def _contains_factor( - factors: list[sympy.Expr], target: sympy.Expr -) -> bool: +def _contains_factor(factors: list[sympy.Expr], target: sympy.Expr) -> bool: """Check if *target* appears as a factor in *factors* (possibly nested).""" for f in factors: if f == target: @@ -365,7 +363,9 @@ def transform_mod_div(expr): pass # Symbolic comparison — can't determine. return sympy.Mod(remainder, q, evaluate=False) - expr = expr.replace(lambda e: transform_floor_div(e) is not None, transform_floor_div) + expr = expr.replace( + lambda e: transform_floor_div(e) is not None, transform_floor_div + ) expr = expr.replace(lambda e: transform_mod_div(e) is not None, transform_mod_div) expr = expr.replace(lambda e: transform_mod(e) is not None, transform_mod) expr = expr.replace(lambda e: transform_floor(e) is not None, transform_floor) @@ -428,11 +428,8 @@ def _rewrite_floor(arg): if iv_coeff == 0: return sympy.floor(arg) rest = numer - iv_coeff * iv - return ( - sympy.floor(iv_coeff / denom) * iv - + sympy.floor( - (sympy.Mod(iv_coeff, denom, evaluate=False) * iv + rest) / denom - ) + return sympy.floor(iv_coeff / denom) * iv + sympy.floor( + (sympy.Mod(iv_coeff, denom, evaluate=False) * iv + rest) / denom ) def _rewrite_mod(*args): @@ -459,6 +456,7 @@ def _rewrite_mod(*args): return None + def simplify_divisor_multiples(expr: sympy.Expr) -> sympy.Expr: """Factor out divisor-multiples from floor/Mod without expand/cancel. diff --git a/waveasm/include/waveasm/Dialect/WaveASMOps.td b/waveasm/include/waveasm/Dialect/WaveASMOps.td index e36602bf46..869792e7c1 100644 --- a/waveasm/include/waveasm/Dialect/WaveASMOps.td +++ b/waveasm/include/waveasm/Dialect/WaveASMOps.td @@ -493,6 +493,19 @@ def WaveASM_V_LOG_F32 : VALUUnaryOp<"v_log_f32">; def WaveASM_V_SIN_F32 : VALUUnaryOp<"v_sin_f32">; def WaveASM_V_COS_F32 : VALUUnaryOp<"v_cos_f32">; +// Lane-swap operations +def WaveASM_V_PERMLANE16_SWAP_B32 : WAVEASMOp<"v_permlane16_swap_b32", [Pure]> { + let summary = "Swap data between lanes 16 positions apart"; + let description = [{ + Exchanges the value in each lane with the lane 16 positions apart + within a 32-lane group. Lane i (i < 16) swaps with lane i + 16. + Returns a struct {swapped, original} as two separate VGPR results. + }]; + let arguments = (ins WaveASM_AnyVGPR:$src0, WaveASM_AnyVGPR:$src1); + let results = (outs WaveASM_AnyVGPR:$dst0, WaveASM_AnyVGPR:$dst1); + let assemblyFormat = "$src0 `,` $src1 attr-dict `:` type($src0) `,` type($src1) `->` type($dst0) `,` type($dst1)"; +} + // Lane operations (VGPR -> SGPR) def WaveASM_V_READFIRSTLANE_B32 : WAVEASMOp<"v_readfirstlane_b32", [Pure]> { let summary = "Read first lane of VGPR to SGPR"; @@ -675,8 +688,12 @@ def WaveASM_V_WRITELANE_B32 : WAVEASMOp<"v_writelane_b32", [Pure]> { } // VGPR/AGPR cross-bank moves -def WaveASM_V_ACCVGPR_READ_B32 : WAVEASMOp<"v_accvgpr_read_b32", [Pure]> { +def WaveASM_V_ACCVGPR_READ_B32 : WAVEASMOp<"v_accvgpr_read_b32"> { let summary = "Read AGPR lane into VGPR"; + // NOT Pure: duplicate reads must not be CSE-folded. + // v_permlane16_swap_b32 modifies its src register in-place; if CSE + // merges two reads and the swap clobbers the shared result, later + // uses see the wrong value. let arguments = (ins WaveASM_AnyAGPR:$src); let results = (outs WaveASM_AnyVGPR:$dst); let assemblyFormat = "$src attr-dict `:` type($src) `->` type($dst)"; diff --git a/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h b/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h index 2fc46f39cd..8332977074 100644 --- a/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h +++ b/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h @@ -52,8 +52,22 @@ class ValueMapper { return valueMap.contains(mlirValue); } + /// Map a second result for multi-result ops (e.g. permlane16_swap) + void mapSecondResult(mlir::Value inputValue, mlir::Value asmValue) { + secondResultMap[inputValue] = asmValue; + } + + /// Get the second result for a multi-result op + std::optional getSecondResult(mlir::Value mlirValue) const { + auto it = secondResultMap.find(mlirValue); + if (it != secondResultMap.end()) + return it->second; + return std::nullopt; + } + private: llvm::DenseMap valueMap; + llvm::DenseMap secondResultMap; }; //===----------------------------------------------------------------------===// diff --git a/waveasm/lib/Transforms/AssemblyEmitter.cpp b/waveasm/lib/Transforms/AssemblyEmitter.cpp index f6ce8afb8b..3d2c5b6ae0 100644 --- a/waveasm/lib/Transforms/AssemblyEmitter.cpp +++ b/waveasm/lib/Transforms/AssemblyEmitter.cpp @@ -970,6 +970,22 @@ std::optional KernelGenerator::generateOp(Operation *op) { return formatter.format("v_cvt_pk_bf16_f32", operands); }) + // V_PERMLANE16_SWAP_B32: swap data between lanes 16 apart + // GCN ISA: v_permlane16_swap_b32 vDST, vSRC (VOP1, 2 operands) + // Hardware hazard: VALU writes to vSRC need 2 wait states before + // permlane can read it (VALUWritesVDstWaitStates=2). Insert s_nop 1 + // unconditionally; the preceding VALU provides 1 wait state, the + // nop provides the second. + .Case( + [&](V_PERMLANE16_SWAP_B32 permOp) -> std::optional { + llvm::SmallVector operands; + operands.push_back(resolveValue(permOp.getDst0())); + operands.push_back(resolveValue(permOp.getSrc0())); + std::string swapStr = + formatter.format("v_permlane16_swap_b32", operands); + return std::string("s_nop 1\n ") + swapStr; + }) + // Carry ops: on GFX9, carry-out is implicit VCC. // v_add_co_u32: dst, vcc, src0, src1 // v_addc_co_u32: dst, vcc, src0, src1, vcc (carry-in). diff --git a/waveasm/lib/Transforms/LinearScanPass.cpp b/waveasm/lib/Transforms/LinearScanPass.cpp index 2012a96d3e..e87dffe31a 100644 --- a/waveasm/lib/Transforms/LinearScanPass.cpp +++ b/waveasm/lib/Transforms/LinearScanPass.cpp @@ -312,6 +312,13 @@ struct LinearScanPass // Result should be allocated to same physical register as accumulator tiedPairs[op->getResult(0)] = acc; } + } else if (isa(op)) { + // v_permlane16_swap_b32 modifies vSRC in-place. The second result + // captures the new vSRC value and must share the same physical + // register so the register allocator keeps the register live. + if (op->getNumResults() >= 2 && op->getNumOperands() >= 1) { + tiedPairs[op->getResult(1)] = op->getOperand(0); + } } }); diff --git a/waveasm/lib/Transforms/LinearScanRegAlloc.cpp b/waveasm/lib/Transforms/LinearScanRegAlloc.cpp index 429c6cf3b6..7fae1b0c70 100644 --- a/waveasm/lib/Transforms/LinearScanRegAlloc.cpp +++ b/waveasm/lib/Transforms/LinearScanRegAlloc.cpp @@ -195,8 +195,8 @@ allocateRegClass(ArrayRef ranges, RegPool &pool, int64_t maxEnd = ranges.back().end; int64_t threshold = (maxEnd * 3) / 4; if (rangeLength > threshold) - physReg = tryAllocateFromTop(pool, range.size, range.alignment, - maxPressure); + physReg = + tryAllocateFromTop(pool, range.size, range.alignment, maxPressure); else physReg = tryAllocate(pool, range.size, range.alignment); } diff --git a/waveasm/lib/Transforms/SCCVerifier.cpp b/waveasm/lib/Transforms/SCCVerifier.cpp index b65fc45a42..f1c83b8e0a 100644 --- a/waveasm/lib/Transforms/SCCVerifier.cpp +++ b/waveasm/lib/Transforms/SCCVerifier.cpp @@ -38,9 +38,7 @@ namespace { /// Returns true if the operation writes the SCC flag on hardware. /// Uses hasTrait for ops that carry the trait (carry ops, cmp ops), /// and isa<> for ops still on SALUBinaryOp/SALUUnaryOp (pending migration). -static bool writesSCC(Operation *op) { - return op->hasTrait(); -} +static bool writesSCC(Operation *op) { return op->hasTrait(); } struct SCCVerifierPass : public waveasm::impl::WAVEASMSCCVerifierBase { @@ -54,22 +52,26 @@ struct SCCVerifierPass errorCount += verifyBlock(block); }); if (errorCount > 0) { - LLVM_DEBUG(llvm::dbgs() << "SCC verifier: found " << errorCount - << " SCC hazard(s)\n"); + LLVM_DEBUG(llvm::dbgs() + << "SCC verifier: found " << errorCount << " SCC hazard(s)\n"); signalPassFailure(); } } private: static SmallVector findSCCClobbersBetween(Operation *producer, - Operation *consumer) { + Operation *consumer) { SmallVector clobbers; if (!producer || !consumer || producer->getBlock() != consumer->getBlock()) return clobbers; bool inRange = false; for (Operation &op : *producer->getBlock()) { - if (&op == producer) { inRange = true; continue; } - if (&op == consumer) break; + if (&op == producer) { + inRange = true; + continue; + } + if (&op == consumer) + break; if (inRange && writesSCC(&op)) clobbers.push_back(&op); } @@ -77,7 +79,7 @@ struct SCCVerifierPass } static void emitSCCClobberError(Operation *consumer, Operation *producer, - ArrayRef clobbers) { + ArrayRef clobbers) { auto diag = consumer->emitError() << "SCC hazard: " << clobbers.size() << " SCC-clobbering op(s) between SCC producer '" @@ -121,8 +123,7 @@ struct SCCVerifierPass ++errors; } if (isa(&op) && !lastSCCWriter) { - op.emitError() - << "SCC hazard: s_addc_u32 has no preceding SCC writer"; + op.emitError() << "SCC hazard: s_addc_u32 has no preceding SCC writer"; ++errors; } if (writesSCC(&op)) diff --git a/waveasm/lib/Transforms/ScopedCSE.cpp b/waveasm/lib/Transforms/ScopedCSE.cpp index 0db4f14860..b3120fcf73 100644 --- a/waveasm/lib/Transforms/ScopedCSE.cpp +++ b/waveasm/lib/Transforms/ScopedCSE.cpp @@ -156,7 +156,8 @@ bool isCSEEligible(Operation *op) { return false; // SCC-reading ops are NOT CSE-eligible: their result depends on implicit - // SCC state, so two ops with identical operands can produce different results. + // SCC state, so two ops with identical operands can produce different + // results. if (op->hasTrait()) return false; diff --git a/waveasm/lib/Transforms/TranslateFromMLIR.cpp b/waveasm/lib/Transforms/TranslateFromMLIR.cpp index 0d98072715..8b8e7f9bc4 100644 --- a/waveasm/lib/Transforms/TranslateFromMLIR.cpp +++ b/waveasm/lib/Transforms/TranslateFromMLIR.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -776,6 +777,7 @@ LogicalResult handleArithShRSI(Operation *op, TranslationContext &ctx); LogicalResult handleArithExtUI(Operation *op, TranslationContext &ctx); LogicalResult handleArithExtSI(Operation *op, TranslationContext &ctx); LogicalResult handleArithTruncI(Operation *op, TranslationContext &ctx); +LogicalResult handleArithBitcast(Operation *op, TranslationContext &ctx); LogicalResult handleArithMinSI(Operation *op, TranslationContext &ctx); LogicalResult handleArithMaxSI(Operation *op, TranslationContext &ctx); LogicalResult handleArithMinUI(Operation *op, TranslationContext &ctx); @@ -1023,8 +1025,7 @@ LogicalResult handleVectorTransferWrite(Operation *op, if (isAGPRType(storeData.getType())) { auto vregType = ctx.createVRegType(numDwords, numDwords > 1 ? numDwords : 1); - storeData = - V_ACCVGPR_READ_B32::create(builder, loc, vregType, storeData); + storeData = V_ACCVGPR_READ_B32::create(builder, loc, vregType, storeData); } if (numDwords == 1) { @@ -1303,6 +1304,15 @@ convertF32ToBF16ForStore(Value srcData, int64_t numElems, Location loc) { Type srcType = srcData.getType(); + // If the data register width already matches the packed bf16 size + // (numElems/2 dwords), the data is already in packed bf16 format + // (e.g. from the permlane16_swap pack path). Skip conversion. + int64_t srcDwords = getRegSize(srcType); + int64_t expectedPackedDwords = (numElems + 1) / 2; + if (srcDwords == expectedPackedDwords && !isAGPRType(srcType)) { + return {srcData, expectedPackedDwords * 4}; + } + // VALU conversion instructions cannot read from AGPR. Move to VGPR first. if (isAGPRType(srcType)) { int64_t size = getRegSize(srcType); @@ -1311,7 +1321,20 @@ convertF32ToBF16ForStore(Value srcData, int64_t numElems, srcType = srcData.getType(); } + // Look through PackOp to get the original operands directly. + // The register allocator does not insert copies for PackOp, so + // ExtractOp from a PackOp whose inputs are at different physical + // registers reads stale data (same issue documented in + // ArithLegalization.cpp::splitI64). + SmallVector packOperands; + if (auto packOp = srcData.getDefiningOp()) { + for (auto operand : packOp.getOperands()) + packOperands.push_back(operand); + } + auto extractF32Elem = [&](int64_t i) -> Value { + if (!packOperands.empty() && i < (int64_t)packOperands.size()) + return packOperands[i]; if (auto pvreg = dyn_cast(srcType)) { int64_t baseIdx = pvreg.getIndex() + i; auto elemType = PVRegType::get(builder.getContext(), baseIdx, 1); @@ -1511,14 +1534,27 @@ LogicalResult handleVectorStore(Operation *op, TranslationContext &ctx) { } else { // Global store - buffer_store_dwordx* with splitting for large vectors + Type elementType = memrefType.getElementType(); + int64_t elementBytes = (elementType.getIntOrFloatBitWidth() + 7) / 8; + + // BF16 store conversion MUST happen before address computation so that + // the PackOp inputs (from vector.from_elements / arith.select) are + // consumed immediately and don't need to survive across the address + // VALU instructions which can clobber their registers. + if (elementType.isBF16() && data.has_value()) { + int64_t numElems = vectorType.getNumElements(); + auto [converted, newNumBytes] = + convertF32ToBF16ForStore(*data, numElems, ctx, builder, loc); + data = converted; + numBytes = newNumBytes; + } + // Compute voffset as byte offset from indices and strides // For 2D memrefs: offset = idx0 * stride0 * elementBytes + idx1 * stride1 * // elementBytes Value voffset; int64_t instOffset = 0; // Constant offset for buffer_store offset:N syntax auto indices = storeOp.getIndices(); - Type elementType = memrefType.getElementType(); - int64_t elementBytes = (elementType.getIntOrFloatBitWidth() + 7) / 8; // Get strides from the memref type SmallVector strides; @@ -1639,16 +1675,6 @@ LogicalResult handleVectorStore(Operation *op, TranslationContext &ctx) { // Check if the source value has split results from a corresponding load auto splitResults = ctx.getSplitResults(storeOp.getValueToStore()); - // BF16 store conversion: the arith.truncf handler defers vector f32->bf16 - // conversion, so data registers may still contain f32 values. - if (elementType.isBF16() && data.has_value()) { - int64_t numElems = vectorType.getNumElements(); - auto [converted, newNumBytes] = - convertF32ToBF16ForStore(*data, numElems, ctx, builder, loc); - data = converted; - numBytes = newNumBytes; - } - // Split large stores into multiple buffer_store_dwordx4 (16 bytes each) // Use the same voffset for all stores, with instOffset for subsequent // chunks Add any constant offset from affine expressions to the base offset @@ -1870,9 +1896,12 @@ LogicalResult handleRawBufferLoad(Operation *op, TranslationContext &ctx); LogicalResult handleRawBufferStore(Operation *op, TranslationContext &ctx); LogicalResult handleMemRefAtomicRMW(Operation *op, TranslationContext &ctx); LogicalResult handleReadFirstLane(Operation *op, TranslationContext &ctx); +LogicalResult handlePermlane16Swap(Operation *op, TranslationContext &ctx); LogicalResult handleROCDLSBarrier(Operation *op, TranslationContext &ctx); LogicalResult handleROCDLSetPrio(Operation *op, TranslationContext &ctx); LogicalResult handleSWaitcnt(Operation *op, TranslationContext &ctx); +LogicalResult handleLLVMExtractValue(Operation *op, TranslationContext &ctx); +LogicalResult handleVectorFromElements(Operation *op, TranslationContext &ctx); //===----------------------------------------------------------------------===// // OpHandlerRegistry Implementation @@ -1978,6 +2007,7 @@ void OpHandlerRegistry::registerDefaultHandlers(mlir::MLIRContext *ctx) { REGISTER_HANDLER(arith::ExtUIOp, handleArithExtUI); REGISTER_HANDLER(arith::ExtSIOp, handleArithExtSI); REGISTER_HANDLER(arith::TruncIOp, handleArithTruncI); + REGISTER_HANDLER(arith::BitcastOp, handleArithBitcast); REGISTER_HANDLER(arith::MinSIOp, handleArithMinSI); REGISTER_HANDLER(arith::MaxSIOp, handleArithMaxSI); REGISTER_HANDLER(arith::MinUIOp, handleArithMinUI); @@ -2019,6 +2049,10 @@ void OpHandlerRegistry::registerDefaultHandlers(mlir::MLIRContext *ctx) { REGISTER_HANDLER(vector::TransferWriteOp, handleVectorTransferWrite); REGISTER_HANDLER(vector::FMAOp, handleVectorFma); REGISTER_HANDLER(vector::ReductionOp, handleVectorReduction); + REGISTER_HANDLER(vector::FromElementsOp, handleVectorFromElements); + + // LLVM dialect + REGISTER_HANDLER(LLVM::ExtractValueOp, handleLLVMExtractValue); // AMDGPU dialect REGISTER_HANDLER(amdgpu::LDSBarrierOp, handleAMDGPULdsBarrier); @@ -2033,6 +2067,7 @@ void OpHandlerRegistry::registerDefaultHandlers(mlir::MLIRContext *ctx) { // ROCDL dialect REGISTER_HANDLER(ROCDL::ReadfirstlaneOp, handleReadFirstLane); + REGISTER_HANDLER(ROCDL::Permlane16SwapOp, handlePermlane16Swap); REGISTER_HANDLER(ROCDL::SBarrierOp, handleROCDLSBarrier); REGISTER_HANDLER(ROCDL::SetPrioOp, handleROCDLSetPrio); REGISTER_HANDLER(ROCDL::SWaitcntOp, handleSWaitcnt); diff --git a/waveasm/lib/Transforms/VGPRCompaction.cpp b/waveasm/lib/Transforms/VGPRCompaction.cpp index 43f77ca1b2..251a7a8c6c 100644 --- a/waveasm/lib/Transforms/VGPRCompaction.cpp +++ b/waveasm/lib/Transforms/VGPRCompaction.cpp @@ -52,8 +52,7 @@ struct PhysVGPRRange { int64_t length() const { return lastUsePoint - defPoint; } }; -static void collectOps(Block &block, - llvm::SmallVectorImpl &ops) { +static void collectOps(Block &block, llvm::SmallVectorImpl &ops) { for (Operation &op : block) { ops.push_back(&op); for (Region ®ion : op.getRegions()) @@ -65,8 +64,7 @@ static void collectOps(Block &block, /// Record a PVRegType occurrence. Merges into existing entry if the base /// index matches, taking max size/alignment and extending the time range. static void recordPVReg(int64_t baseIdx, int64_t size, int64_t point, - bool isDef, - llvm::DenseMap &defPoints, + bool isDef, llvm::DenseMap &defPoints, llvm::DenseMap &usePoints, llvm::DenseMap &sizes, llvm::DenseMap &alignments) { @@ -122,10 +120,9 @@ findAllocBase(int64_t idx, int64_t size, return {idx, size}; } -static void buildPhysRanges( - ProgramOp program, - llvm::SmallVectorImpl &ranges, - llvm::DenseMap &rangeBaseToIdx) { +static void buildPhysRanges(ProgramOp program, + llvm::SmallVectorImpl &ranges, + llvm::DenseMap &rangeBaseToIdx) { llvm::SmallVector ops; collectOps(program.getBodyBlock(), ops); @@ -251,8 +248,8 @@ static void buildPhysRanges( body.walk([&](Operation *innerOp) { for (Value operand : innerOp->getOperands()) { if (auto pvreg = dyn_cast(operand.getType())) { - auto [base, allocSize] = findAllocBase(pvreg.getIndex(), - pvreg.getSize(), knownSizes); + auto [base, allocSize] = + findAllocBase(pvreg.getIndex(), pvreg.getSize(), knownSizes); auto defIt = defPoints.find(base); if (defIt != defPoints.end() && defIt->second < loopStart) { // Value defined before the loop but used inside — extend to @@ -411,8 +408,8 @@ static void applyRemapping(ProgramOp program, } if (auto condOp = dyn_cast(op)) { - if (auto attr = condOp->getAttrOfType( - "_iterArgPhysRegs")) { + if (auto attr = + condOp->getAttrOfType("_iterArgPhysRegs")) { auto vals = attr.asArrayRef(); llvm::SmallVector newVals(vals.begin(), vals.end()); bool anyChanged = false; @@ -422,8 +419,7 @@ static void applyRemapping(ProgramOp program, if (i < condOp.getIterArgs().size()) { Type ty = condOp.getIterArgs()[i].getType(); if (isa(ty)) { - int64_t newIdx = - remapIndex(newVals[i], 1, oldToNew, ranges); + int64_t newIdx = remapIndex(newVals[i], 1, oldToNew, ranges); if (newIdx != newVals[i]) { newVals[i] = newIdx; anyChanged = true; @@ -493,9 +489,9 @@ struct WAVEASMVGPRCompaction if (!anyChange) return; - LLVM_DEBUG(llvm::dbgs() << "VGPR compaction: " << maxBefore << " -> " - << maxAfter << " (saved " - << (maxBefore - maxAfter) << ")\n"); + LLVM_DEBUG(llvm::dbgs() + << "VGPR compaction: " << maxBefore << " -> " << maxAfter + << " (saved " << (maxBefore - maxAfter) << ")\n"); applyRemapping(program, oldToNew, ranges); }); diff --git a/waveasm/lib/Transforms/handlers/AMDGPUHandlers.cpp b/waveasm/lib/Transforms/handlers/AMDGPUHandlers.cpp index 5ccc8da9eb..e48a9796dd 100644 --- a/waveasm/lib/Transforms/handlers/AMDGPUHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/AMDGPUHandlers.cpp @@ -15,8 +15,10 @@ #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/Support/Debug.h" @@ -386,8 +388,8 @@ static void emitSrdNumRecords(OpBuilder &builder, Location loc, int64_t srdBase, if (cmpLhsMapped && cmpRhsMapped && trueMapped && falseMapped && isScalarOrImm(*cmpLhsMapped) && isScalarOrImm(*cmpRhsMapped) && isScalarOrImm(*trueMapped) && isScalarOrImm(*falseMapped)) { - emitScalarCmp(builder, loc, cmpOp.getPredicate(), - *cmpLhsMapped, *cmpRhsMapped, ctx); + emitScalarCmp(builder, loc, cmpOp.getPredicate(), *cmpLhsMapped, + *cmpRhsMapped, ctx); auto dstType = PSRegType::get(builder.getContext(), srdBase + 2, 1); Value trueV = *trueMapped; @@ -1250,6 +1252,89 @@ LogicalResult handleReadFirstLane(Operation *op, TranslationContext &ctx) { return success(); } +LogicalResult handlePermlane16Swap(Operation *op, TranslationContext &ctx) { + auto swapOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + auto vregType = ctx.createVRegType(); + + auto src0 = ctx.getMapper().getMapped(swapOp.getOperand(0)); + auto src1 = ctx.getMapper().getMapped(swapOp.getOperand(1)); + if (!src0 || !src1) + return op->emitError("permlane16_swap operand not mapped"); + + if (isAGPRType(src0->getType())) + src0 = V_ACCVGPR_READ_B32::create(builder, loc, vregType, *src0); + if (isAGPRType(src1->getType())) + src1 = V_ACCVGPR_READ_B32::create(builder, loc, vregType, *src1); + + auto permOp = V_PERMLANE16_SWAP_B32::create(builder, loc, vregType, vregType, + *src0, *src1); + + // rocdl.permlane16_swap returns !llvm.struct<(i32, i32)>. + // Map result 0 to the struct SSA value (used by llvm.extractvalue 0) + // and result 1 to the second output (extractvalue 1). + ctx.getMapper().mapValue(op->getResult(0), permOp.getResult(0)); + ctx.getMapper().mapSecondResult(op->getResult(0), permOp.getResult(1)); + + return success(); +} + +LogicalResult handleLLVMExtractValue(Operation *op, TranslationContext &ctx) { + auto extractOp = cast(op); + auto position = extractOp.getPosition(); + if (position.size() != 1) + return op->emitError("only single-level extractvalue supported"); + + int64_t idx = position[0]; + Value container = extractOp.getContainer(); + + if (idx == 0) { + auto mapped = ctx.getMapper().getMapped(container); + if (!mapped) + return op->emitError("extractvalue: container not mapped (index 0)"); + ctx.getMapper().mapValue(op->getResult(0), *mapped); + } else if (idx == 1) { + auto mapped = ctx.getMapper().getSecondResult(container); + if (!mapped) + return op->emitError("extractvalue: no second result (index 1)"); + ctx.getMapper().mapValue(op->getResult(0), *mapped); + } else { + return op->emitError("extractvalue index > 1 not supported"); + } + + return success(); +} + +LogicalResult handleVectorFromElements(Operation *op, TranslationContext &ctx) { + auto fromElOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + auto resultType = fromElOp.getResult().getType(); + auto vecType = cast(resultType); + int64_t numElems = vecType.getNumElements(); + int64_t elemBits = vecType.getElementType().getIntOrFloatBitWidth(); + int64_t totalBytes = numElems * ((elemBits + 7) / 8); + int64_t numDwords = (totalBytes + 3) / 4; + if (numDwords < 1) + numDwords = 1; + + auto vregType = ctx.createVRegType(numDwords, numDwords > 1 ? numDwords : 1); + + SmallVector mappedSrcs; + for (auto operand : fromElOp.getElements()) { + auto mapped = ctx.getMapper().getMapped(operand); + if (!mapped) + return op->emitError("from_elements operand not mapped"); + mappedSrcs.push_back(*mapped); + } + + auto packed = PackOp::create(builder, loc, vregType, mappedSrcs); + ctx.getMapper().mapValue(op->getResult(0), packed); + + return success(); +} + LogicalResult handleMemRefAtomicRMW(Operation *op, TranslationContext &ctx) { auto atomicOp = cast(op); auto &builder = ctx.getBuilder(); diff --git a/waveasm/lib/Transforms/handlers/AffineHandlers.cpp b/waveasm/lib/Transforms/handlers/AffineHandlers.cpp index 1a56cc0aea..f4a22a5e9d 100644 --- a/waveasm/lib/Transforms/handlers/AffineHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/AffineHandlers.cpp @@ -606,12 +606,11 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { auto shiftConst = ConstantOp::create(builder, loc, shiftImm, shiftAmount); // v_lshl_or_b32: dst = (src << shift) | orend - Value base = ensureVGPR(builder, loc, ctx, - baseResult.value); + Value base = + ensureVGPR(builder, loc, ctx, baseResult.value); orend = ensureVGPR(builder, loc, ctx, orend); Value fusedResult = V_LSHL_OR_B32::create( - builder, loc, vregType, base, shiftConst, - orend); + builder, loc, vregType, base, shiftConst, orend); BitRange shiftedRange = baseResult.range.shiftLeft(shiftAmount); BitRange resultRange = shiftedRange.merge(orendRange); @@ -629,12 +628,11 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { auto shiftImm = ctx.createImmType(shiftAmount); auto shiftConst = ConstantOp::create(builder, loc, shiftImm, shiftAmount); - Value base2 = ensureVGPR(builder, loc, ctx, - baseResult.value); + Value base2 = + ensureVGPR(builder, loc, ctx, baseResult.value); orend = ensureVGPR(builder, loc, ctx, orend); Value fusedResult = V_LSHL_OR_B32::create( - builder, loc, vregType, base2, shiftConst, - orend); + builder, loc, vregType, base2, shiftConst, orend); BitRange shiftedRange = baseResult.range.shiftLeft(shiftAmount); BitRange resultRange = shiftedRange.merge(orendRange); @@ -728,8 +726,7 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { auto shiftAmt = ctx.createImmType(shiftAmount); auto shiftConst = ConstantOp::create(builder, loc, shiftAmt, shiftAmount); - Value shiftResult = - emitLshl(lhs, shiftConst, builder, loc, ctx); + Value shiftResult = emitLshl(lhs, shiftConst, builder, loc, ctx); BitRange resultRange = lhsRange.shiftLeft(shiftAmount); ctx.setBitRange(shiftResult, resultRange); return ExprResult(shiftResult, resultRange); @@ -742,8 +739,7 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { auto shiftAmt = ctx.createImmType(shiftAmount); auto shiftConst = ConstantOp::create(builder, loc, shiftAmt, shiftAmount); - Value shiftResult = - emitLshl(rhs, shiftConst, builder, loc, ctx); + Value shiftResult = emitLshl(rhs, shiftConst, builder, loc, ctx); BitRange resultRange = rhsRange.shiftLeft(shiftAmount); ctx.setBitRange(shiftResult, resultRange); return ExprResult(shiftResult, resultRange); @@ -776,8 +772,7 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { auto shiftAmt = ctx.createImmType(shiftAmount); auto shiftConst = ConstantOp::create(builder, loc, shiftAmt, shiftAmount); - Value shiftResult = - emitLshr(lhs, shiftConst, builder, loc, ctx); + Value shiftResult = emitLshr(lhs, shiftConst, builder, loc, ctx); BitRange resultRange = lhsRange.shiftRight(shiftAmount); ctx.setBitRange(shiftResult, resultRange); return ExprResult(shiftResult, resultRange); @@ -815,10 +810,9 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { S_CMP_NE_U32::create(builder, loc, ctx.createSRegType(), rem, zeroConst); auto sregType = ctx.createSRegType(); - Value result = - S_ADDC_U32::create(builder, loc, sregType, sregType, q, - zeroConst) - .getDst(); + Value result = S_ADDC_U32::create(builder, loc, sregType, + sregType, q, zeroConst) + .getDst(); return ExprResult(result, BitRange()); } V_CMP_NE_U32::create(builder, loc, rem, zeroConst); diff --git a/waveasm/lib/Transforms/handlers/ArithHandlers.cpp b/waveasm/lib/Transforms/handlers/ArithHandlers.cpp index 27a016d1ae..3dddf8fbdd 100644 --- a/waveasm/lib/Transforms/handlers/ArithHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/ArithHandlers.cpp @@ -437,6 +437,21 @@ LogicalResult handleArithTruncI(Operation *op, TranslationContext &ctx) { return success(); } +LogicalResult handleArithBitcast(Operation *op, TranslationContext &ctx) { + auto bitcastOp = cast(op); + auto src = ctx.getMapper().getMapped(bitcastOp.getIn()); + if (src) { + // Emit ensureVGPR to coerce AGPR/SGPR sources to VGPR. + // V_ACCVGPR_READ_B32 is no longer [Pure], so each call produces a + // distinct read that the register allocator keeps live independently. + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + Value vgpr = ensureVGPR(builder, loc, ctx, *src); + ctx.getMapper().mapValue(bitcastOp.getResult(), vgpr); + } + return success(); +} + //===----------------------------------------------------------------------===// // Comparison and Select Operations //===----------------------------------------------------------------------===// @@ -458,7 +473,8 @@ LogicalResult handleArithCmpI(Operation *op, TranslationContext &ctx) { // handleArithSelect which emits s_cmp + s_cselect as a pair. bool lhsScalar = isSGPRType(lhs->getType()) || isImmType(lhs->getType()); bool rhsScalar = isSGPRType(rhs->getType()) || isImmType(rhs->getType()); - bool usedByCondition = cmpOp.getResult().hasOneUse() && + bool usedByCondition = + cmpOp.getResult().hasOneUse() && isa(*cmpOp.getResult().getUsers().begin()); if (lhsScalar && rhsScalar && usedByCondition) { @@ -486,8 +502,8 @@ LogicalResult handleArithCmpI(Operation *op, TranslationContext &ctx) { if (auto selectUser = dyn_cast(user)) { auto trueMap = ctx.getMapper().getMapped(selectUser.getTrueValue()); auto falseMap = ctx.getMapper().getMapped(selectUser.getFalseValue()); - if (!trueMap || !falseMap || - !isScalarOrImm(*trueMap) || !isScalarOrImm(*falseMap)) { + if (!trueMap || !falseMap || !isScalarOrImm(*trueMap) || + !isScalarOrImm(*falseMap)) { allUsersFused = false; break; } @@ -570,8 +586,8 @@ LogicalResult handleArithSelect(Operation *op, TranslationContext &ctx) { if (auto cmpOp = condMLIR.getDefiningOp()) { auto cmpLhs = ctx.getMapper().getMapped(cmpOp.getLhs()); auto cmpRhs = ctx.getMapper().getMapped(cmpOp.getRhs()); - if (cmpLhs && cmpRhs && - isScalarOrImm(*cmpLhs) && isScalarOrImm(*cmpRhs)) { + if (cmpLhs && cmpRhs && isScalarOrImm(*cmpLhs) && + isScalarOrImm(*cmpRhs)) { auto sregType = ctx.createSRegType(); Value lhsOp = *cmpLhs; Value rhsOp = *cmpRhs; diff --git a/waveasm/lib/Transforms/handlers/Handlers.h b/waveasm/lib/Transforms/handlers/Handlers.h index bd6b8862fa..1a0aca3980 100644 --- a/waveasm/lib/Transforms/handlers/Handlers.h +++ b/waveasm/lib/Transforms/handlers/Handlers.h @@ -120,6 +120,8 @@ mlir::LogicalResult handleArithCmpF(mlir::Operation *op, TranslationContext &ctx); mlir::LogicalResult handleArithTruncF(mlir::Operation *op, TranslationContext &ctx); +mlir::LogicalResult handleArithBitcast(mlir::Operation *op, + TranslationContext &ctx); mlir::LogicalResult handleArithExtF(mlir::Operation *op, TranslationContext &ctx); @@ -184,9 +186,25 @@ mlir::LogicalResult handleMemRefAtomicRMW(mlir::Operation *op, mlir::LogicalResult handleReadFirstLane(mlir::Operation *op, TranslationContext &ctx); +mlir::LogicalResult handlePermlane16Swap(mlir::Operation *op, + TranslationContext &ctx); mlir::LogicalResult handleSWaitcnt(mlir::Operation *op, TranslationContext &ctx); +//===----------------------------------------------------------------------===// +// LLVM Dialect Handlers +//===----------------------------------------------------------------------===// + +mlir::LogicalResult handleLLVMExtractValue(mlir::Operation *op, + TranslationContext &ctx); + +//===----------------------------------------------------------------------===// +// Vector Dialect Handlers (additional) +//===----------------------------------------------------------------------===// + +mlir::LogicalResult handleVectorFromElements(mlir::Operation *op, + TranslationContext &ctx); + //===----------------------------------------------------------------------===// // IREE/Stream Dialect Handlers //===----------------------------------------------------------------------===// @@ -290,9 +308,14 @@ inline bool isScalarOrImm(mlir::Value v) { /// If \p v is an SGPR, emit a v_mov_b32 to coerce it into a VGPR so it can /// be used by VALU-only instructions (v_cvt_*, v_rcp_*, v_mul_f32, etc.). +/// If \p v is an AGPR (accumulator), emit v_accvgpr_read_b32 first. /// Returns \p v unchanged when it is already a VGPR or immediate. inline mlir::Value ensureVGPR(mlir::OpBuilder &builder, mlir::Location loc, TranslationContext &ctx, mlir::Value v) { + if (isAGPRType(v.getType())) { + auto vregType = ctx.createVRegType(); + return V_ACCVGPR_READ_B32::create(builder, loc, vregType, v); + } if (isSGPRType(v.getType())) { auto vregType = ctx.createVRegType(); return V_MOV_B32::create(builder, loc, vregType, v); @@ -306,8 +329,9 @@ inline mlir::Value ensureVGPR(mlir::OpBuilder &builder, mlir::Location loc, /// Emit add: S_ADD_U32 when both operands are scalar, V_ADD_U32 otherwise. /// Commutative: swaps to put immediate in src1 (SALU src0 must be SGPR). -inline mlir::Value emitAdd(mlir::Value a, mlir::Value b, mlir::OpBuilder &builder, - mlir::Location loc, TranslationContext &ctx) { +inline mlir::Value emitAdd(mlir::Value a, mlir::Value b, + mlir::OpBuilder &builder, mlir::Location loc, + TranslationContext &ctx) { if (isScalarOrImm(a) && isScalarOrImm(b) && !(isImmType(a.getType()) && isImmType(b.getType()))) { if (isImmType(a.getType())) @@ -321,8 +345,9 @@ inline mlir::Value emitAdd(mlir::Value a, mlir::Value b, mlir::OpBuilder &builde /// Emit sub: S_SUB_U32 when both operands are scalar, V_SUB_U32 otherwise. /// Not commutative: src0 (minuend) must be SGPR. -inline mlir::Value emitSub(mlir::Value a, mlir::Value b, mlir::OpBuilder &builder, - mlir::Location loc, TranslationContext &ctx) { +inline mlir::Value emitSub(mlir::Value a, mlir::Value b, + mlir::OpBuilder &builder, mlir::Location loc, + TranslationContext &ctx) { if (isScalarOrImm(a) && isScalarOrImm(b) && isSGPRType(a.getType())) { auto sregType = ctx.createSRegType(); return S_SUB_U32::create(builder, loc, sregType, sregType, a, b).getDst(); @@ -333,8 +358,9 @@ inline mlir::Value emitSub(mlir::Value a, mlir::Value b, mlir::OpBuilder &builde /// Emit mul: S_MUL_I32 when both operands are scalar, V_MUL_LO_U32 otherwise. /// Commutative: swaps to put immediate in src1. -inline mlir::Value emitMul(mlir::Value a, mlir::Value b, mlir::OpBuilder &builder, - mlir::Location loc, TranslationContext &ctx) { +inline mlir::Value emitMul(mlir::Value a, mlir::Value b, + mlir::OpBuilder &builder, mlir::Location loc, + TranslationContext &ctx) { if (isScalarOrImm(a) && isScalarOrImm(b) && !(isImmType(a.getType()) && isImmType(b.getType()))) { if (isImmType(a.getType())) @@ -378,8 +404,9 @@ inline mlir::Value emitLshl(mlir::Value value, mlir::Value shiftAmt, /// Emit bitwise AND: S_AND_B32 when both scalar, V_AND_B32 otherwise. /// Commutative: swaps to put immediate in src1. -inline mlir::Value emitAnd(mlir::Value a, mlir::Value b, mlir::OpBuilder &builder, - mlir::Location loc, TranslationContext &ctx) { +inline mlir::Value emitAnd(mlir::Value a, mlir::Value b, + mlir::OpBuilder &builder, mlir::Location loc, + TranslationContext &ctx) { if (isScalarOrImm(a) && isScalarOrImm(b) && !(isImmType(a.getType()) && isImmType(b.getType()))) { if (isImmType(a.getType())) @@ -393,8 +420,9 @@ inline mlir::Value emitAnd(mlir::Value a, mlir::Value b, mlir::OpBuilder &builde /// Emit bitwise OR: S_OR_B32 when both scalar, V_OR_B32 otherwise. /// Commutative: swaps to put immediate in src1. -inline mlir::Value emitOr(mlir::Value a, mlir::Value b, mlir::OpBuilder &builder, - mlir::Location loc, TranslationContext &ctx) { +inline mlir::Value emitOr(mlir::Value a, mlir::Value b, + mlir::OpBuilder &builder, mlir::Location loc, + TranslationContext &ctx) { if (isScalarOrImm(a) && isScalarOrImm(b) && !(isImmType(a.getType()) && isImmType(b.getType()))) { if (isImmType(a.getType())) @@ -408,8 +436,9 @@ inline mlir::Value emitOr(mlir::Value a, mlir::Value b, mlir::OpBuilder &builder /// Emit bitwise XOR: S_XOR_B32 when both scalar, V_XOR_B32 otherwise. /// Commutative: swaps to put immediate in src1. -inline mlir::Value emitXor(mlir::Value a, mlir::Value b, mlir::OpBuilder &builder, - mlir::Location loc, TranslationContext &ctx) { +inline mlir::Value emitXor(mlir::Value a, mlir::Value b, + mlir::OpBuilder &builder, mlir::Location loc, + TranslationContext &ctx) { if (isScalarOrImm(a) && isScalarOrImm(b) && !(isImmType(a.getType()) && isImmType(b.getType()))) { if (isImmType(a.getType())) diff --git a/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp b/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp index 70cb50b248..d6bb4426c5 100644 --- a/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp @@ -304,8 +304,7 @@ LogicalResult handleMemRefStore(Operation *op, TranslationContext &ctx) { Value storeData = *data; if (isAGPRType(storeData.getType())) { auto vregType = ctx.createVRegType(); - storeData = - V_ACCVGPR_READ_B32::create(builder, loc, vregType, storeData); + storeData = V_ACCVGPR_READ_B32::create(builder, loc, vregType, storeData); } BUFFER_STORE_DWORD::create(builder, loc, storeData, srd, voffset); From 67b07a70aeea9f38cec49442c44ee1ac9310c935 Mon Sep 17 00:00:00 2001 From: xintin Date: Fri, 27 Mar 2026 02:07:28 +0000 Subject: [PATCH 4/5] dwrodx4 + cvt-first working Signed-off-by: xintin --- examples/python/7.1_schedule.py | 1048 +++++++++++++++++ .../compiler/wave_codegen/read_write.py | 6 +- .../kernel/wave/coalesce_epilogue_stores.py | 50 + wave_lang/kernel/wave/compile.py | 15 + waveasm/lib/Transforms/LinearScanPass.cpp | 30 + 5 files changed, 1146 insertions(+), 3 deletions(-) create mode 100644 wave_lang/kernel/wave/coalesce_epilogue_stores.py diff --git a/examples/python/7.1_schedule.py b/examples/python/7.1_schedule.py index aff372d35f..5d618a9d63 100755 --- a/examples/python/7.1_schedule.py +++ b/examples/python/7.1_schedule.py @@ -43,6 +43,728 @@ SHARED_ADDRESS_SPACE, ) from utils import parse_args, list_tests, run_test +import re + + +def coalesce_buffer_stores_dwordx4(asm_text): + """Post-process assembly to merge 4 consecutive buffer_store_dword into buffer_store_dwordx4. + + Detects groups of 4 buffer_store_dword with same SRD+voffset and + offsets {X, X+4, X+8, X+12}. Replaces each store with a v_mov_b32 + copy to v0-v3, then emits one buffer_store_dwordx4 v[0:3]. + """ + store_re = re.compile( + r"^ buffer_store_dword (v\d+), (v\d+), (s\[\d+:\d+\]), 0 offen(?: offset:(\d+))?$" + ) + + lines = asm_text.split("\n") + out = [] + i = 0 + coalesced_count = 0 + while i < len(lines): + m0 = store_re.match(lines[i]) + if m0: + group = [(i, m0)] + j = i + 1 + while j < len(lines) and len(group) < 4: + mj = store_re.match(lines[j]) + if mj: + group.append((j, mj)) + j += 1 + + if len(group) == 4: + srds = [g.group(3) for _, g in group] + voffs = [g.group(2) for _, g in group] + offsets = [int(g.group(4)) if g.group(4) else 0 for _, g in group] + base = offsets[0] + if ( + all(s == srds[0] for s in srds) + and all(v == voffs[0] for v in voffs) + and offsets == [base, base + 4, base + 8, base + 12] + ): + store_line_indices = {idx for idx, _ in group} + slot = 0 + for k in range(i, group[-1][0] + 1): + if k in store_line_indices: + data_reg = group[slot][1].group(1) + out.append(f" v_mov_b32 v{slot}, {data_reg}") + slot += 1 + else: + out.append(lines[k]) + merged = ( + f" buffer_store_dwordx4 v[0:3], {voffs[0]}, {srds[0]}, 0 offen" + ) + if base > 0: + merged += f" offset:{base}" + out.append(merged) + coalesced_count += 1 + i = group[-1][0] + 1 + continue + + out.append(lines[i]) + i += 1 + + print( + f"[asm_transform] Coalesced {coalesced_count} groups of 4 stores -> buffer_store_dwordx4" + ) + return "\n".join(out) + + +def convert_first_eliminate_cndmask(asm_text): + """Replace per-tile swap+cndmask+cvt+store with convert-first+swap+dwordx4. + + The current epilogue swaps individual f32 values, then uses v_cndmask_b32 + to select own vs partner data (8 cndmask per tile = 384 total). This + transform converts to packed bf16 FIRST, swaps the packed dwords (1 swap + instead of 4), re-reads AGPRs for own data, and stores via dwordx4. + Eliminates all data-select cndmask and halves the swap count. + """ + read_re = re.compile(r"^\s*v_accvgpr_read_b32\s+(v\d+),\s+a(\d+)") + nop_re = re.compile(r"^s_nop 1\s*$") + swap_re = re.compile(r"^\s*v_permlane16_swap_b32\s+") + store_off12_re = re.compile( + r"^\s*buffer_store_dword\s+v\d+,\s+(v\d+),\s+(s\[\d+:\d+\]),\s+0 offen\s+offset:12\s*$" + ) + sub4_re = re.compile(r"^\s*v_sub_u32\s+v\d+,\s+v\d+,\s+4\s*$") + cmp_ne_v244_re = re.compile(r"^\s*v_cmp_ne_u32\s+vcc,\s+v244,\s+0") + cndmask_re = re.compile(r"^\s*v_cndmask_b32\s+(v\d+),") + lshlrev1_re = re.compile(r"^\s*v_lshlrev_b32\s+v\d+,\s+1,") + + lines = asm_text.split("\n") + out = [] + tile_count = 0 + i = 0 + + while i < len(lines): + if i + 11 < len(lines) and read_re.match(lines[i]): + agprs = [] + ok = True + for k in range(4): + base = i + k * 3 + m = read_re.match(lines[base]) + if not ( + m + and nop_re.match(lines[base + 1]) + and swap_re.match(lines[base + 2]) + ): + ok = False + break + agprs.append(int(m.group(2))) + if not ok: + out.append(lines[i]) + i += 1 + continue + + swap_end = i + 12 + + j = swap_end + srd = None + while j < len(lines): + sm = store_off12_re.match(lines[j]) + if sm: + srd = sm.group(2) + break + j += 1 + + if srd is None: + out.append(lines[i]) + i += 1 + continue + + tile_end = j + 1 + tile_count += 1 + + middle = lines[swap_end:tile_end] + preserved = [] + mi = 0 + while mi < len(middle): + s = middle[mi].strip() + + # Offset-select pattern: v_sub_u32 .., 4 -> v_cmp_ne -> v_cndmask + if sub4_re.match(s): + preserved.append(middle[mi]) + if mi + 2 < len(middle): + preserved.append(middle[mi + 1]) + preserved.append(middle[mi + 2]) + mi += 3 + else: + mi += 1 + continue + + # Lane-mask creation: v_cndmask_b32 v244, ... + cm = cndmask_re.match(s) + if cm and cm.group(1) == "v244": + preserved.append(middle[mi]) + mi += 1 + continue + + if ( + cmp_ne_v244_re.match(s) + or (cm and cm.group(1) != "v244") + or s.startswith("v_accvgpr_read_b32") + or s.startswith("v_cvt_pk_bf16_f32") + or s.startswith("buffer_store_dword") + or lshlrev1_re.match(s) + ): + mi += 1 + continue + + preserved.append(middle[mi]) + mi += 1 + + a0, a1, a2, a3 = agprs + # Preserved lines (lane mask, offset select, addr comp, SRD) + # must come first so v244 and v253 are set before we use them. + out.extend(preserved) + out.extend( + [ + f" v_accvgpr_read_b32 v0, a{a0}", + f" v_accvgpr_read_b32 v1, a{a1}", + f" v_cvt_pk_bf16_f32 v2, v0, v1", + f" v_accvgpr_read_b32 v0, a{a2}", + f" v_accvgpr_read_b32 v1, a{a3}", + f" v_cvt_pk_bf16_f32 v3, v0, v1", + "s_nop 1", + " v_permlane16_swap_b32 v4, v2", + "s_nop 1", + " v_permlane16_swap_b32 v5, v3", + f" v_accvgpr_read_b32 v0, a{a0}", + f" v_accvgpr_read_b32 v1, a{a1}", + f" v_cvt_pk_bf16_f32 v2, v0, v1", + f" v_accvgpr_read_b32 v0, a{a2}", + f" v_accvgpr_read_b32 v1, a{a3}", + f" v_cvt_pk_bf16_f32 v3, v0, v1", + " v_cmp_ne_u32 vcc, v244, 0", + " v_cndmask_b32 v0, v4, v2", + " v_cndmask_b32 v1, v5, v3", + " v_cndmask_b32 v6, v2, v4", + " v_cndmask_b32 v7, v3, v5", + " v_mov_b32 v2, v6", + " v_mov_b32 v3, v7", + " v_lshlrev_b32 v245, 1, v253", + f" buffer_store_dwordx4 v[0:3], v245, {srd}, 0 offen", + ] + ) + + i = tile_end + continue + + out.append(lines[i]) + i += 1 + + print( + f"[convert_first] Transformed {tile_count} tiles: eliminated cndmask, merged to dwordx4" + ) + return "\n".join(out) + + +def lds_epilogue_transform(asm_text): + """Replace epilogue with ds_bpermute cross-lane exchange (no permlane_swap). + + Uses ds_bpermute_b32 to read packed bf16 data from partner lanes via the + LDS crossbar (no LDS memory access). Both halves of each swap-pair read + from the same LOW and HIGH partner lanes, producing identical v[0:3]. + Eliminates all permlane_swap, cndmask, s_nop, and AGPR re-reads. + + Per-tile: 4 accvgpr_read + 2 cvt + 4 bpermute + 1 waitcnt + 1 lshlrev + + 1 store = 13 instructions (vs 25 cvt-first, 45 orig). + """ + read_re = re.compile(r"^\s*v_accvgpr_read_b32\s+(v\d+),\s+a(\d+)") + nop_re = re.compile(r"^s_nop 1\s*$") + swap_re = re.compile(r"^\s*v_permlane16_swap_b32\s+") + store_off12_re = re.compile( + r"^\s*buffer_store_dword\s+v\d+,\s+(v\d+),\s+(s\[\d+:\d+\]),\s+0 offen\s+offset:12\s*$" + ) + sub4_re = re.compile(r"^\s*v_sub_u32\s+v\d+,\s+v\d+,\s+4\s*$") + cmp_ne_v244_re = re.compile(r"^\s*v_cmp_ne_u32\s+vcc,\s+v244,\s+0") + cndmask_re = re.compile(r"^\s*v_cndmask_b32\s+(v\d+),") + lshlrev1_re = re.compile(r"^\s*v_lshlrev_b32\s+v\d+,\s+1,") + + lines = asm_text.split("\n") + + out = [] + tile_count = 0 + lds_setup_emitted = False + i = 0 + + while i < len(lines): + if i + 11 < len(lines) and read_re.match(lines[i]): + agprs = [] + ok = True + for k in range(4): + base = i + k * 3 + m = read_re.match(lines[base]) + if not ( + m + and nop_re.match(lines[base + 1]) + and swap_re.match(lines[base + 2]) + ): + ok = False + break + agprs.append(int(m.group(2))) + if not ok: + out.append(lines[i]) + i += 1 + continue + + swap_end = i + 12 + + j = swap_end + srd = None + while j < len(lines): + sm = store_off12_re.match(lines[j]) + if sm: + srd = sm.group(2) + break + j += 1 + + if srd is None: + out.append(lines[i]) + i += 1 + continue + + tile_end = j + 1 + tile_count += 1 + + middle = lines[swap_end:tile_end] + preserved = [] + mi = 0 + while mi < len(middle): + s = middle[mi].strip() + + if sub4_re.match(s): + preserved.append(middle[mi]) + if mi + 2 < len(middle): + preserved.append(middle[mi + 1]) + preserved.append(middle[mi + 2]) + mi += 3 + else: + mi += 1 + continue + + cm = cndmask_re.match(s) + if cm and cm.group(1) == "v244": + preserved.append(middle[mi]) + mi += 1 + continue + + if ( + cmp_ne_v244_re.match(s) + or (cm and cm.group(1) != "v244") + or s.startswith("v_accvgpr_read_b32") + or s.startswith("v_cvt_pk_bf16_f32") + or s.startswith("buffer_store_dword") + or lshlrev1_re.match(s) + ): + mi += 1 + continue + + preserved.append(middle[mi]) + mi += 1 + + if not lds_setup_emitted: + out.extend( + [ + " ;; LDS epilogue: bpermute-based cross-lane exchange", + " v_and_b32 v4, 4294967279, v238", + " v_lshlrev_b32 v4, 2, v4", + " v_or_b32 v5, v238, 16", + " v_lshlrev_b32 v5, 2, v5", + ] + ) + lds_setup_emitted = True + + out.extend(preserved) + + a0, a1, a2, a3 = agprs + out.extend( + [ + f" v_accvgpr_read_b32 v0, a{a0}", + f" v_accvgpr_read_b32 v1, a{a1}", + " v_cvt_pk_bf16_f32 v2, v0, v1", + f" v_accvgpr_read_b32 v0, a{a2}", + f" v_accvgpr_read_b32 v1, a{a3}", + " v_cvt_pk_bf16_f32 v3, v0, v1", + " ds_bpermute_b32 v0, v4, v2", + " ds_bpermute_b32 v1, v4, v3", + " ds_bpermute_b32 v2, v5, v2", + " ds_bpermute_b32 v3, v5, v3", + " s_waitcnt lgkmcnt(0)", + " v_lshlrev_b32 v245, 1, v253", + f" buffer_store_dwordx4 v[0:3], v245, {srd}, 0 offen", + ] + ) + + i = tile_end + continue + + out.append(lines[i]) + i += 1 + + print( + f"[lds_epilogue] Transformed {tile_count} tiles: bpermute exchange, " + f"no swap/cndmask/LDS-memory" + ) + return "\n".join(out) + + +def bpermute_masked_epilogue_transform(asm_text): + """Like lds_epilogue_transform but with exec masking to eliminate duplicate stores. + + In the transposed output, each pair of lanes (i, i^16) writes the same data to + the same address (benign duplicate). This transform masks the exec register so + only one lane per pair executes the buffer_store, halving memory traffic. + + Uses s[20:21] to save exec and s[22:23] for the masked exec (these SGPRs hold + input SRDs that are dead in the epilogue). + """ + read_re = re.compile(r"^\s*v_accvgpr_read_b32\s+(v\d+),\s+a(\d+)") + nop_re = re.compile(r"^s_nop 1\s*$") + swap_re = re.compile(r"^\s*v_permlane16_swap_b32\s+") + store_off12_re = re.compile( + r"^\s*buffer_store_dword\s+v\d+,\s+(v\d+),\s+(s\[\d+:\d+\]),\s+0 offen\s+offset:12\s*$" + ) + sub4_re = re.compile(r"^\s*v_sub_u32\s+v\d+,\s+v\d+,\s+4\s*$") + cmp_ne_v244_re = re.compile(r"^\s*v_cmp_ne_u32\s+vcc,\s+v244,\s+0") + cndmask_re = re.compile(r"^\s*v_cndmask_b32\s+(v\d+),") + lshlrev1_re = re.compile(r"^\s*v_lshlrev_b32\s+v\d+,\s+1,") + + lines = asm_text.split("\n") + out = [] + tile_count = 0 + lds_setup_emitted = False + exec_mask_emitted = False + i = 0 + + while i < len(lines): + if i + 11 < len(lines) and read_re.match(lines[i]): + agprs = [] + ok = True + for k in range(4): + base = i + k * 3 + m = read_re.match(lines[base]) + if not ( + m + and nop_re.match(lines[base + 1]) + and swap_re.match(lines[base + 2]) + ): + ok = False + break + agprs.append(int(m.group(2))) + if not ok: + out.append(lines[i]) + i += 1 + continue + + swap_end = i + 12 + j = swap_end + srd = None + while j < len(lines): + sm = store_off12_re.match(lines[j]) + if sm: + srd = sm.group(2) + break + j += 1 + + if srd is None: + out.append(lines[i]) + i += 1 + continue + + tile_end = j + 1 + tile_count += 1 + + middle = lines[swap_end:tile_end] + preserved = [] + mi = 0 + while mi < len(middle): + s = middle[mi].strip() + + if sub4_re.match(s): + preserved.append(middle[mi]) + if mi + 2 < len(middle): + preserved.append(middle[mi + 1]) + preserved.append(middle[mi + 2]) + mi += 3 + else: + mi += 1 + continue + + cm = cndmask_re.match(s) + if cm and cm.group(1) == "v244": + preserved.append(middle[mi]) + mi += 1 + continue + + if ( + cmp_ne_v244_re.match(s) + or (cm and cm.group(1) != "v244") + or s.startswith("v_accvgpr_read_b32") + or s.startswith("v_cvt_pk_bf16_f32") + or s.startswith("buffer_store_dword") + or lshlrev1_re.match(s) + ): + mi += 1 + continue + + preserved.append(middle[mi]) + mi += 1 + + if not lds_setup_emitted: + out.extend( + [ + " ;; bpermute epilogue with exec-masked stores", + " v_and_b32 v4, 4294967279, v238", + " v_lshlrev_b32 v4, 2, v4", + " v_or_b32 v5, v238, 16", + " v_lshlrev_b32 v5, 2, v5", + ] + ) + lds_setup_emitted = True + + out.extend(preserved) + + if not exec_mask_emitted: + out.extend( + [ + " ;; exec mask: only low lanes (0-15, 32-47) store", + " s_mov_b64 s[20:21], exec", + " v_cmp_ne_u32 vcc, v244, 0", + " s_and_b64 s[22:23], s[20:21], vcc", + ] + ) + exec_mask_emitted = True + + a0, a1, a2, a3 = agprs + out.extend( + [ + f" v_accvgpr_read_b32 v0, a{a0}", + f" v_accvgpr_read_b32 v1, a{a1}", + " v_cvt_pk_bf16_f32 v2, v0, v1", + f" v_accvgpr_read_b32 v0, a{a2}", + f" v_accvgpr_read_b32 v1, a{a3}", + " v_cvt_pk_bf16_f32 v3, v0, v1", + " ds_bpermute_b32 v0, v4, v2", + " ds_bpermute_b32 v1, v4, v3", + " ds_bpermute_b32 v2, v5, v2", + " ds_bpermute_b32 v3, v5, v3", + " s_waitcnt lgkmcnt(0)", + " v_lshlrev_b32 v245, 1, v253", + " s_mov_b64 exec, s[22:23]", + f" buffer_store_dwordx4 v[0:3], v245, {srd}, 0 offen", + " s_mov_b64 exec, s[20:21]", + ] + ) + + i = tile_end + continue + + out.append(lines[i]) + i += 1 + + print( + f"[bpermute_masked] Transformed {tile_count} tiles: " + f"exec-masked stores eliminate duplicate writes" + ) + return "\n".join(out) + + +def bpermute_pipelined_epilogue_transform(asm_text): + """Bpermute epilogue with software pipelining to hide ds_bpermute latency. + + Uses two alternating register sets: even tiles in v[0:3], odd tiles in + v[6:9]. Tile N+1's AGPR read+cvt_pk overlaps with tile N's in-flight + ds_bpermute, hiding the LDS crossbar latency behind VALU work. + + Byte offsets pre-saved in v10 (even) / v11 (odd) so the deferred store + uses the correct address after v253 has been updated for the next tile. + """ + read_re = re.compile(r"^\s*v_accvgpr_read_b32\s+(v\d+),\s+a(\d+)") + nop_re = re.compile(r"^s_nop 1\s*$") + swap_re = re.compile(r"^\s*v_permlane16_swap_b32\s+") + store_off12_re = re.compile( + r"^\s*buffer_store_dword\s+v\d+,\s+(v\d+),\s+(s\[\d+:\d+\]),\s+0 offen\s+offset:12\s*$" + ) + sub4_re = re.compile(r"^\s*v_sub_u32\s+v\d+,\s+v\d+,\s+4\s*$") + cmp_ne_v244_re = re.compile(r"^\s*v_cmp_ne_u32\s+vcc,\s+v244,\s+0") + cndmask_re = re.compile(r"^\s*v_cndmask_b32\s+(v\d+),") + lshlrev1_re = re.compile(r"^\s*v_lshlrev_b32\s+v\d+,\s+1,") + + lines = asm_text.split("\n") + out = [] + tile_count = 0 + lds_setup_emitted = False + exec_mask_emitted = False + first_srd = None + pending_store = None + i = 0 + + while i < len(lines): + if i + 11 < len(lines) and read_re.match(lines[i]): + agprs = [] + ok = True + for k in range(4): + base = i + k * 3 + m = read_re.match(lines[base]) + if not ( + m + and nop_re.match(lines[base + 1]) + and swap_re.match(lines[base + 2]) + ): + ok = False + break + agprs.append(int(m.group(2))) + if not ok: + out.append(lines[i]) + i += 1 + continue + + swap_end = i + 12 + j = swap_end + srd = None + while j < len(lines): + sm = store_off12_re.match(lines[j]) + if sm: + srd = sm.group(2) + break + j += 1 + + if srd is None: + out.append(lines[i]) + i += 1 + continue + + if first_srd is None: + first_srd = srd + + tile_end = j + 1 + tile_count += 1 + + middle = lines[swap_end:tile_end] + preserved = [] + mi = 0 + while mi < len(middle): + s = middle[mi].strip() + if sub4_re.match(s): + preserved.append(middle[mi]) + if mi + 2 < len(middle): + preserved.append(middle[mi + 1]) + preserved.append(middle[mi + 2]) + mi += 3 + else: + mi += 1 + continue + cm = cndmask_re.match(s) + if cm and cm.group(1) == "v244": + preserved.append(middle[mi]) + mi += 1 + continue + if ( + cmp_ne_v244_re.match(s) + or (cm and cm.group(1) != "v244") + or s.startswith("v_accvgpr_read_b32") + or s.startswith("v_cvt_pk_bf16_f32") + or s.startswith("buffer_store_dword") + or lshlrev1_re.match(s) + ): + mi += 1 + continue + preserved.append(middle[mi]) + mi += 1 + + is_even = (tile_count - 1) % 2 == 0 + if is_even: + rd0, rd1, cv0, cv1 = "v0", "v1", "v2", "v3" + store_reg, addr_reg = "v[0:3]", "v10" + else: + rd0, rd1, cv0, cv1 = "v6", "v7", "v8", "v9" + store_reg, addr_reg = "v[6:9]", "v11" + + if not lds_setup_emitted: + out.extend( + [ + " ;; pipelined bpermute epilogue with exec-masked stores", + " v_and_b32 v4, 4294967279, v238", + " v_lshlrev_b32 v4, 2, v4", + " v_or_b32 v5, v238, 16", + " v_lshlrev_b32 v5, 2, v5", + ] + ) + lds_setup_emitted = True + + out.extend(preserved) + + if not exec_mask_emitted: + out.extend( + [ + " ;; exec mask: only low lanes (0-15, 32-47) store", + " s_mov_b64 s[20:21], exec", + " v_cmp_ne_u32 vcc, v244, 0", + " s_and_b64 s[22:23], s[20:21], vcc", + ] + ) + exec_mask_emitted = True + + a0, a1, a2, a3 = agprs + + out.append(f" v_lshlrev_b32 {addr_reg}, 1, v253") + + out.extend( + [ + f" v_accvgpr_read_b32 {rd0}, a{a0}", + f" v_accvgpr_read_b32 {rd1}, a{a1}", + f" v_cvt_pk_bf16_f32 {cv0}, {rd0}, {rd1}", + f" v_accvgpr_read_b32 {rd0}, a{a2}", + f" v_accvgpr_read_b32 {rd1}, a{a3}", + f" v_cvt_pk_bf16_f32 {cv1}, {rd0}, {rd1}", + ] + ) + + if pending_store is not None: + ps_reg, pa_reg = pending_store + out.extend( + [ + " s_waitcnt lgkmcnt(0)", + " s_mov_b64 exec, s[22:23]", + f" buffer_store_dwordx4 {ps_reg}, {pa_reg}, {first_srd}, 0 offen", + " s_mov_b64 exec, s[20:21]", + ] + ) + + out.extend( + [ + f" ds_bpermute_b32 {rd0}, v4, {cv0}", + f" ds_bpermute_b32 {rd1}, v4, {cv1}", + f" ds_bpermute_b32 {cv0}, v5, {cv0}", + f" ds_bpermute_b32 {cv1}, v5, {cv1}", + ] + ) + + pending_store = (store_reg, addr_reg) + + i = tile_end + continue + + if pending_store is not None and lines[i].strip() == "s_endpgm": + ps_reg, pa_reg = pending_store + out.extend( + [ + " s_waitcnt lgkmcnt(0)", + " s_mov_b64 exec, s[22:23]", + f" buffer_store_dwordx4 {ps_reg}, {pa_reg}, {first_srd}, 0 offen", + " s_mov_b64 exec, s[20:21]", + ] + ) + pending_store = None + + out.append(lines[i]) + i += 1 + + print( + f"[bpermute_pipelined] Transformed {tile_count} tiles: " + f"software-pipelined bpermute with exec-masked stores" + ) + return "\n".join(out) def _run_mxfp_gemm(gemm, shape): @@ -634,6 +1356,332 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm_bf16_coalesced( ) +def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm_bf16_coalesced_dwordx4( + is_debug=False, + shape=(6400, 3072, 7168), + block=(256, 192, 256), + eliminate_epilogue=True, +): + """Same as bf16_coalesced but with post-asm dwordx4 store coalescing.""" + gemm, options = get_tagged_mxfp4_gemm_preshuffle_b( + shape, + block, + wave_shape=(2, 2), + reorder_workgroups=True, + output_dtype=tkl.bf16, + transpose_output=True, + ) + dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] + for sym in dynamic_symbols: + del options.subs[sym] + options.dynamic_symbols = dynamic_symbols + options.use_buffer_ops = True + options.backend = "asm" + options.use_wave_asm_backend = True + options.wave_runtime = True + options.eliminate_epilogue = eliminate_epilogue + options.coalesce_epilogue_stores = True + options.asm_transform = coalesce_buffer_stores_dwordx4 + options.dump_intermediates = ( + "build/intermediates/waveasm_256x192x256_bf16_coalesced_dwordx4/" + ) + options.print_mlir_file = ( + "gemm_mxfp4_dbuf_4wave_asymmetric_bf16_coalesced_dwordx4.mlir" + ) + options.print_mlir = True + schedule = get_mxfp4_asymmetric_schedule( + eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True + ) + options.print_ir_after = "all" if is_debug else [] + options = set_default_run_config(options) + gemm = wave_compile(options, gemm, schedule) + + _run_mxfp_gemm_preshuffle( + gemm, shape, all=True, output_dtype=torch.bfloat16, transpose_output=True + ) + print("MXFP GEMM bf16 coalesced dwordx4 epilogue (WaveASM backend) test passed!") + + +def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm_bf16_cvt_first( + is_debug=False, + shape=(6400, 3072, 7168), + block=(256, 192, 256), + eliminate_epilogue=True, +): + """Same as bf16_coalesced but with convert-first epilogue (no cndmask).""" + gemm, options = get_tagged_mxfp4_gemm_preshuffle_b( + shape, + block, + wave_shape=(2, 2), + reorder_workgroups=True, + output_dtype=tkl.bf16, + transpose_output=True, + ) + dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] + for sym in dynamic_symbols: + del options.subs[sym] + options.dynamic_symbols = dynamic_symbols + options.use_buffer_ops = True + options.backend = "asm" + options.use_wave_asm_backend = True + options.wave_runtime = True + options.eliminate_epilogue = eliminate_epilogue + options.coalesce_epilogue_stores = True + options.asm_transform = convert_first_eliminate_cndmask + options.dump_intermediates = ( + "build/intermediates/waveasm_256x192x256_bf16_cvt_first/" + ) + options.print_mlir_file = "gemm_mxfp4_dbuf_4wave_asymmetric_bf16_cvt_first.mlir" + options.print_mlir = True + schedule = get_mxfp4_asymmetric_schedule( + eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True + ) + options.print_ir_after = "all" if is_debug else [] + options = set_default_run_config(options) + gemm = wave_compile(options, gemm, schedule) + + _run_mxfp_gemm_preshuffle( + gemm, shape, all=True, output_dtype=torch.bfloat16, transpose_output=True + ) + print("MXFP GEMM bf16 convert-first epilogue (WaveASM backend) test passed!") + + +def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm_bf16_transpose_only( + is_debug=False, + shape=(6400, 3072, 7168), + block=(256, 192, 256), + eliminate_epilogue=True, +): + """bf16 with transpose_output=True but NO coalesce_epilogue_stores (simple stores).""" + gemm, options = get_tagged_mxfp4_gemm_preshuffle_b( + shape, + block, + wave_shape=(2, 2), + reorder_workgroups=True, + output_dtype=tkl.bf16, + transpose_output=True, + ) + dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] + for sym in dynamic_symbols: + del options.subs[sym] + options.dynamic_symbols = dynamic_symbols + options.use_buffer_ops = True + options.backend = "asm" + options.use_wave_asm_backend = True + options.wave_runtime = True + options.eliminate_epilogue = eliminate_epilogue + options.dump_intermediates = ( + "build/intermediates/waveasm_256x192x256_bf16_transpose_only/" + ) + schedule = get_mxfp4_asymmetric_schedule( + eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True + ) + options.print_ir_after = "all" if is_debug else [] + options = set_default_run_config(options) + gemm = wave_compile(options, gemm, schedule) + + _run_mxfp_gemm_preshuffle( + gemm, shape, all=True, output_dtype=torch.bfloat16, transpose_output=True + ) + print("MXFP GEMM bf16 transpose-only (no coalesce) test passed!") + + +def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm_bf16_lds_epilogue( + is_debug=False, + shape=(6400, 3072, 7168), + block=(256, 192, 256), + eliminate_epilogue=True, +): + """bf16 pipelined bpermute epilogue with exec-masked stores.""" + gemm, options = get_tagged_mxfp4_gemm_preshuffle_b( + shape, + block, + wave_shape=(2, 2), + reorder_workgroups=True, + output_dtype=tkl.bf16, + transpose_output=True, + ) + dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] + for sym in dynamic_symbols: + del options.subs[sym] + options.dynamic_symbols = dynamic_symbols + options.use_buffer_ops = True + options.backend = "asm" + options.use_wave_asm_backend = True + options.wave_runtime = True + options.eliminate_epilogue = eliminate_epilogue + options.coalesce_epilogue_stores = True + options.asm_transform = bpermute_pipelined_epilogue_transform + options.dump_intermediates = ( + "build/intermediates/waveasm_256x192x256_bf16_lds_epilogue/" + ) + options.print_mlir_file = "gemm_mxfp4_dbuf_4wave_asymmetric_bf16_lds_epilogue.mlir" + options.print_mlir = True + schedule = get_mxfp4_asymmetric_schedule( + eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True + ) + options.print_ir_after = "all" if is_debug else [] + options = set_default_run_config(options) + gemm = wave_compile(options, gemm, schedule) + + _run_mxfp_gemm_preshuffle( + gemm, shape, all=True, output_dtype=torch.bfloat16, transpose_output=True + ) + print("MXFP GEMM bf16 pipelined bpermute epilogue (WaveASM backend) test passed!") + + +def _compile_bf16_kernel( + block, + *, + coalesce=False, + dwordx4=False, + convert_first=False, + lds_epilogue=False, + bpermute_masked=False, + bpermute_pipelined=False, + transpose_only=False, +): + """Compile a bf16 kernel once (M,N,K dynamic). Returns (kernel, transpose_output).""" + shape_placeholder = (block[0] * 4, block[1] * 4, block[2] * 4) + use_transpose = ( + coalesce + or lds_epilogue + or bpermute_masked + or bpermute_pipelined + or transpose_only + ) + kwargs = dict( + wave_shape=(2, 2), + reorder_workgroups=True, + output_dtype=tkl.bf16, + ) + if use_transpose: + kwargs["transpose_output"] = True + gemm, options = get_tagged_mxfp4_gemm_preshuffle_b( + shape_placeholder, block, **kwargs + ) + for sym in [tkl.sym.M, tkl.sym.N, tkl.sym.K]: + del options.subs[sym] + options.dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] + options.use_buffer_ops = True + options.backend = "asm" + options.use_wave_asm_backend = True + options.wave_runtime = True + options.eliminate_epilogue = True + if coalesce or lds_epilogue or bpermute_masked or bpermute_pipelined: + options.coalesce_epilogue_stores = True + if bpermute_pipelined: + options.asm_transform = bpermute_pipelined_epilogue_transform + elif bpermute_masked: + options.asm_transform = bpermute_masked_epilogue_transform + elif lds_epilogue: + options.asm_transform = lds_epilogue_transform + elif convert_first: + options.asm_transform = convert_first_eliminate_cndmask + elif dwordx4: + options.asm_transform = coalesce_buffer_stores_dwordx4 + schedule = get_mxfp4_asymmetric_schedule( + eliminate_epilogue=True, + is_bscale_shuffled=True, + ) + options = set_default_run_config(options) + return wave_compile(options, gemm, schedule), use_transpose + + +def _time_kernel(gemm, shape, transpose_output, warmup=2, iters=5): + """Time a compiled GEMM kernel on the given shape. Returns median us.""" + x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape) + w_t = w.T.contiguous() + w_t_ps = b_preshuffle(w_t) + x_scales_ps = e8m0_shuffle(x_scales) + w_scales_ps = e8m0_shuffle(w_scales) + x, w_t_ps = x.cuda(), w_t_ps.cuda() + x_scales_ps, w_scales_ps = x_scales_ps.cuda(), w_scales_ps.cuda() + if transpose_output: + out = torch.zeros(w_t_ps.shape[0], x.shape[0], dtype=torch.bfloat16).cuda() + else: + out = torch.zeros(x.shape[0], w_t_ps.shape[0], dtype=torch.bfloat16).cuda() + + for _ in range(warmup): + gemm(x, x_scales_ps, w_t_ps, w_scales_ps, out) + torch.cuda.synchronize() + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + for i in range(iters): + start_events[i].record() + gemm(x, x_scales_ps, w_t_ps, w_scales_ps, out) + end_events[i].record() + torch.cuda.synchronize() + + times = sorted([s.elapsed_time(e) * 1000 for s, e in zip(start_events, end_events)]) + return times[len(times) // 2] + + +def test_benchmark_bf16_shapes(is_debug=False, **kwargs): + """Benchmark bf16 baseline vs masked vs pipelined (exec-masked stores).""" + shapes = [ + (6400, 3072, 7168), + (4608, 7680, 6656), + ] + block = (256, 192, 256) + warmup = 10 + iters = 200 + rounds = 50 + + print("Compiling bf16 baseline (no transpose)...") + baseline_kernel, _ = _compile_bf16_kernel(block) + print("Compiling bf16 bpermute-masked (exec-masked stores)...") + masked_kernel, _ = _compile_bf16_kernel(block, bpermute_masked=True) + print("Compiling bf16 bpermute-pipelined (SW pipelined + exec-masked)...") + pipelined_kernel, _ = _compile_bf16_kernel(block, bpermute_pipelined=True) + + hdr = ( + f"{'Shape (M,N,K)':<30} {'baseline':>10} {'masked':>10} " + f"{'pipelined':>10} {'pipe/base':>10}" + ) + print(f"\n{hdr}") + print("-" * len(hdr)) + for shape in shapes: + bt, mt, pt = [], [], [] + for _ in range(rounds): + bt.append( + _time_kernel( + baseline_kernel, + shape, + transpose_output=False, + warmup=warmup, + iters=iters, + ) + ) + mt.append( + _time_kernel( + masked_kernel, + shape, + transpose_output=True, + warmup=warmup, + iters=iters, + ) + ) + pt.append( + _time_kernel( + pipelined_kernel, + shape, + transpose_output=True, + warmup=warmup, + iters=iters, + ) + ) + tb = sorted(bt)[len(bt) // 2] + tm = sorted(mt)[len(mt) // 2] + tp = sorted(pt)[len(pt) // 2] + print( + f"{str(shape):<30} {tb:>8.1f}us {tm:>8.1f}us " + f"{tp:>8.1f}us {tb/tp:>9.3f}x" + ) + print() + + if __name__ == "__main__": args = parse_args() diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index e01ddc8c20..16375a0be5 100755 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -1465,9 +1465,9 @@ def _write_permlane_pack_to_global( # Emit 4 stores of vector<2xbf16> (= buffer_store_dword each). # Each pair of f32 values is packed into one bf16 dword by - # v_cvt_pk_bf16_f32. Using 2-element stores avoids the multi-dword - # PackOp, which the register allocator cannot handle (it does not - # insert copies for PackOp operands). + # v_cvt_pk_bf16_f32. A peephole pass in the assembly emitter + # merges consecutive dword stores with sequential offsets into + # wider stores (dwordx2 / dwordx4). all_vals = s_lo + s_hi for pair_idx in range(4): pair_f32 = vector_d.from_elements( diff --git a/wave_lang/kernel/wave/coalesce_epilogue_stores.py b/wave_lang/kernel/wave/coalesce_epilogue_stores.py new file mode 100644 index 0000000000..880b5ab98f --- /dev/null +++ b/wave_lang/kernel/wave/coalesce_epilogue_stores.py @@ -0,0 +1,50 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Graph pass that coalesces epilogue bf16 stores via permlane16_swap. + +Marks eligible Write nodes so the codegen combines each thread's 4 bf16 +values with its partner lane's (16 lanes apart) via v_permlane16_swap_b32, +producing 8 consecutive bf16 written as a single buffer_store_dwordx4. +No LDS staging or barriers required. + +Precondition: the output memory must have M as the innermost (contiguous) +dimension (i.e. transpose_output=True producing [N, M] layout) so that 8 +consecutive bf16 elements span 8 adjacent M rows. +""" + +from .._support.tracing import CapturedTrace +from ..lang.global_symbols import GLOBAL_ADDRESS_SPACE +from ..ops.wave_ops import Write, get_custom +from .region_canonicalization import RegionFormat, requires_region_format +from .utils.symbol_utils import subs_idxc + + +@requires_region_format(RegionFormat.SCHEDULE_SIGNATURE_PLACEHOLDERS) +def coalesce_epilogue_stores(trace: CapturedTrace): + """Tag epilogue bf16 global writes for permlane16_swap packing. + + Walks the root graph and sets ``_permlane_pack_global = True`` on + every Write node that targets global memory with bf16 dtype. + The codegen in ``_write_permlane_pack_to_global`` handles the rest. + """ + import wave_lang.kernel.lang as tkl + + root_graph = trace.get_root_graph() + + for node in root_graph.nodes: + if node.op != "call_function": + continue + custom = get_custom(node) + if not isinstance(custom, Write): + continue + mem_type = custom.memory_type + if ( + subs_idxc(mem_type.address_space) == GLOBAL_ADDRESS_SPACE + and mem_type.dtype == tkl.bf16 + ): + node._permlane_pack_global = True diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index d9da8257b4..fbc2ba90b5 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -1213,6 +1213,21 @@ def get_binary_path(): # ASM flow: generate AMDGCN assembly; optionally build a binary asm = _generate_asm_code(mb, options) + asm_transform = getattr(options, "asm_transform", None) + if callable(asm_transform): + asm = asm_transform(asm) + if options.dump_intermediates: + import os as _os + + _os.makedirs(options.dump_intermediates, exist_ok=True) + with open( + _os.path.join( + options.dump_intermediates, "gemm_transformed.rocmasm" + ), + "w", + ) as _f: + _f.write(asm) + if options.backend == "asm" and not options.compile_to_asm: _compile_asm_to_binary(asm, options) elif options.use_water_backend: diff --git a/waveasm/lib/Transforms/LinearScanPass.cpp b/waveasm/lib/Transforms/LinearScanPass.cpp index e87dffe31a..bbca566660 100644 --- a/waveasm/lib/Transforms/LinearScanPass.cpp +++ b/waveasm/lib/Transforms/LinearScanPass.cpp @@ -433,6 +433,36 @@ struct LinearScanPass } }); + // Insert V_MOV_B32 copies for PackOp operands so that data lands in + // the consecutive physical registers expected by wide stores + // (buffer_store_dwordx4 etc.). PackOp is a no-op in assembly; the + // copies materialise the packing at runtime. + program.walk([&](PackOp packOp) { + auto resultType = dyn_cast(packOp.getResult().getType()); + if (!resultType || resultType.getSize() <= 1) + return; + + int64_t baseReg = resultType.getIndex(); + int64_t width = resultType.getSize(); + OpBuilder copyBuilder(packOp); + auto loc = packOp.getLoc(); + auto *ctx = packOp->getContext(); + + for (int64_t i = 0; + i < width && i < static_cast(packOp.getNumOperands()); + ++i) { + auto srcType = dyn_cast(packOp.getOperand(i).getType()); + if (!srcType) + continue; + int64_t srcReg = srcType.getIndex(); + int64_t tgtReg = baseReg + i; + if (srcReg == tgtReg) + continue; + auto tgtPVReg = PVRegType::get(ctx, tgtReg, 1); + V_MOV_B32::create(copyBuilder, loc, tgtPVReg, packOp.getOperand(i)); + } + }); + // Update block arguments and result types for region-based control flow. // After the walk above, operation results inside loop/if bodies have // physical register types, but block arguments and the parent op's From 02f5f2f6ef49604238d8d53a36441a549b484f6f Mon Sep 17 00:00:00 2001 From: xintin Date: Fri, 27 Mar 2026 18:26:53 +0000 Subject: [PATCH 5/5] compute C^T=B.A^T; store C Signed-off-by: xintin --- examples/python/7.1_schedule.py | 136 ++++++++++++++++---------------- 1 file changed, 68 insertions(+), 68 deletions(-) diff --git a/examples/python/7.1_schedule.py b/examples/python/7.1_schedule.py index 5d618a9d63..4e3a6caf33 100755 --- a/examples/python/7.1_schedule.py +++ b/examples/python/7.1_schedule.py @@ -133,6 +133,7 @@ def convert_first_eliminate_cndmask(asm_text): lines = asm_text.split("\n") out = [] tile_count = 0 + vcc_emitted = False i = 0 while i < len(lines): @@ -216,6 +217,9 @@ def convert_first_eliminate_cndmask(asm_text): # Preserved lines (lane mask, offset select, addr comp, SRD) # must come first so v244 and v253 are set before we use them. out.extend(preserved) + if not vcc_emitted: + out.append(" v_cmp_ne_u32 vcc, v244, 0") + vcc_emitted = True out.extend( [ f" v_accvgpr_read_b32 v0, a{a0}", @@ -224,23 +228,16 @@ def convert_first_eliminate_cndmask(asm_text): f" v_accvgpr_read_b32 v0, a{a2}", f" v_accvgpr_read_b32 v1, a{a3}", f" v_cvt_pk_bf16_f32 v3, v0, v1", + " v_mov_b32 v8, v2", + " v_mov_b32 v9, v3", "s_nop 1", " v_permlane16_swap_b32 v4, v2", "s_nop 1", " v_permlane16_swap_b32 v5, v3", - f" v_accvgpr_read_b32 v0, a{a0}", - f" v_accvgpr_read_b32 v1, a{a1}", - f" v_cvt_pk_bf16_f32 v2, v0, v1", - f" v_accvgpr_read_b32 v0, a{a2}", - f" v_accvgpr_read_b32 v1, a{a3}", - f" v_cvt_pk_bf16_f32 v3, v0, v1", - " v_cmp_ne_u32 vcc, v244, 0", - " v_cndmask_b32 v0, v4, v2", - " v_cndmask_b32 v1, v5, v3", - " v_cndmask_b32 v6, v2, v4", - " v_cndmask_b32 v7, v3, v5", - " v_mov_b32 v2, v6", - " v_mov_b32 v3, v7", + " v_cndmask_b32 v0, v4, v8", + " v_cndmask_b32 v1, v5, v9", + " v_cndmask_b32 v2, v8, v4", + " v_cndmask_b32 v3, v9, v5", " v_lshlrev_b32 v245, 1, v253", f" buffer_store_dwordx4 v[0:3], v245, {srd}, 0 offen", ] @@ -252,9 +249,6 @@ def convert_first_eliminate_cndmask(asm_text): out.append(lines[i]) i += 1 - print( - f"[convert_first] Transformed {tile_count} tiles: eliminated cndmask, merged to dwordx4" - ) return "\n".join(out) @@ -562,10 +556,6 @@ def bpermute_masked_epilogue_transform(asm_text): out.append(lines[i]) i += 1 - print( - f"[bpermute_masked] Transformed {tile_count} tiles: " - f"exec-masked stores eliminate duplicate writes" - ) return "\n".join(out) @@ -760,10 +750,6 @@ def bpermute_pipelined_epilogue_transform(asm_text): out.append(lines[i]) i += 1 - print( - f"[bpermute_pipelined] Transformed {tile_count} tiles: " - f"software-pipelined bpermute with exec-masked stores" - ) return "\n".join(out) @@ -785,42 +771,37 @@ def _run_mxfp_gemm(gemm, shape): def _run_mxfp_gemm_preshuffle( gemm, shape, - all=False, - only_scale=False, - only_b=False, output_dtype=torch.float32, - transpose_output=False, + swap_inputs=False, + **kwargs, ): - """Run compiled GEMM kernel with preshuffled B and B_scale, verify against reference. + """Run compiled GEMM kernel, verify against reference. - Shuffling is applied based on the flags: - all - shuffle a_scale (x_scales), b_scale (w_scales), and b (w_t) - only_scale - shuffle a_scale (x_scales) and b_scale (w_scales) only - only_b - shuffle b_scale (w_scales) only - - When transpose_output is True, the kernel writes C^T [N, M] instead of C [M, N]. + When swap_inputs is True, the kernel computes C^T = B x A^T (with A=X, B=W) + and writes C [M, N] directly via transpose_output + coalesced epilogue. + When swap_inputs is False (baseline), uses standard input order. """ x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape) torch_out = torchScaledGemmMXFP4(x, w, x_scales, w_scales) w_t = w.T.contiguous() - w_t_ps = b_preshuffle(w_t) if all else w_t - - x_scales_ps = e8m0_shuffle(x_scales) if (all or only_scale) else x_scales - - w_scales_ps = e8m0_shuffle(w_scales) if (all or only_scale or only_b) else w_scales - - x, w_t_ps = x.cuda(), w_t_ps.cuda() - x_scales_ps, w_scales_ps = x_scales_ps.cuda(), w_scales_ps.cuda() - if transpose_output: - out = torch.zeros(w_t_ps.shape[0], x.shape[0], dtype=output_dtype).cuda() + if swap_inputs: + kern_a = w_t.cuda() + kern_a_scale = e8m0_shuffle(w_scales).cuda() + kern_b = b_preshuffle(x).cuda() + kern_b_scale = e8m0_shuffle(x_scales).cuda() + out = torch.zeros(shape[0], shape[1], dtype=output_dtype).cuda() + gemm(kern_a, kern_a_scale, kern_b, kern_b_scale, out) + result = out.cpu() else: - out = torch.zeros(x.shape[0], w_t_ps.shape[0], dtype=output_dtype).cuda() - - gemm(x, x_scales_ps, w_t_ps, w_scales_ps, out) - - result = out.T.contiguous().cpu() if transpose_output else out.cpu() + kern_a = x.cuda() + kern_b = b_preshuffle(w_t).cuda() + kern_a_scale = e8m0_shuffle(x_scales).cuda() + kern_b_scale = e8m0_shuffle(w_scales).cuda() + out = torch.zeros(shape[0], shape[1], dtype=output_dtype).cuda() + gemm(kern_a, kern_a_scale, kern_b, kern_b_scale, out) + result = out.cpu() if os.environ.get("WAVE_DEBUG_COMPARE"): ref = torch_out.to(torch.float32).cpu() @@ -1441,7 +1422,7 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm_bf16_cvt_first( gemm = wave_compile(options, gemm, schedule) _run_mxfp_gemm_preshuffle( - gemm, shape, all=True, output_dtype=torch.bfloat16, transpose_output=True + gemm, shape, output_dtype=torch.bfloat16, swap_inputs=True ) print("MXFP GEMM bf16 convert-first epilogue (WaveASM backend) test passed!") @@ -1525,7 +1506,7 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm_bf16_lds_epilogue( gemm = wave_compile(options, gemm, schedule) _run_mxfp_gemm_preshuffle( - gemm, shape, all=True, output_dtype=torch.bfloat16, transpose_output=True + gemm, shape, output_dtype=torch.bfloat16, swap_inputs=True ) print("MXFP GEMM bf16 pipelined bpermute epilogue (WaveASM backend) test passed!") @@ -1539,15 +1520,21 @@ def _compile_bf16_kernel( lds_epilogue=False, bpermute_masked=False, bpermute_pipelined=False, + swap_inputs=False, transpose_only=False, ): - """Compile a bf16 kernel once (M,N,K dynamic). Returns (kernel, transpose_output).""" + """Compile a bf16 kernel once (M,N,K dynamic). Returns (kernel, mode_str). + + mode_str is one of: False (baseline), True (transpose), "swap" (input swap). + """ shape_placeholder = (block[0] * 4, block[1] * 4, block[2] * 4) use_transpose = ( coalesce or lds_epilogue or bpermute_masked or bpermute_pipelined + or swap_inputs + or convert_first or transpose_only ) kwargs = dict( @@ -1568,16 +1555,25 @@ def _compile_bf16_kernel( options.use_wave_asm_backend = True options.wave_runtime = True options.eliminate_epilogue = True - if coalesce or lds_epilogue or bpermute_masked or bpermute_pipelined: + if ( + coalesce + or lds_epilogue + or bpermute_masked + or bpermute_pipelined + or swap_inputs + or convert_first + ): options.coalesce_epilogue_stores = True if bpermute_pipelined: options.asm_transform = bpermute_pipelined_epilogue_transform + elif convert_first: + options.asm_transform = convert_first_eliminate_cndmask + elif swap_inputs: + options.asm_transform = bpermute_masked_epilogue_transform elif bpermute_masked: options.asm_transform = bpermute_masked_epilogue_transform elif lds_epilogue: options.asm_transform = lds_epilogue_transform - elif convert_first: - options.asm_transform = convert_first_eliminate_cndmask elif dwordx4: options.asm_transform = coalesce_buffer_stores_dwordx4 schedule = get_mxfp4_asymmetric_schedule( @@ -1585,32 +1581,36 @@ def _compile_bf16_kernel( is_bscale_shuffled=True, ) options = set_default_run_config(options) - return wave_compile(options, gemm, schedule), use_transpose + mode = "swap" if swap_inputs else use_transpose + return wave_compile(options, gemm, schedule), mode -def _time_kernel(gemm, shape, transpose_output, warmup=2, iters=5): +def _time_kernel(gemm, shape, warmup=2, iters=5, swap_inputs=False, **kwargs): """Time a compiled GEMM kernel on the given shape. Returns median us.""" x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape) w_t = w.T.contiguous() - w_t_ps = b_preshuffle(w_t) - x_scales_ps = e8m0_shuffle(x_scales) - w_scales_ps = e8m0_shuffle(w_scales) - x, w_t_ps = x.cuda(), w_t_ps.cuda() - x_scales_ps, w_scales_ps = x_scales_ps.cuda(), w_scales_ps.cuda() - if transpose_output: - out = torch.zeros(w_t_ps.shape[0], x.shape[0], dtype=torch.bfloat16).cuda() + + if swap_inputs: + kern_a = w_t.cuda() + kern_a_scale = e8m0_shuffle(w_scales).cuda() + kern_b = b_preshuffle(x).cuda() + kern_b_scale = e8m0_shuffle(x_scales).cuda() else: - out = torch.zeros(x.shape[0], w_t_ps.shape[0], dtype=torch.bfloat16).cuda() + kern_a = x.cuda() + kern_b = b_preshuffle(w_t).cuda() + kern_a_scale = e8m0_shuffle(x_scales).cuda() + kern_b_scale = e8m0_shuffle(w_scales).cuda() + out = torch.zeros(shape[0], shape[1], dtype=torch.bfloat16).cuda() for _ in range(warmup): - gemm(x, x_scales_ps, w_t_ps, w_scales_ps, out) + gemm(kern_a, kern_a_scale, kern_b, kern_b_scale, out) torch.cuda.synchronize() start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] for i in range(iters): start_events[i].record() - gemm(x, x_scales_ps, w_t_ps, w_scales_ps, out) + gemm(kern_a, kern_a_scale, kern_b, kern_b_scale, out) end_events[i].record() torch.cuda.synchronize()