diff --git a/.github/workflows/ci-gpu.yaml b/.github/workflows/ci-gpu.yaml index 32a59c0cba..0f0204e3a5 100644 --- a/.github/workflows/ci-gpu.yaml +++ b/.github/workflows/ci-gpu.yaml @@ -219,6 +219,16 @@ jobs: pytest -n 4 --capture=tee-sys -vv ./tests/unittests/ pytest -n 4 --capture=tee-sys -vv ./tests/mlir_wave_iface + - name: Run LIT tests + env: + WAVE_TEST_WATER: ${{ (env.IS_CDNA3 == 'true' || env.is_CDNA4 == 'true') && '1' || '0' }} + WAVE_TEST_DWARFDUMP: ${{ env.IS_RDNA4 == 'false' && '1' || '0' }} + run: | + # TODO: can't sudo to install dwarfdump on rdna4 + echo "WAVE_TEST_WATER=$WAVE_TEST_WATER" + echo "WAVE_TEST_DWARFDUMP=$WAVE_TEST_DWARFDUMP" + lit lit_tests/ -vv + - name: Test TKW runtime related stack on amdgpu if: ${{ env.HAS_GPU == 'true' }} run: | @@ -238,16 +248,6 @@ jobs: run: | WAVE_CACHE_ON=0 pytest -n 4 --timeout=600 --capture=tee-sys -vv --run-e2e --run-expensive-tests --durations=100 ./tests/kernel/ - - name: Run LIT tests - env: - WAVE_TEST_WATER: ${{ (env.IS_CDNA3 == 'true' || env.is_CDNA4 == 'true') && '1' || '0' }} - WAVE_TEST_DWARFDUMP: ${{ env.IS_RDNA4 == 'false' && '1' || '0' }} - run: | - # TODO: can't sudo to install dwarfdump on rdna4 - echo "WAVE_TEST_WATER=$WAVE_TEST_WATER" - echo "WAVE_TEST_DWARFDUMP=$WAVE_TEST_DWARFDUMP" - lit lit_tests/ -vv - - name: MyPy Type Checking run: | mypy diff --git a/README.md b/README.md index f8848f27b2..8122af5e6a 100644 --- a/README.md +++ b/README.md @@ -64,10 +64,18 @@ Before installing Wave, ensure you have the following prerequisites: Before installing Wave, ensure you have the appropriate ROCm-enabled PyTorch dependencies: + ```bash pip install -r pytorch-rocm-requirements.txt ``` + or, to auto-detect the installed ROCm version: + + ```sh + ./gen-pytorch-rocm-requirements.py > requirements-pytorch-rocm-generated.txt + pip install -r requirements-pytorch-rocm-generated.txt + ``` + 2. **Install Wave** You can then install Wave and its dependencies using pip: diff --git a/docs/scc-sgpr-promotion-investigation.md b/docs/scc-sgpr-promotion-investigation.md new file mode 100644 index 0000000000..b868032975 --- /dev/null +++ b/docs/scc-sgpr-promotion-investigation.md @@ -0,0 +1,180 @@ +# SCC & SGPR Promotion Investigation + +## Summary + +This document captures the findings from investigating SCC (Scalar Condition Code) tracking and VGPR→SGPR promotion in the WaveASM backend. The work spans SCC infrastructure, handler-level SALU promotions, and the SRD scalar select issue (now resolved — root cause was a register allocator bug where initial SRD precolored registers were DCE'd). + +## What Was Built + +### SCC Infrastructure (committed, working) + +- **SCCDef / SCCUse traits** on all SALU ops — every op now declares whether it writes or reads SCC +- **SCC verifier pass** (`--waveasm-scc-verifier`) — walks IR in emission order, catches SCC clobbers between producer and consumer +- Key insight: `Pure` in MLIR ODS = `NoMemoryEffect + AlwaysSpeculatableImplTrait`. Composing these with `SCCDef` gives identical MLIR pass behavior while enabling SCC verification. Adding bare `NativeOpTrait` to existing ODS classes changes tablegen output and alters MLIR pass behavior — must use the explicit composition. +- `S_CSELECT_B32` carries `SCCUse` trait (not Pure, not CSE-eligible — result depends on implicit SCC) + +### Handler SALU Promotions (committed, working) + +| Handler | Change | Impact | +|---------|--------|--------| +| `handleArithOrI` | `emitOr` helper (S_OR_B32 when both scalar) | No trigger in GEMM kernel | +| `handleArithXorI` | `emitXor` helper (S_XOR_B32) | No trigger | +| `handleArithDivUI` | Uses `emitLshr` (has scalar path) | No trigger | +| `handleArithRemUI` | Uses `emitAnd` (has scalar path) | No trigger | +| `handleArithMaxSI/UI` | `emitMaxI32/U32` (S_MAX_I32 when scalar) | **1 v_max → s_max** | +| `handleArithMinSI/UI` | `emitMinI32/U32` (S_MIN_I32 when scalar) | No trigger | +| `handleArithCmpI` | SGPR accepted in V_CMP (constant bus) | **2 v_mov_b32 eliminated** | +| `handleArithSelect` | SGPR cond accepted in V_CMP_NE_U32 | No trigger | +| AffineHandlers OR | `emitOr` instead of `ensureBothVGPR + V_OR_B32` | No trigger | + +### `emitScalarCmp` Helper (committed) + +Shared inline function in `Handlers.h` that emits `S_CMP_*` for any `arith::CmpIPredicate`. Used by `handleArithCmpI` and available for `AMDGPUHandlers.cpp`. + +## The VGPR Pressure Problem + +### Numbers (256x192 GEMM, our kernel vs aiter reference) + +| Category | Ours | Reference | Delta | +|----------|------|-----------|-------| +| VALU address/control in loop | 21 VGPRs | 0 | **+21** | +| MFMA scale operands | 24 | 10 | +14 | +| buffer_load voffset addresses | 24 | 10 | +14 | +| Peak arch VGPRs | ~276 | ~227 | **+49** | + +The biggest single win is eliminating VALU from the loop body (+21 VGPRs). + +### Root Cause: cmpi→select→fat_raw_buffer_cast Chain + +In the MLIR, the loop body has: + +```mlir +%next_K = affine.apply (s0 + 2)[%loop_iv] // scalar +%K_bound = affine.apply (s0 ceildiv 256)[%K] // scalar +%cond = arith.cmpi slt, %next_K, %K_bound // uniform comparison +%validBytes = arith.select %cond, %bufSize, 0 // branchless guard +%srd = amdgpu.fat_raw_buffer_cast ... validBytes(%validBytes) +``` + +All values are provably uniform (scalar). The aiter reference emits this as 2 SALU ops: + +```asm +s_cmp_lt_u32 0x200, s51 ; scalar comparison +s_cselect_b32 s61, s61, 0 ; scalar select → SRD soffset +``` + +Our code emits 7 VALU ops + v_readfirstlane because: +1. `handleArithCmpI` emits `V_CMP + materializeVCCToBoolVGPR` (VGPR boolean) +2. `handleArithSelect` emits `V_CMP_NE_U32 + V_CNDMASK_B32` (VGPR result) +3. `emitSrdNumRecords` does `v_readfirstlane_b32` to extract back to SGPR + +## What Was Tried (and Failed) + +### Approach 1: CmpI Fusion in handleArithCmpI + +**Idea:** When both cmpi operands are scalar, emit `s_cmp + s_cselect(1, 0)` to produce an SGPR boolean instead of VGPR. + +**Result:** Memory fault. The SGPR boolean propagates through `arith.extui` and `arith.addi` chains, changing the type of downstream values from VGPR to SGPR. This cascading type change alters register allocation for buffer descriptors, corrupting SRD addresses. + +**Why:** The SGPR result enters the mapper and changes every downstream consumer's type decision. Values that were VGPR (with v_readfirstlane extraction) become SGPR, shifting the entire allocation picture. + +### Approach 2: CmpI Fusion in handleArithSelect + +**Idea:** Add a fusion path that detects `arith.select(arith.cmpi(scalar, scalar), scalar, scalar)` and emits `s_cmp + s_cselect` directly. + +**Result:** Memory fault when the select result feeds non-SRD consumers (cascading SGPR changes). With user-safety guards (only fire for specific downstream patterns), the fusion never triggers for the loop-body pattern. + +### Approach 3: Targeted SRD Scalar Select in emitSrdNumRecords + +**Idea:** Only in `emitSrdNumRecords`, detect the `arith.select(arith.cmpi)` pattern feeding `fat_raw_buffer_cast`'s `validBytes` and emit `s_cmp + s_cselect` directly into the precolored SRD word 2 (`PSRegType`). + +**Result:** Memory fault ("Write access to a read-only page"). Tried four emission variants: +- Direct `S_CSELECT_B32` into `PSRegType(srdBase+2)` with `DCEProtectOp` +- `S_CSELECT_B32` into virtual SGPR → `S_MOV_B32` copy to PSReg +- `S_CSELECT_B32` into virtual SGPR → `S_ADD_U32` copy to PSReg (clobbers SCC — eliminated as cause) +- Fresh `S_MOV_B32` copy of trueOp inside loop body before `S_CSELECT_B32` + +All produce correct-looking assembly. The SCC verifier reports no hazards. Register allocation appears correct (validBytes SGPR kept alive from prologue through loop body, not clobbered by intermediates). + +### Root Cause (RESOLVED) + +The assembly emitted by all three approaches was semantically correct. +The real bug was in the register allocator's interaction with `RawOp`-based +SRD setup: + +**Bug chain:** + +1. `emitSRDPrologue` creates a `PrecoloredSRegOp` for each initial SRD + (e.g., arg4 output buffer at s[36:39]) and fills it via `RawOp`s + (`s_mov_b64`, `s_mov_b32`). + +2. The SSA result of the `PrecoloredSRegOp` is mapped to the memref via + `mapper.mapValue()`. For some SRDs (data/scale buffers used inside the + loop), downstream SSA uses keep the value alive. For others (e.g., the + output buffer SRD), the only downstream references are `RawOp`s that + copy the physical registers directly (e.g., `s_mov_b64 s[80:81], s[36:37]` + in the epilogue). These `RawOp`s create no SSA uses. + +3. `CanonicalizerPass` (DCE) removes the unused `PrecoloredSRegOp`. + +4. `LinearScanPass` walks `PrecoloredSRegOp` ops to build `reservedSGPRs`. + Since the op was DCE'd, s[36:39] is NOT reserved. + +5. The linear scan allocator freely assigns s36/s37 to loop body temp + values (soffset, next-K computation), **clobbering the SRD base address**. + +6. After the loop, the epilogue `RawOp` copies garbage from s[36:37] to + s[80:81], forming a corrupted buffer descriptor → GPU fault. + +**Why it only manifested with s_cselect:** The s_cselect change referenced +the SGPR validBytes value inside the loop, extending its liveness from the +prologue through the loop body. This increased SGPR pressure forced the +linear scan allocator to reclaim s[36:39] (which it thought was free). +In the baseline, the allocator had enough free SGPRs and happened to never +assign anything to s[36:39]. + +**Fix:** Add `DCEProtectOp` after each initial SRD's `PrecoloredSRegOp` +in `emitSRDPrologue`. This prevents DCE from removing the op, so +`LinearScanPass` always reserves the physical SRD registers. + +```cpp +auto srdReg = PrecoloredSRegOp::create(builder, loc, srdType, srdBase, 4); +// ... RawOps for s_mov_b64, s_mov_b32 ... +DCEProtectOp::create(builder, loc, srdReg); // prevent DCE +``` + +**Impact:** The fix is correct for all kernels, not just the s_cselect case. +Any kernel where SGPR pressure pushes the allocator into the SRD register +range would have hit this bug. The SCC/s_cselect work merely exposed it. + +## V_CNDMASK_B32 Constant Bus Issue + +Attempted to allow one SGPR in `V_CNDMASK_B32` (VOP3 constant bus allows one SGPR + VCC). This caused test failures — `v_cndmask_b32` VOP3 encoding with VCC counts as one constant bus slot, and adding an SGPR as src0/src1 requires a SECOND slot, exceeding the limit on some instructions. + +## Files Modified + +| File | Changes | +|------|---------| +| `WaveASMInterfaces.h/.td` | SCCDef, SCCUse trait definitions | +| `WaveASMOps.td` | SCC-aware op classes, trait assignments | +| `SCCVerifier.cpp` | SCC verification pass | +| `ScopedCSE.cpp` | SCCUse exclusion from CSE | +| `Handlers.h` | emitOr, emitXor, emitMinI32/U32, emitMaxI32/U32, emitScalarCmp | +| `ArithHandlers.cpp` | SALU paths for OR/XOR/Min/Max/DivUI/RemUI, V_CMP constant bus | +| `AffineHandlers.cpp` | emitOr for non-overlapping Add | +| `AMDGPUHandlers.cpp` | s_cmp + s_cselect path in emitSrdNumRecords | +| `TranslateFromMLIR.cpp` | DCEProtect for initial SRD PrecoloredSRegOps | +| `Passes.td` | SCC verifier pass registration | +| `CMakeLists.txt` | SCCVerifier.cpp | +| `compile.py` | `--waveasm-scc-verifier` in pipeline | +| `waveasm_e2e.py` | Same | + +## Commits + +1. `d54817dc` — SCC verifier pass and trait infrastructure +2. `bc9726f8` — Migrate all SALU ops to SCC-aware trait classes +3. `d3e69752` — SALU paths for OR, XOR, DivUI, RemUI handlers +4. `3bcf0b85` — Eliminate unnecessary SGPR-to-VGPR coercions before V_CMP +5. `c22ce840` — emitScalarCmp helper and emitMin/emitMax SALU promotion +6. `8c65e6e0` — Move emitScalarCmp to Handlers.h +7. `963444f3` — SRD scalar select TODO with failure analysis diff --git a/docs/wave/canonical_ir_format.md b/docs/wave/canonical_ir_format.md new file mode 100644 index 0000000000..58d2f4084a --- /dev/null +++ b/docs/wave/canonical_ir_format.md @@ -0,0 +1,173 @@ +# Canonical Region IR Format + +This document describes the canonical region structure used by the Wave FX pipeline. + +It is intentionally limited in scope: it only covers nested region interfaces +and capture structure. It does not attempt to define a full canonical form for +all FX graph details such as node ordering, `vector_shapes`, or write +dependencies. + +## Goal + +Wave is moving toward a single canonical region form so that: + +- nested regions have one stable structural representation in Python +- FX <-> MLIR roundtrips preserve region structure in a form that later Python + passes can continue to process +- for migration purposes, passes may declare which temporary non-canonical region view they need + without forcing the whole pipeline to stay in that view + +The canonical form is represented by `RegionFormat.ISOLATED` in +`wave_lang/kernel/wave/region_canonicalization.py`. + +## Terms + +- Outer source: a node defined outside a nested region but used by that region +- Local capture: the region-local representative of an outer source +- Capture signature: the ordered list stored on the parent `NestedRegionOp` in + `implicit_captures` +- Direct outer reference: a region node operand that points straight to a node + in the outer graph + +## Canonical Form: `ISOLATED` + +`ISOLATED` is the canonical/default region form. + +Structural invariants: + +- nested region bodies do not directly reference outer graph nodes +- every captured outer value used inside the region is represented by a region-local + `Placeholder` +- these capture placeholders form the leading non-`IterArg` input prefix of the + subgraph +- the parent `NestedRegionOp.implicit_captures` list is the authoritative + ordered capture signature: it defines *which* outer values are captured and in + *what order* +- each local capture placeholder carries a `meta["lifted"]` link to its outer + source. This per-placeholder metadata is derived from `implicit_captures` and + must stay consistent with it (the verifier checks this) + +In other words, the region interface is explicit and isolated from above. + +Schematic shape: + +```text +root graph: + %outer_a + %outer_b + %region = iterate(..., implicit_captures=[%outer_a, %outer_b], ...) + +region subgraph: + %iter_arg0 = placeholder(iter arg) + %outer_a = placeholder(lifted from outer) + %outer_b = placeholder(lifted from outer) + ... +``` + +## Temporary Non-Canonical Forms + +Not all existing passes operate on `ISOLATED` yet. To migrate incrementally, +passes may request one of several temporary region views. The pass boundary +adapts into that form before the pass runs and, in the normal pipeline, +canonicalizes back to `ISOLATED` afterwards. + +### `LEGACY_PLACEHOLDERS` + +This is the older placeholder-based capture form still expected by some +pre-existing passes. + +Structural properties: + +- captured outer values are represented by region-local placeholders +- the mapping from a local placeholder back to its outer source may still be + recovered with ad-hoc conventions that pre-existing passes relied on: name + matching and positional fallback within the capture prefix (codified in + `_try_resolve_legacy_capture_source` in `region_canonicalization.py`) +- unlike `ISOLATED`, this form does not require `implicit_captures` plus + `meta["lifted"]` to be the sole authoritative description of the capture + interface + +This is a weaker contract than `ISOLATED`. A pass marked +`LEGACY_PLACEHOLDERS` may still reason about captures through placeholder +layout or legacy lookup behavior instead of relying only on the canonical +capture interface. An already-canonical region may also satisfy this weaker +contract, so the adapter can be a no-op on some graphs. + +This mode exists to support passes that still expect legacy placeholder +structure while the pipeline as a whole moves toward explicit canonical +captures. + +### `DIRECT_OUTER_REF` + +This is a legacy form where region bodies directly reference outer graph nodes. + +Structural properties: + +- operands inside the region may point directly to outer nodes +- capture placeholders are removed or bypassed where possible +- the parent capture signature may still track those outer values, but the body + itself is not isolated from above + +This form is convenient for passes that want to inspect or mutate the original +outer values directly, especially around captured memory operands. + +### `SCHEDULE_SIGNATURE_PLACEHOLDERS` + +This is a hybrid legacy form used by scheduling-related passes. + +Structural properties (schedule-signature sources are the outer values that +define the region boundary from the scheduler's point of view, namely +outer-graph `Placeholder`s, i.e. kernel arguments, and `NewRegister`s): + +- placeholders are kept only for those schedule-signature sources +- non-signature captures are rewritten back to direct outer references +- the region mixes explicit placeholders for signature-defining values with + direct outer references for everything else + +In practice, the schedule-signature sources are the outer values that define +the region boundary from the scheduler's point of view, namely values such as +root placeholders and `NewRegister`s. + +## Why There Are Multiple Forms + +The long-term goal is for passes to converge on `ISOLATED`. + +The intermediate forms exist because rewriting every pass at once would be too +large and too risky. Instead: + +1. the pipeline keeps one canonical form +2. each pass declares the temporary form it currently expects +3. pass-boundary adapters convert into that form just for the duration of the + pass +4. the normal pipeline returns to canonical form afterwards + +This makes the migration incremental while still establishing one structural +source of truth. + +## Pass Contract + +Passes declare their required region form with `@requires_region_format(...)`. + +The important contract is: + +- if a pass does not declare a region form, it is assumed to operate on the + canonical `ISOLATED` form +- in the normal pipeline, pass outputs are canonicalized back to `ISOLATED` +- white-box tests that want to inspect a temporary legacy form must request + `canonicalize_output=False` explicitly at the call site + +This keeps the default pipeline principled while still allowing tests to inspect +the raw intermediate structure when needed. + +## Non-Goals + +This document does not define: + +- a canonical ordering for all FX nodes +- a canonical form for `GetResult` materialization beyond what is required by + region structure +- a canonical form for `vector_shapes`, write dependencies, or other downstream + analysis state + +Those are separate concerns that may later build on top of this region-level +structural baseline. diff --git a/examples/python/7.1_schedule.py b/examples/python/7.1_schedule.py index 0896395cc4..d97102576c 100755 --- a/examples/python/7.1_schedule.py +++ b/examples/python/7.1_schedule.py @@ -375,7 +375,7 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm_cpp( eliminate_epilogue=True, ): """Preshuffle-B MXFP4 GEMM using C++ WaveASM backend.""" - gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(1, 4)) + gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(2, 2), reorder_workgroups=True) options.backend = "asm" options.use_buffer_ops = True options.wave_runtime = True @@ -394,7 +394,6 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm_cpp( f"MXFP GEMM preshuffle-B 4-wave (WaveASM) epilogue elimination={eliminate_epilogue} PASSED" ) - def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm( is_debug=False, shape=(1024, 1024, 8192), @@ -421,7 +420,39 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm( gemm = wave_compile(options, gemm, schedule) _run_mxfp_gemm_preshuffle(gemm, shape, all=True) - print("MXFP GEMM preshuffle-B 4-wave (WaveASM backend) test passed!") + print("MXFP GEMM preshuffle-B 4-wave dynamic M, N, K (LLVM backend) test passed!") + + +def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm( + is_debug=False, + shape=(1024, 1024, 8192), + block=(128, 256, 256), + eliminate_epilogue=True, +): + """Preshuffle-B MXFP4 GEMM with dynamic M, N, K.""" + gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(2, 2), reorder_workgroups=True) + # Make M, N, K dynamic so the compiler does not specialize on problem size. + dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] + for sym in dynamic_symbols: + del options.subs[sym] + options.dynamic_symbols = dynamic_symbols + options.use_buffer_ops = True + options.backend = "asm" + options.use_wave_asm_backend = True + options.wave_runtime = True + options.eliminate_epilogue = eliminate_epilogue + options.dump_intermediates = "build/intermediates/" + options.print_mlir_file = "gemm_mxfp4_dbuf_4wave_asymmetric.mlir" + options.print_mlir = True + schedule = get_mxfp4_asymmetric_schedule( + eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True + ) + options.print_ir_after = "all" if is_debug else [] + options = set_default_run_config(options) + gemm = wave_compile(options, gemm, schedule) + + _run_mxfp_gemm_preshuffle(gemm, shape, all=True) + print("MXFP GEMM preshuffle-B 4-wave dynamic M, N, K (WaveASM backend) test passed!") if __name__ == "__main__": diff --git a/examples/python/test_sympy_diff.py b/examples/python/test_sympy_diff.py new file mode 100644 index 0000000000..a9efd6a6b2 --- /dev/null +++ b/examples/python/test_sympy_diff.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +""" +Test suite for mem_simplify — the domain-specific memory access simplifier. + +Verifies that mem_simplify handles all the algebraic identities needed for +tensor memory indexing without relying on sympy.simplify(). + +Usage inside docker: + source /workspace/wave/.venv/bin/activate + python test_sympy_diff.py +""" + +import sympy +from sympy import Symbol, Mod, floor, simplify, expand, Integer + +from wave_lang.kernel.wave.utils.mapping_utils import ( + mem_simplify, + linearize_dims, + _eval_concrete_floor_mod, + _expand_mod, +) + +print(f"SymPy version: {sympy.__version__}") +print("=" * 70) + +K = Symbol("_K_div_256", integer=True, positive=True) +idx0 = Symbol("$index0", integer=True, nonnegative=True) +idx1 = Symbol("$index1", integer=True, nonnegative=True) + +PASS = 0 +FAIL = 0 + + +def check(name, got, expected): + global PASS, FAIL + ok = sympy.expand(got - expected) == 0 if not isinstance(got, bool) else got == expected + if ok: + PASS += 1 + else: + FAIL += 1 + print(f" [FAIL] {name}") + print(f" got: {got}") + print(f" expected: {expected}") + print() + + +# ── 1. The fundamental identity: a*floor(b/a) + Mod(b,a) == b ────────── +print("\n── 1. mem_simplify on floor/Mod identity ──") + +for b_val in [32, 64, 96, 128, 256]: + expr = K * floor(b_val / K) + Mod(b_val, K) + check(f"K*floor({b_val}/K)+Mod({b_val},K)", mem_simplify(expr), b_val) + +for b_val in [32, 64, 96, 128]: + expr = 8 * K * floor(b_val / K) + 8 * Mod(b_val, K) + check(f"8*K*floor({b_val}/K)+8*Mod({b_val},K)", mem_simplify(expr), 8 * b_val) + +print(f" Section 1: {PASS} passed") +p1 = PASS + +# ── 2. linearize_dims: the 2D-to-flat round-trip ─────────────────────── +print("\n── 2. linearize_dims ──") + +Ks = 8 * K +for flat_val in [0, 256, 512, 768, 1024]: + dim0 = floor(flat_val / Ks) + dim1 = Mod(flat_val, Ks) + result = linearize_dims([dim0, dim1], [Ks, Integer(1)]) + check(f"linearize_dims(flat={flat_val})", result, flat_val) + +E = 64 * Mod(idx0, 4) + 4 * Mod(idx1, 16) + 256 * floor(idx0 / 8) +dim0_sym = floor(E / Ks) +dim1_sym = Mod(E, Ks) +result_sym = linearize_dims([dim0_sym, dim1_sym], [Ks, Integer(1)]) +check("linearize_dims(symbolic flat)", mem_simplify(result_sym - E), 0) + +print(f" Section 2: {PASS - p1} passed") +p2 = PASS + +# ── 3. _eval_concrete_floor_mod ──────────────────────────────────────── +print("\n── 3. Concrete floor/Mod evaluation ──") + +check("floor(32/8)", _eval_concrete_floor_mod(floor(Integer(32) / 8)), Integer(4)) +check("floor(7/3)", _eval_concrete_floor_mod(floor(Integer(7) / 3)), Integer(2)) +check("Mod(256, 64)", _eval_concrete_floor_mod(Mod(256, 64)), Integer(0)) +check("Mod(10, 3)", _eval_concrete_floor_mod(Mod(10, 3)), Integer(1)) +check("Mod(0, K)", _eval_concrete_floor_mod(Mod(0, K)), Integer(0)) +check("floor(sym) unchanged", _eval_concrete_floor_mod(floor(K / 8)), floor(K / 8)) + +print(f" Section 3: {PASS - p2} passed") +p3 = PASS + +# ── 4. mem_simplify preserves already-simple expressions ─────────────── +print("\n── 4. Passthrough cases ──") + +check("integer", mem_simplify(Integer(42)), Integer(42)) +check("symbol", mem_simplify(K), K) +check("linear expr", mem_simplify(3 * K + 7), 3 * K + 7) +check("floor(sym/n) unchanged", mem_simplify(floor(K / 8)), floor(K / 8)) + +print(f" Section 4: {PASS - p3} passed") +p4 = PASS + +# ── 5. _expand_mod ───────────────────────────────────────────────────── +print("\n── 5. _expand_mod ──") + +check("expand_mod(Mod(32,K))", _expand_mod(Mod(32, K)), 32 - K * floor(32 / K)) +expr5 = 8 * K * floor(32 / K) + 8 * Mod(32, K) +check("expand_mod + expand cancels", expand(_expand_mod(expr5)), Integer(256)) + +print(f" Section 5: {PASS - p4} passed") +p5 = PASS + +# ── 6. The full scale-mapping simulation ─────────────────────────────── +print("\n── 6. Scale mapping addr simulation ──") + +b_scale_flat = ( + 256 * K * floor(idx1 / 32) + + 64 * Mod(idx0, 4) + + 4 * Mod(idx1, 16) + + 256 * floor(idx0 / 8) + + 2 * floor(Mod(idx0, 8) / 4) + + floor(Mod(idx1, 32) / 16) +) + +for iv_val, expected_addr in [(0, 0), (1, 256), (2, 512), (3, 768), (4, 1024)]: + dim0 = floor(b_scale_flat / Ks) + dim1 = Mod(b_scale_flat, Ks) + addr_expr = linearize_dims([dim0, dim1], [Ks, Integer(1)]) + addr_concrete = mem_simplify(addr_expr.subs({idx0: 8 * iv_val, idx1: 0})) + check(f"scale_addr(iv={iv_val})", addr_concrete, expected_addr) + +print(f" Section 6: {PASS - p5} passed") +p6 = PASS + +# ── 7. Mod auto-factoring ───────────────────────────────────────────── +print("\n── 7. Mod factoring (SymPy built-in) ──") + +for b_val in [32, 64, 128]: + m1 = Mod(8 * b_val, 8 * K) + m2 = 8 * Mod(b_val, K) + check(f"Mod({8*b_val}, 8K) == 8*Mod({b_val},K)", simplify(m1 - m2), 0) + +print(f" Section 7: {PASS - p6} passed") + +# ── Summary ───────────────────────────────────────────────────────────── +print("\n" + "=" * 70) +print(f"Results: {PASS} PASS, {FAIL} FAIL") +if FAIL: + print("*** SOME TESTS FAILED ***") +else: + print("All tests passed.") +print("=" * 70) + +exit(1 if FAIL else 0) diff --git a/gen-pytorch-rocm-requirements.py b/gen-pytorch-rocm-requirements.py new file mode 100755 index 0000000000..abf2652f3b --- /dev/null +++ b/gen-pytorch-rocm-requirements.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +"""Detect the installed ROCm version and emit a pytorch-rocm-requirements.txt +that points pip at a wheel source carrying matching ROCm builds of PyTorch. + +Two wheel sources are tried, in order: + + 1. PyTorch Foundation index --index-url https://download.pytorch.org/whl/rocmX.Y + 2. AMD repo (flat listing) --find-links https://repo.radeon.com/rocm/manylinux/rocm-rel-X.Y/ + +Source (1) is preferred because it is a proper PIP package index; source (2) +is a flat directory of wheels that AMD publishes for every ROCm release. + +The script probes each URL with an HTTP HEAD request and picks the first that +responds. Use --offline to skip probing and always emit the AMD --find-links +URL (it covers every ROCm release). +""" + +from __future__ import annotations + +import argparse +import os +import re +import subprocess +import sys +from pathlib import Path +from urllib.error import URLError +from urllib.request import Request, urlopen + +PYTORCH_INDEX = "https://download.pytorch.org/whl/rocm{ver}" +AMD_FIND_LINKS = "https://repo.radeon.com/rocm/manylinux/rocm-rel-{ver}/" + +TORCH_SPEC = "torch>=2.6,<2.9" + +# --------------------------------------------------------------------------- +# ROCm version detection +# --------------------------------------------------------------------------- + + +def _read_version_file(rocm_root: Path) -> str | None: + p = rocm_root / ".info" / "version" + try: + return p.read_text().strip() + except OSError: + return None + + +def _detect_rocm_version() -> str: + """Return the installed ROCm version string (e.g. '7.2.0'). + + Search order: + 1. $ROCM_PATH/.info/version (or $ROCM_HOME) + 2. /opt/rocm/.info/version + 3. ``hipconfig --version`` output + """ + for env in ("ROCM_PATH", "ROCM_HOME"): + val = os.environ.get(env) + if val: + ver = _read_version_file(Path(val)) + if ver: + return ver + + ver = _read_version_file(Path("/opt/rocm")) + if ver: + return ver + + try: + out = subprocess.check_output( + ["hipconfig", "--version"], text=True, stderr=subprocess.DEVNULL + ).strip() + if out: + return out + except (FileNotFoundError, subprocess.CalledProcessError): + pass + + raise SystemExit( + "error: cannot detect ROCm version.\n" + "Set ROCM_PATH or ensure /opt/rocm/.info/version exists." + ) + + +# --------------------------------------------------------------------------- +# Version string candidates +# --------------------------------------------------------------------------- + +_VER_RE = re.compile(r"^(\d+)\.(\d+)(?:\.(\d+))?$") + + +def _version_candidates(raw: str) -> list[str]: + """Return version strings to try, most-specific first. + + '7.2.0' -> ['7.2'] (patch 0 is always omitted) + '6.2.4' -> ['6.2.4', '6.2'] + '7.0.2' -> ['7.0.2', '7.0'] + """ + m = _VER_RE.match(raw) + if not m: + raise SystemExit(f"error: cannot parse ROCm version {raw!r}") + major, minor, patch = m.group(1), m.group(2), m.group(3) or "0" + short = f"{major}.{minor}" + if patch == "0": + return [short] + full = f"{major}.{minor}.{patch}" + return [full, short] + + +# --------------------------------------------------------------------------- +# URL probing +# --------------------------------------------------------------------------- + + +def _url_exists(url: str, timeout: float = 10) -> bool: + try: + req = Request(url, method="HEAD") + resp = urlopen(req, timeout=timeout) + return resp.status < 400 + except (URLError, OSError, ValueError): + return False + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument( + "-o", + "--output", + default="-", + help="Output file (default: stdout)", + ) + ap.add_argument( + "--offline", + action="store_true", + help="Skip URL probing; always emit AMD --find-links URL", + ) + ap.add_argument( + "--rocm-version", + default=None, + help="Override ROCm version instead of auto-detecting", + ) + args = ap.parse_args() + + raw_ver = args.rocm_version or _detect_rocm_version() + candidates = _version_candidates(raw_ver) + print( + f"Detected ROCm {raw_ver}, candidate version tags: {candidates}", + file=sys.stderr, + ) + + # Try each candidate against each source. + sources: list[tuple[str, str]] = [] + for ver in candidates: + sources.append(("--index-url", PYTORCH_INDEX.format(ver=ver))) + for ver in candidates: + sources.append(("--find-links", AMD_FIND_LINKS.format(ver=ver))) + + chosen_flag = chosen_url = None + + if args.offline: + # In offline mode, pick the first AMD --find-links URL without probing. + for flag, url in sources: + if flag == "--find-links": + chosen_flag, chosen_url = flag, url + break + else: + for flag, url in sources: + probe = url.rstrip("/") + "/" + print(f" probing {probe} ...", end=" ", file=sys.stderr, flush=True) + if _url_exists(probe): + print("ok", file=sys.stderr) + chosen_flag, chosen_url = flag, url + break + print("not found", file=sys.stderr) + + if chosen_url is None: + raise SystemExit(f"error: no PyTorch wheel source found for ROCm {raw_ver}") + + lines = [f"{chosen_flag} {chosen_url}\n", f"{TORCH_SPEC}\n"] + + if args.output == "-": + sys.stdout.writelines(lines) + else: + Path(args.output).write_text("".join(lines)) + print(f"Wrote {args.output}", file=sys.stderr) + + +if __name__ == "__main__": + main() diff --git a/lit_tests/kernel/wave/barrier_strategies.py b/lit_tests/kernel/wave/barrier_strategies.py index 046468458b..be62dc0b01 100644 --- a/lit_tests/kernel/wave/barrier_strategies.py +++ b/lit_tests/kernel/wave/barrier_strategies.py @@ -179,7 +179,7 @@ def test_gemm(): set_post_expansion_indices(trace, constraints) tweak_index(graph) hoist_loop_invariant_ops(trace, constraints) - add_shared_memory_barriers(trace) + add_shared_memory_barriers(trace, canonicalize_output=False) print_trace(trace, False) # CHECK-LABEL: test_gemm @@ -231,8 +231,14 @@ def test_split_barriers(): tweak_index(graph) hoist_loop_invariant_ops(trace, constraints) schedule_graph(trace, constraints, True, enable_scheduling) - schedule_reordering(trace, constraints, enable_scheduling, use_global_to_shared) - add_shared_memory_barriers(trace, target="gfx1201") + schedule_reordering( + trace, + constraints, + enable_scheduling, + use_global_to_shared, + canonicalize_output=False, + ) + add_shared_memory_barriers(trace, target="gfx1201", canonicalize_output=False) print_trace(trace, False) # CHECK-LABEL: test_split_barriers @@ -363,7 +369,7 @@ def test_existing_barrier_not_duplicated(): # Now run barrier placement - it should detect the existing barrier # and NOT insert duplicates - add_shared_memory_barriers(trace) + add_shared_memory_barriers(trace, canonicalize_output=False) # Count barriers after barriers_after = count_barriers(graph) @@ -570,7 +576,7 @@ def test_memory_counter_wait_barrier_prevents_redundant_barrier(): # Now run barrier placement - it should detect the existing MemoryCounterWaitBarrier # and NOT insert an additional SharedMemoryBarrier (since synchronization is already provided) - add_shared_memory_barriers(trace) + add_shared_memory_barriers(trace, canonicalize_output=False) # Count barriers after mcw_barriers_after = count_memory_counter_wait_barriers(graph) diff --git a/lit_tests/kernel/wave/barriers.py b/lit_tests/kernel/wave/barriers.py index 6ec00cd3b3..1b15518b4a 100644 --- a/lit_tests/kernel/wave/barriers.py +++ b/lit_tests/kernel/wave/barriers.py @@ -114,7 +114,7 @@ def test_read_write_equal_sizes(): expand_graph(trace, constraints) set_post_expansion_indices(trace, constraints) tweak_index(graph) - add_shared_memory_barriers(trace) + add_shared_memory_barriers(trace, canonicalize_output=False) print_trace(trace, False) # CHECK: %a # CHECK-NEXT: %c @@ -205,7 +205,7 @@ def test_gemm(): set_post_expansion_indices(trace, constraints) tweak_index(graph) hoist_loop_invariant_ops(trace, constraints) - add_shared_memory_barriers(trace) + add_shared_memory_barriers(trace, canonicalize_output=False) print_trace(trace, False) # Root graph: # CHECK: %a @@ -330,8 +330,14 @@ def test_split_barriers(): tweak_index(graph) hoist_loop_invariant_ops(trace, constraints) schedule_graph(trace, constraints, True, enable_scheduling) - schedule_reordering(trace, constraints, enable_scheduling, use_global_to_shared) - add_shared_memory_barriers(trace, target="gfx1201") + schedule_reordering( + trace, + constraints, + enable_scheduling, + use_global_to_shared, + canonicalize_output=False, + ) + add_shared_memory_barriers(trace, target="gfx1201", canonicalize_output=False) print_trace(trace, False) # Note: In pipelined loops, signal/wait pairs may have operations between them diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 96271b727a..5aa12b7fe6 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -900,6 +900,7 @@ def test( ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE, }, canonicalize=True, + compile_to_mlir=True, ) options = set_default_compile_config(options) test = wave_compile(options, test) diff --git a/lit_tests/kernel/wave/dynamic_shapes_preshuffle_mxfp4.py b/lit_tests/kernel/wave/dynamic_shapes_preshuffle_mxfp4.py index 96e321d39a..a87157a0a0 100644 --- a/lit_tests/kernel/wave/dynamic_shapes_preshuffle_mxfp4.py +++ b/lit_tests/kernel/wave/dynamic_shapes_preshuffle_mxfp4.py @@ -4,20 +4,16 @@ Test dynamic-shapes support for preshuffle-B MXFP4 GEMM. When M, N, K are dynamic (not substituted at compile time), the compiler -must emit a runtime guard (scf.if) that checks whether K is large enough -for the pipelined kernel path (prologue + scf.for + epilogue). If K is -too small, the else branch yields zero accumulators and falls through to -a non-pipelined epilogue loop. +emits a pipelined kernel with gather_to_lds prefetch and software-pipelined +scf.for loop. Index simplification (factoring out divisor-multiples from +floor/Mod expressions) enables the scheduler to prove the pipeline guard +is always satisfied, eliminating the scf.if entirely. Key structural invariants verified: 1. Function signature accepts dynamic index arguments for M, N, K. - 2. scf.if guard selects between pipelined and fallback paths. - 3. The pipelined "then" branch contains: - - gather_to_lds prefetch (prologue) - - scf.for main loop with amdgpu.scaled_mfma - - epilogue scaled_mfma after the loop - 4. The "else" branch yields zero accumulators (scf.yield with %cst). - 5. A second scf.for after the scf.if handles remaining K iterations. + 2. Prologue gather_to_lds prefetch loads. + 3. scf.for main loop with amdgpu.scaled_mfma. + 4. Epilogue scaled_mfma after the loop. """ from wave_lang.kernel.wave.compile import wave_compile @@ -55,29 +51,17 @@ def test_dynamic_preshuffle_b_mxfp4(): # 1. Dynamic index arguments for M, N, K in function signature. # CHECK: func.func @gemm(%arg0: {{.*}}, %arg1: {{.*}}, %arg2: {{.*}}, %arg3: {{.*}}, %arg4: {{.*}}, %arg5: index, %arg6: index, %arg7: index) - # 2. scf.if guard: pipelined path vs fallback. - # CHECK: scf.if %{{.*}} -> - - # 3a. Pipelined "then" branch: prologue gather_to_lds prefetch. + # 2. Prologue gather_to_lds prefetch (no scf.if guard — simplification + # proves the pipeline guard is always satisfied). # CHECK: amdgpu.gather_to_lds # CHECK: amdgpu.gather_to_lds - # 3b. Pipelined "then" branch: main loop with scaled_mfma. + # 3. Main pipelined loop with scaled_mfma. # CHECK: scf.for # CHECK: amdgpu.scaled_mfma # CHECK: scf.yield - # 3c. Pipelined "then" branch: epilogue scaled_mfma after loop. - # CHECK: amdgpu.scaled_mfma - - # 4. Else branch: fallback yields zero accumulators. - # CHECK: } else { - # CHECK-NEXT: scf.yield %cst - - # 5. Non-pipelined epilogue scf.for after the scf.if. - # CHECK: scf.for - # CHECK: amdgpu.lds_barrier - # CHECK: amdgpu.gather_to_lds + # 4. Epilogue scaled_mfma after the loop. # CHECK: amdgpu.scaled_mfma @@ -111,36 +95,23 @@ def test_dynamic_preshuffle_b_mxfp4_eliminate_epilogue(): # 1. Dynamic index arguments for M, N, K in function signature. # CHECK: func.func @gemm(%arg0: {{.*}}, %arg1: {{.*}}, %arg2: {{.*}}, %arg3: {{.*}}, %arg4: {{.*}}, %arg5: index, %arg6: index, %arg7: index) - # 2. scf.if guard: pipelined path vs fallback. - # CHECK: scf.if %{{.*}} -> - - # 3. Pipelined loop steps by 1 (no epilogue to peel off). + # 2. Pipelined loop steps by 1 (no epilogue to peel off). + # No scf.if guard — simplification proves it always satisfied. # CHECK: scf.for %{{.*}} = %c0 to %{{.*}} step %c1 - # 4. Loop carries shared-memory buffers as iter_args (epilogue folded in). + # 3. Loop carries shared-memory buffers as iter_args (epilogue folded in). # CHECK-SAME: memref<{{.*}}, #gpu.address_space> - # 5. OOB guard: arith.select chooses real validBytes vs 0 for out-of-range + # 4. OOB guard: arith.select chooses real validBytes vs 0 for out-of-range # iterations, so the hardware returns zeros on OOB loads. # CHECK: arith.select %{{.*}}, %c2147483646_i64, %c0_i64 : i64 - # 6. fat_raw_buffer_cast uses the dynamically selected validBytes. + # 5. fat_raw_buffer_cast uses the dynamically selected validBytes. # CHECK: amdgpu.fat_raw_buffer_cast %{{.*}} validBytes(%{{.*}}) - # 7. scaled_mfma inside the pipelined loop. + # 6. scaled_mfma inside the pipelined loop. # CHECK: amdgpu.scaled_mfma - # 8. Loop body ends; no epilogue mfma between loop end and scf.yield. + # 7. Loop body ends; no epilogue mfma between loop end and scf.yield. # CHECK: scf.yield # CHECK-NEXT: } - - # 9. Then-branch yields loop results directly (no epilogue ops). - # CHECK: scf.yield - - # 10. Else branch: fallback yields zero accumulators. - # CHECK: } else { - # CHECK-NEXT: scf.yield %cst - - # 11. Remainder scf.for after the scf.if. - # CHECK: scf.for - # CHECK: amdgpu.scaled_mfma diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index 59a0e512d8..344871e9a8 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -78,10 +78,10 @@ def test_read_write_equal_sizes(): idxc.subs = subs graph = read_write_same_size() idxc.finalize() - add_get_results(graph) - infer_types(graph) - set_node_indices(graph, constraints) - expand_graph(graph, constraints) + add_get_results(graph, canonicalize_output=False) + infer_types(graph, canonicalize_output=False) + set_node_indices(graph, constraints, canonicalize_output=False) + expand_graph(graph, constraints, canonicalize_output=False) set_post_expansion_indices(graph, constraints) print_trace(graph) # CHECK: %a @@ -161,10 +161,10 @@ def test_read_write(): idxc.subs = subs graph = read_write_different_dims() idxc.finalize() - add_get_results(graph) - infer_types(graph) - set_node_indices(graph, constraints) - expand_graph(graph, constraints) + add_get_results(graph, canonicalize_output=False) + infer_types(graph, canonicalize_output=False) + set_node_indices(graph, constraints, canonicalize_output=False) + expand_graph(graph, constraints, canonicalize_output=False) set_post_expansion_indices(graph, constraints) print_trace(graph) # CHECK: %a @@ -243,11 +243,11 @@ def test_write_in_iterate(): idxc.subs = subs graph = write_in_iterate() idxc.finalize() - initialize_iter_args(graph) - add_get_results(graph) - infer_types(graph) - set_node_indices(graph, constraints) - expand_graph(graph, constraints) + initialize_iter_args(graph, canonicalize_output=False) + add_get_results(graph, canonicalize_output=False) + infer_types(graph, canonicalize_output=False) + set_node_indices(graph, constraints, canonicalize_output=False) + expand_graph(graph, constraints, canonicalize_output=False) set_post_expansion_indices(graph, constraints) print_trace(graph) # Root graph: @@ -326,11 +326,11 @@ def test_no_writes(): idxc.subs = subs graph = no_writes() idxc.finalize() - initialize_iter_args(graph) - add_get_results(graph) - infer_types(graph) - set_node_indices(graph, constraints) - expand_graph(graph, constraints) + initialize_iter_args(graph, canonicalize_output=False) + add_get_results(graph, canonicalize_output=False) + infer_types(graph, canonicalize_output=False) + set_node_indices(graph, constraints, canonicalize_output=False) + expand_graph(graph, constraints, canonicalize_output=False) set_post_expansion_indices(graph, constraints) print_trace(graph) @@ -367,11 +367,11 @@ def test_gemm(): idxc.subs = subs graph = gemm() idxc.finalize() - initialize_iter_args(graph) - add_get_results(graph) - infer_types(graph) - set_node_indices(graph, constraints) - expand_graph(graph, constraints) + initialize_iter_args(graph, canonicalize_output=False) + add_get_results(graph, canonicalize_output=False) + infer_types(graph, canonicalize_output=False) + set_node_indices(graph, constraints, canonicalize_output=False) + expand_graph(graph, constraints, canonicalize_output=False) set_post_expansion_indices(graph, constraints) print_trace(graph) # Root graph: @@ -432,6 +432,7 @@ def test_gemm(): # CHECK-NEXT: %acc_M:1_N:1_K:0 # CHECK-NEXT: %a + # CHECK-NEXT: %b # CHECK-NEXT: %read_M:0_N:0_K:0 # CHECK-SAME: (%a, 4, None, (), None, MemoryAccessFlags.NONE, None, None, None) # CHECK-NEXT: %read_M:0_N:0_K:1 @@ -441,7 +442,6 @@ def test_gemm(): # CHECK-NEXT: %read_M:1_N:0_K:1 # CHECK-SAME: (%a, 4, None, (), None, MemoryAccessFlags.NONE, None, None, None) - # CHECK-NEXT: %b # CHECK-NEXT: %read_1_M:0_N:0_K:0 # CHECK-SAME: (%b, 4, None, (), None, MemoryAccessFlags.NONE, None, None, None) # CHECK-NEXT: %read_1_M:0_N:0_K:1 @@ -475,11 +475,11 @@ def test_gemm(): # CHECK-NEXT: placeholder(_name=acc_M:1_N:0_K:0 # CHECK-NEXT: placeholder(_name=acc_M:1_N:1_K:0 # CHECK-NEXT: placeholder(_name=a + # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) # CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) # CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) # CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: read(memory=b, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) # CHECK-NEXT: read(memory=b, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) # CHECK-NEXT: read(memory=b, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) @@ -559,11 +559,11 @@ def test_batched_gemm(): idxc.subs = subs graph = batched_gemm() idxc.finalize() - initialize_iter_args(graph) - add_get_results(graph) - infer_types(graph) - set_node_indices(graph, constraints) - expand_graph(graph, constraints) + initialize_iter_args(graph, canonicalize_output=False) + add_get_results(graph, canonicalize_output=False) + infer_types(graph, canonicalize_output=False) + set_node_indices(graph, constraints, canonicalize_output=False) + expand_graph(graph, constraints, canonicalize_output=False) set_post_expansion_indices(graph, constraints) print_trace(graph) # Root graph: @@ -621,6 +621,7 @@ def test_batched_gemm(): # CHECK-NEXT: %acc_M:1_N:1_K:0 # CHECK-NEXT: %a + # CHECK-NEXT: %b # CHECK-NEXT: %read_M:0_N:0_K:0 # CHECK-SAME: (%a, 4, None, (), None, MemoryAccessFlags.NONE, None, None, None) # CHECK-NEXT: %read_M:0_N:0_K:1 @@ -630,7 +631,6 @@ def test_batched_gemm(): # CHECK-NEXT: %read_M:1_N:0_K:1 # CHECK-SAME: (%a, 4, None, (), None, MemoryAccessFlags.NONE, None, None, None) - # CHECK-NEXT: %b # CHECK-NEXT: %read_1_M:0_N:0_K:0 # CHECK-SAME: (%b, 4, None, (), None, MemoryAccessFlags.NONE, None, None, None) # CHECK-NEXT: %read_1_M:0_N:0_K:1 @@ -664,11 +664,11 @@ def test_batched_gemm(): # CHECK-NEXT: placeholder(_name=acc_M:1_N:0_K:0 # CHECK-NEXT: placeholder(_name=acc_M:1_N:1_K:0 # CHECK-NEXT: placeholder(_name=a + # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) # CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) # CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) # CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: read(memory=b, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) # CHECK-NEXT: read(memory=b, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) # CHECK-NEXT: read(memory=b, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) @@ -742,11 +742,11 @@ def test_gemm_non_direct_acc(): idxc.subs = subs graph = gemm_non_direct_acc() idxc.finalize() - initialize_iter_args(graph) - add_get_results(graph) - infer_types(graph) - set_node_indices(graph, constraints) - expand_graph(graph, constraints) + initialize_iter_args(graph, canonicalize_output=False) + add_get_results(graph, canonicalize_output=False) + infer_types(graph, canonicalize_output=False) + set_node_indices(graph, constraints, canonicalize_output=False) + expand_graph(graph, constraints, canonicalize_output=False) set_post_expansion_indices(graph, constraints) print_trace(graph) # CHECK: %add_M:0_N:0_K:0 @@ -811,11 +811,11 @@ def test_tiled_max(): idxc.subs = subs graph = tiled_max() idxc.finalize() - initialize_iter_args(graph) - add_get_results(graph) - infer_types(graph) - set_node_indices(graph, constraints) - expand_graph(graph, constraints) + initialize_iter_args(graph, canonicalize_output=False) + add_get_results(graph, canonicalize_output=False) + infer_types(graph, canonicalize_output=False) + set_node_indices(graph, constraints, canonicalize_output=False) + expand_graph(graph, constraints, canonicalize_output=False) set_post_expansion_indices(graph, constraints) print_trace(graph) # CHECK: max(arg=[read_M:0_K:0, read_M:0_K:1, read_M:0_K:2, read_M:0_K:3], init=acc_M:0_K:0 @@ -845,11 +845,11 @@ def test_gemm_iterate_expansion_only(): idxc.subs = subs graph = gemm() idxc.finalize() - initialize_iter_args(graph) - add_get_results(graph) - infer_types(graph) - set_node_indices(graph, constraints) - expand_graph(graph, constraints) + initialize_iter_args(graph, canonicalize_output=False) + add_get_results(graph, canonicalize_output=False) + infer_types(graph, canonicalize_output=False) + set_node_indices(graph, constraints, canonicalize_output=False) + expand_graph(graph, constraints, canonicalize_output=False) set_post_expansion_indices(graph, constraints) print_trace(graph) # Root graph: @@ -879,12 +879,12 @@ def test_gemm_iterate_expansion_only(): # CHECK: %acc_M:0_N:0_K:0 # CHECK-NEXT: %a + # CHECK-NEXT: %b # CHECK-NEXT: %read_M:0_N:0_K:0 # CHECK-SAME: (%a, 4, None, (), None, MemoryAccessFlags.NONE, None, None, None) # CHECK-NEXT: %read_M:0_N:0_K:1 # CHECK-SAME: (%a, 4, None, (), None, MemoryAccessFlags.NONE, None, None, None) - # CHECK-NEXT: %b # CHECK-NEXT: %read_1_M:0_N:0_K:0 # CHECK-SAME: (%b, 4, None, (), None, MemoryAccessFlags.NONE, None, None, None) # CHECK-NEXT: %read_1_M:0_N:0_K:1 @@ -901,9 +901,9 @@ def test_gemm_iterate_expansion_only(): # CHECK-NEXT: placeholder(_name=acc_M:0_N:0_K:0 # CHECK-NEXT: placeholder(_name=a + # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) # CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: read(memory=b, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) # CHECK-NEXT: read(memory=b, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) # CHECK-NEXT: mma(lhs=read_M:0_N:0_K:0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) @@ -992,11 +992,11 @@ def test_attention(): idxc.subs = subs graph = attention() idxc.finalize() - initialize_iter_args(graph) - add_get_results(graph) - infer_types(graph) - set_node_indices(graph, constraints) - expand_graph(graph, constraints) + initialize_iter_args(graph, canonicalize_output=False) + add_get_results(graph, canonicalize_output=False) + infer_types(graph, canonicalize_output=False) + set_node_indices(graph, constraints, canonicalize_output=False) + expand_graph(graph, constraints, canonicalize_output=False) set_post_expansion_indices(graph, constraints) print_trace(graph) @@ -1089,10 +1089,10 @@ def py_arithmetic_different_dims(): idxc.subs = subs graph = py_arithmetic_different_dims() idxc.finalize() - add_get_results(graph) - infer_types(graph) - set_node_indices(graph, constraints) - expand_graph(graph, constraints) + add_get_results(graph, canonicalize_output=False) + infer_types(graph, canonicalize_output=False) + set_node_indices(graph, constraints, canonicalize_output=False) + expand_graph(graph, constraints, canonicalize_output=False) set_post_expansion_indices(graph, constraints) print_trace(graph) # CHECK: %a @@ -1196,17 +1196,19 @@ def test_chained_gemm_32x32x8(): idxc.subs = subs graph = chained_gemm_32x32x8() idxc.finalize() - initialize_iter_args(graph) - add_get_results(graph) - infer_types(graph) - set_node_indices(graph, constraints) - expand_graph(graph, constraints) + initialize_iter_args(graph, canonicalize_output=False) + add_get_results(graph, canonicalize_output=False) + infer_types(graph, canonicalize_output=False) + set_node_indices(graph, constraints, canonicalize_output=False) + expand_graph(graph, constraints, canonicalize_output=False) set_post_expansion_indices(graph, constraints) print_trace(graph) # CHECK: %acc_M:0_N:0_K2:0 - # CHECK: %register # CHECK: %q + # CHECK: %k + # CHECK: %v + # CHECK: %register # CHECK: %read_M:0_K2:0_K1:0 # CHECK-SAME: (args = (%q, 4, None, (), None, MemoryAccessFlags.NONE, None, None, None) # CHECK: %read_M:0_K2:0_K1:1 @@ -1215,7 +1217,6 @@ def test_chained_gemm_32x32x8(): # CHECK-SAME: (args = (%q, 4, None, (), None, MemoryAccessFlags.NONE, None, None, None) # CHECK: %read_M:0_K2:0_K1:3 # CHECK-SAME: (args = (%q, 4, None, (), None, MemoryAccessFlags.NONE, None, None, None) - # CHECK: %k # CHECK: %read_1_shared_M:0_K2:0_K1:0 # CHECK-SAME: (args = (%k, 4, None, (), None, MemoryAccessFlags.NONE, None, None, None) # CHECK: %read_1_shared_M:0_K2:0_K1:1 @@ -1236,7 +1237,6 @@ def test_chained_gemm_32x32x8(): # CHECK-SAME: (args = (%mma_M:0_K2:0_K1:3, [B, M, K2]) # CHECK: %cast_M:0_K2:0 # CHECK-SAME: (args = (%permute_M:0_K2:0, f16) - # CHECK: %v # CHECK: %read_2_shared_M:0_N:0_K2:0 # CHECK-SAME: (args = (%v, 4, None, (), None, MemoryAccessFlags.NONE, None, None, None) # CHECK: %read_2_shared_M:0_N:0_K2:1 diff --git a/lit_tests/kernel/wave/hoisting.py b/lit_tests/kernel/wave/hoisting.py index 0059c3733b..2ed05133b6 100644 --- a/lit_tests/kernel/wave/hoisting.py +++ b/lit_tests/kernel/wave/hoisting.py @@ -94,14 +94,14 @@ def loop(acc: tkl.Register[M, N, tkl.f32]): } trace = simple_kernel() idxc.finalize() - initialize_iter_args(trace) - add_get_results(trace) - infer_types(trace) - set_node_indices(trace, constraints) - expand_graph(trace, constraints) + initialize_iter_args(trace, canonicalize_output=False) + add_get_results(trace, canonicalize_output=False) + infer_types(trace, canonicalize_output=False) + set_node_indices(trace, constraints, canonicalize_output=False) + expand_graph(trace, constraints, canonicalize_output=False) set_post_expansion_indices(trace, constraints) remove_chained_getresult(trace) - hoist_loop_invariant_ops(trace, constraints) + hoist_loop_invariant_ops(trace, constraints, canonicalize_output=False) print("=== Root Graph ===") print_trace(trace) @@ -123,11 +123,12 @@ def loop(acc: tkl.Register[M, N, tkl.f32]): # CHECK: %write_M:0_N:0_K:0 # CHECK: === Iterate Subgraph === - # CHECK: %acc_M:0_N:0_K:0 - # CHECK: %a - # CHECK: %read_1_M:0_N:0_K:0 - # CHECK: %b - # CHECK: %read_2_M:0_N:0_K:0 - # CHECK: %add_M:0_N:0_K:0 - # CHECK: %mma_M:0_N:0_K:0 - # CHECK: %add_1_M:0_N:0_K:0 + # CHECK: %acc_M:0_N:0_K:0 : + # CHECK: %a : + # CHECK: %b : + # CHECK: %read_1_M:0_N:0_K:0 : + # CHECK: %read_2_M:0_N:0_K:0 : + # CHECK: %add_M:0_N:0_K:0 : + # CHECK: %mma_M:0_N:0_K:0 : + # CHECK: %add_1_M:0_N:0_K:0 : + # CHECK: return [add_1_M:0_N:0_K:0] diff --git a/lit_tests/kernel/wave/index_sequence_analysis.py b/lit_tests/kernel/wave/index_sequence_analysis.py index 5afd85c1b3..4bcac7dc5b 100644 --- a/lit_tests/kernel/wave/index_sequence_analysis.py +++ b/lit_tests/kernel/wave/index_sequence_analysis.py @@ -93,12 +93,12 @@ def test_gemm(): idxc.subs = subs trace: CapturedTrace = gemm() idxc.finalize() - initialize_iter_args(trace) - add_get_results(trace) - infer_types(trace) + initialize_iter_args(trace, canonicalize_output=False) + add_get_results(trace, canonicalize_output=False) + infer_types(trace, canonicalize_output=False) promote_placeholders(trace, constraints) - set_node_indices(trace, constraints) - expand_graph(trace, constraints) + set_node_indices(trace, constraints, canonicalize_output=False) + expand_graph(trace, constraints, canonicalize_output=False) set_post_expansion_indices(trace, constraints) hoist_loop_invariant_ops(trace, constraints) minimize_global_loads(trace, constraints) @@ -268,6 +268,8 @@ def test_gemm(): # CHECK-NEXT: %acc_M:1_N:1_K:0 # CHECK-NEXT: %a # CHECK-NEXT: %b + # CHECK-NEXT: %allocate + # CHECK-NEXT: %allocate_1 # CHECK-NEXT: %read_37 # CHECK-SAME: (%a, 8, None, (), None, MemoryAccessFlags.NONE, None, None, None) # CHECK-NEXT: %write_18 @@ -325,6 +327,8 @@ def test_gemm(): # CHECK-NEXT: placeholder(_name=acc_M:1_N:1_K:0 # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: placeholder(_name=b, _type=Memory[N, K].of(f16)) + # CHECK-NEXT: placeholder(_name=allocate, _type=Memory[M, K].of(f16)) + # CHECK-NEXT: placeholder(_name=allocate_1, _type=Memory[N, K].of(f16)) # CHECK-NEXT: read(memory=a, elements_per_thread=8, # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod(16*$T1 + 32*$T2 + floor($T0/8), 64) : 1 : 1, K: ARGK*BLOCK_K + 8*(Mod($T0, 8)) : 8 : 1}) # CHECK-NEXT: write(register_=read_37, memory=allocate, elements_per_thread=8, @@ -340,7 +344,7 @@ def test_gemm(): # CHECK-NEXT: read(memory=b, elements_per_thread=8, # CHECK-SAME: index={N: $WG1*BLOCK_N + BLOCK_N/2 + Mod(16*$T1 + 32*$T2 + floor($T0/8) + 32, 64) : 1 : 1, K: ARGK*BLOCK_K + 8*(Mod($T0, 8)) : 8 : 1}) # CHECK-NEXT: write(register_=read_40, memory=allocate_1, elements_per_thread=8, - # CHECK-SMAE: index={N: BLOCK_N/2 + Mod(16*$T1 + 32*$T2 + floor($T0/8) + 32, 64), K: 8*(Mod($T0, 8)) : 8 : 1}) + # CHECK-SAME: index={N: BLOCK_N/2 + Mod(16*$T1 + 32*$T2 + floor($T0/8) + 32, 64) : 1 : 1, K: 8*(Mod($T0, 8)) : 8 : 1}) # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, _write_dependency=[write_20, write_21], index={N: BLOCK_N/2 + Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) : 4 : 1}) # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, _write_dependency=[write_20, write_21], index={N: BLOCK_N/2 + Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, _write_dependency=[write_20, write_21], index={N: BLOCK_N/2 + Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) diff --git a/lit_tests/kernel/wave/iteration.py b/lit_tests/kernel/wave/iteration.py index 383759639a..97973a7a0b 100644 --- a/lit_tests/kernel/wave/iteration.py +++ b/lit_tests/kernel/wave/iteration.py @@ -393,16 +393,18 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: idxc.subs = subs trace: CapturedTrace = iterated_gemm() idxc.finalize() - initialize_iter_args(trace) - add_get_results(trace) - infer_types(trace) + initialize_iter_args(trace, canonicalize_output=False) + add_get_results(trace, canonicalize_output=False) + infer_types(trace, canonicalize_output=False) promote_placeholders(trace, constraints) - set_node_indices(trace, constraints) - expand_graph(trace, constraints) + set_node_indices(trace, constraints, canonicalize_output=False) + expand_graph(trace, constraints, canonicalize_output=False) set_post_expansion_indices(trace, constraints) hoist_loop_invariant_ops(trace, constraints) minimize_global_loads(trace, constraints) - apply_shared_memory_indexing_corrections(trace, constraints) + apply_shared_memory_indexing_corrections( + trace, constraints, canonicalize_output=False + ) # Check the graph before unrolling # Find iterate and unroll @@ -413,8 +415,6 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: print_graph(trace.get_subgraph(iterate.subgraph_name)) # CHECK: placeholder - # CHECK-NEXT: placeholder - # CHECK-NEXT: placeholder # CHECK-NEXT: [read] # CHECK-NEXT: [write] # CHECK-NEXT: [read] @@ -434,8 +434,6 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # TODO: Check that the bounds are correct, and steps # CHECK: placeholder - # CHECK-NEXT: placeholder - # CHECK-NEXT: placeholder # CHECK-NEXT: [read] # CHECK-NEXT: [write] # CHECK-NEXT: [read] @@ -460,8 +458,6 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: assert iterate.step == 4 # step multiplied again (2 * 2 = 4) # CHECK: placeholder - # CHECK-NEXT: placeholder - # CHECK-NEXT: placeholder # CHECK-NEXT: [read] # CHECK-NEXT: [write] # CHECK-NEXT: [read] diff --git a/lit_tests/kernel/wave/minimize_global_loads.py b/lit_tests/kernel/wave/minimize_global_loads.py index cd20c8a715..b51a051807 100644 --- a/lit_tests/kernel/wave/minimize_global_loads.py +++ b/lit_tests/kernel/wave/minimize_global_loads.py @@ -108,7 +108,7 @@ def test_gemm(): apply_shared_memory_indexing_corrections(trace, constraints) if visualize: visualize_graph(trace.get_subgraph("region_0"), "after.png") - add_shared_memory_barriers(trace) + add_shared_memory_barriers(trace, canonicalize_output=False) print_trace(trace) # Root graph: # CHECK: %a diff --git a/lit_tests/kernel/wave/mlir_to_fx.py b/lit_tests/kernel/wave/mlir_to_fx.py index 6662569474..4961bf3780 100644 --- a/lit_tests/kernel/wave/mlir_to_fx.py +++ b/lit_tests/kernel/wave/mlir_to_fx.py @@ -405,11 +405,13 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: errors = error_diagnostics(fx_diags) assert errors == [], f"unexpected errors: {errors}" - # Collect address space symbols from Memory-typed placeholders (each - # corresponds to a distinct function argument). + # Collect address space symbols from root Memory placeholders only. + # Canonical MLIR import reconstructs lifted capture placeholders in + # subgraphs, and those should intentionally reuse the same unresolved + # address space symbol as their captured root argument. placeholder_addrs = [ node.type.address_space - for node in fx_trace.walk(lambda n: n) + for node in fx_trace.get_root_graph().nodes if isinstance(get_custom(node), Placeholder) and node.type is not None and issubclass(node.type, Memory) diff --git a/lit_tests/kernel/wave/promotion.py b/lit_tests/kernel/wave/promotion.py index 4a30a3e11c..fbdb9a420a 100644 --- a/lit_tests/kernel/wave/promotion.py +++ b/lit_tests/kernel/wave/promotion.py @@ -181,7 +181,7 @@ def test_gemm(): infer_types(trace) for read_node in read_nodes: promote_node(read_node, None, SHARED_ADDRESS_SPACE, constraints) - hoist_loop_invariant_ops(trace, constraints) + hoist_loop_invariant_ops(trace, constraints, canonicalize_output=False) print_trace(trace, False) # Root graph: # CHECK: %a diff --git a/lit_tests/kernel/wave/scheduling.py b/lit_tests/kernel/wave/scheduling.py index 5f7cf55bba..db7ad8bbe0 100644 --- a/lit_tests/kernel/wave/scheduling.py +++ b/lit_tests/kernel/wave/scheduling.py @@ -108,10 +108,7 @@ def test_gemm_pipelined(): minimize_global_loads(trace, constraints) apply_shared_memory_indexing_corrections(trace, constraints) schedule_graph( - trace, - constraints, - True, - SchedulingType.MODULO, + trace, constraints, True, SchedulingType.MODULO, canonicalize_output=False ) print_subgraph(trace, "pipelined_iterate", False) diff --git a/lit_tests/kernel/wave/test_combine_indices_no_mutation.py b/lit_tests/kernel/wave/test_combine_indices_no_mutation.py new file mode 100644 index 0000000000..9c0917ce15 --- /dev/null +++ b/lit_tests/kernel/wave/test_combine_indices_no_mutation.py @@ -0,0 +1,57 @@ +# RUN: python %s +# Test that combine_indices does not mutate its first argument (thread_independent_index). +# Without deepcopy in combine_indices, the same IndexSequence instances can be shared +# across nodes during propagation, and in-place mutation would change other nodes' indices. + +from wave_lang.kernel._support.indexing import IndexSequence, index_symbol +from wave_lang.kernel.wave.analysis.index_sequence_analysis import combine_indices + + +def test_combine_indices_does_not_mutate_input(): + # Use a symbol as dimension key (same type as in real traces). + dim = index_symbol("M") + thread_independent = {dim: IndexSequence(0, 1, 1)} + thread_dependent = {dim: IndexSequence(10, 5, 2)} + + result = combine_indices(thread_independent, thread_dependent) + + # Result should have combined values. + assert result[dim].start == 10 + assert result[dim].size == 5 + assert result[dim].stride == 2 + + # Input must be unchanged (deepcopy in combine_indices prevents in-place mutation). + assert thread_independent[dim].start == 0 + assert thread_independent[dim].size == 1 + assert thread_independent[dim].stride == 1 + + +def test_combine_indices_shared_input_unchanged_after_second_call(): + # Simulate two nodes sharing the same index dict (e.g. same reference passed + # via worklist). Second combine_indices must not mutate the first result. + dim = index_symbol("K") + shared_base = {dim: IndexSequence(0, 1, 1)} + + result1 = combine_indices(shared_base, {dim: IndexSequence(2, 3, 1)}) + result2 = combine_indices(shared_base, {dim: IndexSequence(20, 10, 2)}) + + # Each result is independent. + assert ( + result1[dim].start == 2 and result1[dim].size == 3 and result1[dim].stride == 1 + ) + assert ( + result2[dim].start == 20 + and result2[dim].size == 10 + and result2[dim].stride == 2 + ) + + # Shared input is still unchanged. + assert shared_base[dim].start == 0 + assert shared_base[dim].size == 1 + assert shared_base[dim].stride == 1 + + +if __name__ == "__main__": + test_combine_indices_does_not_mutate_input() + test_combine_indices_shared_input_unchanged_after_second_call() + print("All tests passed.") diff --git a/requirements-iree-pinned.txt b/requirements-iree-pinned.txt index 282242daa5..e188106793 100644 --- a/requirements-iree-pinned.txt +++ b/requirements-iree-pinned.txt @@ -7,5 +7,5 @@ # Uncomment to skip versions from PyPI (so _only_ nightly versions). # --no-index -iree-base-compiler==3.11.0rc20260305 -iree-base-runtime==3.11.0rc20260305 +iree-base-compiler==3.11.0 +iree-base-runtime==3.11.0 diff --git a/tests/kernel/e2e/test_copy.py b/tests/kernel/e2e/test_copy.py index 42e9174347..fb0347124e 100644 --- a/tests/kernel/e2e/test_copy.py +++ b/tests/kernel/e2e/test_copy.py @@ -17,10 +17,31 @@ from wave_lang.kernel.wave.utils.run_utils import set_default_run_config from wave_lang.kernel.wave.utils.torch_utils import device_randn, device_zeros -from ..common.utils import param_bool, require_e2e, use_water_backend_bool +from ..common.utils import ( + param_bool, + require_cdna4, + require_e2e, + use_water_backend_bool, +) from ._test_util import get_test_shapes +def _get_waveasm_test_shapes(test_name: str, extra_xfails: set): + """Wrap get_test_shapes with xfail markers for unsupported waveasm shapes.""" + xfails = extra_xfails + shapes = get_test_shapes(test_name) + return [ + ( + pytest.param( + s, marks=pytest.mark.xfail(reason="not yet supported in waveasm") + ) + if s in xfails + else s + ) + for s in shapes + ] + + def get_copy_template( shape: tuple[int, int], use_dynamic_dims: bool = False, @@ -109,6 +130,69 @@ def test_copy( assert_close(a, b) +@require_e2e +@require_cdna4 +@pytest.mark.parametrize( + "shape", + _get_waveasm_test_shapes( + "test_copy", + # (111, 813): requires vector constant scalarization for bounds checks. + extra_xfails={(111, 813)}, + ), +) +def test_copy_water_waveasm( + shape: tuple[int, int], + run_bench: bool, +) -> None: + """Test copy kernel through the water+waveasm pipeline (LLVM dialect input).""" + options, test = get_copy_template( + shape, + run_bench=run_bench, + use_water_backend=True, + use_buffer_ops=True, + ) + options = set_default_run_config(options) + options.backend = "asm" + test = wave_compile(options, test) + + a = device_randn(shape, dtype=torch.float16) + b = device_zeros(shape, dtype=torch.float16) + test(a, b) + assert_close(a, b) + + +@require_e2e +@require_cdna4 +@pytest.mark.parametrize( + "shape", + _get_waveasm_test_shapes( + "test_copy", + # Require vector constant scalarization for bounds checks. + extra_xfails={(111, 813), (1, 128), (256, 128), (256, 256), (256, 1024)}, + ), +) +def test_dynamic_copy_water_waveasm( + shape: tuple[int, int], + run_bench: bool, +) -> None: + """Test dynamic copy kernel through the water+waveasm pipeline.""" + options, test = get_copy_template( + shape, + run_bench=run_bench, + use_water_backend=True, + use_buffer_ops=True, + use_dynamic_dims=True, + ) + options = set_default_run_config(options) + options.backend = "asm" + test = wave_compile(options, test) + + a = device_randn(shape, dtype=torch.float16) + b = device_zeros(shape, dtype=torch.float16) + test(a, b) + assert_close(a, b) + + @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_copy")) @param_bool("use_buffer_ops", "buf_ops") diff --git a/tests/kernel/runtime/cache_test.py b/tests/kernel/runtime/cache_test.py index 487ecef087..88caebd940 100644 --- a/tests/kernel/runtime/cache_test.py +++ b/tests/kernel/runtime/cache_test.py @@ -572,6 +572,7 @@ def testSameConfigDifferentFreeVar(tmp_path, mfma_variant): output = device_zeros(o_shape, dtype=torch.float32) # TODO: Add variant of non-transposed V attention kernel. non_causal_mb = base_attention(q, k, v.permute([0, 2, 1]), output) + torch.cuda.synchronize() assert ( cache_manager.cache_misses == 1 and cache_manager.cache_hits == 0 ), "Expected first call to not be cached." @@ -596,13 +597,13 @@ def testSameConfigDifferentFreeVar(tmp_path, mfma_variant): ) options = set_default_run_config(options) causal_attention = wave_compile(options, causal_attention) - q = device_randn(q_shape, dtype=torch.float16) k = device_randn(k_shape, dtype=torch.float16) v = device_randn(v_shape, dtype=torch.float16) output = device_zeros(o_shape, dtype=torch.float32) # TODO: Add variant of non-transposed V attention kernel. causal_mb = causal_attention(q, k, v.permute([0, 2, 1]), output) + torch.cuda.synchronize() assert ( cache_manager.cache_misses == 2 and cache_manager.cache_hits == 0 ), "Expected to be cached despite same config, since it has different values for is_causal." @@ -780,6 +781,7 @@ def double_kernel(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]): @require_e2e @require_cache @require_cdna3 +@pytest.mark.skip(reason="Crashes and/or produces incorrect results.") def testAsmBackendCache(tmp_path): """Test that ASM backend caching works correctly.""" reset_cache_manager(tmp_path) @@ -835,6 +837,7 @@ def simple_copy( # First compilation - should be a cache miss kernel1 = wave_compile(options, simple_copy) kernel1(a, b) + assert_close(a, b) assert ( cache_manager.cache_misses == 1 and cache_manager.cache_hits == 0 @@ -851,6 +854,7 @@ def simple_copy( # Second compilation - should be a cache hit kernel2 = wave_compile(options, simple_copy) kernel2(a, b) + assert_close(a, b) assert ( cache_manager.cache_misses == 1 and cache_manager.cache_hits == 1 diff --git a/tests/kernel/test_water.py b/tests/kernel/test_water.py index 8d797977f4..21ab345ff2 100644 --- a/tests/kernel/test_water.py +++ b/tests/kernel/test_water.py @@ -46,7 +46,7 @@ def test_apply_water_middle_end_passes_unavailable(self): def test_apply_water_middle_end_passes_success(self): """Test apply_water_middle_end_passes with mocked subprocess.""" - mlir_input = "normalform.module [#wave.normal_form] {}" + mlir_input = "normalform.module [#wave.normal_form, #wave.normal_form] {}" expected_output = "normalform.module [] {}" with patch("wave_lang.kernel.wave.water.get_water_opt") as mock_get_water_opt: @@ -93,7 +93,7 @@ class TestWaterLoweringIntegration: def test_lowering_passes(self): # Test with simple Wave dialect operations - just register and add wave_mlir = """ - normalform.module [#wave.normal_form] { + normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form] { func.func @test_kernel() attributes {wave.hyperparameters = #wave.hyperparameters<{}>, wave.constraints = []} { %cst = arith.constant 0.0 : f32 %lhs = wave.register %cst : vector<4xf32> diff --git a/tests/kernel/wave/asm/test_waveasm_e2e.py b/tests/kernel/wave/asm/test_waveasm_e2e.py index 4e2dec946e..bdffdef00a 100644 --- a/tests/kernel/wave/asm/test_waveasm_e2e.py +++ b/tests/kernel/wave/asm/test_waveasm_e2e.py @@ -1238,10 +1238,9 @@ def _dbuf_mxfp4_helper( dynamic_values = {M: m, N: n} del options.subs[M] del options.subs[N] - if not use_schedule: - dynamic_symbols.append(K) - dynamic_values[K] = k - del options.subs[K] + dynamic_symbols.append(K) + dynamic_values[K] = k + del options.subs[K] options.dynamic_symbols = dynamic_symbols # Generate MXFP4 inputs and reference output @@ -1262,12 +1261,8 @@ def _dbuf_mxfp4_helper( "amdgpu.scaled_mfma" in kernel_info.mlir_text ), "Expected amdgpu.scaled_mfma operation in MLIR" if dynamic_dims: - if use_schedule: - expected_idx = r"function_type = \([^)]*index, index\) -> \(\)" - expected_msg = "M and N" - else: - expected_idx = r"function_type = \([^)]*index, index, index\) -> \(\)" - expected_msg = "M, N, and K" + expected_idx = r"function_type = \([^)]*index, index, index\) -> \(\)" + expected_msg = "M, N, and K" assert re.search(expected_idx, kernel_info.mlir_text), ( f"Expected dynamic-dims MLIR signature to carry trailing dynamic " f"{expected_msg} index arguments, got:\n{kernel_info.mlir_text[:400]}" @@ -1389,20 +1384,34 @@ def test_dbuf_4wave_mxfp4_gemm_cpp_backend( if block_id == "256x224x256" and use_schedule: pytest.xfail("C++ ASM backend exceeds VGPR limit with scheduled pipeline") - # VGPR overflow: 224x160x256 without epilogue elimination and with - # scheduled pipeline exceeds the 256 VGPR hardware limit. - if block_id == "224x160x256" and use_schedule and not eliminate_epilogue: - pytest.xfail( - "C++ ASM backend exceeds VGPR limit with scheduled pipeline " - "(ee=False) for 224x160x256" - ) - # VGPR overflow: 256x160x256 without epilogue elimination and with - # scheduled pipeline exceeds the 256 VGPR hardware limit. - if block_id == "256x160x256" and use_schedule and not eliminate_epilogue: - pytest.xfail( - "C++ ASM backend exceeds VGPR limit with scheduled pipeline " - "(ee=False) for 256x160x256" - ) + # VGPR overflow: 224x160x256 scheduled pipeline exceeds 256 VGPR limit + # without epilogue elimination; with ee=True it fits for static dims + # but dynamic dims adds enough extra VGPRs to overflow again. + if block_id == "224x160x256" and use_schedule: + if not eliminate_epilogue: + pytest.xfail( + "C++ ASM backend exceeds VGPR limit with scheduled pipeline " + "(ee=False) for 224x160x256" + ) + elif dynamic_dims: + pytest.xfail( + "C++ ASM backend exceeds VGPR limit with ee=True + dynamic " + "dims for 224x160x256" + ) + # VGPR overflow: 256x160x256 scheduled pipeline exceeds 256 VGPR limit + # without epilogue elimination; with ee=True it fits for static dims + # but dynamic dims adds enough extra VGPRs to overflow again. + if block_id == "256x160x256" and use_schedule: + if not eliminate_epilogue: + pytest.xfail( + "C++ ASM backend exceeds VGPR limit with scheduled pipeline " + "(ee=False) for 256x160x256" + ) + elif dynamic_dims: + pytest.xfail( + "C++ ASM backend exceeds VGPR limit with ee=True + dynamic " + "dims for 256x160x256" + ) # VGPR overflow for 256x192x256: ee=True reduces register pressure # enough to pass with static dims; ee=False and dynamic dims still overflow. diff --git a/tests/unittests/index_mapping_simplify_test.py b/tests/unittests/index_mapping_simplify_test.py new file mode 100644 index 0000000000..b11d18753a --- /dev/null +++ b/tests/unittests/index_mapping_simplify_test.py @@ -0,0 +1,131 @@ +# Copyright 2025 The IREE Authors +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Tests for IndexMapping flat//D and flat%D simplification.""" + +import sympy +import pytest + +import wave_lang.kernel.lang as tkl +from wave_lang.kernel.lang.wave_types import IndexMapping +from wave_lang.kernel.wave.index_mapping_simplify import ( + simplify_index_mapping, + _get_iterator_bounds, + _expr_bounds_with_iters, +) + +M = tkl.sym.M +N = tkl.sym.N +K = tkl.sym.K + + +class TestSimplifyIndexMapping: + def test_flat_div_mod_same_dim(self): + """flat = i0*K + i1, inputs={M: flat//K, K: flat%K} -> {M: i0, K: i1}.""" + i0 = IndexMapping.iterator(0) + i1 = IndexMapping.iterator(1) + + flat = i0 * K + i1 + m = IndexMapping( + num_iterators=2, + inputs={M: flat // K, K: flat % K}, + outputs={M: i0, K: i1}, + ) + + m_new, changed = simplify_index_mapping(m) + assert changed + assert m_new.input_mapping[M] == i0 + assert m_new.input_mapping[K] == i1 + + def test_flat_div_mod_with_addend(self): + """flat//K added to another term: N: (i0//16)*16 + flat//K.""" + i0 = IndexMapping.iterator(0) + i1 = IndexMapping.iterator(1) + + flat = i0 * K + i1 + m = IndexMapping( + num_iterators=2, + inputs={N: (i0 // 16) * 16 + flat // K, K: flat % K}, + outputs={N: i0, K: i1}, + ) + + m_new, changed = simplify_index_mapping(m) + assert changed + # N should be (i0//16)*16 + i0 (the flat//K simplified to i0) + expected_n = (i0 // 16) * 16 + i0 + assert sympy.simplify(m_new.input_mapping[N] - expected_n) == 0 + assert m_new.input_mapping[K] == i1 + + def test_no_simplification_when_bounds_unknown(self): + """When divisor doesn't match iteration dimension, no simplification.""" + i0 = IndexMapping.iterator(0) + i1 = IndexMapping.iterator(1) + D = sympy.Symbol("D", integer=True, positive=True) + + flat = i0 * D + i1 + m = IndexMapping( + num_iterators=2, + inputs={M: flat // D, K: flat % D}, + outputs={M: i0, K: i1}, + ) + + m_new, changed = simplify_index_mapping(m) + # D doesn't match any iteration dimension (M or K), so we can't + # prove i1 < D. + assert not changed + + def test_no_simplification_b_data_preshuffle(self): + """B-data preshuffle: within_nblk can exceed K_PACKED for general K.""" + n_it = IndexMapping.iterator(0) + k_it = IndexMapping.iterator(1) + + within_nblk = ( + (k_it // 32) * 512 + + ((k_it // 16) % 2) * 256 + + (n_it % 16) * 16 + + k_it % 16 + ) + K_PACKED = K // 2 + + m = IndexMapping( + num_iterators=2, + inputs={ + N: (n_it // 16) * 16 + within_nblk // K_PACKED, + K: within_nblk % K_PACKED, + }, + outputs={N: n_it, K: k_it}, + ) + + m_new, changed = simplify_index_mapping(m) + assert not changed + + +class TestExprBoundsWithIters: + def test_iterator_bounds(self): + i0 = IndexMapping.iterator(0) + i1 = IndexMapping.iterator(1) + bounds = {i0: (sympy.Integer(0), sympy.Integer(15)), + i1: (sympy.Integer(0), sympy.Integer(63))} + + assert _expr_bounds_with_iters(i0, bounds) == (0, 15) + assert _expr_bounds_with_iters(i1, bounds) == (0, 63) + + def test_within_nblk_bounds(self): + """within_nblk for tile [0,15]x[0,63] is bounded to [0,1023].""" + n_it = IndexMapping.iterator(0) + k_it = IndexMapping.iterator(1) + bounds = {n_it: (sympy.Integer(0), sympy.Integer(15)), + k_it: (sympy.Integer(0), sympy.Integer(63))} + + within_nblk = ( + (k_it // 32) * 512 + + ((k_it // 16) % 2) * 256 + + (n_it % 16) * 16 + + k_it % 16 + ) + result = _expr_bounds_with_iters(within_nblk, bounds) + assert result is not None + assert result[0] == 0 + assert result[1] == 1023 diff --git a/tests/unittests/simplify_floordiv_test.py b/tests/unittests/simplify_floordiv_test.py new file mode 100644 index 0000000000..c18524f36a --- /dev/null +++ b/tests/unittests/simplify_floordiv_test.py @@ -0,0 +1,147 @@ +# Copyright 2025 The IREE Authors +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Tests for the floordiv/mod divisor-splitting simplification in symbol_utils.""" + +import pytest +import sympy + +from wave_lang.kernel.wave.utils.symbol_utils import ( + simplify, + _split_sum_by_divisibility, + _is_provably_divisible, +) + + +# ── Helpers ────────────────────────────────────────────────────────────────── + +D = sympy.Symbol("D", integer=True, positive=True) +x = sympy.Symbol("x", integer=True, nonnegative=True) +y = sympy.Symbol("y", integer=True, nonnegative=True) +z = sympy.Symbol("z", integer=True, nonnegative=True) + + +# ── _is_provably_divisible ─────────────────────────────────────────────────── + + +class TestIsProvablyDivisible: + def test_zero(self): + assert _is_provably_divisible(sympy.Integer(0), D) + + def test_constant_divisible(self): + assert _is_provably_divisible(sympy.Integer(12), sympy.Integer(4)) + + def test_constant_not_divisible(self): + assert not _is_provably_divisible(sympy.Integer(13), sympy.Integer(4)) + + def test_symbolic_mul_factor(self): + assert _is_provably_divisible(3 * D * x, D) + + def test_symbolic_single_factor(self): + assert _is_provably_divisible(D * x, D) + + def test_sum_not_divisible(self): + # x + D is a sum, not a product — not detected. + assert not _is_provably_divisible(x + D, D) + + def test_no_divisor_factor(self): + assert not _is_provably_divisible(x * y, D) + + def test_numeric_mul_factor(self): + # 256*x is divisible by 8 because 256 % 8 == 0. + assert _is_provably_divisible(256 * x, sympy.Integer(8)) + + def test_numeric_mul_not_divisible(self): + assert not _is_provably_divisible(7 * x, sympy.Integer(8)) + + def test_compound_symbolic_divisor(self): + """256*f(s)*g(s) is divisible by 8*f(s) because 256/8=32.""" + f = sympy.floor(y / 8) + assert _is_provably_divisible(256 * f * x, 8 * f) + + def test_compound_symbolic_divisor_not_divisible(self): + f = sympy.floor(y / 8) + assert not _is_provably_divisible(7 * f * x, 8 * f) + + +# ── _split_sum_by_divisibility ─────────────────────────────────────────────── + + +class TestSplitSumByDivisibility: + def test_basic_split(self): + q, r = _split_sum_by_divisibility(3 * D * x + y, D) + assert q == 3 * x + assert r == y + + def test_no_divisible_terms(self): + assert _split_sum_by_divisibility(x + y, D) is None + + def test_all_divisible(self): + q, r = _split_sum_by_divisibility(6 * x + 3, sympy.Integer(3)) + assert q == 2 * x + 1 + assert r == sympy.Integer(0) + + def test_multiple_divisible_terms(self): + q, r = _split_sum_by_divisibility(D * x + D * y + z, D) + assert simplify(q - (x + y)) == 0 + assert r == z + + +# ── simplify: floor/Mod with divisor splitting ────────────────────────────── + + +class TestSimplifyFloorDiv: + def test_basic_floordiv(self): + expr = sympy.floor((3 * D * x + y) / D) + result = simplify(expr) + assert result == 3 * x + sympy.floor(y / D) + + def test_floordiv_no_remainder(self): + expr = sympy.floor(6 * D / D) + assert simplify(expr) == 6 + + def test_floordiv_bounded_remainder(self): + # x in [0, oo) so floor(x/D) can't be eliminated without tighter + # bounds. But the D*y term is factored out. + expr = sympy.floor((D * y + x) / D) + result = simplify(expr) + assert result == y + sympy.floor(x / D) + + def test_mod_basic(self): + expr = sympy.Mod(3 * D * x + y, D) + result = simplify(expr) + assert result == sympy.Mod(y, D, evaluate=False) or result == sympy.Mod(y, D) + + def test_mod_all_divisible(self): + expr = sympy.Mod(D * x + D * y, D) + result = simplify(expr) + assert result == 0 + + def test_scale_flat_pattern(self): + """The actual MXFP4 scale preshuffle pattern with D%8==0 assumption.""" + k_s = sympy.Symbol("k_s", integer=True, nonnegative=True) + n_s = sympy.Symbol("n_s", integer=True, nonnegative=True) + m = sympy.Symbol("m", integer=True, positive=True) + + # flat with D=8*m (after divisibility substitution D%8==0) + flat = ( + sympy.floor(n_s / 32) * m * 256 + + sympy.floor(k_s / 8) * 256 + + sympy.Mod(sympy.Mod(k_s, 8), 4) * 64 + + sympy.Mod(sympy.Mod(n_s, 32), 16) * 4 + + sympy.floor(sympy.Mod(k_s, 8) / 4) * 2 + + sympy.floor(sympy.Mod(n_s, 32) / 16) + ) + + div_expr = sympy.floor(flat / (8 * m)) + result = simplify(div_expr) + # The 256*m*floor(n_s/32) term should be factored out as 32*floor(n_s/32). + # The remaining floor(rest/(8*m)) should have no m in its numerator. + assert 32 * sympy.floor(n_s / 32) in result.as_ordered_terms() or str( + result + ).startswith("32*floor") + # At minimum, the result should contain 32*floor(n_s/32) as a summand. + result_str = str(result) + assert "32*floor(n_s/32)" in result_str, f"Expected factored term in {result}" diff --git a/tests/unittests/test_graph_utils.py b/tests/unittests/test_graph_utils.py index a81dd2e22f..e9c48f1bd1 100644 --- a/tests/unittests/test_graph_utils.py +++ b/tests/unittests/test_graph_utils.py @@ -7,7 +7,7 @@ import pytest import sympy import torch.fx as fx -from wave_lang.kernel._support.dtype import DataType +from tests.unittests.test_utils import add_test_node from wave_lang.kernel.wave.utils.graph_utils import ( is_barrier_between, is_barrier_between_same_graph, @@ -18,7 +18,6 @@ ) from wave_lang.kernel.ops.wave_ops import ( SharedMemoryBarrier, - NewScalar, Iterate, Conditional, Output, @@ -34,25 +33,6 @@ def create_simple_graph(): return graph -def add_test_node(graph: fx.Graph, name: str) -> fx.Node: - """ - Add a test node to the graph. - - Args: - graph: The fx.Graph to add the node to - name: A name/identifier for the node (used as the value for NewScalar) - - Returns: - The created fx.Node - """ - # Create a NewScalar node with a unique float value based on name hash - # This ensures each node has a distinct value while being deterministic - value = float(hash(name) % 1000) - node = NewScalar(value=value, dtype=DataType("f32")) - node.add_to_graph(graph) - return node.fx_node - - def add_barrier_node(graph: fx.Graph) -> fx.Node: """Add a SharedMemoryBarrier node to the graph.""" barrier = SharedMemoryBarrier() @@ -444,5 +424,45 @@ def test_tuple_vs_scalar_mismatch(self): assert not result +class TestReplaceUsesWith: + """Tests for CustomOp.replace_uses_with.""" + + def test_single_node_replacement(self): + graph = create_simple_graph() + a = add_test_node(graph, "a") + b = add_test_node(graph, "b") + user = add_test_node(graph, "user") + user._update_args_kwargs((a,), {}) + + get_custom(a).replace_uses_with(b) + assert user.args == (b,) + assert user not in a.users + assert user in b.users + + def test_identity_replacement_is_noop(self): + graph = create_simple_graph() + a = add_test_node(graph, "a") + user = add_test_node(graph, "user") + user._update_args_kwargs((a,), {}) + + get_custom(a).replace_uses_with(a) + assert user.args == (a,) + assert user in a.users + + def test_graph_scoped_replacement(self): + graph1 = create_simple_graph() + graph2 = create_simple_graph() + a = add_test_node(graph1, "a") + b = add_test_node(graph1, "b") + user1 = add_test_node(graph1, "user1") + user1._update_args_kwargs((a,), {}) + user2 = add_test_node(graph2, "user2") + user2._update_args_kwargs((a,), {}) + + get_custom(a).replace_uses_with(b, graph=graph1) + assert user1.args == (b,) + assert user2.args == (a,) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/unittests/test_region_canonicalization.py b/tests/unittests/test_region_canonicalization.py new file mode 100644 index 0000000000..7c7b92c6d9 --- /dev/null +++ b/tests/unittests/test_region_canonicalization.py @@ -0,0 +1,348 @@ +# Copyright 2026 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pytest +import torch.fx as fx + +from wave_lang.kernel._support.dtype import DataType +from wave_lang.kernel._support.indexing import IndexSymbol +from wave_lang.kernel._support.tracing import CapturedTrace, KernelRegionGraph +from wave_lang.kernel.ops.wave_ops import ( + Iterate, + NestedRegionOp, + Output, + Placeholder, + get_custom, +) +from tests.unittests.test_utils import add_test_node +from wave_lang.kernel.wave.region_canonicalization import ( + RegionFormat, + canonicalize_region_captures, + enable_direct_capture_refs, + enable_legacy_capture_placeholders, + prepare_region_captures, + requires_region_format, + verify_canonical_region_captures, +) + + +def _make_trace(main_graph: fx.Graph, subgraph: fx.Graph) -> CapturedTrace: + region_graph = KernelRegionGraph() + region_graph.subgraphs["root"] = main_graph + region_graph.subgraphs[subgraph._name] = subgraph + return CapturedTrace(region_graph, "root", None) + + +def _build_nested_region_with_lifted_capture_placeholder() -> ( + tuple[CapturedTrace, fx.Graph, fx.Node] +): + main_graph = fx.Graph() + main_graph.subgraphs = {} + + init_arg = add_test_node(main_graph, "init") + outer_value = add_test_node(main_graph, "captured") + + subgraph = fx.Graph() + subgraph._name = "iterate_subgraph" + + iterate = Iterate( + axis=IndexSymbol("M"), + init_args=[init_arg], + subgraph_name="iterate_subgraph", + implicit_captures=[outer_value], + ) + iterate.add_to_graph(main_graph) + subgraph.parent_op = iterate.fx_node + main_graph.subgraphs["iterate_subgraph"] = subgraph + + capture_placeholder = NestedRegionOp.materialize_capture_placeholder( + subgraph, outer_value + ) + Output(return_vals=([capture_placeholder],)).add_to_graph(subgraph) + return _make_trace(main_graph, subgraph), subgraph, outer_value + + +def _build_captured_trace_with_direct_capture_ref() -> ( + tuple[CapturedTrace, fx.Graph, fx.Node] +): + main_graph = fx.Graph() + main_graph._name = "root" + main_graph.subgraphs = {} + + init_arg = add_test_node(main_graph, "init") + outer_value = add_test_node(main_graph, "captured") + + subgraph = fx.Graph() + subgraph._name = "iterate_subgraph" + + iterate = Iterate( + axis=IndexSymbol("M"), + init_args=[init_arg], + subgraph_name="iterate_subgraph", + implicit_captures=[outer_value], + ) + iterate.add_to_graph(main_graph) + subgraph.parent_op = iterate.fx_node + + NestedRegionOp.materialize_capture_placeholder(subgraph, outer_value) + Output(return_vals=([outer_value],)).add_to_graph(subgraph) + + return _make_trace(main_graph, subgraph), subgraph, outer_value + + +def _build_nested_region_with_outer_placeholder_capture() -> ( + tuple[CapturedTrace, fx.Graph, fx.Node] +): + main_graph = fx.Graph() + main_graph.subgraphs = {} + + init_arg = add_test_node(main_graph, "init") + outer_value = Placeholder("captured", DataType("f32")).add_to_graph(main_graph) + + subgraph = fx.Graph() + subgraph._name = "iterate_subgraph" + + iterate = Iterate( + axis=IndexSymbol("M"), + init_args=[init_arg], + subgraph_name="iterate_subgraph", + implicit_captures=[outer_value], + ) + iterate.add_to_graph(main_graph) + subgraph.parent_op = iterate.fx_node + main_graph.subgraphs["iterate_subgraph"] = subgraph + + capture_placeholder = NestedRegionOp.materialize_capture_placeholder( + subgraph, outer_value + ) + Output(return_vals=([capture_placeholder],)).add_to_graph(subgraph) + return _make_trace(main_graph, subgraph), subgraph, outer_value + + +def test_enable_legacy_capture_placeholders_rebuilds_local_capture_placeholders(): + """The legacy adapter reintroduces local placeholders for exposed captures.""" + + trace, subgraph, outer_value = ( + _build_nested_region_with_lifted_capture_placeholder() + ) + + canonicalize_region_captures(trace) + verify_canonical_region_captures(trace) + + enable_direct_capture_refs(trace) + assert subgraph.output_node().args[0][0][0] is outer_value + assert _count_cross_graph_refs(subgraph) > 0 + with pytest.raises(ValueError, match="Direct outer capture references remain"): + verify_canonical_region_captures(trace) + + # The legacy-placeholder view restores a local Placeholder for the + # captured outer value instead of keeping the direct parent-graph reference. + enable_legacy_capture_placeholders(trace) + local_capture = subgraph.output_node().args[0][0][0] + assert local_capture.graph is subgraph + assert NestedRegionOp.capture_source(local_capture) is outer_value + assert _count_cross_graph_refs(subgraph) == 0 + verify_canonical_region_captures(trace) + + +def test_enable_direct_capture_refs_exposes_outer_placeholder_captures(): + """The direct-ref adapter rewrites lifted placeholders back to outer values.""" + + trace, subgraph, outer_value = _build_nested_region_with_outer_placeholder_capture() + + canonicalize_region_captures(trace) + verify_canonical_region_captures(trace) + + # The temporary DIRECT_OUTER_REF view replaces the local lifted placeholder + # with the actual parent-graph node, which breaks the canonical invariant. + enable_direct_capture_refs(trace) + assert subgraph.output_node().args[0][0][0] is outer_value + assert _count_cross_graph_refs(subgraph) > 0 + with pytest.raises(ValueError, match="Direct outer capture references remain"): + verify_canonical_region_captures(trace) + + canonicalize_region_captures(trace) + verify_canonical_region_captures(trace) + + +def test_canonicalize_region_captures_rejects_missing_nested_subgraph(): + """Canonicalization fails fast when a nested region references no subgraph.""" + + trace, _, _ = _build_nested_region_with_lifted_capture_placeholder() + del trace.get_root_graph().subgraphs["iterate_subgraph"] + + with pytest.raises( + ValueError, match="references missing subgraph iterate_subgraph" + ): + canonicalize_region_captures(trace) + + +def test_canonicalize_region_captures_rejects_unresolvable_legacy_placeholder(): + """Canonicalization rejects legacy placeholders that cannot be resolved.""" + + main_graph = fx.Graph() + main_graph.subgraphs = {} + init_arg = add_test_node(main_graph, "init") + + subgraph = fx.Graph() + subgraph._name = "iterate_subgraph" + + iterate = Iterate( + axis=IndexSymbol("M"), + init_args=[init_arg], + subgraph_name="iterate_subgraph", + implicit_captures=[], + ) + iterate.add_to_graph(main_graph) + subgraph.parent_op = iterate.fx_node + main_graph.subgraphs["iterate_subgraph"] = subgraph + + bogus_capture = Placeholder("bogus", DataType("f32")).add_to_graph(subgraph) + Output(return_vals=([bogus_capture],)).add_to_graph(subgraph) + + trace = _make_trace(main_graph, subgraph) + with pytest.raises( + ValueError, match="Could not resolve legacy capture placeholder bogus" + ): + canonicalize_region_captures(trace) + + +def test_verify_canonical_region_captures_rejects_non_lifted_placeholders(): + main_graph = fx.Graph() + main_graph.subgraphs = {} + init_arg = add_test_node(main_graph, "init") + + subgraph = fx.Graph() + subgraph._name = "iterate_subgraph" + + iterate = Iterate( + axis=IndexSymbol("M"), + init_args=[init_arg], + subgraph_name="iterate_subgraph", + implicit_captures=[], + ) + iterate.add_to_graph(main_graph) + subgraph.parent_op = iterate.fx_node + main_graph.subgraphs["iterate_subgraph"] = subgraph + + bogus_capture = Placeholder("bogus", DataType("f32")).add_to_graph(subgraph) + Output(return_vals=([bogus_capture],)).add_to_graph(subgraph) + trace = _make_trace(main_graph, subgraph) + + with pytest.raises(ValueError, match="Non-lifted region placeholders remain"): + verify_canonical_region_captures(trace) + + +def test_verify_canonical_region_captures_rejects_misplaced_placeholders(): + trace, subgraph, _ = _build_nested_region_with_lifted_capture_placeholder() + capture_placeholder = subgraph.output_node().args[0][0][0] + local_node = add_test_node(subgraph, "local") + local_node.append(capture_placeholder) + + with pytest.raises( + ValueError, match="Canonical capture placeholders must be leading region inputs" + ): + verify_canonical_region_captures(trace) + + +def test_verify_canonical_region_captures_rejects_capture_count_mismatch(): + trace, subgraph, _ = _build_nested_region_with_lifted_capture_placeholder() + region = get_custom(subgraph.parent_op) + region.update_arg("implicit_captures", []) + + with pytest.raises(ValueError, match="Capture placeholder count mismatch"): + verify_canonical_region_captures(trace) + + +def test_verify_canonical_region_captures_rejects_capture_source_mismatch(): + trace, subgraph, _ = _build_nested_region_with_lifted_capture_placeholder() + other_outer = add_test_node(trace.get_root_graph(), "other") + capture_placeholder = subgraph.output_node().args[0][0][0] + capture_placeholder.meta["lifted"] = other_outer + + with pytest.raises(ValueError, match="Capture placeholder source mismatch"): + verify_canonical_region_captures(trace) + + +def _count_cross_graph_refs(graph: fx.Graph) -> int: + """Count operand references in `graph` that point to nodes from other graphs.""" + + count = 0 + + def visit(arg): + nonlocal count + if isinstance(arg, fx.Node) and arg.graph is not graph: + count += 1 + return arg + + for node in graph.nodes: + fx.map_arg(node.args, visit) + fx.map_arg(node.kwargs, visit) + return count + + +def test_requires_region_format_can_skip_post_pass_canonicalization(): + """`canonicalize_output=False` leaves the requested temporary region view in place.""" + + @requires_region_format(RegionFormat.DIRECT_OUTER_REF) + def no_op(trace: CapturedTrace) -> None: + return None + + trace, subgraph, _ = _build_captured_trace_with_direct_capture_ref() + + # The default wrapper behavior canonicalizes on return, so the temporary + # DIRECT_OUTER_REF view does not leak out of the decorated call. + no_op(trace) + assert _count_cross_graph_refs(subgraph) == 0 + verify_canonical_region_captures(trace) + + # `canonicalize_output=False` skips that final return to canonical form, so + # the temporary direct outer references remain visible after the call. + no_op(trace, canonicalize_output=False) + assert _count_cross_graph_refs(subgraph) > 0 + with pytest.raises(ValueError, match="Direct outer capture references remain"): + verify_canonical_region_captures(trace) + + +def test_requires_region_format_rejects_unbound_trace_at_call_time(): + @requires_region_format(RegionFormat.DIRECT_OUTER_REF) + def no_op(trace: CapturedTrace) -> None: + return None + + with pytest.raises(TypeError, match="without binding `trace`"): + no_op() + + +def test_requires_region_format_rejects_non_trace_argument_at_call_time(): + @requires_region_format(RegionFormat.DIRECT_OUTER_REF) + def no_op(trace: CapturedTrace) -> None: + return None + + with pytest.raises(TypeError, match="expected `trace` to be a CapturedTrace"): + no_op(object()) + + +def test_requires_region_format_rejects_missing_trace_parameter(): + with pytest.raises(TypeError, match="must expose a `trace` parameter"): + + @requires_region_format(RegionFormat.DIRECT_OUTER_REF) + def no_trace(foo: int) -> None: + return None + + +def test_requires_region_format_rejects_multiple_trace_candidates(): + with pytest.raises(TypeError, match="multiple possible trace parameters"): + + @requires_region_format(RegionFormat.DIRECT_OUTER_REF) + def ambiguous(lhs: CapturedTrace, rhs: CapturedTrace) -> None: + return None + + +def test_prepare_region_captures_rejects_unsupported_region_format(): + trace, _, _ = _build_nested_region_with_lifted_capture_placeholder() + + with pytest.raises(ValueError, match="Unsupported region format: invalid"): + prepare_region_captures(trace, "invalid") diff --git a/tests/unittests/test_utils.py b/tests/unittests/test_utils.py new file mode 100644 index 0000000000..85399f7af2 --- /dev/null +++ b/tests/unittests/test_utils.py @@ -0,0 +1,29 @@ +# Copyright 2026 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch.fx as fx + +from wave_lang.kernel._support.dtype import DataType +from wave_lang.kernel.ops.wave_ops import NewScalar + + +def add_test_node(graph: fx.Graph, name: str) -> fx.Node: + """ + Add a test node to the graph. + + Args: + graph: The fx.Graph to add the node to + name: A name/identifier for the node (used as the value for NewScalar) + + Returns: + The created fx.Node + """ + # Create a NewScalar node with a unique float value based on name hash + # This ensures each node has a distinct value while being deterministic + value = float(hash(name) % (2**31)) + node = NewScalar(value=value, dtype=DataType("f32")) + node.add_to_graph(graph) + return node.fx_node diff --git a/water/include/water/Dialect/NormalForm/IR/NormalFormOps.td b/water/include/water/Dialect/NormalForm/IR/NormalFormOps.td index 612b38815c..d6798234ba 100644 --- a/water/include/water/Dialect/NormalForm/IR/NormalFormOps.td +++ b/water/include/water/Dialect/NormalForm/IR/NormalFormOps.td @@ -59,14 +59,16 @@ def ModuleOp : NormalFormOp<"module", [ ```mlir // Enforce that all tensor types are fully specified. - normalform.module @my_kernel [#wave.normal_form] { + normalform.module @my_kernel [#wave.normal_form, + #wave.normal_form] { func.func @compute(%arg: !wave.tensor<[64, 128] of f32>) { return } } // Multiple normal form attributes from different dialects. - normalform.module @validated [#wave.normal_form, #other.normal_form] { + normalform.module @validated [#wave.normal_form, + #other.normal_form] { func.func @compute(%arg: !wave.tensor<[64, 128] of f32>) { return } diff --git a/water/include/water/Dialect/Wave/IR/WaveAttrs.td b/water/include/water/Dialect/Wave/IR/WaveAttrs.td index 948a61bbea..a459c94d2d 100644 --- a/water/include/water/Dialect/Wave/IR/WaveAttrs.td +++ b/water/include/water/Dialect/Wave/IR/WaveAttrs.td @@ -108,42 +108,18 @@ def WaveMmaKindAttr : EnumAttr { // Normal form enum and attribute //----------------------------------------------------------------------------- -def NORMAL_FORM_FUNC_BOUNDARY : I32BitEnumAttrCaseBit< - "FunctionBoundarySpecified", 0, "full_func_boundary">; -def NORMAL_FORM_OP_TYPES : I32BitEnumAttrCaseBit< - "OpTypesSpecified", 1, "full_op_types">; -def NORMAL_FORM_INDEX_EXPRS : I32BitEnumAttrCaseBit< - "IndexExprsSpecified", 2, "index_exprs">; -def NORMAL_FORM_MEMORY_ONLY_TYPES : I32BitEnumAttrCaseBit< - "MemoryOnlyTypes", 3, "memory_only_types">; -def NORMAL_FORM_RESOLVED_ALLOCATIONS : I32BitEnumAttrCaseBit< - "ResolvedAllocations", 4, "resolved_allocations">; -def NORMAL_FORM_ORDERED_SYMS : I32BitEnumAttrCaseBit< - "OrderedSymsSpecified", 5, "ordered_syms">; - -def NORMAL_FORM_FULL_TYPES : I32BitEnumAttrCaseGroup< - "AllTypesSpecified", [ - NORMAL_FORM_FUNC_BOUNDARY, NORMAL_FORM_OP_TYPES - ], "full_types">; - // When anything changes below, also update the C and Python API. -def WaveNormalFormEnum : I32BitEnumAttr<"WaveNormalForm", "", [ - I32BitEnumAttrCaseNone<"None", "none">, - // Bits. - NORMAL_FORM_FUNC_BOUNDARY, - NORMAL_FORM_OP_TYPES, - NORMAL_FORM_INDEX_EXPRS, - NORMAL_FORM_MEMORY_ONLY_TYPES, - NORMAL_FORM_RESOLVED_ALLOCATIONS, - NORMAL_FORM_ORDERED_SYMS, - - // Group aliases. - NORMAL_FORM_FULL_TYPES +def WaveNormalFormEnum + : I32EnumAttr<"WaveNormalForm", "", [ + I32EnumAttrCase<"FunctionBoundarySpecified", 0, "full_func_boundary">, + I32EnumAttrCase<"OpTypesSpecified", 1, "full_op_types">, + I32EnumAttrCase<"IndexExprsSpecified", 2, "index_exprs">, + I32EnumAttrCase<"MemoryOnlyTypes", 3, "memory_only_types">, + I32EnumAttrCase<"ResolvedAllocations", 4, "resolved_allocations">, + I32EnumAttrCase<"OrderedSymsSpecified", 5, "ordered_syms">, ]> { let cppNamespace = "::wave"; - let separator = ","; let genSpecializedAttr = 0; - let printBitEnumPrimaryGroups = 1; } def WaveNormalFormAttr @@ -152,9 +128,9 @@ def WaveNormalFormAttr ]> { let assemblyFormat = "`<` $value `>`"; let extraClassDeclaration = [{ - static unsigned getLastSetBit() { - return ::llvm::bit_width_constexpr(}] # !cast(WaveNormalFormEnum).validBits # [{) - 1; - } + static constexpr std::initializer_list<::wave::WaveNormalForm> AllCases = { }]# + !interleave(!foreach(case, WaveNormalFormEnum.enumerants, + "::wave::WaveNormalForm::" # case.symbol), ", ") #[{ }; }]; } diff --git a/water/include/water/Dialect/Wave/IR/WaveOps.td b/water/include/water/Dialect/Wave/IR/WaveOps.td index 319cb1d0f8..51c6529dac 100644 --- a/water/include/water/Dialect/Wave/IR/WaveOps.td +++ b/water/include/water/Dialect/Wave/IR/WaveOps.td @@ -319,7 +319,7 @@ def IterateOp : Op, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods]> { let summary = "Yields values from the current control flow context"; let arguments = (ins diff --git a/water/include/water/Dialect/Wave/Transforms/Utils.h b/water/include/water/Dialect/Wave/Transforms/Utils.h index 3e1a7f8b68..6404951ff3 100644 --- a/water/include/water/Dialect/Wave/Transforms/Utils.h +++ b/water/include/water/Dialect/Wave/Transforms/Utils.h @@ -8,6 +8,7 @@ #define WATER_DIALECT_WAVE_TRANSFORMS_UTILS_H #include "water/Dialect/Wave/IR/WaveAttrs.h" +#include "llvm/ADT/ArrayRef.h" namespace wave { @@ -35,11 +36,11 @@ llvm::LogicalResult collectWaveConstraints( // normal form every time a verifier runs on the operation, including by default // after every pass. // -// By default, preserves existing normal forms and adds the new form. Set -// preserve=false to replace all existing forms with the provided form. -llvm::LogicalResult setNormalFormPassPostcondition(wave::WaveNormalForm form, - mlir::Operation *root, - bool preserve = true); +// By default, preserves existing normal forms and adds the new ones. Set +// preserve=false to replace all existing forms with the provided forms. +llvm::LogicalResult +setNormalFormPassPostcondition(llvm::ArrayRef forms, + mlir::Operation *root, bool preserve = true); // Clears all normal form attributes from the operation, effectively setting // the normal form to None. @@ -50,9 +51,10 @@ llvm::LogicalResult clearNormalFormPassPostcondition(mlir::Operation *root); // attribute that enforces verification. Emits diagnostics and returns failures // when it is not the case. Does *NOT* actually run verification, this is // automated by the presence of the attribute. -llvm::LogicalResult verifyNormalFormPassPrecondition(wave::WaveNormalForm form, - mlir::Operation *root, - llvm::StringRef passName); +llvm::LogicalResult +verifyNormalFormPassPrecondition(llvm::ArrayRef forms, + mlir::Operation *root, + llvm::StringRef passName); } // namespace wave diff --git a/water/include/water/c/Dialects.h b/water/include/water/c/Dialects.h index eab66043c9..b9e495ddb7 100644 --- a/water/include/water/c/Dialects.h +++ b/water/include/water/c/Dialects.h @@ -378,14 +378,12 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirWaveMmaKindAttrGetTypeID(); /// Normal forms, this must remain consistent with WaveAttrs.td. enum WaveNormalForm { - WaveNormalFormNone = 0, - WaveNormalFormFunctionBoundarySpecified = 1, - WaveNormalFormOpTypesSpecified = 2, - WaveNormalFormIndexExprsSpecified = 4, - WaveNormalFormMemoryOnlyTypes = 8, - - WaveNormalFormAllTypesSPecified = - WaveNormalFormFunctionBoundarySpecified | WaveNormalFormOpTypesSpecified + WaveNormalFormFunctionBoundarySpecified = 0, + WaveNormalFormOpTypesSpecified = 1, + WaveNormalFormIndexExprsSpecified = 2, + WaveNormalFormMemoryOnlyTypes = 3, + WaveNormalFormResolvedAllocations = 4, + WaveNormalFormOrderedSymsSpecified = 5, }; /// Checks whether the given MLIR attribute is a WaveNormalFormAttr. diff --git a/water/lib/Dialect/Wave/IR/WaveAttrs.cpp b/water/lib/Dialect/Wave/IR/WaveAttrs.cpp index 58ec7fec5a..6f08d6d78f 100644 --- a/water/lib/Dialect/Wave/IR/WaveAttrs.cpp +++ b/water/lib/Dialect/Wave/IR/WaveAttrs.cpp @@ -1032,88 +1032,81 @@ LogicalResult WaveNormalFormAttr::verifyOperation( function_ref emitError, Operation *op) const { WaveNormalForm form = getValue(); - // No normal form required. - if (form == wave::WaveNormalForm::None) + switch (form) { + case wave::WaveNormalForm::FunctionBoundarySpecified: { + auto func = llvm::dyn_cast(op); + if (!func) + return llvm::success(); + constexpr llvm::StringLiteral kMessage = + "normal form requires tensor types to be fully specified at " + "function boundaries"; + if (llvm::failed(verifyTypesFullySpecified( + /*loc*/ std::nullopt, func.getArgumentTypes(), kMessage))) + return emitError() << kMessage; + if (llvm::failed(verifyTypesFullySpecified( + /*loc*/ std::nullopt, func->getResultTypes(), kMessage))) + return emitError() << kMessage; return llvm::success(); - - if (auto func = llvm::dyn_cast(op)) { - if (wave::bitEnumContainsAll( - form, wave::WaveNormalForm::FunctionBoundarySpecified)) { - constexpr llvm::StringLiteral kMessage = - "normal form requires tensor types to be fully specified at " - "function boundaries"; - if (llvm::failed(verifyTypesFullySpecified( - /*loc*/ std::nullopt, func.getArgumentTypes(), kMessage))) - return emitError() << kMessage; - - if (llvm::failed(verifyTypesFullySpecified( - /*loc*/ std::nullopt, func->getResultTypes(), kMessage))) - return emitError() << kMessage; - } } - - if (wave::bitEnumContainsAll(form, wave::WaveNormalForm::OpTypesSpecified)) { + case wave::WaveNormalForm::OpTypesSpecified: { constexpr llvm::StringLiteral kMessage = "normal form requires tensor types to be fully specified"; if (llvm::failed(visitOpRelatedTypes(op, verifyTypesFullySpecified, kMessage, - /*emitDiagnostics*/ false))) { + /*emitDiagnostics*/ false))) return emitError() << kMessage; - } + return llvm::success(); } - - if (wave::bitEnumContainsAll(form, wave::WaveNormalForm::MemoryOnlyTypes)) { + case wave::WaveNormalForm::MemoryOnlyTypes: { constexpr llvm::StringLiteral kMessage = "normal form requires tensor types to have only memory address spaces " "(elements per thread propagation missing?)"; if (llvm::failed(visitOpRelatedTypes(op, verifyMemoryOnlyAddressSpaces, kMessage, - /*emitDiagnostics*/ false))) { + /*emitDiagnostics*/ false))) return emitError() << kMessage; - } + return llvm::success(); } - - if (wave::bitEnumContainsAll(form, - wave::WaveNormalForm::IndexExprsSpecified)) { - if (op->hasTrait() && - !op->getAttr(wave::WaveDialect::kIndexWaveExprListAttrName)) { - // Only require index expressions for read/write ops, or ops with - // WaveTensorType operands/results. Vector-only ops (after - // elements-per-thread propagation) don't need index expressions. - bool hasWaveTensor = llvm::any_of(op->getOperandTypes(), - llvm::IsaPred) || - llvm::any_of(op->getResultTypes(), - llvm::IsaPred); - bool isMemoryAccessOp = llvm::isa(op); - - // Parent allocations (byte buffers for combined shared memory) don't - // need index expressions. They are never accessed directly by read/write - // operations - only child AllocateOps reference them as a parent buffer. - // A parent allocation has no operands (no parent buffer to view into). - bool isParentAllocation = - llvm::isa(op) && op->getNumOperands() == 0; - - if ((!hasWaveTensor && !isMemoryAccessOp) || isParentAllocation) - return llvm::success(); - - if (isMemoryAccessOp) - return emitError() << "missing index expressions on memory access " - "operation, required by normal form"; - - return emitError() << "missing index expressions on operation with " - "WaveTensorType operand/result, required by " - "normal form"; - } + case wave::WaveNormalForm::IndexExprsSpecified: { + if (!op->hasTrait() || + op->getAttr(wave::WaveDialect::kIndexWaveExprListAttrName)) + return llvm::success(); + + bool hasWaveTensor = + llvm::any_of(op->getOperandTypes(), + llvm::IsaPred) || + llvm::any_of(op->getResultTypes(), llvm::IsaPred); + bool isMemoryAccessOp = llvm::isa(op); + + // Parent allocations (byte buffers for combined shared memory) don't + // need index expressions. They are never accessed directly by read/write + // operations - only child AllocateOps reference them as a parent buffer. + // A parent allocation has no operands (no parent buffer to view into). + bool isParentAllocation = + llvm::isa(op) && op->getNumOperands() == 0; + + if ((!hasWaveTensor && !isMemoryAccessOp) || isParentAllocation) + return llvm::success(); + + if (isMemoryAccessOp) + return emitError() << "missing index expressions on memory access " + "operation, required by normal form"; + + return emitError() << "missing index expressions on operation with " + "WaveTensorType operand/result, required by " + "normal form"; } - - if (wave::bitEnumContainsAll(form, - wave::WaveNormalForm::ResolvedAllocations)) { - if (auto allocOp = llvm::dyn_cast(op)) { - if (!llvm::isa(allocOp.getResult().getType())) - return emitError() << "normal form requires all wave.allocate " - "operations to have memref result type"; - } + case wave::WaveNormalForm::ResolvedAllocations: { + auto allocOp = llvm::dyn_cast(op); + if (!allocOp) + return llvm::success(); + if (!llvm::isa(allocOp.getResult().getType())) + return emitError() << "normal form requires all wave.allocate " + "operations to have memref result type"; + return llvm::success(); } - - return llvm::success(); + case wave::WaveNormalForm::OrderedSymsSpecified: + return llvm::success(); + } + llvm_unreachable("unhandled normal form"); } diff --git a/water/lib/Dialect/Wave/Transforms/DetectNormalForms.cpp b/water/lib/Dialect/Wave/Transforms/DetectNormalForms.cpp index 95fb358dcf..5137f2fe79 100644 --- a/water/lib/Dialect/Wave/Transforms/DetectNormalForms.cpp +++ b/water/lib/Dialect/Wave/Transforms/DetectNormalForms.cpp @@ -34,12 +34,9 @@ namespace wave { static SmallVector collectWaveNormalForms(MLIRContext *ctx) { SmallVector normalForms; - for (unsigned bit = 0, lastBit = WaveNormalFormAttr::getLastSetBit(); - bit <= lastBit; ++bit) { - WaveNormalForm form = - static_cast(static_cast(1) << bit); + normalForms.reserve(WaveNormalFormAttr::AllCases.size()); + for (WaveNormalForm form : WaveNormalFormAttr::AllCases) normalForms.push_back(WaveNormalFormAttr::get(ctx, form)); - } return normalForms; } diff --git a/water/lib/Dialect/Wave/Transforms/InferTypes.cpp b/water/lib/Dialect/Wave/Transforms/InferTypes.cpp index badb3f83d5..bab9e18cad 100644 --- a/water/lib/Dialect/Wave/Transforms/InferTypes.cpp +++ b/water/lib/Dialect/Wave/Transforms/InferTypes.cpp @@ -661,8 +661,8 @@ class InferTypes : public wave::impl::WaterWaveInferTypesPassBase { using WaterWaveInferTypesPassBase::WaterWaveInferTypesPassBase; void runOnOperation() override { - if (llvm::failed(verifyNormalFormPassPrecondition( - wave::WaveNormalForm::FunctionBoundarySpecified, getOperation(), + if (llvm::failed(wave::verifyNormalFormPassPrecondition( + {wave::WaveNormalForm::FunctionBoundarySpecified}, getOperation(), getArgument()))) return signalPassFailure(); @@ -715,8 +715,10 @@ class InferTypes : public wave::impl::WaterWaveInferTypesPassBase { return signalPassFailure(); if (!partial) { - llvm::LogicalResult result = setNormalFormPassPostcondition( - wave::WaveNormalForm::AllTypesSpecified, getOperation()); + llvm::LogicalResult result = wave::setNormalFormPassPostcondition( + {wave::WaveNormalForm::FunctionBoundarySpecified, + wave::WaveNormalForm::OpTypesSpecified}, + getOperation()); if (llvm::failed(result) && !force) return signalPassFailure(); } @@ -1198,8 +1200,9 @@ class PropagateElementsPerThread void runOnOperation() override { if (failed(wave::verifyNormalFormPassPrecondition( - wave::WaveNormalForm::AllTypesSpecified, getOperation(), - getArgument()))) + {wave::WaveNormalForm::FunctionBoundarySpecified, + wave::WaveNormalForm::OpTypesSpecified}, + getOperation(), getArgument()))) return signalPassFailure(); llvm::DenseMap constraints; @@ -1271,7 +1274,7 @@ class PropagateElementsPerThread } if (llvm::failed(wave::setNormalFormPassPostcondition( - wave::WaveNormalForm::MemoryOnlyTypes, getOperation()))) + {wave::WaveNormalForm::MemoryOnlyTypes}, getOperation()))) return signalPassFailure(); } }; @@ -1902,9 +1905,10 @@ class InferIndexExprsPass using Base::Base; void runOnOperation() override { - if (llvm::failed(verifyNormalFormPassPrecondition( - wave::WaveNormalForm::AllTypesSpecified, getOperation(), - getArgument()))) + if (llvm::failed(wave::verifyNormalFormPassPrecondition( + {wave::WaveNormalForm::FunctionBoundarySpecified, + wave::WaveNormalForm::OpTypesSpecified}, + getOperation(), getArgument()))) return signalPassFailure(); IRRewriter rewriter(&getContext()); @@ -1934,7 +1938,7 @@ class InferIndexExprsPass }); if (llvm::failed(wave::setNormalFormPassPostcondition( - wave::WaveNormalForm::IndexExprsSpecified, getOperation()))) + {wave::WaveNormalForm::IndexExprsSpecified}, getOperation()))) return signalPassFailure(); } }; diff --git a/water/lib/Dialect/Wave/Transforms/LowerWaveToMLIR.cpp b/water/lib/Dialect/Wave/Transforms/LowerWaveToMLIR.cpp index abec6b58a4..8b5f09d3d9 100644 --- a/water/lib/Dialect/Wave/Transforms/LowerWaveToMLIR.cpp +++ b/water/lib/Dialect/Wave/Transforms/LowerWaveToMLIR.cpp @@ -54,11 +54,12 @@ struct LowerWaveToMLIRPass Operation *op = getOperation(); if (failed(wave::verifyNormalFormPassPrecondition( - wave::WaveNormalForm::AllTypesSpecified | - wave::WaveNormalForm::MemoryOnlyTypes | - wave::WaveNormalForm::ResolvedAllocations | - wave::WaveNormalForm::IndexExprsSpecified | - wave::WaveNormalForm::OrderedSymsSpecified, + {wave::WaveNormalForm::FunctionBoundarySpecified, + wave::WaveNormalForm::OpTypesSpecified, + wave::WaveNormalForm::MemoryOnlyTypes, + wave::WaveNormalForm::ResolvedAllocations, + wave::WaveNormalForm::IndexExprsSpecified, + wave::WaveNormalForm::OrderedSymsSpecified}, op, getPassName()))) return signalPassFailure(); diff --git a/water/lib/Dialect/Wave/Transforms/ResolveDistributedAllocations.cpp b/water/lib/Dialect/Wave/Transforms/ResolveDistributedAllocations.cpp index 47d9094f16..0a2fef60bb 100644 --- a/water/lib/Dialect/Wave/Transforms/ResolveDistributedAllocations.cpp +++ b/water/lib/Dialect/Wave/Transforms/ResolveDistributedAllocations.cpp @@ -150,8 +150,8 @@ struct ResolveDistributedAllocations return signalPassFailure(); if (llvm::failed(wave::setNormalFormPassPostcondition( - wave::WaveNormalForm::ResolvedAllocations | - wave::WaveNormalForm::OrderedSymsSpecified, + {wave::WaveNormalForm::ResolvedAllocations, + wave::WaveNormalForm::OrderedSymsSpecified}, getOperation()))) return signalPassFailure(); } diff --git a/water/lib/Dialect/Wave/Transforms/Utils.cpp b/water/lib/Dialect/Wave/Transforms/Utils.cpp index 7234ff5405..7295647506 100644 --- a/water/lib/Dialect/Wave/Transforms/Utils.cpp +++ b/water/lib/Dialect/Wave/Transforms/Utils.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/Operation.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVectorExtras.h" @@ -43,13 +44,14 @@ llvm::LogicalResult wave::collectWaveConstraints( } llvm::LogicalResult -wave::setNormalFormPassPostcondition(wave::WaveNormalForm form, Operation *root, - bool preserve) { +wave::setNormalFormPassPostcondition(ArrayRef forms, + Operation *root, bool preserve) { auto module = llvm::dyn_cast(root); if (!module) return root->emitError() << "expected normalform.module"; - wave::WaveNormalForm finalForm = form; + llvm::DenseSet finalForms(forms.begin(), forms.end()); + auto normalforms = module.getNormalForms().getAsRange(); @@ -58,22 +60,26 @@ wave::setNormalFormPassPostcondition(wave::WaveNormalForm form, Operation *root, llvm::IsaPred); if (preserve) { - // Merge all existing normal forms with the new form. - for (auto nf : waveNormalForms) { - wave::WaveNormalForm currentForm = - cast(nf).getValue(); - finalForm = finalForm | currentForm; - } + for (auto nf : waveNormalForms) + finalForms.insert(cast(nf).getValue()); } if (!waveNormalForms.empty()) module.removeNormalForms(waveNormalForms); - module.addNormalForms( - {wave::WaveNormalFormAttr::get(root->getContext(), finalForm)}); + SmallVector sortedForms(finalForms.begin(), + finalForms.end()); + llvm::sort(sortedForms, [](wave::WaveNormalForm a, wave::WaveNormalForm b) { + return static_cast(a) < static_cast(b); + }); + + SmallVector newAttrs; + newAttrs.reserve(sortedForms.size()); + + for (wave::WaveNormalForm form : sortedForms) + newAttrs.push_back(wave::WaveNormalFormAttr::get(root->getContext(), form)); - // We rely on the pass manager to call verifyRegion on the normalform.module - // after the pass + module.addNormalForms(newAttrs); return llvm::success(); } @@ -95,26 +101,28 @@ llvm::LogicalResult wave::clearNormalFormPassPostcondition(Operation *root) { return llvm::success(); } -llvm::LogicalResult -wave::verifyNormalFormPassPrecondition(WaveNormalForm form, Operation *root, - llvm::StringRef passName) { +llvm::LogicalResult wave::verifyNormalFormPassPrecondition( + ArrayRef forms, Operation *root, llvm::StringRef passName) { auto module = llvm::dyn_cast(root); if (!module) return root->emitError() << "expected << " << normalform::ModuleOp::getOperationName(); ArrayRef normalforms = module.getNormalForms().getValue(); - WaveNormalForm expectedForm = WaveNormalForm::None; - for (Attribute form : llvm::make_filter_range( - normalforms, llvm::IsaPred)) { - expectedForm |= cast(form).getValue(); + llvm::DenseSet presentForms; + for (Attribute attr : + llvm::make_filter_range(normalforms, llvm::IsaPred)) + presentForms.insert(cast(attr).getValue()); + + for (WaveNormalForm form : forms) { + if (!presentForms.contains(form)) { + return root->emitError() + << passName + << " pass expects the root operation or its ancestor to " + "guarantee the " + << wave::stringifyWaveNormalForm(form) << " normal form"; + } } - if (wave::bitEnumContainsAll(expectedForm, form)) - return llvm::success(); - - return root->emitError() - << passName - << " pass expects the root operation or its ancestor to guarantee the " - << wave::stringifyEnum(form) << " normal form"; + return llvm::success(); } diff --git a/water/llvm-sha.txt b/water/llvm-sha.txt index e30a63f920..c8b2902007 100644 --- a/water/llvm-sha.txt +++ b/water/llvm-sha.txt @@ -1 +1 @@ -81a537e7081b76d41aa0422924d334a713fdd779 +d783723a584a1cab30c3a92ca247abb3401fc6da diff --git a/water/python/WaterExtensionNanobind.cpp b/water/python/WaterExtensionNanobind.cpp index c19b9cc00e..e4363cb166 100644 --- a/water/python/WaterExtensionNanobind.cpp +++ b/water/python/WaterExtensionNanobind.cpp @@ -1053,13 +1053,13 @@ NB_MODULE(_waterDialects, m) { .value("GPR_NUMBER", WaveIndexSymbol_GPR_NUMBER); nb::enum_(d, "WaveNormalForm") - .value("None_", WaveNormalFormNone) .value("FunctionBoundarySpecified", WaveNormalFormFunctionBoundarySpecified) .value("OpTypesSpecified", WaveNormalFormOpTypesSpecified) .value("IndexExprsSpecified", WaveNormalFormIndexExprsSpecified) .value("MemoryOnlyTypes", WaveNormalFormMemoryOnlyTypes) - .value("AllTypesSpecified", WaveNormalFormAllTypesSPecified); + .value("ResolvedAllocations", WaveNormalFormResolvedAllocations) + .value("OrderedSymsSpecified", WaveNormalFormOrderedSymsSpecified); nb::enum_(d, "WaveWorkgroupDim") .value("X", WaveWorkgroupDimX) diff --git a/water/test/Dialect/Wave/infer-index-exprs-lattice.mlir b/water/test/Dialect/Wave/infer-index-exprs-lattice.mlir index 03cb95bb4d..f590af036c 100644 --- a/water/test/Dialect/Wave/infer-index-exprs-lattice.mlir +++ b/water/test/Dialect/Wave/infer-index-exprs-lattice.mlir @@ -8,7 +8,7 @@ // states. // -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @simple_mma( %lhs: !wave.tensor<[@M, @K] of f16>, %rhs: !wave.tensor<[@N, @K] of f16>, @@ -35,7 +35,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @simple_mma( %lhs: !wave.tensor<[@M, @K] of f16>, %rhs: !wave.tensor<[@N, @K] of f16>, @@ -61,7 +61,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @simple_mma( %lhs: !wave.tensor<[@M, @K] of f16>, %rhs: !wave.tensor<[@N, @K] of f16>, @@ -88,7 +88,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @simple_mma( %lhs: !wave.tensor<[@M, @K] of f16>, %rhs: !wave.tensor<[@N, @K] of f16>, @@ -116,7 +116,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @add_then_mul( %a: !wave.tensor<[@M, @K] of f16>, %b: !wave.tensor<[@M, @K] of f16>, @@ -144,7 +144,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] attributes { wave_test.disable_backward } { +normalform.module [#wave.normal_form, #wave.normal_form] attributes { wave_test.disable_backward } { func.func @add_then_mul( %a: !wave.tensor<[@M, @K] of f16>, %b: !wave.tensor<[@M, @K] of f16>, @@ -173,7 +173,7 @@ normalform.module [#wave.normal_form] attributes { wave_test.disable // ----- -normalform.module [#wave.normal_form] attributes { wave_test.disable_backward } { +normalform.module [#wave.normal_form, #wave.normal_form] attributes { wave_test.disable_backward } { func.func @operand_conflict( %a: !wave.tensor<[@M, @K] of f16>, %b: !wave.tensor<[@M, @K] of f16>, @@ -202,7 +202,7 @@ normalform.module [#wave.normal_form] attributes { wave_test.disable // Generic error message when reached top somehow without detecting the conflict before. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @simple_add( %a: !wave.tensor<[@M, @N] of f32>, %b: !wave.tensor<[@M, @N] of f32> @@ -223,7 +223,7 @@ normalform.module [#wave.normal_form] { // Joining with the same expression results in that expression. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @join_with_same func.func @join_with_same( %a: !wave.tensor<[@M] of f32>, @@ -250,7 +250,7 @@ normalform.module [#wave.normal_form] { // Joining with null (uninitialized) doesn't crash and gives the other expression. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @join_with_null func.func @join_with_null( %a: !wave.tensor<[@M] of f32>, @@ -277,7 +277,7 @@ normalform.module [#wave.normal_form] { // Joining with bottom (denoted as unit) gives the other expression. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @join_with_bottom func.func @join_with_bottom( %a: !wave.tensor<[@M] of f32>, @@ -304,7 +304,7 @@ normalform.module [#wave.normal_form] { // Joining with zero is gives the other expression. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @join_with_zero func.func @join_with_zero( %a: !wave.tensor<[@M] of f32>, @@ -331,7 +331,7 @@ normalform.module [#wave.normal_form] { // Additional constant summand makes expressions join to top. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @simple_add( %a: !wave.tensor<[@M] of f32>, %b: !wave.tensor<[@M] of f32> @@ -357,7 +357,7 @@ normalform.module [#wave.normal_form] { // Different constant summands join to top. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @simple_add( %a: !wave.tensor<[@M] of f32>, %b: !wave.tensor<[@M] of f32> @@ -384,7 +384,7 @@ normalform.module [#wave.normal_form] { // Different constant values other than zero join to top. // Also, difference may be in the step. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @simple_add( %a: !wave.tensor<[@M] of f32>, %b: !wave.tensor<[@M] of f32> @@ -410,7 +410,7 @@ normalform.module [#wave.normal_form] { // Difference in stride joins to top. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @simple_add( %a: !wave.tensor<[@M] of f32>, %b: !wave.tensor<[@M] of f32> @@ -436,7 +436,7 @@ normalform.module [#wave.normal_form] { // Stride 1 joins with the other constant stride to become that stride. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @simple_add( %a: !wave.tensor<[@M] of f32>, %b: !wave.tensor<[@M] of f32> @@ -462,7 +462,7 @@ normalform.module [#wave.normal_form] { // Step 1 joins with the other non-constant step to become that step. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @simple_add( %a: !wave.tensor<[@M] of f32>, %b: !wave.tensor<[@M] of f32> @@ -488,7 +488,7 @@ normalform.module [#wave.normal_form] { // Different expressions in step join to top even if they would have resulted in a sum for start. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @simple_add( %a: !wave.tensor<[@M] of f32>, %b: !wave.tensor<[@M] of f32> @@ -516,7 +516,7 @@ normalform.module [#wave.normal_form] { // Note that here the underlying affine expression is the same, but symbols // are different, we should be able to catch that. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @simple_add( %a: !wave.tensor<[@M] of f32>, %b: !wave.tensor<[@M] of f32> @@ -543,7 +543,7 @@ normalform.module [#wave.normal_form] { // Different expressions involving workgroups join to top. // Note that there are unused symbols in mappings. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @simple_add( %a: !wave.tensor<[@M] of f32>, %b: !wave.tensor<[@M] of f32> @@ -570,7 +570,7 @@ normalform.module [#wave.normal_form] { // Joining thread and block components is fine. Note that some symbols are unused in mappings. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @join_threads_workgroups func.func @join_threads_workgroups( %a: !wave.tensor<[@M] of f32>, @@ -597,7 +597,7 @@ normalform.module [#wave.normal_form] { // Identical constant summands don't sum up when symbols do. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @same_constant_summands func.func @same_constant_summands( %a: !wave.tensor<[@M] of f32>, @@ -624,7 +624,7 @@ normalform.module [#wave.normal_form] { // Joining thread and block components is fine, this requires aligning symbols in mappings. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @join_threads_workgroups_align func.func @join_threads_workgroups_align( %a: !wave.tensor<[@M] of f32>, @@ -652,7 +652,7 @@ normalform.module [#wave.normal_form] { // Joining iter symbols and blocks is fine and results in an add. // TODO: Also check that iter symbols don't leak form the loop to results. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @join_iter_workgroups func.func @join_iter_workgroups( %a: !wave.tensor<[@M, @K] of f32>, @@ -684,7 +684,7 @@ normalform.module [#wave.normal_form] { // Joining iter symbols with themselves is fine. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @join_iter_same func.func @join_iter_same( %a: !wave.tensor<[@M, @K] of f32>, @@ -718,7 +718,7 @@ normalform.module [#wave.normal_form] { // Also check that we are not leaking iter symbols to operations after the loop // by checking that they are not used in expressions for loop results. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @join_iters func.func @join_iters( %a: !wave.tensor<[@M, @K] of f32>, @@ -762,7 +762,7 @@ normalform.module [#wave.normal_form] { // Otherwise iter symbols behave like any other component, e.g., different // expressions involving the same symbol join to top. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @join_iters( %a: !wave.tensor<[@M, @K] of f32>, %b: !wave.tensor<[@M, @K] of f32> @@ -793,7 +793,7 @@ normalform.module [#wave.normal_form] { // Check that we don't leak iter symbols to values before the loop. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @do_not_leak_above func.func @do_not_leak_above( %a: !wave.tensor<[@M, @K] of f32>, @@ -825,7 +825,7 @@ normalform.module [#wave.normal_form] { // Check that we propagate lattices between adjacent operands of a write. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @write_sideways_propagation func.func @write_sideways_propagation( %a: !wave.tensor<[@M] of f32>, @@ -855,7 +855,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @priority_join func.func @priority_join( %a: !wave.tensor<[@M] of f32>, @@ -886,7 +886,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] attributes { wave_test.disable_backward } { +normalform.module [#wave.normal_form, #wave.normal_form] attributes { wave_test.disable_backward } { func.func @same_priority_conflict( %a: !wave.tensor<[@M] of f32>, %b: !wave.tensor<[@M] of f32> @@ -913,7 +913,7 @@ normalform.module [#wave.normal_form] attributes { wave_test.disable // Test that higher priority from write propagates backward through multiple operations. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @priority_backward_through_chain func.func @priority_backward_through_chain( %a: !wave.tensor<[@M] of f32>, @@ -961,7 +961,7 @@ normalform.module [#wave.normal_form] { // Check that sideways propagation between operands of a write that would // lead to a conflict is not happening. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @write_sideways_no_conflicting_propagation func.func @write_sideways_no_conflicting_propagation( %a: !wave.tensor<[@M] of f32>, @@ -1002,7 +1002,7 @@ normalform.module [#wave.normal_form] { // Check that this value overrides the one the write is initialized // with give its priority is only 1. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @propagate_priority_equal_values func.func @propagate_priority_equal_values( %a: !wave.tensor<[@M] of f32>, diff --git a/water/test/Dialect/Wave/infer-index-exprs.mlir b/water/test/Dialect/Wave/infer-index-exprs.mlir index 6d10725147..31c9b111fa 100644 --- a/water/test/Dialect/Wave/infer-index-exprs.mlir +++ b/water/test/Dialect/Wave/infer-index-exprs.mlir @@ -1,6 +1,6 @@ // RUN: water-opt %s --water-wave-infer-index-exprs --allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s -// expected-error @below {{expects the root operation or its ancestor to guarantee the full_types normal form}} +// expected-error @below {{expects the root operation or its ancestor to guarantee the full_func_boundary normal form}} normalform.module [] { func.func @normal_form() { return @@ -9,7 +9,7 @@ normalform.module [] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @simple_mma(%a: !wave.tensor<[@M, @K] of f16>, %b: !wave.tensor<[@N, @K] of f16>, %c: !wave.tensor<[@M, @N] of f32>) { @@ -22,7 +22,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // expected-error @below {{expected a hardware constraint}} func.func @simple_mma(%a: !wave.tensor<[@M, @K] of f16>, %b: !wave.tensor<[@N, @K] of f16>, @@ -36,7 +36,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // expected-error @below {{expected either waves_per_block in the hardware constraint or wave constraints on an ancestor op}} func.func @simple_mma(%a: !wave.tensor<[@M, @K] of f16>, %b: !wave.tensor<[@N, @K] of f16>, @@ -52,7 +52,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK: @simple_mma func.func @simple_mma(%a: !wave.tensor<[@M, @K] of f16>, %b: !wave.tensor<[@N, @K] of f16>, @@ -87,7 +87,7 @@ normalform.module [#wave.normal_form] { // ----- // Batched (3D) MMA: batch dimension B is leading; M, N, K indexing is as in 2D. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK: @batched_mma func.func @batched_mma(%a: !wave.tensor<[@B, @M, @K] of f16>, %b: !wave.tensor<[@B, @N, @K] of f16>, @@ -128,7 +128,7 @@ normalform.module [#wave.normal_form] { // Make sure the tiling constraints apply to batch dimensions. Note that there // are tests below for propagation across `wave.iterate`, here it is not the // point. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK: @batched_mma_in_a_loop func.func @batched_mma_in_a_loop(%a: !wave.tensor<[@B, @M, @K] of f16>, %b: !wave.tensor<[@B, @N, @K] of f16>, @@ -171,7 +171,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK: @simple_mma_with_reads_and_write func.func @simple_mma_with_reads_and_write(%a: !wave.tensor<[@M, @K] of f16>, %b: !wave.tensor<[@N, @K] of f16>, @@ -219,7 +219,7 @@ normalform.module [#wave.normal_form] { // MMA only, with wave constraints and workgroup constraints for M, N, K. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK: @mma_with_workgroup_constraints func.func @mma_with_workgroup_constraints(%a: !wave.tensor<[@M, @K] of f16>, %b: !wave.tensor<[@N, @K] of f16>, @@ -255,7 +255,7 @@ normalform.module [#wave.normal_form] { // Two MMAs in a row. We need to store to the temporary storage and // load back because of the index (layout) change. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @mma_chain func.func @mma_chain(%a: !wave.tensor<[@M, @K] of f16>, %b: !wave.tensor<[@N, @K] of f16>, @@ -349,7 +349,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK: @mma_32x32x8_f16 func.func @mma_32x32x8_f16(%a: !wave.tensor<[@M, @K] of f16>, %b: !wave.tensor<[@N, @K] of f16>, @@ -379,7 +379,7 @@ normalform.module [#wave.normal_form] { // ----- // Check that an unmapped dimension gets the default (0, 1, 1) index expression. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK: @mma_32x32x8_f16_3d func.func @mma_32x32x8_f16_3d(%a: !wave.tensor<[@B, @M, @K] of f16>, %b: !wave.tensor<[@B, @N, @K] of f16>, @@ -412,7 +412,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK: @mma_16x16x32_f16 func.func @mma_16x16x32_f16(%a: !wave.tensor<[@M, @K] of f16>, %b: !wave.tensor<[@N, @K] of f16>, @@ -441,7 +441,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK: @mma_16x16x32_k4_f8 func.func @mma_16x16x32_k4_f8(%a: !wave.tensor<[@M, @K] of f8E5M2>, %b: !wave.tensor<[@N, @K] of f8E5M2>, @@ -470,7 +470,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK: @mma_32x32x16_f16 func.func @mma_32x32x16_f16(%a: !wave.tensor<[@M, @K] of f16>, %b: !wave.tensor<[@N, @K] of f16>, @@ -499,7 +499,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK: @mma_32x32x16_k4_f8 func.func @mma_32x32x16_k4_f8(%a: !wave.tensor<[@M, @K] of f8E5M2>, %b: !wave.tensor<[@N, @K] of f8E5M2>, @@ -528,7 +528,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // Technically this is a matrix multiplication, but we really care about the iterators. func.func @iterate(%a: !wave.tensor<[@M, @K] of bf16, >, %b: !wave.tensor<[@N, @K] of bf16, >, @@ -585,7 +585,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @unregistered_noprop(%a: !wave.tensor<[@M, @K] of f16>, %b: !wave.tensor<[@N, @K] of f16>, %c: !wave.tensor<[@M, @N] of f32>) @@ -606,7 +606,7 @@ normalform.module [#wave.normal_form] { // Cannot propagate for only pure operations in absence of MMA/writes/reductions. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @simple_add(%a: !wave.tensor<[@M, @K] of f16>, %b: !wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> @@ -626,7 +626,7 @@ normalform.module [#wave.normal_form] { // There is no inference source here so we can't infer. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @failed_to_infer_write(%src: !wave.tensor<[@M, @N] of f32>, %dst: !wave.tensor<[@M, @N] of f32, >) attributes { wave.constraints = [ #wave.hardware_constraint @@ -644,7 +644,7 @@ normalform.module [#wave.normal_form] { // There is no inference source here so we can't infer. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @failed_to_infer_binop(%a: !wave.tensor<[@M, @N] of f32>, %b: !wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M, @N] of f32> attributes { wave.constraints = [ @@ -660,7 +660,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // expected-error @below {{unsupported constraint type: #wave.device_constraint}} func.func @empty() attributes { wave.constraints = [ @@ -675,7 +675,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @allocate_register_index( %a: !wave.tensor<[@M, @K] of f16> ) attributes { @@ -704,7 +704,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @tiling_constraints_in_iter( %x: !wave.tensor<[@M, @K] of f16> ) -> !wave.tensor<[@M, @N] of f32> attributes { @@ -756,7 +756,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @tiling_constraints_in_nested_iter( %x: !wave.tensor<[@M, @K] of f16> ) attributes { @@ -815,7 +815,7 @@ normalform.module [#wave.normal_form] { // ----- // Test broadcast propagates index expressions (identity propagation). -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @broadcast_index_exprs func.func @broadcast_index_exprs( %a: !wave.tensor<[@M, @K] of f16>, @@ -844,7 +844,7 @@ normalform.module [#wave.normal_form] { // ----- // Test broadcast gets index expression for broadcasted dimension via backward propagation. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @broadcast_index_exprs_backward func.func @broadcast_index_exprs_backward( %a: !wave.tensor<[@M, @K] of f16>, @@ -890,7 +890,7 @@ normalform.module [#wave.normal_form] { // Backward propagation from the MMA should give the register only M's index // expr, not N (the broadcast dimension). -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @broadcast_mma_back_to_register func.func @broadcast_mma_back_to_register( %a: !wave.tensor<[@M, @K] of f16>, @@ -925,7 +925,7 @@ normalform.module [#wave.normal_form] { // through permute to write. The permute swaps M and N dimensions, which should // swap their strides. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @permute_propagation func.func @permute_propagation( %a: !wave.tensor<[@M, @K] of f16>, @@ -967,7 +967,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @propagate_from_write func.func @propagate_from_write( %a: !wave.tensor<[@M, @N] of f32>, @@ -1007,7 +1007,7 @@ normalform.module [#wave.normal_form] { // ----- // Elements per thread provided on the op used instead of the value inferred from workgroup constraints. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @propagate_from_write_explicit_ept func.func @propagate_from_write_explicit_ept( %output: !wave.tensor<[@M, @N] of f32> @@ -1033,7 +1033,7 @@ normalform.module [#wave.normal_form] { // ----- // Elements per thread is used for the trailing dimension because its vector shape is no longer 1. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @propagate_from_write_vector_shape func.func @propagate_from_write_vector_shape( %output: !wave.tensor<[@M, @N] of f32> @@ -1061,7 +1061,7 @@ normalform.module [#wave.normal_form] { // Test that unmapped dimensions get default (0, 1, 1) index expressions // when there are no workgroup/wave/tiling constraints for them. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @unmapped_dimension_default func.func @unmapped_dimension_default( %a: !wave.tensor<[@B, @M, @N] of f32>, @@ -1099,7 +1099,7 @@ normalform.module [#wave.normal_form] { // All writes should establish index expressions with the same priority, // and the join should succeed since they agree. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @multiple_writes_consistent func.func @multiple_writes_consistent( %a: !wave.tensor<[@M, @N] of f32>, @@ -1147,7 +1147,7 @@ normalform.module [#wave.normal_form] { // ----- // Test write when all dimension symbols are absent from constraints. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @write_all_dimensions_unmapped func.func @write_all_dimensions_unmapped( %a: !wave.tensor<[@P, @Q] of f32>, @@ -1177,7 +1177,7 @@ normalform.module [#wave.normal_form] { // MMa index expression has higher priority than write. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @write_after_mma_priority func.func @write_after_mma_priority( %a: !wave.tensor<[@M, @K] of f16>, @@ -1212,7 +1212,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @batched_mma_with_reads_and_write func.func @batched_mma_with_reads_and_write(%a: !wave.tensor<[@B, @M, @K] of f16>, %b: !wave.tensor<[@B, @N, @K] of f16>, @@ -1273,7 +1273,7 @@ normalform.module [#wave.normal_form] { // Make sure we write index expr initialization doesn't crash // on rank-0 tensors. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @write_rank0_tensor func.func @write_rank0_tensor( %src: !wave.tensor<[] of f32>, diff --git a/water/test/Dialect/Wave/lower-wave-to-mlir-invalid.mlir b/water/test/Dialect/Wave/lower-wave-to-mlir-invalid.mlir index 7c99206b11..0757a70fce 100644 --- a/water/test/Dialect/Wave/lower-wave-to-mlir-invalid.mlir +++ b/water/test/Dialect/Wave/lower-wave-to-mlir-invalid.mlir @@ -1,6 +1,6 @@ // RUN: water-opt %s -allow-unregistered-dialect -lower-wave-to-mlir --split-input-file --verify-diagnostics -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { func.func @binary_ops_pattern_failure() { %cst = arith.constant 1.0 : f32 // expected-error @below {{wave dialect operation with no hyperparameters provided by any ancestor}} @@ -12,8 +12,8 @@ normalform.module [#wave.normal_form] { +// expected-error @below {{LowerWaveToMLIRPass pass expects the root operation or its ancestor to guarantee the index_exprs normal form}} +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { func.func @missing_index_exprs_normal_form(%mem: !wave.tensor<[@M, @N] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 128}>} { %result = wave.read %mem : (!wave.tensor<[@M, @N] of f16, >) -> vector<8xf16> @@ -23,7 +23,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // expected-error @+1 {{failed to convert starting at this operation}} func.func @write_pattern_failure(%mem: !wave.tensor<[@M, @N] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 128}>} { @@ -47,7 +47,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // expected-error @below {{failed to convert starting at this operation}} func.func @lower_read_non_innermost_dim(%mem: !wave.tensor<[@M, @N] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64, BLOCK_N = 64, M = 128, N = 128}>} { diff --git a/water/test/Dialect/Wave/lower-wave-to-mlir.mlir b/water/test/Dialect/Wave/lower-wave-to-mlir.mlir index 9ee121a633..7525ed85c1 100644 --- a/water/test/Dialect/Wave/lower-wave-to-mlir.mlir +++ b/water/test/Dialect/Wave/lower-wave-to-mlir.mlir @@ -1,6 +1,6 @@ // RUN: water-opt %s -allow-unregistered-dialect -lower-wave-to-mlir -lower-normalform-module --mlir-print-local-scope --split-input-file --verify-diagnostics | FileCheck %s -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { func.func @no_hyperparams() { %cst = arith.constant 0.0 : f32 // expected-error @below {{wave dialect operation with no hyperparameters provided by any ancestor}} @@ -11,7 +11,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_exp2 func.func @lower_exp2() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { // CHECK-NOT: wave.exp2 @@ -27,7 +27,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_reciprocal func.func @lower_reciprocal() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { // CHECK-NOT: wave.reciprocal @@ -44,7 +44,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_mma_f16_f32 func.func @lower_mma_f16_f32() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { %cst_f16 = arith.constant 0.0 : f16 @@ -71,7 +71,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_all_mmas func.func @lower_all_mmas() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { // Common scalars @@ -200,7 +200,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_register func.func @lower_register() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { // CHECK-NOT: wave.register @@ -221,7 +221,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_allocate_memref // Test lowering when wave.allocate already has MemRefType result // (after ResolveDistributedAllocations pass). @@ -238,7 +238,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_nested_register func.func @lower_nested_register(%cond: i1) attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { // CHECK-NOT: wave.register @@ -259,7 +259,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_add func.func @lower_add() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { // CHECK-NOT: wave.add @@ -287,7 +287,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_max func.func @lower_max() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { // CHECK-NOT: wave.max @@ -305,7 +305,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_min func.func @lower_min() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { // CHECK-NOT: wave.min @@ -323,7 +323,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_sub func.func @lower_sub() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { // CHECK-NOT: wave.sub @@ -351,7 +351,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_mul func.func @lower_mul() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { // CHECK-NOT: wave.mul @@ -379,7 +379,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_div func.func @lower_div() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { // CHECK-NOT: wave.div @@ -407,7 +407,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_apply_expr func.func @lower_apply_expr() -> vector<4xi32> attributes {wave.hyperparameters = #wave.hyperparameters<{N = 10}>} { %cst = arith.constant 42 : i32 @@ -434,7 +434,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_apply_comparisons func.func @lower_apply_comparisons() -> vector<4xi1> attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { %cst = arith.constant 42 : i32 @@ -463,7 +463,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_apply_expr_minmax func.func @lower_apply_expr_minmax() -> vector<4xi32> attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { %cst = arith.constant 42 : i32 @@ -482,7 +482,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_apply_expr_div func.func @lower_apply_expr_div() -> vector<4xi64> attributes {wave.hyperparameters = #wave.hyperparameters<{A = 15, B = 4}>} { // CHECK: %[[CST_A:.+]] = arith.constant dense<15> : vector<4xi64> @@ -503,7 +503,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_select func.func @lower_select() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { // CHECK-NOT: wave.select @@ -523,7 +523,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_alloc_view func.func @lower_alloc_view() attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 4, BLOCK_K = 28}>} { // CHECK: %[[BUFF:.*]] = memref.alloc() : memref<256xi8, #gpu.address_space> @@ -542,7 +542,7 @@ func.func @lower_alloc_view() attributes {wave.hyperparameters = #wave.hyperpara // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @lower_alloc func.func @lower_alloc() attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 4, BLOCK_K = 28}>} { // CHECK: memref.alloc() : memref<4x32xbf16, #gpu.address_space> @@ -555,7 +555,7 @@ func.func @lower_alloc() attributes {wave.hyperparameters = #wave.hyperparameter // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @lower_read func.func @lower_read(%mem: !wave.tensor<[@M, @N] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64, BLOCK_N = 64, M = 128, N = 128}>} { %0 = wave.read %mem index [{ @@ -577,7 +577,7 @@ func.func @lower_read(%mem: !wave.tensor<[@M, @N] of f16, >) attributes // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @lower_read_non_innermost_dim func.func @lower_read_non_innermost_dim(%mem: !wave.tensor<[@M, @N] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64, BLOCK_N = 64, M = 128, N = 128}>} { %0 = wave.read %mem index [{ @@ -599,7 +599,7 @@ func.func @lower_read_non_innermost_dim(%mem: !wave.tensor<[@M, @N] of f16, ] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @lower_read_masked func.func @lower_read_masked(%mem: !wave.tensor<[@M, @N] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64, BLOCK_N = 64, M = 100, N = 50}>} { @@ -641,7 +641,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @lower_read_masked_non_innermost_dim func.func @lower_read_masked_non_innermost_dim(%mem: !wave.tensor<[@M, @N] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64, BLOCK_N = 64, M = 100, N = 50}>} { @@ -663,7 +663,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @lower_read_sparse_bounds func.func @lower_read_sparse_bounds(%mem: !wave.tensor<[@M, @N] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64, BLOCK_N = 64, M = 100, N = 64}>} { @@ -689,7 +689,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @read_with_vector_result func.func @read_with_vector_result(%mem: !wave.tensor<[@M, @N] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 128}>} { @@ -709,7 +709,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @lower_write func.func @lower_write(%mem: !wave.tensor<[@M, @N] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64, BLOCK_N = 64, M = 128, N = 128}>} { %cst = arith.constant 0.0 : f16 @@ -733,7 +733,7 @@ func.func @lower_write(%mem: !wave.tensor<[@M, @N] of f16, >) attributes // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @lower_write_non_innermost func.func @lower_write_non_innermost(%mem: !wave.tensor<[@M, @N] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64, BLOCK_N = 64, M = 128, N = 128}>} { %cst = arith.constant 0.0 : f16 @@ -757,7 +757,7 @@ func.func @lower_write_non_innermost(%mem: !wave.tensor<[@M, @N] of f16, ] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_cast_float_extension func.func @lower_cast_float_extension() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { // CHECK-NOT: wave.cast @@ -772,7 +772,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_cast_float_truncation func.func @lower_cast_float_truncation() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { // CHECK-NOT: wave.cast @@ -787,7 +787,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_cast_integer_extension func.func @lower_cast_integer_extension() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { // CHECK-NOT: wave.cast @@ -802,7 +802,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_cast_integer_truncation func.func @lower_cast_integer_truncation() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { // CHECK-NOT: wave.cast @@ -817,7 +817,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_cast_float_to_integer func.func @lower_cast_float_to_integer() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { // CHECK-NOT: wave.cast @@ -832,7 +832,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_cast_integer_to_float func.func @lower_cast_integer_to_float() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { // CHECK-NOT: wave.cast @@ -847,7 +847,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_cast_mixed_types func.func @lower_cast_mixed_types() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { // Test f16 -> f32 extension @@ -871,7 +871,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_extract_static func.func @lower_extract_static() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { // CHECK: %[[INPUT:.*]] = arith.constant dense<1.000000e+00> : vector<8xf32> @@ -890,7 +890,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_extract_dynamic func.func @lower_extract_dynamic() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { // CHECK-DAG: %[[INPUT:.*]] = arith.constant dense<2.000000e+00> : vector<8xf32> @@ -909,7 +909,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_extract_slice_constants func.func @lower_extract_slice_constants() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { // CHECK: %[[INPUT:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32> @@ -930,7 +930,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_iterate func.func @lower_iterate(%init: vector<8xf32>) attributes { wave.hyperparameters = #wave.hyperparameters<{K = 128, BLOCK_K = 32}>, @@ -956,7 +956,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_iterate_with_operations func.func @lower_iterate_with_operations(%init: vector<4xf32>) attributes { wave.hyperparameters = #wave.hyperparameters<{K = 64, BLOCK_K = 16, M = 32}>, @@ -987,7 +987,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @test_normalform_module_lowered func.func @test_normalform_module_lowered() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { return @@ -996,7 +996,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @func_with_wave_input // CHECK-SAME: (%[[ARG0:.*]]: memref<32x32xf16, #gpu.address_space>) func.func @func_with_wave_input(%arg0: !wave.tensor<[@M, @N] of f16, >) @@ -1007,7 +1007,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @func_with_multiple_wave_inputs // CHECK-SAME: (%[[ARG0:.*]]: memref<32x32xf16, #gpu.address_space>, %[[ARG1:.*]]: memref<32x32xf16, #gpu.address_space>, %[[ARG2:.*]]: memref<32x32xf16, #gpu.address_space>) func.func @func_with_multiple_wave_inputs( @@ -1021,7 +1021,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @func_with_mixed_input_types // CHECK-SAME: (%[[ARG0:.*]]: memref<64x64xf32, #gpu.address_space>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: memref<16x16xf16, #gpu.address_space>) func.func @func_with_mixed_input_types( @@ -1035,7 +1035,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @func_without_wave_tensors // CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> memref<32x32xf32> func.func @func_without_wave_tensors(%arg0: f32, %arg1: i64) -> memref<32x32xf32> @@ -1047,7 +1047,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @lower_iterate_with_vector_iter_args func.func @lower_iterate_with_vector_iter_args() attributes { wave.hyperparameters = #wave.hyperparameters<{K = 128, BLOCK_K = 32}>, @@ -1076,7 +1076,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @lower_iterate_multiple_vector_iter_args func.func @lower_iterate_multiple_vector_iter_args() attributes { wave.hyperparameters = #wave.hyperparameters<{I = 8}>, @@ -1105,7 +1105,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @lower_iterate_with_vector_captures func.func @lower_iterate_with_vector_captures() attributes { wave.hyperparameters = #wave.hyperparameters<{I = 4}>, @@ -1135,7 +1135,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @lower_permute_with_read_write func.func @lower_permute_with_read_write( %src: memref<64x64xf16, #gpu.address_space>, @@ -1167,7 +1167,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @lower_read_write_memref func.func @lower_read_write_memref(%mem: memref<64x64xf16, #gpu.address_space>) attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64, BLOCK_N = 64}>} { @@ -1191,7 +1191,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @lower_read_write_bounds_non_trailing_vectorized_dim func.func @lower_read_write_bounds_non_trailing_vectorized_dim( %mem: memref<100x64xf16, #gpu.address_space>) @@ -1234,7 +1234,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @ordered_syms_determines_dim_order func.func @ordered_syms_determines_dim_order(%mem: memref<64x32x128xf16, #gpu.address_space>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 64, K = 32, N = 128}>} { @@ -1263,7 +1263,7 @@ normalform.module [#wave.normal_form (d1, d0). -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @lower_read_with_non_identity_mapping func.func @lower_read_with_non_identity_mapping(%mem: memref<64x32xf16, #gpu.address_space>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 32, N = 64}>} { @@ -1291,7 +1291,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @lower_write_with_non_identity_mapping func.func @lower_write_with_non_identity_mapping(%mem: memref<64x32xf16, #gpu.address_space>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 32, N = 64}>} { @@ -1321,7 +1321,7 @@ normalform.module [#wave.normal_form (d1, d2, d0): value [K, M, N]. // Inverse is (d0, d1, d2) -> (d2, d0, d1), so memory order = [N, K, M] -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @lower_read_with_non_self_inverse_mapping func.func @lower_read_with_non_self_inverse_mapping(%mem: memref<8x16x4xf16, #gpu.address_space>) attributes {wave.hyperparameters = #wave.hyperparameters<{K = 16, M = 4, N = 8}>} { @@ -1350,7 +1350,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @lower_write_with_non_self_inverse_mapping func.func @lower_write_with_non_self_inverse_mapping(%mem: memref<8x16x4xf16, #gpu.address_space>) attributes {wave.hyperparameters = #wave.hyperparameters<{K = 16, M = 4, N = 8}>} { @@ -1382,7 +1382,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @lower_symbolic_wave_tensor_read_with_non_self_inverse_mapping func.func @lower_symbolic_wave_tensor_read_with_non_self_inverse_mapping(%sym: !wave.tensor<[@N, @K, @M] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{K = 16, M = 4, N = 8}>} { @@ -1410,7 +1410,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @lower_symbolic_wave_tensor_write_with_non_self_inverse_mapping func.func @lower_symbolic_wave_tensor_write_with_non_self_inverse_mapping(%sym: !wave.tensor<[@N, @K, @M] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{K = 16, M = 4, N = 8}>} { @@ -1440,7 +1440,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_shuffle_xor func.func @lower_shuffle_xor() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { %cst = arith.constant 1.0 : f16 @@ -1457,7 +1457,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_shuffle_down func.func @lower_shuffle_down() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { %cst = arith.constant 0.0 : f32 @@ -1474,7 +1474,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_shuffle_up func.func @lower_shuffle_up() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { %cst = arith.constant 2.0 : bf16 @@ -1491,7 +1491,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_shuffle_idx func.func @lower_shuffle_idx() attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { %cst = arith.constant 42 : i32 @@ -1510,7 +1510,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_sum // expected-warning @below {{unused hyperparameter: N}} func.func @lower_sum() attributes {wave.hyperparameters = #wave.hyperparameters<{M = 64, N = 32}>} { @@ -1533,7 +1533,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_max_element // expected-warning @below {{unused hyperparameter: M}} func.func @lower_max_element() attributes {wave.hyperparameters = #wave.hyperparameters<{M = 64, N = 32}>} { @@ -1555,7 +1555,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_sum_integer func.func @lower_sum_integer() attributes {wave.hyperparameters = #wave.hyperparameters<{K = 16}>} { %cst_input = arith.constant 1 : i32 @@ -1576,7 +1576,7 @@ normalform.module [#wave.normal_form - uses gpu.all_reduce instead of gpu.subgroup_reduce. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_sum_block // expected-warning @below {{unused hyperparameter: N}} func.func @lower_sum_block() attributes {wave.hyperparameters = #wave.hyperparameters<{M = 64, N = 32}>} { @@ -1599,7 +1599,7 @@ normalform.module [#wave.normal_form - uses gpu.all_reduce instead of gpu.subgroup_reduce. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_max_element_block // expected-warning @below {{unused hyperparameter: M}} func.func @lower_max_element_block() attributes {wave.hyperparameters = #wave.hyperparameters<{M = 64, N = 32}>} { @@ -1622,7 +1622,7 @@ normalform.module [#wave.normal_form but hardware constraint specifies only one wave per block. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @warn_block_reduction_single_wave func.func @warn_block_reduction_single_wave() attributes {wave.hyperparameters = #wave.hyperparameters<{M = 64}>, @@ -1640,7 +1640,7 @@ normalform.module [#wave.normal_form but hardware constraint specifies multiple waves per block. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @warn_wave_reduction_multiple_waves func.func @warn_wave_reduction_multiple_waves() attributes {wave.hyperparameters = #wave.hyperparameters<{M = 64}>, @@ -1658,7 +1658,7 @@ normalform.module [#wave.normal_form and single wave. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @warn_block_max_single_wave func.func @warn_block_max_single_wave() attributes {wave.hyperparameters = #wave.hyperparameters<{N = 32}>, @@ -1676,7 +1676,7 @@ normalform.module [#wave.normal_form and multiple waves. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @warn_wave_max_multiple_waves func.func @warn_wave_max_multiple_waves() attributes {wave.hyperparameters = #wave.hyperparameters<{N = 32}>, @@ -1694,7 +1694,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_self_index_unit_stride func.func @lower_self_index_unit_stride() -> vector<4xi32> attributes {wave.hyperparameters = #wave.hyperparameters<{N = 64}>} { // CHECK-NOT: wave.self_index @@ -1713,7 +1713,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_self_index_with_stride func.func @lower_self_index_with_stride() -> vector<4xi64> attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} { // CHECK-NOT: wave.self_index @@ -1736,7 +1736,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_self_index_complex_start func.func @lower_self_index_complex_start() -> vector<8xi32> attributes {wave.hyperparameters = #wave.hyperparameters<{M = 256, BLOCK_M = 64}>} { // CHECK-NOT: wave.self_index @@ -1756,7 +1756,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_reshape_single_elem_vectors_to_vector func.func @lower_reshape_single_elem_vectors_to_vector() -> vector<3xf32> attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { %c0 = arith.constant 0.0 : f32 @@ -1778,7 +1778,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_reshape_concat_vectors func.func @lower_reshape_concat_vectors() -> vector<8xf32> attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { %c0 = arith.constant 0.0 : f32 @@ -1799,7 +1799,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_reshape_concat_vectors_i16 func.func @lower_reshape_concat_vectors_i16() -> vector<8xi16> attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { %c0 = arith.constant 0 : i16 @@ -1821,7 +1821,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_reshape_extract_slice func.func @lower_reshape_extract_slice() -> vector<4xf32> attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { %c0 = arith.constant 0.0 : f32 @@ -1836,7 +1836,7 @@ normalform.module [#wave.normal_form to vector<8> produces vector.broadcast. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_broadcast_vector1_to_vector8 func.func @lower_broadcast_vector1_to_vector8() -> vector<8xf32> attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { %c0 = arith.constant 0.0 : f32 @@ -1852,7 +1852,7 @@ normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @lower_broadcast_identity func.func @lower_broadcast_identity() -> vector<4xf32> attributes {wave.hyperparameters = #wave.hyperparameters<{}>} { %c0 = arith.constant 1.0 : f32 diff --git a/water/test/Dialect/Wave/normal-forms.mlir b/water/test/Dialect/Wave/normal-forms.mlir index 0e03c57be6..58d2de65da 100644 --- a/water/test/Dialect/Wave/normal-forms.mlir +++ b/water/test/Dialect/Wave/normal-forms.mlir @@ -1,6 +1,6 @@ // RUN: water-opt %s --water-wave-infer-types --water-wave-propagate-elements-per-thread | FileCheck %s -// CHECK: normalform.module [#wave.normal_form] +// CHECK: normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form] normalform.module [#wave.normal_form] { func.func @test_multiple_forms_in_sequence(%mem: !wave.tensor<[@M] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>, wave.constraints = []} { %0 = arith.constant 0.0 : f32 diff --git a/water/test/Dialect/Wave/op-modification.mlir b/water/test/Dialect/Wave/op-modification.mlir index 66126a18ae..2c73701baa 100644 --- a/water/test/Dialect/Wave/op-modification.mlir +++ b/water/test/Dialect/Wave/op-modification.mlir @@ -2,7 +2,7 @@ // Technically these are matrix multiplications, but we really care about the iterators. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @make_isolated // CHECK-SAME: %[[ARG_A:.+]]: !wave.tensor<[@M, @K] of bf16, > // CHECK-SAME: %[[ARG_B:.+]]: !wave.tensor<[@N, @K] of bf16, > @@ -38,7 +38,7 @@ normalform.module [#wave.normal_form] { } } -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @make_non_isolated // CHECK-SAME: %[[ARG_A:.+]]: !wave.tensor<[@M, @K] of bf16, > // CHECK-SAME: %[[ARG_B:.+]]: !wave.tensor<[@N, @K] of bf16, > diff --git a/water/test/Dialect/Wave/ops-invalid.mlir b/water/test/Dialect/Wave/ops-invalid.mlir index c4c0c9a456..f661d4ffe4 100644 --- a/water/test/Dialect/Wave/ops-invalid.mlir +++ b/water/test/Dialect/Wave/ops-invalid.mlir @@ -49,7 +49,7 @@ func.func @mma_1d(%lhs: !wave.tensor<[@A] of f16>, %rhs: !wave.tensor<[@B] of f1 // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @mma_3d_mismatch(%a: !wave.tensor<[@M, @K, @B] of f16>, %b: !wave.tensor<[@N, @K, @B] of f16>, %c: !wave.tensor<[@M, @N, @B] of f32>) { @@ -515,7 +515,7 @@ module attributes { wave.hyperparameters = #wave.hyperparameters<{A = 42, C = 43 // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @index_key_unspecified(%mem: !wave.tensor<[@M] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64, BLOCK_N = 64, M = 128}>} { // expected-error @below {{attribute "index" uses symbolic value "N" not provided as a hyperparameter}} @@ -538,7 +538,7 @@ func.func @read_element_type_mismatch(%mem: memref<64x64xf16, #gpu.address_space // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @index_value_unspecified(%mem: !wave.tensor<[@M] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 256}>} { // expected-error @below {{attribute "index" uses symbolic value #wave.symbol<"BLOCK_M"> not provided as a hyperparameter}} @@ -553,7 +553,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @index_length_mismatch(%mem: !wave.tensor<[@M] of f16, >) { // expected-error @below {{index attribute length (0) does not match the number of index expression values (1)}} %0 = wave.read %mem index [] : (!wave.tensor<[@M] of f16, >) -> !wave.tensor<[@M] of f16, > @@ -563,7 +563,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @read_index_multiple_dicts(%mem: !wave.tensor<[@M] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} { // expected-error @below {{index attribute length (2) does not match the number of index expression values (1)}} @@ -575,7 +575,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @elements_per_thread_mismatch(%mem: !wave.tensor<[@M] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} { // expected-error @below {{expected result vector type to have the number of elements per thread matching the attribute (4), got 42}} diff --git a/water/test/Dialect/Wave/ops.mlir b/water/test/Dialect/Wave/ops.mlir index a8e97f0c57..d2068517ab 100644 --- a/water/test/Dialect/Wave/ops.mlir +++ b/water/test/Dialect/Wave/ops.mlir @@ -563,7 +563,7 @@ func.func @permute_with_index(%arg0: !wave.tensor<[@M, @N] of f16, >) // ----- // Test wave.iterate and wave.yield with vector types -normalform.module [#wave.normal_form] attributes {wave.hyperparameters = #wave.hyperparameters<{I = 4}>} { +normalform.module [#wave.normal_form, #wave.normal_form] attributes {wave.hyperparameters = #wave.hyperparameters<{I = 4}>} { // Test that wave.iterate supports vector types in both iter_args and captures // CHECK-LABEL: @iterate_vector_types diff --git a/water/test/Dialect/Wave/propagate-elements-per-thread.mlir b/water/test/Dialect/Wave/propagate-elements-per-thread.mlir index 0d92ab29cd..662750727d 100644 --- a/water/test/Dialect/Wave/propagate-elements-per-thread.mlir +++ b/water/test/Dialect/Wave/propagate-elements-per-thread.mlir @@ -1,6 +1,6 @@ // RUN: water-opt %s --water-wave-propagate-elements-per-thread --split-input-file --verify-diagnostics --allow-unregistered-dialect | FileCheck %s -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @register_alone() attributes {wave.hyperparameters = #wave.hyperparameters<{Y = 10, Z = 1}>, wave.constraints = []} { %cst = arith.constant 0.0 : f32 // expected-error @below {{couldn't identify elements per thread for result #0}} @@ -11,7 +11,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @register_add() attributes {wave.hyperparameters = #wave.hyperparameters<{Y = 10, Z = 1}>, wave.constraints = []} { %cst = arith.constant 0.0 : f32 // expected-error @below {{couldn't identify elements per thread for result #0}} @@ -23,8 +23,8 @@ normalform.module [#wave.normal_form] { // ----- -// CHECK: #wave.normal_form -normalform.module [#wave.normal_form] { +// CHECK: #wave.normal_form, #wave.normal_form, #wave.normal_form +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @propagate_register_write func.func @propagate_register_write(%mem: !wave.tensor<[@M] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>, wave.constraints = []} { %cst = arith.constant 0.0 : f16 @@ -42,8 +42,8 @@ func.func @propagate_register_write(%mem: !wave.tensor<[@M] of f16, >) a // Register per thread is the non-unit second element of the index map, // propagate that in absence of explicit elements_per_thread. // -// CHECK: #wave.normal_form -normalform.module [#wave.normal_form] { +// CHECK: #wave.normal_form, #wave.normal_form, #wave.normal_form +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @propagate_register_write_index_expr func.func @propagate_register_write_index_expr(%mem: !wave.tensor<[@M] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>, wave.constraints = []} { %cst = arith.constant 0.0 : f16 @@ -58,7 +58,7 @@ func.func @propagate_register_write_index_expr(%mem: !wave.tensor<[@M] of f16, < // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @propagate_register_write_index_expr_conflict(%mem: !wave.tensor<[@M] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>, wave.constraints = []} { %cst = arith.constant 0.0 : f16 %reg = wave.register %cst index [{M : <[] -> (, 4, )>}] : !wave.tensor<[@M] of f16, > @@ -76,7 +76,7 @@ func.func @propagate_register_write_index_expr_conflict(%mem: !wave.tensor<[@M] // initialization process may have assigned an EPT value to an operand when initializing dataflow for // its defining operation, making it the only scenario in which the conflict error may be seen // during initialization and not at some later point. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @mma_operands_from_reads( %mem_a: !wave.tensor<[@M, @K] of f16, >, %mem_b: !wave.tensor<[@N, @K] of f16, >, @@ -103,7 +103,7 @@ func.func @mma_operands_from_reads( // Null hyperparameters: step uses a symbol so it cannot be evaluated; pass must // not crash and should report that EPT could not be identified. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @null_hyperparams_symbol_step(%mem: !wave.tensor<[@M] of f16, >) attributes {wave.constraints = []} { %cst = arith.constant 0.0 : f16 // expected-error @below {{couldn't identify elements per thread for result #0}} @@ -118,8 +118,8 @@ normalform.module [#wave.normal_form] { // Null hyperparameters but constant step: step has no symbols so it is // evaluated without hyperparams and EPT is inferred. -// CHECK: #wave.normal_form -normalform.module [#wave.normal_form] { +// CHECK: #wave.normal_form, #wave.normal_form, #wave.normal_form +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @null_hyperparams_constant_step func.func @null_hyperparams_constant_step(%mem: !wave.tensor<[@M] of f16, >) attributes {wave.constraints = []} { %cst = arith.constant 0.0 : f16 @@ -135,7 +135,7 @@ normalform.module [#wave.normal_form] { // ----- // Index missing dimension N for result type [M, N]; pass must report missing dimensions. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @index_missing_dimension(%mem: !wave.tensor<[@M, @N] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 64}>, wave.constraints = []} { %cst = arith.constant 0.0 : f16 // expected-error @below {{expected index to contain entries for all result #0 dimensions}} @@ -147,8 +147,8 @@ normalform.module [#wave.normal_form] { // ----- -// CHECK: #wave.normal_form -normalform.module [#wave.normal_form] { +// CHECK: #wave.normal_form, #wave.normal_form, #wave.normal_form +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @propagate_backward_from_write func.func @propagate_backward_from_write(%mem: !wave.tensor<[@M] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>, wave.constraints = []} { %cst = arith.constant 0.0 : f16 @@ -173,8 +173,8 @@ func.func @propagate_backward_from_write(%mem: !wave.tensor<[@M] of f16, -normalform.module [#wave.normal_form] { +// CHECK: #wave.normal_form, #wave.normal_form, #wave.normal_form +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @propagate_forward_from_read func.func @propagate_forward_from_read(%mem: !wave.tensor<[@M] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>, wave.constraints = []} { // CHECK: wave.read {{.*}} : (!wave.tensor<[@M] of f16, >) -> vector<4xf16> @@ -195,8 +195,8 @@ func.func @propagate_forward_from_read(%mem: !wave.tensor<[@M] of f16, > // ----- -// CHECK: #wave.normal_form -normalform.module [#wave.normal_form] { +// CHECK: #wave.normal_form, #wave.normal_form, #wave.normal_form +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @propagate_via_identity_rhs func.func @propagate_via_identity_rhs(%mem: !wave.tensor<[@M] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>, wave.constraints = []} { // CHECK: wave.read {{.*}} : (!wave.tensor<[@M] of f16, >) -> vector<4xf16> @@ -217,7 +217,7 @@ func.func @propagate_via_identity_rhs(%mem: !wave.tensor<[@M] of f16, >) // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @missing_elements_per_thread(%mem: !wave.tensor<[@M] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>, wave.constraints = []} { // expected-error @below {{couldn't identify elements per thread for result #0}} %reg = wave.read %mem : (!wave.tensor<[@M] of f16, >) -> !wave.tensor<[@M] of f16, > @@ -227,7 +227,7 @@ func.func @missing_elements_per_thread(%mem: !wave.tensor<[@M] of f16, > // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @read_write_conflict(%mem: !wave.tensor<[@M] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>, wave.constraints = []} { %reg = wave.read %mem {elements_per_thread = 4} : (!wave.tensor<[@M] of f16, >) -> !wave.tensor<[@M] of f16, > // expected-error @below {{failed to propagate elements per thread backward: mismatch between elements_per_thread attribute (8) and operand #0 (4)}} @@ -238,7 +238,7 @@ func.func @read_write_conflict(%mem: !wave.tensor<[@M] of f16, >) attrib // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @read_write_conflict_indirect(%mem: !wave.tensor<[@M] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>, wave.constraints = []} { %reg = wave.read %mem {elements_per_thread = 4} : (!wave.tensor<[@M] of f16, >) -> !wave.tensor<[@M] of f16, > %val = wave.exp2 %reg : (!wave.tensor<[@M] of f16, >) -> !wave.tensor<[@M] of f16, > @@ -250,7 +250,7 @@ func.func @read_write_conflict_indirect(%mem: !wave.tensor<[@M] of f16, // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @alloc_is_harmless func.func @alloc_is_harmless() attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 4, BLOCK_K = 28, M = 128, N=128, K= 128}>, wave.constraints = []} { // CHECK: wave.allocate @@ -267,7 +267,7 @@ func.func @alloc_is_harmless() attributes {wave.hyperparameters = #wave.hyperpar // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @unsupported_op() attributes {wave.hyperparameters = #wave.hyperparameters<{Y = 100, Z = 200}>, wave.constraints = []} { %cst = arith.constant 42.0 : f32 %reg = wave.register %cst : !wave.tensor<[@Y, @Z] of f32, > @@ -279,8 +279,8 @@ func.func @unsupported_op() attributes {wave.hyperparameters = #wave.hyperparame // ----- -// CHECK: normalform.module [#wave.normal_form] -normalform.module [#wave.normal_form] { +// CHECK: normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form] +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @test_normal_form_conditions(%mem: !wave.tensor<[@M] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>, wave.constraints = []} { %0 = arith.constant 0.0 : f32 %reg = wave.register %0 : !wave.tensor<[@M] of f32, > @@ -291,7 +291,7 @@ normalform.module [#wave.normal_form] { // ----- -// expected-error @below {{pass expects the root operation or its ancestor to guarantee the full_types normal form}} +// expected-error @below {{pass expects the root operation or its ancestor to guarantee the full_func_boundary normal form}} normalform.module [] { func.func @normal_form_missing() { return @@ -300,8 +300,8 @@ normalform.module [] { // ----- -// CHECK: #wave.normal_form -normalform.module [#wave.normal_form] { +// CHECK: #wave.normal_form, #wave.normal_form, #wave.normal_form +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @memory_resharding_allowed func.func @memory_resharding_allowed(%mem: !wave.tensor<[@M] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>, wave.constraints = []} { %cst = arith.constant 0.0 : f16 @@ -323,8 +323,8 @@ func.func @memory_resharding_allowed(%mem: !wave.tensor<[@M] of f16, >) // ----- -// CHECK: #wave.normal_form -normalform.module [#wave.normal_form] { +// CHECK: #wave.normal_form, #wave.normal_form, #wave.normal_form +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @write_backward_propagation func.func @write_backward_propagation(%mem: !wave.tensor<[@M] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>, wave.constraints = []} { %cst = arith.constant 0.0 : f16 @@ -342,8 +342,8 @@ func.func @write_backward_propagation(%mem: !wave.tensor<[@M] of f16, >) // ----- -// CHECK: #wave.normal_form -normalform.module [#wave.normal_form] { +// CHECK: #wave.normal_form, #wave.normal_form, #wave.normal_form +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @read_register_propagation func.func @read_register_propagation(%mem: !wave.tensor<[@M] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>, wave.constraints = []} { // ReadOp should only propagate to its register result, not validate memory. @@ -360,8 +360,8 @@ func.func @read_register_propagation(%mem: !wave.tensor<[@M] of f16, >) // ----- -// CHECK: #wave.normal_form -normalform.module [#wave.normal_form] { +// CHECK: #wave.normal_form, #wave.normal_form, #wave.normal_form +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @mma_compute_lhs_from_rhs func.func @mma_compute_lhs_from_rhs(%mem1: !wave.tensor<[@N, @K] of f16, >, %mem2: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} { // LHS without elements_per_thread - will be computed from RHS + MMA constraints. @@ -386,8 +386,8 @@ func.func @mma_compute_lhs_from_rhs(%mem1: !wave.tensor<[@N, @K] of f16, -normalform.module [#wave.normal_form] { +// CHECK: #wave.normal_form, #wave.normal_form, #wave.normal_form +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @mma_compute_rhs_from_lhs func.func @mma_compute_rhs_from_lhs(%mem1: !wave.tensor<[@M, @K] of f16, >, %mem2: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} { // LHS properly initialized through read operation. @@ -413,8 +413,8 @@ func.func @mma_compute_rhs_from_lhs(%mem1: !wave.tensor<[@M, @K] of f16, -normalform.module [#wave.normal_form] { +// CHECK: #wave.normal_form, #wave.normal_form, #wave.normal_form +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @mma_compute_both_lhs_rhs func.func @mma_compute_both_lhs_rhs(%mem1: !wave.tensor<[@M, @K] of f16, >, %mem2: !wave.tensor<[@N, @K] of f16, >, %mem3: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} { // Both LHS and RHS without elements_per_thread - can compute from MMA formulas. @@ -440,7 +440,7 @@ normalform.module [#wave.normal_form] { // ----- // Test MMA error when operand has wrong elements_per_thread -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @mma_operand_mismatch(%mem1: !wave.tensor<[@M, @K] of f16, >, %mem2: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} { // LHS with wrong elements_per_thread (should be 8, not 4). %lhs = wave.read %mem1 {elements_per_thread = 4} : (!wave.tensor<[@M, @K] of f16, >) -> !wave.tensor<[@M, @K] of f16, > @@ -460,7 +460,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: func.func @batched_mma func.func @batched_mma(%mem1: !wave.tensor<[@B, @M, @K] of f16, >, %mem2: !wave.tensor<[@B, @N, @K] of f16, >, %mem3: !wave.tensor<[@B, @M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{B = 2, M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint, vector_shapes = {B = 1, M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} { %lhs_init = arith.constant 0.0 : f16 @@ -480,7 +480,7 @@ normalform.module [#wave.normal_form] { // ----- // Test iterate working with vectors after PropagateElementsPerThread conversion -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @iterate_with_vectors_after_ept func.func @iterate_with_vectors_after_ept(%mem: !wave.tensor<[@M] of f32, >) @@ -512,7 +512,7 @@ normalform.module [#wave.normal_form] { // ----- // Test extract_slice propagates elements_per_thread as a no-op. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @extract_slice_propagates_ept func.func @extract_slice_propagates_ept(%mem: !wave.tensor<[@M, @N] of f32, >) @@ -539,7 +539,7 @@ normalform.module [#wave.normal_form] { // ----- // CHECK-LABEL: @reduction_propagation_forward -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @reduction_propagation_forward(%mem: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 64}>, wave.constraints = []} { @@ -557,7 +557,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @reduction_propagation_backward func.func @reduction_propagation_backward( %mem: !wave.tensor<[@M, @N] of f32, >, @@ -581,7 +581,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @reduction_propagation_tx func.func @reduction_propagation_tx( %mem: !wave.tensor<[@M, @N] of f32, >, @@ -607,7 +607,7 @@ normalform.module [#wave.normal_form] { // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @reduction_propagation_tx_conflict( %mem: !wave.tensor<[@M, @N] of f32, >, %result_mem: !wave.tensor<[@M] of f32, >) @@ -628,7 +628,7 @@ normalform.module [#wave.normal_form] { // ----- // Test broadcast doesn't propagate EPT. -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @broadcast_no_propagation(%mem: !wave.tensor<[@M] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 64}>, wave.constraints = []} { @@ -644,8 +644,8 @@ normalform.module [#wave.normal_form] { // ----- // Reshape forward propagation: single operand, num_slices=2 -> result EPT = operand EPT / 2. -// CHECK: #wave.normal_form -normalform.module [#wave.normal_form] { +// CHECK: #wave.normal_form, #wave.normal_form, #wave.normal_form +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @reshape_forward_single_operand_num_slices func.func @reshape_forward_single_operand_num_slices( %mem: !wave.tensor<[@M] of f32, >, @@ -668,8 +668,8 @@ func.func @reshape_forward_single_operand_num_slices( // ----- // Reshape backward propagation: write fixes result EPT, operand gets result EPT * num_slices. -// CHECK: #wave.normal_form -normalform.module [#wave.normal_form] { +// CHECK: #wave.normal_form, #wave.normal_form, #wave.normal_form +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @reshape_backward_single_result_num_slices func.func @reshape_backward_single_result_num_slices( %mem: !wave.tensor<[@M] of f32, >, @@ -695,7 +695,7 @@ func.func @reshape_backward_single_result_num_slices( // ----- -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @reshape_forward_multiple_operands func.func @reshape_forward_multiple_operands( %mem1: !wave.tensor<[@M] of f32, >, @@ -717,7 +717,7 @@ func.func @reshape_forward_multiple_operands( // ----- // Backward propagation: write fixes result EPT, reshape with two operands, each operand gets result EPT / 2 -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { // CHECK-LABEL: @reshape_backward_multiple_operands func.func @reshape_backward_multiple_operands( %mem1: !wave.tensor<[@M] of f32, >, diff --git a/water/test/Dialect/Wave/python_bindings.py b/water/test/Dialect/Wave/python_bindings.py index dfcba681b2..a95f1664ad 100644 --- a/water/test/Dialect/Wave/python_bindings.py +++ b/water/test/Dialect/Wave/python_bindings.py @@ -418,16 +418,15 @@ # CHECK: #wave.expr_list<[#wave.symbol<"M">, #wave.symbol<"BLOCK_M">] -> (M floordiv BLOCK_M)> print(tiling_constr.tile_size) - # CHECK: #wave.normal_form - normal_form_attr = wave.WaveNormalFormAttr.get(wave.WaveNormalForm.None_) + # CHECK: #wave.normal_form + normal_form_attr = wave.WaveNormalFormAttr.get( + wave.WaveNormalForm.FunctionBoundarySpecified + ) print(normal_form_attr) - # CHECK: WaveNormalForm.None_ + # CHECK: WaveNormalForm.FunctionBoundarySpecified print(normal_form_attr.value) - # CHECK: #wave.normal_form - print(wave.WaveNormalFormAttr.get(wave.WaveNormalForm.FunctionBoundarySpecified)) - # CHECK: #wave.normal_form print(wave.WaveNormalFormAttr.get(wave.WaveNormalForm.OpTypesSpecified)) @@ -437,8 +436,11 @@ # CHECK: #wave.normal_form print(wave.WaveNormalFormAttr.get(wave.WaveNormalForm.MemoryOnlyTypes)) - # CHECK: #wave.normal_form - print(wave.WaveNormalFormAttr.get(wave.WaveNormalForm.AllTypesSpecified)) + # CHECK: #wave.normal_form + print(wave.WaveNormalFormAttr.get(wave.WaveNormalForm.ResolvedAllocations)) + + # CHECK: #wave.normal_form + print(wave.WaveNormalFormAttr.get(wave.WaveNormalForm.OrderedSymsSpecified)) try: wave.WaveNormalFormAttr.get(100) diff --git a/water/test/Dialect/Wave/resolve-distributed-allocations.mlir b/water/test/Dialect/Wave/resolve-distributed-allocations.mlir index 2d9d8d1c01..aff9d2527e 100644 --- a/water/test/Dialect/Wave/resolve-distributed-allocations.mlir +++ b/water/test/Dialect/Wave/resolve-distributed-allocations.mlir @@ -1,9 +1,9 @@ // RUN: water-opt %s --water-wave-resolve-distributed-allocations --split-input-file | FileCheck %s // Test basic shared memory allocation resolution. -// CHECK: normalform.module [#wave.normal_form] +// CHECK: normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] // CHECK-LABEL: func.func @resolve_basic_alloc -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @resolve_basic_alloc() attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64, BLOCK_K = 32, M = 128, K = 64}>} { // CHECK: wave.allocate {distributed_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_K">] -> (BLOCK_M, BLOCK_K)>} // CHECK-SAME: memref<64x32xbf16, #gpu.address_space> @@ -16,9 +16,9 @@ normalform.module [#wave.normal_form] { // ----- // Test allocation with expression in distributed shape. -// CHECK: normalform.module [#wave.normal_form] +// CHECK: normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] // CHECK-LABEL: func.func @resolve_alloc_with_expr -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @resolve_alloc_with_expr() attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64, BLOCK_K = 28, M = 128, K = 64}>} { // CHECK: wave.allocate {distributed_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_K">] -> (BLOCK_M, BLOCK_K + 4)>} // CHECK-SAME: memref<64x32xbf16, #gpu.address_space> @@ -31,9 +31,9 @@ normalform.module [#wave.normal_form] { // ----- // Test that child allocation is also resolved. -// CHECK: normalform.module [#wave.normal_form] +// CHECK: normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] // CHECK-LABEL: func.func @resolve_child_alloc -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @resolve_child_alloc() attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64, BLOCK_K = 32, M = 128, K = 64, SIZE = 8192}>} { // CHECK: %[[PARENT:.*]] = wave.allocate {distributed_shape = #wave.expr_list<[] -> (8192)>} // CHECK-SAME: memref<8192xi8, #gpu.address_space> @@ -53,9 +53,9 @@ normalform.module [#wave.normal_form] { // Test that padding increases last dimension of resolved memref. // BLOCK_K = 32, padding = 4, so last dim should be 32 + 4 = 36. -// CHECK: normalform.module [#wave.normal_form] +// CHECK: normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] // CHECK-LABEL: func.func @resolve_alloc_with_padding -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @resolve_alloc_with_padding() attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64, BLOCK_K = 32, M = 128, K = 64}>} { // CHECK: wave.allocate {distributed_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_K">] -> (BLOCK_M, BLOCK_K)>, padding = 4 : i64} // CHECK-SAME: memref<64x36xbf16, #gpu.address_space> @@ -69,9 +69,9 @@ normalform.module [#wave.normal_form] { // Test that tail_padding does NOT affect the resolved memref shape. // BLOCK_K = 32, tail_padding = 128, shape should still be 64x32. -// CHECK: normalform.module [#wave.normal_form] +// CHECK: normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] // CHECK-LABEL: func.func @resolve_alloc_with_tail_padding -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @resolve_alloc_with_tail_padding() attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64, BLOCK_K = 32, M = 128, K = 64}>} { // CHECK: wave.allocate {distributed_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_K">] -> (BLOCK_M, BLOCK_K)>, tail_padding = 128 : i64} // CHECK-SAME: memref<64x32xbf16, #gpu.address_space> @@ -85,9 +85,9 @@ normalform.module [#wave.normal_form] { // Test both padding and tail_padding together. // BLOCK_K = 32, padding = 4 -> last dim = 36, tail_padding = 128 -> no shape change. -// CHECK: normalform.module [#wave.normal_form] +// CHECK: normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] // CHECK-LABEL: func.func @resolve_alloc_with_both_padding -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @resolve_alloc_with_both_padding() attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64, BLOCK_K = 32, M = 128, K = 64}>} { // CHECK: wave.allocate {distributed_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_K">] -> (BLOCK_M, BLOCK_K)>, padding = 4 : i64, tail_padding = 128 : i64} // CHECK-SAME: memref<64x36xbf16, #gpu.address_space> @@ -100,8 +100,8 @@ normalform.module [#wave.normal_form] { // ----- // Test that resolved_allocations is set even when there are no allocations. -// CHECK: normalform.module [#wave.normal_form] -normalform.module [#wave.normal_form] { +// CHECK: normalform.module [#wave.normal_form, #wave.normal_form, #wave.normal_form, #wave.normal_form] +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @no_allocations() { return } @@ -110,7 +110,7 @@ normalform.module [#wave.normal_form] { // ----- // Test that resolved_allocations is set on a module without existing normal form. -// CHECK: normalform.module [#wave.normal_form] +// CHECK: normalform.module [#wave.normal_form, #wave.normal_form] normalform.module [] { func.func @resolve_without_existing_normal_form() attributes {wave.hyperparameters = #wave.hyperparameters<{M = 32}>} { // CHECK: wave.allocate @@ -127,7 +127,7 @@ normalform.module [] { // The tensor has shape [@M, @K, @N] - ordered_syms should preserve this order, // NOT the alphabetical order (K, M, N) that DictionaryAttr would produce. // CHECK-LABEL: func.func @read_ordered_syms -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @read_ordered_syms(%mem: !wave.tensor<[@M, @K, @N] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 64, K = 32, N = 128, T0 = 64}>} { // CHECK: wave.read @@ -145,7 +145,7 @@ normalform.module [#wave.normal_form] { // Test that ordered_syms is set on wave.write ops. // CHECK-LABEL: func.func @write_ordered_syms -normalform.module [#wave.normal_form] { +normalform.module [#wave.normal_form, #wave.normal_form] { func.func @write_ordered_syms(%val: vector<8xf16>, %mem: !wave.tensor<[@M, @K, @N] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 64, K = 32, N = 128, T0 = 64}>} { // CHECK: wave.write diff --git a/wave_lang/kernel/_support/indexing.py b/wave_lang/kernel/_support/indexing.py index b740a6ff4c..61a02bacca 100644 --- a/wave_lang/kernel/_support/indexing.py +++ b/wave_lang/kernel/_support/indexing.py @@ -83,6 +83,61 @@ def subs_idxc( return IndexingContext.current().subs_expr(input) +def _resolve_chained_subs( + subs: dict[IndexSymbol, int | IndexSymbol], +) -> dict[IndexSymbol, int | IndexSymbol]: + """Resolve chained dependencies in a substitution dictionary. + + When a substitution dict has ``{K: 8192, K_SCALE: K // 32}``, a single + simultaneous substitution pass replaces ``K_SCALE`` with ``K // 32`` + but leaves the ``K`` inside the replacement unresolved. + + This function processes entries in topological (dependency) order: + entries whose values don't reference other keys are resolved first, + then their concrete values are substituted into the remaining entries. + + Only the values are updated; the keys (symbols being defined) are + never modified. + """ + all_keys = set(subs.keys()) + resolved: dict = {} + pending = dict(subs) + + while pending: + ready = [] + for key, val in pending.items(): + if not isinstance(val, sympy.Basic): + ready.append(key) + continue + deps = (val.free_symbols & all_keys) - {key} - set(resolved.keys()) + if not deps: + ready.append(key) + + if not ready: + break + + for key in ready: + val = pending.pop(key) + if isinstance(val, sympy.Basic) and resolved: + val = piecewise_aware_subs(val, resolved) + resolved[key] = val + + if pending: + import warnings + cycle_keys = sorted(str(k) for k in pending.keys()) + warnings.warn( + f"_resolve_chained_subs: circular dependency among" + f" {cycle_keys} — substitution may be incomplete", + stacklevel=2, + ) + for key, val in pending.items(): + if isinstance(val, sympy.Basic) and resolved: + val = piecewise_aware_subs(val, resolved) + resolved[key] = val + + return resolved + + def is_literal(input: IndexSymbol | int) -> bool: """ Check if input is a literal number value. @@ -153,7 +208,7 @@ def __str__(self): ) def set_subs(self, subs: dict[IndexSymbol, int | IndexSymbol]): - self.subs = copy.deepcopy(subs) + self.subs = _resolve_chained_subs(copy.deepcopy(subs)) self.cached_subs = {} def subs_expr(self, expr: IndexExpr) -> IndexExpr: diff --git a/wave_lang/kernel/_support/tracing.py b/wave_lang/kernel/_support/tracing.py index b0bf0cf5f0..35b46f7100 100644 --- a/wave_lang/kernel/_support/tracing.py +++ b/wave_lang/kernel/_support/tracing.py @@ -186,6 +186,8 @@ def get_subgraph(self, name: str) -> fx.Graph: def add_subgraph(self, name: str, graph: fx.Graph): self.region_graph.subgraphs[name] = graph + if name != self.root_graph: + self.get_root_graph().subgraphs[name] = graph def get_root_graph(self) -> fx.Graph: return self.get_subgraph(self.root_graph) diff --git a/wave_lang/kernel/compiler/kernel_codegen.py b/wave_lang/kernel/compiler/kernel_codegen.py index 6321dc3491..8f307cf216 100644 --- a/wave_lang/kernel/compiler/kernel_codegen.py +++ b/wave_lang/kernel/compiler/kernel_codegen.py @@ -411,9 +411,13 @@ def get_users_recursive(node, parent=None): ret.append(user) continue - if custom.subgraph_name not in graph.subgraphs: + # All subgraphs (including nested ones) are registered in a + # flat dict on the root graph, so this lookup works at any + # nesting depth. + root_subgraphs = custom.get_root_graph().subgraphs + if custom.subgraph_name not in root_subgraphs: raise KeyError(custom.subgraph_name) - subgraph = graph.subgraphs[custom.subgraph_name] + subgraph = root_subgraphs[custom.subgraph_name] nested_placeholders = filter_fx_graph(subgraph, is_placeholder) for nested in nested_placeholders: captured = get_custom(nested).get_captured_fx_node() diff --git a/wave_lang/kernel/compiler/wave_codegen/handlers.py b/wave_lang/kernel/compiler/wave_codegen/handlers.py index cb054af271..fd5bc47d36 100644 --- a/wave_lang/kernel/compiler/wave_codegen/handlers.py +++ b/wave_lang/kernel/compiler/wave_codegen/handlers.py @@ -1476,8 +1476,7 @@ def handle_conditional(emitter: WaveEmitter, node: fx.Node): for i, iter_arg in enumerate(iter_args): emitter.bind_node_proxy(iter_arg, flat_else_return[i]) - captured_vars: list[fx.Node] = get_custom(node).captured_vars(subgraph) - for root_v, subgraph_v in zip(implicit_capture, captured_vars): + for root_v, subgraph_v in get_custom(node).get_capture_bindings(subgraph): emitter._node_values[subgraph_v] = emitter.lookup_node_values(root_v) # Emit the subgraph. @@ -1568,13 +1567,7 @@ def handle_iterate(emitter: WaveEmitter, node: fx.Node): ) for i, v in enumerate(forOp.inner_iter_args): emitter.bind_node_proxy(iter_args[i], IRProxyValue(v)) - captured_vars: list[fx.Node] = get_custom(node).captured_vars(subgraph) - for subgraph_v in captured_vars: - if "lifted" not in subgraph_v.meta: - raise ValueError( - "Cannot find subgraph_v's corresponding value in the root graph." - ) - root_v = subgraph_v.meta["lifted"] + for root_v, subgraph_v in get_custom(node).get_capture_bindings(subgraph): emitter._node_values[subgraph_v] = emitter.lookup_node_values(root_v) # Emit the subgraph. return_values = emitter._emit_graph(subgraph) @@ -1661,9 +1654,7 @@ def handle_iterate_while(emitter: WaveEmitter, node: fx.Node): with InsertionPoint(whileOp.after.blocks[0]): subgraph = emitter.trace.get_subgraph(subgraph) # Map the captured variables from the root graph to the subgraph - for root_v, subgraph_v in zip( - implicit_capture, get_custom(node).captured_vars(subgraph) - ): + for root_v, subgraph_v in get_custom(node).get_capture_bindings(subgraph): emitter._node_values[subgraph_v] = emitter.lookup_node_values(root_v) # Map the iteration variable diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py old mode 100644 new mode 100755 index 76386a8793..ad657294fd --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -61,8 +61,13 @@ MemoryAccessFlags, ) from ...wave.utils.general_utils import get_fastest_index, infer_dim, linearize_index -from ...wave.utils.mapping_utils import transform_index_on_mapping -from ...wave.utils.symbol_utils import safe_subs, simplify +from ...wave.utils.mapping_utils import ( + linearize_dims, + mem_simplify, + transform_index_on_mapping, +) +from ...wave.assumptions import get_divisibility_subs +from ...wave.utils.symbol_utils import safe_subs, simplify, extract_iv from .emitter import ( WaveEmitter, add_emitter_subs, @@ -429,18 +434,36 @@ def _compute_branchless_valid_bytes( We emit: cond = gen_sympy_index(guard_condition) # index type, nonzero=true - real_valid = compute_static_validBytes() + real_valid = actual_buffer_size_bytes # NOT 0x7FFFFFFE validBytes = select(cond != 0, real_valid, 0) When the condition is false (last iterations), validBytes=0 makes the SRD's NUM_RECORDS=0 so gather_to_lds DMA is a hardware no-op. + + When the condition is true, NUM_RECORDS=real buffer size so the + hardware clamps any OOB flat addresses to return 0 — no per-load + software bounds checking needed. """ uint64 = IntegerType.get_signless(64) total_bytes = _compute_total_valid_bytes( elem_type, symbolic_shape, use_real_bounds=True ) - real_valid = arith_d.constant(uint64, get_constant_attr(total_bytes, uint64)) + hw_max = _valid_bytes_buffer(elem_type) + if total_bytes == hw_max and symbolic_shape is not None: + # Static path returned the hardware-max fallback — shape is dynamic. + # Compute actual buffer size at runtime so num_records provides + # real bounds clamping (matching AITER's approach). + elem_bytes = _elem_bytes(elem_type) + total_bytes_expr = sympy.prod(s for s in symbolic_shape) * elem_bytes + subs_map = add_emitter_subs(emitter) + real_valid_index = gen_sympy_index(subs_map, total_bytes_expr) + real_valid = arith_d.index_cast(uint64, real_valid_index) + else: + real_valid = arith_d.constant( + uint64, get_constant_attr(total_bytes, uint64) + ) + zero_valid = arith_d.constant(uint64, get_constant_attr(0, uint64)) cond_val = gen_sympy_index(add_emitter_subs(emitter), guard_condition) @@ -470,15 +493,24 @@ def _compute_valid_bytes( uint64 = IntegerType.get_signless(64) if use_real_bounds: + hw_max = _valid_bytes_buffer(elem_type) + if total_bytes == hw_max and symbolic_shape is not None: + elem_bytes = _elem_bytes(elem_type) + total_bytes_expr = sympy.prod(s for s in symbolic_shape) * elem_bytes + subs_map = add_emitter_subs(emitter) + real_valid_index = gen_sympy_index(subs_map, total_bytes_expr) + total_val = arith_d.index_cast(uint64, real_valid_index) + else: + total_val = arith_d.constant( + uint64, get_constant_attr(total_bytes, uint64) + ) metadata = memref_d.extract_strided_metadata(ptr) offset_elements = metadata[1] offset_bytes = arith_d.index_cast(uint64, offset_elements) - # Use _elem_bytes to avoid zero offset_bytes for sub-byte types (e.g. mxfp4). elem_bytes_val = arith_d.constant( uint64, get_constant_attr(_elem_bytes(elem_type), uint64) ) offset_bytes = arith_d.muli(offset_bytes, elem_bytes_val) - total_val = arith_d.constant(uint64, get_constant_attr(total_bytes, uint64)) return arith_d.subi(total_val, offset_bytes) return arith_d.constant(uint64, get_constant_attr(total_bytes, uint64)) @@ -771,12 +803,73 @@ def extract(vec, ind): return +def _cancel_floordiv_mod_linearize( + dim_exprs: list[sympy.Expr], + strides: list[sympy.Expr], +) -> sympy.Expr: + """Compute ``sum(e_i * s_i)`` while cancelling floor/Mod pairs. + + Delegates to :func:`linearize_dims` which expands ``Mod(x, d)`` + into ``x - d*floor(x/d)`` so that the matching ``floor`` terms + cancel algebraically under ``expand()``. + """ + return linearize_dims(dim_exprs, strides) + + +def _emit_cycle_offset( + cycle: list[IndexExpr], + iv_mlir: Value, + subs_map: dict, + overflow_flags, +) -> Value: + """Emit MLIR for cycle-based IV offset. + + For a repeating stride cycle [s0, s1, ...] of length N: + offset = (IV // N) * macro_stride + cumulative[IV % N] + where macro_stride = sum(cycle) and cumulative = prefix sums of cycle. + """ + n = len(cycle) + macro_stride = sum(cycle) + cumulative = [sympy.Integer(0)] + for s in cycle[:-1]: + cumulative.append(cumulative[-1] + s) + + idx_ty = IndexType.get() + + is_pow2 = (n & (n - 1)) == 0 and n > 0 + n_val = arith_d.constant(idx_ty, n) + + if is_pow2: + log2n = int(math.log2(n)) + shift = arith_d.constant(idx_ty, log2n) + iv_div = arith_d.shrui(iv_mlir, shift) + mask_val = arith_d.constant(idx_ty, n - 1) + iv_mod = arith_d.andi(iv_mlir, mask_val) + else: + iv_div = arith_d.divui(iv_mlir, n_val) + iv_mod = arith_d.remui(iv_mlir, n_val) + + macro_val = gen_sympy_index(subs_map, macro_stride) + macro_term = arith_d.muli(iv_div, macro_val, overflow_flags=overflow_flags) + + cum_offset = arith_d.constant(idx_ty, 0) + for i in range(n - 1, -1, -1): + c_val = gen_sympy_index(subs_map, cumulative[i]) + i_val = arith_d.constant(idx_ty, i) + cmp = arith_d.cmpi(arith_d.CmpIPredicate.eq, iv_mod, i_val) + cum_offset = arith_d.select(cmp, c_val, cum_offset) + + return arith_d.addi(macro_term, cum_offset, overflow_flags=overflow_flags) + + def _try_iv_split_offset( emitter: WaveEmitter, index: dict[IndexExpr, IndexSequence | IndexExpr], - strides: list[int], + strides: list[int | IndexExpr], dynamic_vals: dict[IndexExpr, Any], - use_subs_idxc: bool = False, + use_subs_idxc: bool = True, + precomputed_iv_stride: dict[sympy.Symbol, IndexExpr | list[IndexExpr]] | None = None, + **kwargs, ) -> Optional[Value]: """Compute a hoisted IV-split linearized offset for a loop-carried read. @@ -784,14 +877,11 @@ def _try_iv_split_offset( expressions are provably affine in the loop IV, or ``None`` to fall back to the default address path. - The caller is responsible for emitting the actual load/gather using the - returned offset. + When *precomputed_iv_stride* is supplied (from + ``compute_iv_stride_through_mapping``), the IV stride is known from the + pre-mapping index and the extraction phase is skipped entirely. - Parameters - ---------- - strides : per-dimension integer strides for linearisation. - use_subs_idxc : if True, apply ``subs_idxc`` before simplification - (needed when expressions contain residual shape symbols). + Otherwise falls back to the original Phase 1 / Phase 1b extraction. """ ip = InsertionPoint.current owner = ip.block.owner @@ -803,7 +893,6 @@ def _try_iv_split_offset( # Find the IV symbol for this scf.for directly from its block argument. current_iv = owner.induction_variable - # do a reverse lookup of the dimension/symbol that the current IV is associated with dim = next((d for d, v in emitter.induction_vars.items() if v == current_iv), None) if dim is None: return None @@ -826,63 +915,151 @@ def _try_iv_split_offset( if len(start_exprs) != len(strides): return None - # Phase 1: Symbolic linearity proof w.r.t. the current loop's IV only. - # substitute IV = step*_j and check - # that the linearized index is c*_j + remainder (no _j in remainder). + sym_strides = [sympy.sympify(s) for s in strides] + + # ------------------------------------------------------------------ + # Fast path: pre-computed IV stride from mapping analysis. + # ------------------------------------------------------------------ + has_iv = any(iv_sym in sympy.sympify(e).free_symbols for e in start_exprs) + if not has_iv: + return None + if precomputed_iv_stride and iv_sym in precomputed_iv_stride: + k_stride_per_iv = precomputed_iv_stride[iv_sym] + + base_start_exprs = [safe_subs(e, {iv_sym: 0}) for e in start_exprs] + + hoist_ip = InsertionPoint(owner) + subs_map = add_emitter_subs(emitter, dynamic_vals) + overflow_flags = arith_d.IntegerOverflowFlags.nsw + + with hoist_ip: + lin_offset = None + for base_expr, stride in zip(base_start_exprs, sym_strides): + val = gen_sympy_index(subs_map, base_expr) + stride_val = gen_sympy_index(subs_map, stride) + term = arith_d.muli(val, stride_val, overflow_flags=overflow_flags) + lin_offset = ( + term + if lin_offset is None + else arith_d.addi( + lin_offset, term, overflow_flags=overflow_flags + ) + ) + + iv_mlir = subs_map.get(iv_sym) + if iv_mlir is None: + return None + + if isinstance(k_stride_per_iv, list): + iv_offset = _emit_cycle_offset( + k_stride_per_iv, iv_mlir, subs_map, overflow_flags + ) + else: + k_stride_val = gen_sympy_index(subs_map, k_stride_per_iv) + iv_offset = arith_d.muli( + iv_mlir, k_stride_val, overflow_flags=overflow_flags + ) + + total = arith_d.addi(lin_offset, iv_offset, overflow_flags=overflow_flags) + return total + + # ------------------------------------------------------------------ + # Original extraction path (Phase 1 / Phase 1b). + # ------------------------------------------------------------------ + div_fwd, div_bwd = get_divisibility_subs(emitter.constraints) + _j = sympy.Symbol("_j", integer=True, nonnegative=True) iv_as_j = step_int * _j - lin_sym = sympy.Integer(0) - for expr, ps in zip(start_exprs, strides): + + dims = list(index.keys()) + + dim_exprs = [] + for i, (expr, stride) in enumerate(zip(start_exprs, sym_strides)): e = safe_subs(expr, {iv_sym: iv_as_j}) if use_subs_idxc: e = subs_idxc(e) - e = simplify(e) - lin_sym += e * ps - lin_sym = simplify(lin_sym) - - coeff = lin_sym.coeff(_j) - remainder = simplify(lin_sym - coeff * _j) - if not coeff.is_Integer or coeff == 0 or _j in remainder.free_symbols: - return None - k_stride_per_iv, rem = divmod(int(coeff), step_int) - if rem != 0: + if div_fwd: + e = safe_subs(e, div_fwd) + e = mem_simplify(e) + dim_exprs.append(e) + + # Phase 1: per-dimension extract. + iv_stride_sym = sympy.Integer(0) + base_exprs = [] + split_first_ok = True + + for i, (e, stride) in enumerate(zip(dim_exprs, sym_strides)): + result = extract_iv(e, _j) + if result is None: + split_first_ok = False + break + j_coeff, remainder = result + + if div_bwd: + j_coeff = safe_subs(j_coeff, div_bwd) + remainder = safe_subs(remainder, div_bwd) + + iv_stride_sym += simplify(mem_simplify(j_coeff * stride)) + base_exprs.append(remainder) + + # Phase 1b: linearize-first fallback. + if not split_first_ok: + fwd_strides = [] + for s in sym_strides: + fs = safe_subs(s, div_fwd) if div_fwd else s + fwd_strides.append(fs) + + lin_sym = _cancel_floordiv_mod_linearize(dim_exprs, fwd_strides) + lin_sym = mem_simplify(lin_sym) + + result = extract_iv(lin_sym, _j) + if result is None: + return None + j_coeff_lin, base_lin = result + + if div_bwd: + j_coeff_lin = safe_subs(j_coeff_lin, div_bwd) + base_lin = safe_subs(base_lin, div_bwd) + + iv_stride_sym = simplify(mem_simplify(j_coeff_lin)) + base_exprs = None + base_lin_expr = base_lin + + if iv_stride_sym == 0: return None - # Phase 2: Substitute IV=0 to get the loop-invariant base offset. - iv_zero_subs = {iv_sym: 0} - index_no_iv = {} - for dim, seq in index.items(): - start = _get_start_index(seq) - new_start = safe_subs(start, iv_zero_subs) - if isinstance(seq, IndexSequence): - index_no_iv[dim] = IndexSequence(new_start, seq.size) - else: - index_no_iv[dim] = new_start + if iv_stride_sym.is_Integer: + k_stride_per_iv_int, rem = divmod(int(iv_stride_sym), step_int) + if rem != 0: + return None + k_stride_per_iv = sympy.Integer(k_stride_per_iv_int) + else: + k_stride_per_iv = simplify(mem_simplify(iv_stride_sym / step_int)) - # Emit the hoisted linearized offset BEFORE the scf.for. + # Emit MLIR. hoist_ip = InsertionPoint(owner) subs_map = add_emitter_subs(emitter, dynamic_vals) overflow_flags = arith_d.IntegerOverflowFlags.nsw with hoist_ip: - iv0_exprs = _get_start_indices(index_no_iv) - lin_offset = None - for expr, ps in zip(iv0_exprs, strides): - val = gen_sympy_index(subs_map, expr) - stride_c = arith_d.constant(IndexType.get(), ps) - term = arith_d.muli(val, stride_c, overflow_flags=overflow_flags) - lin_offset = ( - term - if lin_offset is None - else arith_d.addi(lin_offset, term, overflow_flags=overflow_flags) - ) + if base_exprs is not None: + lin_offset = None + for base_expr, stride in zip(base_exprs, sym_strides): + val = gen_sympy_index(subs_map, base_expr) + stride_val = gen_sympy_index(subs_map, stride) + term = arith_d.muli(val, stride_val, overflow_flags=overflow_flags) + lin_offset = ( + term + if lin_offset is None + else arith_d.addi(lin_offset, term, overflow_flags=overflow_flags) + ) + else: + lin_offset = gen_sympy_index(subs_map, base_lin_expr) + k_stride_val = gen_sympy_index(subs_map, k_stride_per_iv) - # Back inside the loop: total = hoisted_base + IV * k_stride. iv_mlir = subs_map.get(iv_sym) if iv_mlir is None: return None - - k_stride_val = arith_d.constant(IndexType.get(), k_stride_per_iv) iv_offset = arith_d.muli(iv_mlir, k_stride_val, overflow_flags=overflow_flags) return arith_d.addi(lin_offset, iv_offset, overflow_flags=overflow_flags) @@ -961,9 +1138,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): is_global_mem = kb_src.type.memory_space is None buffer_ops_enabled = emitter.options.use_buffer_ops and is_global_mem - # Set by _emit_wide_read in partition_strided_operators.py when merging - # reads with per-element bounds. Buffer-ops rely on SRD bounds checking - # instead, so the precomputed mask is only emitted for non-buffer-ops. + iv_stride_from_mapping = node.meta.get("iv_stride", None) precomputed_mask_expr = getattr(node, "precomputed_mask_expr", None) if precomputed_mask_expr is not None and not buffer_ops_enabled: mask = gen_sympy_index(add_emitter_subs(emitter), precomputed_mask_expr) @@ -996,7 +1171,6 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): # IV-split fast path for global reads: hoist address before the loop. if ( is_global - and mask is None and not use_llvm_load and not read_meets_hw_transpose_requirements( get_custom(node), emitter.constraints, emitter.options.target @@ -1005,48 +1179,66 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): kb_type = MemRefType(kb_src.type) phys_strides, _ = kb_type.get_strides_and_offset() dyn_sentinel = ShapedType.get_dynamic_stride_or_offset() - if not any(s == dyn_sentinel for s in phys_strides): - total_offset = _try_iv_split_offset( - emitter, - index, - list(phys_strides), - dynamic_vals_map_start, + if any(s == dyn_sentinel for s in phys_strides): + iv_strides = list( + strides_from_symbolic_shape( + IndexingContext.current(), + input_shape, + allow_mixed_shapes=True, + ) ) - if total_offset is not None: - ip = InsertionPoint.current - owner = ip.block.owner - hoist_ip = InsertionPoint(owner) - with hoist_ip: - strides_vals = [ - arith_d.constant(IndexType.get(), s) for s in phys_strides - ] - zero_indices = [arith_d.constant(IndexType.get(), 0)] * len( - phys_strides + else: + iv_strides = [sympy.Integer(s) for s in phys_strides] + total_offset = _try_iv_split_offset( + emitter, + index, + iv_strides, + dynamic_vals_map_start, + precomputed_iv_stride=iv_stride_from_mapping, + ) + if total_offset is not None: + ip = InsertionPoint.current + owner = ip.block.owner + hoist_ip = InsertionPoint(owner) + subs_map = add_emitter_subs(emitter, dynamic_vals_map_start) + with hoist_ip: + strides_vals = [gen_sympy_index(subs_map, s) for s in iv_strides] + zero_indices = [arith_d.constant(IndexType.get(), 0)] * len( + iv_strides + ) + lin_src, _ = _linearize_memref( + kb_src, zero_indices, zero_indices, strides_vals + ) + + # With epilogue elimination the loop runs extra iterations + # whose offsets can exceed the actual buffer. Wrap the + # linearised memref in a fat_raw_buffer_cast so that the + # SRD's NUM_RECORDS = real buffer size and OOB loads safely + # return zero instead of faulting. + if buffer_ops_enabled and emitter.options.eliminate_epilogue: + valid_bytes = _compute_valid_bytes( + lin_src, + element_type, + input_shape, + emitter, ) - lin_src, _ = _linearize_memref( - kb_src, zero_indices, zero_indices, strides_vals + lin_src = _cast_buffer_and_encode_stride( + lin_src, + strides_vals, + element_type, + valid_bytes, ) - # With epilogue elimination the loop runs extra iterations - # whose offsets can exceed the actual buffer. Wrap the - # linearised memref in a fat_raw_buffer_cast so that the - # SRD's NUM_RECORDS = real buffer size and OOB loads safely - # return zero instead of faulting. - if buffer_ops_enabled and emitter.options.eliminate_epilogue: - valid_bytes = _compute_valid_bytes( - lin_src, - element_type, - input_shape, - emitter, - ) - lin_src = _cast_buffer_and_encode_stride( - lin_src, - strides_vals, - element_type, - valid_bytes, - ) + if mask is None: result = vector_d.load(vector_type, lin_src, [total_offset]) - emitter.bind_node_proxy(node, IRProxyValue(result)) - return + else: + element_type = vector_type.element_type + zero = arith_d.constant(element_type, get_constant_attr(0, element_type)) + passthru = vector_d.broadcast(vector_type, zero) + result = vector_d.maskedload( + vector_type, lin_src, [total_offset], mask, passthru + ) + emitter.bind_node_proxy(node, IRProxyValue(result)) + return start_indices, start_indices_wg, start_indices_th = _build_start_indices( emitter, index, dynamic_vals_map_start @@ -1383,6 +1575,7 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): src_dynamic_vals_map_start = {} dst_dynamic_vals_map_start = {} + iv_stride_from_mapping = node.meta.get("iv_stride", None) if src_mapping: dyn_vals = tuple( cast_vector(emitter, reg, element_type=IndexType.get()) @@ -1431,21 +1624,14 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): subs_map = add_emitter_subs(emitter, src_dynamic_vals_map_start) strides = [gen_sympy_index(subs_map, s) for s in sym_stride_vals] - # IV-split: try hoisting the src offset before the loop. - try: - sym_strides_int = [int(subs_idxc(s)) for s in sym_stride_vals] - except (TypeError, ValueError): - sym_strides_int = [] - - src_offset = None - if sym_strides_int: - src_offset = _try_iv_split_offset( - emitter, - new_src_idx, - sym_strides_int, - src_dynamic_vals_map_start, - use_subs_idxc=True, - ) + src_offset = _try_iv_split_offset( + emitter, + new_src_idx, + list(sym_stride_vals), + src_dynamic_vals_map_start, + use_subs_idxc=True, + precomputed_iv_stride=iv_stride_from_mapping, + ) if src_offset is not None: # IV-split path: offset=0 reinterpret_cast, full address in src_offset. @@ -1480,18 +1666,26 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): ), ) - mask = _build_mask( - emitter, - src_idx, - elements_per_thread=1, - bounds=src_bounds, - dynamic_values=src_dynamic_vals_map_start, - ) - if mask: - mask = vector_d.extract(mask, static_position=[0], dynamic_position=[]) - oob_index_value = _get_out_of_bounds_index(element_type) - oob_index = arith_d.constant(IndexType.get(), oob_index_value) - src_offset = arith_d.select(mask, src_offset, oob_index) + # When the SRD validBytes encodes the real buffer size (via + # _compute_branchless_valid_bytes with dynamic shape), hardware + # num_records clamping returns 0 for OOB addresses. Skip the + # per-load software mask to eliminate ~4 VALU instructions and + # temporary VGPRs per gather_to_lds call. + if valid_bytes_override is None: + mask = _build_mask( + emitter, + src_idx, + elements_per_thread=1, + bounds=src_bounds, + dynamic_values=src_dynamic_vals_map_start, + ) + if mask: + mask = vector_d.extract( + mask, static_position=[0], dynamic_position=[] + ) + oob_index_value = _get_out_of_bounds_index(element_type) + oob_index = arith_d.constant(IndexType.get(), oob_index_value) + src_offset = arith_d.select(mask, src_offset, oob_index) amdgpu_d.gather_to_lds( src=lin_src, diff --git a/wave_lang/kernel/ops/wave_ops.py b/wave_lang/kernel/ops/wave_ops.py index d366a31600..91c6617b4d 100644 --- a/wave_lang/kernel/ops/wave_ops.py +++ b/wave_lang/kernel/ops/wave_ops.py @@ -869,15 +869,18 @@ def _add_proxy_to_graph(self, region_graph: RegionGraph): def update_arg(self, idx_or_name: int | str | fx.Node, value: CustomOp | fx.Node): """ - Update the value of an argument in the node while keeping the - underlying fx.Node consistent. + Update an operand or named field while keeping the underlying + `fx.Node` consistent for both positional arguments and keyword + arguments. """ inherited_field_count = len(CustomOp.__dataclass_fields__) field_names = [field.name for field in fields(self)[inherited_field_count:]] + field_name = None if isinstance(idx_or_name, str): if idx_or_name not in field_names: raise ValueError(f"Field {idx_or_name} not found") idx = field_names.index(idx_or_name) + field_name = idx_or_name elif isinstance(idx_or_name, fx.Node): idx = self.fx_node.args.index(idx_or_name) else: @@ -886,10 +889,19 @@ def update_arg(self, idx_or_name: int | str | fx.Node, value: CustomOp | fx.Node value = value.fx_node # Skip the fields defined by the abstract base class if 0 <= idx < len(field_names): - field_name = field_names[idx] + field_name = field_name or field_names[idx] # Set the new value for the field setattr(self, field_name, value) - self.fx_node.update_arg(idx, value) + if field_name in self.fx_node.kwargs: + kwargs = dict(self.fx_node.kwargs) + kwargs[field_name] = value + self.fx_node.kwargs = kwargs + elif idx < len(self.fx_node.args): + self.fx_node.update_arg(idx, value) + else: + raise IndexError( + f"Field '{field_name}' is not present in fx.Node args or kwargs" + ) else: raise IndexError("Index out of range") @@ -951,12 +963,70 @@ def replacement_location_propagate( else: new_node.location = self.location - def replace_all_uses_with(self, new_node: CustomOp | fx.Node): - """Replace all uses of the current node with the new node.""" + def replace_uses_with( + self, + new_node: CustomOp | fx.Node, + *, + graph: Optional[fx.Graph] = None, + propagate_location: bool = True, + ) -> None: + """Replace uses of this node, optionally restricted to one graph. + + When `graph` is `None`, this replaces every use of the current node. + When `graph` is provided, only uses in that specific `fx.Graph` are + replaced. + + This matters for nested regions: the same FX node may be referenced from + multiple graphs at the same time, for example from a nested subgraph and + from sibling or outer graphs. A graph-scoped replacement therefore does + not imply that the current node becomes globally dead. After the call, + the node may still have uses in other graphs that were intentionally left + untouched. + """ if isinstance(new_node, CustomOp): new_node = new_node.fx_node - self.replacement_location_propagate(new_node) - self.fx_node.replace_all_uses_with(new_node) + if self.fx_node is new_node: + return + if propagate_location: + self.replacement_location_propagate(new_node) + + def contains(node: fx.Node, user: fx.Node) -> bool: + found = False + + def visit(arg): + nonlocal found + if arg is node: + found = True + return arg + + fx.map_arg((user.args, user.kwargs), visit) + return found + + for user in list(self.fx_node.users): + if graph is not None and user.graph is not graph: + continue + user._update_args_kwargs( + fx.map_arg( + user.args, lambda arg: new_node if arg is self.fx_node else arg + ), + fx.map_arg( + user.kwargs, lambda arg: new_node if arg is self.fx_node else arg + ), + ) + if self.fx_node.graph is not user.graph and not contains( + self.fx_node, user + ): + self.fx_node.users.pop(user, None) + if ( + isinstance(new_node, fx.Node) + and new_node.graph is not user.graph + and contains(new_node, user) + ): + new_node.users[user] = None + + def replace_all_uses_with(self, new_node: CustomOp | fx.Node): + """Replace all uses of the current node with the new node.""" + self.replace_uses_with(new_node) def replace_all_uses_with_except( self, new_node: CustomOp | fx.Node, except_nodes: list[CustomOp] @@ -1510,14 +1580,8 @@ def erase(self): if not isinstance(custom, NestedRegionOp): return - # Cleanup dead captures subgraph = custom.get_root_graph().subgraphs[custom.subgraph_name] - live_captures = [] - for var in custom.implicit_captures: - if custom.get_captured_fx_node(subgraph, var): - live_captures.append(var) - - custom.update_arg("implicit_captures", live_captures) + custom.refresh_captures(subgraph) @property def indexing_dims(self) -> list[IndexSymbol]: @@ -2314,32 +2378,156 @@ def is_contiguous_vec(self, constraints, target: str) -> bool: class NestedRegionOp(CustomOp): def captured_vars(self, graph: fx.Graph) -> list[fx.Node]: - """ - Nodes that are placeholders and are not iter args are captured vars. - """ + """Return local Placeholder nodes that represent captured outer values.""" captured_vars = [] for nested_node in graph.nodes: custom = get_custom(nested_node) - if isinstance(custom, Placeholder) and not isinstance(custom, IterArg): - captured_vars.append(nested_node) + if isinstance(custom, IterArg): + continue + captured = self.capture_source(nested_node) + # Before canonicalization, malformed or legacy placeholders may still + # resolve to another local node instead of a true outer source. + if captured is nested_node or captured.graph is graph: + continue + captured_vars.append(nested_node) return captured_vars - def get_outer_node(self, outer_node: fx.Node) -> fx.Node: - while "lifted" in outer_node.meta: - outer_node = outer_node.meta["lifted"] - return outer_node - def get_captured_fx_node( - self, graph: fx.Graph, outer_node: fx.Node + self, + graph: fx.Graph, + outer_node: fx.Node, + lookup: tuple[dict[fx.Node, fx.Node], list[fx.Node]] | None = None, ) -> Optional[fx.Node]: - outer_node = self.get_outer_node(outer_node) + """Return the local representative for `outer_node` in `graph` if it exists.""" + outer_node = self.capture_source(outer_node) + if lookup is not None: + by_outer, _ = lookup + return by_outer.get(outer_node) for var in self.captured_vars(graph): - custom = get_custom(var) - if custom.get_captured_fx_node() == outer_node: + if self.capture_source(var) is outer_node: return var - return None + def get_capture_bindings( + self, + graph: fx.Graph, + lookup: tuple[dict[fx.Node, fx.Node], list[fx.Node]] | None = None, + ) -> list[tuple[fx.Node, fx.Node]]: + """Return `(outer_source, local_region_value)` pairs in signature order.""" + by_outer = ( + # Keep the first local representative for each outer source as the + # canonical binding when multiple legacy nodes still alias it. + {self.capture_source(var): var for var in self.captured_vars(graph)} + if lookup is None + else lookup[0] + ) + bindings = [] + for outer_node in self.implicit_captures: + outer_source = self.capture_source(outer_node) + captured = by_outer.get(outer_source) + if captured is not None: + bindings.append((outer_source, captured)) + return bindings + + def refresh_captures( + self, + graph: fx.Graph, + lookup: tuple[dict[fx.Node, fx.Node], list[fx.Node]] | None = None, + ) -> None: + """Refresh the capture signature from the current graph contents.""" + if lookup is None: + # Match `get_capture_bindings`: the first local representative + # becomes the canonical binding for each outer source. + by_outer = { + self.capture_source(var): var for var in self.captured_vars(graph) + } + direct_sources_in_order: list[fx.Node] = [] + else: + by_outer, direct_sources_in_order = lookup + # dict-keyed-by-None is used as an insertion-order set. + captures: dict[fx.Node, None] = {} + + for outer_node in self.implicit_captures: + resolved = self.capture_source(outer_node) + if resolved in by_outer: + captures.setdefault(resolved, None) + for outer_node in direct_sources_in_order: + captures.setdefault(outer_node, None) + + self.update_arg("implicit_captures", list(captures)) + + @staticmethod + def capture_source(node: fx.Node | CustomOp) -> fx.Node: + """Return the defining outer value for a local Placeholder node.""" + if isinstance(node, CustomOp): + node = node.fx_node + seen: set[fx.Node] = set() + while "lifted" in node.meta: + if node in seen: + raise ValueError( + f"Cycle detected while resolving lifted capture source for {node}" + ) + seen.add(node) + node = node.meta["lifted"] + return node + + @staticmethod + def _last_region_input_or_root(graph: fx.Graph) -> fx.Node: + """Return the last leading region-input node in `graph`.""" + last = graph._root + for node in graph.nodes: + if isinstance(get_custom(node), Placeholder): + last = node + else: + break + return last + + @classmethod + def materialize_capture_placeholder( + cls, + graph: fx.Graph, + outer_node: fx.Node | CustomOp, + location: Optional[CapturedLocation] = None, + ) -> fx.Node: + """Return a lifted placeholder that represents `outer_node` in `graph`. + + If one already exists in the leading placeholder prefix, it is + returned directly. Otherwise a new one is created. + """ + outer_node = cls.capture_source(outer_node) + for node in graph.nodes: + if not isinstance(get_custom(node), Placeholder): + break + if node.meta.get("lifted") is outer_node: + return node + placeholder = Placeholder(outer_node.name, outer_node.type) + with graph.inserting_after(cls._last_region_input_or_root(graph)): + placeholder_node = placeholder.add_to_graph(graph, loc=location) + if placeholder_node.name != outer_node.name: + # Preserve the outer source name when possible for compatibility + # with legacy code paths that still correlate capture placeholders + # with outer values by name. + # `add_to_graph` goes through `fx.Graph.create_node`, which may + # auto-rename to keep graph-local names unique. Restore the + # semantic capture name afterwards, renaming any conflicting node if + # we can still find one. + conflicting_node = next( + ( + node + for node in graph.nodes + if node is not placeholder_node and node.name == outer_node.name + ), + None, + ) + if conflicting_node is not None: + conflicting_node.name = placeholder_node.name + placeholder_node.name = outer_node.name + # `Placeholder.add_to_graph` only creates the FX node; keep the type on + # the FX node itself so later passes do not need to re-infer it. + placeholder_node.type = outer_node.type + placeholder_node.meta["lifted"] = outer_node + return placeholder_node + def get_root_graph(self): """ Return the "root"/outermost layer of our computation graph. diff --git a/wave_lang/kernel/ops/wave_schedule_ops.py b/wave_lang/kernel/ops/wave_schedule_ops.py index e355dbb9c7..9d00ad0d3d 100755 --- a/wave_lang/kernel/ops/wave_schedule_ops.py +++ b/wave_lang/kernel/ops/wave_schedule_ops.py @@ -137,6 +137,9 @@ def _insert_cond_barrier( barrier_graph.parent_op = cond_barrier trace.add_subgraph(barrier_graph_name, barrier_graph) + local_root = get_custom(cond_barrier).get_root_graph() + if barrier_graph_name not in local_root.subgraphs: + local_root.subgraphs[barrier_graph_name] = barrier_graph return cond_barrier diff --git a/wave_lang/kernel/wave/analysis/annotate_iv_strides.py b/wave_lang/kernel/wave/analysis/annotate_iv_strides.py new file mode 100644 index 0000000000..bcf6ab1a82 --- /dev/null +++ b/wave_lang/kernel/wave/analysis/annotate_iv_strides.py @@ -0,0 +1,70 @@ +# Copyright 2025 The IREE Authors +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Annotate Read and GatherToLDS nodes with pre-computed IV strides. + +Walks all Read and GatherToLDS nodes that have a non-identity mapping and +computes the linearized IV stride through the mapping via numerical probing. +The result is stored in ``node.meta["iv_stride"]`` so that codegen can use +the PRECOMPUTED fast path in ``_try_iv_split_offset`` without performing +any symbolic analysis at MLIR emission time. + +This pass should run after ``generate_bounds_exprs`` (all indices are final) +and before ``merge_contiguous_reads`` (which drops mappings on merged reads +but propagates ``meta["iv_stride"]`` from the anchor). +""" + +from collections.abc import Sequence + +from ..._support.indexing import IndexingContext +from ..._support.tracing import CapturedTrace +from ...compiler.utils import strides_from_symbolic_shape +from ...ops.wave_ops import Read, GatherToLDS, get_custom +from ..constraints import Constraint +from ..utils.mapping_utils import compute_iv_stride_through_mapping + + +def annotate_iv_strides( + trace: CapturedTrace, + constraints: Sequence[Constraint] = (), +): + """Annotate every mapped Read/GatherToLDS with ``meta["iv_stride"]``.""" + idxc = IndexingContext.current() + + for node in trace.walk( + lambda n: isinstance(get_custom(n), (Read, GatherToLDS)) + ): + if node.meta.get("iv_stride") is not None: + continue + + custom = get_custom(node) + + if isinstance(custom, GatherToLDS): + mapping = custom.src_mapping + index = custom.src_index + mem_node = custom.src + else: + mapping = custom.mapping + index = custom.index + mem_node = custom.memory + + if mapping is None: + continue + if hasattr(custom, "has_identity_mapping") and custom.has_identity_mapping(): + continue + + symbolic_shape = custom.type.symbolic_shape + mem_sym_shape = get_custom(mem_node).type.symbolic_shape + phys_strides = strides_from_symbolic_shape( + idxc, mem_sym_shape, allow_mixed_shapes=True + ) + + iv_stride = compute_iv_stride_through_mapping( + mapping, symbolic_shape, index, + is_read=True, mem_strides=list(phys_strides), + constraints=constraints, + ) + if iv_stride is not None: + node.meta["iv_stride"] = iv_stride diff --git a/wave_lang/kernel/wave/analysis/index_sequence_analysis.py b/wave_lang/kernel/wave/analysis/index_sequence_analysis.py index b70ad62795..b65d95dc7e 100644 --- a/wave_lang/kernel/wave/analysis/index_sequence_analysis.py +++ b/wave_lang/kernel/wave/analysis/index_sequence_analysis.py @@ -19,6 +19,10 @@ PersistentEmitter, ) from wave_lang.kernel.wave.compile_options import WaveCompileOptions +from wave_lang.kernel.wave.region_canonicalization import ( + RegionFormat, + requires_region_format, +) from wave_lang.support.logging import get_logger from ..._support.indexing import IndexSequence, IndexSymbol, IndexExpr @@ -410,6 +414,7 @@ def ensure_symbols_positive( ) from e +@requires_region_format(RegionFormat.LEGACY_PLACEHOLDERS) def set_node_indices_water_checked( trace: CapturedTrace, constraints: list[Constraint], @@ -438,6 +443,7 @@ def set_node_indices_water_checked( _reset_water_id(trace) +@requires_region_format(RegionFormat.LEGACY_PLACEHOLDERS) def set_node_indices( trace: CapturedTrace, constraints: list[Constraint], @@ -812,7 +818,7 @@ def combine_indices( which make the index sequence (access pattern) thread specific. These are added to the thread independent index which is obtained from the constraints. """ - combined_index = {k: v for k, v in thread_independent_index.items()} + combined_index = {k: deepcopy(v) for k, v in thread_independent_index.items()} for k in combined_index: if k in thread_dependent_index: combined_index[k].start += thread_dependent_index[k].start diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index 4ddac1abb2..a7a6f1f314 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -22,6 +22,7 @@ from ..._support.tracing import CapturedTrace from ...lang.global_symbols import * from ...lang.wave_types import IndexMapping +from ..index_mapping_simplify import simplify_index_mapping, get_tile_sizes_from_index from ...ops.wave_ops import ( CustomOp, ExtractSlice, @@ -800,6 +801,9 @@ def _emit_wide_read(anchor_custom, wide_index, wide_ept, tag_source, mask_expr=N wide_read.vector_shapes = deepcopy(tag_source.vector_shapes) if mask_expr is not None: wide_read.precomputed_mask_expr = mask_expr + anchor_iv_stride = anchor_custom.fx_node.meta.get("iv_stride") + if anchor_iv_stride is not None: + wide_read.meta["iv_stride"] = anchor_iv_stride propagate_tag(tag_source, wide_read) return wide_read @@ -1941,6 +1945,26 @@ def simplify_indices(trace: CapturedTrace, constraints: Sequence[Constraint] = ( ) if mapping_changed: custom.mapping = new_mapping + # Try to eliminate flat//D and flat%D patterns using + # tile-level iterator bounds from the node's index. + # Re-read from node to get the persisted mapping (not the + # transient wrapper modified by _simplify_mapping above). + custom = get_custom(node) + if custom.mapping is not None: + try: + node_index = custom.index + except (ValueError, AttributeError): + node_index = None + if isinstance(node_index, dict): + tile_sizes = get_tile_sizes_from_index( + custom.mapping, node_index + ) + if tile_sizes: + new_mapping2, mapping_changed2 = simplify_index_mapping( + custom.mapping, constraints, tile_sizes + ) + if mapping_changed2: + custom.update_arg("mapping", new_mapping2) # Simplify index sequences. try: index = custom.index diff --git a/wave_lang/kernel/wave/barriers.py b/wave_lang/kernel/wave/barriers.py index 2b31eef475..4ec4e89810 100644 --- a/wave_lang/kernel/wave/barriers.py +++ b/wave_lang/kernel/wave/barriers.py @@ -16,6 +16,7 @@ SharedMemoryBarrierWait, get_custom, ) +from .region_canonicalization import RegionFormat, requires_region_format from .utils.graph_utils import ( is_barrier_between, ) @@ -156,6 +157,7 @@ def place_barrier(self, region: SyncRegion) -> None: ) +@requires_region_format(RegionFormat.SCHEDULE_SIGNATURE_PLACEHOLDERS) def add_shared_memory_barriers( trace: CapturedTrace, target: Optional[str] = None, diff --git a/wave_lang/kernel/wave/cluster_barriers.py b/wave_lang/kernel/wave/cluster_barriers.py index b8ca83593c..d9940e67a4 100644 --- a/wave_lang/kernel/wave/cluster_barriers.py +++ b/wave_lang/kernel/wave/cluster_barriers.py @@ -13,6 +13,7 @@ from .._support.tracing import CapturedTrace from .compile_options import WaveCompileOptions from .constraints import Constraint, TilingConstraint +from .region_canonicalization import RegionFormat, requires_region_format from ..ops.wave_ops import ( get_custom, Iterate, @@ -97,9 +98,13 @@ def add_cluster_barriers_to_iterate( trace.add_subgraph(signal_subgraph_name, signal_subgraph) - Conditional(condition, signal_subgraph_name, []).add_to_graph( + signal_cond = Conditional(condition, signal_subgraph_name, []).add_to_graph( subgraph, loc=location ) + signal_subgraph.parent_op = signal_cond + local_root = get_custom(signal_cond).get_root_graph() + if signal_subgraph_name not in local_root.subgraphs: + local_root.subgraphs[signal_subgraph_name] = signal_subgraph # Add conditional barrier_wait at end of body output_node = next(n for n in subgraph.nodes if n.op == "output") @@ -116,9 +121,13 @@ def add_cluster_barriers_to_iterate( trace.add_subgraph(wait_subgraph_name, wait_subgraph) - Conditional(condition, wait_subgraph_name, []).add_to_graph( + wait_cond = Conditional(condition, wait_subgraph_name, []).add_to_graph( subgraph, loc=location ) + wait_subgraph.parent_op = wait_cond + local_root = get_custom(wait_cond).get_root_graph() + if wait_subgraph_name not in local_root.subgraphs: + local_root.subgraphs[wait_subgraph_name] = wait_subgraph def is_multicast_tensor_load(node: fx.Node) -> bool: @@ -131,6 +140,7 @@ def is_multicast_tensor_load(node: fx.Node) -> bool: return True +@requires_region_format(RegionFormat.DIRECT_OUTER_REF) def add_cluster_barriers( trace: CapturedTrace, constraints: list[Constraint], options: WaveCompileOptions ): diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index df2bdc91d0..70c9abfb8f 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -37,6 +37,7 @@ set_node_indices_water_checked, set_post_expansion_indices, ) +from .analysis.annotate_iv_strides import annotate_iv_strides from .analysis.partition_strided_operators import ( merge_contiguous_reads, partition_gather_like_ops, @@ -72,6 +73,14 @@ from .preshuffle_scale_to_shared import preshuffle_scale_to_shared from .multicast import multicast from .promotion import compute_shared_memory_usage, promote_placeholders +from .region_canonicalization import ( + RegionFormat, + prepare_region_captures, + raw_graph_pass, + requires_region_format, + verify_canonical_region_captures, + wrap_graph_passes_with_region_adapters, +) from .schedule_reordering import schedule_reordering from .scheduling.loop_reconstruction import guard_g2s_with_bounds_check from .scheduling.schedule import schedule_graph @@ -121,7 +130,11 @@ is_cache_enabled, ) -from .water import water_leak_in_bounds_check, water_lowering_pipeline +from .water import ( + water_leak_in_bounds_check, + water_lowering_pipeline, + water_waveasm_lowering_pipeline, +) from wave_lang.runtime.launch import Launchable from wave_lang.runtime.multi_device_launch import MultiDeviceLaunchable from .wave import LaunchableWave @@ -458,7 +471,8 @@ def build_graph_passes( all compilation stages. Each pass is a zero-argument callable (typically a `partial`). Passes mutate the *trace* in place and must be executed in the returned order within the same `IndexingContext` that was active when - the trace was created. + the trace was created. The returned passes include canonical-region + adapters according to each pass's declared region-format requirements. """ if debug_arg_info is None: debug_arg_info = [] @@ -593,6 +607,7 @@ def build_graph_passes( launchable.constraints, launchable.reordering_constraints, ), + partial(annotate_iv_strides, trace, launchable.constraints), partial( merge_contiguous_reads, trace, @@ -616,7 +631,8 @@ def build_graph_passes( ) ) - return graph_passes + raw_graph_passes = [raw_graph_pass(graph_pass) for graph_pass in graph_passes] + return wrap_graph_passes_with_region_adapters(trace, raw_graph_passes) def _build_initial_pass_pipeline( @@ -630,10 +646,12 @@ def _build_initial_pass_pipeline( ) -> list[Callable]: idxc = IndexingContext.current() - def finalize_indices(): + @requires_region_format(RegionFormat.LEGACY_PLACEHOLDERS) + def finalize_indices(trace: CapturedTrace): idxc.finalize() - def substitute_vector_shapes(): + @requires_region_format(RegionFormat.LEGACY_PLACEHOLDERS) + def substitute_vector_shapes(trace: CapturedTrace): launchable.hardware_constraints[0].subs_vector_shapes(idxc.subs) return ( @@ -642,8 +660,8 @@ def substitute_vector_shapes(): partial(initialize_iter_args, trace), partial(launchable.create_induction_vars, trace), partial(launchable.initialize_reductions, trace), - finalize_indices, - substitute_vector_shapes, + partial(finalize_indices, trace), + partial(substitute_vector_shapes, trace), partial(add_get_results, trace), partial(infer_types, trace, launchable.constraints), partial(construct_index_mapping, trace, launchable.constraints), @@ -801,6 +819,9 @@ def compile_launchable_to_mlir( # Only emit MLIR if we don't have a module yet. if not module_op: + # The non-Water Wave emitter currently consumes the schedule-style + # hybrid region view rather than fully canonical captures. + prepare_region_captures(trace, RegionFormat.SCHEDULE_SIGNATURE_PLACEHOLDERS) emitter = WaveEmitter( dispatch_entrypoint, trace, @@ -914,9 +935,13 @@ def _trace_launchable_and_get_kernel_signature( pass_times = {} for p in graph_passes: + if options.verify_region_captures: + verify_canonical_region_captures(trace, f"before {p.__name__}") try_apply_pass( p, trace, print_ir_before, print_ir_after, profile_pass, pass_times ) + if options.verify_region_captures: + verify_canonical_region_captures(trace, f"after {p.__name__}") if options.print_pass_times: pass_times_list = sorted(pass_times.items(), key=lambda x: x[1], reverse=True) @@ -1092,7 +1117,7 @@ def get_binary_path(): ) = compile_launchable_to_mlir( launchable=kernel, trace=graph, - context=None, + context=overriding_module_op.context, module_op=overriding_module_op, options=options, ) @@ -1170,7 +1195,16 @@ def get_binary_path(): ] options.kernel_usages = kernel_usages - if options.compile_to_asm or options.backend == "asm": + if options.use_water_backend and options.backend == "asm": + # Water + WaveASM flow: lower to LLVM dialect, then waveasm. + module = water_waveasm_lowering_pipeline(mb.module_op, options) + return WaveKernelExecutionEngine( + options, + module, + asm, + create_execution_engine=not options.compile_to_mlir, + ) + elif options.compile_to_asm or options.backend == "asm": # ASM flow: generate AMDGCN assembly; optionally build a binary asm = _generate_asm_code(mb, options) @@ -1286,6 +1320,10 @@ def _generate_asm_code(mb, options): mlir_file.write(kernel_mlir) mlir_path = mlir_file.name + # Debug: save a copy of the MLIR input to waveasm-translate + import shutil + shutil.copy(mlir_path, "/tmp/waveasm_input.mlir") + try: base_passes = [ "--mlir-cse", @@ -1312,14 +1350,16 @@ def _generate_asm_code(mb, options): # TODO: improve Ticketing logic (better latency-covering heuristics, # smarter coalescing) so ticketed waitcnt can be always-on without # a performance hit, removing this wave-shape conditional. - use_ticketed_waitcnt = waves_in_m >= 2 and waves_in_n >= 2 + use_ticketed_waitcnt = False waitcnt_flag = ( "--waveasm-insert-waitcnt" if use_ticketed_waitcnt else "--waveasm-insert-waitcnt=ticketed-waitcnt=false" ) tail_passes = [ + "--waveasm-scc-verifier", "--waveasm-linear-scan=max-vgprs=512 max-agprs=512", + "--waveasm-vgpr-compaction", waitcnt_flag, f"--waveasm-hazard-mitigation=target={options.target}", "--emit-assembly", @@ -1336,7 +1376,17 @@ def _run_translate(extra_passes): + extra_passes + tail_passes ) - return subprocess.run(full_cmd, capture_output=True, text=True, timeout=60) + # If WAVEASM_DUMP_IR is set, add --mlir-print-ir-after-all and + # save stderr to the specified file for debugging. + ir_dump_path = os.environ.get("WAVEASM_DUMP_IR") + if ir_dump_path: + full_cmd.append("--mlir-print-ir-after-all") + result = subprocess.run(full_cmd, capture_output=True, text=True, timeout=120) + if ir_dump_path and result.stderr: + os.makedirs(os.path.dirname(ir_dump_path) or ".", exist_ok=True) + with open(ir_dump_path, "w") as f: + f.write(result.stderr) + return result import re @@ -1444,4 +1494,5 @@ def validate_options(options: WaveCompileOptions): ) if options.backend == "asm" and not options.wave_runtime: - raise ValueError("ASM backend requires wave_runtime=True") + if not options.use_water_backend: + raise ValueError("ASM backend requires wave_runtime=True") diff --git a/wave_lang/kernel/wave/compile_options.py b/wave_lang/kernel/wave/compile_options.py index 674d0a08a3..fafb01453b 100644 --- a/wave_lang/kernel/wave/compile_options.py +++ b/wave_lang/kernel/wave/compile_options.py @@ -78,6 +78,7 @@ class WaveCompileOptions: check_water_analysis: bool = False enforce_locations: bool = True drop_debug_info_before_mlir: bool = True + verify_region_captures: bool = True # === Performance options === optimization_level: bool = True diff --git a/wave_lang/kernel/wave/construct_index_mapping.py b/wave_lang/kernel/wave/construct_index_mapping.py index 55f96bfcf6..cf8c5dc558 100644 --- a/wave_lang/kernel/wave/construct_index_mapping.py +++ b/wave_lang/kernel/wave/construct_index_mapping.py @@ -8,9 +8,11 @@ from ..ops.wave_ops import Read, Write, get_custom from ..lang.wave_types import IndexMapping, index_symbol, IndexSymbol from ..wave.constraints import Constraint, IteratorBindings +from .region_canonicalization import RegionFormat, requires_region_format import sympy +@requires_region_format(RegionFormat.LEGACY_PLACEHOLDERS) def construct_index_mapping( trace: CapturedTrace, constraints: list[Constraint] ) -> None: diff --git a/wave_lang/kernel/wave/debug_log_hoist.py b/wave_lang/kernel/wave/debug_log_hoist.py index 13afece0ad..8b216654ee 100644 --- a/wave_lang/kernel/wave/debug_log_hoist.py +++ b/wave_lang/kernel/wave/debug_log_hoist.py @@ -17,6 +17,7 @@ from ..lang.wave_types import IndexMapping, Memory from .compile_options import WaveCompileOptions from .constraints import Constraint, TilingConstraint +from .region_canonicalization import RegionFormat, requires_region_format from typing import TypedDict, Any import sympy @@ -34,6 +35,7 @@ def is_debug_log_transformer(node): return isinstance(get_custom(node), DebugLog) +@requires_region_format(RegionFormat.LEGACY_PLACEHOLDERS) def debug_log_hoist(trace: CapturedTrace, debug_handlers: list[Any]): """ Finds debug log ops and hoists kernel inputs for them. @@ -64,6 +66,7 @@ def debug_log_hoist(trace: CapturedTrace, debug_handlers: list[Any]): ) +@requires_region_format(RegionFormat.LEGACY_PLACEHOLDERS) def debug_log_write_replace( trace: CapturedTrace, constraints: list[Constraint], diff --git a/wave_lang/kernel/wave/decompose_dot_mma.py b/wave_lang/kernel/wave/decompose_dot_mma.py index baca439fba..91e6fa2079 100644 --- a/wave_lang/kernel/wave/decompose_dot_mma.py +++ b/wave_lang/kernel/wave/decompose_dot_mma.py @@ -10,6 +10,7 @@ from .._support.tracing import CapturedTrace from ..ops.wave_ops import MMA, Add, CastOp, Mul, Sum, get_custom +from .region_canonicalization import RegionFormat, requires_region_format from .constraints import ( Constraint, GenericDot, @@ -17,6 +18,7 @@ from .utils.general_utils import get_hardware_constraint +@requires_region_format(RegionFormat.ISOLATED) def decompose_dot_mma(trace: CapturedTrace, constraints: list[Constraint]): """ Decomposes dot MMA operations into the dot products and cross-thread reductions. diff --git a/wave_lang/kernel/wave/decompose_reduce_ops.py b/wave_lang/kernel/wave/decompose_reduce_ops.py index 17ed85147e..a5743ff51f 100644 --- a/wave_lang/kernel/wave/decompose_reduce_ops.py +++ b/wave_lang/kernel/wave/decompose_reduce_ops.py @@ -15,6 +15,7 @@ from .._support.tracing import CapturedTrace from .._support.location import CapturedLocation from ..lang.global_symbols import * +from .region_canonicalization import RegionFormat, requires_region_format from ..ops.wave_ops import ( Add, Allocate, @@ -313,6 +314,7 @@ def emit_interwave_reduction( return interwave_reduction +@requires_region_format(RegionFormat.DIRECT_OUTER_REF) def decompose_reduce_ops( trace: CapturedTrace, constraints: list[Constraint], diff --git a/wave_lang/kernel/wave/decompose_scan_ops.py b/wave_lang/kernel/wave/decompose_scan_ops.py index 7ca57ee86d..415acf5a7a 100644 --- a/wave_lang/kernel/wave/decompose_scan_ops.py +++ b/wave_lang/kernel/wave/decompose_scan_ops.py @@ -13,6 +13,7 @@ from .._support.dtype import i1 from .._support.location import CapturedLocation from .._support.tracing import CapturedTrace +from .region_canonicalization import RegionFormat, requires_region_format from ..ops.wave_ops import ( Add, Allocate, @@ -30,13 +31,14 @@ ScanOp, SelectOp, ShuffleOp, + NestedRegionOp, Write, get_custom, ) from .constraints import HardwareConstraint, WaveConstraint, WorkgroupConstraint from .utils.classes import ShuffleMode from .utils.general_utils import all_equal, delinearize_index -from .utils.graph_utils import DCE, get_outer_node +from .utils.graph_utils import DCE def get_graph_node( @@ -348,7 +350,7 @@ def emit_interwave_scan( Conditional( is_last_lane, subgraph_name=sub_store, - implicit_captures=[get_outer_node(wave_total), sums_buf], + implicit_captures=[NestedRegionOp.capture_source(wave_total), sums_buf], ), graph, location, @@ -360,7 +362,9 @@ def emit_interwave_scan( # that the compiler can't find. exec_on_last.parent_op = cond_store trace.add_subgraph(sub_store, exec_on_last) - trace.get_root_graph().subgraphs[sub_store] = exec_on_last + local_root = get_custom(cond_store).get_root_graph() + if sub_store not in local_root.subgraphs: + local_root.subgraphs[sub_store] = exec_on_last # read all per-wave totals read_totals = Read( @@ -412,6 +416,7 @@ def emit_interwave_scan( return [out] +@requires_region_format(RegionFormat.DIRECT_OUTER_REF) def decompose_scan_ops( trace: CapturedTrace, constraints: list, diff --git a/wave_lang/kernel/wave/decompose_topk_ops.py b/wave_lang/kernel/wave/decompose_topk_ops.py index 9e912bef42..d6b0721613 100644 --- a/wave_lang/kernel/wave/decompose_topk_ops.py +++ b/wave_lang/kernel/wave/decompose_topk_ops.py @@ -15,6 +15,7 @@ from .._support.tracing import CapturedTrace from .._support.indexing import IndexSequence, IndexExpr from ..lang.global_symbols import * +from .region_canonicalization import RegionFormat, requires_region_format from ..ops.wave_ops import ( Broadcast, Eq, @@ -533,6 +534,7 @@ def decompose_topk_op( # that target them, and remove them together. +@requires_region_format(RegionFormat.DIRECT_OUTER_REF) def decompose_topk_ops( trace: CapturedTrace, constraints: list[Constraint], diff --git a/wave_lang/kernel/wave/decompose_vmma_ops.py b/wave_lang/kernel/wave/decompose_vmma_ops.py index 7b5839e18a..144da314b5 100644 --- a/wave_lang/kernel/wave/decompose_vmma_ops.py +++ b/wave_lang/kernel/wave/decompose_vmma_ops.py @@ -15,6 +15,7 @@ Reshape, get_custom, ) + from ..wave.constraints import ( Constraint, HardwareConstraint, diff --git a/wave_lang/kernel/wave/expansion/expansion.py b/wave_lang/kernel/wave/expansion/expansion.py index f116ceb357..7b91b77329 100644 --- a/wave_lang/kernel/wave/expansion/expansion.py +++ b/wave_lang/kernel/wave/expansion/expansion.py @@ -37,6 +37,7 @@ Write, get_custom, ) +from ..region_canonicalization import RegionFormat, requires_region_format from ..constraints import ( Constraint, ) @@ -439,6 +440,7 @@ def get_mma_reduction_count(arg: MMA, dim_scaling: dict[IndexSymbol, int]) -> in return reduction_count +@requires_region_format(RegionFormat.LEGACY_PLACEHOLDERS) def add_get_results(trace: CapturedTrace): iterate_ops = trace.walk(lambda x: isinstance(get_custom(x), Iterate)) conditional_ops = trace.walk(lambda x: isinstance(get_custom(x), Conditional)) @@ -972,6 +974,7 @@ def is_leaf_node(node): ) +@requires_region_format(RegionFormat.LEGACY_PLACEHOLDERS) def expand_graph( trace: CapturedTrace, constraints: Sequence[Constraint], diff --git a/wave_lang/kernel/wave/fuse_tensor_loads.py b/wave_lang/kernel/wave/fuse_tensor_loads.py index 7eb614618e..094eb9eac8 100644 --- a/wave_lang/kernel/wave/fuse_tensor_loads.py +++ b/wave_lang/kernel/wave/fuse_tensor_loads.py @@ -18,6 +18,7 @@ TensorLoadToLDS, get_custom, ) +from .region_canonicalization import RegionFormat, requires_region_format from ..wave.constraints import Constraint from ..wave.compile_options import WaveCompileOptions from ..wave.utils.general_utils import get_hardware_constraint @@ -301,6 +302,7 @@ def find_adjacent_loads( return fusable_pairs +@requires_region_format(RegionFormat.DIRECT_OUTER_REF) def fuse_tensor_loads( trace: CapturedTrace, constraints: list[Constraint], diff --git a/wave_lang/kernel/wave/gather_to_shared.py b/wave_lang/kernel/wave/gather_to_shared.py index 506d4fde57..23e3345ca4 100644 --- a/wave_lang/kernel/wave/gather_to_shared.py +++ b/wave_lang/kernel/wave/gather_to_shared.py @@ -18,6 +18,7 @@ from .._support.indexing import IndexExpr, IndexSequence, IndexSymbol, xor from .._support.tracing import CapturedTrace from ..lang.global_symbols import * +from .region_canonicalization import RegionFormat, requires_region_format from ..ops.wave_ops import ( CustomOp, GatherToLDS, @@ -455,6 +456,7 @@ def get_load_width( return load_width +@requires_region_format(RegionFormat.SCHEDULE_SIGNATURE_PLACEHOLDERS) def gather_to_shared( trace: CapturedTrace, constraints: list[Constraint], @@ -574,6 +576,7 @@ def gather_to_shared( DCE(trace) +@requires_region_format(RegionFormat.SCHEDULE_SIGNATURE_PLACEHOLDERS) def gather_to_shared_swizzling( trace: CapturedTrace, constraints: list[Constraint], diff --git a/wave_lang/kernel/wave/generate_bounds_exprs.py b/wave_lang/kernel/wave/generate_bounds_exprs.py index efc158b9d6..dcb6e9b027 100644 --- a/wave_lang/kernel/wave/generate_bounds_exprs.py +++ b/wave_lang/kernel/wave/generate_bounds_exprs.py @@ -8,6 +8,7 @@ import torch.fx as fx from ..ops.wave_ops import Read, Write +from .region_canonicalization import RegionFormat, requires_region_format from .assumptions import get_divisibility_subs from .constraints import Constraint, DistributionConstraint, ReorderingConstraint from .utils.general_utils import ( @@ -57,6 +58,7 @@ def is_divisible( return simplify(diff) == 0 +@requires_region_format(RegionFormat.DIRECT_OUTER_REF) def generate_bounds_exprs( trace: CapturedTrace, constraints: list[Constraint], diff --git a/wave_lang/kernel/wave/global_to_shared_gathers.py b/wave_lang/kernel/wave/global_to_shared_gathers.py index 545d8fdabd..af30fbed5f 100644 --- a/wave_lang/kernel/wave/global_to_shared_gathers.py +++ b/wave_lang/kernel/wave/global_to_shared_gathers.py @@ -15,6 +15,7 @@ from ..lang.global_symbols import * from ..lang.wave_types import IndexMapping from ..ops.wave_ops import Read, Write, get_custom +from .region_canonicalization import RegionFormat, requires_region_format from ..wave.constraints import ( Constraint, HardwareConstraint, @@ -384,6 +385,7 @@ def add_optimized_nodes( return optimized_writes, shared_read_metadata +@requires_region_format(RegionFormat.DIRECT_OUTER_REF) def global_to_shared_gathers(trace: CapturedTrace, constraints: list[Constraint]): """ This function converts global gathers to shared gathers. diff --git a/wave_lang/kernel/wave/hardware_transpose.py b/wave_lang/kernel/wave/hardware_transpose.py index 0adbef38a1..26ba7a49f0 100644 --- a/wave_lang/kernel/wave/hardware_transpose.py +++ b/wave_lang/kernel/wave/hardware_transpose.py @@ -13,6 +13,7 @@ from .utils.tag_utils import propagate_tag from .global_to_shared_gathers import update_read_mapping_dynamic_values +from .region_canonicalization import RegionFormat, requires_region_format from .._support.tracing import CapturedTrace from ..ops.wave_ops import ( Read, @@ -223,6 +224,7 @@ def rewrite_node( custom_node.replace_all_uses_with(concat) +@requires_region_format(RegionFormat.DIRECT_OUTER_REF) def mark_hardware_transpose_candidates( trace: CapturedTrace, constraints: list[Constraint], options: WaveCompileOptions ): diff --git a/wave_lang/kernel/wave/hoisting.py b/wave_lang/kernel/wave/hoisting.py index 87c3046ef8..afab83c995 100644 --- a/wave_lang/kernel/wave/hoisting.py +++ b/wave_lang/kernel/wave/hoisting.py @@ -13,6 +13,7 @@ from ..lang.global_symbols import * from ..ops.wave_ops import * from .constraints import Constraint +from .region_canonicalization import RegionFormat, requires_region_format from .utils.general_utils import get_induction_variable logger = get_logger("wave.hoisting") @@ -118,8 +119,8 @@ def remove_unused_captured_vars(reduction: CustomOp, subgraph: fx.Graph): for captured_idx in reversed(range(len(captured_vars))): if len(captured_vars[captured_idx].users) == 0: get_custom(captured_vars[captured_idx]).erase() - # Order of captured_vars in subgraph do not necessarily match order of root - # implicit_capture. Especially if we introduce instruction reoderings. + # Order of captured vars in subgraph do not necessarily match order of root + # implicit_capture. Especially if we introduce instruction reorderings. root_capture_idx = new_implicit_captures.index( captured_vars[captured_idx].meta["lifted"] ) @@ -127,6 +128,7 @@ def remove_unused_captured_vars(reduction: CustomOp, subgraph: fx.Graph): reduction.update_arg("implicit_captures", new_implicit_captures) +@requires_region_format(RegionFormat.LEGACY_PLACEHOLDERS) def hoist_loop_invariant_ops(trace: CapturedTrace, constraints: list[Constraint]): """Hoists ops that are loop-invariant from reduction subgraphs to outer root graph.""" root_graph = trace.get_root_graph() @@ -139,7 +141,7 @@ def hoist_loop_invariant_ops(trace: CapturedTrace, constraints: list[Constraint] custom_node, constraints ) subgraph = trace.get_subgraph(custom_node.subgraph_name) - # Captured variables from inside the loop. + # Region-local representatives of captured outer values. captured_vars = custom_node.captured_vars(subgraph) hoistable_ops = get_hoistable_ops( subgraph, captured_vars, induction_variable diff --git a/wave_lang/kernel/wave/in_thread_transpose.py b/wave_lang/kernel/wave/in_thread_transpose.py index 7e97134b0e..e159c94120 100644 --- a/wave_lang/kernel/wave/in_thread_transpose.py +++ b/wave_lang/kernel/wave/in_thread_transpose.py @@ -19,6 +19,7 @@ from ..lang.global_symbols import * from ..lang.wave_types import IndexMapping from ..ops.wave_ops import Extract, Read, Reshape, Write, get_custom +from .region_canonicalization import RegionFormat, requires_region_format from ..wave.utils.tag_utils import propagate_tag from ..wave.constraints import ( Constraint, @@ -367,6 +368,7 @@ def create_transpose_writes( return new_writes +@requires_region_format(RegionFormat.DIRECT_OUTER_REF) def in_thread_transpose( trace: CapturedTrace, constraints: list[Constraint], options: WaveCompileOptions ): diff --git a/wave_lang/kernel/wave/index_mapping_simplify.py b/wave_lang/kernel/wave/index_mapping_simplify.py new file mode 100644 index 0000000000..ee2335a499 --- /dev/null +++ b/wave_lang/kernel/wave/index_mapping_simplify.py @@ -0,0 +1,324 @@ +# Copyright 2025 The IREE Authors +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Simplify IndexMapping expressions by decomposing flat // D and flat % D. + +Given an IndexMapping whose input_mapping contains paired expressions: + dim1: flat_expr // D + dim2: flat_expr % D + +This pass tries to decompose flat_expr = quotient * D + remainder where +0 <= remainder < D, allowing the rewrite: + dim1: quotient + dim2: remainder + +The decomposition uses: + 1. Algebraic factoring: terms in flat_expr that are provably multiples of D + are separated into the quotient. + 2. Bounds analysis: if the remaining terms are provably bounded in [0, D), + the dynamic floordiv/mod is eliminated entirely. +""" + +import sympy +from collections.abc import Sequence +from functools import lru_cache + +from ..lang.wave_types import IndexMapping +from .utils.symbol_utils import ( + _split_sum_by_divisibility, + expr_bounds, + IndexExpr, + IndexSymbol, + subs_idxc, +) + + +def _get_symbol_lower_bounds( + constraints: Sequence, +) -> dict[sympy.Symbol, sympy.Expr]: + """Extract lower bounds on symbols from Assumption constraints. + + Handles ``Assumption(S >= c)`` and ``Assumption(S > c)`` patterns, + returning ``{S: lower_bound}`` where lower_bound is the tightest + bound found. + """ + from .assumptions import Assumption + + bounds: dict[sympy.Symbol, sympy.Expr] = {} + for c in constraints: + if not isinstance(c, Assumption): + continue + expr = c.expr + # Match S >= c (GreaterThan) or S > c (StrictGreaterThan). + if isinstance(expr, sympy.GreaterThan): # S >= c + lhs, rhs = expr.args + if lhs.is_Symbol and rhs.is_number: + cur = bounds.get(lhs) + if cur is None or rhs > cur: + bounds[lhs] = rhs + elif isinstance(expr, sympy.StrictGreaterThan): # S > c + lhs, rhs = expr.args + if lhs.is_Symbol and rhs.is_number: + lb = rhs + 1 # S > c implies S >= c+1 for integers. + cur = bounds.get(lhs) + if cur is None or lb > cur: + bounds[lhs] = lb + return bounds + + +def _get_iterator_bounds( + mapping: IndexMapping, + tile_sizes: dict[IndexSymbol, IndexExpr] | None = None, +) -> dict[sympy.Symbol, tuple[sympy.Expr, sympy.Expr]]: + """Extract iterator bounds from tile sizes or the iteration_shape. + + When *tile_sizes* is provided (from the node's index sequences), + the bounds use the concrete tile size for each iterator, which is + typically much tighter than the full dimension from iteration_shape. + + Returns {iterator_symbol: (0, upper_bound - 1)} for each iterator. + """ + bounds = {} + # Build reverse map: dim_symbol -> iterator_symbol. + dim_to_iter = {} + for sym, idx in mapping.iters.items(): + dim = mapping.iteration_shape[idx] + if dim is not None: + dim_to_iter[dim] = sym + + for sym, idx in mapping.iters.items(): + dim = mapping.iteration_shape[idx] + if dim is None: + continue + # Prefer tile_sizes (concrete) over iteration_shape (symbolic). + if tile_sizes and dim in tile_sizes: + upper = tile_sizes[dim] - 1 + else: + upper = dim - 1 + bounds[sym] = (sympy.Integer(0), upper) + return bounds + + +def get_tile_sizes_from_index( + mapping: IndexMapping, + index: dict, +) -> dict[IndexSymbol, IndexExpr]: + """Extract tile sizes for each iterator dimension from the node's index. + + The node's index dict maps dimension symbols (M, N, K) to + IndexSequences with start/size/stride. The mapping's output_mapping + tells us which dimension each iterator corresponds to. The size of + that IndexSequence is the tile size for that iterator. + """ + from .utils.symbol_utils import IndexSequence + + tile_sizes = {} + # output_mapping: {dim_sym: iterator_sym} + # iteration_shape maps iterator ordinal -> dim_sym + for dim in mapping.iteration_shape: + if dim is None: + continue + seq = index.get(dim) + if isinstance(seq, IndexSequence): + tile_sizes[dim] = seq.size + return tile_sizes + + +def _expr_bounds_with_iters( + expr: sympy.Expr, + iter_bounds: dict[sympy.Symbol, tuple[sympy.Expr, sympy.Expr]], +) -> tuple[sympy.Expr, sympy.Expr] | None: + """Compute expression bounds using iterator upper bounds. + + Extends expr_bounds by substituting known iterator ranges. + """ + if expr.is_Integer or expr.is_Rational: + return (expr, expr) + if expr.is_Symbol: + if expr in iter_bounds: + return iter_bounds[expr] + return (sympy.Integer(0), sympy.oo) if expr.is_nonnegative else None + + # For floor/Mod/Add/Mul, delegate to structural recursion. + if isinstance(expr, sympy.Mod): + p, q = expr.args + if q.is_positive and q.is_number: + return (sympy.Integer(0), q - 1) + q_bounds = _expr_bounds_with_iters(q, iter_bounds) + if q_bounds and q_bounds[0].is_positive: + return (sympy.Integer(0), q_bounds[1] - 1) + return None + + if isinstance(expr, sympy.floor): + inner_bounds = _expr_bounds_with_iters(expr.args[0], iter_bounds) + if inner_bounds: + return (sympy.floor(inner_bounds[0]), sympy.floor(inner_bounds[1])) + return None + + if isinstance(expr, sympy.Add): + bounds = [_expr_bounds_with_iters(a, iter_bounds) for a in expr.args] + if all(b is not None for b in bounds): + return (sum(b[0] for b in bounds), sum(b[1] for b in bounds)) + return None + + if isinstance(expr, sympy.Mul): + if not expr.args: + return (sympy.Integer(1), sympy.Integer(1)) + bounds = [_expr_bounds_with_iters(a, iter_bounds) for a in expr.args] + if all(b is not None for b in bounds): + if any(sympy.oo in b or -sympy.oo in b for b in bounds): + return None + lo, hi = bounds[0] + for b in bounds[1:]: + corners = [lo * b[0], lo * b[1], hi * b[0], hi * b[1]] + try: + lo, hi = min(corners), max(corners) + except TypeError: + return None + return (lo, hi) + return None + + return None + + +def _find_floordiv_mod_pairs( + input_mapping: dict[IndexSymbol, IndexExpr], +) -> list[tuple]: + """Find paired floor/Mod expressions sharing the same divisor. + + Returns list of (dim_q, dim_r, numerator, divisor, addend) tuples where: + dim_q has expression: addend + floor(numerator / divisor) + dim_r has expression: Mod(something, divisor) + + Note: sympy auto-evaluates Mod(A*D + B, D) → Mod(B, D), so the Mod's + first arg may not match the floor's numerator exactly. We match on + the divisor and verify compatibility. + """ + floor_info: list[tuple] = [] # (dim, numerator, divisor, addend) + mod_info: list[tuple] = [] # (dim, arg, divisor) + + for dim, expr in input_mapping.items(): + # Top-level Mod(E, D). + if isinstance(expr, sympy.Mod): + mod_info.append((dim, expr.args[0], expr.args[1])) + continue + + # Top-level floor(E / D). + if isinstance(expr, sympy.floor): + inner = expr.args[0] + numer, denom = inner.as_numer_denom() + if denom != 1: + floor_info.append((dim, numer, denom, sympy.Integer(0))) + continue + + # A + floor(E / D) pattern. + if isinstance(expr, sympy.Add): + for arg in expr.args: + if isinstance(arg, sympy.floor): + inner = arg.args[0] + numer, denom = inner.as_numer_denom() + if denom != 1: + addend = expr - arg + floor_info.append((dim, numer, denom, addend)) + break + + # Match on divisor. + pairs = [] + for dim_q, numer, divisor, addend in floor_info: + for dim_r, mod_arg, mod_divisor in mod_info: + if divisor == mod_divisor: + pairs.append((dim_q, dim_r, numer, divisor, addend)) + break + + return pairs + + +def simplify_index_mapping( + mapping: IndexMapping, + constraints=(), + tile_sizes: dict[IndexSymbol, IndexExpr] | None = None, +) -> tuple[IndexMapping, bool]: + """Simplify flat // D and flat % D patterns in an IndexMapping. + + When *tile_sizes* is provided (from a node's index sequences), uses + the concrete tile dimensions as iterator upper bounds. This enables + proving remainder < divisor for tile-level expressions that would be + unprovable with the full-dimension iteration_shape. + + Returns (new_mapping, changed). + """ + iter_bounds = _get_iterator_bounds(mapping, tile_sizes) + sym_lower_bounds = _get_symbol_lower_bounds(constraints) + input_mapping = dict(mapping.input_mapping) + changed = False + + pairs = _find_floordiv_mod_pairs(input_mapping) + for dim_q, dim_r, flat_expr, divisor, addend in pairs: + # Step 1: Factor out D-multiples from flat_expr. + split = _split_sum_by_divisibility(flat_expr, divisor) + if split is None: + quotient = sympy.Integer(0) + remainder = flat_expr + else: + quotient, remainder = split + + # Step 2: Check if remainder is bounded in [0, D). + rem_bounds = _expr_bounds_with_iters(remainder, iter_bounds) + if rem_bounds is None: + continue + + lo, hi = rem_bounds + if hi == sympy.oo: + continue + + # Check lo >= 0. + lo_nonneg = lo.is_nonnegative if hasattr(lo, "is_nonnegative") else None + if lo_nonneg is None: + lo_simplified = sympy.simplify(lo) + lo_nonneg = lo_simplified.is_nonnegative + if not lo_nonneg: + continue + + # Check hi < divisor. When the divisor is symbolic, resolve + # it via IndexingContext.subs (e.g. K_PACKED -> K//2) and then + # substitute known lower bounds from constraints (e.g. K >= 2048 + # lets us prove K//2 >= 1024). + try: + divisor_lb = subs_idxc(divisor) + except (IndexError, KeyError): + divisor_lb = divisor + if sym_lower_bounds: + divisor_lb = divisor_lb.subs( + {s: lb for s, lb in sym_lower_bounds.items()} + ) + # Evaluate floor/ceiling after substitution. + try: + divisor_lb = sympy.Integer(int(divisor_lb)) + except (TypeError, ValueError): + pass + + diff = sympy.simplify(hi - divisor_lb) + if diff.is_negative is not True: + continue + + # Remainder < D proven! Eliminate the dynamic floordiv/mod. + input_mapping[dim_q] = addend + quotient + input_mapping[dim_r] = remainder + changed = True + + if not changed: + return mapping, False + + return ( + IndexMapping( + mapping.num_iterators, + input_mapping, + dict(mapping.output_mapping), + dynamic_val_mappings=tuple( + dict(dvm) for dvm in (mapping.dynamic_val_mappings or ()) + ), + ), + True, + ) diff --git a/wave_lang/kernel/wave/location_check_pass.py b/wave_lang/kernel/wave/location_check_pass.py index 55d72bdbf4..3fc9eb5bb0 100644 --- a/wave_lang/kernel/wave/location_check_pass.py +++ b/wave_lang/kernel/wave/location_check_pass.py @@ -7,11 +7,13 @@ from wave_lang.support.logging import get_logger from .._support.tracing import CapturedTrace from ..ops.wave_ops import get_custom +from .region_canonicalization import RegionFormat, requires_region_format from typing import Optional logger = get_logger("wave.ops_location_check") +@requires_region_format(RegionFormat.DIRECT_OUTER_REF) def location_check_pass( trace: CapturedTrace, pass_name: str = "unnamed", diff --git a/wave_lang/kernel/wave/memory_analysis/minimize_shared_allocs.py b/wave_lang/kernel/wave/memory_analysis/minimize_shared_allocs.py index 814c6d5f75..e26e187ecf 100644 --- a/wave_lang/kernel/wave/memory_analysis/minimize_shared_allocs.py +++ b/wave_lang/kernel/wave/memory_analysis/minimize_shared_allocs.py @@ -10,6 +10,7 @@ import torch.fx as fx from wave_lang.kernel._support.tracing import CapturedTrace +from ..region_canonicalization import RegionFormat, requires_region_format from ..._support.dtype import i8 from ...lang.global_symbols import * @@ -45,20 +46,30 @@ def propagate_user(user: fx.Node) -> list[fx.Node]: return [user] -def compute_live_intervals(allocs: list[fx.Node]): - """ - Compute the live intervals for the allocs. - """ +def _get_propagated_users(node: fx.Node) -> list[fx.Node]: + """Return users of `node`, propagated through region entry/exit adapters.""" + + users, _ = get_users(node, None) + return flatten_list([propagate_user(u) for u in users]) + + +def compute_live_intervals( + allocs: list[fx.Node], +) -> dict[fx.Node, LiveInterval]: + """Compute live intervals for shared allocations that still have uses.""" + live_intervals = {} for alloc in allocs: - live_intervals[alloc] = LiveInterval() - users, _ = get_users(alloc, None) - users = flatten_list([propagate_user(u) for u in users]) + users = _get_propagated_users(alloc) + if not users: + continue + interval = LiveInterval() for user in users: - if user._sort_key < live_intervals[alloc].start: - live_intervals[alloc].start = user._sort_key - if user._sort_key > live_intervals[alloc].end: - live_intervals[alloc].end = user._sort_key + if user._sort_key < interval.start: + interval.start = user._sort_key + if user._sort_key > interval.end: + interval.end = user._sort_key + live_intervals[alloc] = interval return live_intervals @@ -69,8 +80,7 @@ def get_shared_memory_allocation_size(alloc: fx.Node) -> int: def get_use( alloc: fx.Node, live_interval: LiveInterval, match_sort_key: int ) -> fx.Node: - users, _ = get_users(alloc, None) - users = flatten_list([propagate_user(u) for u in users]) + users = _get_propagated_users(alloc) matches = [x for x in users if x._sort_key == live_interval.start] if len(matches) != 1: raise ValueError( @@ -111,7 +121,13 @@ def insert_barrier_if_needed(alloc: fx.Node, first_use: fx.Node, last_use: fx.No ) -def get_alloc_info(trace: CapturedTrace): +def get_alloc_info( + trace: CapturedTrace, +) -> tuple[ + list[fx.Node] | None, + dict[fx.Node, LiveInterval] | None, + list[tuple[int, tuple[int], tuple[int]]] | None, +]: def is_shared_alloc(alloc: fx.Node) -> bool: custom = get_custom(alloc) return ( @@ -123,7 +139,10 @@ def is_shared_alloc(alloc: fx.Node) -> bool: if not allocs: return None, None, None live_intervals = compute_live_intervals(allocs) + if not live_intervals: + return None, None, None + allocs = list(live_intervals) alloc_info = [ ( get_shared_memory_allocation_size(x), @@ -136,6 +155,7 @@ def is_shared_alloc(alloc: fx.Node) -> bool: return allocs, live_intervals, alloc_info +@requires_region_format(RegionFormat.DIRECT_OUTER_REF) def minimize_shared_allocs(trace: CapturedTrace, minimize_shared_allocs: bool): """ Minimize the number of shared allocs by reusing them. diff --git a/wave_lang/kernel/wave/minimize_global_loads.py b/wave_lang/kernel/wave/minimize_global_loads.py index b5694296b7..d2a1d36eb3 100644 --- a/wave_lang/kernel/wave/minimize_global_loads.py +++ b/wave_lang/kernel/wave/minimize_global_loads.py @@ -15,13 +15,21 @@ from .._support.tracing import CapturedTrace from ..lang.global_symbols import * from ..lang.wave_types import IndexMapping -from ..ops.wave_ops import Read, Write, GatherToLDS, TensorLoadToLDS, get_custom +from ..ops.wave_ops import ( + Read, + Write, + GatherToLDS, + TensorLoadToLDS, + CustomOp, + get_custom, +) from ..wave.constraints import ( Constraint, HardwareConstraint, TilingConstraint, WorkgroupConstraint, ) +from .region_canonicalization import RegionFormat, requires_region_format from .utils.general_utils import ( ceildiv, delinearize_index, @@ -134,7 +142,7 @@ def identify_optimizable_loads( num_global_loads > (M * N) / (T * L) where the memory has shape [M, N], there are T threads and each thread can load L elements. """ - optimizable_loads: dict[fx.Node, tuple[int, list[Read], set["Custom"]]] = {} + optimizable_loads: dict[fx.Node, tuple[int, list[Read], set[CustomOp]]] = {} processed_memories = set() for read_node in global_read_nodes: custom = get_custom(read_node) @@ -312,6 +320,7 @@ def update_shared_memory_read( if custom_memory_shape != metadata.memory_shape: permutation = [custom_memory_shape.index(k) for k in metadata.memory_shape] custom_memory.update_arg("shape", metadata.memory_shape) + custom_memory.fx_node.type = custom_memory.type new_distributed_shape = [] for i, perm in enumerate(permutation): offset = 0 @@ -367,6 +376,7 @@ def is_replaceable_write(node: fx.Node) -> bool: DCE(trace) +@requires_region_format(RegionFormat.SCHEDULE_SIGNATURE_PLACEHOLDERS) def minimize_global_loads(trace: CapturedTrace, constraints: list[Constraint]): """ This function attempts to minimize the number of global loads in a graph. diff --git a/wave_lang/kernel/wave/mlir_converter/fx_emitter.py b/wave_lang/kernel/wave/mlir_converter/fx_emitter.py index 32a6aa3996..a90f80c0f8 100644 --- a/wave_lang/kernel/wave/mlir_converter/fx_emitter.py +++ b/wave_lang/kernel/wave/mlir_converter/fx_emitter.py @@ -97,6 +97,7 @@ from wave_lang.kernel.wave.mlir_converter.diagnostics import MLIRDiagnostic, WaterError from wave_lang.kernel.wave.mlir_converter.mlir_converter import FxEmitterResponse from wave_lang.kernel.wave.mlir_converter.water_emitter import serialize_location +from wave_lang.kernel.wave.region_canonicalization import canonicalize_region_captures from wave_lang.kernel.wave.utils.symbol_utils import get_induction_symbol from wave_lang.support.indexing import index_symbol, IndexSequence from wave_lang.kernel._support.indexing import IndexExpr, IndexSymbol, safe_subs @@ -149,7 +150,6 @@ symbol_attr_to_name, ) - # Converted attribute value: The union of all types that may be produced # for a single MLIR attribute. AttrValue = ( @@ -733,6 +733,7 @@ def _handle_allocate_op(op: AllocateOp, parse_ctx: _OpParseContext) -> None: offset=offset, tail_padding=tail_padding, ) + allocate_op.fx_node.type = allocate_op.type _apply_mlir_attrs_to_fx_node(allocate_op.fx_node, converted_attrs) parse_ctx.add_mapping(op.result, allocate_op.fx_node) @@ -1310,7 +1311,6 @@ def _create_get_result_nodes( if isinstance(result_index, dict): get_result_op.fx_node.index = result_index - parse_ctx.add_mapping(result, get_result_op.fx_node) @@ -1322,7 +1322,7 @@ def _handle_iterate_op(op: IterateOp, parse_ctx: _OpParseContext) -> None: (operands and block arguments). Then creates a nested subgraph for the iterate body: - Iterator axis from the iterator attribute - Init args become IterArg placeholders in the subgraph - - Captures (explicit after makeIsolated) are mapped directly to outer values + - Captures (explicit after makeIsolated) become lifted placeholders in the subgraph - GetResult nodes for each iterate result """ axis = index_symbol(symbol_attr_to_name(op.iterator)) @@ -1349,7 +1349,7 @@ def _handle_iterate_op(op: IterateOp, parse_ctx: _OpParseContext) -> None: # Create a local scope for the iterate body. # - IterArg block arguments -> new IterArg placeholder nodes in subgraph - # - Capture block arguments -> mapped directly to outer values (no placeholders) + # - Capture block arguments -> new lifted placeholders in subgraph local_map: dict[ir.Value, fx.Node | int | float] = {} # Map iter args to new placeholder nodes in the subgraph @@ -1368,10 +1368,17 @@ def _handle_iterate_op(op: IterateOp, parse_ctx: _OpParseContext) -> None: arg_node.vector_shapes = dict(init_node.vector_shapes) local_map[block_arg] = arg_node - # Map capture block arguments directly to their outer values rather than - # creating lifted placeholders (the graph comparison handles both forms). + # Rebuild region captures as lifted placeholders instead of mapping block + # arguments directly to outer values. MLIR uses block arguments for captures + # (and admits both explicit and implicit capture forms), but the FX-side + # canonical region form represents them as explicit local placeholders with + # `meta["lifted"]` links so MLIR -> FX -> MLIR roundtrips preserve the + # canonical isolated interface. for block_arg, capture_node in zip(block_args[iter_count:], captures): - local_map[block_arg] = capture_node + capture_placeholder = Iterate.materialize_capture_placeholder( + subgraph, capture_node + ) + local_map[block_arg] = capture_placeholder # Parse the body operations. All values now resolve within local_map. _convert_ops( @@ -1427,6 +1434,16 @@ def _handle_iterate_op(op: IterateOp, parse_ctx: _OpParseContext) -> None: else: iterate_op.fx_node.type = result_types + converted_attrs = _convert_supported_attrs( + op, + ignore_attrs={ + AttrNames.INDEX.mlir_name, + "iterator", + "operandSegmentSizes", + }, + ) + _apply_mlir_attrs_to_fx_node(iterate_op.fx_node, converted_attrs) + # Create GetResult nodes for each iterate result _create_get_result_nodes( parse_ctx, @@ -1615,6 +1632,7 @@ def _diagnostics_handler(d: ir.Diagnostic) -> bool: _initialize_vector_shapes(trace, hw) _initialize_tiling_constraints(trace, constraints) + canonicalize_region_captures(trace) return trace, constraints, options, diagnostics diff --git a/wave_lang/kernel/wave/mlir_converter/mlir_converter.py b/wave_lang/kernel/wave/mlir_converter/mlir_converter.py index 5d1aa11d76..0e4985ee36 100644 --- a/wave_lang/kernel/wave/mlir_converter/mlir_converter.py +++ b/wave_lang/kernel/wave/mlir_converter/mlir_converter.py @@ -27,6 +27,7 @@ from wave_lang.kernel._support.tracing import CapturedTrace from wave_lang.kernel.wave.compile_options import WaveCompileOptions from wave_lang.kernel.wave.constraints import Constraint +from wave_lang.kernel.wave.region_canonicalization import canonicalize_region_captures from wave_lang.kernel.wave.mlir_converter.diagnostics import ( FileLocation, MLIRDiagnostic, @@ -262,6 +263,8 @@ def _prepare_water_request( Snapshots the trace's node state, expands the water-analysis pipeline if requested, and returns the dill-serialized request bytes. """ + canonicalize_region_captures(trace) + # Ensure additional node fields (like .type) are not lost during pickling. trace.snapshot_node_state() @@ -323,6 +326,7 @@ def _unpack_fx_response( f"fx_emitter trace has unexpected type: {type(response.trace)}" ) response.trace.restore_node_state() + canonicalize_region_captures(response.trace) return response.trace, response.constraints, response.options, response.diagnostics diff --git a/wave_lang/kernel/wave/mlir_converter/water_emitter.py b/wave_lang/kernel/wave/mlir_converter/water_emitter.py index 1cd444859c..8de2b81d43 100644 --- a/wave_lang/kernel/wave/mlir_converter/water_emitter.py +++ b/wave_lang/kernel/wave/mlir_converter/water_emitter.py @@ -74,7 +74,6 @@ NewRegister, Output, Placeholder, - Placeholder, Read, ReduceOp as Reduce, Reshape, diff --git a/wave_lang/kernel/wave/multicast.py b/wave_lang/kernel/wave/multicast.py index a7241b10f3..73a6b9edee 100644 --- a/wave_lang/kernel/wave/multicast.py +++ b/wave_lang/kernel/wave/multicast.py @@ -22,6 +22,7 @@ from .._support.indexing import IndexExpr, IndexSymbol, IndexSequence, safe_subs from .._support.tracing import CapturedTrace from ..lang.global_symbols import WORKGROUP_0, WORKGROUP_1, WORKGROUP_2 +from .region_canonicalization import RegionFormat, requires_region_format from ..ops.wave_ops import TensorLoadToLDS, get_custom from .compile_options import WaveCompileOptions from .constraints import Constraint, WorkgroupConstraint @@ -89,6 +90,7 @@ def compute_multicast_mask( return sympy.simplify(mask_expr) +@requires_region_format(RegionFormat.DIRECT_OUTER_REF) def multicast( trace: CapturedTrace, constraints: list[Constraint], diff --git a/wave_lang/kernel/wave/opsel_scaled_mfma.py b/wave_lang/kernel/wave/opsel_scaled_mfma.py index eacebd963e..3db1031c5b 100644 --- a/wave_lang/kernel/wave/opsel_scaled_mfma.py +++ b/wave_lang/kernel/wave/opsel_scaled_mfma.py @@ -197,24 +197,48 @@ def _find_mergeable_groups( yield_operands = list(yield_op.operands) # Collect per-arg info: (iter_index, offset, init_source, yield_source). + # yield_source may be None when the yield value doesn't trace back to + # extract_strided_slice (e.g. pipelined sub-word loads from a swapped + # double-buffer). Such groups are still mergeable — we construct the + # yield vector from individual bytes later. eligible = [] for i, iter_arg in enumerate(for_view.inner_iter_args): if iter_arg.type != v1xi8: continue init_info = _trace_extract_strided_slice(init_args[i]) - yield_info = _trace_extract_strided_slice(yield_operands[i]) - if init_info is None or yield_info is None: + if init_info is None: continue init_src, init_off = init_info - yield_src, yield_off = yield_info - if init_off != yield_off: - continue + yield_info = _trace_extract_strided_slice(yield_operands[i]) + if yield_info is not None: + yield_src, yield_off = yield_info + if init_off != yield_off: + # Init and yield extract different bytes from their respective + # dword loads. This happens when the pipeliner shifts the LDS + # address by one scale element between init and yield (e.g. + # init loads at addr X and extracts byte 1, yield loads at + # addr X+1 and extracts byte 0 — same physical byte). + # Mark yield as untraceable so the coalescer constructs the + # yield vector from the yield source's parent vector<4xi8>. + yield_src = None + logger.debug( + f"iter_arg {i}: init offset {init_off} != yield offset " + f"{yield_off} — treating yield as untraceable for " + f"dword coalescing" + ) + else: + yield_src = None + logger.debug( + f"iter_arg {i}: init traces to offset {init_off} but yield " + f"value does not trace to extract_strided_slice — will " + f"construct yield vector from individual bytes" + ) eligible.append((i, init_off, init_src, yield_src)) by_init_src = defaultdict(list) for entry in eligible: _, _, init_src, _ = entry - by_init_src[id(init_src.owner)].append(entry) + by_init_src[hash(init_src.owner)].append(entry) result = [] for entries in by_init_src.values(): @@ -232,13 +256,21 @@ def _find_mergeable_groups( init_source = None yield_owners = set() yield_source = None + has_untraceable_yield = False for o in present_offsets: idx, _, isrc, ysrc = by_offset[o].pop(0) members[o] = idx init_source = isrc - yield_source = ysrc - yield_owners.add(id(ysrc.owner)) - if len(yield_owners) == 1: + if ysrc is not None: + yield_source = ysrc + yield_owners.add(hash(ysrc.owner)) + else: + has_untraceable_yield = True + if has_untraceable_yield: + # Some yield values don't trace — the yield vector will be + # constructed from individual bytes during coalescing. + result.append((init_source, None, members)) + elif len(yield_owners) == 1: result.append((init_source, yield_source, members)) present_offsets = [o for o in range(SCALE_VECTOR_WIDTH) if by_offset[o]] @@ -381,6 +413,61 @@ def make_extract_slice(source: Value, offset: int): logger.debug(f"Coalescing {len(groups)} group(s) of vector<1xi8> iter_args") + # For groups whose yield values don't trace to extract_strided_slice + # (e.g. pipelined sub-word loads from a swapped double-buffer), + # create a single wide vector<4xi8> load at the same base address + # as the byte-0 yield load. This replaces the sub-word loads with + # one dword load that waveasm can map to a ds_read_b32. + old_yield_operands = list(yield_op.operands) + for g_idx, (init_source, yield_source, members) in enumerate(groups): + if yield_source is not None: + continue + + # Try to find a yield value whose source vector<4xi8> already + # exists (from an extract_strided_slice of a wider load). + # This handles the address-shifted pattern where init and yield + # extract different offsets from their respective dword loads. + any_member_idx = next(iter(members.values())) + any_yield = old_yield_operands[any_member_idx] + any_yield_info = _trace_extract_strided_slice(any_yield) + if any_yield_info is not None: + # The yield value is an extract from a vector<4xi8> — use + # the source vector directly as the merged yield value. + yield_vec_src, _ = any_yield_info + groups[g_idx] = (init_source, yield_vec_src, members) + logger.debug( + f"Group {g_idx}: reusing existing vector<4xi8> yield " + f"source (offset-shifted pattern)" + ) + continue + + if 0 not in members: + logger.debug( + f"Group {g_idx}: no byte-0 member, cannot determine base " + f"address for wide load — skipping" + ) + continue + byte0_yield = old_yield_operands[members[0]] + byte0_op = byte0_yield.owner + if not _is_op_named(byte0_op, "vector.load"): + logger.debug( + f"Group {g_idx}: byte-0 yield is not a vector.load " + f"({byte0_op.name}) — skipping" + ) + continue + load_view = byte0_op.opview + memref = load_view.base + indices = list(load_view.indices) + v4xi8 = VectorType.get([SCALE_VECTOR_WIDTH], i8) + with InsertionPoint(byte0_op): + wide_load = vector_d.load(v4xi8, memref, indices) + groups[g_idx] = (init_source, wide_load, members) + logger.debug( + f"Created wide vector<{SCALE_VECTOR_WIDTH}xi8> load for " + f"group {g_idx} yield value (replaces {len(members)} " + f"sub-word loads)" + ) + plan = _build_coalesce_plan(groups, for_view, yield_op) old_iter_args = list(for_view.inner_iter_args) old_iv = for_view.induction_variable @@ -431,6 +518,191 @@ def make_extract_slice(source: Value, offset: int): for_op.erase() +def _get_affine_constant_offset(op) -> Optional[int]: + """Extract the trailing integer constant from an affine.apply's map. + + For ``affine_map<()[s0,s1] -> (expr + 2)>`` returns ``2``. + For ``affine_map<()[s0,s1] -> (expr)>`` (no trailing constant) returns ``0``. + Returns ``None`` if *op* is not an affine.apply or the map cannot be parsed. + """ + if not _is_op_named(op, "affine.apply"): + return None + import re + + map_str = str(op.attributes["map"]) + m = re.search(r"\+\s*(\d+)\)\s*>\s*$", map_str) + if m: + return int(m.group(1)) + if re.search(r"\)\s*>\s*$", map_str): + return 0 + return None + + +def _affine_base_key(op) -> Optional[tuple]: + """Return a key that identifies the non-constant base of an affine.apply. + + Two affine.apply ops with the same base key compute addresses that + differ only by a compile-time constant offset, so a single wide load + at the offset-0 address covers all of them. + + Returns ``None`` when *op* is not an affine.apply. + """ + if not _is_op_named(op, "affine.apply"): + return None + operand_hashes = tuple(hash(v) for v in op.operands) + import re + + map_str = str(op.attributes["map"]) + m = re.search(r"\+\s*\d+\)\s*>\s*$", map_str) + if m: + base_map = map_str[: m.start()].rstrip() + ")>" + else: + base_map = map_str + return (base_map, operand_hashes) + + +def _merge_scale_byte_loads(module: Module) -> None: + """Merge adjacent vector<1/2xi8> LDS loads into vector<4xi8> + extract. + + Handles the unrolled-iteration pattern where sub-word scale loads + aren't loop-carried and thus not covered by _coalesce_vector_iter_args. + For vector<2xi8> loads, replaces downstream extract_strided_slice(size=1) + users with direct byte extracts from the wide vector<4xi8>, so that + _trace_scale_chain sees the correct source type for opsel. + """ + i8 = IntegerType.get_signless(8) + i64 = IntegerType.get_signless(64) + v1xi8 = VectorType.get([1], i8) + + byte_loads: list[Operation] = [] + for op in _walk_operations(module.operation): + if not _is_op_named(op, "vector.load"): + continue + rtype = op.results[0].type + if not isinstance(rtype, VectorType) or rtype.element_type != i8: + continue + if rtype.shape[0] > 2: + continue + view = op.opview + addr = view.indices[-1] + if _get_affine_constant_offset(addr.owner) is None: + continue + byte_loads.append(op) + + if not byte_loads: + return + + def _full_group_key(op: Operation): + """Group by memref, all non-last indices, AND the affine base expr.""" + view = op.opview + memref_h = hash(view.base) + prefix_hashes = tuple(hash(v) for v in list(view.indices)[:-1]) + addr = view.indices[-1] + base_k = _affine_base_key(addr.owner) + return (memref_h, prefix_hashes, base_k) + + by_group: dict = defaultdict(list) + for op in byte_loads: + by_group[_full_group_key(op)].append(op) + + for loads in by_group.values(): + if len(loads) < 2: + continue + + addr_to_offset: dict[int, int] = {} + base_op: Optional[Operation] = None + for load_op in loads: + addr_val = load_op.opview.indices[-1] + off = _get_affine_constant_offset(addr_val.owner) + if off is None: + continue + if off >= SCALE_VECTOR_WIDTH: + continue + addr_to_offset[id(load_op)] = off + if off == 0: + base_op = load_op + + if base_op is None: + continue + + base_view = base_op.opview + v4xi8 = VectorType.get([SCALE_VECTOR_WIDTH], i8) + + earliest = base_op + for load_op in loads: + if id(load_op) in addr_to_offset: + earliest = load_op + break + + with InsertionPoint(earliest): + wide = vector_d.load(v4xi8, base_view.base, list(base_view.indices)) + + replaced = 0 + for load_op in loads: + if id(load_op) not in addr_to_offset: + continue + off = addr_to_offset[id(load_op)] + rtype = load_op.results[0].type + n = rtype.shape[0] + if off + n > SCALE_VECTOR_WIDTH: + continue + + if n == 1: + with InsertionPoint(load_op): + offsets = ArrayAttr.get([IntegerAttr.get(i64, off)]) + sizes = ArrayAttr.get([IntegerAttr.get(i64, 1)]) + strides = ArrayAttr.get([IntegerAttr.get(i64, 1)]) + ext = vector_d.ExtractStridedSliceOp( + v1xi8, wide, offsets, sizes, strides + ) + load_op.results[0].replace_all_uses_with(ext.result) + load_op.erase() + replaced += 1 + else: + ess_users = [] + for use in list(load_op.results[0].uses): + user_op = use.owner + if ( + _is_op_named(user_op, "vector.extract_strided_slice") + and IntegerAttr(user_op.opview.sizes[0]).value == 1 + ): + inner_off = IntegerAttr(user_op.opview.offsets[0]).value + byte_off = off + inner_off + if byte_off < SCALE_VECTOR_WIDTH: + ess_users.append((user_op, byte_off)) + + for user_op, byte_off in ess_users: + with InsertionPoint(user_op): + offsets = ArrayAttr.get([IntegerAttr.get(i64, byte_off)]) + sizes = ArrayAttr.get([IntegerAttr.get(i64, 1)]) + strides = ArrayAttr.get([IntegerAttr.get(i64, 1)]) + ext = vector_d.ExtractStridedSliceOp( + v1xi8, wide, offsets, sizes, strides + ) + user_op.results[0].replace_all_uses_with(ext.result) + user_op.erase() + + has_remaining = any(True for _ in load_op.results[0].uses) + if has_remaining: + with InsertionPoint(load_op): + offsets = ArrayAttr.get([IntegerAttr.get(i64, off)]) + sizes = ArrayAttr.get([IntegerAttr.get(i64, n)]) + strides = ArrayAttr.get([IntegerAttr.get(i64, 1)]) + fallback = vector_d.ExtractStridedSliceOp( + rtype, wide, offsets, sizes, strides + ) + load_op.results[0].replace_all_uses_with(fallback.result) + + load_op.erase() + replaced += 1 + + if replaced: + logger.debug( + f"Merged {replaced} sub-word LDS loads into one " + f"vector<{SCALE_VECTOR_WIDTH}xi8> load" + ) + + def apply_opsel_scaled_mfma(module: Module): """Walk the MLIR module and apply the opsel optimization to scaled_mfma ops. @@ -444,6 +716,7 @@ def apply_opsel_scaled_mfma(module: Module): with mlir_ctx, Location.unknown(): _coalesce_vector_iter_args(module) + _merge_scale_byte_loads(module) f8e8m0 = Float8E8M0FNUType.get() diff --git a/wave_lang/kernel/wave/preshuffle_scale_to_shared.py b/wave_lang/kernel/wave/preshuffle_scale_to_shared.py index 6d99adf64d..83883393d3 100644 --- a/wave_lang/kernel/wave/preshuffle_scale_to_shared.py +++ b/wave_lang/kernel/wave/preshuffle_scale_to_shared.py @@ -38,6 +38,8 @@ import sympy import torch.fx as fx +from .region_canonicalization import RegionFormat, requires_region_format + from wave_lang.support.logging import get_logger from .._support.indexing import IndexSequence @@ -186,6 +188,7 @@ def _create_wide_read_1d( read.erase() +@requires_region_format(RegionFormat.SCHEDULE_SIGNATURE_PLACEHOLDERS) def preshuffle_scale_to_shared(trace: CapturedTrace, constraints: list[Constraint]): """Transform shared memory layout for preshuffle scale buffers. @@ -396,11 +399,6 @@ def _transform_scale_memory( input_read.update_arg("mapping", None) get_custom(input_read.fx_node).erase() - # --- Transform reads --- - # LDS data is now in preshuffle physical order. Each MMA read needs - # one scale byte at logical (k, m). The preshuffle formula decomposes - # into constant_base + lane_id * 4 when k_offset is a multiple of 4 - # and m_offset is a multiple of 16 (guaranteed by MMA tiling). read_infos = [] for node in trace.walk(lambda n: isinstance(get_custom(n), Read)): read = get_custom(node) diff --git a/wave_lang/kernel/wave/promotion.py b/wave_lang/kernel/wave/promotion.py index f511bd75ed..66f97993e6 100644 --- a/wave_lang/kernel/wave/promotion.py +++ b/wave_lang/kernel/wave/promotion.py @@ -12,6 +12,7 @@ from .._support.tracing import CapturedTrace from ..lang.global_symbols import * from ..ops.wave_ops import * +from .region_canonicalization import RegionFormat, requires_region_format from .constraints import Constraint, get_constrained_shape from .utils.classes import KernelLaunchInfo from .utils.graph_utils import move_node_after @@ -255,6 +256,7 @@ def promote_placeholders( fix_manual_allocate_dependencies(graph) +@requires_region_format(RegionFormat.DIRECT_OUTER_REF) def compute_shared_memory_usage( graph: CapturedTrace, kernel_launch_info: KernelLaunchInfo ): diff --git a/wave_lang/kernel/wave/region_canonicalization.py b/wave_lang/kernel/wave/region_canonicalization.py new file mode 100644 index 0000000000..a92706b05a --- /dev/null +++ b/wave_lang/kernel/wave/region_canonicalization.py @@ -0,0 +1,603 @@ +# Copyright 2026 The Wave Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import functools +import inspect +from enum import Enum +from typing import Callable + +import torch.fx as fx + +from .._support.tracing import CapturedTrace +from ..ops.wave_ops import ( + Conditional, + IterArg, + Iterate, + NestedRegionOp, + NewRegister, + Placeholder, + get_custom, +) + + +class RegionFormat(str, Enum): + """Region representation required by a pass. + + `ISOLATED` is the canonical/default format of the compilation pipeline, so + passes that operate on the canonical form do not need an annotation. See + `docs/wave/canonical_ir_format.md` for the structural definition of each + supported region form. + """ + + ISOLATED = "canonical" + LEGACY_PLACEHOLDERS = "legacy_capture_placeholders" + DIRECT_OUTER_REF = "direct_capture_refs" + SCHEDULE_SIGNATURE_PLACEHOLDERS = "schedule_signature_placeholders" + + +_REQUIRED_REGION_FORMAT_ATTR = "_wave_required_region_format" +_RAW_GRAPH_PASS_ATTR = "_wave_raw_graph_pass" +_CANONICALIZE_OUTPUT_KWARG = "canonicalize_output" + + +def _iter_nested_regions(graph: fx.Graph): + """Yield each nested region op together with its subgraph.""" + for node in graph.nodes: + custom = get_custom(node) + if not isinstance(custom, NestedRegionOp): + continue + root_subgraphs = custom.get_root_graph().subgraphs + if custom.subgraph_name not in root_subgraphs: + raise ValueError( + f"Nested region {custom.fx_node.name} references missing subgraph " + f"{custom.subgraph_name}" + ) + subgraph = root_subgraphs[custom.subgraph_name] + yield custom, subgraph + yield from _iter_nested_regions(subgraph) + + +def _replace_direct_capture_aliases( + graph: fx.Graph, source: fx.Node, replacement: fx.Node +) -> None: + """Rewrite direct outer capture references to a local replacement node.""" + # Some region-local users may still point at the outer value directly. + # Rewrite just those aliases to the new local placeholder. + for user in graph.nodes: + + def rewrite(arg): + if not isinstance(arg, fx.Node): + return arg + if arg.graph is graph: + return arg + if NestedRegionOp.capture_source(arg) is source: + return replacement + return arg + + user._update_args_kwargs( + fx.map_arg(user.args, rewrite), + fx.map_arg(user.kwargs, rewrite), + ) + + +def _collect_capture_sources( + region: NestedRegionOp, subgraph: fx.Graph +) -> list[fx.Node]: + """Collect outer capture sources in stable signature order.""" + # dict-keyed-by-None is used as an insertion-order set. + captures: dict[fx.Node, None] = {} + + def remember(source: fx.Node): + source = region.capture_source(source) + captures.setdefault(source, None) + + for source in region.implicit_captures: + remember(source) + ordered_placeholders = [ + node + for node in subgraph.nodes + if isinstance(get_custom(node), Placeholder) + and not isinstance(get_custom(node), IterArg) + ] + for local_capture_idx, local_capture in enumerate(ordered_placeholders): + source = NestedRegionOp.capture_source(local_capture) + if source is local_capture or source.graph is subgraph: + continue + resolved = _try_resolve_legacy_capture_source( + region, subgraph, local_capture, local_capture_idx + ) + if resolved is not None: + remember(resolved) + for node in subgraph.nodes: + custom = get_custom(node) + if isinstance(custom, (IterArg, Placeholder)): + continue + source = NestedRegionOp.capture_source(node) + if source is node or source.graph is subgraph: + continue + remember(source) + for source in _collect_direct_capture_uses(subgraph): + remember(source) + return list(captures) + + +def _collect_direct_capture_uses(graph: fx.Graph) -> list[fx.Node]: + """Return outer values referenced directly in `graph`.""" + captures: list[fx.Node] = [] + seen: set[fx.Node] = set() + + def visit(arg): + if not isinstance(arg, fx.Node): + return arg + if arg.graph is graph: + return arg + source = NestedRegionOp.capture_source(arg) + if source.graph is graph or source in seen: + return arg + seen.add(source) + captures.append(source) + return arg + + for nested_node in graph.nodes: + custom = get_custom(nested_node) + if isinstance(custom, Iterate): + fx.map_arg((custom.init_args, custom.start, custom.condition), visit) + continue + if isinstance(custom, Conditional): + fx.map_arg((custom.condition, custom.else_return), visit) + continue + fx.map_arg((nested_node.args, nested_node.kwargs), visit) + return captures + + +def _build_capture_lookup( + region: NestedRegionOp, subgraph: fx.Graph +) -> tuple[dict[fx.Node, fx.Node], list[fx.Node]]: + """Build direct-ref-aware capture lookup data for one region subgraph.""" + by_outer: dict[fx.Node, fx.Node] = {} + for var in region.captured_vars(subgraph): + # Keep the first local representative we encounter for each outer + # source. Canonicalization collapses duplicates later, so lookup only + # needs one stable local binding per outer capture. + by_outer.setdefault(region.capture_source(var), var) + direct_sources_in_order = [ + region.capture_source(source) + for source in _collect_direct_capture_uses(subgraph) + ] + for source in direct_sources_in_order: + by_outer.setdefault(source, source) + return by_outer, direct_sources_in_order + + +def _try_resolve_legacy_capture_source( + region: NestedRegionOp, subgraph: fx.Graph, local_capture: fx.Node, capture_idx: int +) -> fx.Node | None: + """Try to map a legacy local capture node back to its outer source.""" + source = NestedRegionOp.capture_source(local_capture) + if source.graph is not subgraph: + return region.capture_source(source) + + for implicit_capture in region.implicit_captures: + outer_source = region.capture_source(implicit_capture) + if outer_source.name == local_capture.name: + return outer_source + + # Legacy placeholder-only regions may lose explicit outer-source links. + # In that representation the placeholder prefix still follows capture + # signature order, so placeholder position is a safe final fallback. Keep + # this path narrow to legacy placeholders rather than applying it to all + # lifted local nodes. + if isinstance(get_custom(local_capture), Placeholder) and capture_idx < len( + region.implicit_captures + ): + return region.capture_source(region.implicit_captures[capture_idx]) + + return None + + +def _ensure_capture_placeholder( + region: NestedRegionOp, subgraph: fx.Graph, source: fx.Node +) -> fx.Node: + """Ensure `source` is represented by a lifted placeholder in `subgraph`.""" + outer_source = region.capture_source(source) + local_capture = region.get_captured_fx_node( + subgraph, outer_source, lookup=_build_capture_lookup(region, subgraph) + ) + if local_capture is None: + return region.materialize_capture_placeholder(subgraph, outer_source) + if local_capture is outer_source: + placeholder = region.materialize_capture_placeholder( + subgraph, outer_source, getattr(outer_source, "location", None) + ) + _replace_direct_capture_aliases(subgraph, outer_source, placeholder) + return placeholder + local_capture.meta["lifted"] = outer_source + return local_capture + + +def _canonicalize_nested_region(region: NestedRegionOp, subgraph: fx.Graph) -> None: + """Rewrite one nested region into the canonical isolated form.""" + sources = _collect_capture_sources(region, subgraph) + region.update_arg("implicit_captures", sources) + + # Gather every legacy local representative for each outer capture source so + # they can be collapsed to a single canonical placeholder. + legacy_capture_nodes: dict[fx.Node, list[fx.Node]] = {} + ordered_placeholders = [ + node + for node in subgraph.nodes + if isinstance(get_custom(node), Placeholder) + and not isinstance(get_custom(node), IterArg) + ] + for idx, local_capture in enumerate(ordered_placeholders): + source = _try_resolve_legacy_capture_source( + region, subgraph, local_capture, idx + ) + if source is None: + raise ValueError( + f"Could not resolve legacy capture placeholder {local_capture.name} " + f"in {region.subgraph_name} to an outer source" + ) + legacy_capture_nodes.setdefault(source, []).append(local_capture) + for local_capture in subgraph.nodes: + custom = get_custom(local_capture) + if isinstance(custom, IterArg): + continue + source = NestedRegionOp.capture_source(local_capture) + if source is local_capture or source.graph is subgraph: + continue + if isinstance(custom, Placeholder): + continue + legacy_capture_nodes.setdefault(source, []).append(local_capture) + + anchor = subgraph._root + for node in subgraph.nodes: + if isinstance(get_custom(node), IterArg): + anchor = node + continue + break + for source in sources: + legacy = legacy_capture_nodes.get(source, []) + template = legacy[0] if legacy else None + location = getattr(template, "location", None) if template else None + canonical = next( + (node for node in legacy if isinstance(get_custom(node), Placeholder)), + None, + ) + if canonical is None: + canonical = region.materialize_capture_placeholder( + subgraph, source, location + ) + + # Canonical capture placeholders must form the leading non-IterArg + # region input prefix in signature order. In `torch.fx`, `append` + # moves an existing node in place instead of duplicating it, so this + # handles both newly created and reused placeholders. + anchor.append(canonical) + anchor = canonical + canonical.meta["lifted"] = source + canonical.type = source.type + + _replace_direct_capture_aliases(subgraph, source, canonical) + for legacy_capture in legacy: + if legacy_capture is canonical: + continue + get_custom(legacy_capture).replace_uses_with( + canonical, graph=subgraph, propagate_location=False + ) + get_custom(legacy_capture).erase() + + region.refresh_captures(subgraph) + + +def canonicalize_region_captures(trace: CapturedTrace) -> None: + """Canonicalize capture handling for every nested region in `trace`.""" + root_graph = trace.get_root_graph() + for region, subgraph in _iter_nested_regions(root_graph): + _canonicalize_nested_region(region, subgraph) + + +def verify_canonical_region_captures(trace: CapturedTrace, where: str = "") -> None: + """Check that all nested regions satisfy the canonical capture invariant.""" + root_graph = trace.get_root_graph() + context = f" ({where})" if where else "" + for region, subgraph in _iter_nested_regions(root_graph): + direct_uses = _collect_direct_capture_uses(subgraph) + if direct_uses: + raise ValueError( + f"Direct outer capture references remain in {region.subgraph_name}{context}: " + + ", ".join(node.name for node in direct_uses) + ) + + capture_placeholders = [ + node + for node in subgraph.nodes + if isinstance(get_custom(node), Placeholder) + and not isinstance(get_custom(node), IterArg) + ] + non_lifted_placeholders = [ + node for node in capture_placeholders if "lifted" not in node.meta + ] + if non_lifted_placeholders: + raise ValueError( + f"Non-lifted region placeholders remain in {region.subgraph_name}{context}: " + + ", ".join(node.name for node in non_lifted_placeholders) + ) + + expected_prefix = [] + for node in subgraph.nodes: + custom = get_custom(node) + if isinstance(custom, IterArg): + continue + if isinstance(custom, Placeholder): + expected_prefix.append(node) + continue + break + if len(expected_prefix) != len(capture_placeholders): + misplaced = [ + node for node in capture_placeholders if node not in expected_prefix + ] + raise ValueError( + f"Canonical capture placeholders must be leading region inputs in " + f"{region.subgraph_name}{context}: " + + ", ".join(node.name for node in misplaced) + ) + + if len(expected_prefix) != len(region.implicit_captures): + raise ValueError( + f"Capture placeholder count mismatch in {region.subgraph_name}{context}: " + f"{len(expected_prefix)} local placeholders vs " + f"{len(region.implicit_captures)} implicit captures" + ) + + for source, capture_placeholder in zip( + region.implicit_captures, expected_prefix + ): + if NestedRegionOp.capture_source( + capture_placeholder + ) is not region.capture_source(source): + raise ValueError( + f"Capture placeholder source mismatch in {region.subgraph_name}{context}: " + f"{capture_placeholder.name} does not match the parent capture signature" + ) + + +def enable_legacy_capture_placeholders(trace: CapturedTrace) -> None: + """Convert canonical captures to the legacy placeholder-based view.""" + root_graph = trace.get_root_graph() + for region, subgraph in _iter_nested_regions(root_graph): + sources = _collect_capture_sources(region, subgraph) + region.update_arg("implicit_captures", sources) + for source in sources: + _ensure_capture_placeholder(region, subgraph, source) + + +def enable_direct_capture_refs(trace: CapturedTrace) -> None: + """Convert lifted placeholders back into direct outer references.""" + root_graph = trace.get_root_graph() + for region, subgraph in _iter_nested_regions(root_graph): + sources = _collect_capture_sources(region, subgraph) + replacements = { + local_capture: region.capture_source(source) + for source in sources + if (local_capture := region.get_captured_fx_node(subgraph, source)) + is not None + and local_capture is not region.capture_source(source) + } + for local_capture, outer_source in replacements.items(): + get_custom(local_capture).replace_uses_with( + outer_source, graph=subgraph, propagate_location=False + ) + get_custom(local_capture).erase() + region.refresh_captures( + subgraph, lookup=_build_capture_lookup(region, subgraph) + ) + + +def enable_schedule_signature_placeholders(trace: CapturedTrace) -> None: + """Keep placeholders only for schedule-signature sources.""" + root_graph = trace.get_root_graph() + for region, subgraph in _iter_nested_regions(root_graph): + all_sources = _collect_capture_sources(region, subgraph) + by_outer, _ = _build_capture_lookup(region, subgraph) + signature_sources = [ + source + for source in all_sources + if isinstance(get_custom(source), (Placeholder, NewRegister)) + ] + signature_source_set = set(signature_sources) + for source in all_sources: + outer_source = region.capture_source(source) + if outer_source in signature_source_set: + by_outer[outer_source] = _ensure_capture_placeholder( + region, subgraph, outer_source + ) + continue + local_capture = by_outer.get(outer_source) + if local_capture is None or local_capture is outer_source: + continue + get_custom(local_capture).replace_uses_with( + outer_source, graph=subgraph, propagate_location=False + ) + get_custom(local_capture).erase() + by_outer[outer_source] = outer_source + region.update_arg("implicit_captures", signature_sources) + + +def graph_pass_region_mode(graph_pass: Callable) -> RegionFormat: + """Return the declared region format for a graph pass.""" + graph_pass = raw_graph_pass(graph_pass) + graph_pass_fn = ( + graph_pass.func if isinstance(graph_pass, functools.partial) else graph_pass + ) + if inspect.ismethod(graph_pass_fn): + graph_pass_fn = graph_pass_fn.__func__ + return getattr(graph_pass_fn, _REQUIRED_REGION_FORMAT_ATTR, RegionFormat.ISOLATED) + + +def raw_graph_pass(graph_pass: Callable) -> Callable: + """Unwrap a decorated, partial, or bound pass to its raw implementation.""" + if isinstance(graph_pass, functools.partial): + raw_partial = functools.partial( + raw_graph_pass(graph_pass.func), + *graph_pass.args, + **(graph_pass.keywords or {}), + ) + setattr( + raw_partial, + "__name__", + getattr(raw_partial.func, "__name__", type(raw_partial.func).__name__), + ) + return raw_partial + if inspect.ismethod(graph_pass): + raw_func = getattr( + graph_pass.__func__, _RAW_GRAPH_PASS_ATTR, graph_pass.__func__ + ) + return raw_func.__get__(graph_pass.__self__, type(graph_pass.__self__)) + return getattr(graph_pass, _RAW_GRAPH_PASS_ATTR, graph_pass) + + +def requires_region_format( + region_format: RegionFormat, +) -> Callable[[Callable], Callable]: + """Wrap a graph pass so it runs in the requested region format. + + The decorated pass must expose a uniquely identifiable `trace` parameter, + either by name or by annotating exactly one parameter as `CapturedTrace`. + The wrapped pass enforces the full region-format boundary contract: + it canonicalizes the input trace, enables the requested temporary region + format for the pass body, and then canonicalizes the trace again after the + pass returns so the rest of the pipeline always sees canonical + `RegionFormat.ISOLATED` regions. + + By default, the wrapper canonicalizes the pass output before returning. The + wrapper also accepts a reserved keyword argument, `canonicalize_output=False`, + which disables the return to canonical form for one specific invocation. + This is useful for compatibility with legacy FileCheck tests. + """ + + def decorator(graph_pass: Callable) -> Callable: + trace_parameter_name = _resolve_trace_parameter_name(graph_pass) + graph_pass_signature = inspect.signature(graph_pass) + setattr(graph_pass, _REQUIRED_REGION_FORMAT_ATTR, region_format) + + @functools.wraps(graph_pass) + def wrapped(*args, **kwargs): + canonicalize_output_enabled = kwargs.pop(_CANONICALIZE_OUTPUT_KWARG, True) + if not isinstance(canonicalize_output_enabled, bool): + raise TypeError( + f"{_CANONICALIZE_OUTPUT_KWARG} must be a bool, got " + f"{type(canonicalize_output_enabled).__name__}" + ) + bound_args = graph_pass_signature.bind_partial(*args, **kwargs) + if trace_parameter_name not in bound_args.arguments: + raise TypeError( + f"Decorated pass {graph_pass.__name__} was called " + f"without binding `{trace_parameter_name}`" + ) + trace = bound_args.arguments[trace_parameter_name] + if not isinstance(trace, CapturedTrace): + raise TypeError( + f"Decorated pass {graph_pass.__name__} expected " + f"`{trace_parameter_name}` to be a CapturedTrace, got " + f"{type(trace).__name__}" + ) + prepare_region_captures_for_pass(trace, graph_pass, region_format) + result = graph_pass(*args, **kwargs) + if canonicalize_output_enabled: + canonicalize_region_captures(trace) + return result + + setattr(wrapped, _REQUIRED_REGION_FORMAT_ATTR, region_format) + setattr(wrapped, _RAW_GRAPH_PASS_ATTR, graph_pass) + return wrapped + + return decorator + + +def prepare_region_captures_for_pass( + trace: CapturedTrace, + graph_pass: Callable, + mode: RegionFormat | None = None, +) -> RegionFormat: + """Canonicalize `trace` and enable the region view required by a pass.""" + mode = graph_pass_region_mode(graph_pass) if mode is None else mode + prepare_region_captures(trace, mode) + return mode + + +def prepare_region_captures(trace: CapturedTrace, mode: RegionFormat) -> None: + """Canonicalize `trace` and enable a specific temporary region view.""" + canonicalize_region_captures(trace) + _rewrite_region(trace, mode) + + +def _rewrite_region(trace: CapturedTrace, mode: RegionFormat) -> None: + """Rewrite an already-canonical trace into a temporary region format.""" + if mode == RegionFormat.ISOLATED: + return + if mode == RegionFormat.LEGACY_PLACEHOLDERS: + enable_legacy_capture_placeholders(trace) + elif mode == RegionFormat.DIRECT_OUTER_REF: + enable_direct_capture_refs(trace) + elif mode == RegionFormat.SCHEDULE_SIGNATURE_PLACEHOLDERS: + enable_schedule_signature_placeholders(trace) + else: + raise ValueError(f"Unsupported region format: {mode}") + + +def wrap_graph_passes_with_region_adapters( + trace: CapturedTrace, graph_passes: list[Callable] +) -> list[Callable]: + """Wrap raw graph passes so each pass sees its declared region view.""" + canonicalize_region_captures(trace) + wrapped_passes = [] + for graph_pass in graph_passes: + mode = graph_pass_region_mode(graph_pass) + + def wrapped(graph_pass=graph_pass, mode=mode): + _rewrite_region(trace, mode) + result = graph_pass() + canonicalize_region_captures(trace) + return result + + wrapped.__name__ = getattr(graph_pass, "__name__", type(graph_pass).__name__) + setattr(wrapped, _RAW_GRAPH_PASS_ATTR, graph_pass) + wrapped_passes.append(wrapped) + return wrapped_passes + + +def _resolve_trace_parameter_name(graph_pass: Callable) -> str: + """Find the parameter that should receive the `CapturedTrace`.""" + try: + signature = inspect.signature(graph_pass) + except (TypeError, ValueError) as exc: + raise TypeError( + f"Cannot inspect pass signature for {graph_pass.__name__}" + ) from exc + + parameters = tuple(signature.parameters.values()) + candidates = [param.name for param in parameters if param.name == "trace"] + if not candidates: + candidates = [ + param.name + for param in parameters + if param.annotation is CapturedTrace + or param.annotation == "CapturedTrace" + or getattr(param.annotation, "__forward_arg__", None) == "CapturedTrace" + ] + + if len(candidates) == 1: + return candidates[0] + if not candidates: + raise TypeError( + f"Decorated pass {graph_pass.__name__} must expose a `trace` " + "parameter or annotate exactly one parameter as CapturedTrace" + ) + raise TypeError( + f"Decorated pass {graph_pass.__name__} has multiple possible " + f"trace parameters: {', '.join(candidates)}" + ) diff --git a/wave_lang/kernel/wave/schedule_reordering.py b/wave_lang/kernel/wave/schedule_reordering.py index 4f414ad1b9..96af5db9ce 100644 --- a/wave_lang/kernel/wave/schedule_reordering.py +++ b/wave_lang/kernel/wave/schedule_reordering.py @@ -19,6 +19,7 @@ from .._support.location import CapturedLocation from .._support.fx import filter_fx_graph from ..lang.global_symbols import * +from .region_canonicalization import RegionFormat, requires_region_format from ..ops.wave_ops import ( MMA, Conditional, @@ -456,6 +457,9 @@ def insert_cond_barrier(cond_reg, trace, graph, location: Optional[CapturedLocat cond_barrier.location = location barrier_graph.parent_op = cond_barrier trace.add_subgraph(barrier_graph_name, barrier_graph) + local_root = get_custom(cond_barrier).get_root_graph() + if barrier_graph_name not in local_root.subgraphs: + local_root.subgraphs[barrier_graph_name] = barrier_graph return cond_barrier @@ -854,6 +858,7 @@ def get_ops_of_type(graph, operation_type): ] +@requires_region_format(RegionFormat.DIRECT_OUTER_REF) def schedule_reordering( trace: CapturedTrace, constraints: list[Constraint], @@ -1048,11 +1053,21 @@ def schedule_reordering( original_subgraph_name = custom_iterate.subgraph_name reordered_subgraph_name = f"reoredered_{original_subgraph_name}" trace.add_subgraph(reordered_subgraph_name, reordered_graph) - trace.get_root_graph().subgraphs[reordered_subgraph_name] = reordered_graph + # When the iterate lives inside a nested region that carries its + # own `subgraphs` dict (e.g. a pipeline-guard conditional created + # by schedule.py), `get_root_graph()` resolves to that intermediate + # graph rather than the trace-level root. Register the reordered + # subgraph there so that `_iter_nested_regions` can find it. + local_root = get_custom(iterate_node).get_root_graph() + if reordered_subgraph_name not in local_root.subgraphs: + local_root.subgraphs[reordered_subgraph_name] = reordered_graph custom_iterate.update_arg("subgraph_name", reordered_subgraph_name) del trace.region_graph.subgraphs[original_subgraph_name] - del trace.get_root_graph().subgraphs[original_subgraph_name] + if original_subgraph_name in trace.get_root_graph().subgraphs: + del trace.get_root_graph().subgraphs[original_subgraph_name] + if original_subgraph_name in local_root.subgraphs: + del local_root.subgraphs[original_subgraph_name] if is_pingpong_strategy(reorder_strategy): add_conditional_barriers_to_loop(custom_iterate, trace, hardware_constraint) @@ -1381,7 +1396,9 @@ def _update_trace_with_reordered_graph( # Create new subgraph name and add to trace reordered_subgraph_name = f"reordered_{custom_iterate.subgraph_name}" trace.add_subgraph(reordered_subgraph_name, reordered_graph) - trace.get_root_graph().subgraphs[reordered_subgraph_name] = reordered_graph + local_root = custom_iterate.get_root_graph() + if reordered_subgraph_name not in local_root.subgraphs: + local_root.subgraphs[reordered_subgraph_name] = reordered_graph # Update the iterate operation to use the new subgraph custom_iterate.update_arg("subgraph_name", reordered_subgraph_name) diff --git a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py index 6069941284..2428b47f8f 100755 --- a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py +++ b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py @@ -1783,8 +1783,13 @@ def mxfp4_dbuf_schedule(): # Interleave MFMAs with memory ops (matching aiter f4gemm pattern). # Clamp start_offsets so they fit within each partition when the M # tile count is odd (e.g. 7 tiles split into 4+3). - base_offsets = [0, 3, 2, 0] - base_intervals = [4, 4, 2, 4] + n_mma = len(loop_scaled_mma_0) + if n_mma <= 4: + base_offsets = [0, 1, 1, 0] + base_intervals = [2, 2, 1, 2] + else: + base_offsets = [0, 3, 2, 0] + base_intervals = [4, 4, 2, 4] def _clamp_offsets(n, offsets): return [min(o, max(0, n - 1)) for o in offsets] @@ -1992,8 +1997,8 @@ def split_by_iteration(nodes, key="name"): tkw.reorder_graph(pipeline_loop.PROLOGUE, prologue_clusters) tkw.reorder_graph(pipeline_loop.KERNEL, clusters) - unroll_factor = 2 - tkw.unroll(pipeline_loop.KERNEL, unroll_factor) + unroll_factor = 2 + tkw.unroll(pipeline_loop.KERNEL, unroll_factor) tkw.insert_at_start( pipeline_loop.KERNEL, diff --git a/wave_lang/kernel/wave/scheduling/loop_reconstruction.py b/wave_lang/kernel/wave/scheduling/loop_reconstruction.py index 70e7c7b7b7..b5aa885075 100644 --- a/wave_lang/kernel/wave/scheduling/loop_reconstruction.py +++ b/wave_lang/kernel/wave/scheduling/loop_reconstruction.py @@ -851,12 +851,14 @@ def guard_g2s_with_bounds_check( max_iv = iterate.count induction_variable = get_induction_variable(iterate, constraints) - prefetch_offset = (num_stages - 1) * step - guard_condition = sympy.StrictLessThan( - induction_variable + prefetch_offset, max_iv - ) - + pipeline_depth = num_stages - 1 for g2s in g2s_nodes: + custom = get_custom(g2s) + unroll_iter = getattr(custom, "unroll_iteration", 0) or 0 + prefetch_offset = pipeline_depth + unroll_iter + guard_condition = sympy.StrictLessThan( + induction_variable + prefetch_offset, max_iv + ) g2s.meta["g2s_guard"] = guard_condition @@ -943,9 +945,11 @@ def construct_pipelined_loop( ) pipelined_reduction_graph.parent_op = pipelined_reduction - trace.add_subgraph( - get_custom(pipelined_reduction).subgraph_name, pipelined_reduction_graph - ) + subgraph_name = get_custom(pipelined_reduction).subgraph_name + trace.add_subgraph(subgraph_name, pipelined_reduction_graph) + local_root = get_custom(pipelined_reduction).get_root_graph() + if subgraph_name not in local_root.subgraphs: + local_root.subgraphs[subgraph_name] = pipelined_reduction_graph if eliminate_epilogue: pipelined_reduction.meta["eliminate_epilogue"] = True diff --git a/wave_lang/kernel/wave/scheduling/schedule.py b/wave_lang/kernel/wave/scheduling/schedule.py index 469ef3f5e0..4548c0f929 100644 --- a/wave_lang/kernel/wave/scheduling/schedule.py +++ b/wave_lang/kernel/wave/scheduling/schedule.py @@ -203,6 +203,56 @@ def schedule_reduction( ) +def _can_skip_pipeline_guard( + max_induction_variable, + rounding_stride: int, + constraints: list[Constraint], +) -> bool: + """Check if assumptions prove max_induction_variable >= rounding_stride. + + Uses divisibility forward substitutions to relate inequality + assumptions (e.g. ``K > BLOCK_K * 6``) to the simplified + max_induction_variable (e.g. ``K / 256``). Returns True when a + static lower bound can be derived that satisfies the guard. + """ + from ..assumptions import get_divisibility_subs + from ..utils.general_utils import get_assumptions + + assumptions = get_assumptions(constraints) + if not assumptions: + return False + + fwd, _ = get_divisibility_subs(constraints) + if not fwd: + return False + + max_iv_fwd = ( + max_induction_variable.subs(fwd) + if isinstance(max_induction_variable, sympy.Basic) + else max_induction_variable + ) + if not isinstance(max_iv_fwd, sympy.Symbol): + return False + + for assumption in assumptions: + expr = subs_idxc(assumption.expr) + if not isinstance(expr, sympy.core.relational.StrictGreaterThan): + continue + lhs, rhs = expr.args + if not rhs.is_number: + continue + lhs_fwd = lhs.subs(fwd) if isinstance(lhs, sympy.Basic) else lhs + coeff = lhs_fwd.as_coefficient(max_iv_fwd) + if coeff is not None and coeff.is_number and coeff > 0: + # lhs = coeff * max_iv_fwd > rhs + # => max_iv_fwd > rhs / coeff + # Since max_iv_fwd is a positive integer: max_iv_fwd >= floor(rhs/coeff) + 1 + lower_bound = int(rhs / coeff) + 1 + if lower_bound >= rounding_stride: + return True + return False + + def build_guarded_pipeline_with_remainder( trace: CapturedTrace, reduction: Iterate, @@ -488,8 +538,20 @@ def construct_pipelined_loop_adaptive( ) ) - if not is_dynamic: - # For static shapes, use the old implementation + # When assumptions prove the guard is always satisfied, treat the + # symbolic max_induction_variable like a static value: emit + # prologue + pipelined loop + epilogue with no conditional guard + # and no remainder loop — exactly the same structure as the static + # path but with symbolic address computation. + from math import lcm + + rounding_stride = lcm(num_stages, unroll_factor) + if not is_dynamic or _can_skip_pipeline_guard( + max_induction_variable, rounding_stride, constraints + ): + concrete_max_iv = ( + int(max_induction_variable) if not is_dynamic else max_induction_variable + ) new_reduction, node_mapping, _ = construct_pipelined_loop( trace, reduction, @@ -497,7 +559,7 @@ def construct_pipelined_loop_adaptive( constraints, num_stages, initiation_interval, - int(max_induction_variable), + concrete_max_iv, visualize, use_scheduling_barriers, multi_buffer_count, @@ -513,8 +575,8 @@ def construct_pipelined_loop_adaptive( ) return new_reduction, node_mapping - # For dynamic shapes, emit conditional + pipelined loop + remainder loop - # Call helper function to build the conditional structure + # Fallback: dynamic shapes without sufficient assumptions — + # emit conditional + pipelined loop + remainder loop. return build_guarded_pipeline_with_remainder( trace, reduction, diff --git a/wave_lang/kernel/wave/shared_memory_indexing.py b/wave_lang/kernel/wave/shared_memory_indexing.py index fd99879f8d..b56c23f965 100644 --- a/wave_lang/kernel/wave/shared_memory_indexing.py +++ b/wave_lang/kernel/wave/shared_memory_indexing.py @@ -10,9 +10,11 @@ from ..lang.global_symbols import * from ..ops.wave_ops import AtomicOp, Read, Write, get_custom from .constraints import Constraint +from .region_canonicalization import RegionFormat, requires_region_format from .utils.general_utils import is_shared_mem_access, remove_global_indexing +@requires_region_format(RegionFormat.DIRECT_OUTER_REF) def apply_shared_memory_indexing_corrections( trace: CapturedTrace, constraints: list[Constraint] ): diff --git a/wave_lang/kernel/wave/specialize.py b/wave_lang/kernel/wave/specialize.py index f40e2fc444..afe5202668 100644 --- a/wave_lang/kernel/wave/specialize.py +++ b/wave_lang/kernel/wave/specialize.py @@ -36,6 +36,8 @@ - load partition waiting on compute to signal empty. """ +from .region_canonicalization import RegionFormat, requires_region_format + import math from typing import Optional, List from collections import defaultdict @@ -171,6 +173,7 @@ def set_specialized_conditions( return (is_load_wave, is_compute_wave, wave_id) +@requires_region_format(RegionFormat.SCHEDULE_SIGNATURE_PLACEHOLDERS) def specialize_kernel( trace: CapturedTrace, constraints: list[Constraint], diff --git a/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py b/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py index 86b7f39049..b595bbbb8d 100755 --- a/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py +++ b/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py @@ -433,6 +433,13 @@ def get_tagged_mxfp4_gemm_preshuffle_b( constraints += [tkw.Assumption(Eq(N % 32, 0))] constraints += [tkw.Assumption(Eq(K % 256, 0))] + # K is always large enough for software pipelining. + constraints += [tkw.Assumption(K > BLOCK_K * 6)] + + # # K >= 2048 ensures the B-data preshuffle within_nblk (max 1023) + # # is always < K_PACKED (= K/2 >= 1024), eliminating dynamic floordiv. + # constraints += [tkw.Assumption(K >= 2048)] + if reorder_workgroups: new_wg0, new_wg1 = _reorder_mxfp4_workgroups( M, N, BLOCK_M, BLOCK_N, GROUP_SIZE_N diff --git a/wave_lang/kernel/wave/tensor_load_to_shared.py b/wave_lang/kernel/wave/tensor_load_to_shared.py index 72c22c936c..90df133aeb 100644 --- a/wave_lang/kernel/wave/tensor_load_to_shared.py +++ b/wave_lang/kernel/wave/tensor_load_to_shared.py @@ -43,6 +43,8 @@ - shared offset preserves tile-level index: similar structure to global offset """ +from .region_canonicalization import RegionFormat, requires_region_format + import logging import math from collections import defaultdict @@ -290,6 +292,7 @@ def clear_padding(write: Write): custom_memory.update_arg("distributed_shape", tuple(new_distributed_shape)) +@requires_region_format(RegionFormat.SCHEDULE_SIGNATURE_PLACEHOLDERS) def tensor_load_to_shared( trace: CapturedTrace, constraints: list[Constraint], diff --git a/wave_lang/kernel/wave/type_inference.py b/wave_lang/kernel/wave/type_inference.py index fba06abc6d..a91b0eaf27 100644 --- a/wave_lang/kernel/wave/type_inference.py +++ b/wave_lang/kernel/wave/type_inference.py @@ -8,12 +8,14 @@ from wave_lang.support.logging import get_logger from .constraints import Constraint +from .region_canonicalization import RegionFormat, requires_region_format from .._support.tracing import CapturedTrace from ..ops.wave_ops import * logger = get_logger("wave.type_inference") +@requires_region_format(RegionFormat.LEGACY_PLACEHOLDERS) def infer_types( trace: CapturedTrace, constraints: Optional[list[Constraint]] = None, diff --git a/wave_lang/kernel/wave/utils/graph_utils.py b/wave_lang/kernel/wave/utils/graph_utils.py index edcf838cae..650498309f 100644 --- a/wave_lang/kernel/wave/utils/graph_utils.py +++ b/wave_lang/kernel/wave/utils/graph_utils.py @@ -50,6 +50,7 @@ Write, get_custom, ) +from ..region_canonicalization import RegionFormat, requires_region_format from .classes import Failure, Result, Success from .symbol_utils import ( collect_allowed_induction_symbols, @@ -865,6 +866,16 @@ def assert_traces_equivalent( raise AssertionError(f"Traces are not equivalent: {check_result.error}") +def assert_traces_mutually_equivalent( + lhs: CapturedTrace, + rhs: CapturedTrace, + subs: Optional[dict[IndexSymbol, int]] = None, +) -> None: + """Assert structural equivalence in both directions.""" + assert_traces_equivalent(lhs, rhs, subs=subs) + assert_traces_equivalent(rhs, lhs, subs=subs) + + def assert_constraints_equivalent( lhs_constraints: Sequence[Any], rhs_constraints: Sequence[Any], @@ -1011,6 +1022,7 @@ def is_chained_getresult(node: fx.Node) -> bool: get_custom(node).graph.erase_node(node) +@requires_region_format(RegionFormat.DIRECT_OUTER_REF) def remove_chained_extractslice(trace: CapturedTrace): def is_chained_extractslice(node: fx.Node) -> bool: custom = get_custom(node) @@ -1362,6 +1374,7 @@ def is_iterate_subgraph(graph: fx.Graph): return isinstance(get_custom(graph.parent_op), Iterate) +@requires_region_format(RegionFormat.LEGACY_PLACEHOLDERS) def initialize_iter_args(trace: CapturedTrace) -> None: """ Initializes the IterArgs in each reduction with an index @@ -1382,12 +1395,6 @@ def initialize_iter_args(trace: CapturedTrace) -> None: count += 1 -def get_outer_node(outer_node: fx.Node) -> fx.Node: - while "lifted" in outer_node.meta: - outer_node = outer_node.meta["lifted"] - return outer_node - - def is_barrier_between_same_graph( src: fx.Node, dst: fx.Node, barId: int = -1, barrier_check: set = None ) -> Optional[fx.Node]: @@ -1674,7 +1681,7 @@ def prepare_subgraph_for_conditional( placeholder.meta["lifted"] = node placeholders[node] = placeholder - implicit_captures.append(get_outer_node(node)) + implicit_captures.append(NestedRegionOp.capture_source(node)) return subgraph, implicit_captures, placeholders @@ -1712,6 +1719,8 @@ def finish_conditional_subgraph( # Register subgraph with trace subgraph.parent_op = conditional trace.add_subgraph(subgraph._name, subgraph) - trace.get_root_graph().subgraphs[subgraph._name] = subgraph + local_root = get_custom(conditional).get_root_graph() + if subgraph._name not in local_root.subgraphs: + local_root.subgraphs[subgraph._name] = subgraph return conditional diff --git a/wave_lang/kernel/wave/utils/mapping_utils.py b/wave_lang/kernel/wave/utils/mapping_utils.py index 2248da418d..a6ce300437 100644 --- a/wave_lang/kernel/wave/utils/mapping_utils.py +++ b/wave_lang/kernel/wave/utils/mapping_utils.py @@ -5,11 +5,16 @@ from typing import TypeVar from copy import deepcopy +from math import gcd, lcm + import sympy import torch.fx as fx +from collections.abc import Sequence + from ..._support.indexing import IndexingContext from ...lang.wave_types import IndexMapping +from ..assumptions import get_divisibility_subs from .general_utils import infer_dim, get_fastest_index from .symbol_utils import IndexExpr, IndexSequence, IndexSymbol, simplify, subs_idxc from ....support.indexing import piecewise_aware_subs @@ -253,6 +258,344 @@ def _make_aligned_index(offset_val: int) -> dict[IndexExpr, IndexSequence]: return True +_INDUCTION_PREFIX = "$ARG" + + +def _expand_mod(expr: sympy.Expr) -> sympy.Expr: + """Rewrite ``Mod(x, d)`` as ``x - d*floor(x/d)``. + + This lets ``floor(x/d)*d + Mod(x, d)`` cancel to ``x`` via normal + algebraic simplification, which SymPy cannot do when ``Mod`` is + kept as an opaque node. + """ + if not expr.has(sympy.Mod): + return expr + return expr.replace( + lambda e: isinstance(e, sympy.Mod), + lambda e: e.args[0] - e.args[1] * sympy.floor(e.args[0] / e.args[1]), + ) + + +def _eval_concrete_floor_mod(expr: sympy.Expr) -> sympy.Expr: + """Evaluate ``floor`` and ``Mod`` nodes whose arguments are concrete. + + SymPy sometimes leaves ``floor(4)`` or ``Mod(0, K)`` un-evaluated + when they were built symbolically and then substituted. This pass + collapses them bottom-up so the result contains no unnecessary + wrappers around known integers. + """ + + def _try_eval(e): + if isinstance(e, sympy.floor): + inner = e.args[0] + if inner.is_number: + return sympy.Integer(int(sympy.floor(inner))) + if isinstance(e, sympy.Mod): + a, m = e.args + if a.is_Integer and m.is_Integer and int(m) != 0: + return sympy.Integer(int(a) % int(m)) + if a.is_Integer and int(a) == 0: + return sympy.Integer(0) + return e + + if not expr.has(sympy.floor) and not expr.has(sympy.Mod): + return expr + return expr.replace( + lambda e: isinstance(e, (sympy.floor, sympy.Mod)), + _try_eval, + ) + + +def mem_simplify(expr: sympy.Expr) -> sympy.Expr: + """Simplify an expression representing a tensor memory index. + + Domain-specific alternative to ``sympy.simplify`` that applies only + the algebraic identities relevant to integer memory addressing: + + 1. Rewrite ``Mod(x, d)`` as ``x - d*floor(x/d)`` so that paired + floor terms cancel algebraically. + 2. ``expand()`` to distribute products and cancel matching terms — + this collapses the fundamental round-trip + ``floor(E/D)*D + Mod(E, D) → E``. + 3. Evaluate any ``floor`` or ``Mod`` of concrete integers left over. + """ + expr = sympy.sympify(expr) + if expr.is_number: + return expr + if expr.has(sympy.Mod): + expr = _expand_mod(expr) + expr = sympy.expand(expr) + expr = _eval_concrete_floor_mod(expr) + return expr + + +def linearize_dims( + dim_exprs: list[sympy.Expr], + strides: list[sympy.Expr], +) -> sympy.Expr: + """Compute ``sum(dim_i * stride_i)`` with floor/Mod cancellation. + + Preshuffle mappings produce ``floor(flat/D)`` for one dimension and + ``Mod(flat, D)`` for another, with the first dimension's stride + equal to ``D``. The naive sum keeps SymPy's opaque ``Mod`` node + and ``simplify`` cannot prove the cancellation. + + By expanding ``Mod`` to its floor definition first and then calling + ``expand``, the matching ``floor`` terms cancel algebraically:: + + floor(E/D)*D + Mod(E,D) + → floor(E/D)*D + E - D*floor(E/D) + → E + """ + total = sum(d * s for d, s in zip(dim_exprs, strides)) + return mem_simplify(total) + + +def _infer_floor_to_exact(mem_strides: list[IndexExpr]) -> dict: + """Infer ``floor(sym/n) → sym/n`` substitutions from memory strides. + + When a memory stride has the form ``sym/n`` (symbolic numerator over + a positive integer denominator), the stride must be a non-negative + integer for the layout to be physically valid. That forces the + numerator to be a multiple of the denominator, so + ``floor(sym/n) == sym/n``. + + Substituting this into dimension expressions collapses + ``floor(E / floor(sym/n)) * (sym/n) + Mod(E, floor(sym/n))`` back + into ``E`` via the standard floor/Mod identity, eliminating the + symbolic divisor that the probing cannot handle. + """ + subs_map: dict[sympy.Expr, sympy.Expr] = {} + for stride in mem_strides: + stride = sympy.sympify(stride) + numer, denom = stride.as_numer_denom() + if denom.is_Integer and int(denom) > 1 and numer.free_symbols: + exact_quot = numer / denom + subs_map[sympy.floor(exact_quot)] = exact_quot + return subs_map + + +def compute_iv_stride_through_mapping( + mapping: IndexMapping, + symbolic_shape: tuple[IndexExpr, ...], + index: dict[IndexExpr, IndexSequence], + is_read: bool = True, + mem_strides: list[IndexExpr] | None = None, + constraints: Sequence = (), +) -> dict[sympy.Symbol, IndexExpr | list[IndexExpr]] | None: + """Compute each IV symbol's linearized stride through a mapping. + + Uses numerical probing: evaluates the linearized address at iv=0,1,...,P, + takes consecutive differences, and detects constant or cyclic stride + patterns. This handles symbolic divisors (e.g. ``floor(K/32)``) that + defeat the old ``sympy.coeff(_iv)`` approach. + + Parameters + ---------- + mem_strides : optional physical memory strides. When provided these are + used for linearization instead of strides derived from + ``symbolic_shape``. Callers with access to the physical memory + layout should pass these to ensure floor/Mod cancellation. + constraints : constraint sequence (may include ``Assumption`` objects). + Divisibility assumptions (e.g. ``Assumption(Eq(K % 256, 0))``) + are used to simplify floor/Mod expressions before probing. + + Returns ``{iv_sym: stride}`` (constant) or ``{iv_sym: [s0, s1, ...]}`` + (repeating cycle), or ``None`` when no IV is found. + """ + iters = mapping.iters + + iv_info: dict[sympy.Symbol, tuple[sympy.Symbol, int]] = {} + for dim_sym, iter_sym in mapping.output_mapping.items(): + seq = index.get(dim_sym) + if seq is None: + continue + start = sympy.sympify(seq.start if isinstance(seq, IndexSequence) else seq) + for sym in start.free_symbols: + if not str(sym).startswith(_INDUCTION_PREFIX): + continue + coeff = sympy.expand(start).coeff(sym) + concrete = subs_idxc(coeff) + if not isinstance(concrete, (int, sympy.Integer)): + print( + f"IV coeff for {sym} is non-concrete: coeff={coeff}" + f" resolved={concrete} (type={type(concrete).__name__})" + f" — skipping this IV" + ) + continue + iv_info[sym] = (iter_sym, int(concrete)) + + if not iv_info: + return None + + map_dims = mapping.input_shape if is_read else mapping.output_shape + raw_exprs = ( + mapping.map_input_indices(map_dims) + if is_read + else mapping.map_output_indices(map_dims) + ) + + idxc = IndexingContext.current() + dim_exprs = [subs_idxc(e) for e in raw_exprs] + + if mem_strides is None: + symbolic_shape_resolved = tuple(infer_dim(d) for d in symbolic_shape) + mem_strides = strides_from_symbolic_shape( + idxc, symbolic_shape_resolved, allow_mixed_shapes=True + ) + + div_fwd, div_bwd = get_divisibility_subs(constraints) + if div_fwd: + fwd_dict = dict(div_fwd) + dim_exprs = [sympy.sympify(e).subs(fwd_dict) for e in dim_exprs] + mem_strides = [sympy.sympify(s).subs(fwd_dict) for s in mem_strides] + else: + floor_subs = _infer_floor_to_exact(mem_strides) + if floor_subs: + dim_exprs = [sympy.sympify(e).subs(floor_subs) for e in dim_exprs] + + result: dict[sympy.Symbol, IndexExpr | list[IndexExpr]] = {} + + for iv_sym, (iv_iter, concrete_coeff) in iv_info.items(): + stride_or_cycle = _probe_iv_stride( + dim_exprs, mem_strides, iters, iv_iter, concrete_coeff + ) + if stride_or_cycle is None: + return None + result[iv_sym] = stride_or_cycle + + if div_bwd: + bwd_dict = dict(div_bwd) + + def _bwd(v): + if isinstance(v, list): + return [mem_simplify(sympy.sympify(x).subs(bwd_dict)) for x in v] + return mem_simplify(sympy.sympify(v).subs(bwd_dict)) + + result = {k: _bwd(v) for k, v in result.items()} + + return result + + +def _extract_integer_divisors(expr: sympy.Expr) -> set[int]: + """Collect positive integer divisors from ``floor`` and ``Mod`` nodes.""" + divisors: set[int] = set() + for sub in sympy.preorder_traversal(expr): + if isinstance(sub, sympy.Mod): + d = sub.args[1] + if d.is_Integer and int(d) > 0: + divisors.add(int(d)) + elif isinstance(sub, sympy.floor): + _, denom = sub.args[0].as_numer_denom() + if denom.is_Integer and int(denom) > 1: + divisors.add(int(denom)) + return divisors + + +_MAX_PROBE_DEPTH = 1024 + + +def _compute_probe_depth(flat_expr: sympy.Expr, concrete_coeff: int) -> int: + """Compute the exact probe depth from divisors in *flat_expr*. + + For each ``floor(expr/D)`` or ``Mod(expr, D)`` with integer D, the IV + contribution has period ``D / gcd(C, D)`` where ``C`` is the IV step. + The overall period ``P = LCM(all periods)`` is the exact number of + diffs needed to capture the full stride pattern. + + Raises ``ValueError`` if the computed depth exceeds ``_MAX_PROBE_DEPTH``, + which indicates a pathological mapping with too many coprime divisors. + """ + if concrete_coeff == 0: + return 1 + divisors = _extract_integer_divisors(flat_expr) + if not divisors: + return 1 + C = abs(concrete_coeff) + periods = [d // gcd(C, d) for d in divisors] + depth = lcm(*periods) + if depth > _MAX_PROBE_DEPTH: + raise ValueError( + f"Probe depth {depth} exceeds maximum {_MAX_PROBE_DEPTH}" + f" (divisors={sorted(divisors)}, C={C})." + f" The mapping has too many coprime floor/Mod divisors" + f" for exact stride analysis." + ) + return depth + + +def _probe_iv_stride( + dim_exprs: list[IndexExpr], + mem_strides: list[IndexExpr], + iters: dict, + iv_iter: sympy.Symbol, + concrete_coeff: int, +) -> IndexExpr | list[IndexExpr] | None: + """Compute the IV stride through the linearized address. + + 1. Linearize ``dim_exprs * mem_strides`` symbolically via + ``linearize_dims`` + ``mem_simplify`` to obtain the flat address + expression (cancelling floor/Mod round-trip pairs). + 2. Compute the exact probe depth ``P`` from the remaining integer + divisors in the flat expression: ``P = LCM(D_i / gcd(C, D_i))``. + 3. Evaluate ``P + 1`` concrete addresses and detect constant or + cyclic stride patterns from the integer diffs. + + If the addresses contain free symbols, this is an error — the + substitution context has unresolved chained dependencies. + + Returns a single IndexExpr (constant stride), a list of IndexExpr + (repeating cycle), or ``None`` on failure. + """ + + # Step 1: linearize symbolically, then compute probe depth from the + # flat expression's divisors. Apply subs_idxc to iv_flat so the + # probe-depth computation sees the same symbol resolution as the + # concrete address evaluations (prevents under-probing when a + # divisor is symbolic pre-subs but integer post-subs). + flat_expr = mem_simplify(linearize_dims(dim_exprs, mem_strides)) + iv_flat = flat_expr.subs( + { + it: (concrete_coeff * sympy.Symbol("_iv") if it == iv_iter else 0) + for it in iters.keys() + } + ) + iv_flat = subs_idxc(iv_flat) + probe_depth = _compute_probe_depth(iv_flat, concrete_coeff) + + # Step 2: evaluate P+1 concrete addresses. + def _linearized_addr(iv_val: int) -> IndexExpr: + subs = { + it: (concrete_coeff * iv_val if it == iv_iter else 0) for it in iters.keys() + } + resolved = [subs_idxc(dim_expr.subs(subs)) for dim_expr in dim_exprs] + return mem_simplify(subs_idxc(linearize_dims(resolved, mem_strides))) + + addrs: list[int] = [] + for iv in range(probe_depth + 1): + a = _linearized_addr(iv) + if getattr(a, "free_symbols", set()): + return None + addrs.append(int(a)) + + diffs = [addrs[i + 1] - addrs[i] for i in range(probe_depth)] + + if not diffs: + return None + + # Step 3: detect constant or shortest repeating cycle. + # The probe depth is the analytically-proven period, so one full + # repetition is sufficient — use range(1, probe_depth + 1). + for cycle_len in range(1, probe_depth + 1): + if all(diffs[i] == diffs[i % cycle_len] for i in range(probe_depth)): + cycle = [sympy.Integer(diffs[i]) for i in range(cycle_len)] + if cycle_len == 1: + return cycle[0] + return cycle + + return None + + def transform_index_on_mapping( mapping: IndexMapping, symbolic_shape: tuple[IndexExpr, ...], diff --git a/wave_lang/kernel/wave/utils/symbol_utils.py b/wave_lang/kernel/wave/utils/symbol_utils.py index 98540ea53d..13a5e993c0 100644 --- a/wave_lang/kernel/wave/utils/symbol_utils.py +++ b/wave_lang/kernel/wave/utils/symbol_utils.py @@ -109,10 +109,115 @@ def expr_bounds(expr: sympy.Expr) -> tuple[sympy.Expr, sympy.Expr] | None: return None +def _is_provably_divisible(term: sympy.Expr, divisor: sympy.Expr) -> bool: + """Check if *term* is provably an integer multiple of *divisor*. + + Works for both constant and symbolic divisors. For a compound divisor + like ``c * D`` (numeric ``c``, symbolic ``D``), the check decomposes: + term must contain ``D`` as a factor and its numeric coefficient must be + divisible by ``c``. + """ + if term.is_zero: + return True + if divisor.is_number and divisor.is_nonzero: + # Constant divisor: check if term has divisor as a factor. + if term.is_number: + return term.is_integer and (term % divisor == 0) + # term = coeff * rest: check if coeff is divisible. + if isinstance(term, sympy.Mul): + for arg in term.args: + if arg.is_number and arg.is_integer and (arg % divisor == 0): + return True + return False + # Decompose the divisor into numeric and symbolic parts. + # E.g. 8*floor(...) -> (8, floor(...)) + div_coeff, div_sym = _split_coeff(divisor) + + if isinstance(term, sympy.Mul): + # Check if term contains div_sym as a multiplicative factor + # (possibly nested inside a sub-Mul), and the remaining numeric + # coefficient is divisible by div_coeff. + term_coeff, term_sym_factors = _split_coeff(term) + # Flatten symbolic factors. + sym_factors = ( + list(term_sym_factors.args) + if isinstance(term_sym_factors, sympy.Mul) + else [term_sym_factors] + ) + if _contains_factor(sym_factors, div_sym): + if div_coeff == 1 or (term_coeff % div_coeff == 0): + return True + return term == divisor + + +def _split_coeff(expr: sympy.Expr) -> tuple[sympy.Integer, sympy.Expr]: + """Split *expr* into ``(numeric_coeff, symbolic_rest)``.""" + if expr.is_number: + return (expr, sympy.Integer(1)) + if isinstance(expr, sympy.Mul): + coeff = sympy.Integer(1) + sym_parts: list[sympy.Expr] = [] + for arg in expr.args: + if arg.is_number and arg.is_integer: + coeff *= arg + else: + sym_parts.append(arg) + sym = sympy.Mul(*sym_parts) if sym_parts else sympy.Integer(1) + return (coeff, sym) + return (sympy.Integer(1), expr) + + +def _contains_factor( + factors: list[sympy.Expr], target: sympy.Expr +) -> bool: + """Check if *target* appears as a factor in *factors* (possibly nested).""" + for f in factors: + if f == target: + return True + # target might be inside a nested Mul. + if isinstance(f, sympy.Mul) and target in f.args: + return True + return False + + +def _split_sum_by_divisibility( + expr: sympy.Expr, divisor: sympy.Expr +) -> tuple[sympy.Expr, sympy.Expr] | None: + """Split *expr* into ``(quotient, remainder)`` such that + ``expr == quotient * divisor + remainder`` and every additive term in + ``remainder`` is NOT a proven multiple of *divisor*. + + Returns ``None`` if no term is provably divisible (nothing to split). + """ + terms = expr.as_ordered_terms() if isinstance(expr, sympy.Add) else [expr] + quot_terms: list[sympy.Expr] = [] + rem_terms: list[sympy.Expr] = [] + for t in terms: + if _is_provably_divisible(t, divisor): + # t = something * divisor, extract the quotient contribution. + q = sympy.cancel(t / divisor) + quot_terms.append(q) + else: + rem_terms.append(t) + if not quot_terms: + return None + quotient = sympy.Add(*quot_terms) if quot_terms else sympy.Integer(0) + remainder = sympy.Add(*rem_terms) if rem_terms else sympy.Integer(0) + return (quotient, remainder) + + def _custom_simplify_once(expr: sympy.Expr) -> sympy.Expr: """Apply custom algebraic simplifications that sympy misses. - Two rewrites that sympy.simplify does not perform: + Three rewrites that sympy.simplify does not perform: + + ``transform_floor_div``: factor out D-multiples from floor division. + ``floor((A*D + B) / D) -> A + floor(B/D)`` + (and ``-> A`` when B is bounded in ``[0, D)``) + + ``transform_mod_div``: drop D-multiple terms from Mod. + ``Mod(A*D + B, D) -> Mod(B, D)`` + (and ``-> B`` when B is bounded in ``[0, D)``) ``transform_mod``: pull a small constant addend out of Mod when every other term is a multiple of the modulus divisor. @@ -213,6 +318,55 @@ def transform_floor(expr): terms.append(arg) return sympy.floor(sum(terms)) + def transform_floor_div(expr): + """``floor((A*D + B) / D) -> A + floor(B/D)``.""" + if not isinstance(expr, sympy.floor): + return None + inner = expr.args[0] + # Match pattern: inner = numerator / divisor (sympy represents as + # numerator * divisor^(-1), i.e. a Mul with a Pow(..., -1) factor). + numer, denom = inner.as_numer_denom() + if denom == 1: + return None + result = _split_sum_by_divisibility(numer, denom) + if result is None: + return None + quotient, remainder = result + if remainder.is_zero: + return quotient + # Check if remainder is bounded in [0, denom). + rem_bounds = expr_bounds(remainder) + if rem_bounds and rem_bounds[0] >= 0 and rem_bounds[1] != sympy.oo: + try: + if rem_bounds[1] < denom: + return quotient + except TypeError: + pass # Symbolic comparison — can't determine. + return quotient + sympy.floor(remainder / denom) + + def transform_mod_div(expr): + """``Mod(A*D + B, D) -> Mod(B, D)``.""" + if not isinstance(expr, sympy.Mod): + return None + p, q = expr.args + result = _split_sum_by_divisibility(p, q) + if result is None: + return None + _quotient, remainder = result + if remainder.is_zero: + return sympy.Integer(0) + # Check if remainder is bounded in [0, q) — then Mod is identity. + rem_bounds = expr_bounds(remainder) + if rem_bounds and rem_bounds[0] >= 0 and rem_bounds[1] != sympy.oo: + try: + if rem_bounds[1] < q: + return remainder + except TypeError: + pass # Symbolic comparison — can't determine. + return sympy.Mod(remainder, q, evaluate=False) + + expr = expr.replace(lambda e: transform_floor_div(e) is not None, transform_floor_div) + expr = expr.replace(lambda e: transform_mod_div(e) is not None, transform_mod_div) expr = expr.replace(lambda e: transform_mod(e) is not None, transform_mod) expr = expr.replace(lambda e: transform_floor(e) is not None, transform_floor) return expr @@ -221,6 +375,103 @@ def transform_floor(expr): _simplify_cache: dict[sympy.Basic, sympy.Expr] = {} +def extract_iv( + expr: sympy.Expr, + iv: sympy.Symbol, +) -> tuple[sympy.Expr, sympy.Expr] | None: + """Split *expr* into ``(iv_coeff, base)`` such that + ``expr == iv_coeff * iv + base`` and ``iv`` does not appear in ``base``. + + Uses ``sympy.Expr.coeff`` for the linear case. When *iv* survives in the + remainder (e.g. because ``simplify`` folded it into floor/Mod), applies the + general integer-division identity to pull *iv* out: + + * ``floor((c*iv + r) / d) = floor(c/d)*iv + floor((Mod(c,d)*iv + r)/d)`` + * ``Mod(c*iv + r, m) = Mod(Mod(c,m)*iv + r, m)`` + + Returns ``None`` if *iv* cannot be fully separated. + """ + expanded = sympy.expand(expr) + coeff = expanded.coeff(iv) + base = simplify(expanded - coeff * iv) + if iv not in base.free_symbols: + return (coeff, base) + + # iv is stuck inside floor/Mod — apply decomposition identities. + return _extract_iv_from_floor_mod(expr, iv) + + +def _extract_iv_from_floor_mod( + expr: sympy.Expr, + iv: sympy.Symbol, +) -> tuple[sympy.Expr, sympy.Expr] | None: + """Apply floor/Mod integer-division identities to separate *iv*. + + Walks the expression tree. For each ``floor(numer/denom)`` or + ``Mod(value, modulus)`` that contains *iv*, decomposes the *iv*-coefficient + using the identities: + + * ``floor((c*iv + r) / d) = floor(c/d)*iv + floor((Mod(c,d)*iv + r)/d)`` + * ``Mod(c*iv + r, m) = Mod(Mod(c,m)*iv + r, m)`` + + After rewriting, if *iv* remains in the base (e.g. inside + ``floor(Mod(c,d)*iv/d + ...)``), uses numeric probing to check whether + the residual coefficient ``Mod(c, d)`` is zero for all practical values. + When ``d | c`` (common for power-of-2 tile sizes), the probe confirms + zero and the *iv* term vanishes from the residual floor. + """ + + def _rewrite_floor(arg): + numer, denom = arg.as_numer_denom() + numer = sympy.expand(numer) + iv_coeff = numer.coeff(iv) + if iv_coeff == 0: + return sympy.floor(arg) + rest = numer - iv_coeff * iv + return ( + sympy.floor(iv_coeff / denom) * iv + + sympy.floor( + (sympy.Mod(iv_coeff, denom, evaluate=False) * iv + rest) / denom + ) + ) + + def _rewrite_mod(*args): + value, modulus = args + value_expanded = sympy.expand(value) + iv_coeff = value_expanded.coeff(iv) + if iv_coeff == 0: + return sympy.Mod(value, modulus, evaluate=False) + rest = value_expanded - iv_coeff * iv + return sympy.Mod( + sympy.Mod(iv_coeff, modulus, evaluate=False) * iv + rest, + modulus, + evaluate=False, + ) + + rewritten = expr.replace(sympy.floor, _rewrite_floor) + rewritten = rewritten.replace(sympy.Mod, _rewrite_mod) + + expanded = sympy.expand(rewritten) + coeff = expanded.coeff(iv) + base = expanded - coeff * iv + if iv not in base.free_symbols: + return (simplify(coeff), simplify(base)) + + return None + +def simplify_divisor_multiples(expr: sympy.Expr) -> sympy.Expr: + """Factor out divisor-multiples from floor/Mod without expand/cancel. + + Applies only the ``transform_floor_div`` and ``transform_mod_div`` + rewrites from ``_custom_simplify_once``. This is safe for complex + post-substitution expressions where ``sympy.expand`` / ``sympy.cancel`` + would destroy the expression structure. + """ + if not isinstance(expr, sympy.Basic) or expr.is_Atom: + return expr + return _custom_simplify_once(expr) + + def simplify(expr: sympy.Expr) -> sympy.Expr: """Simplify a sympy expression using interval arithmetic and cancel. diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index 4aa249b9f0..1c31e61d3d 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -16,6 +16,7 @@ from wave_lang.kernel.wave.compile_options import WaveCompileOptions from wave_lang.support.detect_water import get_water_mlir_pkg_path, get_water_opt +from wave_lang.support.detect_waveasm import get_waveasm_translate from wave_lang.support.ir_imports import ( Attribute, BlockArgument, @@ -370,6 +371,127 @@ def diagnostic_from_json( print("[info] No out-of-bounds accesses detected.") +def water_waveasm_lowering_pipeline( + module: Module, options: WaveCompileOptions +) -> Module: + """Lower via water-opt -> waveasm-translate -> water-opt. + + Step 1 (water-opt): lower to LLVM dialect on both host and device sides. + Step 2 (waveasm-translate): LLVM -> WaveASM -> regalloc -> HSACO -> gpu.binary. + Step 3 (water-opt): host runtime wrapping (gpu.binary -> runtime calls). + """ + water_opt = get_water_opt() + mlir_asm = module.operation.get_asm() + target_chip = options.target + lld_path = get_water_mlir_pkg_path() / "llvm" / "bin" / "ld.lld" + + def add_opt(pipeline): + if options.optimization_level: + return [pipeline] + return [] + + canonicalize_cse = "composite-fixed-point-pass", { + "name": "canonicalize_cse", + "pipeline": "any(canonicalize,cse)", + } + + int_range_optimizations = "composite-fixed-point-pass", { + "name": "int-range-optimizations", + "pipeline": 'any(int-range-optimizations,arith-int-range-narrowing{int-bitwidths-supported="32"},canonicalize,cse)', + } + + # Step 1: water-opt lowers host + device to LLVM dialect. + lowering_pipeline = [ + "water-memref-decomposition", + *add_opt(canonicalize_cse), + "lower-affine", + *add_opt(int_range_optimizations), + *add_opt("loop-invariant-code-motion"), + "convert-scf-to-cf", + ("convert-amdgpu-to-rocdl", {"chipset": target_chip}), + ( + "convert-gpu-to-rocdl", + {"use-bare-ptr-memref-call-conv": "1"}, + "gpu.module", + ), + ("gpu-to-llvm", {"use-bare-pointers-for-kernels": "1"}), + "convert-vector-to-llvm", + "reconcile-unrealized-casts", + *add_opt(canonicalize_cse), + ] + + def run_subprocess(args, input_text, tool_name): + try: + result = subprocess.run( + args, input=input_text, text=True, capture_output=True + ) + if result.returncode != 0: + raise RuntimeError( + f"{tool_name} failed (rc={result.returncode}):\n{result.stderr}" + ) + return result.stdout + except RuntimeError: + raise + except Exception as e: + raise RuntimeError(f"{tool_name} failed: {e}") from e + + water_args = [water_opt, make_linear_pass_pipeline(lowering_pipeline)] + if options.mlir_print_ir_after_all: + water_args.append("--mlir-print-ir-after-all") + lowered_mlir = run_subprocess(water_args, mlir_asm, "water-opt (lowering)") + + if options.print_mlir: + print("=== After water-opt lowering ===") + print(lowered_mlir) + + if options.compile_to_mlir: + with module.context: + return Module.parse(lowered_mlir) + + # Step 2: waveasm-translate -- LLVM -> WaveASM -> regalloc -> gpu.binary. + + waveasm_translate = get_waveasm_translate() + waveasm_args = [ + waveasm_translate, + f"--waveasm-translate-from-llvm=target={target_chip}", + "--waveasm-arith-legalization", + "--waveasm-scoped-cse", + "--waveasm-peephole", + "--waveasm-memory-offset-opt", + "--canonicalize", + "--waveasm-scoped-cse", + "--waveasm-linear-scan=max-vgprs=512 max-agprs=512", + "--waveasm-insert-waitcnt=ticketed-waitcnt=true", + f"--waveasm-hazard-mitigation=target={target_chip}", + f"--waveasm-gpu-module-to-binary=target={target_chip} lld-path={lld_path}", + ] + if options.mlir_print_ir_after_all: + waveasm_args.append("--mlir-print-ir-after-all") + binary_mlir = run_subprocess(waveasm_args, lowered_mlir, "waveasm-translate") + + if options.print_mlir: + print("=== After waveasm-translate ===") + print(binary_mlir[:500] + ("..." if len(binary_mlir) > 500 else "")) + + # Step 3: water-opt host runtime wrapping. + host_pipeline = [ + "water-gpu-to-gpu-runtime", + "symbol-dce", + *add_opt(canonicalize_cse), + ] + host_args = [water_opt, make_linear_pass_pipeline(host_pipeline)] + if options.mlir_print_ir_after_all: + host_args.append("--mlir-print-ir-after-all") + final_mlir = run_subprocess(host_args, binary_mlir, "water-opt (host)") + + if options.print_mlir: + print("=== After water-opt host ===") + print(final_mlir) + + with module.context: + return Module.parse(final_mlir) + + def water_lowering_pipeline(module: Module, options: WaveCompileOptions) -> Module: binary = get_water_opt() mlir_asm = module.operation.get_asm() diff --git a/wave_lang/kernel/wave/wave.py b/wave_lang/kernel/wave/wave.py index 85a2df27b9..7387df32b1 100644 --- a/wave_lang/kernel/wave/wave.py +++ b/wave_lang/kernel/wave/wave.py @@ -41,6 +41,7 @@ get_grid_shape, get_device_layout, ) +from .region_canonicalization import RegionFormat, requires_region_format from .symbolic_constraints import SymbolicAlias from .utils.general_utils import ( @@ -314,6 +315,7 @@ def _trace( return trace + @requires_region_format(RegionFormat.LEGACY_PLACEHOLDERS) def create_induction_vars(self, trace: CapturedTrace) -> None: """ Creates induction variables for all the reductions in the graph @@ -359,6 +361,7 @@ def initialize_wave_constraints(self) -> None: hardware_constraint.waves_per_block = tuple(waves_per_block) + @requires_region_format(RegionFormat.LEGACY_PLACEHOLDERS) def initialize_reductions(self, trace: CapturedTrace) -> None: """ For each reduction, initializes the reduction count by looking at the @@ -500,6 +503,7 @@ def eager_execute(self, args, kwargs): def __repr__(self): return f"tk.wave @{self._name}[{self.grid_type}]" + @requires_region_format(RegionFormat.DIRECT_OUTER_REF) def run_manual_schedule( self, trace: CapturedTrace, diff --git a/wave_lang/kernel/wave/workgroup_reordering.py b/wave_lang/kernel/wave/workgroup_reordering.py index 85163ef086..b2273988f2 100644 --- a/wave_lang/kernel/wave/workgroup_reordering.py +++ b/wave_lang/kernel/wave/workgroup_reordering.py @@ -8,10 +8,12 @@ from .._support.tracing import CapturedTrace from ..lang.global_symbols import * from ..ops.wave_ops import * +from .region_canonicalization import RegionFormat, requires_region_format from .constraints import * from .utils.symbol_utils import * +@requires_region_format(RegionFormat.LEGACY_PLACEHOLDERS) def reorder_workgroups(graph: CapturedTrace, reordering_constraints): if len(reordering_constraints) == 0: return diff --git a/waveasm/docs/architecture.md b/waveasm/docs/architecture.md index 9a0d64ec2f..1bc1b1934e 100644 --- a/waveasm/docs/architecture.md +++ b/waveasm/docs/architecture.md @@ -26,7 +26,12 @@ This document describes the architecture of the WaveASM C++ backend, which trans ║ ▼ ║ ║ ┌─────────────────────────────────────────────────────┐ ║ ║ │ WaveASM Dialect IR │ ║ - ║ │ (Pure SSA with virtual registers) │ ║ + ║ │ (Generic pseudo-ops + concrete machine ops) │ ║ + ║ └─────────────────────────┬───────────────────────────┘ ║ + ║ ▼ ║ + ║ ┌─────────────────────────────────────────────────────┐ ║ + ║ │ Legalization Passes │ ║ + ║ │ (Type legalization, register placement, ISel) │ ║ ║ └─────────────────────────┬───────────────────────────┘ ║ ║ ▼ ║ ║ ┌─────────────────────────────────────────────────────┐ ║ @@ -134,9 +139,24 @@ The WaveASM dialect is a pure SSA representation close to AMDGCN assembly. └────────────────────────────────────────────────────────────────────────────────┘ ``` -### 3. Optimization Passes +### 3. Legalization (see [legalization.md](legalization.md)) + +Translation emits generic pseudo-ops (`waveasm.arith.add`, `.mul`, `.cmp_*`) +for arithmetic, deferring type width and register file decisions. Dedicated +legalization passes then: + +1. **Type legalization** — narrows i64 to i32 (or expands to carry chains). +2. **Register legalization** — assigns SGPR vs VGPR, enforces constant bus + limits, inserts `v_mov_b32` copies as needed. +3. **Instruction selection** — lowers generic ops to concrete machine ops + (`s_add_u32` / `v_add_u32` / etc.). + +Operations with unambiguous lowering (buffer ops, MFMA, control flow, LDS) +emit concrete machine ops directly during translation and skip legalization. + +### 4. Optimization Passes -The pass pipeline optimizes the WaveASM IR before register allocation. +The pass pipeline optimizes the legalized WaveASM IR before register allocation. ``` ┌────────────────────────────────────────────────────────────────────────────────┐ @@ -183,7 +203,7 @@ The pass pipeline optimizes the WaveASM IR before register allocation. └────────────────────────────────────────────────────────────────────────────────┘ ``` -### 4. Register Allocation +### 5. Register Allocation Linear scan register allocation converts virtual registers to physical registers. @@ -229,7 +249,7 @@ Linear scan register allocation converts virtual registers to physical registers └────────────────────────────────────────────────────────────────────────────────┘ ``` -### 5. Assembly Emission +### 6. Assembly Emission The assembly emitter generates AMDGCN assembly with full HSA metadata. diff --git a/waveasm/docs/legalization.md b/waveasm/docs/legalization.md new file mode 100644 index 0000000000..b9dbb1f201 --- /dev/null +++ b/waveasm/docs/legalization.md @@ -0,0 +1,236 @@ +# WaveASM Legalization Design + +This document describes the legalization strategy for the WaveASM backend: +how pseudo-ops are lowered to legal AMDGCN machine ops with correct +register placement and operand widths. + +## Problem Statement + +The WaveASM backend accepts MLIR input from two frontends: + +1. **TranslateFromMLIR** — translates `gpu`, `arith`, `vector`, `scf`, + `amdgpu` dialect ops. +2. **TranslateFromLLVM** — translates LLVM dialect ops (used when the Water + frontend lowers through LLVM IR). + +Both frontends face the same set of legalization problems: + +- **Type width**: LLVM IR uses i64 for pointer arithmetic and index math. + AMDGCN VALU instructions are 32-bit. i64 operations must be narrowed or + expanded to i32 pairs. +- **Register placement**: AMDGCN has two register files (SGPR and VGPR) with + strict placement rules. Kernel arguments live in SGPRs; lane-varying values + must be in VGPRs. +- **Constant bus**: VALU instructions can read at most one SGPR operand per + instruction (pre-GFX10). Violations require inserting SGPR-to-VGPR moves. +- **Instruction selection**: A pseudo-op `arith.add` can lower to `s_add_u32` (scalar), + `v_add_u32` (vector), or a carry chain for i64. + +## Architecture + +The legalization strategy follows the real LLVM AMDGPU backend: emit +structurally correct but "unlegalized" IR during translation, then fix it up +in dedicated post-translation passes. + +### Pipeline + +``` + Translation (TranslateFromMLIR / TranslateFromLLVM) + │ + │ Emits pseudo-ops for arithmetic. + │ Does NOT make register placement or width decisions. + │ + ▼ + ┌──────────────────────────────────────┐ + │ Type Legalization Pass │ + │ │ + │ - i64 → i32 narrowing or expansion │ + │ - i64 add → i32 carry chain │ + │ - i64 mul → i32 partial products │ + │ - i64 icmp → hi/lo comparison │ + │ - trunc/sext/zext lowering │ + └──────────────┬───────────────────────┘ + │ + ▼ + ┌──────────────────────────────────────┐ + │ Register Legalization Pass │ + │ │ + │ - Assigns SGPR vs VGPR per value │ + │ - Inserts v_mov_b32 for SGPR→VGPR │ + │ - Enforces constant bus limits │ + └──────────────┬───────────────────────┘ + │ + ▼ + ┌──────────────────────────────────────┐ + │ Instruction Selection Pass │ + │ │ + │ - Lowers pseudo-ops to concrete │ + │ machine ops (s_add/v_add/etc.) │ + │ - Selects SALU vs VALU encoding │ + └──────────────┬───────────────────────┘ + │ + ▼ + Existing pipeline (CSE, Peephole, Liveness, RegAlloc, Emit) +``` + +Note: these three passes may be combined into fewer passes if the separation +proves unnecessary. The logical ordering matters more than the physical pass +count. + +## Pseudo-Operations + +Translation emits a targeted set of **pseudo-ops** for cases where the +register file and/or width decision is non-trivial. These ops carry the +original type information and are lowered by legalization. + +### Scope + +Pseudo-ops are introduced only for arithmetic and comparison -- operations +where the SGPR/VGPR and width choice depends on context. Operations with +unambiguous lowering (SRD setup, buffer loads, MFMA, barriers, branch, endpgm) +continue to emit concrete machine ops directly during translation. + +### Implemented pseudo-ops + +| Pseudo-op | Lowers to (SALU) | Lowers to (VALU) | +|-----------|-----------------|-------------------| +| `waveasm.arith.add : iN` | `s_add_u32` (+ `s_addc_u32` for i64) | `v_add_u32` (+ `v_addc_co_u32` for i64) | +| `waveasm.arith.mul : iN` | `s_mul_i32` (+ `s_mul_hi_u32` for i64) | `v_mul_lo_u32` (+ partial products for i64) | +| `waveasm.arith.cmp_XX : iN` | `s_cmp_XX_u32/i32` | `v_cmp_XX_u32/i32` (+ hi/lo for i64) | +| `waveasm.arith.sext : i32 -> i64` | `s_ashr_i32` + pack | `v_ashrrev_i32` + pack | +| `waveasm.arith.zext : i32 -> i64` | `s_mov_b32(0)` + pack | `v_mov_b32(0)` + pack | +| `waveasm.arith.trunc : i64 -> i32` | extract sub-register | extract sub-register | +| `waveasm.arith.select` | `s_cselect_b32` | `v_cndmask_b32` | + +The pseudo-ops use a register-file-agnostic type (e.g., a plain integer type +or a new `!waveasm.greg` virtual type) to defer the SGPR/VGPR decision. + +### What stays concrete + +These operations emit concrete machine ops directly during translation: + +- **SRD/buffer ops**: `s_load_dwordx2`, `s_mov_b32` for descriptor fields, + `buffer_load_*`, `buffer_store_*`. +- **LDS ops**: `ds_read_*`, `ds_write_*`. +- **MFMA**: `v_mfma_*` — always VALU, always concrete widths. +- **Lane ID**: `v_mbcnt_lo/hi` — always VGPR. +- **Control flow**: `s_branch`, `s_cbranch_*`, `s_endpgm`, `s_barrier`. +- **Waitcnts**: `s_waitcnt` — always scalar. +- **Comments, labels, raw asm**: pass-through, no legalization needed. + +## Type Legalization Details + +### i64 narrowing (short-term) + +For dynamic dimension arguments that are known to fit in 32 bits, narrowing +is sufficient: + +``` +waveasm.arith.add %a, %b : i64 + → %a32 = waveasm.arith.trunc %a : i64 → i32 + %b32 = waveasm.arith.trunc %b : i64 → i32 + %r32 = waveasm.arith.add %a32, %b32 : i32 + %r = waveasm.arith.sext %r32 : i32 → i64 // if result used as i64 +``` + +### i64 expansion (long-term) + +For true 64-bit arithmetic, expand to carry chains: + +``` +waveasm.arith.add %a, %b : i64 + → %a_lo, %a_hi = split %a + %b_lo, %b_hi = split %b + %r_lo, %carry = v_add_co_u32 %a_lo, %b_lo + %r_hi = v_addc_co_u32 %a_hi, %b_hi, %carry + %r = merge %r_lo, %r_hi +``` + +The decision between narrowing and expansion can be driven by value range +analysis or simply by op type (e.g., always narrow kernel args, always expand +address arithmetic). + +## Register Legalization Details + +### Divergence model + +Values are classified as **uniform** (same across all lanes) or **divergent** +(varies per lane): + +| Source | Classification | +|--------|---------------| +| Kernel arguments (preloaded SGPRs) | Uniform | +| `workitem.id.x` / `v_mbcnt` | Divergent | +| Constants / immediates | Uniform | +| Derived from any divergent value | Divergent | +| Loop induction variable (scalar) | Uniform | +| `v_readfirstlane_b32` result | Uniform | + +Uniform values prefer SGPRs. Divergent values must be in VGPRs. + +For the initial implementation, divergence is inferred demand-driven: when an +instruction requires a VGPR operand and the source is an SGPR, insert a move. +No explicit divergence analysis pass is needed yet. + +### Constant bus enforcement + +After register assignment, scan each VALU instruction. If it reads more than +one distinct SGPR: +1. Pick one SGPR operand (prefer the one with fewer uses, to minimize moves). +2. Insert `v_mov_b32` to copy it to a VGPR. +3. Rely on CSE to deduplicate identical moves. + +### VGPR → SGPR (future) + +`v_readfirstlane_b32` extracts lane 0 from a VGPR into an SGPR. Needed when a +divergent value must feed a scalar-only instruction (e.g., `s_load`, SRD +field). Not needed for current use cases but the legalization framework should +accommodate it. + +## Comparison With LLVM AMDGPU Backend + +| Aspect | LLVM AMDGPU | WaveASM | +|--------|------------|---------| +| Pseudo-ops | G_ADD, G_MUL (GlobalISel) | waveasm.arith.add, .mul | +| Type legalization | LegalizerInfo rules | Dedicated pass | +| Register placement | RegBankSelect + UniformityAnalysis | Demand-driven + constant bus pass | +| Constant bus | SIInstrInfo::legalizeOperands (post-ISel) | Register legalization pass | +| i64 lowering | RegBankLegalizeHelper | Type legalization pass | +| VGPR↔SGPR copies | SIFixSGPRCopies (cost heuristic) | Simple demand-driven insertion | + +WaveASM intentionally simplifies the LLVM model. The full LLVM backend has +~15 post-ISel fixup passes; WaveASM targets 2-3 legalization passes covering +the same ground for the subset of IR patterns we generate. + +## Implementation Plan + +### Phase 1: Pseudo-ops + type legalization + +1. Define arithmetic pseudo-ops in `WaveASMOps.td`. +2. Update `TranslateFromLLVM` to emit pseudo-ops instead of concrete ops with + inline `ensure32Bit`/`ensureVGPR` calls. +3. Implement type legalization pass (i64 narrowing only). + +### Phase 2: Register legalization + +1. Implement register legalization pass: SGPR/VGPR assignment + constant bus. +2. Update `TranslateFromMLIR` to also emit pseudo-ops where appropriate. +3. Remove all inline register fixup code from both translators. + +### Phase 3: Full i64 + instruction selection + +1. Add i64 carry-chain expansion to type legalization. +2. Add instruction selection pass (pseudo-op → concrete, SALU vs VALU). +3. Add VGPR→SGPR support (`v_readfirstlane_b32`) if needed. + +## Resolved Questions + +- **Pseudo-op type representation**: pseudo-ops use polymorphic result types + (register-file-agnostic `!waveasm.sreg`/`!waveasm.vreg`/`!waveasm.imm`). + Width is inferred from operand register sizes. +- **Pass count**: type legalization, register legalization, and instruction + selection are currently combined in a single `ArithLegalization` pass. + Splitting into separate passes is tracked as future work. +- **Value range analysis**: not needed. The LLVM frontend already emits + explicit `trunc` ops where narrowing is safe; legalization lowers those + and uses i64 carry-chain expansion for true 64-bit arithmetic. diff --git a/waveasm/include/waveasm/Dialect/WaveASMArithOps.td b/waveasm/include/waveasm/Dialect/WaveASMArithOps.td new file mode 100644 index 0000000000..e8a1fe4f6d --- /dev/null +++ b/waveasm/include/waveasm/Dialect/WaveASMArithOps.td @@ -0,0 +1,148 @@ +// Copyright 2026 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef WAVEASM_DIALECT_WAVEASMARITHOPS +#define WAVEASM_DIALECT_WAVEASMARITHOPS + +include "mlir/IR/OpBase.td" +include "mlir/IR/EnumAttr.td" +include "waveasm/Dialect/WaveASMDialect.td" +include "waveasm/Dialect/WaveASMTypes.td" +include "waveasm/Dialect/WaveASMInterfaces.td" + +//===----------------------------------------------------------------------===// +// Generic Arithmetic Pseudo-Ops +// +// These ops are emitted during translation from MLIR/LLVM dialect and are +// lowered to concrete SALU/VALU machine ops by legalization passes. They +// defer register file (SGPR vs VGPR) and type width (i32 vs i64) decisions. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Base Class +//===----------------------------------------------------------------------===// + +// Generic arithmetic pseudo-op. Register-file-agnostic: accepts any operand +// type (SGPR, VGPR, immediate) and produces any register type. Legalization +// replaces these with concrete SALU/VALU ops. +class ArithPseudoOp traits = []> + : WAVEASMOp<"arith." # mnemonic, !listconcat([Pure, WaveASM_ArithmeticOp], traits)>; + +//===----------------------------------------------------------------------===// +// Binary Arithmetic +//===----------------------------------------------------------------------===// + +def WaveASM_ArithAddOp : ArithPseudoOp<"add", [Commutative]> { + let summary = "Generic integer add (legalized to s_add_u32 or v_add_u32)"; + let arguments = (ins WaveASM_AnyOperand:$lhs, WaveASM_AnyOperand:$rhs); + let results = (outs WaveASM_AnyReg:$dst); + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)"; +} + +def WaveASM_ArithMulOp : ArithPseudoOp<"mul", [Commutative]> { + let summary = "Generic integer multiply (legalized to s_mul_i32 or v_mul_lo_u32)"; + let arguments = (ins WaveASM_AnyOperand:$lhs, WaveASM_AnyOperand:$rhs); + let results = (outs WaveASM_AnyReg:$dst); + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)"; +} + +//===----------------------------------------------------------------------===// +// Bitwise +//===----------------------------------------------------------------------===// + +def WaveASM_ArithOrOp : ArithPseudoOp<"or", [Commutative]> { + let summary = "Generic bitwise OR (legalized to s_or_b32/b64 or v_or_b32/b64)"; + let arguments = (ins WaveASM_AnyOperand:$lhs, WaveASM_AnyOperand:$rhs); + let results = (outs WaveASM_AnyReg:$dst); + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)"; +} + +def WaveASM_ArithAndOp : ArithPseudoOp<"and", [Commutative]> { + let summary = "Generic bitwise AND (legalized to s_and_b32/b64 or v_and_b32/b64)"; + let arguments = (ins WaveASM_AnyOperand:$lhs, WaveASM_AnyOperand:$rhs); + let results = (outs WaveASM_AnyReg:$dst); + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)"; +} + +//===----------------------------------------------------------------------===// +// Comparison +//===----------------------------------------------------------------------===// + +def WaveASM_ArithCmpOp : ArithPseudoOp<"cmp", []> { + let summary = "Generic integer compare (legalized to s_cmp_* or v_cmp_*)"; + let description = [{ + Compares two integer operands using the given predicate. Produces a + condition result that can be consumed by arith.select or control flow. + + Legalization lowers this to: + - SALU: s_cmp_*_i32/u32 (result is SCC in an SGPR). + - VALU: v_cmp_*_i32/u32 (sets VCC implicitly, no explicit result). + }]; + let arguments = (ins WaveASM_CmpPred:$predicate, + WaveASM_AnyOperand:$lhs, WaveASM_AnyOperand:$rhs); + let results = (outs WaveASM_AnyReg:$dst); + let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` functional-type(operands, results)"; +} + +//===----------------------------------------------------------------------===// +// Select +//===----------------------------------------------------------------------===// + +def WaveASM_ArithSelectOp : ArithPseudoOp<"select"> { + let summary = "Generic conditional select (legalized to v_cndmask_b32 or s_cselect_b32)"; + let description = [{ + Selects between two values based on a condition. The condition is + typically produced by arith.cmp. + + Legalization lowers this to: + - VALU: v_cndmask_b32 (condition via VCC from preceding v_cmp). + - SALU: s_cselect_b32 (condition via SCC from preceding s_cmp). + }]; + let arguments = (ins WaveASM_AnyOperand:$falseVal, + WaveASM_AnyOperand:$trueVal, + WaveASM_AnyReg:$condition); + let results = (outs WaveASM_AnyReg:$dst); + let assemblyFormat = "$condition `,` $trueVal `,` $falseVal attr-dict `:` functional-type(operands, results)"; +} + +//===----------------------------------------------------------------------===// +// Type Conversion +//===----------------------------------------------------------------------===// + +def WaveASM_ArithTruncOp : ArithPseudoOp<"trunc"> { + let summary = "Generic integer truncation (e.g., i64 -> i32)"; + let description = [{ + Truncates a wide integer to a narrower width by extracting the low + sub-register. For i64->i32 on SGPR pairs, extracts the low SGPR. + }]; + let arguments = (ins WaveASM_AnyOperand:$src); + let results = (outs WaveASM_AnyReg:$dst); + let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; +} + +def WaveASM_ArithSExtOp : ArithPseudoOp<"sext"> { + let summary = "Generic sign extension (e.g., i32 -> i64)"; + let description = [{ + Sign-extends a narrow integer to a wider width. For i32->i64, the + high word is filled with the sign bit (arithmetic shift right by 31). + }]; + let arguments = (ins WaveASM_AnyOperand:$src); + let results = (outs WaveASM_AnyReg:$dst); + let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; +} + +def WaveASM_ArithZExtOp : ArithPseudoOp<"zext"> { + let summary = "Generic zero extension (e.g., i32 -> i64)"; + let description = [{ + Zero-extends a narrow integer to a wider width. For i32->i64, the + high word is set to zero. + }]; + let arguments = (ins WaveASM_AnyOperand:$src); + let results = (outs WaveASM_AnyReg:$dst); + let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; +} + +#endif // WAVEASM_DIALECT_WAVEASMARITHOPS diff --git a/waveasm/include/waveasm/Dialect/WaveASMControlFlowOps.td b/waveasm/include/waveasm/Dialect/WaveASMControlFlowOps.td index 323e6ef365..576a410af3 100644 --- a/waveasm/include/waveasm/Dialect/WaveASMControlFlowOps.td +++ b/waveasm/include/waveasm/Dialect/WaveASMControlFlowOps.td @@ -104,7 +104,7 @@ def WaveASM_ConditionOp : WAVEASMOp<"condition", [ Terminator, HasParent<"LoopOp">, DeclareOpInterfaceMethods + ["getSuccessorRegions", "getMutableSuccessorOperands"]> ]> { let summary = "Loop condition terminator"; let description = [{ diff --git a/waveasm/include/waveasm/Dialect/WaveASMDialect.td b/waveasm/include/waveasm/Dialect/WaveASMDialect.td index b8b1b94e4c..564a8a37a8 100644 --- a/waveasm/include/waveasm/Dialect/WaveASMDialect.td +++ b/waveasm/include/waveasm/Dialect/WaveASMDialect.td @@ -70,6 +70,11 @@ def WaveASMDialect : Dialect { // Name of the ABI attribute on kernel programs static constexpr ::llvm::StringLiteral kABIAttrName = "waveasm.abi"; + /// Original kernel name before symbol mangling. + static ::llvm::StringRef getKernelNameAttrName() { + return "kernel_name"; + } + // Registration functions void registerAttributes(); void registerTypes(); diff --git a/waveasm/include/waveasm/Dialect/WaveASMInterfaces.h b/waveasm/include/waveasm/Dialect/WaveASMInterfaces.h index b401f5ee8c..0167f3e175 100644 --- a/waveasm/include/waveasm/Dialect/WaveASMInterfaces.h +++ b/waveasm/include/waveasm/Dialect/WaveASMInterfaces.h @@ -47,6 +47,20 @@ class ControlFlowOp : public TraitBase {}; template class SpecialRegOp : public TraitBase {}; +/// Trait for operations that implicitly write the SCC (Scalar Condition Code) +/// flag as a hardware side effect. Ops with this trait must not be placed +/// between an SCC-producing op and its consumer (ConditionOp, s_cselect, +/// s_addc_u32). The SCC verifier pass checks this invariant. +template +class SCCDef : public TraitBase {}; + +/// Trait for operations that implicitly read the SCC flag. +/// The result depends on the current SCC value, so these ops are NOT +/// eligible for CSE (two identical operands can produce different results +/// if SCC differs). +template +class SCCUse : public TraitBase {}; + } // namespace OpTrait } // namespace mlir diff --git a/waveasm/include/waveasm/Dialect/WaveASMInterfaces.td b/waveasm/include/waveasm/Dialect/WaveASMInterfaces.td index c8c143b84d..6a68ea10eb 100644 --- a/waveasm/include/waveasm/Dialect/WaveASMInterfaces.td +++ b/waveasm/include/waveasm/Dialect/WaveASMInterfaces.td @@ -264,4 +264,10 @@ def WaveASM_ControlFlowOp : NativeOpTrait<"ControlFlowOp">; // Trait for special register operations (M0, VCC, EXEC) def WaveASM_SpecialRegOp : NativeOpTrait<"SpecialRegOp">; +// Trait for operations that implicitly write the SCC flag. +def WaveASM_SCCDef : NativeOpTrait<"SCCDef">; + +// Trait for operations that implicitly read the SCC flag. +def WaveASM_SCCUse : NativeOpTrait<"SCCUse">; + #endif // WaveASM_DIALECT_WAVEASMINTERFACES diff --git a/waveasm/include/waveasm/Dialect/WaveASMOps.td b/waveasm/include/waveasm/Dialect/WaveASMOps.td index c0fc219dda..e36602bf46 100644 --- a/waveasm/include/waveasm/Dialect/WaveASMOps.td +++ b/waveasm/include/waveasm/Dialect/WaveASMOps.td @@ -83,6 +83,10 @@ class VALUCmpOp traits = []> } // SALU Unary: dst = op(src) +// Pure — does NOT clobber SCC on hardware. +// Used for: s_mov_b32, s_mov_b64, s_sext_i32_i8, s_sext_i32_i16. +// NOTE: Also used for SCC-clobbering ops (s_not, s_brev, etc.) until +// SALUUnaryWithSCCOp migration is validated. class SALUUnaryOp traits = []> : WAVEASMOp { let arguments = (ins WaveASM_SRegOrImm:$src); @@ -90,7 +94,25 @@ class SALUUnaryOp traits = []> let assemblyFormat = "$src attr-dict `:` type($src) `->` type($dst)"; } +// SALU Unary with SCC clobber: dst = op(src), SCC = f(result) +// Writes SCC on hardware but result depends only on operands. +// Uses NoMemoryEffect + AlwaysSpeculatableImplTrait (equivalent to Pure) +// so MLIR's DCE and LICM still work, plus SCCDef for the SCC verifier. +class SALUUnaryWithSCCOp traits = []> + : WAVEASMOp { + let arguments = (ins WaveASM_SRegOrImm:$src); + let results = (outs WaveASM_AnySGPR:$dst); + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($dst)"; +} + // SALU Binary: dst = op(src0, src1) +// Pure — does NOT clobber SCC on hardware. +// Used for: s_mul_i32, s_mul_hi_*, s_bfm_*, s_pack_*, s_cselect_b32. +// NOTE: Also used for SCC-clobbering ops (s_and, s_or, s_lshl, etc.) +// until SALUBinaryWithSCCOp migration is validated. class SALUBinaryOp traits = []> : WAVEASMOp { let arguments = (ins WaveASM_AnySGPR:$src0, WaveASM_SRegOrImm:$src1); @@ -98,6 +120,21 @@ class SALUBinaryOp traits = []> let assemblyFormat = "$src0 `,` $src1 attr-dict `:` type($src0) `,` type($src1) `->` type($dst)"; } +// SALU Binary with SCC clobber: dst = op(src0, src1), SCC = f(result) +// Writes SCC on hardware but result depends only on operands. +// Uses NoMemoryEffect + AlwaysSpeculatableImplTrait (equivalent to Pure) +// so MLIR's DCE and LICM still work, plus SCCDef so the SCC verifier +// can identify these as SCC-clobbering. +class SALUBinaryWithSCCOp traits = []> + : WAVEASMOp { + let arguments = (ins WaveASM_AnySGPR:$src0, WaveASM_SRegOrImm:$src1); + let results = (outs WaveASM_AnySGPR:$dst); + let assemblyFormat = "$src0 `,` $src1 attr-dict `:` type($src0) `,` type($src1) `->` type($dst)"; +} + // SALU Binary with carry/SCC: dst, scc = op(src0, src1) // s_add_u32 and s_addc_u32 set SCC (carry flag) as a side effect. // We model SCC as an explicit sreg result so the verifier can prevent @@ -107,7 +144,7 @@ class SALUBinaryOp traits = []> // a true SSA dependency. Currently the chain is only enforced by emission // ordering, not by the SSA graph. class SALUBinaryWithCarryOp traits = []> - : WAVEASMOp { + : WAVEASMOp { let arguments = (ins WaveASM_AnySGPR:$src0, WaveASM_SRegOrImm:$src1); let results = (outs WaveASM_AnySGPR:$dst, WaveASM_AnySGPR:$scc); let assemblyFormat = "$src0 `,` $src1 attr-dict `:` type($src0) `,` type($src1) `->` type($dst) `,` type($scc)"; @@ -123,7 +160,7 @@ class SALUBinaryWithCarryOp traits = []> // since the live range is very short (s_cmp -> condition terminator). // TODO: Consider a dedicated SCCType to avoid the wasted register slot. class SALUCmpOp traits = []> - : WAVEASMOp { + : WAVEASMOp { let arguments = (ins WaveASM_AnySGPR:$src0, WaveASM_SRegOrImm:$src1); let results = (outs WaveASM_AnySGPR:$result); // SCC result (see note above) let assemblyFormat = "$src0 `,` $src1 attr-dict `:` type($src0) `,` type($src1) `->` type($result)"; @@ -217,6 +254,7 @@ class LDSStoreOp traits = []> //===----------------------------------------------------------------------===// include "waveasm/Dialect/WaveASMControlFlowOps.td" +include "waveasm/Dialect/WaveASMArithOps.td" //===----------------------------------------------------------------------===// // Program Container @@ -697,26 +735,29 @@ def WaveASM_V_CMP_GE_U64 : VALUCmpOp<"v_cmp_ge_u64">; // SALU Unary Instructions //===----------------------------------------------------------------------===// +// No SCC clobber def WaveASM_S_MOV_B32 : SALUUnaryOp<"s_mov_b32">; def WaveASM_S_MOV_B64 : SALUUnaryOp<"s_mov_b64">; -def WaveASM_S_NOT_B32 : SALUUnaryOp<"s_not_b32">; -def WaveASM_S_NOT_B64 : SALUUnaryOp<"s_not_b64">; -def WaveASM_S_BREV_B32 : SALUUnaryOp<"s_brev_b32">; -def WaveASM_S_BREV_B64 : SALUUnaryOp<"s_brev_b64">; -def WaveASM_S_BCNT0_I32_B32 : SALUUnaryOp<"s_bcnt0_i32_b32">; -def WaveASM_S_BCNT0_I32_B64 : SALUUnaryOp<"s_bcnt0_i32_b64">; -def WaveASM_S_BCNT1_I32_B32 : SALUUnaryOp<"s_bcnt1_i32_b32">; -def WaveASM_S_BCNT1_I32_B64 : SALUUnaryOp<"s_bcnt1_i32_b64">; -def WaveASM_S_FF0_I32_B32 : SALUUnaryOp<"s_ff0_i32_b32">; -def WaveASM_S_FF0_I32_B64 : SALUUnaryOp<"s_ff0_i32_b64">; -def WaveASM_S_FF1_I32_B32 : SALUUnaryOp<"s_ff1_i32_b32">; -def WaveASM_S_FF1_I32_B64 : SALUUnaryOp<"s_ff1_i32_b64">; -def WaveASM_S_FLBIT_I32_B32 : SALUUnaryOp<"s_flbit_i32_b32">; -def WaveASM_S_FLBIT_I32_B64 : SALUUnaryOp<"s_flbit_i32_b64">; -def WaveASM_S_ABS_I32 : SALUUnaryOp<"s_abs_i32">; def WaveASM_S_SEXT_I32_I8 : SALUUnaryOp<"s_sext_i32_i8">; def WaveASM_S_SEXT_I32_I16 : SALUUnaryOp<"s_sext_i32_i16">; +// Clobber SCC (SCC = f(result)) +def WaveASM_S_NOT_B32 : SALUUnaryWithSCCOp<"s_not_b32">; +def WaveASM_S_NOT_B64 : SALUUnaryWithSCCOp<"s_not_b64">; +def WaveASM_S_BREV_B32 : SALUUnaryWithSCCOp<"s_brev_b32">; +def WaveASM_S_BREV_B64 : SALUUnaryWithSCCOp<"s_brev_b64">; +def WaveASM_S_BCNT0_I32_B32 : SALUUnaryWithSCCOp<"s_bcnt0_i32_b32">; +def WaveASM_S_BCNT0_I32_B64 : SALUUnaryWithSCCOp<"s_bcnt0_i32_b64">; +def WaveASM_S_BCNT1_I32_B32 : SALUUnaryWithSCCOp<"s_bcnt1_i32_b32">; +def WaveASM_S_BCNT1_I32_B64 : SALUUnaryWithSCCOp<"s_bcnt1_i32_b64">; +def WaveASM_S_FF0_I32_B32 : SALUUnaryWithSCCOp<"s_ff0_i32_b32">; +def WaveASM_S_FF0_I32_B64 : SALUUnaryWithSCCOp<"s_ff0_i32_b64">; +def WaveASM_S_FF1_I32_B32 : SALUUnaryWithSCCOp<"s_ff1_i32_b32">; +def WaveASM_S_FF1_I32_B64 : SALUUnaryWithSCCOp<"s_ff1_i32_b64">; +def WaveASM_S_FLBIT_I32_B32 : SALUUnaryWithSCCOp<"s_flbit_i32_b32">; +def WaveASM_S_FLBIT_I32_B64 : SALUUnaryWithSCCOp<"s_flbit_i32_b64">; +def WaveASM_S_ABS_I32 : SALUUnaryWithSCCOp<"s_abs_i32">; + //===----------------------------------------------------------------------===// // SALU Binary Instructions //===----------------------------------------------------------------------===// @@ -727,48 +768,58 @@ def WaveASM_S_ADDC_U32 : SALUBinaryWithCarryOp<"s_addc_u32">; def WaveASM_S_ADD_I32 : SALUBinaryWithCarryOp<"s_add_i32">; def WaveASM_S_SUB_U32 : SALUBinaryWithCarryOp<"s_sub_u32">; def WaveASM_S_SUB_I32 : SALUBinaryWithCarryOp<"s_sub_i32">; +// Conditional select: dst = SCC ? src0 : src1 (reads SCC, no SCC write). +// NOT Pure/ArithmeticOp — result depends on implicit SCC state, so NOT +// CSE-eligible (two s_cselect_b32 with same operands can give different +// results if SCC differs). NoMemoryEffect allows DCE of unused results. +def WaveASM_S_CSELECT_B32 : WAVEASMOp<"s_cselect_b32", + [WaveASM_SCCUse, NoMemoryEffect]> { + let arguments = (ins WaveASM_AnySGPR:$src0, WaveASM_SRegOrImm:$src1); + let results = (outs WaveASM_AnySGPR:$dst); + let assemblyFormat = "$src0 `,` $src1 attr-dict `:` type($src0) `,` type($src1) `->` type($dst)"; +} // Multiplication does not set SCC def WaveASM_S_MUL_I32 : SALUBinaryOp<"s_mul_i32">; def WaveASM_S_MUL_HI_U32 : SALUBinaryOp<"s_mul_hi_u32">; def WaveASM_S_MUL_HI_I32 : SALUBinaryOp<"s_mul_hi_i32">; -// Bitwise -def WaveASM_S_AND_B32 : SALUBinaryOp<"s_and_b32">; -def WaveASM_S_AND_B64 : SALUBinaryOp<"s_and_b64">; -def WaveASM_S_OR_B32 : SALUBinaryOp<"s_or_b32">; -def WaveASM_S_OR_B64 : SALUBinaryOp<"s_or_b64">; -def WaveASM_S_XOR_B32 : SALUBinaryOp<"s_xor_b32">; -def WaveASM_S_XOR_B64 : SALUBinaryOp<"s_xor_b64">; -def WaveASM_S_ANDN2_B32 : SALUBinaryOp<"s_andn2_b32">; -def WaveASM_S_ANDN2_B64 : SALUBinaryOp<"s_andn2_b64">; -def WaveASM_S_ORN2_B32 : SALUBinaryOp<"s_orn2_b32">; -def WaveASM_S_ORN2_B64 : SALUBinaryOp<"s_orn2_b64">; -def WaveASM_S_NAND_B32 : SALUBinaryOp<"s_nand_b32">; -def WaveASM_S_NAND_B64 : SALUBinaryOp<"s_nand_b64">; -def WaveASM_S_NOR_B32 : SALUBinaryOp<"s_nor_b32">; -def WaveASM_S_NOR_B64 : SALUBinaryOp<"s_nor_b64">; -def WaveASM_S_XNOR_B32 : SALUBinaryOp<"s_xnor_b32">; -def WaveASM_S_XNOR_B64 : SALUBinaryOp<"s_xnor_b64">; +// Bitwise — all clobber SCC (SCC = result != 0) +def WaveASM_S_AND_B32 : SALUBinaryWithSCCOp<"s_and_b32">; +def WaveASM_S_AND_B64 : SALUBinaryWithSCCOp<"s_and_b64">; +def WaveASM_S_OR_B32 : SALUBinaryWithSCCOp<"s_or_b32">; +def WaveASM_S_OR_B64 : SALUBinaryWithSCCOp<"s_or_b64">; +def WaveASM_S_XOR_B32 : SALUBinaryWithSCCOp<"s_xor_b32">; +def WaveASM_S_XOR_B64 : SALUBinaryWithSCCOp<"s_xor_b64">; +def WaveASM_S_ANDN2_B32 : SALUBinaryWithSCCOp<"s_andn2_b32">; +def WaveASM_S_ANDN2_B64 : SALUBinaryWithSCCOp<"s_andn2_b64">; +def WaveASM_S_ORN2_B32 : SALUBinaryWithSCCOp<"s_orn2_b32">; +def WaveASM_S_ORN2_B64 : SALUBinaryWithSCCOp<"s_orn2_b64">; +def WaveASM_S_NAND_B32 : SALUBinaryWithSCCOp<"s_nand_b32">; +def WaveASM_S_NAND_B64 : SALUBinaryWithSCCOp<"s_nand_b64">; +def WaveASM_S_NOR_B32 : SALUBinaryWithSCCOp<"s_nor_b32">; +def WaveASM_S_NOR_B64 : SALUBinaryWithSCCOp<"s_nor_b64">; +def WaveASM_S_XNOR_B32 : SALUBinaryWithSCCOp<"s_xnor_b32">; +def WaveASM_S_XNOR_B64 : SALUBinaryWithSCCOp<"s_xnor_b64">; // Shifts -def WaveASM_S_LSHL_B32 : SALUBinaryOp<"s_lshl_b32">; -def WaveASM_S_LSHL_B64 : SALUBinaryOp<"s_lshl_b64">; -def WaveASM_S_LSHR_B32 : SALUBinaryOp<"s_lshr_b32">; -def WaveASM_S_LSHR_B64 : SALUBinaryOp<"s_lshr_b64">; -def WaveASM_S_ASHR_I32 : SALUBinaryOp<"s_ashr_i32">; -def WaveASM_S_ASHR_I64 : SALUBinaryOp<"s_ashr_i64">; +def WaveASM_S_LSHL_B32 : SALUBinaryWithSCCOp<"s_lshl_b32">; +def WaveASM_S_LSHL_B64 : SALUBinaryWithSCCOp<"s_lshl_b64">; +def WaveASM_S_LSHR_B32 : SALUBinaryWithSCCOp<"s_lshr_b32">; +def WaveASM_S_LSHR_B64 : SALUBinaryWithSCCOp<"s_lshr_b64">; +def WaveASM_S_ASHR_I32 : SALUBinaryWithSCCOp<"s_ashr_i32">; +def WaveASM_S_ASHR_I64 : SALUBinaryWithSCCOp<"s_ashr_i64">; // Min/Max -def WaveASM_S_MIN_I32 : SALUBinaryOp<"s_min_i32">; -def WaveASM_S_MIN_U32 : SALUBinaryOp<"s_min_u32">; -def WaveASM_S_MAX_I32 : SALUBinaryOp<"s_max_i32">; -def WaveASM_S_MAX_U32 : SALUBinaryOp<"s_max_u32">; +def WaveASM_S_MIN_I32 : SALUBinaryWithSCCOp<"s_min_i32">; +def WaveASM_S_MIN_U32 : SALUBinaryWithSCCOp<"s_min_u32">; +def WaveASM_S_MAX_I32 : SALUBinaryWithSCCOp<"s_max_i32">; +def WaveASM_S_MAX_U32 : SALUBinaryWithSCCOp<"s_max_u32">; // Misc -def WaveASM_S_BFE_U32 : SALUBinaryOp<"s_bfe_u32">; -def WaveASM_S_BFE_I32 : SALUBinaryOp<"s_bfe_i32">; -def WaveASM_S_BFE_U64 : SALUBinaryOp<"s_bfe_u64">; -def WaveASM_S_BFE_I64 : SALUBinaryOp<"s_bfe_i64">; +def WaveASM_S_BFE_U32 : SALUBinaryWithSCCOp<"s_bfe_u32">; +def WaveASM_S_BFE_I32 : SALUBinaryWithSCCOp<"s_bfe_i32">; +def WaveASM_S_BFE_U64 : SALUBinaryWithSCCOp<"s_bfe_u64">; +def WaveASM_S_BFE_I64 : SALUBinaryWithSCCOp<"s_bfe_i64">; def WaveASM_S_BFM_B32 : SALUBinaryOp<"s_bfm_b32">; def WaveASM_S_BFM_B64 : SALUBinaryOp<"s_bfm_b64">; def WaveASM_S_PACK_LL_B32_B16 : SALUBinaryOp<"s_pack_ll_b32_b16">; diff --git a/waveasm/include/waveasm/Dialect/WaveASMTypes.td b/waveasm/include/waveasm/Dialect/WaveASMTypes.td index 5315c4e052..988005a154 100644 --- a/waveasm/include/waveasm/Dialect/WaveASMTypes.td +++ b/waveasm/include/waveasm/Dialect/WaveASMTypes.td @@ -29,6 +29,25 @@ def WaveASM_RegClass : I32EnumAttr<"RegClass", "Register class", [ let cppNamespace = "::waveasm"; } +//===----------------------------------------------------------------------===// +// Comparison Predicate Enum +//===----------------------------------------------------------------------===// + +def WaveASM_CmpPred : I64EnumAttr<"CmpPredicate", "Integer comparison predicate", [ + I64EnumAttrCase<"eq", 0, "eq">, + I64EnumAttrCase<"ne", 1, "ne">, + I64EnumAttrCase<"slt", 2, "slt">, + I64EnumAttrCase<"sle", 3, "sle">, + I64EnumAttrCase<"sgt", 4, "sgt">, + I64EnumAttrCase<"sge", 5, "sge">, + I64EnumAttrCase<"ult", 6, "ult">, + I64EnumAttrCase<"ule", 7, "ule">, + I64EnumAttrCase<"ugt", 8, "ugt">, + I64EnumAttrCase<"uge", 9, "uge">, +]> { + let cppNamespace = "::waveasm"; +} + //===----------------------------------------------------------------------===// // Virtual Register Types (Pre-Allocation) //===----------------------------------------------------------------------===// diff --git a/waveasm/include/waveasm/Transforms/AssemblyEmitter.h b/waveasm/include/waveasm/Transforms/AssemblyEmitter.h index 87eaff24d8..2e5f2d774a 100644 --- a/waveasm/include/waveasm/Transforms/AssemblyEmitter.h +++ b/waveasm/include/waveasm/Transforms/AssemblyEmitter.h @@ -19,6 +19,17 @@ namespace waveasm { +/// Return the kernel name for assembly emission. +/// Uses the "kernel_name" attribute if present, otherwise falls back to +/// sym_name. The kernel_name attribute is set when the program is given a +/// mangled sym_name to avoid symbol collisions with the original llvm.func. +inline mlir::StringRef getKernelName(ProgramOp program) { + auto attrName = WaveASMDialect::getKernelNameAttrName(); + if (auto attr = program->getAttrOfType(attrName)) + return attr.getValue(); + return program.getSymName(); +} + //===----------------------------------------------------------------------===// // Instruction Formatter //===----------------------------------------------------------------------===// diff --git a/waveasm/include/waveasm/Transforms/Passes.td b/waveasm/include/waveasm/Transforms/Passes.td index 8364b206ad..674bd21193 100644 --- a/waveasm/include/waveasm/Transforms/Passes.td +++ b/waveasm/include/waveasm/Transforms/Passes.td @@ -38,8 +38,8 @@ def WAVEASMLinearScan : Pass<"waveasm-linear-scan"> { let options = [ Option<"maxVGPRs", "max-vgprs", "int64_t", "256", "Maximum VGPRs available">, - Option<"maxSGPRs", "max-sgprs", "int64_t", "104", - "Maximum SGPRs available">, + Option<"maxSGPRs", "max-sgprs", "int64_t", "102", + "Maximum user-addressable SGPRs (s0-s101 on GFX9+)">, Option<"maxAGPRs", "max-agprs", "int64_t", "256", "Maximum AGPRs available"> ]; @@ -246,6 +246,36 @@ def WAVEASMExtractScalarization ]; } +//===----------------------------------------------------------------------===// +// SCC Verifier Pass +//===----------------------------------------------------------------------===// + +def WAVEASMSCCVerifier : Pass<"waveasm-scc-verifier"> { + let summary = "Verify SCC (Scalar Condition Code) liveness invariants"; + let description = [{ + Checks that no SCC-clobbering SALU instruction sits between an + SCC-producing op and its consumer. Uses isa<> checks (not traits) + to avoid changing ODS-generated code for existing op classes. + }]; + let dependentDialects = ["::waveasm::WaveASMDialect"]; +} + +//===----------------------------------------------------------------------===// +// Memory Offset Optimization Pass +//===----------------------------------------------------------------------===// + +def WAVEASMVGPRCompaction : Pass<"waveasm-vgpr-compaction"> { + let summary = "Compact physical VGPR assignments to reduce peak register count"; + let description = [{ + Re-assigns physical VGPRs after linear scan allocation to eliminate + fragmentation caused by interleaved buffer_load and ds_read destinations. + Uses a shortest-first greedy strategy: short-lived values get low + registers, long-lived values get high registers, eliminating gaps + from interleaving. + }]; + let dependentDialects = ["::waveasm::WaveASMDialect"]; +} + //===----------------------------------------------------------------------===// // Memory Offset Optimization Pass //===----------------------------------------------------------------------===// @@ -262,4 +292,80 @@ def WAVEASMMemoryOffsetOpt : Pass<"waveasm-memory-offset-opt"> { let dependentDialects = ["::waveasm::WaveASMDialect"]; } +//===----------------------------------------------------------------------===// +// Translate from LLVM Dialect Pass +//===----------------------------------------------------------------------===// + +def WAVEASMGPUModuleToBinary : Pass<"waveasm-gpu-module-to-binary"> { + let summary = "Emit assembly and link gpu.module { waveasm.program } to gpu.binary"; + let description = [{ + Final pass in the WaveASM pipeline. Expects gpu.module ops containing + register-allocated waveasm.program ops. Emits AMDGCN assembly, assembles + + links to HSACO via LLVM MC, and replaces each gpu.module with a + gpu.binary holding the code object. + }]; + + let options = [ + Option<"targetArch", "target", "std::string", "\"gfx942\"", + "Target GPU architecture">, + Option<"lldPath", "lld-path", "std::string", "\"\"", + "Path to ld.lld for linking HSACO"> + ]; + + let dependentDialects = [ + "::waveasm::WaveASMDialect", + "::mlir::gpu::GPUDialect" + ]; +} + +def WAVEASMArithLegalization : Pass<"waveasm-arith-legalization"> { + let summary = "Lower generic arithmetic pseudo-ops to concrete SALU/VALU ops"; + let description = [{ + Lowers register-file-agnostic arith pseudo-ops (arith.add, arith.mul, + arith.cmp, arith.select, arith.trunc, arith.sext, arith.zext) to + concrete SALU or VALU machine ops. + + Register file assignment (demand-driven divergence): + - If any operand is a VGPR → use VALU, result is VGPR. + - If all operands are SGPR/immediate → use SALU, result is SGPR. + + Constant bus enforcement: + - VALU ops can read at most one SGPR per instruction. + - If both operands are SGPRs, one is moved to VGPR via v_mov_b32. + + i64 legalization (split into i32 halves): + - add: carry chain (s_add_u32 + s_addc_u32 / v_add_co_u32 + v_addc_co_u32). + - mul: schoolbook decomposition (lo*lo, hi*lo, lo*hi). + - cmp eq/ne: XOR each half, OR, compare to zero. + - cmp ordered: hi/lo decomposition with signed hi / unsigned lo. + - select: split both operands, v_cndmask_b32 each half, merge. + - sext/zext: extend to {lo, sign/zero} pair. + - trunc: extract lo half. + + Reports errors on unsupported widths (anything other than i32/i64). + }]; + + let statistics = [ + Statistic<"numOpsLegalized", "Pseudo-ops legalized", "count">, + Statistic<"numSGPRToVGPRMoves", "SGPR-to-VGPR moves inserted", "count"> + ]; + + let dependentDialects = ["::waveasm::WaveASMDialect"]; +} + +def WAVEASMTranslateFromLLVM : Pass<"waveasm-translate-from-llvm"> { + let summary = "Translate LLVM dialect kernels to WaveASM IR"; + let description = [{ + Translates gpu.module { llvm.func } with rocdl intrinsics into + waveasm.program ops. Fails on any unhandled op (strict mode). + }]; + + let options = [ + Option<"targetArch", "target", "std::string", "\"gfx950\"", + "Target GPU architecture"> + ]; + + let dependentDialects = ["::waveasm::WaveASMDialect"]; +} + #endif // WaveASM_TRANSFORMS_PASSES diff --git a/waveasm/include/waveasm/Transforms/RegAlloc.h b/waveasm/include/waveasm/Transforms/RegAlloc.h index 523fec57ae..d3d5dee371 100644 --- a/waveasm/include/waveasm/Transforms/RegAlloc.h +++ b/waveasm/include/waveasm/Transforms/RegAlloc.h @@ -11,6 +11,7 @@ #include "waveasm/Dialect/WaveASMOps.h" #include "waveasm/Dialect/WaveASMTypes.h" #include "waveasm/Transforms/Liveness.h" +#include "llvm/ADT/BitVector.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SmallVector.h" @@ -66,126 +67,148 @@ class PhysicalMapping { // Register Pool //===----------------------------------------------------------------------===// -/// Pool of available physical registers +/// Pool of available physical registers. +/// Uses a BitVector for O(1) per-register operations and cache-friendly +/// scanning. Typical GPU register files (256 VGPRs, 102 SGPRs, 256 AGPRs) +/// fit in a few 64-bit words. class RegPool { public: RegPool(RegClass regClass, int64_t maxRegs, const llvm::DenseSet &reserved) - : regClass(regClass), maxRegs(maxRegs) { - // Initialize free list with all non-reserved registers - for (int64_t i = 0; i < maxRegs; ++i) { - if (!reserved.contains(i)) { - freeList.push_back(i); - } + : regClass(regClass), maxRegs(maxRegs), free(maxRegs, true) { + for (int64_t r : reserved) { + if (r >= 0 && r < maxRegs) + free.reset(r); } } - /// Check if a register is currently in the free list + /// O(1) free-check via bit test. bool isFree(int64_t reg) const { - return std::find(freeList.begin(), freeList.end(), reg) != freeList.end(); + return reg >= 0 && reg < maxRegs && free.test(reg); } - /// Reserve a specific register (for precoloring) + /// Reserve a specific register range (for precoloring / re-reservation). void reserve(int64_t reg, int64_t size) { for (int64_t i = 0; i < size; ++i) { int64_t r = reg + i; - auto it = std::find(freeList.begin(), freeList.end(), r); - if (it != freeList.end()) { - freeList.erase(it); - allocated.insert(r); + if (r < maxRegs && free.test(r)) { + free.reset(r); + ++currentUsage; } } updatePeak(); } - /// Allocate a single register - /// Returns -1 if allocation fails + /// Allocate the lowest-numbered free register. + /// Returns -1 if no register is available. int64_t allocSingle() { - if (freeList.empty()) + int idx = free.find_first(); + if (idx < 0) return -1; - - int64_t reg = freeList.front(); - freeList.erase(freeList.begin()); - allocated.insert(reg); + free.reset(idx); + ++currentUsage; updatePeak(); - return reg; + return idx; } - /// Allocate a contiguous range of registers with alignment - /// Returns base register index, or -1 if allocation fails + /// Allocate a contiguous range of registers with the given alignment. + /// Scans by alignment stride for O(maxRegs/alignment) candidate checks. + /// Returns base register index, or -1 if allocation fails. int64_t allocRange(int64_t size, int64_t alignment) { if (size <= 0) return -1; - llvm::DenseSet freeSet(freeList.begin(), freeList.end()); + for (int64_t c = 0; c + size <= maxRegs; c += alignment) { + bool allFree = true; + for (int64_t o = 0; o < size; ++o) { + if (!free.test(c + o)) { + allFree = false; + break; + } + } + if (allFree) { + for (int64_t o = 0; o < size; ++o) + free.reset(c + o); + currentUsage += size; + updatePeak(); + return c; + } + } + return -1; + } - for (int64_t candidate : freeList) { - // Check alignment - if (candidate % alignment != 0) - continue; + /// Allocate the highest-numbered free register below ceiling. + int64_t allocSingleFromTop(int64_t ceiling = -1) { + int64_t cap = (ceiling > 0 && ceiling <= maxRegs) ? ceiling : maxRegs; + // Scan from cap-1 downward for the first free register. + for (int64_t i = cap - 1; i >= 0; --i) { + if (free.test(i)) { + free.reset(i); + ++currentUsage; + updatePeak(); + return i; + } + } + return -1; + } + + /// Allocate a contiguous range from the top of a capped region. + /// Scans from `ceiling` downward to pack long-lived values at the top + /// of the USED range, not the top of the entire register file. + int64_t allocRangeFromTop(int64_t size, int64_t alignment, + int64_t ceiling = -1) { + if (size <= 0) + return -1; - // Check if all registers in range are free + int64_t cap = (ceiling > 0 && ceiling <= maxRegs) ? ceiling : maxRegs; + int64_t highestBase = ((cap - size) / alignment) * alignment; + for (int64_t c = highestBase; c >= 0; c -= alignment) { bool allFree = true; - for (int64_t offset = 0; offset < size; ++offset) { - int64_t reg = candidate + offset; - if (reg >= maxRegs || !freeSet.contains(reg)) { + for (int64_t o = 0; o < size; ++o) { + if (!free.test(c + o)) { allFree = false; break; } } - if (allFree) { - // Allocate the range - for (int64_t offset = 0; offset < size; ++offset) { - int64_t reg = candidate + offset; - freeList.erase(std::find(freeList.begin(), freeList.end(), reg)); - allocated.insert(reg); - } + for (int64_t o = 0; o < size; ++o) + free.reset(c + o); + currentUsage += size; updatePeak(); - return candidate; + return c; } } - - return -1; // Allocation failed + return -1; } - /// Free a single register + /// Free a single register back to the pool. void freeSingle(int64_t reg) { - if (!allocated.contains(reg)) + if (reg < 0 || reg >= maxRegs || free.test(reg)) return; - - allocated.erase(reg); - - // Insert in sorted order - auto it = std::lower_bound(freeList.begin(), freeList.end(), reg); - freeList.insert(it, reg); + free.set(reg); + --currentUsage; } - /// Free a range of registers + /// Free a contiguous range of registers. void freeRange(int64_t base, int64_t size) { for (int64_t offset = 0; offset < size; ++offset) { freeSingle(base + offset); } } - /// Get peak usage int64_t getPeakUsage() const { return peak; } - /// Get current usage - int64_t getCurrentUsage() const { return allocated.size(); } + int64_t getCurrentUsage() const { return currentUsage; } - /// Get register class RegClass getRegClass() const { return regClass; } private: - void updatePeak() { - peak = std::max(peak, static_cast(allocated.size())); - } + void updatePeak() { peak = std::max(peak, currentUsage); } RegClass regClass; int64_t maxRegs; - llvm::SmallVector freeList; // Sorted - llvm::DenseSet allocated; + llvm::BitVector free; + int64_t currentUsage = 0; int64_t peak = 0; }; diff --git a/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h b/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h index 96b6c0c11b..2fc46f39cd 100644 --- a/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h +++ b/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h @@ -133,17 +133,20 @@ struct BitRange { std::max(int64_t(0), highBit - n)); } - /// Create range from a constant value + /// Create range from a constant value. + /// Negative constants use their unsigned 32-bit representation so the + /// OR-for-add optimisation in AffineHandlers sees the actually-set bits. static BitRange fromConstant(int64_t value) { if (value == 0) return BitRange(0, 0); - int64_t bits = 0; - int64_t v = value; - while (v > 0) { - bits++; - v >>= 1; - } - return BitRange(0, bits - 1); + // Work with the 32-bit unsigned representation so negative values + // (e.g. -16 = 0xFFFFFFF0) are handled correctly. + uint32_t uval = static_cast(value); + if (uval == 0) + return BitRange(0, 0); + int low = llvm::countr_zero(uval); + int high = 31 - llvm::countl_zero(uval); + return BitRange(low, high); } /// Create range for a value with known max @@ -386,8 +389,7 @@ class TranslationContext { /// Get next available swizzle SRD index (for cache swizzle SRDs and /// per-workgroup SRD base adjustments). /// These are allocated after all regular SRDs, computed in emitSRDPrologue(). - /// Each allocation reserves 5 SGPRs: s[N..N+3] for the SRD quad plus - /// s[N+4] as a temporary for the multiply low-half result. + /// Each allocation reserves 4 SGPRs: s[N..N+3] for the SRD quad. int64_t getNextSwizzleSRDIndex() { if (nextSwizzleSRDIndex < 0) { int64_t maxSrdEnd = 24; @@ -398,16 +400,15 @@ class TranslationContext { nextSwizzleSRDIndex = (maxSrdEnd + 3) & ~3; // Align to 4 } int64_t idx = nextSwizzleSRDIndex; - // 4 SRD SGPRs + 1 temp for byteOffLo, padded to next 4-aligned index - // (SRD buffer descriptors require 4-SGPR alignment on AMDGCN). - nextSwizzleSRDIndex = (idx + 5 + 3) & ~3; + // SRD buffer descriptors require 4-SGPR alignment on AMDGCN. + nextSwizzleSRDIndex = idx + 4; return idx; } /// Update buffer size for a pending SRD (called when we see reinterpret_cast) void updateSRDBufferSize(mlir::Value memref, int64_t bufferSize); - /// Get the number of kernel arguments (bindings + scalar args) + /// Get the number of kernel arguments (bindings + scalar args). size_t getNumKernelArgs() const { return pendingSRDs.size() + pendingScalarArgs.size(); } @@ -642,11 +643,29 @@ class TranslationContext { return count; } - /// Get SGPR index for workgroup ID in the given dimension (0=x, 1=y, 2=z) - /// System SGPRs (workgroup IDs) come after user SGPRs + /// Get SGPR index for workgroup ID in the given dimension (0=x, 1=y, 2=z). + /// System SGPRs are only allocated for enabled dimensions, so we count + /// only the enabled IDs before the requested dimension. int64_t getWorkgroupIdSgprIndex(int dimension) const { int64_t baseIndex = getUserSgprCount(); - return baseIndex + dimension; + // When all three IDs are enabled (e.g., via enableAllWorkgroupIds()), + // the layout is simply base + dimension. + if (usesWorkgroupIdX && usesWorkgroupIdY && usesWorkgroupIdZ) + return baseIndex + dimension; + int64_t offset = 0; + if (dimension > 0 && usesWorkgroupIdX) + offset++; + if (dimension > 1 && usesWorkgroupIdY) + offset++; + return baseIndex + offset; + } + + /// Enable all three workgroup IDs so that the SGPR layout is predictable. + /// Call this before translating any ops to avoid ordering dependencies. + void enableAllWorkgroupIds() { + usesWorkgroupIdX = true; + usesWorkgroupIdY = true; + usesWorkgroupIdZ = true; } /// Check if this is a multi-wave kernel (more than 64 threads per workgroup) diff --git a/waveasm/lib/Dialect/WaveASMOps.cpp b/waveasm/lib/Dialect/WaveASMOps.cpp index ed9c3b98a8..f85c56b04f 100644 --- a/waveasm/lib/Dialect/WaveASMOps.cpp +++ b/waveasm/lib/Dialect/WaveASMOps.cpp @@ -137,8 +137,13 @@ LogicalResult ExtractOp::verify() { return emitOpError() << "index must be non-negative, got " << index; } - // Preserve vector register bank (VGPR <-> VGPR, AGPR <-> AGPR). - if (isAGPRType(vectorType) != isAGPRType(resultType)) { + // Preserve register bank (VGPR <-> VGPR, AGPR <-> AGPR, SGPR <-> SGPR). + bool bankMismatch = false; + if (isSGPRType(vectorType) != isSGPRType(resultType)) + bankMismatch = true; + else if (isAGPRType(vectorType) != isAGPRType(resultType)) + bankMismatch = true; + if (bankMismatch) { return emitOpError() << "result register class must match source register class: source " << vectorType << ", result " << resultType; diff --git a/waveasm/lib/Transforms/ArithLegalization.cpp b/waveasm/lib/Transforms/ArithLegalization.cpp new file mode 100644 index 0000000000..6fecc0638b --- /dev/null +++ b/waveasm/lib/Transforms/ArithLegalization.cpp @@ -0,0 +1,795 @@ +// Copyright 2026 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// Arithmetic Legalization Pass +// +// Lowers generic arithmetic pseudo-ops (arith.add, arith.mul, arith.cmp, +// arith.select, arith.trunc, arith.sext, arith.zext) to concrete SALU or +// VALU machine ops based on operand register files and widths (i32/i64). +//===----------------------------------------------------------------------===// + +#include "waveasm/Dialect/WaveASMDialect.h" +#include "waveasm/Dialect/WaveASMOps.h" +#include "waveasm/Dialect/WaveASMTypes.h" +#include "waveasm/Transforms/Passes.h" + +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace waveasm { +#define GEN_PASS_DEF_WAVEASMARITHLEGALIZATION +#include "waveasm/Transforms/Passes.h.inc" +} // namespace waveasm + +using namespace mlir; +using namespace waveasm; + +//===----------------------------------------------------------------------===// +// Width and register-file helpers +//===----------------------------------------------------------------------===// + +/// Return register width in 32-bit units (1 = i32, 2 = i64). +/// Returns 0 for unsupported types. +static int64_t getRegWidth(Value v) { + return TypeSwitch(v.getType()) + .Case( + [](auto t) { return t.getSize(); }) + .Case([](ImmType) { return int64_t(1); }) + .Default([](Type) { return int64_t(0); }); +} + +/// Return true if any operand is a VGPR (divergent context). +static bool anyVGPR(ValueRange operands) { + return llvm::any_of(operands, + [](Value v) { return isVGPRType(v.getType()); }); +} + +/// Validate that width is exactly 1 (i32) or 2 (i64). +/// Returns failure and emits an error on the op otherwise. +static LogicalResult checkWidth(Operation *op, int64_t width) { + if (width == 1 || width == 2) + return success(); + op->emitError("unsupported operand width (expected i32 or i64, got ") + << width << " dwords)"; + return failure(); +} + +//===----------------------------------------------------------------------===// +// i64 split/merge helpers +//===----------------------------------------------------------------------===// + +/// Split an i64 (size-2) register into {lo, hi} i32 halves. +static std::pair splitI64(Value v, OpBuilder &builder, + Location loc) { + // Look through pack ops to avoid extract/pack round-trips. + // The register allocator does not insert copies for pack, so extracting + // from a pack whose inputs are at different physical registers would + // read stale data. + if (auto pack = v.getDefiningOp()) { + auto operands = pack.getOperands(); + assert(operands.size() == 2 && "expected 2-element pack for i64"); + return {operands[0], operands[1]}; + } + if (isSGPRType(v.getType())) { + // For precolored SGPRs, create precolored extracts at known indices. + if (auto psreg = v.getDefiningOp()) { + auto sregTy = SRegType::get(builder.getContext(), 1, 1); + Value lo = PrecoloredSRegOp::create(builder, loc, sregTy, + psreg.getIndex(), /*size=*/1); + Value hi = PrecoloredSRegOp::create(builder, loc, sregTy, + psreg.getIndex() + 1, /*size=*/1); + return {lo, hi}; + } + auto sregTy = SRegType::get(builder.getContext(), 1, 1); + Value lo = ExtractOp::create(builder, loc, sregTy, v, 0); + Value hi = ExtractOp::create(builder, loc, sregTy, v, 1); + return {lo, hi}; + } + auto vregTy = VRegType::get(builder.getContext()); + Value lo = ExtractOp::create(builder, loc, vregTy, v, 0); + Value hi = ExtractOp::create(builder, loc, vregTy, v, 1); + return {lo, hi}; +} + +/// Merge {lo, hi} i32 values into an i64 (size-2) register. +static Value mergeI64(Value lo, Value hi, OpBuilder &builder, Location loc) { + if (isSGPRType(lo.getType())) { + auto sregTy = SRegType::get(builder.getContext(), 2, 2); + return PackOp::create(builder, loc, sregTy, ValueRange{lo, hi}); + } + auto vregTy = VRegType::get(builder.getContext(), 2); + return PackOp::create(builder, loc, vregTy, ValueRange{lo, hi}); +} + +//===----------------------------------------------------------------------===// +// Register file conversion helpers +//===----------------------------------------------------------------------===// + +/// Move an SGPR i32 value to a VGPR via v_mov_b32. +static Value sgprToVgpr(Value v, OpBuilder &builder, Location loc) { + if (!isSGPRType(v.getType())) + return v; + auto vregTy = VRegType::get(builder.getContext()); + return V_MOV_B32::create(builder, loc, vregTy, v); +} + +//===----------------------------------------------------------------------===// +// i32 legalization +//===----------------------------------------------------------------------===// + +static void legalizeAddI32(Value lhs, Value rhs, ArithAddOp op, + OpBuilder &builder) { + Location loc = op.getLoc(); + Value result; + if (anyVGPR({lhs, rhs})) { + lhs = sgprToVgpr(lhs, builder, loc); + auto vregTy = VRegType::get(builder.getContext()); + result = V_ADD_U32::create(builder, loc, vregTy, lhs, rhs); + } else { + auto sregTy = SRegType::get(builder.getContext(), 1, 1); + result = S_ADD_U32::create(builder, loc, sregTy, sregTy, lhs, rhs).getDst(); + } + op.replaceAllUsesWith(result); +} + +static void legalizeMulI32(Value lhs, Value rhs, ArithMulOp op, + OpBuilder &builder) { + Location loc = op.getLoc(); + Value result; + if (anyVGPR({lhs, rhs})) { + lhs = sgprToVgpr(lhs, builder, loc); + auto vregTy = VRegType::get(builder.getContext()); + result = V_MUL_LO_U32::create(builder, loc, vregTy, lhs, rhs); + } else { + auto sregTy = SRegType::get(builder.getContext(), 1, 1); + result = S_MUL_I32::create(builder, loc, sregTy, lhs, rhs); + } + op.replaceAllUsesWith(result); +} + +//===----------------------------------------------------------------------===// +// i64 legalization +//===----------------------------------------------------------------------===// + +/// i64 add via carry chain: s_add_u32 + s_addc_u32 (SALU) or +/// v_add_co_u32 + v_addc_co_u32 (VALU). +/// NOTE: The carry between the two ops is implicit (SCC/VCC). They must +/// remain adjacent -- do not schedule or insert ops between them. +static void legalizeAddI64(Value lhs, Value rhs, ArithAddOp op, + OpBuilder &builder) { + Location loc = op.getLoc(); + auto [lhsLo, lhsHi] = splitI64(lhs, builder, loc); + auto [rhsLo, rhsHi] = splitI64(rhs, builder, loc); + + Value loResult, hiResult; + if (anyVGPR({lhs, rhs})) { + lhsLo = sgprToVgpr(lhsLo, builder, loc); + lhsHi = sgprToVgpr(lhsHi, builder, loc); + auto vregTy = VRegType::get(builder.getContext()); + auto sregTy = SRegType::get(builder.getContext(), 1, 1); + // v_add_co_u32: lo + lo, carry out to VCC. + auto addLo = + V_ADD_CO_U32::create(builder, loc, vregTy, sregTy, lhsLo, rhsLo); + loResult = addLo.getDst(); + // v_addc_co_u32: hi + hi + carry in from VCC. + auto addHi = + V_ADDC_CO_U32::create(builder, loc, vregTy, sregTy, lhsHi, rhsHi); + hiResult = addHi.getDst(); + } else { + auto sregTy = SRegType::get(builder.getContext(), 1, 1); + // s_add_u32: lo + lo, carry out to SCC. + auto addLo = S_ADD_U32::create(builder, loc, sregTy, sregTy, lhsLo, rhsLo); + loResult = addLo.getDst(); + // s_addc_u32: hi + hi + carry in from SCC. + auto addHi = S_ADDC_U32::create(builder, loc, sregTy, sregTy, lhsHi, rhsHi); + hiResult = addHi.getDst(); + } + + Value result = mergeI64(loResult, hiResult, builder, loc); + op.replaceAllUsesWith(result); +} + +/// i64 multiply via schoolbook decomposition: +/// result_lo = mul_lo(a_lo, b_lo) +/// result_hi = mul_hi(a_lo, b_lo) + mul_lo(a_lo, b_hi) + mul_lo(a_hi, b_lo) +static void legalizeMulI64(Value lhs, Value rhs, ArithMulOp op, + OpBuilder &builder) { + Location loc = op.getLoc(); + auto [aLo, aHi] = splitI64(lhs, builder, loc); + auto [bLo, bHi] = splitI64(rhs, builder, loc); + + Value loResult, hiResult; + if (anyVGPR({lhs, rhs})) { + aLo = sgprToVgpr(aLo, builder, loc); + aHi = sgprToVgpr(aHi, builder, loc); + auto vregTy = VRegType::get(builder.getContext()); + // lo = mul_lo(a_lo, b_lo). + loResult = V_MUL_LO_U32::create(builder, loc, vregTy, aLo, bLo); + // hi = mul_hi(a_lo, b_lo) + mul_lo(a_lo, b_hi) + mul_lo(a_hi, b_lo). + Value hiPartial = V_MUL_HI_U32::create(builder, loc, vregTy, aLo, bLo); + Value cross1 = V_MUL_LO_U32::create(builder, loc, vregTy, aLo, bHi); + Value cross2 = V_MUL_LO_U32::create(builder, loc, vregTy, aHi, bLo); + // Accumulate with v_add3_u32 (3-input add, no carry needed since we + // discard bits above 64). + hiResult = + V_ADD3_U32::create(builder, loc, vregTy, hiPartial, cross1, cross2); + } else { + auto sregTy = SRegType::get(builder.getContext(), 1, 1); + // lo = mul_lo(a_lo, b_lo). + loResult = S_MUL_I32::create(builder, loc, sregTy, aLo, bLo); + // hi = mul_hi(a_lo, b_lo) + mul_lo(a_lo, b_hi) + mul_lo(a_hi, b_lo). + Value hiPartial = S_MUL_HI_U32::create(builder, loc, sregTy, aLo, bLo); + Value cross1 = S_MUL_I32::create(builder, loc, sregTy, aLo, bHi); + Value cross2 = S_MUL_I32::create(builder, loc, sregTy, aHi, bLo); + // Accumulate (carry discarded -- computing mod 2^64). + Value hiTemp = + S_ADD_U32::create(builder, loc, sregTy, sregTy, hiPartial, cross1) + .getDst(); + hiResult = S_ADD_U32::create(builder, loc, sregTy, sregTy, hiTemp, cross2) + .getDst(); + } + + Value result = mergeI64(loResult, hiResult, builder, loc); + op.replaceAllUsesWith(result); +} + +/// Emit a VCC-setting compare and return a VCC placeholder. +static Value emitVCmp(CmpPredicate pred, Value lhs, Value rhs, + OpBuilder &builder, Location loc) { + switch (pred) { + case CmpPredicate::eq: + V_CMP_EQ_I32::create(builder, loc, lhs, rhs); + break; + case CmpPredicate::ne: + V_CMP_NE_I32::create(builder, loc, lhs, rhs); + break; + case CmpPredicate::slt: + V_CMP_LT_I32::create(builder, loc, lhs, rhs); + break; + case CmpPredicate::sle: + V_CMP_LE_I32::create(builder, loc, lhs, rhs); + break; + case CmpPredicate::sgt: + V_CMP_GT_I32::create(builder, loc, lhs, rhs); + break; + case CmpPredicate::sge: + V_CMP_GE_I32::create(builder, loc, lhs, rhs); + break; + case CmpPredicate::ult: + V_CMP_LT_U32::create(builder, loc, lhs, rhs); + break; + case CmpPredicate::ule: + V_CMP_LE_U32::create(builder, loc, lhs, rhs); + break; + case CmpPredicate::ugt: + V_CMP_GT_U32::create(builder, loc, lhs, rhs); + break; + case CmpPredicate::uge: + V_CMP_GE_U32::create(builder, loc, lhs, rhs); + break; + } + auto ty = ImmType::get(builder.getContext(), 1); + return ConstantOp::create(builder, loc, ty, 1); +} + +/// Emit an SCC-setting compare and return the sreg result. +static Value emitSCmp(CmpPredicate pred, Value lhs, Value rhs, + OpBuilder &builder, Location loc) { + auto sregTy = SRegType::get(builder.getContext(), 1, 1); + switch (pred) { + case CmpPredicate::eq: + return S_CMP_EQ_I32::create(builder, loc, sregTy, lhs, rhs); + case CmpPredicate::ne: + return S_CMP_NE_I32::create(builder, loc, sregTy, lhs, rhs); + case CmpPredicate::slt: + return S_CMP_LT_I32::create(builder, loc, sregTy, lhs, rhs); + case CmpPredicate::sle: + return S_CMP_LE_I32::create(builder, loc, sregTy, lhs, rhs); + case CmpPredicate::sgt: + return S_CMP_GT_I32::create(builder, loc, sregTy, lhs, rhs); + case CmpPredicate::sge: + return S_CMP_GE_I32::create(builder, loc, sregTy, lhs, rhs); + case CmpPredicate::ult: + return S_CMP_LT_U32::create(builder, loc, sregTy, lhs, rhs); + case CmpPredicate::ule: + return S_CMP_LE_U32::create(builder, loc, sregTy, lhs, rhs); + case CmpPredicate::ugt: + return S_CMP_GT_U32::create(builder, loc, sregTy, lhs, rhs); + case CmpPredicate::uge: + return S_CMP_GE_U32::create(builder, loc, sregTy, lhs, rhs); + } + llvm_unreachable("unhandled CmpPredicate"); +} + +/// Get hi/lo predicates for ordered i64 comparison. +/// Hi uses the strict less/greater variant with same signedness. +/// Lo always uses unsigned (lo halves have no sign meaning). +static std::pair +getOrderedI64Preds(CmpPredicate pred) { + switch (pred) { + case CmpPredicate::slt: + return {CmpPredicate::slt, CmpPredicate::ult}; + case CmpPredicate::sle: + return {CmpPredicate::slt, CmpPredicate::ule}; + case CmpPredicate::sgt: + return {CmpPredicate::sgt, CmpPredicate::ugt}; + case CmpPredicate::sge: + return {CmpPredicate::sgt, CmpPredicate::uge}; + case CmpPredicate::ult: + return {CmpPredicate::ult, CmpPredicate::ult}; + case CmpPredicate::ule: + return {CmpPredicate::ult, CmpPredicate::ule}; + case CmpPredicate::ugt: + return {CmpPredicate::ugt, CmpPredicate::ugt}; + case CmpPredicate::uge: + return {CmpPredicate::ugt, CmpPredicate::uge}; + default: + llvm_unreachable("not an ordered predicate"); + } +} + +/// i64 compare: eq/ne via XOR+OR, ordered via hi/lo decomposition. +/// +/// Ordered strategy: result = hiPred(hi_a, hi_b) || +/// (hi_a == hi_b && loPred(lo_a, lo_b)) +/// Hi comparison uses same signedness as original predicate (strict variant). +/// Lo comparison always uses unsigned (lo halves have no sign meaning). +static LogicalResult legalizeCmpI64(Value lhs, Value rhs, CmpPredicate pred, + ArithCmpOp op, OpBuilder &builder) { + Location loc = op.getLoc(); + auto [lhsLo, lhsHi] = splitI64(lhs, builder, loc); + auto [rhsLo, rhsHi] = splitI64(rhs, builder, loc); + + bool isEqNe = pred == CmpPredicate::eq || pred == CmpPredicate::ne; + + if (anyVGPR({lhs, rhs})) { + lhsLo = sgprToVgpr(lhsLo, builder, loc); + lhsHi = sgprToVgpr(lhsHi, builder, loc); + auto vregTy = VRegType::get(builder.getContext()); + + if (isEqNe) { + // eq/ne: XOR each half, OR, compare to zero. + // Materialize to vreg so VCC can be re-established by the consumer. + Value xorLo = V_XOR_B32::create(builder, loc, vregTy, lhsLo, rhsLo); + Value xorHi = V_XOR_B32::create(builder, loc, vregTy, lhsHi, rhsHi); + Value combined = V_OR_B32::create(builder, loc, vregTy, xorLo, xorHi); + auto zeroTy = ImmType::get(builder.getContext(), 0); + Value zero = ConstantOp::create(builder, loc, zeroTy, 0); + auto oneTy = ImmType::get(builder.getContext(), 1); + Value one = ConstantOp::create(builder, loc, oneTy, 1); + if (pred == CmpPredicate::eq) + V_CMP_EQ_I32::create(builder, loc, combined, zero); + else + V_CMP_NE_I32::create(builder, loc, combined, zero); + auto vccPlaceholder = ConstantOp::create(builder, loc, oneTy, 1); + Value boolResult = V_CNDMASK_B32::create(builder, loc, vregTy, zero, one, + vccPlaceholder); + op.replaceAllUsesWith(boolResult); + op.erase(); + return success(); + } else { + // Ordered: materialize hi/lo results, select based on hi equality. + auto [hiPred, loPred] = getOrderedI64Preds(pred); + auto zeroTy = ImmType::get(builder.getContext(), 0); + Value zero = ConstantOp::create(builder, loc, zeroTy, 0); + auto oneTy = ImmType::get(builder.getContext(), 1); + Value one = ConstantOp::create(builder, loc, oneTy, 1); + + // Materialize hi comparison to vreg. + Value hiVcc = emitVCmp(hiPred, lhsHi, rhsHi, builder, loc); + Value hiRes = + V_CNDMASK_B32::create(builder, loc, vregTy, zero, one, hiVcc); + + // Materialize lo comparison to vreg. + Value loVcc = emitVCmp(loPred, lhsLo, rhsLo, builder, loc); + Value loRes = + V_CNDMASK_B32::create(builder, loc, vregTy, zero, one, loVcc); + + // Select: if hi equal, use lo result; else use hi result. + Value eqVcc = emitVCmp(CmpPredicate::eq, lhsHi, rhsHi, builder, loc); + Value finalBool = + V_CNDMASK_B32::create(builder, loc, vregTy, hiRes, loRes, eqVcc); + + // Return the materialized boolean vreg so the consumer (select) can + // re-establish VCC right before its v_cndmask_b32. Setting VCC here + // is unsafe because intervening i64 ops (v_add_co_u32) may clobber it. + op.replaceAllUsesWith(finalBool); + op.erase(); + return success(); + } + } else { + auto sregTy = SRegType::get(builder.getContext(), 1, 1); + + if (isEqNe) { + // eq/ne: XOR each half, OR, compare to zero. + Value xorLo = S_XOR_B32::create(builder, loc, sregTy, lhsLo, rhsLo); + Value xorHi = S_XOR_B32::create(builder, loc, sregTy, lhsHi, rhsHi); + Value combined = S_OR_B32::create(builder, loc, sregTy, xorLo, xorHi); + auto immTy = ImmType::get(builder.getContext(), 0); + Value zero = ConstantOp::create(builder, loc, immTy, 0); + Value result; + if (pred == CmpPredicate::eq) + result = S_CMP_EQ_I32::create(builder, loc, sregTy, combined, zero); + else + result = S_CMP_NE_I32::create(builder, loc, sregTy, combined, zero); + op.replaceAllUsesWith(result); + } else { + // Ordered: result = hiPred(hi) | (hiEq(hi) & loPred(lo)). + auto [hiPred, loPred] = getOrderedI64Preds(pred); + Value hiCmp = emitSCmp(hiPred, lhsHi, rhsHi, builder, loc); + Value hiEq = emitSCmp(CmpPredicate::eq, lhsHi, rhsHi, builder, loc); + Value loCmp = emitSCmp(loPred, lhsLo, rhsLo, builder, loc); + Value eqAndLo = S_AND_B32::create(builder, loc, sregTy, hiEq, loCmp); + Value result = S_OR_B32::create(builder, loc, sregTy, hiCmp, eqAndLo); + op.replaceAllUsesWith(result); + } + } + op.erase(); + return success(); +} + +/// If the condition is a materialized boolean VGPR (from an i64 compare), +/// re-establish VCC right before the v_cndmask_b32 that consumes it. +/// Returns a VCC placeholder suitable for v_cndmask_b32's condition operand. +static Value ensureVCC(Value cond, OpBuilder &builder, Location loc) { + if (isSGPRType(cond.getType())) + cond = sgprToVgpr(cond, builder, loc); + if (isVGPRType(cond.getType())) { + auto zeroTy = ImmType::get(builder.getContext(), 0); + Value zero = ConstantOp::create(builder, loc, zeroTy, 0); + V_CMP_NE_I32::create(builder, loc, cond, zero); + } + auto placeholderTy = ImmType::get(builder.getContext(), 1); + return ConstantOp::create(builder, loc, placeholderTy, 1); +} + +/// i64 select: split both operands, select each half, merge. +static void legalizeSelectI64(Value trueVal, Value falseVal, Value cond, + ArithSelectOp op, OpBuilder &builder) { + Location loc = op.getLoc(); + auto [trueLo, trueHi] = splitI64(trueVal, builder, loc); + auto [falseLo, falseHi] = splitI64(falseVal, builder, loc); + + auto vregTy = VRegType::get(builder.getContext()); + falseLo = sgprToVgpr(falseLo, builder, loc); + trueLo = sgprToVgpr(trueLo, builder, loc); + falseHi = sgprToVgpr(falseHi, builder, loc); + trueHi = sgprToVgpr(trueHi, builder, loc); + + Value vccCond = ensureVCC(cond, builder, loc); + Value selLo = + V_CNDMASK_B32::create(builder, loc, vregTy, falseLo, trueLo, vccCond); + Value selHi = + V_CNDMASK_B32::create(builder, loc, vregTy, falseHi, trueHi, vccCond); + Value result = mergeI64(selLo, selHi, builder, loc); + op.replaceAllUsesWith(result); +} + +//===----------------------------------------------------------------------===// +// Dispatch functions (i32 vs i64) +//===----------------------------------------------------------------------===// + +static LogicalResult legalizeAdd(ArithAddOp op, OpBuilder &builder) { + int64_t width = getRegWidth(op.getLhs()); + if (failed(checkWidth(op, width))) + return failure(); + if (width == 2) + legalizeAddI64(op.getLhs(), op.getRhs(), op, builder); + else + legalizeAddI32(op.getLhs(), op.getRhs(), op, builder); + op.erase(); + return success(); +} + +static LogicalResult legalizeMul(ArithMulOp op, OpBuilder &builder) { + int64_t width = getRegWidth(op.getLhs()); + if (failed(checkWidth(op, width))) + return failure(); + if (width == 2) + legalizeMulI64(op.getLhs(), op.getRhs(), op, builder); + else + legalizeMulI32(op.getLhs(), op.getRhs(), op, builder); + op.erase(); + return success(); +} + +/// Legalize a generic bitwise op (or, and) to SALU/VALU concrete ops. +/// Template parameters: the ArithPseudoOp, and the four machine ops +/// (SALU i32, SALU i64, VALU i32, VALU i64). +template +static LogicalResult legalizeBitwiseOp(ArithOp op, OpBuilder &builder) { + int64_t width = getRegWidth(op.getLhs()); + if (failed(checkWidth(op, width))) + return failure(); + + Location loc = op.getLoc(); + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + Value result; + + if (width == 2) { + // i64: bitwise ops have native s_*_b64 / v_*_b64 instructions. + if (anyVGPR({lhs, rhs})) { + lhs = sgprToVgpr(lhs, builder, loc); + auto vregTy = VRegType::get(builder.getContext(), 2); + result = VALUOp64::create(builder, loc, vregTy, lhs, rhs); + } else { + auto sregTy = SRegType::get(builder.getContext(), 2, 2); + result = SALUOp64::create(builder, loc, sregTy, lhs, rhs); + } + } else { + if (anyVGPR({lhs, rhs})) { + lhs = sgprToVgpr(lhs, builder, loc); + auto vregTy = VRegType::get(builder.getContext()); + result = VALUOp32::create(builder, loc, vregTy, lhs, rhs); + } else { + auto sregTy = SRegType::get(builder.getContext(), 1, 1); + result = SALUOp32::create(builder, loc, sregTy, lhs, rhs); + } + } + op.replaceAllUsesWith(result); + op.erase(); + return success(); +} + +static LogicalResult legalizeCmp(ArithCmpOp op, OpBuilder &builder) { + Location loc = op.getLoc(); + int64_t width = getRegWidth(op.getLhs()); + if (failed(checkWidth(op, width))) + return failure(); + + if (width == 2) + return legalizeCmpI64(op.getLhs(), op.getRhs(), op.getPredicate(), op, + builder); + + // i32 path. + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + auto pred = op.getPredicate(); + + if (anyVGPR({lhs, rhs})) { + if (isSGPRType(lhs.getType())) + lhs = sgprToVgpr(lhs, builder, loc); + + switch (pred) { + case CmpPredicate::eq: + V_CMP_EQ_I32::create(builder, loc, lhs, rhs); + break; + case CmpPredicate::ne: + V_CMP_NE_I32::create(builder, loc, lhs, rhs); + break; + case CmpPredicate::slt: + V_CMP_LT_I32::create(builder, loc, lhs, rhs); + break; + case CmpPredicate::sle: + V_CMP_LE_I32::create(builder, loc, lhs, rhs); + break; + case CmpPredicate::sgt: + V_CMP_GT_I32::create(builder, loc, lhs, rhs); + break; + case CmpPredicate::sge: + V_CMP_GE_I32::create(builder, loc, lhs, rhs); + break; + case CmpPredicate::ult: + V_CMP_LT_U32::create(builder, loc, lhs, rhs); + break; + case CmpPredicate::ule: + V_CMP_LE_U32::create(builder, loc, lhs, rhs); + break; + case CmpPredicate::ugt: + V_CMP_GT_U32::create(builder, loc, lhs, rhs); + break; + case CmpPredicate::uge: + V_CMP_GE_U32::create(builder, loc, lhs, rhs); + break; + } + // VCC-setting compares have no explicit result. Create a placeholder + // constant for uses (v_cndmask_b32 reads VCC implicitly). + auto immTy = ImmType::get(builder.getContext(), 1); + auto placeholder = ConstantOp::create(builder, loc, immTy, 1); + op.replaceAllUsesWith(placeholder.getResult()); + } else { + auto sregTy = SRegType::get(builder.getContext(), 1, 1); + Value result; + switch (pred) { + case CmpPredicate::eq: + result = S_CMP_EQ_I32::create(builder, loc, sregTy, lhs, rhs); + break; + case CmpPredicate::ne: + result = S_CMP_NE_I32::create(builder, loc, sregTy, lhs, rhs); + break; + case CmpPredicate::slt: + result = S_CMP_LT_I32::create(builder, loc, sregTy, lhs, rhs); + break; + case CmpPredicate::sle: + result = S_CMP_LE_I32::create(builder, loc, sregTy, lhs, rhs); + break; + case CmpPredicate::sgt: + result = S_CMP_GT_I32::create(builder, loc, sregTy, lhs, rhs); + break; + case CmpPredicate::sge: + result = S_CMP_GE_I32::create(builder, loc, sregTy, lhs, rhs); + break; + case CmpPredicate::ult: + result = S_CMP_LT_U32::create(builder, loc, sregTy, lhs, rhs); + break; + case CmpPredicate::ule: + result = S_CMP_LE_U32::create(builder, loc, sregTy, lhs, rhs); + break; + case CmpPredicate::ugt: + result = S_CMP_GT_U32::create(builder, loc, sregTy, lhs, rhs); + break; + case CmpPredicate::uge: + result = S_CMP_GE_U32::create(builder, loc, sregTy, lhs, rhs); + break; + } + op.replaceAllUsesWith(result); + } + op.erase(); + return success(); +} + +static LogicalResult legalizeSelect(ArithSelectOp op, OpBuilder &builder) { + Location loc = op.getLoc(); + Value falseVal = op.getFalseVal(); + Value trueVal = op.getTrueVal(); + Value cond = op.getCondition(); + int64_t width = getRegWidth(trueVal); + if (failed(checkWidth(op, width))) + return failure(); + + if (width == 2) { + legalizeSelectI64(trueVal, falseVal, cond, op, builder); + } else { + auto vregTy = VRegType::get(builder.getContext()); + falseVal = sgprToVgpr(falseVal, builder, loc); + trueVal = sgprToVgpr(trueVal, builder, loc); + Value vccCond = ensureVCC(cond, builder, loc); + auto sel = + V_CNDMASK_B32::create(builder, loc, vregTy, falseVal, trueVal, vccCond); + op.replaceAllUsesWith(sel.getResult()); + } + op.erase(); + return success(); +} + +static LogicalResult legalizeTrunc(ArithTruncOp op, OpBuilder &builder) { + Value src = op.getSrc(); + int64_t width = getRegWidth(src); + if (width < 2) { + // Already i32 or narrower -- pass through. + op.replaceAllUsesWith(src); + op.erase(); + return success(); + } + if (failed(checkWidth(op, width))) + return failure(); + + Location loc = op.getLoc(); + // For precolored SGPRs, create a precolored reference to the lo half. + if (auto psreg = src.getDefiningOp()) { + auto sregTy = SRegType::get(builder.getContext(), 1, 1); + Value lo = PrecoloredSRegOp::create(builder, loc, sregTy, psreg.getIndex(), + /*size=*/1); + op.replaceAllUsesWith(lo); + op.erase(); + return success(); + } + auto [lo, hi] = splitI64(src, builder, loc); + (void)hi; + op.replaceAllUsesWith(lo); + op.erase(); + return success(); +} + +static LogicalResult legalizeSExt(ArithSExtOp op, OpBuilder &builder) { + Value src = op.getSrc(); + int64_t srcWidth = getRegWidth(src); + if (srcWidth != 1) { + op.emitError("sext source must be i32 (got ") << srcWidth << " dwords)"; + return failure(); + } + + Location loc = op.getLoc(); + // hi = arithmetic shift right by 31 (sign-fill). + if (isSGPRType(src.getType())) { + auto sregTy = SRegType::get(builder.getContext(), 1, 1); + auto immTy = ImmType::get(builder.getContext(), 31); + Value shift = ConstantOp::create(builder, loc, immTy, 31); + Value hi = S_ASHR_I32::create(builder, loc, sregTy, src, shift); + Value result = mergeI64(src, hi, builder, loc); + op.replaceAllUsesWith(result); + } else { + auto vregTy = VRegType::get(builder.getContext()); + auto immTy = ImmType::get(builder.getContext(), 31); + Value shift = ConstantOp::create(builder, loc, immTy, 31); + // v_ashrrev_i32: dst = src >> shift (reversed operand order). + Value hi = V_ASHRREV_I32::create(builder, loc, vregTy, shift, src); + Value result = mergeI64(src, hi, builder, loc); + op.replaceAllUsesWith(result); + } + op.erase(); + return success(); +} + +static LogicalResult legalizeZExt(ArithZExtOp op, OpBuilder &builder) { + Value src = op.getSrc(); + int64_t srcWidth = getRegWidth(src); + if (srcWidth != 1) { + op.emitError("zext source must be i32 (got ") << srcWidth << " dwords)"; + return failure(); + } + + Location loc = op.getLoc(); + // hi = 0. + if (isSGPRType(src.getType())) { + auto sregTy = SRegType::get(builder.getContext(), 1, 1); + auto immTy = ImmType::get(builder.getContext(), 0); + Value zero = ConstantOp::create(builder, loc, immTy, 0); + Value hi = S_MOV_B32::create(builder, loc, sregTy, zero); + Value result = mergeI64(src, hi, builder, loc); + op.replaceAllUsesWith(result); + } else { + auto vregTy = VRegType::get(builder.getContext()); + auto immTy = ImmType::get(builder.getContext(), 0); + Value zero = ConstantOp::create(builder, loc, immTy, 0); + Value hi = V_MOV_B32::create(builder, loc, vregTy, zero); + Value result = mergeI64(src, hi, builder, loc); + op.replaceAllUsesWith(result); + } + op.erase(); + return success(); +} + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +namespace { + +struct ArithLegalizationPass + : waveasm::impl::WAVEASMArithLegalizationBase { + + void runOnOperation() override { + // Post-order walk is safe for in-place erasure of the current op. + WalkResult walkResult = getOperation()->walk([&](Operation *op) { + if (!isa(op)) + return WalkResult::advance(); + + OpBuilder builder(op); + LogicalResult result = + TypeSwitch(op) + .Case([&](ArithAddOp o) { return legalizeAdd(o, builder); }) + .Case([&](ArithMulOp o) { return legalizeMul(o, builder); }) + .Case([&](ArithOrOp o) { + return legalizeBitwiseOp(o, builder); + }) + .Case([&](ArithAndOp o) { + return legalizeBitwiseOp(o, builder); + }) + .Case([&](ArithCmpOp o) { return legalizeCmp(o, builder); }) + .Case([&](ArithSelectOp o) { return legalizeSelect(o, builder); }) + .Case([&](ArithTruncOp o) { return legalizeTrunc(o, builder); }) + .Case([&](ArithSExtOp o) { return legalizeSExt(o, builder); }) + .Case([&](ArithZExtOp o) { return legalizeZExt(o, builder); }) + .Default([](Operation *) { return success(); }); + if (mlir::failed(result)) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + + if (walkResult.wasInterrupted()) + return signalPassFailure(); + } +}; + +} // namespace diff --git a/waveasm/lib/Transforms/AssemblyEmitter.cpp b/waveasm/lib/Transforms/AssemblyEmitter.cpp index 860f4c1273..f6ce8afb8b 100644 --- a/waveasm/lib/Transforms/AssemblyEmitter.cpp +++ b/waveasm/lib/Transforms/AssemblyEmitter.cpp @@ -166,9 +166,16 @@ std::string KernelGenerator::emitBufferStore(Operation *op, // Use VMEMStoreOpInterface to access operands by name if (auto storeOp = dyn_cast(op)) { std::string vdata = resolveValue(storeOp.getData()); - std::string voffset = resolveValue(storeOp.getVoffset()); + Value voffsetVal = storeOp.getVoffset(); std::string srd = resolveValue(storeOp.getSaddr()); - result += " " + vdata + ", " + voffset + ", " + srd + ", 0 offen"; + + if (isa(voffsetVal.getType())) { + result += " " + vdata + ", off, " + srd + ", 0"; + } else { + std::string voffset = resolveValue(voffsetVal); + result += " " + vdata + ", " + voffset + ", " + srd + ", 0 offen"; + } + if (auto instOffsetAttr = op->getAttrOfType("instOffset")) { int64_t offset = instOffsetAttr.getInt(); if (offset > 0) { @@ -533,16 +540,18 @@ std::optional KernelGenerator::generateOp(Operation *op) { std::string src = resolveValue(srcVal); std::string lines; if (isAGPR) { - // v_accvgpr_write_b32 requires a VGPR source in this backend. - // Materialize immediate sources into the reserved scratch VGPR. + auto [isLit, litVal] = getLiteralValue(srcVal); std::string writeSrc = src; - if (srcIsImm) { + if (srcIsImm && !(isLit && isInlineConstant(litVal))) { + // Non-inline literal: materialize into scratch VGPR first. lines += " v_mov_b32 " + formatVGPRRange(kScratchVGPR, 1) + ", " + src; writeSrc = formatVGPRRange(kScratchVGPR, 1); peakVGPRs = std::max(peakVGPRs, kScratchVGPR + 1); invalidateScratchCache(); } + // Inline constants (e.g. 0) go directly into + // v_accvgpr_write_b32 without a scratch VGPR. for (int64_t i = 0; i < size; ++i) { if (!lines.empty()) lines += "\n"; @@ -563,6 +572,12 @@ std::optional KernelGenerator::generateOp(Operation *op) { if (isAGPR) { if (srcIsImm) { std::string src = resolveValue(srcVal); + auto [isLit, litVal] = getLiteralValue(srcVal); + if (isLit && isInlineConstant(litVal)) { + // Inline constant: emit directly without scratch VGPR. + return " v_accvgpr_write_b32 " + resolveValue(result) + ", " + + src; + } std::string scratch = formatVGPRRange(kScratchVGPR, 1); peakVGPRs = std::max(peakVGPRs, kScratchVGPR + 1); invalidateScratchCache(); @@ -695,10 +710,7 @@ std::optional KernelGenerator::generateOp(Operation *op) { SmallVector handled(pendingCopies.size(), false); - // Allocate swap temps once and reuse across all swaps. - // Swaps are emitted sequentially so the temp is dead after - // each 3-instruction sequence and can be reused. - int64_t vSwapTemp = -1; + // SGPR swaps still need a temp; VGPR swaps use v_swap_b32. int64_t sSwapTemp = -1; for (size_t i = 0; i < pendingCopies.size(); ++i) { @@ -728,13 +740,7 @@ std::optional KernelGenerator::generateOp(Operation *op) { } int64_t regA = pendingCopies[i].dst; int64_t regB = pendingCopies[j].dst; - if (vSwapTemp < 0) { - vSwapTemp = peakVGPRs; - peakVGPRs = std::max(peakVGPRs, vSwapTemp + 1); - } - os << " v_mov_b32 v" << vSwapTemp << ", v" << regA << "\n"; - os << " v_mov_b32 v" << regA << ", v" << regB << "\n"; - os << " v_mov_b32 v" << regB << ", v" << vSwapTemp << "\n"; + os << " v_swap_b32 v" << regA << ", v" << regB << "\n"; handled[i] = true; handled[j] = true; break; @@ -842,14 +848,22 @@ std::optional KernelGenerator::generateOp(Operation *op) { .Case( [&](YieldOp) -> std::optional { return std::nullopt; }) .Case( + S_CMP_GE_U32, S_CMP_NE_U32, S_CMP_LT_I32, S_CMP_EQ_I32, + S_CMP_LE_I32, S_CMP_GT_I32, S_CMP_GE_I32, S_CMP_NE_I32>( [&](auto cmpOp) -> std::optional { llvm::StringRef opName = cmpOp->getName().getStringRef(); llvm::StringRef mnemonic = opName; if (opName.starts_with("waveasm.")) { mnemonic = opName.drop_front(8); } + // s_cmp_ne_* → s_cmp_lg_* (ISA mnemonic) + std::string mnemStr; + if (mnemonic.contains("_ne_")) { + mnemStr = mnemonic.str(); + size_t pos = mnemStr.find("_ne_"); + mnemStr.replace(pos, 4, "_lg_"); + mnemonic = mnemStr; + } llvm::SmallVector operands; for (Value operand : cmpOp->getOperands()) { operands.push_back(resolveValue(operand)); diff --git a/waveasm/lib/Transforms/BufferLoadStrengthReduction.cpp b/waveasm/lib/Transforms/BufferLoadStrengthReduction.cpp index e2672df944..185fc93b76 100644 --- a/waveasm/lib/Transforms/BufferLoadStrengthReduction.cpp +++ b/waveasm/lib/Transforms/BufferLoadStrengthReduction.cpp @@ -616,6 +616,53 @@ static void applyStrengthReduction(LoopOp loopOp) { loopOp.erase(); } +// Peephole: when a buffer_load has voffset = V_ADD_U32(vgpr, sgpr) and +// soffset = 0, fold the SGPR addend into soffset. This avoids a VALU +// instruction per load by using the hardware scalar offset field. +static void peepholeSoffsetFold(Operation *root) { + root->walk([&](Operation *op) { + if (!isBufferLoad(op) && !isBufferLoadLDS(op)) + return; + if (op->getNumOperands() < 3) + return; + + unsigned soffsetIdx = 2; + auto soffsetConst = getConstantValue(op->getOperand(soffsetIdx)); + if (!soffsetConst || *soffsetConst != 0) + return; + + unsigned voffsetIdx = getVoffsetIdx(op); + Value voffset = op->getOperand(voffsetIdx); + auto addOp = voffset.getDefiningOp(); + if (!addOp) + return; + + Value src0 = addOp.getSrc0(); + Value src1 = addOp.getSrc1(); + Value vgprPart = nullptr; + Value sgprPart = nullptr; + + if (isVGPRType(src0.getType()) && isSGPRType(src1.getType())) { + vgprPart = src0; + sgprPart = src1; + } else if (isSGPRType(src0.getType()) && isVGPRType(src1.getType())) { + vgprPart = src1; + sgprPart = src0; + } + if (!vgprPart || !sgprPart) + return; + + // Only fold if the V_ADD_U32 is used exclusively by buffer_loads. + for (Operation *user : addOp->getUsers()) { + if (!isBufferLoad(user) && !isBufferLoadLDS(user)) + return; + } + + op->setOperand(voffsetIdx, vgprPart); + op->setOperand(soffsetIdx, sgprPart); + }); +} + struct BufferLoadStrengthReductionPass : public waveasm::impl::WAVEASMBufferLoadStrengthReductionBase< BufferLoadStrengthReductionPass> { @@ -628,6 +675,7 @@ struct BufferLoadStrengthReductionPass module->walk([&](LoopOp loopOp) { loops.push_back(loopOp); }); for (auto loopOp : loops) applyStrengthReduction(loopOp); + peepholeSoffsetFold(module); } }; diff --git a/waveasm/lib/Transforms/CMakeLists.txt b/waveasm/lib/Transforms/CMakeLists.txt index 9960bcdd6a..9770b31f0c 100644 --- a/waveasm/lib/Transforms/CMakeLists.txt +++ b/waveasm/lib/Transforms/CMakeLists.txt @@ -14,9 +14,11 @@ foreach(src ${HANDLERS_SRCS}) endforeach() add_mlir_dialect_library(MLIRWaveASMTransforms + ArithLegalization.cpp AssemblyEmitter.cpp BufferLoadStrengthReduction.cpp ExtractScalarization.cpp + GPUModuleToBinary.cpp HazardMitigation.cpp LinearScanPass.cpp LinearScanRegAlloc.cpp @@ -29,9 +31,12 @@ add_mlir_dialect_library(MLIRWaveASMTransforms Peephole.cpp RegionBuilder.cpp ScalePackElimination.cpp + SCCVerifier.cpp ScopedCSE.cpp Ticketing.cpp + TranslateFromLLVMDialect.cpp TranslateFromMLIR.cpp + VGPRCompaction.cpp ${HANDLERS_FULL_PATHS} ADDITIONAL_HEADER_DIRS @@ -47,8 +52,12 @@ add_mlir_dialect_library(MLIRWaveASMTransforms MLIRFuncDialect MLIRGPUDialect MLIRIR + MLIRLLVMDialect MLIRMathDialect MLIRMemRefDialect + MLIRPass + MLIRROCDLDialect + MLIRROCDLTarget MLIRSCFDialect MLIRSupport MLIRVectorDialect diff --git a/waveasm/lib/Transforms/GPUModuleToBinary.cpp b/waveasm/lib/Transforms/GPUModuleToBinary.cpp new file mode 100644 index 0000000000..f3c3afb062 --- /dev/null +++ b/waveasm/lib/Transforms/GPUModuleToBinary.cpp @@ -0,0 +1,141 @@ +// Copyright 2026 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// WAVEASMGPUModuleToBinary: emit assembly, assemble, link, gpu.binary. +// +// Final pass in the WaveASM pipeline. Expects gpu.module ops containing +// waveasm.program ops (already register-allocated and scheduled). Emits +// AMDGCN assembly, assembles + links to HSACO, and replaces each gpu.module +// with a gpu.binary holding the code object. +//===----------------------------------------------------------------------===// + +#include "waveasm/Dialect/WaveASMDialect.h" +#include "waveasm/Dialect/WaveASMOps.h" +#include "waveasm/Transforms/AssemblyEmitter.h" +#include "waveasm/Transforms/Passes.h" + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/Target/LLVM/ROCDL/Utils.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/Path.h" + +#define DEBUG_TYPE "waveasm-gpu-module-to-binary" + +using namespace mlir; + +namespace waveasm { +#define GEN_PASS_DEF_WAVEASMGPUMODULETOBINARY +#include "waveasm/Transforms/Passes.h.inc" +} // namespace waveasm + +namespace { + +static constexpr StringLiteral kTriple = "amdgcn-amd-amdhsa"; + +struct WAVEASMGPUModuleToBinaryPass + : waveasm::impl::WAVEASMGPUModuleToBinaryBase< + WAVEASMGPUModuleToBinaryPass> { + using WAVEASMGPUModuleToBinaryBase::WAVEASMGPUModuleToBinaryBase; + + void runOnOperation() override { + Operation *rootOp = getOperation(); + + // Collect gpu.modules that contain waveasm.program ops. + SmallVector gpuModules; + rootOp->walk([&](gpu::GPUModuleOp m) -> WalkResult { + if (!m.getOps().empty()) + gpuModules.push_back(m); + return WalkResult::skip(); + }); + + if (gpuModules.empty()) + return; + + // Step 1: Emit assembly for all programs. + waveasm::PhysicalMapping mapping; + std::string asmText; + llvm::raw_string_ostream asmStream(asmText); + + WalkResult walkResult = + rootOp->walk([&](waveasm::ProgramOp program) { + size_t posBefore = asmText.size(); + if (failed(waveasm::writeAssembly(program, mapping, asmStream))) + return WalkResult::interrupt(); + if (asmText.size() == posBefore) { + program->emitError("assembly emission produced no output"); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) + return signalPassFailure(); + + // Step 2: Assemble + link to HSACO. + ROCDL::SerializeGPUModuleBase::init(); + + auto emitError = [&]() -> InFlightDiagnostic { + return rootOp->emitError(); + }; + + std::string gpuArch = targetArch.getValue(); + FailureOr> objectCode = ROCDL::assembleIsa( + asmText, kTriple, gpuArch, /*features=*/"", emitError); + if (failed(objectCode)) + return signalPassFailure(); + + // Resolve lld path. + SmallString<128> actualLldPath(lldPath.getValue()); + if (actualLldPath.empty() || !llvm::sys::fs::exists(actualLldPath)) { + actualLldPath = ROCDL::getROCMPath(); + llvm::sys::path::append(actualLldPath, "llvm", "bin", "ld.lld"); + } + if (!llvm::sys::fs::exists(actualLldPath)) { + rootOp->emitError() + << "ld.lld not found (set --lld-path or ROCM_PATH). Tried: " + << actualLldPath; + return signalPassFailure(); + } + + FailureOr> hsaco = + ROCDL::linkObjectCode(*objectCode, actualLldPath, emitError); + if (failed(hsaco)) + return signalPassFailure(); + + // Step 3: Replace each gpu.module with gpu.binary. + OpBuilder builder(rootOp->getContext()); + StringAttr binaryAttr = + builder.getStringAttr(StringRef(hsaco->data(), hsaco->size())); + + for (gpu::GPUModuleOp gpuModule : gpuModules) { + builder.setInsertionPointAfter(gpuModule); + + Attribute target; + if (gpuModule.getTargetsAttr() && !gpuModule.getTargetsAttr().empty()) { + target = gpuModule.getTargetsAttr()[0]; + if (gpuModule.getTargetsAttr().size() > 1) + gpuModule.emitWarning("multiple targets specified, only the first " + "is used; remaining targets are ignored"); + } + if (!target) + target = ROCDL::ROCDLTargetAttr::get(rootOp->getContext(), + /*optLevel=*/2, kTriple, gpuArch); + + auto objectAttr = builder.getAttr( + target, gpu::CompilationTarget::Binary, binaryAttr, + /*properties=*/DictionaryAttr{}, /*kernels=*/gpu::KernelTableAttr{}); + + gpu::BinaryOp::create(builder, gpuModule.getLoc(), gpuModule.getName(), + /*offloadingHandler=*/nullptr, + builder.getArrayAttr({objectAttr})); + gpuModule->erase(); + } + } +}; + +} // namespace diff --git a/waveasm/lib/Transforms/LinearScanPass.cpp b/waveasm/lib/Transforms/LinearScanPass.cpp index ac9d79b9b2..2012a96d3e 100644 --- a/waveasm/lib/Transforms/LinearScanPass.cpp +++ b/waveasm/lib/Transforms/LinearScanPass.cpp @@ -34,7 +34,8 @@ namespace waveasm { } // namespace waveasm /// Convert a virtual register type to a physical register type. -/// Returns the original type unchanged if it's not a virtual register type +/// Also handles re-indexing an already-physical type to a new physReg. +/// Returns the original type unchanged if it's not a register type /// or if physReg < 0. static Type makePhysicalType(MLIRContext *ctx, Type virtualType, int64_t physReg) { @@ -46,9 +47,140 @@ static Type makePhysicalType(MLIRContext *ctx, Type virtualType, return PSRegType::get(ctx, physReg, sreg.getSize()); if (auto areg = dyn_cast(virtualType)) return PARegType::get(ctx, physReg, areg.getSize()); + if (auto pvreg = dyn_cast(virtualType)) + return PVRegType::get(ctx, physReg, pvreg.getSize()); + if (auto psreg = dyn_cast(virtualType)) + return PSRegType::get(ctx, physReg, psreg.getSize()); + if (auto pareg = dyn_cast(virtualType)) + return PARegType::get(ctx, physReg, pareg.getSize()); return virtualType; } +//===----------------------------------------------------------------------===// +// Rematerialization +//===----------------------------------------------------------------------===// + +/// Return true when \p op sits inside a LoopOp body at any nesting depth. +static bool isInsideLoopBody(Operation *op) { + for (Operation *p = op->getParentOp(); p; p = p->getParentOp()) { + if (isa(p)) + return true; + } + return false; +} + +/// Check if an operation can be cheaply rematerialized (cloned near each +/// use to shorten its live range). +/// +/// Accepted patterns: +/// - V_MOV_B32 with immediate operands (accumulator zero-init) +/// - V_MOV_B32 with SGPR source (scalar-to-vector address copy) +/// - S_MOV_B32 with immediate operands (scalar constant materialisation) +static bool isRematerializableOp(Operation *op) { + if (op->getNumResults() != 1) + return false; + + if (isa(op)) { + for (Value operand : op->getOperands()) { + auto *defOp = operand.getDefiningOp(); + if (!defOp) + return false; + if (isa(defOp)) + continue; + // SGPR source: the scalar value dominates the original def, so it + // also dominates every use site where we might place a clone. + if (isSGPRType(operand.getType())) + continue; + return false; + } + return true; + } + + if (isa(op)) { + for (Value operand : op->getOperands()) { + auto *defOp = operand.getDefiningOp(); + if (!defOp || !isa(defOp)) + return false; + } + return true; + } + + return false; +} + +/// Rematerialize cheap-to-compute VGPR values by cloning their defining ops +/// near each use site, shortening live ranges and reducing peak register +/// pressure at the cost of slightly increased code size. +/// +/// A `v_mov_b32 %v, 0` defined at instruction 5 and used at instruction 100 +/// holds a VGPR for 95 instructions. After rematerialization the clone +/// appears at instruction 99 with a 1-instruction live range, freeing the +/// VGPR(s) for the other 94 instructions. For multi-register results +/// (e.g. 4-wide accumulators) this frees 4 VGPRs across that span. +static void rematerializeCheapOps(ProgramOp program) { + constexpr int64_t kMinRematRangeLength = 10; + + LivenessInfo liveness = computeLiveness(program); + + llvm::SmallVector candidates; + program.walk([&](Operation *op) { + if (!isRematerializableOp(op)) + return; + Value result = op->getResult(0); + // Accept both VGPR and SGPR results (V_MOV_B32 -> VGPR, S_MOV_B32 -> SGPR). + if (!isVGPRType(result.getType()) && !isSGPRType(result.getType())) + return; + const LiveRange *range = liveness.getRange(result); + if (!range || range->length() <= kMinRematRangeLength) + return; + candidates.push_back(op); + }); + + for (Operation *op : candidates) { + Value result = op->getResult(0); + bool defOutsideLoop = !isInsideLoopBody(op); + bool isVGPR = isVGPRType(result.getType()); + + llvm::SmallVector uses; + for (OpOperand &use : result.getUses()) + uses.push_back(&use); + + if (uses.empty()) + continue; + + // Clone once per unique user operation to avoid redundant copies when + // the same op references the value in multiple operand slots. + llvm::DenseMap cloneCache; + + for (OpOperand *use : uses) { + Operation *user = use->getOwner(); + + // Preserve VALU-free loop bodies: when a VGPR-producing op is defined + // outside a loop but a use is inside, skip that use. The value's live + // range already spans the entire loop body (Pass 2b extends it), so + // cloning inside wouldn't reduce in-loop pressure — it would only add + // a VALU instruction to the critical loop path. + // SALU ops (S_MOV_B32) don't use the VALU pipeline, so they're fine. + if (defOutsideLoop && isVGPR && isInsideLoopBody(user)) + continue; + + auto it = cloneCache.find(user); + if (it != cloneCache.end()) { + use->set(it->second); + } else { + OpBuilder builder(user); + Operation *clone = builder.clone(*op); + Value cloned = clone->getResult(0); + use->set(cloned); + cloneCache[user] = cloned; + } + } + + if (result.use_empty()) + op->erase(); + } +} + namespace { //===----------------------------------------------------------------------===// @@ -208,6 +340,12 @@ struct LinearScanPass } }); + // Rematerialize cheap VGPR ops (v_mov_b32 from immediates) near their + // use sites to shorten live ranges and reduce peak register pressure. + // This must run after duplicate-init-arg handling (which creates new + // V_MOV_B32 ops) and before allocation (which consumes the IR). + rematerializeCheapOps(program); + // Create allocator with precolored values and tied operands. // MFMA ties come from the local tiedPairs map; loop ties come from // the TiedValueClasses built during liveness analysis (see below). @@ -428,14 +566,27 @@ struct LinearScanPass } }); - // Also update if op result types from yield operand types + // Also update if op result types. + // Prefer the allocation mapping (which respects loop ties) over the + // then-yield operand type. When an if result feeds a loop init arg, + // the allocator ties it to the loop block arg and both receive the + // same physical register. The then-yield operand may carry a + // *different* physical register (from the inner loop), so copying it + // blindly would break the LoopLikeOpInterface verifier which requires + // exact type equality between init args and region iter_args. program.walk([&](IfOp ifOp) { auto &thenBlock = ifOp.getThenBlock(); - if (auto yieldOp = dyn_cast(thenBlock.getTerminator())) { - for (unsigned i = 0; i < ifOp->getNumResults(); ++i) { - if (i < yieldOp.getResults().size()) { - ifOp->getResult(i).setType(yieldOp.getResults()[i].getType()); - } + auto yieldOp = dyn_cast(thenBlock.getTerminator()); + if (!yieldOp) + return; + for (unsigned i = 0; i < ifOp->getNumResults(); ++i) { + Value res = ifOp->getResult(i); + int64_t physReg = mapping.getPhysReg(res); + if (physReg >= 0) { + res.setType( + makePhysicalType(ifOp->getContext(), res.getType(), physReg)); + } else if (i < yieldOp.getResults().size()) { + res.setType(yieldOp.getResults()[i].getType()); } } }); diff --git a/waveasm/lib/Transforms/LinearScanRegAlloc.cpp b/waveasm/lib/Transforms/LinearScanRegAlloc.cpp index 88e38a98bd..429c6cf3b6 100644 --- a/waveasm/lib/Transforms/LinearScanRegAlloc.cpp +++ b/waveasm/lib/Transforms/LinearScanRegAlloc.cpp @@ -70,15 +70,25 @@ static void insertActiveRange(llvm::SmallVectorImpl &active, active.insert(insertPos, newRange); } -/// Try to allocate a physical register from the pool. -/// Returns the allocated register index, or std::nullopt on failure. +/// Try to allocate a physical register from the pool (lowest-first). static std::optional tryAllocate(RegPool &pool, int64_t size, int64_t alignment) { int64_t physReg = (size == 1) ? pool.allocSingle() : pool.allocRange(size, alignment); - if (physReg < 0) { + if (physReg < 0) + return std::nullopt; + return physReg; +} + +/// Try to allocate from the top of a capped region (highest-first). +static std::optional tryAllocateFromTop(RegPool &pool, int64_t size, + int64_t alignment, + int64_t ceiling) { + int64_t physReg = (size == 1) + ? pool.allocSingleFromTop(ceiling) + : pool.allocRangeFromTop(size, alignment, ceiling); + if (physReg < 0) return std::nullopt; - } return physReg; } @@ -167,8 +177,31 @@ allocateRegClass(ArrayRef ranges, RegPool &pool, } } - // Allocate physical register(s) from the pool - physReg = tryAllocate(pool, range.size, range.alignment); + // For VGPRs with size > 1 (dwordx2/x4), allocate long-lived ranges + // from the top of the expected register usage and short-lived ranges + // from the bottom. This separates interleaved buffer_load (long-lived + // prefetch) and ds_read (short-lived consumed) destinations into + // contiguous regions, reducing fragmentation and peak VGPR count. + // The ceiling is maxPressure (peak simultaneous VGPRs from liveness), + // not maxRegs (512), to avoid allocating into the AGPR region. + bool useBidirectional = + pool.getRegClass() == RegClass::VGPR && range.size > 1; + if (useBidirectional) { + int64_t rangeLength = range.end - range.start; + // Ranges in the top 10% by length get allocated from the top. + // This targets buffer_load prefetch values (which span almost the + // entire loop body) while leaving ds_read values (consumed within + // one half) at the bottom. + int64_t maxEnd = ranges.back().end; + int64_t threshold = (maxEnd * 3) / 4; + if (rangeLength > threshold) + physReg = tryAllocateFromTop(pool, range.size, range.alignment, + maxPressure); + else + physReg = tryAllocate(pool, range.size, range.alignment); + } + if (!physReg) + physReg = tryAllocate(pool, range.size, range.alignment); if (!physReg) { return program.emitOpError() diff --git a/waveasm/lib/Transforms/LiteralMaterialization.cpp b/waveasm/lib/Transforms/LiteralMaterialization.cpp index 8adea968cf..c090272a23 100644 --- a/waveasm/lib/Transforms/LiteralMaterialization.cpp +++ b/waveasm/lib/Transforms/LiteralMaterialization.cpp @@ -33,11 +33,12 @@ static bool isInlineConstant(int64_t val) { } static const llvm::StringSet<> &getVOP2Instructions() { + // v_add_u32, v_sub_u32, v_subrev_u32 are VOP3-only on GFX9+ (the VOP2 + // carry-producing variants are v_add_co_u32 / v_sub_co_u32 / + // v_subrev_co_u32). VOP3 does not support literal operands. static const llvm::StringSet<> kVOP2 = { - "v_add_u32", "v_sub_u32", "v_subrev_u32", "v_and_b32", - "v_or_b32", "v_xor_b32", "v_lshlrev_b32", "v_lshrrev_b32", - "v_ashrrev_i32", "v_max_u32", "v_min_u32", "v_add_i32", - "v_sub_i32", + "v_and_b32", "v_or_b32", "v_xor_b32", "v_lshlrev_b32", + "v_lshrrev_b32", "v_ashrrev_i32", "v_max_u32", "v_min_u32", }; return kVOP2; } diff --git a/waveasm/lib/Transforms/Liveness.cpp b/waveasm/lib/Transforms/Liveness.cpp index b4c3e987f8..f404c34693 100644 --- a/waveasm/lib/Transforms/Liveness.cpp +++ b/waveasm/lib/Transforms/Liveness.cpp @@ -31,44 +31,36 @@ static bool isSwapPatternIterArg(Value iterArg, Block &bodyBlock, unsigned i) { return false; } -/// Detect a write-after-read (WAR) hazard between a buffer_load iter_arg -/// and the block_arg it feeds back into. +/// Detect a write-after-read (WAR) hazard between an iter_arg and the +/// block_arg it feeds back into. /// -/// In pipelined schedules, next-iteration loads can be interleaved with -/// MFMAs that still consume the current iteration's block_arg. If the -/// allocator ties them to the same register, the MFMA silently reads the -/// new load value instead of the old one. +/// If the iter_arg is defined at a point where the block_arg still has +/// subsequent uses, tying them to the same physical register would let +/// the in-place update clobber a value that later instructions still need. /// -/// For single-element loads (buffer_load_ubyte/sbyte/ushort/sshort), the -/// block_arg is consumed indirectly through vector.bitcast / vector.extract -/// that share the same physical register. The direct use-point check -/// misses these transitive uses, so we unconditionally flag them. -static bool hasBufferLoadWARHazard(Value iterArg, Value blockArg, - const LivenessInfo &info) { +/// This is critical for unrolled loops where CSE can merge an +/// affine.apply result (e.g. arg8+2 for a bounds check) with the IV +/// increment (also arg8+2 for step=2). After merging, the single +/// S_ADD_U32 is placed in the middle of the body; if it shares a +/// register with the IV block_arg, the second unrolled copy's use of +/// the original IV reads the wrong value. +static bool hasWARHazard(Value iterArg, Value blockArg, + const LivenessInfo &info) { if (isa(iterArg)) return false; auto *defOp = iterArg.getDefiningOp(); if (!defOp) return false; - auto opName = defOp->getName().getStringRef(); - if (!opName.contains("buffer_load") || opName.contains("_lds")) - return false; - - // Single-element loads: unconditionally untie (transitive uses hidden). - if (opName.contains("_ubyte") || opName.contains("_sbyte") || - opName.contains("_ushort") || opName.contains("_sshort")) { - LLVM_DEBUG(llvm::dbgs() - << " WAR hazard (single-element load): " << opName << "\n"); - return true; - } - // Multi-register loads: check for def/use overlap. + // Check for def/use overlap: if the iter_arg is defined at a point + // where block_arg still has subsequent uses, tying them creates a + // WAR hazard. auto iterDefIt = info.defPoints.find(iterArg); auto baUseIt = info.usePoints.find(blockArg); if (iterDefIt != info.defPoints.end() && baUseIt != info.usePoints.end()) { - int64_t loadDef = iterDefIt->second; + int64_t iterDef = iterDefIt->second; for (int64_t usePoint : baUseIt->second) { - if (usePoint >= loadDef) + if (usePoint > iterDef) return true; } } @@ -554,11 +546,11 @@ LivenessInfo computeLiveness(ProgramOp program) { // Condition iter_arg -> block arg. // Skip swap patterns and WAR hazards so the allocator keeps them - // in separate registers (see hasBufferLoadWARHazard). + // in separate registers (see hasWARHazard). if (i < condOp.getIterArgs().size()) { Value iterArg = condOp.getIterArgs()[i]; bool skip = isSwapPatternIterArg(iterArg, bodyBlock, i) || - hasBufferLoadWARHazard(iterArg, blockArg, info); + hasWARHazard(iterArg, blockArg, info); if (!skip && info.ranges.contains(iterArg)) members.push_back(iterArg); } @@ -610,7 +602,7 @@ LivenessInfo computeLiveness(ProgramOp program) { if (i < condOp.getIterArgs().size()) { Value iterArg = condOp.getIterArgs()[i]; bool skip = isSwapPatternIterArg(iterArg, bodyBlock, i) || - hasBufferLoadWARHazard(iterArg, blockArg, info); + hasWARHazard(iterArg, blockArg, info); if (!skip && info.ranges.contains(iterArg) && !tc.tiedPairs.contains(iterArg)) tc.tiedPairs[iterArg] = blockArg; @@ -634,11 +626,19 @@ LivenessInfo computeLiveness(ProgramOp program) { } } - // Sort by (start, end) for linear scan + // Sort by (start, -size, -alignment, -end) for linear scan. + // At the same start point, allocate larger and more constrained ranges + // first to reduce fragmentation — a 16-wide aligned range has very few + // valid slots, so giving it priority prevents smaller ranges from + // blocking its only valid position. auto sortByStart = [](const LiveRange &a, const LiveRange &b) { if (a.start != b.start) return a.start < b.start; - return a.end < b.end; + if (a.size != b.size) + return a.size > b.size; + if (a.alignment != b.alignment) + return a.alignment > b.alignment; + return a.end > b.end; }; llvm::sort(info.vregRanges, sortByStart); @@ -650,6 +650,11 @@ LivenessInfo computeLiveness(ProgramOp program) { info.maxSRegPressure = computeMaxPressure(info.sregRanges, info.tiedClasses); info.maxARegPressure = computeMaxPressure(info.aregRanges, info.tiedClasses); + // Dump detailed pressure breakdown for debugging. + dumpPeakPressureInfo(info, ops, RegClass::VGPR); + dumpPeakPressureInfo(info, ops, RegClass::SGPR); + dumpPeakPressureInfo(info, ops, RegClass::AGPR); + return info; } diff --git a/waveasm/lib/Transforms/MetadataEmitter.cpp b/waveasm/lib/Transforms/MetadataEmitter.cpp index 5e21940f28..db3ff8de72 100644 --- a/waveasm/lib/Transforms/MetadataEmitter.cpp +++ b/waveasm/lib/Transforms/MetadataEmitter.cpp @@ -99,7 +99,7 @@ llvm::SmallVector MetadataEmitter::emitPrologue() { lines.push_back(".text"); lines.push_back(""); - std::string symName = program.getSymName().str(); + std::string symName = getKernelName(program).str(); lines.push_back(".protected " + symName); lines.push_back(".globl " + symName); lines.push_back(".p2align 8"); @@ -176,7 +176,7 @@ llvm::SmallVector MetadataEmitter::emitKernelDescriptor(int64_t peakVGPRs, int64_t peakSGPRs, int64_t peakAGPRs, int64_t ldsSize) { llvm::SmallVector lines; - std::string symName = program.getSymName().str(); + std::string symName = getKernelName(program).str(); bool usesWorkgroupIdX, usesWorkgroupIdY, usesWorkgroupIdZ, usesWorkitemId; scanSystemRegisterUsage(program, usesWorkgroupIdX, usesWorkgroupIdY, @@ -271,12 +271,16 @@ MetadataEmitter::emitKernelDescriptor(int64_t peakVGPRs, int64_t peakSGPRs, lines.push_back(" .amdhsa_next_free_vgpr " + std::to_string(nextFreeVGPR)); lines.push_back(" .amdhsa_next_free_sgpr " + std::to_string(nextFreeSGPR)); + // Always enable all workgroup IDs when any is used. + // This matches the real LLVM backend and ensures the SGPR layout + // is predictable (base+0=x, base+1=y, base+2=z) without gaps. + bool anyWgId = usesWorkgroupIdX || usesWorkgroupIdY || usesWorkgroupIdZ; lines.push_back(" .amdhsa_system_sgpr_workgroup_id_x " + - std::to_string(usesWorkgroupIdX ? 1 : 0)); + std::to_string(anyWgId ? 1 : 0)); lines.push_back(" .amdhsa_system_sgpr_workgroup_id_y " + - std::to_string(usesWorkgroupIdY ? 1 : 0)); + std::to_string(anyWgId ? 1 : 0)); lines.push_back(" .amdhsa_system_sgpr_workgroup_id_z " + - std::to_string(usesWorkgroupIdZ ? 1 : 0)); + std::to_string(anyWgId ? 1 : 0)); // Derive system_vgpr_workitem_id from workgroup dimensions. // When wgZ > 1, hardware provides thread IDs in v0 (x), v1 (y), v2 (z). @@ -312,7 +316,7 @@ llvm::SmallVector MetadataEmitter::emitMetadataYAML(int64_t peakVGPRs, int64_t peakSGPRs, int64_t peakAGPRs, int64_t ldsSize) { llvm::SmallVector lines; - std::string symName = program.getSymName().str(); + std::string symName = getKernelName(program).str(); lines.push_back(".amdgpu_metadata"); lines.push_back("---"); diff --git a/waveasm/lib/Transforms/SCCVerifier.cpp b/waveasm/lib/Transforms/SCCVerifier.cpp new file mode 100644 index 0000000000..b65fc45a42 --- /dev/null +++ b/waveasm/lib/Transforms/SCCVerifier.cpp @@ -0,0 +1,138 @@ +// Copyright 2026 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// SCC Verifier Pass +// +// Verifies that no SCC-clobbering SALU instruction is placed between an +// SCC-producing op and its consumer. Uses isa<> checks instead of +// hasTrait() because adding traits to existing op classes changes +// ODS-generated C++ and causes MLIR passes to produce different IR. +//===----------------------------------------------------------------------===// + +#include "waveasm/Dialect/WaveASMDialect.h" +#include "waveasm/Dialect/WaveASMInterfaces.h" +#include "waveasm/Dialect/WaveASMOps.h" +#include "waveasm/Transforms/Passes.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "waveasm-scc-verifier" + +using namespace mlir; +using namespace waveasm; + +namespace waveasm { +#define GEN_PASS_DEF_WAVEASMSCCVERIFIER +#include "waveasm/Transforms/Passes.h.inc" +} // namespace waveasm + +namespace { + +/// Returns true if the operation writes the SCC flag on hardware. +/// Uses hasTrait for ops that carry the trait (carry ops, cmp ops), +/// and isa<> for ops still on SALUBinaryOp/SALUUnaryOp (pending migration). +static bool writesSCC(Operation *op) { + return op->hasTrait(); +} + +struct SCCVerifierPass + : public waveasm::impl::WAVEASMSCCVerifierBase { + using WAVEASMSCCVerifierBase::WAVEASMSCCVerifierBase; + + void runOnOperation() override { + Operation *module = getOperation(); + unsigned errorCount = 0; + module->walk([&](ProgramOp program) { + for (Block &block : program.getBody()) + errorCount += verifyBlock(block); + }); + if (errorCount > 0) { + LLVM_DEBUG(llvm::dbgs() << "SCC verifier: found " << errorCount + << " SCC hazard(s)\n"); + signalPassFailure(); + } + } + +private: + static SmallVector findSCCClobbersBetween(Operation *producer, + Operation *consumer) { + SmallVector clobbers; + if (!producer || !consumer || producer->getBlock() != consumer->getBlock()) + return clobbers; + bool inRange = false; + for (Operation &op : *producer->getBlock()) { + if (&op == producer) { inRange = true; continue; } + if (&op == consumer) break; + if (inRange && writesSCC(&op)) + clobbers.push_back(&op); + } + return clobbers; + } + + static void emitSCCClobberError(Operation *consumer, Operation *producer, + ArrayRef clobbers) { + auto diag = consumer->emitError() + << "SCC hazard: " << clobbers.size() + << " SCC-clobbering op(s) between SCC producer '" + << producer->getName() << "' and consumer '" + << consumer->getName() << "'"; + for (Operation *c : clobbers) + diag.attachNote(c->getLoc()) + << "SCC clobbered here by '" << c->getName() << "'"; + diag.attachNote(producer->getLoc()) << "SCC defined here"; + } + + unsigned verifyBlock(Block &block) { + unsigned errors = 0; + Operation *lastSCCWriter = nullptr; + for (Operation &op : block) { + if (auto condOp = dyn_cast(&op)) { + Value cond = condOp.getCondition(); + Operation *condDef = cond.getDefiningOp(); + if (condDef && lastSCCWriter && lastSCCWriter != condDef) { + auto clobbers = findSCCClobbersBetween(condDef, &op); + if (!clobbers.empty()) { + emitSCCClobberError(&op, condDef, clobbers); + ++errors; + } + } + } + if (auto ifOp = dyn_cast(&op)) { + Value cond = ifOp.getCondition(); + Operation *condDef = cond.getDefiningOp(); + if (condDef && lastSCCWriter && lastSCCWriter != condDef) { + auto clobbers = findSCCClobbersBetween(condDef, &op); + if (!clobbers.empty()) { + emitSCCClobberError(&op, condDef, clobbers); + ++errors; + } + } + } + if (isa(&op) && !lastSCCWriter) { + op.emitError() + << "SCC hazard: s_cselect_b32 has no preceding SCC writer"; + ++errors; + } + if (isa(&op) && !lastSCCWriter) { + op.emitError() + << "SCC hazard: s_addc_u32 has no preceding SCC writer"; + ++errors; + } + if (writesSCC(&op)) + lastSCCWriter = &op; + for (Region ®ion : op.getRegions()) + for (Block &nestedBlock : region) + errors += verifyBlock(nestedBlock); + } + return errors; + } +}; + +} // namespace diff --git a/waveasm/lib/Transforms/ScopedCSE.cpp b/waveasm/lib/Transforms/ScopedCSE.cpp index 4c589060fa..0db4f14860 100644 --- a/waveasm/lib/Transforms/ScopedCSE.cpp +++ b/waveasm/lib/Transforms/ScopedCSE.cpp @@ -155,6 +155,11 @@ bool isCSEEligible(Operation *op) { if (op->hasTrait()) return false; + // SCC-reading ops are NOT CSE-eligible: their result depends on implicit + // SCC state, so two ops with identical operands can produce different results. + if (op->hasTrait()) + return false; + // Precolored registers are not eligible (they're fixed) if (isa(op)) return false; diff --git a/waveasm/lib/Transforms/TranslateFromLLVMDialect.cpp b/waveasm/lib/Transforms/TranslateFromLLVMDialect.cpp new file mode 100644 index 0000000000..c02dc634c5 --- /dev/null +++ b/waveasm/lib/Transforms/TranslateFromLLVMDialect.cpp @@ -0,0 +1,776 @@ +// Copyright 2026 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// TranslateFromLLVM: Strict LLVM dialect -> WaveASM translation. +// +// Consumes gpu.module { llvm.func @kernel ... } with rocdl intrinsics. +// Fails on any unhandled op — no silent fallthrough. +//===----------------------------------------------------------------------===// + +#include "waveasm/Dialect/WaveASMDialect.h" +#include "waveasm/Dialect/WaveASMOps.h" +#include "waveasm/Dialect/WaveASMTypes.h" +#include "waveasm/Transforms/AssemblyEmitter.h" +#include "waveasm/Transforms/Passes.h" +#include "waveasm/Transforms/TranslateFromMLIR.h" + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Builders.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "waveasm-translate-llvm" + +using namespace mlir; + +namespace waveasm { +#define GEN_PASS_DEF_WAVEASMTRANSLATEFROMLLVM +#include "waveasm/Transforms/Passes.h.inc" +} // namespace waveasm + +namespace waveasm { + +//===----------------------------------------------------------------------===// +// LLVM Translation State +//===----------------------------------------------------------------------===// + +// AMDGPU SRD (Shader Resource Descriptor) constants. +// SRD is 4 consecutive SGPRs: [base_lo, base_hi|stride, num_records, flags]. + +/// Mask for SRD word 1 to keep only base_addr[47:32] (lower 16 bits). +static constexpr int64_t kSRDWord1BaseMask = 0xFFFF; +/// Default SRD word 3 flags set by the prologue (OOB_SELECT=2). +static constexpr int64_t kSRDDefaultFlags = 0x20000; +/// Default num_records when buffer size is unknown (max 4-byte-aligned value). +static constexpr int64_t kSRDDefaultNumRecords = 0x7FFFFFFC; + +/// Tracks decomposed buffer pointer info from GEP operations. +/// A GEP on ptr<7> decomposes into (SRD, byte-offset-vgpr). +/// TODO: consider a separate decomposition pass for ptr<7>. +struct BufferPtrInfo { + Value srd; // The SRD (4×SGPR) from rocdl.make.buffer.rsrc. + Value voffset; // Byte offset VGPR. +}; + +/// State for LLVM->WaveASM translation, layered on top of TranslationContext. +class LLVMTranslationState { +public: + explicit LLVMTranslationState(TranslationContext &ctx) : ctx(ctx) {} + + TranslationContext &ctx; + + /// Map rocdl.make.buffer.rsrc result -> SRD SGPR value from prologue. + void mapBufferRsrc(Value rsrc, Value srd) { rsrcToSRD[rsrc] = srd; } + Value lookupSRD(Value rsrc) const { return rsrcToSRD.lookup(rsrc); } + + /// Map GEP result -> decomposed (SRD, voffset). + void mapGEP(Value gep, BufferPtrInfo info) { gepMap[gep] = info; } + std::optional lookupGEP(Value gep) const { + auto it = gepMap.find(gep); + if (it != gepMap.end()) + return it->second; + return std::nullopt; + } + + /// Track base-pointer byte offset from bare-pointer GEPs. + /// These offsets accumulate and get added to voffset when the pointer + /// is used via make.buffer.rsrc + buffer GEP. + void setBaseOffset(Value ptr, Value offset) { baseOffsets[ptr] = offset; } + Value lookupBaseOffset(Value ptr) const { return baseOffsets.lookup(ptr); } + +private: + DenseMap rsrcToSRD; + DenseMap gepMap; + DenseMap baseOffsets; +}; + +//===----------------------------------------------------------------------===// +// Helpers +//===----------------------------------------------------------------------===// + +/// Extract workgroup size from llvm.func attributes. +/// Returns failure if both gpu.known_block_size and +/// rocdl.reqd_work_group_size are present and disagree. +static FailureOr> +getWorkgroupSize(LLVM::LLVMFuncOp func) { + auto gpuAttr = func->getAttrOfType("gpu.known_block_size"); + auto rocdlAttr = + func->getAttrOfType("rocdl.reqd_work_group_size"); + + if (gpuAttr && rocdlAttr && gpuAttr.asArrayRef() != rocdlAttr.asArrayRef()) + return func->emitOpError("contradicting workgroup size attributes: " + "gpu.known_block_size and " + "rocdl.reqd_work_group_size disagree"); + + DenseI32ArrayAttr attr = gpuAttr ? gpuAttr : rocdlAttr; + if (!attr) + return std::tuple{64, 1, 1}; + + auto vals = attr.asArrayRef(); + int64_t x = vals.size() > 0 ? vals[0] : 64; + int64_t y = vals.size() > 1 ? vals[1] : 1; + int64_t z = vals.size() > 2 ? vals[2] : 1; + return std::tuple{x, y, z}; +} + +/// Create a waveasm.program from an llvm.func kernel. +static ProgramOp createProgramFromLLVMFunc(LLVM::LLVMFuncOp func, + OpBuilder &builder, + StringRef targetId) { + auto *mlirCtx = builder.getContext(); + auto loc = func.getLoc(); + + // Code object version 5: supports kernel argument preloading. + auto targetAttr = + TargetAttr::get(mlirCtx, getTargetKindAttr(mlirCtx, targetId), + /*code_object_version=*/5); + auto abiAttr = KernelABIAttr::get(mlirCtx, /*tid=*/0, /*kernarg=*/0, + /*wg_id_x=*/std::nullopt, + /*wg_id_y=*/std::nullopt, + /*wg_id_z=*/std::nullopt); + + FailureOr> wgSize = + getWorkgroupSize(func); + if (failed(wgSize)) + return {}; + auto [wgX, wgY, wgZ] = *wgSize; + std::array sizes = {builder.getI64IntegerAttr(wgX), + builder.getI64IntegerAttr(wgY), + builder.getI64IntegerAttr(wgZ)}; + + // Mangle the program name to avoid symbol collision with the original + // llvm.func (which we keep alive for gpu.launch_func verification). + // Store the original kernel name for assembly emission. + std::string programName = (func.getName() + "__waveasm").str(); + auto program = + ProgramOp::create(builder, loc, programName, targetAttr, abiAttr, + /*vgprs=*/int64_t{256}, + /*sgprs=*/int64_t{104}, + /*workgroup_size=*/builder.getArrayAttr(sizes), + /*lds_size=*/IntegerAttr{}); + + program->setAttr(WaveASMDialect::getKernelNameAttrName(), + builder.getStringAttr(func.getName())); + assert(!program.getBody().empty() && "ProgramOp builder must create a block"); + return program; +} + +/// Look up the WaveASM value that an LLVM value was translated to. +/// Returns failure if the value was never mapped -- silently returning +/// the original (soon-to-be-erased) LLVM value is a use-after-free bug. +static FailureOr resolve(Value v, TranslationContext &ctx) { + if (auto mapped = ctx.getMapper().getMapped(v)) + return *mapped; + // Block arguments (func params) are mapped during prologue setup. + // If we get here, an LLVM op was skipped or handled incorrectly. + return failure(); +} + +/// Truncate an i64 WaveASM value to i32 via an arith.trunc pseudo-op. +/// Returns the value unchanged if the LLVM source type is already <= 32 bits. +static Value truncToI32(Value v, Type llvmType, OpBuilder &builder, + Location loc, TranslationContext &ctx) { + auto intTy = dyn_cast(llvmType); + if (!intTy || intTy.getWidth() <= 32) + return v; + Type resTy = isVGPRType(v.getType()) ? (Type)ctx.createVRegType() + : (Type)ctx.createSRegType(); + return ArithTruncOp::create(builder, loc, resTy, v); +} + +/// Infer the pseudo-op result type from operand types. +/// If any operand is VGPR -> VReg; otherwise SReg. +static Type inferResultType(ValueRange operands, TranslationContext &ctx) { + for (Value v : operands) + if (isVGPRType(v.getType())) + return ctx.createVRegType(); + return ctx.createSRegType(); +} + +//===----------------------------------------------------------------------===// +// Op handlers +//===----------------------------------------------------------------------===// + +static LogicalResult handlePoison(LLVM::PoisonOp op, LLVMTranslationState &st) { + auto &ctx = st.ctx; + auto &builder = ctx.getBuilder(); + auto loc = op.getLoc(); + + // Poison is undefined -- materialize as zero. Must be mapped because + // downstream ops (e.g. GEP, add) may reference the poison result via + // resolve(), which now requires every LLVM value to have a mapping. + auto intType = dyn_cast(op.getResult().getType()); + if (!intType) + return op->emitOpError("expected integer poison"); + + auto immTy = ctx.createImmType(0); + auto zeroImm = ConstantOp::create(builder, loc, immTy, int64_t{0}); + auto vregTy = ctx.createVRegType(); + + if (intType.getWidth() <= 32) { + Value mov = V_MOV_B32::create(builder, loc, vregTy, zeroImm); + ctx.getMapper().mapValue(op.getResult(), mov); + return success(); + } + + if (intType.getWidth() <= 64) { + Value loMov = V_MOV_B32::create(builder, loc, vregTy, zeroImm); + Value hiMov = V_MOV_B32::create(builder, loc, vregTy, zeroImm); + auto wideTy = ctx.createVRegType(2, 2); + Value packed = + PackOp::create(builder, loc, wideTy, ValueRange{loMov, hiMov}); + ctx.getMapper().mapValue(op.getResult(), packed); + return success(); + } + + return op->emitOpError("unsupported poison width (expected i32 or i64)"); +} + +static LogicalResult handleConstant(LLVM::ConstantOp op, + LLVMTranslationState &st) { + auto &ctx = st.ctx; + auto &builder = ctx.getBuilder(); + auto loc = op.getLoc(); + + auto valAttr = op.getValue(); + int64_t intVal = 0; + if (auto intAttr = dyn_cast(valAttr)) + intVal = intAttr.getValue().getSExtValue(); + else + return op->emitOpError("unsupported constant type"); + + auto intType = dyn_cast(op.getResult().getType()); + if (!intType) + return op->emitOpError("expected integer constant"); + + if (intType.getWidth() <= 32) { + auto immTy = ctx.createImmType(intVal); + auto immOp = ConstantOp::create(builder, loc, immTy, intVal); + auto vregTy = ctx.createVRegType(); + auto mov = V_MOV_B32::create(builder, loc, vregTy, immOp); + ctx.getMapper().mapValue(op.getResult(), mov); + return success(); + } + + if (intType.getWidth() <= 64) { + // Split i64 constant into lo/hi halves and pack into vreg<2>. + int32_t lo = static_cast(intVal & 0xFFFFFFFF); + int32_t hi = static_cast(static_cast(intVal) >> 32); + auto vregTy = ctx.createVRegType(); + auto loImm = ConstantOp::create(builder, loc, ctx.createImmType(lo), lo); + auto hiImm = ConstantOp::create(builder, loc, ctx.createImmType(hi), hi); + Value loMov = V_MOV_B32::create(builder, loc, vregTy, loImm); + Value hiMov = V_MOV_B32::create(builder, loc, vregTy, hiImm); + auto wideTy = ctx.createVRegType(2, 2); + Value packed = + PackOp::create(builder, loc, wideTy, ValueRange{loMov, hiMov}); + ctx.getMapper().mapValue(op.getResult(), packed); + return success(); + } + + return op->emitOpError("unsupported constant width (expected i32 or i64)"); +} + +static LogicalResult handleThreadIdX(ROCDL::ThreadIdXOp op, + LLVMTranslationState &st) { + auto &ctx = st.ctx; + auto &builder = ctx.getBuilder(); + auto loc = op.getLoc(); + + // rocdl.workitem.id.x -> hardware v0 (flat workitem ID). + ctx.setUsesWorkitemId(true); + auto vregTy = ctx.createVRegType(); + auto v0 = PrecoloredVRegOp::create(builder, loc, vregTy, /*regIndex=*/0, + /*size=*/1); + ctx.getMapper().mapValue(op.getResult(), v0); + return success(); +} + +// rocdl.workgroup.id.{x,y,z} -> system SGPRs (set by hardware dispatch). +template +static LogicalResult handleWorkgroupId(OpTy op, LLVMTranslationState &st, + int dimIndex) { + auto &ctx = st.ctx; + auto &builder = ctx.getBuilder(); + auto loc = op.getLoc(); + + if (dimIndex == 0) + ctx.setUsesWorkgroupIdX(true); + else if (dimIndex == 1) + ctx.setUsesWorkgroupIdY(true); + else + ctx.setUsesWorkgroupIdZ(true); + + int64_t sgprIndex = ctx.getWorkgroupIdSgprIndex(dimIndex); + auto sregType = ctx.createSRegType(); + auto blockId = PrecoloredSRegOp::create(builder, loc, sregType, sgprIndex, 1); + ctx.getMapper().mapValue(op.getResult(), blockId); + return success(); +} + +// Emit arith pseudo-ops for i32/i64 casts -- legalization pass handles width. +/// Translate an LLVM cast op to a WaveASM arithmetic pseudo-op. +template +static LogicalResult handleCastOp(LLVMOp op, LLVMTranslationState &st) { + auto &ctx = st.ctx; + auto &builder = ctx.getBuilder(); + FailureOr src = resolve(op.getOperand(), ctx); + if (failed(src)) + return op->emitOpError("unmapped operand in cast"); + Type resTy = isVGPRType(src->getType()) ? (Type)ctx.createVRegType() + : (Type)ctx.createSRegType(); + Value pseudo = WaveASMOp::create(builder, op.getLoc(), resTy, *src); + ctx.getMapper().mapValue(op.getResult(), pseudo); + return success(); +} + +/// Map LLVM ICmpPredicate to WaveASM CmpPredicate. +static CmpPredicate mapLLVMPredicate(LLVM::ICmpPredicate pred) { + using LP = LLVM::ICmpPredicate; + switch (pred) { + case LP::eq: + return CmpPredicate::eq; + case LP::ne: + return CmpPredicate::ne; + case LP::slt: + return CmpPredicate::slt; + case LP::sle: + return CmpPredicate::sle; + case LP::sgt: + return CmpPredicate::sgt; + case LP::sge: + return CmpPredicate::sge; + case LP::ult: + return CmpPredicate::ult; + case LP::ule: + return CmpPredicate::ule; + case LP::ugt: + return CmpPredicate::ugt; + case LP::uge: + return CmpPredicate::uge; + } + llvm_unreachable("unhandled LLVM ICmpPredicate"); +} + +static LogicalResult handleICmp(LLVM::ICmpOp op, LLVMTranslationState &st) { + auto &ctx = st.ctx; + auto &builder = ctx.getBuilder(); + FailureOr lhs = resolve(op.getLhs(), ctx); + FailureOr rhs = resolve(op.getRhs(), ctx); + if (failed(lhs) || failed(rhs)) + return op->emitOpError("unmapped operand in icmp"); + Type resTy = inferResultType({*lhs, *rhs}, ctx); + auto pred = mapLLVMPredicate(op.getPredicate()); + Value cmp = ArithCmpOp::create(builder, op.getLoc(), resTy, pred, *lhs, *rhs); + ctx.getMapper().mapValue(op.getResult(), cmp); + return success(); +} + +static LogicalResult handleSelect(LLVM::SelectOp op, LLVMTranslationState &st) { + auto &ctx = st.ctx; + auto &builder = ctx.getBuilder(); + FailureOr cond = resolve(op.getCondition(), ctx); + FailureOr trueVal = resolve(op.getTrueValue(), ctx); + FailureOr falseVal = resolve(op.getFalseValue(), ctx); + if (failed(cond) || failed(trueVal) || failed(falseVal)) + return op->emitOpError("unmapped operand in select"); + Type resTy = inferResultType({*trueVal, *falseVal}, ctx); + // ODS declaration order: (falseVal, trueVal, condition). + Value sel = ArithSelectOp::create(builder, op.getLoc(), resTy, *falseVal, + *trueVal, *cond); + ctx.getMapper().mapValue(op.getResult(), sel); + return success(); +} + +/// Translate an LLVM binary op to a WaveASM arithmetic pseudo-op. +/// Width validation is deferred to ArithLegalization. +template +static LogicalResult handleBinaryOp(LLVMOp op, LLVMTranslationState &st) { + auto &ctx = st.ctx; + auto &builder = ctx.getBuilder(); + FailureOr lhs = resolve(op.getLhs(), ctx); + FailureOr rhs = resolve(op.getRhs(), ctx); + if (failed(lhs) || failed(rhs)) + return op->emitOpError("unmapped operand in binary op"); + Type resTy = inferResultType({*lhs, *rhs}, ctx); + Value result = WaveASMOp::create(builder, op.getLoc(), resTy, *lhs, *rhs); + ctx.getMapper().mapValue(op.getResult(), result); + return success(); +} + +static LogicalResult handleMakeBufferRsrc(ROCDL::MakeBufferRsrcOp op, + LLVMTranslationState &st) { + auto &ctx = st.ctx; + auto &builder = ctx.getBuilder(); + auto loc = op.getLoc(); + + // The base pointer was set up as an SRD in the prologue via queueSRDSetup. + Value basePtr = op.getBase(); + auto srdVal = ctx.getMapper().getMapped(basePtr); + if (!srdVal) + return op->emitOpError("SRD not found for base pointer"); + + // The prologue used s_mov_b64 to copy the 64-bit pointer into SRD[0:1]. + // This corrupts SRD word 1 bits [31:16] (stride/swizzle) with pointer bits. + // Also, the prologue hardcodes SRD[3]=0x20000 but make.buffer.rsrc may + // want different flags. Patch both now that we know the actual values. + auto srdOp = dyn_cast(srdVal->getDefiningOp()); + if (srdOp) { + int64_t srdBase = srdOp.getIndex(); + + // TODO: Replace RawOp with typed S_AND_B32/S_MOV_B32 ops. Blocked on + // regalloc supporting contiguous allocation constraints for PackOp + // inputs, so SRD sub-registers can be addressed without hardcoded + // register strings. + + // Clear stride/swizzle bits in SRD word 1 (keep only base_addr[47:32]). + std::string andStr = "s_and_b32 s" + std::to_string(srdBase + 1) + ", s" + + std::to_string(srdBase + 1) + ", 0x" + + llvm::utohexstr(kSRDWord1BaseMask); + RawOp::create(builder, loc, andStr); + + // Patch SRD[3] with the actual flags from make.buffer.rsrc. + auto flags = getConstantIntValue(op.getFlags()); + if (flags && *flags != kSRDDefaultFlags) { + std::string movFlags = "s_mov_b32 s" + std::to_string(srdBase + 3) + + ", 0x" + llvm::utohexstr(*flags); + RawOp::create(builder, loc, movFlags); + } + } + + st.mapBufferRsrc(op.getResult(), *srdVal); + + // Propagate any base offset from bare-pointer GEPs so buffer GEPs + // can add it to their voffset. + Value baseOff = st.lookupBaseOffset(basePtr); + if (baseOff) + st.setBaseOffset(op.getResult(), baseOff); + + return success(); +} + +static LogicalResult handleGEP(LLVM::GEPOp op, LLVMTranslationState &st) { + auto &ctx = st.ctx; + auto &builder = ctx.getBuilder(); + auto loc = op.getLoc(); + Value base = op.getBase(); + + // GEP index is a dynamic Value (not a constant attr). + auto indices = op.getIndices(); + if (indices.size() != 1) + return op->emitOpError("GEP with multiple indices not yet supported"); + auto idx = indices[0].dyn_cast(); + if (!idx) + return op->emitOpError("GEP with constant index attr not yet supported"); + + FailureOr resolved = resolve(idx, ctx); + if (failed(resolved)) + return op->emitOpError("unmapped GEP index"); + Value newOffset = *resolved; + + // Buffer voffsets are 32-bit. Truncate i64 GEP indices. + newOffset = truncToI32(newOffset, idx.getType(), builder, loc, ctx); + + // Bare-pointer GEP (!llvm.ptr, not <7>): pointer arithmetic before + // make.buffer.rsrc. Propagate the mapper entry and accumulate + // the byte offset so it can be added to voffset at load/store time. + auto baseTy = op.getBase().getType(); + unsigned addrSpace = 0; + if (auto ptrTy = dyn_cast(baseTy)) + addrSpace = ptrTy.getAddressSpace(); + + if (addrSpace != 0 && addrSpace != 7) + return op->emitOpError("unsupported address space ") << addrSpace; + + if (addrSpace == 0) { + // Forward mapper entry so make.buffer.rsrc can find the SRD. + if (auto mapped = ctx.getMapper().getMapped(base)) + ctx.getMapper().mapValue(op.getResult(), *mapped); + + // Accumulate base offset. + Value prevOffset = st.lookupBaseOffset(base); + if (prevOffset) { + auto vregTy = ctx.createVRegType(); + newOffset = + V_ADD_U32::create(builder, loc, vregTy, prevOffset, newOffset); + } + st.setBaseOffset(op.getResult(), newOffset); + return success(); + } + + // Buffer GEP (ptr<7>): decompose into (SRD, voffset). + auto srd = st.lookupSRD(base); + if (srd) { + // Check if the make.buffer.rsrc had a base offset from bare-pointer GEPs. + Value baseOff = st.lookupBaseOffset(base); + if (baseOff) { + auto vregTy = ctx.createVRegType(); + newOffset = V_ADD_U32::create(builder, loc, vregTy, baseOff, newOffset); + } + st.mapGEP(op.getResult(), {srd, newOffset}); + return success(); + } + + std::optional baseGEP = st.lookupGEP(base); + if (!baseGEP) + return op->emitOpError("GEP base is not a tracked buffer resource"); + + // Chain: add this offset to the base GEP's offset. + auto vregTy = ctx.createVRegType(); + auto sum = + V_ADD_U32::create(builder, loc, vregTy, baseGEP->voffset, newOffset); + st.mapGEP(op.getResult(), {baseGEP->srd, sum}); + return success(); +} + +/// Compute buffer load/store size from the LLVM element type. +static int64_t getBufferAccessBytes(Type ty) { + if (auto vecTy = dyn_cast(ty)) + return vecTy.getNumElements() * + vecTy.getElementType().getIntOrFloatBitWidth() / 8; + if (ty.isIntOrFloat()) + return ty.getIntOrFloatBitWidth() / 8; + return 0; +} + +static LogicalResult handleLoad(LLVM::LoadOp op, LLVMTranslationState &st) { + auto &ctx = st.ctx; + auto &builder = ctx.getBuilder(); + auto loc = op.getLoc(); + + std::optional ptr = st.lookupGEP(op.getAddr()); + if (!ptr) + return op->emitOpError("load address not from a tracked GEP"); + + int64_t numBytes = getBufferAccessBytes(op.getResult().getType()); + + auto soffsetTy = ctx.createImmType(0); + auto zeroSoffset = ConstantOp::create(builder, loc, soffsetTy, 0); + auto vregTy = ctx.createVRegType(); + + Operation *loadOp = nullptr; + if (numBytes == 2) + loadOp = BUFFER_LOAD_USHORT::create(builder, loc, TypeRange{vregTy}, + ptr->srd, ptr->voffset, zeroSoffset, + /*instOffset=*/0); + else if (numBytes == 4) + loadOp = + BUFFER_LOAD_DWORD::create(builder, loc, TypeRange{vregTy}, ptr->srd, + ptr->voffset, zeroSoffset, /*instOffset=*/0); + else if (numBytes == 8) { + auto wideTy = ctx.createVRegType(2, 2); + loadOp = BUFFER_LOAD_DWORDX2::create(builder, loc, TypeRange{wideTy}, + ptr->srd, ptr->voffset, zeroSoffset, + /*instOffset=*/0); + } else if (numBytes == 12) { + auto wideTy = ctx.createVRegType(3, 3); + loadOp = BUFFER_LOAD_DWORDX3::create(builder, loc, TypeRange{wideTy}, + ptr->srd, ptr->voffset, zeroSoffset, + /*instOffset=*/0); + } else if (numBytes == 16) { + auto wideTy = ctx.createVRegType(4, 4); + loadOp = BUFFER_LOAD_DWORDX4::create(builder, loc, TypeRange{wideTy}, + ptr->srd, ptr->voffset, zeroSoffset, + /*instOffset=*/0); + } else + return op->emitOpError("unsupported load size: ") << numBytes << " bytes"; + + ctx.getMapper().mapValue(op.getResult(), loadOp->getResult(0)); + return success(); +} + +static LogicalResult handleStore(LLVM::StoreOp op, LLVMTranslationState &st) { + auto &ctx = st.ctx; + auto &builder = ctx.getBuilder(); + auto loc = op.getLoc(); + + std::optional ptr = st.lookupGEP(op.getAddr()); + if (!ptr) + return op->emitOpError("store address not from a tracked GEP"); + + FailureOr data = resolve(op.getValue(), ctx); + if (failed(data)) + return op->emitOpError("unmapped store value"); + int64_t numBytes = getBufferAccessBytes(op.getValue().getType()); + + if (numBytes == 2) + BUFFER_STORE_SHORT::create(builder, loc, *data, ptr->srd, ptr->voffset, + /*instOffset=*/0); + else if (numBytes == 4) + BUFFER_STORE_DWORD::create(builder, loc, *data, ptr->srd, ptr->voffset, + /*instOffset=*/0); + else if (numBytes == 8) + BUFFER_STORE_DWORDX2::create(builder, loc, *data, ptr->srd, ptr->voffset, + /*instOffset=*/0); + else if (numBytes == 12) + BUFFER_STORE_DWORDX3::create(builder, loc, *data, ptr->srd, ptr->voffset, + /*instOffset=*/0); + else if (numBytes == 16) + BUFFER_STORE_DWORDX4::create(builder, loc, *data, ptr->srd, ptr->voffset, + /*instOffset=*/0); + else + return op->emitOpError("unsupported store size: ") << numBytes << " bytes"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// Op dispatch +//===----------------------------------------------------------------------===// + +static LogicalResult translateOp(Operation *op, LLVMTranslationState &st) { + return llvm::TypeSwitch(op) + .Case([&](LLVM::ConstantOp o) { return handleConstant(o, st); }) + .Case([&](LLVM::PoisonOp o) { return handlePoison(o, st); }) + .Case([&](ROCDL::ThreadIdXOp o) { return handleThreadIdX(o, st); }) + .Case([&](ROCDL::BlockIdXOp o) { return handleWorkgroupId(o, st, 0); }) + .Case([&](ROCDL::BlockIdYOp o) { return handleWorkgroupId(o, st, 1); }) + .Case([&](ROCDL::BlockIdZOp o) { return handleWorkgroupId(o, st, 2); }) + .Case([&](LLVM::SExtOp o) { + return handleCastOp(o, st); + }) + .Case([&](LLVM::ZExtOp o) { + return handleCastOp(o, st); + }) + .Case([&](LLVM::TruncOp o) { + return handleCastOp(o, st); + }) + .Case([&](LLVM::ICmpOp o) { return handleICmp(o, st); }) + .Case([&](LLVM::SelectOp o) { return handleSelect(o, st); }) + .Case([&](LLVM::MulOp o) { + return handleBinaryOp(o, st); + }) + .Case([&](LLVM::AddOp o) { + return handleBinaryOp(o, st); + }) + .Case([&](LLVM::OrOp o) { + return handleBinaryOp(o, st); + }) + .Case([&](LLVM::AndOp o) { + return handleBinaryOp(o, st); + }) + .Case([&](ROCDL::MakeBufferRsrcOp o) { + return handleMakeBufferRsrc(o, st); + }) + .Case([&](LLVM::GEPOp o) { return handleGEP(o, st); }) + .Case([&](LLVM::LoadOp o) { return handleLoad(o, st); }) + .Case([&](LLVM::StoreOp o) { return handleStore(o, st); }) + .Default([](Operation *op) { + return op->emitOpError("unhandled op in LLVM->WaveASM translation"); + }); +} + +//===----------------------------------------------------------------------===// +// Core translation logic +//===----------------------------------------------------------------------===// + +static LogicalResult translateLLVMModule(Operation *rootOp, + StringRef targetId) { + auto target = getTargetKindAttr(rootOp->getContext(), targetId); + if (!target) + return rootOp->emitError() << "unknown target: " << targetId; + + if (!isa(target)) + return rootOp->emitError() + << "LLVM->WaveASM translation only supports gfx950, got " + << targetId; + + SmallVector kernels; + rootOp->walk([&](LLVM::LLVMFuncOp func) { + if (func->hasAttr("gpu.kernel") || func->hasAttr("rocdl.kernel")) + kernels.push_back(func); + }); + + if (kernels.empty()) + return success(); + + for (LLVM::LLVMFuncOp func : kernels) { + OpBuilder builder(rootOp->getContext()); + builder.setInsertionPointAfter(func); + + ProgramOp program = createProgramFromLLVMFunc(func, builder, targetId); + if (!program) + return failure(); + builder.setInsertionPointToStart(&program.getBodyBlock()); + TranslationContext ctx(builder, program, target); + LLVMTranslationState st(ctx); + + // Map llvm.func arguments: pointers get SRD setup, scalars get mapped + // to their preloaded SGPR positions directly. + SmallVector scalarArgs; + for (auto arg : func.getBody().getArguments()) { + if (isa(arg.getType())) { + int64_t argIdx = arg.getArgNumber(); + ctx.queueSRDSetup(arg, argIdx, /*bufferSize=*/kSRDDefaultNumRecords); + } else { + scalarArgs.push_back(arg); + ctx.queueScalarArgLoad(arg, arg.getArgNumber()); + } + } + + ctx.emitSRDPrologue(); + + // Map scalar (non-pointer) args to their SGPR positions. + // gfx950 hardware preloads arg N into s[2+N*2 : 2+N*2+1] (64-bit each). + // Assumes all scalar args fit in the preload window (no overflow). + for (auto arg : scalarArgs) { + int64_t argIdx = arg.getArgNumber(); + int64_t preloadBase = 2 + argIdx * 2; + auto sregTy = ctx.createSRegType(2, 2); + auto sreg = PrecoloredSRegOp::create(builder, arg.getLoc(), sregTy, + preloadBase, /*size=*/2); + ctx.getMapper().mapValue(arg, sreg); + } + + // Enable all workgroup IDs so the SGPR layout is predictable. + // Note: LLVM enables them selectively via amdgpu-no-workgroup-id-{y,z} + // attributes. We enable all three unconditionally for simplicity. + ctx.enableAllWorkgroupIds(); + + for (Operation &op : func.getBody().front()) { + if (isa(op)) + continue; + if (failed(translateOp(&op, st))) + return failure(); + } + + S_ENDPGM::create(builder, func.getLoc()); + + program->setAttr("num_kernel_args", + builder.getI64IntegerAttr(func.getNumArguments())); + + int64_t ldsSize = ctx.getTotalLDSSize(); + if (ldsSize > 0) + program->setAttr("lds_size", builder.getI64IntegerAttr(ldsSize)); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// Pass definition +//===----------------------------------------------------------------------===// + +namespace { + +struct WAVEASMTranslateFromLLVMPass + : impl::WAVEASMTranslateFromLLVMBase { + using WAVEASMTranslateFromLLVMBase::WAVEASMTranslateFromLLVMBase; + + void runOnOperation() override { + if (failed(translateLLVMModule(getOperation(), targetArch))) + return signalPassFailure(); + } +}; + +} // namespace + +} // namespace waveasm diff --git a/waveasm/lib/Transforms/TranslateFromMLIR.cpp b/waveasm/lib/Transforms/TranslateFromMLIR.cpp index 0b0e849a5c..0d98072715 100644 --- a/waveasm/lib/Transforms/TranslateFromMLIR.cpp +++ b/waveasm/lib/Transforms/TranslateFromMLIR.cpp @@ -9,6 +9,7 @@ #include "waveasm/Dialect/WaveASMDialect.h" #include "waveasm/Dialect/WaveASMOps.h" #include "waveasm/Dialect/WaveASMTypes.h" +#include "waveasm/Transforms/AssemblyEmitter.h" #include "waveasm/Transforms/Utils.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" @@ -161,8 +162,9 @@ void TranslationContext::emitSRDPrologue() { // branch+alignment) bool isGFX95 = llvm::isa(target); - // Recompute SRD base indices now that we know the total number of args - // SRDs must start after: user SGPRs + system SGPRs (workgroup IDs) + // Recompute SRD base indices now that we know the total number of args. + // SRDs must start after: user SGPRs + system SGPRs (workgroup IDs). + size_t numPreloadedArgs = getNumKernelArgs(); int64_t userSgprCount = 2; // kernarg ptr if (isGFX95) { userSgprCount += std::min(int64_t(14), (int64_t)getNumKernelArgs() * 2); @@ -189,6 +191,12 @@ void TranslationContext::emitSRDPrologue() { int64_t afterOverflow = afterLastSrd + numOverflowScalars * 2; nextSwizzleSRDIndex = (afterOverflow + 3) & ~3; // Align to 4 + // Dedicated SGPR slots for scalar kernel args, placed after overflow-scalar + // slots so they don't collide with SRDs or preloads. + int64_t scalarArgSgprBase = nextSwizzleSRDIndex; + nextSwizzleSRDIndex = + (nextSwizzleSRDIndex + (int64_t)pendingScalarArgs.size() + 3) & ~3; + // Emit comment for prologue CommentOp::create(builder, loc, "SRD setup prologue"); @@ -207,26 +215,14 @@ void TranslationContext::emitSRDPrologue() { // Hardware limits user SGPRs to 16 (s[0:15]), so only reserve preload // slots for args that fit within the limit. Overflow args are loaded // via explicit s_load from the kernarg buffer at runtime. - llvm::DenseSet reservedPreloadBases; - for (const auto &pending : pendingSRDs) { - int64_t preloadBase = 2 + pending.argIndex * 2; - if (preloadBase >= 16) - continue; - if (reservedPreloadBases.insert(preloadBase).second) { - auto preloadType = createSRegType(2, 2); - PrecoloredSRegOp::create(builder, loc, preloadType, preloadBase, - /*size=*/2); - } - } - for (const auto &pending : pendingScalarArgs) { - int64_t preloadBase = 2 + pending.argIndex * 2; + // Reserve all arg positions, not just pointer args with SRDs. + for (size_t i = 0; i < numPreloadedArgs; ++i) { + int64_t preloadBase = 2 + i * 2; if (preloadBase >= 16) continue; - if (reservedPreloadBases.insert(preloadBase).second) { - auto preloadType = createSRegType(2, 2); - PrecoloredSRegOp::create(builder, loc, preloadType, preloadBase, - /*size=*/2); - } + auto preloadType = createSRegType(2, 2); + PrecoloredSRegOp::create(builder, loc, preloadType, preloadBase, + /*size=*/2); } } @@ -240,11 +236,12 @@ void TranslationContext::emitSRDPrologue() { auto kernargBase = PrecoloredSRegOp::create(builder, loc, kernargSRegType, 0, 2); - for (const auto &pending : pendingSRDs) { - int64_t loadBase = 2 + pending.argIndex * 2; + // Load all kernel args (pointers and scalars) into preload positions. + for (size_t i = 0; i < numPreloadedArgs; ++i) { + int64_t loadBase = 2 + i * 2; if (loadBase >= 16) continue; // Overflow arg: loaded via s_load_dword path below. - int64_t kernargOffset = pending.argIndex * 8; + int64_t kernargOffset = i * 8; auto loadDstType = createSRegType(2, loadBase); auto offsetImm = builder.getType(kernargOffset); @@ -254,31 +251,16 @@ void TranslationContext::emitSRDPrologue() { offsetConst); } - // Also load scalar kernel arguments (index types) from kernarg buffer. - // Scalar args that still fit in the preload SGPR window can be loaded - // before the aligned main entry. + // Overflow scalar args: loaded after the aligned entry point below. int64_t overflowSgprBase = (srdStartIndex + (int64_t)pendingSRDs.size() * 4 + 3) & ~3; - for (const auto &pending : pendingScalarArgs) { - int64_t loadBase = 2 + pending.argIndex * 2; - if (loadBase >= 16) - continue; - int64_t kernargOffset = pending.argIndex * 8; - - auto loadDstType = createSRegType(2, loadBase); - auto offsetImm = builder.getType(kernargOffset); - auto offsetConst = - ConstantOp::create(builder, loc, offsetImm, kernargOffset); - S_LOAD_DWORDX2::create(builder, loc, TypeRange{loadDstType}, kernargBase, - offsetConst); - } // Step 2: Branch to aligned entry point (gfx95* requirement). // Keep any high-SGPR overflow loads after the aligned entry; LLVM does the // same, and loading them before the branch leaves the overflow arg stale // on gfx95 hardware. // NOTE: Labels/branches are control flow and must remain as RawOp for now. - std::string kernelName = program.getSymName().str(); + std::string kernelName = getKernelName(program).str(); std::string mainLabel = ".L_" + kernelName + "_main"; RawOp::create(builder, loc, "s_branch " + mainLabel); @@ -320,6 +302,8 @@ void TranslationContext::emitSRDPrologue() { // size/stride. Must use RawOp: S_MOV_B64/S_MOV_B32 are Pure (SALUUnaryOp) // and write to physical registers with no SSA consumer, so CSE/DCE // eliminates them. + // TODO: Replace with typed ops once regalloc supports contiguous + // allocation constraints for PackOp inputs. for (size_t i = 0; i < pendingSRDs.size(); ++i) { const auto &pending = pendingSRDs[i]; int64_t srdBase = pending.srdBaseIndex; @@ -344,14 +328,22 @@ void TranslationContext::emitSRDPrologue() { ", 0x" + llvm::utohexstr(kSRDStrideSwizzle); RawOp::create(builder, loc, movStrideStr); + // Prevent DCE from removing this PrecoloredSRegOp. The SRD registers + // are later referenced by RawOps (e.g., s_mov_b64 for epilogue SRD + // copies) that don't create SSA uses. Without DCEProtect, canonicalize + // removes the PrecoloredSRegOp, and the register allocator doesn't + // reserve s[srdBase:srdBase+3], allowing it to allocate temps there. + DCEProtectOp::create(builder, loc, srdReg); + mapper.mapValue(pending.memref, srdReg); } - // Move scalar args from preload SGPRs to VGPRs. + // Copy scalar args from preload/overflow SGPRs to dedicated SGPRs. // Lower 32 bits of the preload pair hold the value (little-endian). // Overflow args were loaded into overflowSgprBase positions above. int64_t ovfIdx = 0; - for (const auto &pending : pendingScalarArgs) { + for (size_t i = 0; i < pendingScalarArgs.size(); ++i) { + const auto &pending = pendingScalarArgs[i]; int64_t preloadBase = 2 + pending.argIndex * 2; int64_t sgprSrc; if (preloadBase >= 16) { @@ -360,13 +352,17 @@ void TranslationContext::emitSRDPrologue() { } else { sgprSrc = preloadBase; } - auto vregType = createVRegType(); - auto vreg = - PrecoloredVRegOp::create(builder, loc, vregType, pending.argIndex, 1); + int64_t sgprDst = scalarArgSgprBase + i; + + auto dstType = createSRegType(1, 1); + auto dstSreg = + PrecoloredSRegOp::create(builder, loc, dstType, sgprDst, 1); + RawOp::create(builder, loc, - "v_mov_b32 v" + std::to_string(pending.argIndex) + ", s" + + "s_mov_b32 s" + std::to_string(sgprDst) + ", s" + std::to_string(sgprSrc)); - mapper.mapValue(pending.blockArg, vreg); + + mapper.mapValue(pending.blockArg, dstSreg); } } else { // Non-GFX95* path (e.g., gfx942): Load directly into SRD positions @@ -441,18 +437,15 @@ void TranslationContext::emitSRDPrologue() { mapper.mapValue(pending.memref, srdReg); } - // Move scalar args from SGPRs to VGPRs + // Map scalar args to their dedicated SGPRs (already loaded above). for (size_t i = 0; i < pendingScalarArgs.size(); ++i) { const auto &pending = pendingScalarArgs[i]; int64_t sgprIdx = scalarSgprBase + (int64_t)i; - auto vregType = createVRegType(); - auto vreg = - PrecoloredVRegOp::create(builder, loc, vregType, pending.argIndex, 1); - RawOp::create(builder, loc, - "v_mov_b32 v" + std::to_string(pending.argIndex) + ", s" + - std::to_string(sgprIdx)); - mapper.mapValue(pending.blockArg, vreg); + auto dstType = createSRegType(1, 1); + auto dstSreg = + PrecoloredSRegOp::create(builder, loc, dstType, sgprIdx, 1); + mapper.mapValue(pending.blockArg, dstSreg); } } @@ -609,7 +602,7 @@ Value emitSRDBaseAdjustment(const TranslationContext::PendingSRDBaseAdjust &adj, auto *mlirCtx = builder.getContext(); int64_t N = ctx.getNextSwizzleSRDIndex(); - assert(N + 4 < 108 && "SRD allocation exceeds SGPR limit"); + assert(N + 4 <= 102 && "SRD allocation exceeds SGPR limit (s0-s101)"); // Copy source SRD base to new SRD. // Must use RawOp: S_MOV_B64 is Pure (SALUUnaryOp) and writes to a @@ -1026,12 +1019,20 @@ LogicalResult handleVectorTransferWrite(Operation *op, voffset = ConstantOp::create(builder, loc, immType, 0); } + Value storeData = *data; + if (isAGPRType(storeData.getType())) { + auto vregType = + ctx.createVRegType(numDwords, numDwords > 1 ? numDwords : 1); + storeData = + V_ACCVGPR_READ_B32::create(builder, loc, vregType, storeData); + } + if (numDwords == 1) { - BUFFER_STORE_DWORD::create(builder, loc, *data, srd, voffset); + BUFFER_STORE_DWORD::create(builder, loc, storeData, srd, voffset); } else if (numDwords == 2) { - BUFFER_STORE_DWORDX2::create(builder, loc, *data, srd, voffset); + BUFFER_STORE_DWORDX2::create(builder, loc, storeData, srd, voffset); } else { - BUFFER_STORE_DWORDX4::create(builder, loc, *data, srd, voffset); + BUFFER_STORE_DWORDX4::create(builder, loc, storeData, srd, voffset); } } diff --git a/waveasm/lib/Transforms/VGPRCompaction.cpp b/waveasm/lib/Transforms/VGPRCompaction.cpp new file mode 100644 index 0000000000..43f77ca1b2 --- /dev/null +++ b/waveasm/lib/Transforms/VGPRCompaction.cpp @@ -0,0 +1,505 @@ +// Copyright 2026 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// VGPR Compaction Pass +// +// Re-assigns physical VGPRs after register allocation to minimize the peak +// register number. The linear scan allocator assigns VGPRs in instruction +// order (lowest-first), which causes fragmentation when interleaved +// buffer_load (long-lived) and ds_read (short-lived) instructions get +// interleaved register numbers. This pass reassigns them using a +// shortest-first greedy strategy that packs short-lived values into low +// registers, pushing long-lived values to a contiguous high range. +//===----------------------------------------------------------------------===// + +#include "waveasm/Dialect/WaveASMDialect.h" +#include "waveasm/Dialect/WaveASMOps.h" +#include "waveasm/Dialect/WaveASMTypes.h" +#include "waveasm/Transforms/Passes.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "waveasm-vgpr-compaction" + +using namespace mlir; +using namespace waveasm; + +namespace waveasm { +#define GEN_PASS_DEF_WAVEASMVGPRCOMPACTION +#include "waveasm/Transforms/Passes.h.inc" +} // namespace waveasm + +namespace { + +struct PhysVGPRRange { + int64_t physIdx; + int64_t size; + int64_t alignment; + int64_t defPoint; + int64_t lastUsePoint; + bool pinned = false; // Precolored — must keep original position + + int64_t length() const { return lastUsePoint - defPoint; } +}; + +static void collectOps(Block &block, + llvm::SmallVectorImpl &ops) { + for (Operation &op : block) { + ops.push_back(&op); + for (Region ®ion : op.getRegions()) + for (Block &nested : region) + collectOps(nested, ops); + } +} + +/// Record a PVRegType occurrence. Merges into existing entry if the base +/// index matches, taking max size/alignment and extending the time range. +static void recordPVReg(int64_t baseIdx, int64_t size, int64_t point, + bool isDef, + llvm::DenseMap &defPoints, + llvm::DenseMap &usePoints, + llvm::DenseMap &sizes, + llvm::DenseMap &alignments) { + int64_t align = (size >= 4) ? 4 : (size >= 2 ? 2 : 1); + + auto it = sizes.find(baseIdx); + if (it != sizes.end()) { + it->second = std::max(it->second, size); + alignments[baseIdx] = std::max(alignments[baseIdx], align); + if (isDef) + defPoints[baseIdx] = std::min(defPoints[baseIdx], point); + else { + auto uit = usePoints.find(baseIdx); + if (uit != usePoints.end()) + uit->second = std::max(uit->second, point); + else + usePoints[baseIdx] = point; + } + } else { + sizes[baseIdx] = size; + alignments[baseIdx] = align; + if (isDef) + defPoints[baseIdx] = point; + else + usePoints[baseIdx] = point; + } +} + +/// For a PVRegType with index X and size S, find the "allocation base": +/// the base index of the multi-register range that contains X. +/// E.g., if the allocator assigned v[92:95] (base=92, size=4), then +/// v93 (index=93, size=1) has allocation base 92. +/// Returns (allocBase, allocSize) or (idx, size) if no containing range. +static std::pair +findAllocBase(int64_t idx, int64_t size, + const llvm::DenseMap &knownSizes) { + // Check if this index IS a known base + auto it = knownSizes.find(idx); + if (it != knownSizes.end() && it->second >= size) + return {idx, it->second}; + + // Check if this index falls within a known larger range. + // Multi-reg ranges are always aligned, so check aligned bases below idx. + for (int64_t align : {4, 2}) { + int64_t base = (idx / align) * align; + if (base == idx) + continue; + auto baseIt = knownSizes.find(base); + if (baseIt != knownSizes.end() && base + baseIt->second > idx) + return {base, baseIt->second}; + } + + return {idx, size}; +} + +static void buildPhysRanges( + ProgramOp program, + llvm::SmallVectorImpl &ranges, + llvm::DenseMap &rangeBaseToIdx) { + + llvm::SmallVector ops; + collectOps(program.getBodyBlock(), ops); + + llvm::DenseMap opToIdx; + for (int64_t i = 0; i < static_cast(ops.size()); ++i) + opToIdx[ops[i]] = i; + + llvm::DenseMap defPoints, usePoints, sizes, alignments; + + // Track precolored VGPR indices — these must not be remapped. + llvm::DenseSet pinnedIndices; + + // First pass: collect all multi-register definitions to build the + // "known allocations" map. This lets us identify sub-element accesses. + // Also identify precolored VGPRs. + for (int64_t i = 0; i < static_cast(ops.size()); ++i) { + Operation *op = ops[i]; + + if (isa(op)) { + for (Value result : op->getResults()) { + if (auto pvreg = dyn_cast(result.getType())) { + pinnedIndices.insert(pvreg.getIndex()); + for (int64_t k = 0; k < pvreg.getSize(); ++k) + pinnedIndices.insert(pvreg.getIndex() + k); + } + } + } + + for (Value result : op->getResults()) { + if (auto pvreg = dyn_cast(result.getType())) { + if (pvreg.getSize() > 1) { + auto &sz = sizes[pvreg.getIndex()]; + sz = std::max(sz, pvreg.getSize()); + int64_t align = (pvreg.getSize() >= 4) ? 4 : 2; + auto &al = alignments[pvreg.getIndex()]; + al = std::max(al, align); + } + } + } + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (BlockArgument arg : block.getArguments()) { + if (auto pvreg = dyn_cast(arg.getType())) { + if (pvreg.getSize() > 1) { + auto &sz = sizes[pvreg.getIndex()]; + sz = std::max(sz, pvreg.getSize()); + int64_t align = (pvreg.getSize() >= 4) ? 4 : 2; + auto &al = alignments[pvreg.getIndex()]; + al = std::max(al, align); + } + } + } + } + } + } + + // Second pass: collect def/use points, mapping sub-elements to their base. + llvm::DenseMap knownSizes = sizes; + defPoints.clear(); + usePoints.clear(); + sizes.clear(); + alignments.clear(); + + auto processPVReg = [&](int64_t idx, int64_t size, int64_t point, + bool isDef) { + auto [base, allocSize] = findAllocBase(idx, size, knownSizes); + recordPVReg(base, allocSize, point, isDef, defPoints, usePoints, sizes, + alignments); + }; + + for (int64_t i = 0; i < static_cast(ops.size()); ++i) { + Operation *op = ops[i]; + + for (Value result : op->getResults()) { + if (auto pvreg = dyn_cast(result.getType())) + processPVReg(pvreg.getIndex(), pvreg.getSize(), i, true); + } + + for (Value operand : op->getOperands()) { + if (auto pvreg = dyn_cast(operand.getType())) + processPVReg(pvreg.getIndex(), pvreg.getSize(), i, false); + } + + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (BlockArgument arg : block.getArguments()) { + if (auto pvreg = dyn_cast(arg.getType())) + processPVReg(pvreg.getIndex(), pvreg.getSize(), i, true); + } + } + } + + if (isa(op)) { + auto loopOp = cast(op); + Block &body = loopOp.getBodyBlock(); + Operation *terminator = body.getTerminator(); + int64_t termIdx = terminator ? opToIdx.lookup(terminator) : i; + for (BlockArgument arg : body.getArguments()) { + if (auto pvreg = dyn_cast(arg.getType())) + processPVReg(pvreg.getIndex(), pvreg.getSize(), termIdx, false); + } + } + } + + // Extend live ranges for values used inside loop bodies. + // Any value used inside a loop at some program point P is actually + // live from its definition until the end of the loop body (not just P), + // because the loop re-executes and the value must be available on + // every iteration. + for (int64_t i = 0; i < static_cast(ops.size()); ++i) { + Operation *op = ops[i]; + if (!isa(op)) + continue; + auto loopOp = cast(op); + Block &body = loopOp.getBodyBlock(); + Operation *terminator = body.getTerminator(); + int64_t loopStart = i; + int64_t loopEnd = terminator ? opToIdx.lookup(terminator) : i; + + // For each operand used inside the loop body, extend its range + // to cover the entire loop. + body.walk([&](Operation *innerOp) { + for (Value operand : innerOp->getOperands()) { + if (auto pvreg = dyn_cast(operand.getType())) { + auto [base, allocSize] = findAllocBase(pvreg.getIndex(), + pvreg.getSize(), knownSizes); + auto defIt = defPoints.find(base); + if (defIt != defPoints.end() && defIt->second < loopStart) { + // Value defined before the loop but used inside — extend to + // end of loop body. + auto &usePt = usePoints[base]; + usePt = std::max(usePt, loopEnd); + } + } + } + }); + } + + for (const auto &[base, defPt] : defPoints) { + auto useIt = usePoints.find(base); + int64_t usePt = (useIt != usePoints.end()) ? useIt->second : defPt; + PhysVGPRRange r; + r.physIdx = base; + r.size = sizes.lookup(base); + r.alignment = alignments.lookup(base); + if (r.size == 0) + r.size = 1; + if (r.alignment == 0) + r.alignment = 1; + r.defPoint = defPt; + r.lastUsePoint = usePt; + r.pinned = pinnedIndices.contains(base); + rangeBaseToIdx[base] = ranges.size(); + ranges.push_back(r); + } +} + +static bool overlaps(const PhysVGPRRange &a, const PhysVGPRRange &b) { + return a.defPoint <= b.lastUsePoint && b.defPoint <= a.lastUsePoint; +} + +static llvm::DenseMap +computeCompaction(llvm::SmallVectorImpl &ranges, + int64_t maxRegs) { + llvm::SmallVector order(ranges.size()); + std::iota(order.begin(), order.end(), 0); + llvm::sort(order, [&](int64_t a, int64_t b) { + if (ranges[a].length() != ranges[b].length()) + return ranges[a].length() < ranges[b].length(); + return ranges[a].defPoint < ranges[b].defPoint; + }); + + llvm::DenseMap oldToNew; + llvm::SmallVector newAssignment(ranges.size(), -1); + + // Pin precolored ranges to their original positions first. + for (size_t i = 0; i < ranges.size(); ++i) { + if (ranges[i].pinned) { + newAssignment[i] = ranges[i].physIdx; + oldToNew[ranges[i].physIdx] = ranges[i].physIdx; + } + } + + for (int64_t orderIdx : order) { + if (ranges[orderIdx].pinned) + continue; + + PhysVGPRRange &r = ranges[orderIdx]; + int64_t sz = r.size; + int64_t align = r.alignment; + + // v15 is the scratch VGPR for literal materialization in the assembly + // emitter (AssemblyEmitter.h KernelGenerator::kScratchVGPR). Must stay + // excluded so compaction never places a live value there. + constexpr int64_t kScratchVGPR = 15; + llvm::BitVector occupied(maxRegs, false); + if (kScratchVGPR < maxRegs) + occupied.set(kScratchVGPR); + for (size_t j = 0; j < ranges.size(); ++j) { + if (newAssignment[j] < 0) + continue; + if (overlaps(r, ranges[j])) { + int64_t base = newAssignment[j]; + for (int64_t k = 0; k < ranges[j].size; ++k) + if (base + k < maxRegs) + occupied.set(base + k); + } + } + + int64_t chosen = -1; + for (int64_t c = 0; c + sz <= maxRegs; c += align) { + bool free = true; + for (int64_t k = 0; k < sz; ++k) { + if (occupied.test(c + k)) { + free = false; + break; + } + } + if (free) { + chosen = c; + break; + } + } + + if (chosen >= 0) { + newAssignment[orderIdx] = chosen; + oldToNew[r.physIdx] = chosen; + } else { + newAssignment[orderIdx] = r.physIdx; + oldToNew[r.physIdx] = r.physIdx; + } + } + + return oldToNew; +} + +/// Remap a PVRegType index using the base-to-base mapping. +/// For sub-elements (e.g., v93 within v[92:95]), computes the offset +/// from the old base and applies it to the new base. +static int64_t remapIndex(int64_t oldIdx, int64_t size, + const llvm::DenseMap &oldToNew, + const llvm::SmallVectorImpl &ranges) { + // Direct lookup: this index IS a base + auto it = oldToNew.find(oldIdx); + if (it != oldToNew.end()) + return it->second; + + // Sub-element: find the containing base range + for (const auto &r : ranges) { + if (oldIdx >= r.physIdx && oldIdx < r.physIdx + r.size) { + auto baseIt = oldToNew.find(r.physIdx); + if (baseIt != oldToNew.end()) + return baseIt->second + (oldIdx - r.physIdx); + } + } + + return oldIdx; +} + +static void applyRemapping(ProgramOp program, + const llvm::DenseMap &oldToNew, + const llvm::SmallVectorImpl &ranges) { + auto remap = [&](Type ty) -> Type { + auto pvreg = dyn_cast(ty); + if (!pvreg) + return ty; + int64_t newIdx = + remapIndex(pvreg.getIndex(), pvreg.getSize(), oldToNew, ranges); + if (newIdx == pvreg.getIndex()) + return ty; + return PVRegType::get(ty.getContext(), newIdx, pvreg.getSize()); + }; + + program.walk([&](Operation *op) { + if (isa(op)) + return; + + for (Value result : op->getResults()) { + Type newTy = remap(result.getType()); + if (newTy != result.getType()) + result.setType(newTy); + } + + if (auto condOp = dyn_cast(op)) { + if (auto attr = condOp->getAttrOfType( + "_iterArgPhysRegs")) { + auto vals = attr.asArrayRef(); + llvm::SmallVector newVals(vals.begin(), vals.end()); + bool anyChanged = false; + for (size_t i = 0; i < newVals.size(); ++i) { + if (newVals[i] < 0) + continue; + if (i < condOp.getIterArgs().size()) { + Type ty = condOp.getIterArgs()[i].getType(); + if (isa(ty)) { + int64_t newIdx = + remapIndex(newVals[i], 1, oldToNew, ranges); + if (newIdx != newVals[i]) { + newVals[i] = newIdx; + anyChanged = true; + } + } + } + } + if (anyChanged) { + condOp->setAttr("_iterArgPhysRegs", + DenseI64ArrayAttr::get(op->getContext(), newVals)); + } + } + } + }); + + program.walk([&](LoopOp loopOp) { + Block &body = loopOp.getBodyBlock(); + for (BlockArgument arg : body.getArguments()) { + Type newTy = remap(arg.getType()); + if (newTy != arg.getType()) + arg.setType(newTy); + } + for (Value result : loopOp->getResults()) { + Type newTy = remap(result.getType()); + if (newTy != result.getType()) + result.setType(newTy); + } + }); +} + +struct WAVEASMVGPRCompaction + : waveasm::impl::WAVEASMVGPRCompactionBase { + + void runOnOperation() override { + auto moduleOp = getOperation(); + + moduleOp->walk([&](ProgramOp program) { + llvm::SmallVector ranges; + llvm::DenseMap rangeBaseToIdx; + + buildPhysRanges(program, ranges, rangeBaseToIdx); + + if (ranges.empty()) + return; + + int64_t maxBefore = 0; + for (const auto &r : ranges) + maxBefore = std::max(maxBefore, r.physIdx + r.size); + + auto oldToNew = computeCompaction(ranges, /*maxRegs=*/512); + + int64_t maxAfter = 0; + for (const auto &r : ranges) { + auto it = oldToNew.find(r.physIdx); + int64_t newIdx = (it != oldToNew.end()) ? it->second : r.physIdx; + maxAfter = std::max(maxAfter, newIdx + r.size); + } + + bool anyChange = false; + for (const auto &[old, newIdx] : oldToNew) { + if (old != newIdx) { + anyChange = true; + break; + } + } + + if (!anyChange) + return; + + LLVM_DEBUG(llvm::dbgs() << "VGPR compaction: " << maxBefore << " -> " + << maxAfter << " (saved " + << (maxBefore - maxAfter) << ")\n"); + + applyRemapping(program, oldToNew, ranges); + }); + } +}; + +} // namespace diff --git a/waveasm/lib/Transforms/handlers/AMDGPUHandlers.cpp b/waveasm/lib/Transforms/handlers/AMDGPUHandlers.cpp index 336db2c46a..5ccc8da9eb 100644 --- a/waveasm/lib/Transforms/handlers/AMDGPUHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/AMDGPUHandlers.cpp @@ -351,8 +351,9 @@ LogicalResult handleAMDGPUScaledMfma(Operation *op, TranslationContext &ctx) { /// Emit the SRD NUM_RECORDS field (word 2) from the validBytes operand. /// Static constants emit s_mov_b32; dynamic values (e.g. from arith.select -/// for the branchless g2s guard) go through v_readfirstlane_b32 + s_add_u32 -/// (non-Pure to avoid DCE). Falls back to hardware maximum if absent. +/// for the branchless g2s guard) use s_cmp + s_cselect when all operands +/// are scalar, or v_readfirstlane_b32 for VGPR values. +/// Falls back to hardware maximum if absent. static void emitSrdNumRecords(OpBuilder &builder, Location loc, int64_t srdBase, Operation *op, TranslationContext &ctx) { auto castOp = cast(op); @@ -367,6 +368,41 @@ static void emitSrdNumRecords(OpBuilder &builder, Location loc, int64_t srdBase, return; } + // Detect arith.select(arith.cmpi(scalar, scalar), scalar, scalar) + // and emit s_cmp + s_cselect directly into SRD word 2. + if (auto selectOp = validBytesVal.getDefiningOp()) { + Value condVal = selectOp.getCondition(); + if (auto cmpOp = condVal.getDefiningOp()) { + Value cmpLhs = cmpOp.getLhs(); + Value cmpRhs = cmpOp.getRhs(); + Value trueVal = selectOp.getTrueValue(); + Value falseVal = selectOp.getFalseValue(); + + auto cmpLhsMapped = ctx.getMapper().getMapped(cmpLhs); + auto cmpRhsMapped = ctx.getMapper().getMapped(cmpRhs); + auto trueMapped = ctx.getMapper().getMapped(trueVal); + auto falseMapped = ctx.getMapper().getMapped(falseVal); + + if (cmpLhsMapped && cmpRhsMapped && trueMapped && falseMapped && + isScalarOrImm(*cmpLhsMapped) && isScalarOrImm(*cmpRhsMapped) && + isScalarOrImm(*trueMapped) && isScalarOrImm(*falseMapped)) { + emitScalarCmp(builder, loc, cmpOp.getPredicate(), + *cmpLhsMapped, *cmpRhsMapped, ctx); + + auto dstType = PSRegType::get(builder.getContext(), srdBase + 2, 1); + Value trueV = *trueMapped; + Value falseV = *falseMapped; + auto sregType = ctx.createSRegType(); + if (isImmType(trueV.getType())) + trueV = S_MOV_B32::create(builder, loc, sregType, trueV); + auto result = + S_CSELECT_B32::create(builder, loc, dstType, trueV, falseV); + DCEProtectOp::create(builder, loc, result); + return; + } + } + } + auto mapped = ctx.getMapper().getMapped(validBytesVal); if (mapped) { Value src = *mapped; diff --git a/waveasm/lib/Transforms/handlers/AffineHandlers.cpp b/waveasm/lib/Transforms/handlers/AffineHandlers.cpp index a18c35e165..1a56cc0aea 100644 --- a/waveasm/lib/Transforms/handlers/AffineHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/AffineHandlers.cpp @@ -31,6 +31,7 @@ #include #include +#include #define DEBUG_TYPE "waveasm-affine-handlers" @@ -38,14 +39,234 @@ using namespace mlir; namespace waveasm { +//===----------------------------------------------------------------------===// +// Affine expression normalization for 32-bit backends +//===----------------------------------------------------------------------===// +// MLIR's affine canonicalizer can compose nested affine.apply ops into a +// single expression, multiplying all coefficients by a large LCM factor. +// For example: +// (s0*4 + s1*256 + ...) floordiv 2000 +// may become: +// (s0*15258789062500 + s1*976562500000000 + ...) floordiv 7629394531250000 +// +// These 50+ bit constants overflow 32-bit GPU arithmetic. We normalize by +// dividing all *variable coefficients* and the divisor by their GCD, then +// absorbing the constant term's quotient into the reduced expression. +// +// Math: FloorDiv(G*(A+q) + r, G*D') = FloorDiv(A+q, D') +// when 0 <= r < G and D' >= 2 (always true here). +//===----------------------------------------------------------------------===// + +static constexpr int64_t kMaxConst32 = + static_cast(std::numeric_limits::max()); + +/// Return true if any constant in \p e exceeds 32-bit range. +static bool hasLargeConstants(AffineExpr e) { + if (auto c = dyn_cast(e)) + return std::abs(c.getValue()) > kMaxConst32; + if (auto bin = dyn_cast(e)) + return hasLargeConstants(bin.getLHS()) || hasLargeConstants(bin.getRHS()); + return false; +} + +/// Collect all variable-coefficient constants from a sum expression that is +/// the numerator of a FloorDiv. Walks Add-chains and Mul(const, expr) +/// nodes, returning each constant multiplier. The free constant term is +/// returned via \p freeConst. Returns false if the expression shape is too +/// complex to normalize (e.g. nested FloorDiv products). +static bool collectSumCoefficients(AffineExpr e, + SmallVectorImpl &coeffs, + int64_t &freeConst) { + if (auto c = dyn_cast(e)) { + freeConst += c.getValue(); + return true; + } + if (isa(e) || isa(e)) { + // A bare dimension or symbol has an implicit coefficient of 1. + coeffs.push_back(1); + return true; + } + auto bin = dyn_cast(e); + if (!bin) + return false; + + if (bin.getKind() == AffineExprKind::Add) + return collectSumCoefficients(bin.getLHS(), coeffs, freeConst) && + collectSumCoefficients(bin.getRHS(), coeffs, freeConst); + + if (bin.getKind() == AffineExprKind::Mul) { + auto tryConstMul = [&](AffineExpr constSide, AffineExpr exprSide) -> bool { + auto c = dyn_cast(constSide); + if (!c) + return false; + SmallVector inner; + int64_t innerConst = 0; + if (!collectSumCoefficients(exprSide, inner, innerConst)) + return false; + for (int64_t ic : inner) + coeffs.push_back(ic * c.getValue()); + freeConst += innerConst * c.getValue(); + return true; + }; + if (tryConstMul(bin.getRHS(), bin.getLHS()) || + tryConstMul(bin.getLHS(), bin.getRHS())) + return true; + // Both sides are non-constant -- treat as opaque + coeffs.push_back(1); + return true; + } + + // FloorDiv, CeilDiv, Mod sub-expressions are opaque for coefficient + // collection — they contribute an implicit coefficient of 1. + if (bin.getKind() == AffineExprKind::FloorDiv || + bin.getKind() == AffineExprKind::CeilDiv || + bin.getKind() == AffineExprKind::Mod) { + coeffs.push_back(1); + return true; + } + + return false; +} + +/// Divide every variable coefficient and the divisor by \p g, and adjust +/// the free constant term using floordiv semantics. +static AffineExpr divideExprByGCD(AffineExpr e, int64_t g, int64_t &constAdj, + MLIRContext *ctx) { + if (auto c = dyn_cast(e)) { + // Free constant: absorb floor(C/g) into constAdj, drop remainder + int64_t val = c.getValue(); + int64_t q = (val >= 0) ? val / g : -(-val + g - 1) / g; // floor division + constAdj += q; + return getAffineConstantExpr(0, ctx); + } + if (isa(e) || isa(e)) + return e; + + auto bin = dyn_cast(e); + if (!bin) + return e; + + if (bin.getKind() == AffineExprKind::Add) { + auto l = divideExprByGCD(bin.getLHS(), g, constAdj, ctx); + auto r = divideExprByGCD(bin.getRHS(), g, constAdj, ctx); + return l + r; + } + if (bin.getKind() == AffineExprKind::Mul) { + // Divide whichever side is the constant factor by g. + // The non-constant side is recursed with g=1 so that any nested + // free constants still get their floor-quotient absorbed into constAdj. + for (auto [constSide, exprSide] : {std::pair{bin.getRHS(), bin.getLHS()}, + std::pair{bin.getLHS(), bin.getRHS()}}) { + if (auto c = dyn_cast(constSide)) { + int64_t newVal = c.getValue() / g; + auto reduced = divideExprByGCD(exprSide, 1, constAdj, ctx); + return reduced * getAffineConstantExpr(newVal, ctx); + } + } + return e; + } + // FloorDiv, Mod, etc. -- leave as-is (opaque sub-expression) + return e; +} + +/// Try to normalize a FloorDiv expression whose divisor exceeds 32 bits. +/// Returns a normalized expression or std::nullopt on failure. +static std::optional +tryNormalizeFloorDiv(AffineExpr numerator, int64_t divisor, MLIRContext *ctx) { + SmallVector coeffs; + int64_t freeConst = 0; + if (!collectSumCoefficients(numerator, coeffs, freeConst)) + return std::nullopt; + + // Compute GCD of all variable coefficients and the divisor. + // Exclude the free constant -- it may not share the factor, but the + // floordiv identity still holds (see header comment). + int64_t g = std::abs(divisor); + for (int64_t c : coeffs) + g = std::gcd(g, std::abs(c)); + if (g <= 1) + return std::nullopt; + + int64_t newDiv = divisor / g; + if (std::abs(newDiv) > kMaxConst32) + return std::nullopt; + + // Verify all reduced coefficients fit in 32 bits + for (int64_t c : coeffs) { + if (std::abs(c / g) > kMaxConst32) + return std::nullopt; + } + + // Build the reduced numerator + int64_t constAdj = 0; + AffineExpr reduced = divideExprByGCD(numerator, g, constAdj, ctx); + // Add the absorbed constant quotient + if (constAdj != 0) + reduced = reduced + getAffineConstantExpr(constAdj, ctx); + + LLVM_DEBUG(llvm::dbgs() << "Normalized FloorDiv: divisor " << divisor + << " -> " << newDiv << " (GCD=" << g << ")\n"); + + return reduced.floorDiv(getAffineConstantExpr(newDiv, ctx)); +} + +/// Walk the expression tree and normalize any FloorDiv whose constant +/// divisor exceeds 32 bits. +static AffineExpr normalizeExpr(AffineExpr e, MLIRContext *ctx) { + if (!hasLargeConstants(e)) + return e; + + auto bin = dyn_cast(e); + if (!bin) + return e; + + // Recursively normalize sub-expressions first + AffineExpr lhs = normalizeExpr(bin.getLHS(), ctx); + AffineExpr rhs = normalizeExpr(bin.getRHS(), ctx); + + if (bin.getKind() == AffineExprKind::FloorDiv) { + if (auto constRhs = dyn_cast(rhs)) { + int64_t divisor = constRhs.getValue(); + if (static_cast(std::abs(divisor)) > kMaxConst32) { + if (auto norm = tryNormalizeFloorDiv(lhs, divisor, ctx)) + return *norm; + } + } + } + + // Only rebuild if a child actually changed + if (lhs == bin.getLHS() && rhs == bin.getRHS()) + return e; + + switch (bin.getKind()) { + case AffineExprKind::Add: + return lhs + rhs; + case AffineExprKind::Mul: + return lhs * rhs; + case AffineExprKind::FloorDiv: + return lhs.floorDiv(rhs); + case AffineExprKind::CeilDiv: + return lhs.ceilDiv(rhs); + case AffineExprKind::Mod: + return lhs % rhs; + default: + return e; + } +} + // 32-bit unsigned Barrett reduction: floordiv(x, d) exact for all uint32. // Matches LLVM's __udivsi3 lowering: // 1. Float rcp seed → Newton-Raphson refinement → exact reciprocal // 2. q = mulhi(x, recip) // 3. Two remainder-based correction steps +// When both inputs are scalar, extracts the result back to SGPR via +// v_readfirstlane_b32 so downstream arithmetic stays in SALU. Value emitUnsignedFloordiv(Value x, Value d, OpBuilder &builder, Location loc, TranslationContext &ctx) { + bool bothScalar = isScalarOrImm(x) && isScalarOrImm(d); auto vregType = ctx.createVRegType(); + x = ensureVGPR(builder, x.getLoc(), ctx, x); + d = ensureVGPR(builder, d.getLoc(), ctx, d); Value zeroConst = createImmConst(0, builder, loc, ctx); // --- Step 1: float reciprocal seed --- @@ -89,6 +310,11 @@ Value emitUnsignedFloordiv(Value x, Value d, OpBuilder &builder, Location loc, zeroConst); q = V_ADD_U32::create(builder, loc, vregType, q, inc2); + if (bothScalar) { + auto sregType = ctx.createSRegType(); + q = V_READFIRSTLANE_B32::create(builder, loc, sregType, q); + } + return q; } @@ -111,14 +337,9 @@ Value emitUnsignedFloordiv(Value x, Value d, OpBuilder &builder, Location loc, Value emitConstantUnsignedFloordiv(Value x, int64_t divisor, OpBuilder &builder, Location loc, TranslationContext &ctx) { assert(divisor >= 2 && "divisor must be >= 2"); + assert(static_cast(divisor) <= 0xFFFFFFFFULL && + "divisor exceeds 32 bits -- should have been normalized"); - if (static_cast(divisor) > 0xFFFFFFFFULL) { - llvm::errs() << "ERROR: divisor " << divisor << " (0x" - << llvm::utohexstr(static_cast(divisor)) - << ") exceeds 32 bits in emitConstantUnsignedFloordiv\n"; - } - - auto vregType = ctx.createVRegType(); llvm::APInt divisorAPInt(32, static_cast(divisor)); auto mag = llvm::UnsignedDivisionByConstantInfo::get( divisorAPInt, /*LeadingZeros=*/0, @@ -133,23 +354,21 @@ Value emitConstantUnsignedFloordiv(Value x, int64_t divisor, OpBuilder &builder, int64_t magicVal = static_cast(mag.Magic.getZExtValue()); Value magicConst = createImmConst(magicVal, builder, loc, ctx); - Value q = V_MUL_HI_U32::create(builder, loc, vregType, x, magicConst); + Value q = emitMulHi(x, magicConst, builder, loc, ctx); auto emitShiftRight = [&](Value val, unsigned amount) -> Value { if (amount == 0) return val; Value shiftConst = createImmConst(static_cast(amount), builder, loc, ctx); - return V_LSHRREV_B32::create(builder, loc, vregType, shiftConst, val); + return emitLshr(val, shiftConst, builder, loc, ctx); }; if (mag.IsAdd) { - // add form: (mulhi(x,m) + ((x - mulhi(x,m)) >> 1)) >> PostShift - Value xSubQ = V_SUB_U32::create(builder, loc, vregType, x, q); + Value xSubQ = emitSub(x, q, builder, loc, ctx); Value oneConst = createImmConst(1, builder, loc, ctx); - Value halfDiff = - V_LSHRREV_B32::create(builder, loc, vregType, oneConst, xSubQ); - Value sum = V_ADD_U32::create(builder, loc, vregType, q, halfDiff); + Value halfDiff = emitLshr(xSubQ, oneConst, builder, loc, ctx); + Value sum = emitAdd(q, halfDiff, builder, loc, ctx); return emitShiftRight(sum, mag.PostShift); } @@ -170,7 +389,23 @@ Value emitConstantUnsignedFloordiv(Value x, int64_t divisor, OpBuilder &builder, static Value emitCeilFromFloorQuotient(Value q, Value x, Value d, OpBuilder &builder, Location loc, TranslationContext &ctx) { + bool allScalar = isScalarOrImm(q) && isScalarOrImm(x) && isScalarOrImm(d); + if (allScalar) { + // Fully scalar path: rem = x - q*d; SCC = (rem != 0); q += SCC + Value qd = emitMul(q, d, builder, loc, ctx); + Value rem = emitSub(x, qd, builder, loc, ctx); + Value zeroConst = createImmConst(0, builder, loc, ctx); + // s_cmp_lg_u32 sets SCC = (rem != 0) + S_CMP_NE_U32::create(builder, loc, ctx.createSRegType(), rem, zeroConst); + // s_addc_u32: dst = q + 0 + SCC (carry-in from SCC) + auto sregType = ctx.createSRegType(); + return S_ADDC_U32::create(builder, loc, sregType, sregType, q, zeroConst) + .getDst(); + } auto vregType = ctx.createVRegType(); + q = ensureVGPR(builder, loc, ctx, q); + x = ensureVGPR(builder, loc, ctx, x); + d = ensureVGPR(builder, loc, ctx, d); Value qd = V_MUL_LO_U32::create(builder, loc, vregType, q, d); Value rem = V_SUB_U32::create(builder, loc, vregType, x, qd); Value zeroConst = createImmConst(0, builder, loc, ctx); @@ -273,7 +508,10 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { // TODO: Re-enable constant extraction only for values used directly in memory // ops int64_t constAddend = 0; - AffineExpr exprToCompile = expr; + // Reduce any 50+ bit constants introduced by MLIR's affine canonicalizer + // back to 32-bit range before emitting GPU instructions (see normalization + // section at the top of this file). + AffineExpr exprToCompile = normalizeExpr(expr, applyOp.getContext()); // Simple pattern matching for common affine expressions // Pattern: d0 mod N -> v_and_b32 (when N is power of 2) @@ -338,6 +576,12 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { BitRange lhsRange = lhsResult.range; BitRange rhsRange = rhsResult.range; + // Helper: coerce both operands to VGPR for VALU binary ops. + auto ensureBothVGPR = [&]() { + lhs = ensureVGPR(builder, loc, ctx, lhs); + rhs = ensureVGPR(builder, loc, ctx, rhs); + }; + switch (binExpr.getKind()) { case AffineExprKind::Add: { if (!lhsRange.overlaps(rhsRange)) { @@ -362,8 +606,11 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { auto shiftConst = ConstantOp::create(builder, loc, shiftImm, shiftAmount); // v_lshl_or_b32: dst = (src << shift) | orend + Value base = ensureVGPR(builder, loc, ctx, + baseResult.value); + orend = ensureVGPR(builder, loc, ctx, orend); Value fusedResult = V_LSHL_OR_B32::create( - builder, loc, vregType, baseResult.value, shiftConst, + builder, loc, vregType, base, shiftConst, orend); BitRange shiftedRange = baseResult.range.shiftLeft(shiftAmount); @@ -382,8 +629,11 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { auto shiftImm = ctx.createImmType(shiftAmount); auto shiftConst = ConstantOp::create(builder, loc, shiftImm, shiftAmount); + Value base2 = ensureVGPR(builder, loc, ctx, + baseResult.value); + orend = ensureVGPR(builder, loc, ctx, orend); Value fusedResult = V_LSHL_OR_B32::create( - builder, loc, vregType, baseResult.value, shiftConst, + builder, loc, vregType, base2, shiftConst, orend); BitRange shiftedRange = baseResult.range.shiftLeft(shiftAmount); @@ -406,14 +656,14 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { return *result; } - // No fusion possible, emit regular v_or_b32 - Value orResult = V_OR_B32::create(builder, loc, vregType, lhs, rhs); + // No fusion possible — emitOr picks S_OR_B32 when both scalar. + Value orResult = emitOr(lhs, rhs, builder, loc, ctx); BitRange resultRange = lhsRange.merge(rhsRange); ctx.setBitRange(orResult, resultRange); return ExprResult(orResult, resultRange); } // Overlapping ranges - must use ADD - Value addResult = V_ADD_U32::create(builder, loc, vregType, lhs, rhs); + Value addResult = emitAdd(lhs, rhs, builder, loc, ctx); BitRange resultRange = lhsRange.extendForAdd(rhsRange); ctx.setBitRange(addResult, resultRange); return ExprResult(addResult, resultRange); @@ -443,10 +693,8 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { int64_t mask = ~(N - 1) & 0xFFFFFFFF; auto maskImm = ctx.createImmType(mask); auto maskConst = ConstantOp::create(builder, loc, maskImm, mask); - // NOTE: constant must be src0 (first operand) for VOP2 encoding. - // src1 must be a VGPR on AMDGCN. - Value andResult = V_AND_B32::create(builder, loc, vregType, maskConst, - innerResult.value); + Value andResult = + emitAnd(maskConst, innerResult.value, builder, loc, ctx); // Result has same bit range as inner, but low bits cleared BitRange resultRange = innerResult.range; // Clear bits below log2(N) -- conservative: use inner range @@ -481,14 +729,12 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { auto shiftConst = ConstantOp::create(builder, loc, shiftAmt, shiftAmount); Value shiftResult = - V_LSHLREV_B32::create(builder, loc, vregType, shiftConst, lhs); - // Shift the bit range left by shiftAmount + emitLshl(lhs, shiftConst, builder, loc, ctx); BitRange resultRange = lhsRange.shiftLeft(shiftAmount); ctx.setBitRange(shiftResult, resultRange); return ExprResult(shiftResult, resultRange); } } - // Also check LHS for power of 2 multiply if (auto constLhs = dyn_cast(binExpr.getLHS())) { int64_t val = constLhs.getValue(); if (isPowerOf2(val)) { @@ -497,14 +743,13 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { auto shiftConst = ConstantOp::create(builder, loc, shiftAmt, shiftAmount); Value shiftResult = - V_LSHLREV_B32::create(builder, loc, vregType, shiftConst, rhs); + emitLshl(rhs, shiftConst, builder, loc, ctx); BitRange resultRange = rhsRange.shiftLeft(shiftAmount); ctx.setBitRange(shiftResult, resultRange); return ExprResult(shiftResult, resultRange); } } - Value mulResult = - V_MUL_LO_U32::create(builder, loc, vregType, lhs, rhs); + Value mulResult = emitMul(lhs, rhs, builder, loc, ctx); return ExprResult(mulResult, BitRange()); // Conservative: full range } @@ -532,7 +777,7 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { auto shiftConst = ConstantOp::create(builder, loc, shiftAmt, shiftAmount); Value shiftResult = - V_LSHRREV_B32::create(builder, loc, vregType, shiftConst, lhs); + emitLshr(lhs, shiftConst, builder, loc, ctx); BitRange resultRange = lhsRange.shiftRight(shiftAmount); ctx.setBitRange(shiftResult, resultRange); return ExprResult(shiftResult, resultRange); @@ -561,13 +806,21 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { if (isPowerOf2(divisor)) { int64_t shiftAmount = log2(divisor); Value shiftConst = createImmConst(shiftAmount, builder, loc, ctx); - Value q = - V_LSHRREV_B32::create(builder, loc, vregType, shiftConst, lhs); + Value q = emitLshr(lhs, shiftConst, builder, loc, ctx); int64_t mask = divisor - 1; Value maskConst = createImmConst(mask, builder, loc, ctx); - Value rem = - V_AND_B32::create(builder, loc, vregType, maskConst, lhs); + Value rem = emitAnd(lhs, maskConst, builder, loc, ctx); Value zeroConst = createImmConst(0, builder, loc, ctx); + if (isScalarOrImm(rem)) { + S_CMP_NE_U32::create(builder, loc, ctx.createSRegType(), rem, + zeroConst); + auto sregType = ctx.createSRegType(); + Value result = + S_ADDC_U32::create(builder, loc, sregType, sregType, q, + zeroConst) + .getDst(); + return ExprResult(result, BitRange()); + } V_CMP_NE_U32::create(builder, loc, rem, zeroConst); Value oneConst = createImmConst(1, builder, loc, ctx); Value oneVgpr = V_MOV_B32::create(builder, loc, vregType, oneConst); @@ -601,8 +854,7 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { int64_t val = constRhs.getValue(); if (isPowerOf2(val)) { Value maskConst = createImmConst(val - 1, builder, loc, ctx); - Value andResult = - V_AND_B32::create(builder, loc, vregType, lhs, maskConst); + Value andResult = emitAnd(lhs, maskConst, builder, loc, ctx); BitRange resultRange = BitRange(0, log2(val) - 1); ctx.setBitRange(andResult, resultRange); return ExprResult(andResult, resultRange); @@ -612,16 +864,16 @@ LogicalResult handleAffineApply(Operation *op, TranslationContext &ctx) { if (val >= 2) { Value q = emitConstantUnsignedFloordiv(lhs, val, builder, loc, ctx); Value dConst = createImmConst(val, builder, loc, ctx); - Value qd = V_MUL_LO_U32::create(builder, loc, vregType, q, dConst); - Value rem = V_SUB_U32::create(builder, loc, vregType, lhs, qd); + Value qd = emitMul(q, dConst, builder, loc, ctx); + Value rem = emitSub(lhs, qd, builder, loc, ctx); return ExprResult(rem, BitRange()); } } // Symbolic divisor fallback: x mod d = x - floordiv(x, d) * d { Value q = emitUnsignedFloordiv(lhs, rhs, builder, loc, ctx); - Value qd = V_MUL_LO_U32::create(builder, loc, vregType, q, rhs); - Value rem = V_SUB_U32::create(builder, loc, vregType, lhs, qd); + Value qd = emitMul(q, rhs, builder, loc, ctx); + Value rem = emitSub(lhs, qd, builder, loc, ctx); return ExprResult(rem, BitRange()); } } diff --git a/waveasm/lib/Transforms/handlers/ArithHandlers.cpp b/waveasm/lib/Transforms/handlers/ArithHandlers.cpp index fced3fda29..27a016d1ae 100644 --- a/waveasm/lib/Transforms/handlers/ArithHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/ArithHandlers.cpp @@ -108,43 +108,135 @@ LogicalResult handleArithConstant(Operation *op, TranslationContext &ctx) { //===----------------------------------------------------------------------===// LogicalResult handleArithAddI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALU(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + + auto result = emitAdd(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } LogicalResult handleArithSubI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALU(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + + auto result = emitSub(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } LogicalResult handleArithMulI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALU(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + + auto result = emitMul(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } LogicalResult handleArithMinSI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALU(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + auto result = emitMinI32(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } LogicalResult handleArithMaxSI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALU(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + auto result = emitMaxI32(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } LogicalResult handleArithMinUI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALU(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + auto result = emitMinU32(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } LogicalResult handleArithMaxUI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALU(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + auto result = emitMaxU32(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } LogicalResult handleArithAndI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALU(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + + auto result = emitAnd(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } LogicalResult handleArithOrI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALU(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + + auto result = emitOr(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } LogicalResult handleArithXorI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALU(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + + auto result = emitXor(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } //===----------------------------------------------------------------------===// @@ -152,11 +244,31 @@ LogicalResult handleArithXorI(Operation *op, TranslationContext &ctx) { //===----------------------------------------------------------------------===// LogicalResult handleArithShLI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALURev(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + + auto result = emitLshl(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } LogicalResult handleArithShRUI(Operation *op, TranslationContext &ctx) { - return handleBinaryVALURev(op, ctx); + auto typedOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + + std::optional lhs, rhs; + if (failed(validateBinaryOperands(typedOp, ctx, lhs, rhs))) + return failure(); + + auto result = emitLshr(*lhs, *rhs, builder, loc, ctx); + ctx.getMapper().mapValue(typedOp.getResult(), result); + return success(); } LogicalResult handleArithShRSI(Operation *op, TranslationContext &ctx) { @@ -197,11 +309,9 @@ LogicalResult handleArithDivUI(Operation *op, TranslationContext &ctx) { if (auto constOp = rhs->getDefiningOp()) { int64_t divisor = constOp.getValue(); if (isPowerOf2(divisor)) { - auto vregType = ctx.createVRegType(); int64_t shiftAmt = log2(divisor); Value shiftConst = createImmConst(shiftAmt, builder, loc, ctx); - auto result = - V_LSHRREV_B32::create(builder, loc, vregType, shiftConst, *lhs); + auto result = emitLshr(*lhs, shiftConst, builder, loc, ctx); ctx.getMapper().mapValue(divOp.getResult(), result); return success(); } @@ -235,7 +345,7 @@ LogicalResult handleArithRemUI(Operation *op, TranslationContext &ctx) { int64_t modulus = constOp.getValue(); if (isPowerOf2(modulus)) { Value maskConst = createImmConst(modulus - 1, builder, loc, ctx); - auto result = V_AND_B32::create(builder, loc, vregType, *lhs, maskConst); + auto result = emitAnd(*lhs, maskConst, builder, loc, ctx); ctx.getMapper().mapValue(remOp.getResult(), result); return success(); } @@ -341,90 +451,95 @@ LogicalResult handleArithCmpI(Operation *op, TranslationContext &ctx) { return failure(); } - // When both operands are scalar (SGPR or immediate), use S_CMP which - // produces an SGPR result directly. This is required for scf.if/scf.for - // conditions that feed waveasm.if/waveasm.condition (which require SGPRs). + // When both operands are scalar AND the result feeds a ConditionOp (loop + // back-edge), use S_CMP which writes SCC directly. ConditionOp reads SCC + // via s_cbranch_scc1. + // For other scalar consumers (arith.select), the fusion happens in + // handleArithSelect which emits s_cmp + s_cselect as a pair. bool lhsScalar = isSGPRType(lhs->getType()) || isImmType(lhs->getType()); bool rhsScalar = isSGPRType(rhs->getType()) || isImmType(rhs->getType()); + bool usedByCondition = cmpOp.getResult().hasOneUse() && + isa(*cmpOp.getResult().getUsers().begin()); - if (lhsScalar && rhsScalar) { + if (lhsScalar && rhsScalar && usedByCondition) { auto sregType = ctx.createSRegType(); - // S_CMP requires SGPR operands; move immediates to SGPRs first. Value lhsOp = *lhs; Value rhsOp = *rhs; if (isImmType(lhsOp.getType())) lhsOp = S_MOV_B32::create(builder, loc, sregType, lhsOp); if (isImmType(rhsOp.getType())) rhsOp = S_MOV_B32::create(builder, loc, sregType, rhsOp); - Value result; - switch (cmpOp.getPredicate()) { - case arith::CmpIPredicate::eq: - result = S_CMP_EQ_U32::create(builder, loc, sregType, lhsOp, rhsOp); - break; - case arith::CmpIPredicate::ne: - result = S_CMP_NE_U32::create(builder, loc, sregType, lhsOp, rhsOp); - break; - case arith::CmpIPredicate::slt: - result = S_CMP_LT_I32::create(builder, loc, sregType, lhsOp, rhsOp); - break; - case arith::CmpIPredicate::sle: - result = S_CMP_LE_I32::create(builder, loc, sregType, lhsOp, rhsOp); - break; - case arith::CmpIPredicate::sgt: - result = S_CMP_GT_I32::create(builder, loc, sregType, lhsOp, rhsOp); - break; - case arith::CmpIPredicate::sge: - result = S_CMP_GE_I32::create(builder, loc, sregType, lhsOp, rhsOp); - break; - case arith::CmpIPredicate::ult: - result = S_CMP_LT_U32::create(builder, loc, sregType, lhsOp, rhsOp); - break; - case arith::CmpIPredicate::ule: - result = S_CMP_LE_U32::create(builder, loc, sregType, lhsOp, rhsOp); - break; - case arith::CmpIPredicate::ugt: - result = S_CMP_GT_U32::create(builder, loc, sregType, lhsOp, rhsOp); - break; - case arith::CmpIPredicate::uge: - result = S_CMP_GE_U32::create(builder, loc, sregType, lhsOp, rhsOp); - break; - } + Value result = + emitScalarCmp(builder, loc, cmpOp.getPredicate(), lhsOp, rhsOp, ctx); ctx.getMapper().mapValue(cmpOp.getResult(), result); return success(); } - // Vector path: at least one operand is a VGPR, so use V_CMP which writes - // to VCC implicitly. + // When both operands are scalar and ALL users are arith.select ops with + // scalar true/false values, the fusion in handleArithSelect will emit + // s_cmp + s_cselect directly. Skip the VALU path entirely — map the + // result to a dummy so the mapper doesn't complain. + if (lhsScalar && rhsScalar) { + bool allUsersFused = true; + for (auto &use : cmpOp.getResult().getUses()) { + auto *user = use.getOwner(); + if (auto selectUser = dyn_cast(user)) { + auto trueMap = ctx.getMapper().getMapped(selectUser.getTrueValue()); + auto falseMap = ctx.getMapper().getMapped(selectUser.getFalseValue()); + if (!trueMap || !falseMap || + !isScalarOrImm(*trueMap) || !isScalarOrImm(*falseMap)) { + allUsersFused = false; + break; + } + } else { + allUsersFused = false; + break; + } + } + if (allUsersFused) { + Value dummy = createImmConst(0, builder, loc, ctx); + ctx.getMapper().mapValue(cmpOp.getResult(), dummy); + return success(); + } + } + + // Vector path: use V_CMP which writes to VCC implicitly. + // V_CMP VOP3 can read one SGPR from the constant bus — only coerce + // to VGPR when BOTH operands are SGPR (two constant bus slots). + Value lhsV = *lhs; + Value rhsV = *rhs; + if (isSGPRType(lhsV.getType()) && isSGPRType(rhsV.getType())) + lhsV = ensureVGPR(builder, loc, ctx, lhsV); switch (cmpOp.getPredicate()) { case arith::CmpIPredicate::eq: - V_CMP_EQ_U32::create(builder, loc, *lhs, *rhs); + V_CMP_EQ_U32::create(builder, loc, lhsV, rhsV); break; case arith::CmpIPredicate::ne: - V_CMP_NE_U32::create(builder, loc, *lhs, *rhs); + V_CMP_NE_U32::create(builder, loc, lhsV, rhsV); break; case arith::CmpIPredicate::slt: - V_CMP_LT_I32::create(builder, loc, *lhs, *rhs); + V_CMP_LT_I32::create(builder, loc, lhsV, rhsV); break; case arith::CmpIPredicate::sle: - V_CMP_LE_I32::create(builder, loc, *lhs, *rhs); + V_CMP_LE_I32::create(builder, loc, lhsV, rhsV); break; case arith::CmpIPredicate::sgt: - V_CMP_GT_I32::create(builder, loc, *lhs, *rhs); + V_CMP_GT_I32::create(builder, loc, lhsV, rhsV); break; case arith::CmpIPredicate::sge: - V_CMP_GE_I32::create(builder, loc, *lhs, *rhs); + V_CMP_GE_I32::create(builder, loc, lhsV, rhsV); break; case arith::CmpIPredicate::ult: - V_CMP_LT_U32::create(builder, loc, *lhs, *rhs); + V_CMP_LT_U32::create(builder, loc, lhsV, rhsV); break; case arith::CmpIPredicate::ule: - V_CMP_LE_U32::create(builder, loc, *lhs, *rhs); + V_CMP_LE_U32::create(builder, loc, lhsV, rhsV); break; case arith::CmpIPredicate::ugt: - V_CMP_GT_U32::create(builder, loc, *lhs, *rhs); + V_CMP_GT_U32::create(builder, loc, lhsV, rhsV); break; case arith::CmpIPredicate::uge: - V_CMP_GE_U32::create(builder, loc, *lhs, *rhs); + V_CMP_GE_U32::create(builder, loc, lhsV, rhsV); break; } @@ -447,12 +562,68 @@ LogicalResult handleArithSelect(Operation *op, TranslationContext &ctx) { return op->emitError("operands not mapped"); } - // Restore the materialized boolean VGPR (0/1) back into VCC + // CmpI fusion: when condition is arith.cmpi with scalar operands and + // both select values are scalar, emit s_cmp + s_cselect directly. + // This bypasses the VGPR boolean from handleArithCmpI entirely. + if (isScalarOrImm(*trueVal) && isScalarOrImm(*falseVal)) { + Value condMLIR = selectOp.getCondition(); + if (auto cmpOp = condMLIR.getDefiningOp()) { + auto cmpLhs = ctx.getMapper().getMapped(cmpOp.getLhs()); + auto cmpRhs = ctx.getMapper().getMapped(cmpOp.getRhs()); + if (cmpLhs && cmpRhs && + isScalarOrImm(*cmpLhs) && isScalarOrImm(*cmpRhs)) { + auto sregType = ctx.createSRegType(); + Value lhsOp = *cmpLhs; + Value rhsOp = *cmpRhs; + if (isImmType(lhsOp.getType())) + lhsOp = S_MOV_B32::create(builder, loc, sregType, lhsOp); + if (isImmType(rhsOp.getType())) + rhsOp = S_MOV_B32::create(builder, loc, sregType, rhsOp); + emitScalarCmp(builder, loc, cmpOp.getPredicate(), lhsOp, rhsOp, ctx); + Value trueV = *trueVal; + Value falseV = *falseVal; + if (isImmType(trueV.getType())) + trueV = S_MOV_B32::create(builder, loc, sregType, trueV); + auto result = + S_CSELECT_B32::create(builder, loc, sregType, trueV, falseV); + ctx.getMapper().mapValue(selectOp.getResult(), result); + return success(); + } + } + } + + // Scalar path: when condition and both values are scalar, use + // s_cmp_lg_u32 + s_cselect_b32 (no VGPR needed). + if (isScalarOrImm(*cond) && isScalarOrImm(*trueVal) && + isScalarOrImm(*falseVal)) { + Value zeroConst = createImmConst(0, builder, loc, ctx); + Value condV = *cond; + if (isImmType(condV.getType())) + condV = S_MOV_B32::create(builder, loc, ctx.createSRegType(), condV); + S_CMP_NE_U32::create(builder, loc, ctx.createSRegType(), condV, zeroConst); + auto sregType = ctx.createSRegType(); + Value trueV = *trueVal; + Value falseV = *falseVal; + if (isImmType(trueV.getType())) + trueV = S_MOV_B32::create(builder, loc, sregType, trueV); + auto result = S_CSELECT_B32::create(builder, loc, sregType, trueV, falseV); + ctx.getMapper().mapValue(selectOp.getResult(), result); + return success(); + } + + // Vector path: restore the materialized boolean VGPR (0/1) back into VCC. + // V_CMP can take one SGPR via the constant bus; zeroConst is an immediate + // (no bus slot), so an SGPR cond is fine without coercion. Value zeroConst = createImmConst(0, builder, loc, ctx); - V_CMP_NE_U32::create(builder, loc, *cond, zeroConst); + Value condV = *cond; + V_CMP_NE_U32::create(builder, loc, condV, zeroConst); + // v_cndmask_b32: coerce both to VGPR. VOP3 constant bus allows one SGPR + // but interacts with vcc slot; keeping both as VGPR is safest. + Value trueVgpr = ensureVGPR(builder, loc, ctx, *trueVal); + Value falseVgpr = ensureVGPR(builder, loc, ctx, *falseVal); auto result = - V_CNDMASK_B32::create(builder, loc, vregType, *falseVal, *trueVal, *cond); + V_CNDMASK_B32::create(builder, loc, vregType, falseVgpr, trueVgpr, *cond); ctx.getMapper().mapValue(selectOp.getResult(), result); return success(); } diff --git a/waveasm/lib/Transforms/handlers/Handlers.h b/waveasm/lib/Transforms/handlers/Handlers.h index c3d2265f9c..bd6b8862fa 100644 --- a/waveasm/lib/Transforms/handlers/Handlers.h +++ b/waveasm/lib/Transforms/handlers/Handlers.h @@ -26,6 +26,7 @@ #include "waveasm/Transforms/TranslateFromMLIR.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" #include "mlir/Support/LogicalResult.h" @@ -278,6 +279,263 @@ template inline To bitCast(From value) { return result; } +//===----------------------------------------------------------------------===// +// Scalar / VGPR Helpers +//===----------------------------------------------------------------------===// + +/// Check whether a value is scalar (SGPR) or an inline constant. +inline bool isScalarOrImm(mlir::Value v) { + return isSGPRType(v.getType()) || isImmType(v.getType()); +} + +/// If \p v is an SGPR, emit a v_mov_b32 to coerce it into a VGPR so it can +/// be used by VALU-only instructions (v_cvt_*, v_rcp_*, v_mul_f32, etc.). +/// Returns \p v unchanged when it is already a VGPR or immediate. +inline mlir::Value ensureVGPR(mlir::OpBuilder &builder, mlir::Location loc, + TranslationContext &ctx, mlir::Value v) { + if (isSGPRType(v.getType())) { + auto vregType = ctx.createVRegType(); + return V_MOV_B32::create(builder, loc, vregType, v); + } + return v; +} + +//===----------------------------------------------------------------------===// +// Auto-select SALU/VALU Emit Helpers +//===----------------------------------------------------------------------===// + +/// Emit add: S_ADD_U32 when both operands are scalar, V_ADD_U32 otherwise. +/// Commutative: swaps to put immediate in src1 (SALU src0 must be SGPR). +inline mlir::Value emitAdd(mlir::Value a, mlir::Value b, mlir::OpBuilder &builder, + mlir::Location loc, TranslationContext &ctx) { + if (isScalarOrImm(a) && isScalarOrImm(b) && + !(isImmType(a.getType()) && isImmType(b.getType()))) { + if (isImmType(a.getType())) + std::swap(a, b); + auto sregType = ctx.createSRegType(); + return S_ADD_U32::create(builder, loc, sregType, sregType, a, b).getDst(); + } + auto vregType = ctx.createVRegType(); + return V_ADD_U32::create(builder, loc, vregType, a, b); +} + +/// Emit sub: S_SUB_U32 when both operands are scalar, V_SUB_U32 otherwise. +/// Not commutative: src0 (minuend) must be SGPR. +inline mlir::Value emitSub(mlir::Value a, mlir::Value b, mlir::OpBuilder &builder, + mlir::Location loc, TranslationContext &ctx) { + if (isScalarOrImm(a) && isScalarOrImm(b) && isSGPRType(a.getType())) { + auto sregType = ctx.createSRegType(); + return S_SUB_U32::create(builder, loc, sregType, sregType, a, b).getDst(); + } + auto vregType = ctx.createVRegType(); + return V_SUB_U32::create(builder, loc, vregType, a, b); +} + +/// Emit mul: S_MUL_I32 when both operands are scalar, V_MUL_LO_U32 otherwise. +/// Commutative: swaps to put immediate in src1. +inline mlir::Value emitMul(mlir::Value a, mlir::Value b, mlir::OpBuilder &builder, + mlir::Location loc, TranslationContext &ctx) { + if (isScalarOrImm(a) && isScalarOrImm(b) && + !(isImmType(a.getType()) && isImmType(b.getType()))) { + if (isImmType(a.getType())) + std::swap(a, b); + auto sregType = ctx.createSRegType(); + return S_MUL_I32::create(builder, loc, sregType, a, b); + } + auto vregType = ctx.createVRegType(); + return V_MUL_LO_U32::create(builder, loc, vregType, a, b); +} + +/// Emit logical shift right: S_LSHR_B32 when scalar, V_LSHRREV_B32 otherwise. +/// Not commutative: src0 (value) must be SGPR. +inline mlir::Value emitLshr(mlir::Value value, mlir::Value shiftAmt, + mlir::OpBuilder &builder, mlir::Location loc, + TranslationContext &ctx) { + if (isScalarOrImm(value) && isScalarOrImm(shiftAmt) && + isSGPRType(value.getType())) { + auto sregType = ctx.createSRegType(); + return S_LSHR_B32::create(builder, loc, sregType, value, shiftAmt); + } + auto vregType = ctx.createVRegType(); + value = ensureVGPR(builder, loc, ctx, value); + return V_LSHRREV_B32::create(builder, loc, vregType, shiftAmt, value); +} + +/// Emit logical shift left: S_LSHL_B32 when scalar, V_LSHLREV_B32 otherwise. +/// Not commutative: src0 (value) must be SGPR. +inline mlir::Value emitLshl(mlir::Value value, mlir::Value shiftAmt, + mlir::OpBuilder &builder, mlir::Location loc, + TranslationContext &ctx) { + if (isScalarOrImm(value) && isScalarOrImm(shiftAmt) && + isSGPRType(value.getType())) { + auto sregType = ctx.createSRegType(); + return S_LSHL_B32::create(builder, loc, sregType, value, shiftAmt); + } + auto vregType = ctx.createVRegType(); + value = ensureVGPR(builder, loc, ctx, value); + return V_LSHLREV_B32::create(builder, loc, vregType, shiftAmt, value); +} + +/// Emit bitwise AND: S_AND_B32 when both scalar, V_AND_B32 otherwise. +/// Commutative: swaps to put immediate in src1. +inline mlir::Value emitAnd(mlir::Value a, mlir::Value b, mlir::OpBuilder &builder, + mlir::Location loc, TranslationContext &ctx) { + if (isScalarOrImm(a) && isScalarOrImm(b) && + !(isImmType(a.getType()) && isImmType(b.getType()))) { + if (isImmType(a.getType())) + std::swap(a, b); + auto sregType = ctx.createSRegType(); + return S_AND_B32::create(builder, loc, sregType, a, b); + } + auto vregType = ctx.createVRegType(); + return V_AND_B32::create(builder, loc, vregType, a, b); +} + +/// Emit bitwise OR: S_OR_B32 when both scalar, V_OR_B32 otherwise. +/// Commutative: swaps to put immediate in src1. +inline mlir::Value emitOr(mlir::Value a, mlir::Value b, mlir::OpBuilder &builder, + mlir::Location loc, TranslationContext &ctx) { + if (isScalarOrImm(a) && isScalarOrImm(b) && + !(isImmType(a.getType()) && isImmType(b.getType()))) { + if (isImmType(a.getType())) + std::swap(a, b); + auto sregType = ctx.createSRegType(); + return S_OR_B32::create(builder, loc, sregType, a, b); + } + auto vregType = ctx.createVRegType(); + return V_OR_B32::create(builder, loc, vregType, a, b); +} + +/// Emit bitwise XOR: S_XOR_B32 when both scalar, V_XOR_B32 otherwise. +/// Commutative: swaps to put immediate in src1. +inline mlir::Value emitXor(mlir::Value a, mlir::Value b, mlir::OpBuilder &builder, + mlir::Location loc, TranslationContext &ctx) { + if (isScalarOrImm(a) && isScalarOrImm(b) && + !(isImmType(a.getType()) && isImmType(b.getType()))) { + if (isImmType(a.getType())) + std::swap(a, b); + auto sregType = ctx.createSRegType(); + return S_XOR_B32::create(builder, loc, sregType, a, b); + } + auto vregType = ctx.createVRegType(); + return V_XOR_B32::create(builder, loc, vregType, a, b); +} + +/// Emit mulhi: S_MUL_HI_U32 when both scalar, V_MUL_HI_U32 otherwise. +/// Commutative: swaps to put immediate in src1. +inline mlir::Value emitMulHi(mlir::Value a, mlir::Value b, + mlir::OpBuilder &builder, mlir::Location loc, + TranslationContext &ctx) { + if (isScalarOrImm(a) && isScalarOrImm(b) && + !(isImmType(a.getType()) && isImmType(b.getType()))) { + if (isImmType(a.getType())) + std::swap(a, b); + auto sregType = ctx.createSRegType(); + return S_MUL_HI_U32::create(builder, loc, sregType, a, b); + } + auto vregType = ctx.createVRegType(); + return V_MUL_HI_U32::create(builder, loc, vregType, a, b); +} + +/// Emit signed max: S_MAX_I32 when both scalar, V_MAX_I32 otherwise. +/// Not commutative for SCC semantics (SCC = src0 was selected) but +/// the result value is commutative. +inline mlir::Value emitMaxI32(mlir::Value a, mlir::Value b, + mlir::OpBuilder &builder, mlir::Location loc, + TranslationContext &ctx) { + if (isScalarOrImm(a) && isScalarOrImm(b) && + !(isImmType(a.getType()) && isImmType(b.getType()))) { + if (isImmType(a.getType())) + std::swap(a, b); + auto sregType = ctx.createSRegType(); + return S_MAX_I32::create(builder, loc, sregType, a, b); + } + auto vregType = ctx.createVRegType(); + return V_MAX_I32::create(builder, loc, vregType, a, b); +} + +/// Emit signed min: S_MIN_I32 when both scalar, V_MIN_I32 otherwise. +inline mlir::Value emitMinI32(mlir::Value a, mlir::Value b, + mlir::OpBuilder &builder, mlir::Location loc, + TranslationContext &ctx) { + if (isScalarOrImm(a) && isScalarOrImm(b) && + !(isImmType(a.getType()) && isImmType(b.getType()))) { + if (isImmType(a.getType())) + std::swap(a, b); + auto sregType = ctx.createSRegType(); + return S_MIN_I32::create(builder, loc, sregType, a, b); + } + auto vregType = ctx.createVRegType(); + return V_MIN_I32::create(builder, loc, vregType, a, b); +} + +/// Emit unsigned max: S_MAX_U32 when both scalar, V_MAX_U32 otherwise. +inline mlir::Value emitMaxU32(mlir::Value a, mlir::Value b, + mlir::OpBuilder &builder, mlir::Location loc, + TranslationContext &ctx) { + if (isScalarOrImm(a) && isScalarOrImm(b) && + !(isImmType(a.getType()) && isImmType(b.getType()))) { + if (isImmType(a.getType())) + std::swap(a, b); + auto sregType = ctx.createSRegType(); + return S_MAX_U32::create(builder, loc, sregType, a, b); + } + auto vregType = ctx.createVRegType(); + return V_MAX_U32::create(builder, loc, vregType, a, b); +} + +/// Emit unsigned min: S_MIN_U32 when both scalar, V_MIN_U32 otherwise. +inline mlir::Value emitMinU32(mlir::Value a, mlir::Value b, + mlir::OpBuilder &builder, mlir::Location loc, + TranslationContext &ctx) { + if (isScalarOrImm(a) && isScalarOrImm(b) && + !(isImmType(a.getType()) && isImmType(b.getType()))) { + if (isImmType(a.getType())) + std::swap(a, b); + auto sregType = ctx.createSRegType(); + return S_MIN_U32::create(builder, loc, sregType, a, b); + } + auto vregType = ctx.createVRegType(); + return V_MIN_U32::create(builder, loc, vregType, a, b); +} + +//===----------------------------------------------------------------------===// +// Scalar Comparison Helper +//===----------------------------------------------------------------------===// + +/// Emit S_CMP_* for the given predicate with SGPR operands (sets SCC). +/// Both lhs and rhs must be SGPRs (not immediates) before calling. +/// Returns the S_CMP result value (phantom SCC). +inline mlir::Value emitScalarCmp(mlir::OpBuilder &builder, mlir::Location loc, + mlir::arith::CmpIPredicate pred, + mlir::Value lhs, mlir::Value rhs, + TranslationContext &ctx) { + auto sregType = ctx.createSRegType(); + switch (pred) { + case mlir::arith::CmpIPredicate::eq: + return S_CMP_EQ_U32::create(builder, loc, sregType, lhs, rhs); + case mlir::arith::CmpIPredicate::ne: + return S_CMP_NE_U32::create(builder, loc, sregType, lhs, rhs); + case mlir::arith::CmpIPredicate::slt: + return S_CMP_LT_I32::create(builder, loc, sregType, lhs, rhs); + case mlir::arith::CmpIPredicate::sle: + return S_CMP_LE_I32::create(builder, loc, sregType, lhs, rhs); + case mlir::arith::CmpIPredicate::sgt: + return S_CMP_GT_I32::create(builder, loc, sregType, lhs, rhs); + case mlir::arith::CmpIPredicate::sge: + return S_CMP_GE_I32::create(builder, loc, sregType, lhs, rhs); + case mlir::arith::CmpIPredicate::ult: + return S_CMP_LT_U32::create(builder, loc, sregType, lhs, rhs); + case mlir::arith::CmpIPredicate::ule: + return S_CMP_LE_U32::create(builder, loc, sregType, lhs, rhs); + case mlir::arith::CmpIPredicate::ugt: + return S_CMP_GT_U32::create(builder, loc, sregType, lhs, rhs); + case mlir::arith::CmpIPredicate::uge: + return S_CMP_GE_U32::create(builder, loc, sregType, lhs, rhs); + } + llvm_unreachable("unhandled CmpIPredicate"); +} + //===----------------------------------------------------------------------===// // Error Handling Helpers //===----------------------------------------------------------------------===// diff --git a/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp b/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp index 785808bdba..70cb50b248 100644 --- a/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp @@ -301,7 +301,14 @@ LogicalResult handleMemRefStore(Operation *op, TranslationContext &ctx) { voffset = ConstantOp::create(builder, loc, immType, 0); } - BUFFER_STORE_DWORD::create(builder, loc, *data, srd, voffset); + Value storeData = *data; + if (isAGPRType(storeData.getType())) { + auto vregType = ctx.createVRegType(); + storeData = + V_ACCVGPR_READ_B32::create(builder, loc, vregType, storeData); + } + + BUFFER_STORE_DWORD::create(builder, loc, storeData, srd, voffset); } return success(); diff --git a/waveasm/test/Dialect/arith-pseudo-ops.mlir b/waveasm/test/Dialect/arith-pseudo-ops.mlir new file mode 100644 index 0000000000..e2b3450e98 --- /dev/null +++ b/waveasm/test/Dialect/arith-pseudo-ops.mlir @@ -0,0 +1,49 @@ +// RUN: waveasm-translate %s | waveasm-translate | FileCheck %s +// Verify that generic arithmetic pseudo-ops roundtrip through parsing and printing. + +// CHECK: waveasm.program @test_arith_pseudo_ops +waveasm.program @test_arith_pseudo_ops + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %v0 = waveasm.precolored.vreg 0 : !waveasm.vreg + %s0 = waveasm.precolored.sreg 0 : !waveasm.sreg + %c42 = waveasm.constant 42 : !waveasm.imm<42> + + // CHECK: waveasm.arith.add + %add_vv = waveasm.arith.add %v0, %v0 : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg + + // CHECK: waveasm.arith.add + %add_sv = waveasm.arith.add %s0, %v0 : (!waveasm.sreg, !waveasm.vreg) -> !waveasm.vreg + + // CHECK: waveasm.arith.add + %add_ss = waveasm.arith.add %s0, %s0 : (!waveasm.sreg, !waveasm.sreg) -> !waveasm.sreg + + // CHECK: waveasm.arith.mul + %mul = waveasm.arith.mul %v0, %c42 : (!waveasm.vreg, !waveasm.imm<42>) -> !waveasm.vreg + + // CHECK: waveasm.arith.cmp eq + %cmp_eq = waveasm.arith.cmp eq, %v0, %v0 : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg + + // CHECK: waveasm.arith.cmp slt + %cmp_slt = waveasm.arith.cmp slt, %s0, %c42 : (!waveasm.sreg, !waveasm.imm<42>) -> !waveasm.sreg + + // CHECK: waveasm.arith.cmp ult + %cmp_ult = waveasm.arith.cmp ult, %v0, %s0 : (!waveasm.vreg, !waveasm.sreg) -> !waveasm.vreg + + // CHECK: waveasm.arith.select + %sel = waveasm.arith.select %cmp_eq, %v0, %add_vv : (!waveasm.vreg, !waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg + + // CHECK: waveasm.arith.trunc + %s_wide = waveasm.precolored.sreg 4, 2 : !waveasm.sreg<2, 2> + %trunc = waveasm.arith.trunc %s_wide : (!waveasm.sreg<2, 2>) -> !waveasm.sreg + + // CHECK: waveasm.arith.sext + %sext = waveasm.arith.sext %s0 : (!waveasm.sreg) -> !waveasm.sreg<2, 2> + + // CHECK: waveasm.arith.zext + %zext = waveasm.arith.zext %v0 : (!waveasm.vreg) -> !waveasm.vreg<2> + + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/arith-legalization-bitwise.mlir b/waveasm/test/Transforms/arith-legalization-bitwise.mlir new file mode 100644 index 0000000000..da3f9fda7b --- /dev/null +++ b/waveasm/test/Transforms/arith-legalization-bitwise.mlir @@ -0,0 +1,130 @@ +// RUN: waveasm-translate %s --waveasm-arith-legalization | FileCheck %s +// Verify that arith.or and arith.and are lowered to concrete SALU/VALU ops. + +//===----------------------------------------------------------------------===// +// i32 OR +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: waveasm.program @test_or_i32 +waveasm.program @test_or_i32 + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %v0 = waveasm.precolored.vreg 0 : !waveasm.vreg + %s0 = waveasm.precolored.sreg 0 : !waveasm.sreg + %s1 = waveasm.precolored.sreg 1 : !waveasm.sreg + %c42 = waveasm.constant 42 : !waveasm.imm<42> + + // VGPR | VGPR -> v_or_b32. + // CHECK: waveasm.v_or_b32 %{{.*}}, %{{.*}} : !waveasm.vreg, !waveasm.vreg + %or_vv = waveasm.arith.or %v0, %v0 : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg + + // SGPR | SGPR -> s_or_b32. + // CHECK: waveasm.s_or_b32 %{{.*}}, %{{.*}} : !waveasm.sreg, !waveasm.sreg + %or_ss = waveasm.arith.or %s0, %s1 : (!waveasm.sreg, !waveasm.sreg) -> !waveasm.sreg + + // SGPR | VGPR -> v_mov_b32 + v_or_b32. + // CHECK: waveasm.v_mov_b32 + // CHECK: waveasm.v_or_b32 + %or_sv = waveasm.arith.or %s0, %v0 : (!waveasm.sreg, !waveasm.vreg) -> !waveasm.vreg + + // SGPR | imm -> s_or_b32. + // CHECK: waveasm.s_or_b32 %{{.*}}, %{{.*}} : !waveasm.sreg, !waveasm.imm<42> + %or_si = waveasm.arith.or %s0, %c42 : (!waveasm.sreg, !waveasm.imm<42>) -> !waveasm.sreg + + // CHECK-NOT: waveasm.arith. + waveasm.s_endpgm +} + +//===----------------------------------------------------------------------===// +// i32 AND +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: waveasm.program @test_and_i32 +waveasm.program @test_and_i32 + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %v0 = waveasm.precolored.vreg 0 : !waveasm.vreg + %s0 = waveasm.precolored.sreg 0 : !waveasm.sreg + %s1 = waveasm.precolored.sreg 1 : !waveasm.sreg + + // VGPR & VGPR -> v_and_b32. + // CHECK: waveasm.v_and_b32 %{{.*}}, %{{.*}} : !waveasm.vreg, !waveasm.vreg + %and_vv = waveasm.arith.and %v0, %v0 : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg + + // SGPR & SGPR -> s_and_b32. + // CHECK: waveasm.s_and_b32 %{{.*}}, %{{.*}} : !waveasm.sreg, !waveasm.sreg + %and_ss = waveasm.arith.and %s0, %s1 : (!waveasm.sreg, !waveasm.sreg) -> !waveasm.sreg + + // SGPR & VGPR -> v_mov_b32 + v_and_b32. + // CHECK: waveasm.v_mov_b32 + // CHECK: waveasm.v_and_b32 + %and_sv = waveasm.arith.and %s0, %v0 : (!waveasm.sreg, !waveasm.vreg) -> !waveasm.vreg + + // CHECK-NOT: waveasm.arith. + waveasm.s_endpgm +} + +//===----------------------------------------------------------------------===// +// i64 OR +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: waveasm.program @test_or_i64_salu +waveasm.program @test_or_i64_salu + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %a = waveasm.precolored.sreg 0, 2 : !waveasm.sreg<2, 2> + %b = waveasm.precolored.sreg 2, 2 : !waveasm.sreg<2, 2> + + // Native s_or_b64 for i64 SALU. + // CHECK: waveasm.s_or_b64 %{{.*}}, %{{.*}} : !waveasm.sreg<2, 2>, !waveasm.sreg<2, 2> + %or = waveasm.arith.or %a, %b : (!waveasm.sreg<2, 2>, !waveasm.sreg<2, 2>) -> !waveasm.sreg<2, 2> + + // CHECK-NOT: waveasm.arith. + waveasm.s_endpgm +} + +// CHECK-LABEL: waveasm.program @test_or_i64_valu +waveasm.program @test_or_i64_valu + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %v0 = waveasm.precolored.vreg 0 : !waveasm.vreg + %v1 = waveasm.precolored.vreg 1 : !waveasm.vreg + %va = waveasm.pack %v0, %v0 : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<2> + %vb = waveasm.pack %v1, %v1 : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<2> + + // Native v_or_b64 for i64 VALU. + // CHECK: waveasm.v_or_b64 %{{.*}}, %{{.*}} : !waveasm.vreg<2>, !waveasm.vreg<2> + %or = waveasm.arith.or %va, %vb : (!waveasm.vreg<2>, !waveasm.vreg<2>) -> !waveasm.vreg<2> + + // CHECK-NOT: waveasm.arith. + waveasm.s_endpgm +} + +//===----------------------------------------------------------------------===// +// i64 AND +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: waveasm.program @test_and_i64_salu +waveasm.program @test_and_i64_salu + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %a = waveasm.precolored.sreg 0, 2 : !waveasm.sreg<2, 2> + %b = waveasm.precolored.sreg 2, 2 : !waveasm.sreg<2, 2> + + // Native s_and_b64 for i64 SALU. + // CHECK: waveasm.s_and_b64 %{{.*}}, %{{.*}} : !waveasm.sreg<2, 2>, !waveasm.sreg<2, 2> + %and = waveasm.arith.and %a, %b : (!waveasm.sreg<2, 2>, !waveasm.sreg<2, 2>) -> !waveasm.sreg<2, 2> + + // CHECK-NOT: waveasm.arith. + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/arith-legalization-error-cmp.mlir b/waveasm/test/Transforms/arith-legalization-error-cmp.mlir new file mode 100644 index 0000000000..26619de7fb --- /dev/null +++ b/waveasm/test/Transforms/arith-legalization-error-cmp.mlir @@ -0,0 +1,16 @@ +// RUN: not waveasm-translate %s --waveasm-arith-legalization 2>&1 | FileCheck %s +// Verify that cmp with unsupported width produces a diagnostic. + +waveasm.program @test_cmp_unsupported_width + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %s_wide = waveasm.precolored.sreg 0, 4 : !waveasm.sreg<4, 4> + %s_wide2 = waveasm.precolored.sreg 4, 4 : !waveasm.sreg<4, 4> + + // CHECK: unsupported operand width (expected i32 or i64, got 4 dwords) + %cmp = waveasm.arith.cmp eq, %s_wide, %s_wide2 : (!waveasm.sreg<4, 4>, !waveasm.sreg<4, 4>) -> !waveasm.sreg + + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/arith-legalization-error-mul.mlir b/waveasm/test/Transforms/arith-legalization-error-mul.mlir new file mode 100644 index 0000000000..5b1ed7eb1d --- /dev/null +++ b/waveasm/test/Transforms/arith-legalization-error-mul.mlir @@ -0,0 +1,16 @@ +// RUN: not waveasm-translate %s --waveasm-arith-legalization 2>&1 | FileCheck %s +// Verify that mul with unsupported width produces a diagnostic. + +waveasm.program @test_mul_unsupported_width + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %s_wide = waveasm.precolored.sreg 0, 4 : !waveasm.sreg<4, 4> + %s_wide2 = waveasm.precolored.sreg 4, 4 : !waveasm.sreg<4, 4> + + // CHECK: unsupported operand width (expected i32 or i64, got 4 dwords) + %mul = waveasm.arith.mul %s_wide, %s_wide2 : (!waveasm.sreg<4, 4>, !waveasm.sreg<4, 4>) -> !waveasm.sreg<4, 4> + + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/arith-legalization-error-select.mlir b/waveasm/test/Transforms/arith-legalization-error-select.mlir new file mode 100644 index 0000000000..30ccd9468a --- /dev/null +++ b/waveasm/test/Transforms/arith-legalization-error-select.mlir @@ -0,0 +1,17 @@ +// RUN: not waveasm-translate %s --waveasm-arith-legalization 2>&1 | FileCheck %s +// Verify that select with unsupported width produces a diagnostic. + +waveasm.program @test_select_unsupported_width + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %s_wide = waveasm.precolored.sreg 0, 4 : !waveasm.sreg<4, 4> + %s_wide2 = waveasm.precolored.sreg 4, 4 : !waveasm.sreg<4, 4> + %cond = waveasm.precolored.sreg 8 : !waveasm.sreg + + // CHECK: unsupported operand width (expected i32 or i64, got 4 dwords) + %sel = waveasm.arith.select %cond, %s_wide, %s_wide2 : (!waveasm.sreg<4, 4>, !waveasm.sreg<4, 4>, !waveasm.sreg) -> !waveasm.sreg<4, 4> + + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/arith-legalization-error-sext.mlir b/waveasm/test/Transforms/arith-legalization-error-sext.mlir new file mode 100644 index 0000000000..90698ca1f4 --- /dev/null +++ b/waveasm/test/Transforms/arith-legalization-error-sext.mlir @@ -0,0 +1,15 @@ +// RUN: not waveasm-translate %s --waveasm-arith-legalization 2>&1 | FileCheck %s +// Verify that sext from non-i32 source produces a diagnostic. + +waveasm.program @test_sext_bad_source + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %s_wide = waveasm.precolored.sreg 0, 2 : !waveasm.sreg<2, 2> + + // CHECK: sext source must be i32 (got 2 dwords) + %sext = waveasm.arith.sext %s_wide : (!waveasm.sreg<2, 2>) -> !waveasm.sreg<4, 4> + + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/arith-legalization-error-trunc.mlir b/waveasm/test/Transforms/arith-legalization-error-trunc.mlir new file mode 100644 index 0000000000..981a045c7a --- /dev/null +++ b/waveasm/test/Transforms/arith-legalization-error-trunc.mlir @@ -0,0 +1,15 @@ +// RUN: not waveasm-translate %s --waveasm-arith-legalization 2>&1 | FileCheck %s +// Verify that trunc from unsupported width produces a diagnostic. + +waveasm.program @test_trunc_bad_width + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %s_wide = waveasm.precolored.sreg 0, 4 : !waveasm.sreg<4, 4> + + // CHECK: unsupported operand width (expected i32 or i64, got 4 dwords) + %trunc = waveasm.arith.trunc %s_wide : (!waveasm.sreg<4, 4>) -> !waveasm.sreg + + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/arith-legalization-error-zext.mlir b/waveasm/test/Transforms/arith-legalization-error-zext.mlir new file mode 100644 index 0000000000..eb101d1e97 --- /dev/null +++ b/waveasm/test/Transforms/arith-legalization-error-zext.mlir @@ -0,0 +1,15 @@ +// RUN: not waveasm-translate %s --waveasm-arith-legalization 2>&1 | FileCheck %s +// Verify that zext from non-i32 source produces a diagnostic. + +waveasm.program @test_zext_bad_source + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %s_wide = waveasm.precolored.sreg 0, 2 : !waveasm.sreg<2, 2> + + // CHECK: zext source must be i32 (got 2 dwords) + %zext = waveasm.arith.zext %s_wide : (!waveasm.sreg<2, 2>) -> !waveasm.sreg<4, 4> + + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/arith-legalization-errors.mlir b/waveasm/test/Transforms/arith-legalization-errors.mlir new file mode 100644 index 0000000000..721a1f378b --- /dev/null +++ b/waveasm/test/Transforms/arith-legalization-errors.mlir @@ -0,0 +1,17 @@ +// RUN: not waveasm-translate %s --waveasm-arith-legalization 2>&1 | FileCheck %s +// Verify that unsupported operand widths produce diagnostics. + +// Unsupported width: sreg<4> is not i32 or i64. +waveasm.program @test_add_unsupported_width + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %s_wide = waveasm.precolored.sreg 0, 4 : !waveasm.sreg<4, 4> + %s_wide2 = waveasm.precolored.sreg 4, 4 : !waveasm.sreg<4, 4> + + // CHECK: unsupported operand width (expected i32 or i64, got 4 dwords) + %add = waveasm.arith.add %s_wide, %s_wide2 : (!waveasm.sreg<4, 4>, !waveasm.sreg<4, 4>) -> !waveasm.sreg<4, 4> + + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/arith-legalization-i64-cmp-vcc-clobber.mlir b/waveasm/test/Transforms/arith-legalization-i64-cmp-vcc-clobber.mlir new file mode 100644 index 0000000000..dc8f3e804e --- /dev/null +++ b/waveasm/test/Transforms/arith-legalization-i64-cmp-vcc-clobber.mlir @@ -0,0 +1,51 @@ +// RUN: waveasm-translate %s --waveasm-translate-from-llvm --waveasm-arith-legalization | FileCheck %s +// Verify that i64 compare results survive VCC clobbering by intervening +// i64 add operations (v_add_co_u32). The compare must materialize its +// boolean result to a VGPR, and the select must re-establish VCC from +// that VGPR right before v_cndmask_b32. + +// CHECK-LABEL: waveasm.program @test_i64_cmp_vcc_clobber__waveasm + +// The i64 ordered compare decomposes into hi/lo comparisons. The final +// step selects between hi and lo results based on hi-equality. This +// v_cmp_eq + v_cndmask_b32 produces the materialized boolean. +// CHECK: waveasm.v_cmp_eq_i32 +// CHECK-NEXT: %[[VCC_PH:.*]] = waveasm.constant +// CHECK-NEXT: %[[BOOL:.*]] = waveasm.v_cndmask_b32 %{{.*}}, %{{.*}}, %[[VCC_PH]] + +// Intervening i64 add legalizes to v_add_co_u32 which clobbers VCC. +// CHECK: waveasm.v_add_co_u32 +// CHECK: waveasm.v_addc_co_u32 + +// The select re-establishes VCC from the materialized boolean +// right before v_cndmask_b32. +// CHECK: waveasm.v_cmp_ne_i32 %[[BOOL]], %{{.*}} +// CHECK: waveasm.v_cndmask_b32 +// CHECK: waveasm.v_cndmask_b32 + +// No arith pseudo-ops should remain. +// CHECK-NOT: waveasm.arith. +// CHECK: waveasm.s_endpgm + +gpu.module @gpu_module { + llvm.func @test_i64_cmp_vcc_clobber(%arg0: !llvm.ptr, %arg1: i64) attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %c100 = llvm.mlir.constant(100 : i64) : i64 + %c2 = llvm.mlir.constant(2 : i64) : i64 + %tid = rocdl.workitem.id.x range : i32 + %tid64 = llvm.sext %tid : i32 to i64 + + // i64 compare — result must survive across the add below. + %cmp = llvm.icmp "slt" %tid64, %arg1 : i64 + + // i64 add — legalizes to v_add_co_u32 + v_addc_co_u32, clobbering VCC. + %sum = llvm.add %tid64, %c2 : i64 + + // i64 select — must see the compare result, not the stale carry from add. + %sel = llvm.select %cmp, %tid64, %c100 : i1, i64 + %trunc = llvm.trunc %sel : i64 to i32 + + // Use %sum to prevent DCE. + %sum32 = llvm.trunc %sum : i64 to i32 + llvm.return + } +} diff --git a/waveasm/test/Transforms/arith-legalization.mlir b/waveasm/test/Transforms/arith-legalization.mlir new file mode 100644 index 0000000000..a488c52e14 --- /dev/null +++ b/waveasm/test/Transforms/arith-legalization.mlir @@ -0,0 +1,413 @@ +// RUN: waveasm-translate %s --waveasm-arith-legalization | FileCheck %s +// Verify that generic arithmetic pseudo-ops are lowered to concrete SALU/VALU ops. + +// CHECK-LABEL: waveasm.program @test_add +waveasm.program @test_add + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %v0 = waveasm.precolored.vreg 0 : !waveasm.vreg + %s0 = waveasm.precolored.sreg 0 : !waveasm.sreg + %s1 = waveasm.precolored.sreg 1 : !waveasm.sreg + %c42 = waveasm.constant 42 : !waveasm.imm<42> + + // VGPR + VGPR -> v_add_u32. + // CHECK: waveasm.v_add_u32 %{{.*}}, %{{.*}} : !waveasm.vreg, !waveasm.vreg + %add_vv = waveasm.arith.add %v0, %v0 : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg + + // SGPR + SGPR -> s_add_u32. + // CHECK: waveasm.s_add_u32 %{{.*}}, %{{.*}} : !waveasm.sreg, !waveasm.sreg + %add_ss = waveasm.arith.add %s0, %s1 : (!waveasm.sreg, !waveasm.sreg) -> !waveasm.sreg + + // SGPR + VGPR -> v_add_u32 (SGPR broadcast via v_mov_b32). + // CHECK: waveasm.v_mov_b32 + // CHECK: waveasm.v_add_u32 + %add_sv = waveasm.arith.add %s0, %v0 : (!waveasm.sreg, !waveasm.vreg) -> !waveasm.vreg + + // SGPR + imm -> s_add_u32. + // CHECK: waveasm.s_add_u32 %{{.*}}, %{{.*}} : !waveasm.sreg, !waveasm.imm<42> + %add_si = waveasm.arith.add %s0, %c42 : (!waveasm.sreg, !waveasm.imm<42>) -> !waveasm.sreg + + // CHECK-NOT: waveasm.arith. + waveasm.s_endpgm +} + +// CHECK-LABEL: waveasm.program @test_mul +waveasm.program @test_mul + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %v0 = waveasm.precolored.vreg 0 : !waveasm.vreg + %s0 = waveasm.precolored.sreg 0 : !waveasm.sreg + %s1 = waveasm.precolored.sreg 1 : !waveasm.sreg + %c42 = waveasm.constant 42 : !waveasm.imm<42> + + // VGPR * imm -> v_mul_lo_u32. + // CHECK: waveasm.v_mul_lo_u32 + %mul_vi = waveasm.arith.mul %v0, %c42 : (!waveasm.vreg, !waveasm.imm<42>) -> !waveasm.vreg + + // SGPR * SGPR -> s_mul_i32. + // CHECK: waveasm.s_mul_i32 + %mul_ss = waveasm.arith.mul %s0, %s1 : (!waveasm.sreg, !waveasm.sreg) -> !waveasm.sreg + + // CHECK-NOT: waveasm.arith. + waveasm.s_endpgm +} + +// CHECK-LABEL: waveasm.program @test_cmp +waveasm.program @test_cmp + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %v0 = waveasm.precolored.vreg 0 : !waveasm.vreg + %s0 = waveasm.precolored.sreg 0 : !waveasm.sreg + %c10 = waveasm.constant 10 : !waveasm.imm<10> + + // VGPR, VGPR -> v_cmp_lt_i32 (sets VCC). + // CHECK: waveasm.v_cmp_lt_i32 + %cmp_vv = waveasm.arith.cmp slt, %v0, %v0 : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg + + // SGPR, imm -> s_cmp_lt_u32 (sets SCC). + // CHECK: waveasm.s_cmp_lt_u32 + %cmp_si = waveasm.arith.cmp ult, %s0, %c10 : (!waveasm.sreg, !waveasm.imm<10>) -> !waveasm.sreg + + // SGPR, VGPR -> v_mov_b32 (constant bus) + v_cmp_eq_i32. + // CHECK: waveasm.v_mov_b32 + // CHECK: waveasm.v_cmp_eq_i32 + %cmp_sv = waveasm.arith.cmp eq, %s0, %v0 : (!waveasm.sreg, !waveasm.vreg) -> !waveasm.vreg + + // CHECK-NOT: waveasm.arith. + waveasm.s_endpgm +} + +// CHECK-LABEL: waveasm.program @test_select +waveasm.program @test_select + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %v0 = waveasm.precolored.vreg 0 : !waveasm.vreg + %v1 = waveasm.precolored.vreg 1 : !waveasm.vreg + + // CHECK: waveasm.v_cmp_lt_i32 + %cmp = waveasm.arith.cmp slt, %v0, %v1 : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg + + // CHECK: waveasm.v_cndmask_b32 + %sel = waveasm.arith.select %cmp, %v0, %v1 : (!waveasm.vreg, !waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg + + // CHECK-NOT: waveasm.arith. + waveasm.s_endpgm +} + +// CHECK-LABEL: waveasm.program @test_select_sgpr_cond +waveasm.program @test_select_sgpr_cond + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %v0 = waveasm.precolored.vreg 0 : !waveasm.vreg + %v1 = waveasm.precolored.vreg 1 : !waveasm.vreg + %s0 = waveasm.precolored.sreg 0 : !waveasm.sreg + + // SGPR condition: must be moved to VGPR and compared to set VCC. + // CHECK: waveasm.v_mov_b32 + // CHECK: waveasm.v_cmp_ne_i32 + // CHECK: waveasm.v_cndmask_b32 + %sel = waveasm.arith.select %s0, %v0, %v1 : (!waveasm.vreg, !waveasm.vreg, !waveasm.sreg) -> !waveasm.vreg + + // CHECK-NOT: waveasm.arith. + waveasm.s_endpgm +} + +// CHECK-LABEL: waveasm.program @test_wide_narrowing +waveasm.program @test_wide_narrowing + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %v0 = waveasm.precolored.vreg 0 : !waveasm.vreg + %s_wide = waveasm.precolored.sreg 4, 2 : !waveasm.sreg<2, 2> + + // Trunc of precolored wide SGPR -> precolored lo half. + // CHECK: waveasm.precolored.sreg 4 : !waveasm.sreg + %trunc = waveasm.arith.trunc %s_wide : (!waveasm.sreg<2, 2>) -> !waveasm.sreg + + // CHECK-NOT: waveasm.arith. + waveasm.s_endpgm +} + +//===----------------------------------------------------------------------===// +// i64 tests +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: waveasm.program @test_add_i64_salu +waveasm.program @test_add_i64_salu + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %a = waveasm.precolored.sreg 0, 2 : !waveasm.sreg<2, 2> + %b = waveasm.precolored.sreg 2, 2 : !waveasm.sreg<2, 2> + + // Precolored SGPR pairs: split into lo/hi via precolored.sreg at known indices. + // CHECK-DAG: [[A_LO:%.*]] = waveasm.precolored.sreg 0 : !waveasm.sreg + // CHECK-DAG: [[A_HI:%.*]] = waveasm.precolored.sreg 1 : !waveasm.sreg + // CHECK-DAG: [[B_LO:%.*]] = waveasm.precolored.sreg 2 : !waveasm.sreg + // CHECK-DAG: [[B_HI:%.*]] = waveasm.precolored.sreg 3 : !waveasm.sreg + // Carry chain: s_add_u32 (lo) then s_addc_u32 (hi). + // CHECK: [[LO:%.*]], %{{.*}} = waveasm.s_add_u32 [[A_LO]], [[B_LO]] + // CHECK-NEXT: [[HI:%.*]], %{{.*}} = waveasm.s_addc_u32 [[A_HI]], [[B_HI]] + // CHECK: waveasm.pack [[LO]], [[HI]] + %add = waveasm.arith.add %a, %b : (!waveasm.sreg<2, 2>, !waveasm.sreg<2, 2>) -> !waveasm.sreg<2, 2> + + // CHECK-NOT: waveasm.arith. + waveasm.s_endpgm +} + +// CHECK-LABEL: waveasm.program @test_add_i64_valu +waveasm.program @test_add_i64_valu + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %a = waveasm.precolored.vreg 0 : !waveasm.vreg + %b = waveasm.precolored.vreg 1 : !waveasm.vreg + %va = waveasm.pack %a, %a : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<2> + %vb = waveasm.pack %b, %b : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<2> + + // VGPR i64 add: splitI64 looks through pack, uses originals directly. + // CHECK: [[A:%.*]] = waveasm.precolored.vreg 0 + // CHECK: [[B:%.*]] = waveasm.precolored.vreg 1 + // CHECK: [[ADD_LO:%.*]], %{{.*}} = waveasm.v_add_co_u32 [[A]], [[B]] + // CHECK-NEXT: [[ADD_HI:%.*]], %{{.*}} = waveasm.v_addc_co_u32 [[A]], [[B]] + // CHECK: waveasm.pack [[ADD_LO]], [[ADD_HI]] + %add = waveasm.arith.add %va, %vb : (!waveasm.vreg<2>, !waveasm.vreg<2>) -> !waveasm.vreg<2> + + // CHECK-NOT: waveasm.arith. + waveasm.s_endpgm +} + +// CHECK-LABEL: waveasm.program @test_mul_i64_salu +waveasm.program @test_mul_i64_salu + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %a = waveasm.precolored.sreg 0, 2 : !waveasm.sreg<2, 2> + %b = waveasm.precolored.sreg 2, 2 : !waveasm.sreg<2, 2> + + // Schoolbook i64 multiply: lo = mul_lo(a_lo, b_lo). + // hi = mul_hi(a_lo, b_lo) + mul_lo(a_lo, b_hi) + mul_lo(a_hi, b_lo). + // CHECK-DAG: [[A_LO:%.*]] = waveasm.precolored.sreg 0 : !waveasm.sreg + // CHECK-DAG: [[A_HI:%.*]] = waveasm.precolored.sreg 1 : !waveasm.sreg + // CHECK-DAG: [[B_LO:%.*]] = waveasm.precolored.sreg 2 : !waveasm.sreg + // CHECK-DAG: [[B_HI:%.*]] = waveasm.precolored.sreg 3 : !waveasm.sreg + // CHECK: [[RES_LO:%.*]] = waveasm.s_mul_i32 [[A_LO]], [[B_LO]] + // CHECK: [[HI_PART:%.*]] = waveasm.s_mul_hi_u32 [[A_LO]], [[B_LO]] + // CHECK: [[CROSS1:%.*]] = waveasm.s_mul_i32 [[A_LO]], [[B_HI]] + // CHECK: [[CROSS2:%.*]] = waveasm.s_mul_i32 [[A_HI]], [[B_LO]] + // CHECK: [[HI_TEMP:%.*]], %{{.*}} = waveasm.s_add_u32 [[HI_PART]], [[CROSS1]] + // CHECK: [[RES_HI:%.*]], %{{.*}} = waveasm.s_add_u32 [[HI_TEMP]], [[CROSS2]] + // CHECK: waveasm.pack [[RES_LO]], [[RES_HI]] + %mul = waveasm.arith.mul %a, %b : (!waveasm.sreg<2, 2>, !waveasm.sreg<2, 2>) -> !waveasm.sreg<2, 2> + + // CHECK-NOT: waveasm.arith. + waveasm.s_endpgm +} + +// CHECK-LABEL: waveasm.program @test_mul_i64_valu +waveasm.program @test_mul_i64_valu + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %a = waveasm.precolored.vreg 0 : !waveasm.vreg + %b = waveasm.precolored.vreg 1 : !waveasm.vreg + %va = waveasm.pack %a, %a : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<2> + %vb = waveasm.pack %b, %b : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<2> + + // VALU schoolbook: lo = mul_lo(a_lo, b_lo). + // hi = mul_hi(a_lo, b_lo) + mul_lo(a_lo, b_hi) + mul_lo(a_hi, b_lo). + // splitI64 looks through pack, so all operands trace back to the two vregs. + // CHECK: [[A:%.*]] = waveasm.precolored.vreg 0 + // CHECK: [[B:%.*]] = waveasm.precolored.vreg 1 + // CHECK: [[RES_LO:%.*]] = waveasm.v_mul_lo_u32 [[A]], [[B]] + // CHECK: [[HI_PART:%.*]] = waveasm.v_mul_hi_u32 [[A]], [[B]] + // CHECK: [[CROSS1:%.*]] = waveasm.v_mul_lo_u32 [[A]], [[B]] + // CHECK: [[CROSS2:%.*]] = waveasm.v_mul_lo_u32 [[A]], [[B]] + // CHECK: [[RES_HI:%.*]] = waveasm.v_add3_u32 [[HI_PART]], [[CROSS1]], [[CROSS2]] + // CHECK: waveasm.pack [[RES_LO]], [[RES_HI]] + %mul = waveasm.arith.mul %va, %vb : (!waveasm.vreg<2>, !waveasm.vreg<2>) -> !waveasm.vreg<2> + + // CHECK-NOT: waveasm.arith. + waveasm.s_endpgm +} + +// CHECK-LABEL: waveasm.program @test_sext +waveasm.program @test_sext + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %s0 = waveasm.precolored.sreg 0 : !waveasm.sreg + %v0 = waveasm.precolored.vreg 0 : !waveasm.vreg + + // SGPR sext: hi = s_ashr_i32(lo, 31), then merge. + // CHECK: waveasm.s_ashr_i32 %{{.*}}, %{{.*}} : !waveasm.sreg, !waveasm.imm<31> + // CHECK: waveasm.pack + %sext_s = waveasm.arith.sext %s0 : (!waveasm.sreg) -> !waveasm.sreg<2, 2> + + // VGPR sext: hi = v_ashrrev_i32(31, lo), then pack. + // CHECK: waveasm.v_ashrrev_i32 %{{.*}}, %{{.*}} : !waveasm.imm<31>, !waveasm.vreg + // CHECK: waveasm.pack + %sext_v = waveasm.arith.sext %v0 : (!waveasm.vreg) -> !waveasm.vreg<2> + + // CHECK-NOT: waveasm.arith. + waveasm.s_endpgm +} + +// CHECK-LABEL: waveasm.program @test_zext +waveasm.program @test_zext + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %s0 = waveasm.precolored.sreg 0 : !waveasm.sreg + %v0 = waveasm.precolored.vreg 0 : !waveasm.vreg + + // SGPR zext: hi = s_mov_b32(0), then merge. + // CHECK: waveasm.s_mov_b32 %{{.*}} : !waveasm.imm<0> + // CHECK: waveasm.pack + %zext_s = waveasm.arith.zext %s0 : (!waveasm.sreg) -> !waveasm.sreg<2, 2> + + // VGPR zext: hi = v_mov_b32(0), then pack. + // CHECK: waveasm.v_mov_b32 %{{.*}} : !waveasm.imm<0> + // CHECK: waveasm.pack + %zext_v = waveasm.arith.zext %v0 : (!waveasm.vreg) -> !waveasm.vreg<2> + + // CHECK-NOT: waveasm.arith. + waveasm.s_endpgm +} + +// CHECK-LABEL: waveasm.program @test_cmp_i64_eq +waveasm.program @test_cmp_i64_eq + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %a = waveasm.precolored.sreg 0, 2 : !waveasm.sreg<2, 2> + %b = waveasm.precolored.sreg 2, 2 : !waveasm.sreg<2, 2> + + // i64 eq: XOR each half, OR, compare to 0. + // CHECK-DAG: [[A_LO:%.*]] = waveasm.precolored.sreg 0 : !waveasm.sreg + // CHECK-DAG: [[A_HI:%.*]] = waveasm.precolored.sreg 1 : !waveasm.sreg + // CHECK-DAG: [[B_LO:%.*]] = waveasm.precolored.sreg 2 : !waveasm.sreg + // CHECK-DAG: [[B_HI:%.*]] = waveasm.precolored.sreg 3 : !waveasm.sreg + // CHECK: [[XOR_LO:%.*]] = waveasm.s_xor_b32 [[A_LO]], [[B_LO]] + // CHECK: [[XOR_HI:%.*]] = waveasm.s_xor_b32 [[A_HI]], [[B_HI]] + // CHECK: [[COMBINED:%.*]] = waveasm.s_or_b32 [[XOR_LO]], [[XOR_HI]] + // CHECK: waveasm.s_cmp_eq_i32 [[COMBINED]], %{{.*}} + %cmp = waveasm.arith.cmp eq, %a, %b : (!waveasm.sreg<2, 2>, !waveasm.sreg<2, 2>) -> !waveasm.sreg + + // CHECK-NOT: waveasm.arith. + waveasm.s_endpgm +} + +// CHECK-LABEL: waveasm.program @test_cmp_i64_slt_valu +waveasm.program @test_cmp_i64_slt_valu + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %v0 = waveasm.precolored.vreg 0 : !waveasm.vreg + %v1 = waveasm.precolored.vreg 1 : !waveasm.vreg + %va = waveasm.pack %v0, %v0 : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<2> + %vb = waveasm.pack %v1, %v1 : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<2> + + // Ordered i64 slt (VALU): hi signed + lo unsigned + select on hi equality. + // The compare materializes the boolean result to a VGPR (no final + // v_cmp_ne_i32 here — the consumer re-establishes VCC when needed). + // CHECK: waveasm.v_cmp_lt_i32 + // CHECK: waveasm.v_cndmask_b32 + // CHECK: waveasm.v_cmp_lt_u32 + // CHECK: waveasm.v_cndmask_b32 + // CHECK: waveasm.v_cmp_eq_i32 + // CHECK: waveasm.v_cndmask_b32 + %cmp = waveasm.arith.cmp slt, %va, %vb : (!waveasm.vreg<2>, !waveasm.vreg<2>) -> !waveasm.vreg + + // CHECK-NOT: waveasm.arith. + waveasm.s_endpgm +} + +// CHECK-LABEL: waveasm.program @test_cmp_i64_slt_salu +waveasm.program @test_cmp_i64_slt_salu + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %a = waveasm.precolored.sreg 0, 2 : !waveasm.sreg<2, 2> + %b = waveasm.precolored.sreg 2, 2 : !waveasm.sreg<2, 2> + + // Ordered i64 slt (SALU): hiLt | (hiEq & loLt). + // CHECK-DAG: [[A_LO:%.*]] = waveasm.precolored.sreg 0 : !waveasm.sreg + // CHECK-DAG: [[A_HI:%.*]] = waveasm.precolored.sreg 1 : !waveasm.sreg + // CHECK-DAG: [[B_LO:%.*]] = waveasm.precolored.sreg 2 : !waveasm.sreg + // CHECK-DAG: [[B_HI:%.*]] = waveasm.precolored.sreg 3 : !waveasm.sreg + // CHECK: [[HI_LT:%.*]] = waveasm.s_cmp_lt_i32 [[A_HI]], [[B_HI]] + // CHECK: [[HI_EQ:%.*]] = waveasm.s_cmp_eq_i32 [[A_HI]], [[B_HI]] + // CHECK: [[LO_LT:%.*]] = waveasm.s_cmp_lt_u32 [[A_LO]], [[B_LO]] + // CHECK: [[EQ_AND_LO:%.*]] = waveasm.s_and_b32 [[HI_EQ]], [[LO_LT]] + // CHECK: waveasm.s_or_b32 [[HI_LT]], [[EQ_AND_LO]] + %cmp = waveasm.arith.cmp slt, %a, %b : (!waveasm.sreg<2, 2>, !waveasm.sreg<2, 2>) -> !waveasm.sreg + + // CHECK-NOT: waveasm.arith. + waveasm.s_endpgm +} + +// CHECK-LABEL: waveasm.program @test_select_i64 +waveasm.program @test_select_i64 + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %v0 = waveasm.precolored.vreg 0 : !waveasm.vreg + %v1 = waveasm.precolored.vreg 1 : !waveasm.vreg + %va = waveasm.pack %v0, %v0 : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<2> + %vb = waveasm.pack %v1, %v1 : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<2> + + // Set up VCC with a compare first. + // CHECK: waveasm.v_cmp_lt_i32 + %cmp = waveasm.arith.cmp slt, %v0, %v1 : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg + + // i64 select: split, select each half, pack. + // splitI64 looks through pack. v_cndmask_b32 arg order: (false, true, vcc). + // (No v_cmp_ne_i32 here because the cond is from an i32 compare, + // which is already a VCC placeholder, not a materialized VGPR boolean.) + // CHECK: [[SEL_LO:%.*]] = waveasm.v_cndmask_b32 [[VB:%.*]], [[VA:%.*]], %{{.*}} + // CHECK: [[SEL_HI:%.*]] = waveasm.v_cndmask_b32 [[VB]], [[VA]], %{{.*}} + // CHECK: waveasm.pack [[SEL_LO]], [[SEL_HI]] + %sel = waveasm.arith.select %cmp, %va, %vb : (!waveasm.vreg<2>, !waveasm.vreg<2>, !waveasm.vreg) -> !waveasm.vreg<2> + + // CHECK-NOT: waveasm.arith. + waveasm.s_endpgm +} + +// CHECK-LABEL: waveasm.program @test_trunc_i64 +waveasm.program @test_trunc_i64 + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi + attributes {vgprs = 32 : i64, sgprs = 16 : i64} { + + %v0 = waveasm.precolored.vreg 0 : !waveasm.vreg + %va = waveasm.pack %v0, %v0 : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<2> + + // Trunc i64 VGPR from pack -> returns pack input directly (no extract). + // CHECK-NOT: waveasm.extract + %trunc = waveasm.arith.trunc %va : (!waveasm.vreg<2>) -> !waveasm.vreg + + // CHECK-NOT: waveasm.arith. + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/linear-scan-if-feeds-loop.mlir b/waveasm/test/Transforms/linear-scan-if-feeds-loop.mlir new file mode 100644 index 0000000000..b066fc3a9e --- /dev/null +++ b/waveasm/test/Transforms/linear-scan-if-feeds-loop.mlir @@ -0,0 +1,68 @@ +// RUN: waveasm-translate --waveasm-linear-scan %s 2>&1 | FileCheck %s +// +// Test: waveasm.if results feeding waveasm.loop init args. +// +// When an if-op result is used as a loop init arg, the register allocator +// ties the if result and the loop block arg to the same physical register. +// The LinearScanPass must set the if-op result type from the allocation +// mapping (not from the then-yield operand type), because the then-yield +// operand may have been allocated to a different physical register (e.g. +// from an inner loop). Without the fix, the LoopLikeOpInterface verifier +// rejects the mismatch between init arg and region iter_arg types. + +// CHECK-LABEL: waveasm.program @if_result_feeds_loop + +// The if-op result must get its physical register from the allocation +// mapping, which ties it to the loop block arg. Capture the if-result +// physical register index as IF_REG. +// CHECK: waveasm.if {{.*}} -> !waveasm.pareg<[[IF_REG:[0-9]+]], 4> + +// The then-yield (from MFMA) carries a *different* physical register. +// CHECK: waveasm.v_mfma_scale_f32_16x16x128_f8f6f4 {{.*}} -> !waveasm.pareg<[[MFMA_REG:[0-9]+]], 4> +// CHECK-NEXT: waveasm.yield {{.*}} : !waveasm.pareg<[[MFMA_REG]], 4> + +// The loop init arg type must exactly match the if-result type (IF_REG). +// CHECK: waveasm.loop ({{.*}}!waveasm.pareg<[[IF_REG]], 4>) -> ({{.*}}!waveasm.pareg<[[IF_REG]], 4>) + +// Inside the loop body, the block arg fed to MFMA must also be pareg. +// CHECK: waveasm.v_mfma_scale_f32_16x16x128_f8f6f4 {{.*}}!waveasm.pareg<[[IF_REG]], 4>{{.*}} -> !waveasm.pareg<[[IF_REG]], 4> + +waveasm.program @if_result_feeds_loop + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi<> { + + %c0 = waveasm.constant 0 : !waveasm.imm<0> + %c1 = waveasm.constant 1 : !waveasm.imm<1> + %c4 = waveasm.constant 4 : !waveasm.imm<4> + %v0 = waveasm.precolored.vreg 0, 4 : !waveasm.pvreg<0, 4> + %v4 = waveasm.precolored.vreg 4, 4 : !waveasm.pvreg<4, 4> + %vs = waveasm.precolored.vreg 8 : !waveasm.pvreg<8> + + %s_zero = waveasm.s_mov_b32 %c0 : !waveasm.imm<0> -> !waveasm.sreg + %cmp = waveasm.s_cmp_lt_u32 %s_zero, %c1 : !waveasm.sreg, !waveasm.imm<1> -> !waveasm.sreg + + %if_result = waveasm.if %cmp : !waveasm.sreg -> !waveasm.areg<4, 4> { + %acc_init = waveasm.v_mov_b32 %c0 : !waveasm.imm<0> -> !waveasm.areg<4, 4> + %mfma = waveasm.v_mfma_scale_f32_16x16x128_f8f6f4 %v0, %v4, %acc_init, %vs, %vs + : !waveasm.pvreg<0, 4>, !waveasm.pvreg<4, 4>, !waveasm.areg<4, 4>, !waveasm.pvreg<8>, !waveasm.pvreg<8> -> !waveasm.areg<4, 4> + waveasm.yield %mfma : !waveasm.areg<4, 4> + } else { + %zero_acc = waveasm.v_mov_b32 %c0 : !waveasm.imm<0> -> !waveasm.areg<4, 4> + waveasm.yield %zero_acc : !waveasm.areg<4, 4> + } + + %init_i = waveasm.s_mov_b32 %c0 : !waveasm.imm<0> -> !waveasm.sreg + + %i_out, %acc_out = waveasm.loop(%i = %init_i, %acc = %if_result) + : (!waveasm.sreg, !waveasm.areg<4, 4>) -> (!waveasm.sreg, !waveasm.areg<4, 4>) { + + %new_mfma = waveasm.v_mfma_scale_f32_16x16x128_f8f6f4 %v0, %v4, %acc, %vs, %vs + : !waveasm.pvreg<0, 4>, !waveasm.pvreg<4, 4>, !waveasm.areg<4, 4>, !waveasm.pvreg<8>, !waveasm.pvreg<8> -> !waveasm.areg<4, 4> + + %next_i:2 = waveasm.s_add_u32 %i, %c1 : !waveasm.sreg, !waveasm.imm<1> -> !waveasm.sreg, !waveasm.sreg + %loop_cond = waveasm.s_cmp_lt_u32 %next_i#0, %c4 : !waveasm.sreg, !waveasm.imm<4> -> !waveasm.sreg + waveasm.condition %loop_cond : !waveasm.sreg iter_args(%next_i#0, %new_mfma) : !waveasm.sreg, !waveasm.areg<4, 4> + } + + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/translate-from-llvm-add.mlir b/waveasm/test/Transforms/translate-from-llvm-add.mlir new file mode 100644 index 0000000000..9bb0f099f0 --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-add.mlir @@ -0,0 +1,15 @@ +// RUN: waveasm-translate %s --waveasm-translate-from-llvm | FileCheck %s +// Verify llvm.add is translated to waveasm.arith.add pseudo-op. + +// CHECK: waveasm.program @test__waveasm +// CHECK: waveasm.arith.add +// CHECK: waveasm.s_endpgm + +gpu.module @gpu_module { + llvm.func @test() attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %0 = llvm.mlir.constant(42 : i32) : i32 + %1 = llvm.mlir.constant(7 : i32) : i32 + %2 = llvm.add %0, %1 : i32 + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-bare-ptr-gep.mlir b/waveasm/test/Transforms/translate-from-llvm-bare-ptr-gep.mlir new file mode 100644 index 0000000000..aaa33fab83 --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-bare-ptr-gep.mlir @@ -0,0 +1,57 @@ +// RUN: waveasm-translate %s --waveasm-translate-from-llvm | FileCheck %s +// Verify bare-pointer GEPs before make.buffer.rsrc are handled. +// The byte offset from bare-pointer GEPs must be added to the buffer voffset. + +// CHECK: waveasm.program @test__waveasm +// Base offset (bare-ptr GEP) + element offset combined into voffset. +// CHECK: waveasm.v_add_u32 +// SRD word 1 stride bits cleared, flags patched from make.buffer.rsrc. +// CHECK: waveasm.raw "s_and_b32 +// CHECK: waveasm.raw "s_mov_b32 s{{[0-9]+}}, 0x27000" +// CHECK: waveasm.v_add_u32 +// CHECK: waveasm.buffer_load_ushort +// CHECK: waveasm.v_add_u32 +// CHECK: waveasm.raw "s_and_b32 +// CHECK: waveasm.raw "s_mov_b32 s{{[0-9]+}}, 0x27000" +// CHECK: waveasm.v_add_u32 +// CHECK: waveasm.buffer_store_short +// CHECK: waveasm.s_endpgm + +gpu.module @gpu_module { + llvm.func @test(%arg0: !llvm.ptr, %arg1: !llvm.ptr) attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %0 = llvm.mlir.constant(2147483645 : i64) : i64 + %1 = llvm.mlir.constant(2 : i32) : i32 + %2 = llvm.mlir.constant(128 : i32) : i32 + %3 = llvm.mlir.constant(159744 : i32) : i32 + %4 = llvm.mlir.constant(16448 : i16) : i16 + %5 = llvm.mlir.constant(0 : index) : i64 + %6 = llvm.mlir.constant(2 : index) : i64 + %7 = rocdl.workgroup.id.y range : i32 + %8 = llvm.sext %7 : i32 to i64 + %9 = rocdl.workitem.id.x range : i32 + %10 = llvm.sext %9 : i32 to i64 + %11 = llvm.trunc %8 : i64 to i32 + %12 = llvm.mul %11, %2 overflow : i32 + %13 = llvm.zext %12 : i32 to i64 + // Bare-pointer GEP on arg — computes row offset. + %14 = llvm.getelementptr nusw %arg0[%13] : (!llvm.ptr, i64) -> !llvm.ptr, i8 + %15 = llvm.mul %5, %6 overflow : i64 + // Chained bare-pointer GEP. + %16 = llvm.getelementptr nusw %14[%15] : (!llvm.ptr, i64) -> !llvm.ptr, i8 + // make.buffer.rsrc on the offset pointer. + %17 = rocdl.make.buffer.rsrc %16, %4, %0, %3 : !llvm.ptr to <7> + %18 = llvm.trunc %10 : i64 to i32 + %19 = llvm.mul %18, %1 overflow : i32 + %20 = llvm.zext %19 : i32 to i64 + // Buffer GEP — base offset from bare GEPs should be added here. + %21 = llvm.getelementptr nusw %17[%20] : (!llvm.ptr<7>, i64) -> !llvm.ptr<7>, i8 + %22 = llvm.load %21 : !llvm.ptr<7> -> vector<1xf16> + // Same pattern for store side. + %23 = llvm.getelementptr nusw %arg1[%13] : (!llvm.ptr, i64) -> !llvm.ptr, i8 + %24 = llvm.getelementptr nusw %23[%15] : (!llvm.ptr, i64) -> !llvm.ptr, i8 + %25 = rocdl.make.buffer.rsrc %24, %4, %0, %3 : !llvm.ptr to <7> + %26 = llvm.getelementptr nusw %25[%20] : (!llvm.ptr<7>, i64) -> !llvm.ptr<7>, i8 + llvm.store %22, %26 : vector<1xf16>, !llvm.ptr<7> + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-bitwise.mlir b/waveasm/test/Transforms/translate-from-llvm-bitwise.mlir new file mode 100644 index 0000000000..bbf7356b96 --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-bitwise.mlir @@ -0,0 +1,17 @@ +// RUN: waveasm-translate %s --waveasm-translate-from-llvm | FileCheck %s +// Verify llvm.or and llvm.and are translated to waveasm.arith.or/and pseudo-ops. + +// CHECK: waveasm.program @test__waveasm +// CHECK: waveasm.arith.or +// CHECK: waveasm.arith.and +// CHECK: waveasm.s_endpgm + +gpu.module @gpu_module { + llvm.func @test() attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %0 = llvm.mlir.constant(42 : i32) : i32 + %1 = llvm.mlir.constant(7 : i32) : i32 + %2 = llvm.or %0, %1 : i32 + %3 = llvm.and %0, %1 : i32 + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-error-multi-index-gep.mlir b/waveasm/test/Transforms/translate-from-llvm-error-multi-index-gep.mlir new file mode 100644 index 0000000000..aa926f03e7 --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-error-multi-index-gep.mlir @@ -0,0 +1,17 @@ +// RUN: not waveasm-translate %s --waveasm-translate-from-llvm 2>&1 | FileCheck %s +// Verify that multi-index GEPs produce a diagnostic instead of crashing. + +// CHECK: GEP with multiple indices not yet supported + +gpu.module @gpu_module { + llvm.func @test(%arg0: !llvm.ptr) attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %0 = llvm.mlir.constant(0 : i16) : i16 + %1 = llvm.mlir.constant(2147483645 : i64) : i64 + %2 = llvm.mlir.constant(822243328 : i32) : i32 + %3 = llvm.mlir.constant(0 : i32) : i32 + %4 = llvm.mlir.constant(1 : i32) : i32 + %5 = rocdl.make.buffer.rsrc %arg0, %0, %1, %2 : !llvm.ptr to <7> + %6 = llvm.getelementptr %5[%3, %4] : (!llvm.ptr<7>, i32, i32) -> !llvm.ptr<7>, !llvm.array<4 x i32> + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-error-unsupported-target.mlir b/waveasm/test/Transforms/translate-from-llvm-error-unsupported-target.mlir new file mode 100644 index 0000000000..9226c76296 --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-error-unsupported-target.mlir @@ -0,0 +1,15 @@ +// RUN: not waveasm-translate %s --waveasm-translate-from-llvm='target=gfx942' 2>&1 | FileCheck %s +// Verify that non-gfx950 targets are rejected. + +// CHECK: LLVM->WaveASM translation only supports gfx950 + +gpu.module @gpu_module { + llvm.func @test(%arg0: !llvm.ptr) attributes { + gpu.kernel, + gpu.known_block_size = array, + rocdl.kernel, + rocdl.reqd_work_group_size = array + } { + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-error-workgroup-size.mlir b/waveasm/test/Transforms/translate-from-llvm-error-workgroup-size.mlir new file mode 100644 index 0000000000..96d46b77ea --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-error-workgroup-size.mlir @@ -0,0 +1,15 @@ +// RUN: not waveasm-translate %s --waveasm-translate-from-llvm 2>&1 | FileCheck %s +// Verify that contradicting workgroup size attributes produce a diagnostic. + +// CHECK: contradicting workgroup size attributes + +gpu.module @gpu_module { + llvm.func @test(%arg0: !llvm.ptr, %arg1: !llvm.ptr) attributes { + gpu.kernel, + gpu.known_block_size = array, + rocdl.kernel, + rocdl.reqd_work_group_size = array + } { + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-gep-chain.mlir b/waveasm/test/Transforms/translate-from-llvm-gep-chain.mlir new file mode 100644 index 0000000000..6591d5376b --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-gep-chain.mlir @@ -0,0 +1,25 @@ +// RUN: waveasm-translate %s --waveasm-translate-from-llvm | FileCheck %s +// Verify chained GEPs (GEP on GEP) are handled by adding offsets. + +// CHECK: waveasm.program @test__waveasm +// CHECK: waveasm.v_add_u32 +// CHECK: waveasm.buffer_load_dword +// CHECK: waveasm.s_endpgm + +gpu.module @gpu_module { + llvm.func @test(%arg0: !llvm.ptr) attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %0 = llvm.mlir.constant(0 : i16) : i16 + %1 = llvm.mlir.constant(2147483645 : i64) : i64 + %2 = llvm.mlir.constant(822243328 : i32) : i32 + %3 = rocdl.workitem.id.x range : i32 + %4 = llvm.sext %3 : i32 to i64 + %5 = rocdl.make.buffer.rsrc %arg0, %0, %1, %2 : !llvm.ptr to <7> + // First GEP from buffer resource. + %6 = llvm.getelementptr nusw %5[%4] : (!llvm.ptr<7>, i64) -> !llvm.ptr<7>, i8 + // Chained GEP on previous GEP result. + %7 = llvm.mlir.constant(256 : i64) : i64 + %8 = llvm.getelementptr nusw %6[%7] : (!llvm.ptr<7>, i64) -> !llvm.ptr<7>, i8 + %9 = llvm.load %8 : !llvm.ptr<7> -> i32 + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-i64-gep.mlir b/waveasm/test/Transforms/translate-from-llvm-i64-gep.mlir new file mode 100644 index 0000000000..b42b8a6c4a --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-i64-gep.mlir @@ -0,0 +1,31 @@ +// RUN: waveasm-translate %s --waveasm-translate-from-llvm --waveasm-arith-legalization | FileCheck %s +// Verify that i64 GEP indices are truncated to i32 for buffer voffset. + +// CHECK-LABEL: waveasm.program @test_i64_gep_offset__waveasm + +// zext i32 -> i64: pack {mul_result, 0} into vreg<2>. +// CHECK: %[[MUL:.*]] = waveasm.v_mul_lo_u32 +// CHECK: waveasm.pack %[[MUL]], %{{.*}} : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<2, 2> + +// Trunc looks through pack, uses mul result directly as buffer voffset. +// CHECK: waveasm.buffer_load_ushort %{{.*}}, %[[MUL]], %{{.*}} : !waveasm.sreg<4, 4>, !waveasm.vreg, !waveasm.imm<0> + +// No arith pseudo-ops should remain. +// CHECK-NOT: waveasm.arith. +// CHECK: waveasm.s_endpgm + +gpu.module @gpu_module { + llvm.func @test_i64_gep_offset(%arg0: !llvm.ptr) attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %0 = llvm.mlir.constant(2147483645 : i64) : i64 + %1 = llvm.mlir.constant(822243328 : i32) : i32 + %2 = llvm.mlir.constant(0 : i16) : i16 + %3 = llvm.mlir.constant(2 : i32) : i32 + %tid = rocdl.workitem.id.x range : i32 + %mul = llvm.mul %tid, %3 : i32 + %ext = llvm.zext %mul : i32 to i64 + %rsrc = rocdl.make.buffer.rsrc %arg0, %2, %0, %1 : !llvm.ptr to <7> + %gep = llvm.getelementptr nusw %rsrc[%ext] : (!llvm.ptr<7>, i64) -> !llvm.ptr<7>, i8 + %val = llvm.load %gep : !llvm.ptr<7> -> vector<1xf16> + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-i64-select.mlir b/waveasm/test/Transforms/translate-from-llvm-i64-select.mlir new file mode 100644 index 0000000000..fdd5aa88c2 --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-i64-select.mlir @@ -0,0 +1,42 @@ +// RUN: waveasm-translate %s --waveasm-translate-from-llvm --waveasm-arith-legalization | FileCheck %s +// Verify that i64 constants are split into lo/hi halves and that i64 select +// legalizes correctly through the full translation + legalization pipeline. + +// CHECK-LABEL: waveasm.program @test_i64_select__waveasm + +// i64 constant 100 split into v_mov_b32(100) + v_mov_b32(0) + pack. +// CHECK-DAG: %[[C100_LO:.*]] = waveasm.v_mov_b32 %{{.*}} : !waveasm.imm<100> -> !waveasm.vreg +// CHECK-DAG: %[[C100_HI:.*]] = waveasm.v_mov_b32 %{{.*}} : !waveasm.imm<0> -> !waveasm.vreg +// CHECK: waveasm.pack %[[C100_LO]], %[[C100_HI]] +// CHECK-SAME: -> !waveasm.vreg<2, 2> + +// sext i32 -> i64: v_ashrrev_i32 for sign extension + pack. +// CHECK: %[[TID:.*]] = waveasm.precolored.vreg 0 : !waveasm.vreg +// CHECK: %[[SIGN:.*]] = waveasm.v_ashrrev_i32 %{{.*}}, %[[TID]] +// CHECK: waveasm.pack %[[TID]], %[[SIGN]] +// CHECK-SAME: -> !waveasm.vreg<2, 2> + +// i32 compare sets VCC. +// CHECK: waveasm.v_cmp_lt_i32 + +// i64 select: splitI64 looks through pack, uses original lo/hi directly. +// CHECK: waveasm.v_cndmask_b32 %[[C100_LO]], %[[TID]], %{{.*}} +// CHECK: waveasm.v_cndmask_b32 %[[C100_HI]], %[[SIGN]], %{{.*}} +// CHECK: waveasm.pack %{{.*}}, %{{.*}} : (!waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<2, 2> + +// No arith pseudo-ops should remain. +// CHECK-NOT: waveasm.arith. +// CHECK: waveasm.s_endpgm + +gpu.module @gpu_module { + llvm.func @test_i64_select() attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %c100 = llvm.mlir.constant(100 : i64) : i64 + %c27 = llvm.mlir.constant(27 : i32) : i32 + %tid = rocdl.workitem.id.x range : i32 + %ext = llvm.sext %tid : i32 to i64 + %cmp = llvm.icmp "slt" %tid, %c27 : i32 + %sel = llvm.select %cmp, %ext, %c100 : i1, i64 + %trunc = llvm.trunc %sel : i64 to i32 + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-poison.mlir b/waveasm/test/Transforms/translate-from-llvm-poison.mlir new file mode 100644 index 0000000000..aa92a96031 --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-poison.mlir @@ -0,0 +1,30 @@ +// RUN: waveasm-translate %s --waveasm-translate-from-llvm | FileCheck %s +// Verify llvm.mlir.poison is translated as zero for both i32 and i64. + +// CHECK: waveasm.program @test__waveasm +// CHECK: waveasm.constant 0 +// CHECK: waveasm.v_mov_b32 +// CHECK: waveasm.s_endpgm + +gpu.module @gpu_module { + llvm.func @test() attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %0 = llvm.mlir.poison : i32 + llvm.return + } +} + +// ----- + +// CHECK: waveasm.program @test_i64__waveasm +// CHECK: [[ZERO:%.*]] = waveasm.constant 0 +// CHECK: [[LO:%.*]] = waveasm.v_mov_b32 [[ZERO]] +// CHECK: [[HI:%.*]] = waveasm.v_mov_b32 [[ZERO]] +// CHECK: waveasm.pack [[LO]], [[HI]] +// CHECK: waveasm.s_endpgm + +gpu.module @gpu_module_i64 { + llvm.func @test_i64() attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %0 = llvm.mlir.poison : i64 + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-scalar-args.mlir b/waveasm/test/Transforms/translate-from-llvm-scalar-args.mlir new file mode 100644 index 0000000000..b1ba9f2f3f --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-scalar-args.mlir @@ -0,0 +1,24 @@ +// RUN: waveasm-translate %s --waveasm-translate-from-llvm | FileCheck %s +// Verify scalar (non-pointer) kernel arguments are mapped to preloaded SGPRs, +// not treated as buffer pointers. i64 args produce arith pseudo-ops that the +// legalization pass will lower to concrete SALU/VALU ops. + +// CHECK: waveasm.program @test__waveasm +// Pointer arg gets SRD setup. +// CHECK: waveasm.precolored.sreg [[SRD:[0-9]+]], 4 +// CHECK: waveasm.raw "s_mov_b32 +// CHECK: waveasm.raw "s_mov_b32 +// Scalar arg mapped to preloaded SGPR pair, sext/cmp emitted as pseudo-ops. +// CHECK: waveasm.precolored.sreg [[PAIR:[0-9]+]], 2 +// CHECK: waveasm.arith.sext +// CHECK: waveasm.arith.cmp slt +// CHECK: waveasm.s_endpgm + +gpu.module @gpu_module { + llvm.func @test(%arg0: !llvm.ptr, %arg1: i64) attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %0 = rocdl.workitem.id.x range : i32 + %1 = llvm.sext %0 : i32 to i64 + %2 = llvm.icmp "slt" %1, %arg1 : i64 + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-select.mlir b/waveasm/test/Transforms/translate-from-llvm-select.mlir new file mode 100644 index 0000000000..beeba3ed4e --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-select.mlir @@ -0,0 +1,28 @@ +// RUN: waveasm-translate %s --waveasm-translate-from-llvm | FileCheck %s +// Verify llvm.select is translated to waveasm.arith.select pseudo-op +// with correct operand wiring from icmp condition. + +// CHECK: waveasm.program @test__waveasm +// CHECK: [[C10:%.*]] = waveasm.constant 10 +// CHECK: [[A:%.*]] = waveasm.v_mov_b32 [[C10]] +// CHECK: [[C20:%.*]] = waveasm.constant 20 +// CHECK: [[B:%.*]] = waveasm.v_mov_b32 [[C20]] +// CHECK: [[C1:%.*]] = waveasm.constant 1 +// CHECK: [[X:%.*]] = waveasm.v_mov_b32 [[C1]] +// CHECK: [[C2:%.*]] = waveasm.constant 2 +// CHECK: [[Y:%.*]] = waveasm.v_mov_b32 [[C2]] +// CHECK: [[CMP:%.*]] = waveasm.arith.cmp slt, [[A]], [[B]] +// CHECK: waveasm.arith.select [[CMP]], [[X]], [[Y]] +// CHECK: waveasm.s_endpgm + +gpu.module @gpu_module { + llvm.func @test() attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %a = llvm.mlir.constant(10 : i32) : i32 + %b = llvm.mlir.constant(20 : i32) : i32 + %x = llvm.mlir.constant(1 : i32) : i32 + %y = llvm.mlir.constant(2 : i32) : i32 + %cmp = llvm.icmp "slt" %a, %b : i32 + %sel = llvm.select %cmp, %x, %y : i1, i32 + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-wide-load-store.mlir b/waveasm/test/Transforms/translate-from-llvm-wide-load-store.mlir new file mode 100644 index 0000000000..794712698b --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-wide-load-store.mlir @@ -0,0 +1,32 @@ +// RUN: waveasm-translate %s --waveasm-translate-from-llvm | FileCheck %s +// Verify wide (dwordx2/x4) buffer loads and stores. + +// CHECK: waveasm.program @test__waveasm +// CHECK: waveasm.buffer_load_dwordx2 +// CHECK: waveasm.buffer_load_dwordx4 +// CHECK: waveasm.buffer_store_dwordx2 +// CHECK: waveasm.buffer_store_dwordx4 +// CHECK: waveasm.s_endpgm + +gpu.module @gpu_module { + llvm.func @test(%arg0: !llvm.ptr, %arg1: !llvm.ptr) attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %0 = llvm.mlir.constant(0 : i16) : i16 + %1 = llvm.mlir.constant(2147483645 : i64) : i64 + %2 = llvm.mlir.constant(822243328 : i32) : i32 + %3 = rocdl.workitem.id.x range : i32 + %4 = llvm.sext %3 : i32 to i64 + %5 = rocdl.make.buffer.rsrc %arg0, %0, %1, %2 : !llvm.ptr to <7> + %6 = llvm.getelementptr nusw %5[%4] : (!llvm.ptr<7>, i64) -> !llvm.ptr<7>, i8 + // 8-byte load (dwordx2). + %7 = llvm.load %6 : !llvm.ptr<7> -> vector<4xf16> + // 16-byte load (dwordx4). + %8 = llvm.load %6 : !llvm.ptr<7> -> vector<4xf32> + %9 = rocdl.make.buffer.rsrc %arg1, %0, %1, %2 : !llvm.ptr to <7> + %10 = llvm.getelementptr nusw %9[%4] : (!llvm.ptr<7>, i64) -> !llvm.ptr<7>, i8 + // 8-byte store (dwordx2). + llvm.store %7, %10 : vector<4xf16>, !llvm.ptr<7> + // 16-byte store (dwordx4). + llvm.store %8, %10 : vector<4xf32>, !llvm.ptr<7> + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-workgroup-id.mlir b/waveasm/test/Transforms/translate-from-llvm-workgroup-id.mlir new file mode 100644 index 0000000000..429a818f14 --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-workgroup-id.mlir @@ -0,0 +1,17 @@ +// RUN: waveasm-translate %s --waveasm-translate-from-llvm | FileCheck %s +// Verify workgroup ID intrinsics are translated to precolored SGPRs. + +// CHECK: waveasm.program @test__waveasm +// CHECK: waveasm.precolored.sreg +// CHECK: waveasm.precolored.sreg +// CHECK: waveasm.precolored.sreg +// CHECK: waveasm.s_endpgm + +gpu.module @gpu_module { + llvm.func @test() attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %0 = rocdl.workgroup.id.x : i32 + %1 = rocdl.workgroup.id.y : i32 + %2 = rocdl.workgroup.id.z : i32 + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm.mlir b/waveasm/test/Transforms/translate-from-llvm.mlir new file mode 100644 index 0000000000..bb51b56531 --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm.mlir @@ -0,0 +1,40 @@ +// RUN: waveasm-translate %s --waveasm-translate-from-llvm | FileCheck %s +// Verify the LLVM→WaveASM translation pass handles a copy kernel. + +// CHECK: gpu.module @gpu_module +// CHECK: llvm.func @test +// CHECK: waveasm.program @test__waveasm +// CHECK-SAME: kernel_name = "test" +// CHECK: waveasm.precolored.vreg +// CHECK: waveasm.arith.cmp slt +// CHECK: waveasm.arith.select +// CHECK: waveasm.arith.mul +// CHECK: waveasm.buffer_load_ushort +// CHECK: waveasm.buffer_store_short +// CHECK: waveasm.s_endpgm + +gpu.module @gpu_module { + llvm.func @test(%arg0: !llvm.ptr, %arg1: !llvm.ptr) attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %0 = llvm.mlir.constant(1073741823 : index) : i64 + %1 = llvm.mlir.constant(2147483645 : i64) : i64 + %2 = llvm.mlir.constant(27 : i32) : i32 + %3 = llvm.mlir.constant(2 : i32) : i32 + %4 = llvm.mlir.constant(0 : i16) : i16 + %5 = llvm.mlir.constant(822243328 : i32) : i32 + %6 = rocdl.workitem.id.x range : i32 + %7 = llvm.sext %6 : i32 to i64 + %8 = llvm.trunc %7 : i64 to i32 + %9 = llvm.icmp "slt" %8, %2 : i32 + %10 = rocdl.make.buffer.rsrc %arg0, %4, %1, %5 : !llvm.ptr to <7> + %11 = llvm.select %9, %7, %0 : i1, i64 + %12 = llvm.trunc %11 : i64 to i32 + %13 = llvm.mul %12, %3 overflow : i32 + %14 = llvm.zext %13 : i32 to i64 + %15 = llvm.getelementptr nusw %10[%14] : (!llvm.ptr<7>, i64) -> !llvm.ptr<7>, i8 + %16 = llvm.load %15 : !llvm.ptr<7> -> vector<1xf16> + %17 = rocdl.make.buffer.rsrc %arg1, %4, %1, %5 : !llvm.ptr to <7> + %18 = llvm.getelementptr nusw %17[%14] : (!llvm.ptr<7>, i64) -> !llvm.ptr<7>, i8 + llvm.store %16, %18 : vector<1xf16>, !llvm.ptr<7> + llvm.return + } +} diff --git a/waveasm/test/Translate/affine-negative-const-and-normalization.mlir b/waveasm/test/Translate/affine-negative-const-and-normalization.mlir new file mode 100644 index 0000000000..1039e8319e --- /dev/null +++ b/waveasm/test/Translate/affine-negative-const-and-normalization.mlir @@ -0,0 +1,96 @@ +// RUN: waveasm-translate %s 2>&1 | FileCheck %s +// +// Test: correct handling of negative constants in affine expressions and +// normalization of >32-bit constants produced by MLIR's affine canonicalizer. +// +// Bug 1 (negative constants): BitRange::fromConstant(-16) previously returned +// an empty range because the old while(v > 0) loop never executed for negative +// values. This caused the OR-for-add optimisation to fire incorrectly: the +// compiler emitted v_or_b32 instead of v_add_u32 for (tid floordiv 8) - 16, +// producing wrong results when tid >= 128 (bit 4 of the quotient overlaps +// with bit 4 of 0xFFFFFFF0). +// +// Bug 2 (large constants): MLIR's affine canonicalizer can compose nested +// affine.apply ops, multiplying coefficients by a large LCM factor that +// exceeds 32 bits. The normalization pass divides all variable coefficients +// and the divisor by their GCD to bring them back into 32-bit range. + +module { + +// --- Negative constant: must use v_add_u32, not v_or_b32 --- + +func.func @negative_const_add(%binding: !stream.binding) { + + %c0 = arith.constant 0 : index + %flat = "stream.binding.subspan"(%binding, %c0) : (!stream.binding, index) -> memref + %tid = gpu.thread_id x upper_bound 256 + + // CHECK-LABEL: waveasm.program @negative_const_add + + // tid floordiv 8 = shift right by 3 + // CHECK: waveasm.v_lshrrev_b32 + // With upper_bound 256, floordiv 8 gives [0,31] (bits 0..4). + // -16 = 0xFFFFFFF0 occupies bits 4..31 -- bit 4 overlaps. + // CHECK: waveasm.constant -16 + // CHECK: waveasm.v_add_u32 + // CHECK-NOT: waveasm.v_or_b32 + %r = affine.apply affine_map<()[s0] -> (s0 floordiv 8 - 16)>()[%tid] + + return +} + +// --- Large constant normalization (power-of-2 reduced divisor) --- + +func.func @large_const_norm_po2(%binding: !stream.binding) { + + %c0 = arith.constant 0 : index + %flat = "stream.binding.subspan"(%binding, %c0) : (!stream.binding, index) -> memref + %tid = gpu.thread_id x + %tid_y = gpu.thread_id y + + // CHECK-LABEL: waveasm.program @large_const_norm_po2 + + // Input: (s0 * 6000000000 + s1 * 9000000000) floordiv 12000000000 + // GCD of {6e9, 9e9, 12e9} = 3e9 + // Normalized: (s0 * 2 + s1 * 3) floordiv 4 + // s0*2 becomes lshlrev by 1; floordiv 4 becomes lshrrev by 2 + // All constants fit in 32 bits after normalization. + // CHECK: waveasm.constant 1 + // CHECK: waveasm.v_lshlrev_b32 + // CHECK: waveasm.constant 3 + // CHECK: waveasm.v_mul_lo_u32 + // CHECK: waveasm.v_add_u32 + // CHECK: waveasm.constant 2 + // CHECK: waveasm.v_lshrrev_b32 + %r = affine.apply affine_map<()[s0, s1] -> ((s0 * 6000000000 + s1 * 9000000000) floordiv 12000000000)>()[%tid, %tid_y] + + return +} + +// --- Large constant normalization (non-power-of-2 reduced divisor) --- + +func.func @large_const_norm_magic(%binding: !stream.binding) { + + %c0 = arith.constant 0 : index + %flat = "stream.binding.subspan"(%binding, %c0) : (!stream.binding, index) -> memref + %tid = gpu.thread_id x + %tid_y = gpu.thread_id y + + // CHECK-LABEL: waveasm.program @large_const_norm_magic + + // Input: (s0 * 21000000000 + s1 * 14000000000) floordiv 49000000000 + // GCD of {21e9, 14e9, 49e9} = 7e9 + // Normalized: (s0 * 3 + s1 * 2) floordiv 7 + // floordiv 7 uses magic number multiplication (add form) + // All constants fit in 32 bits after normalization. + // CHECK: waveasm.constant 3 + // CHECK: waveasm.v_mul_lo_u32 + // CHECK: waveasm.v_add_u32 + // floordiv 7 via magic number multiplication + // CHECK: waveasm.v_mul_hi_u32 + %r = affine.apply affine_map<()[s0, s1] -> ((s0 * 21000000000 + s1 * 14000000000) floordiv 49000000000)>()[%tid, %tid_y] + + return +} + +} diff --git a/waveasm/tools/waveasm-translate/CMakeLists.txt b/waveasm/tools/waveasm-translate/CMakeLists.txt index 9208b0173f..55591306e5 100644 --- a/waveasm/tools/waveasm-translate/CMakeLists.txt +++ b/waveasm/tools/waveasm-translate/CMakeLists.txt @@ -20,6 +20,8 @@ target_link_libraries(waveasm-translate MLIRFuncDialect MLIRGPUDialect MLIRIR + MLIRLLVMDialect + MLIRROCDLDialect MLIRMemRefDialect MLIRParser MLIRPass diff --git a/waveasm/tools/waveasm-translate/waveasm-translate.cpp b/waveasm/tools/waveasm-translate/waveasm-translate.cpp index 4af6f98773..f9836f3959 100644 --- a/waveasm/tools/waveasm-translate/waveasm-translate.cpp +++ b/waveasm/tools/waveasm-translate/waveasm-translate.cpp @@ -25,6 +25,8 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -111,6 +113,12 @@ int main(int argc, char **argv) { waveasm::registerWaveASMPasses(); mlir::registerTransformsPasses(); + // Enable standard MLIR pass manager CLI options: + // --mlir-print-ir-after-all, --mlir-print-ir-before-all, + // --mlir-print-ir-after-change, --mlir-print-ir-after-failure, etc. + // IR is printed to stderr; pipe with 2>file to capture. + mlir::registerPassManagerCLOptions(); + // Construct AFTER pass registration — PassNameParser::initialize() snapshots // the registry, so passes must already be registered at this point. static mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run"); @@ -128,6 +136,8 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); + registry.insert(); registry.insert(); MLIRContext context(registry); @@ -155,8 +165,17 @@ int main(int argc, char **argv) { bool hasWaveASMPrograms = false; module->walk([&](waveasm::ProgramOp) { hasWaveASMPrograms = true; }); - // If not already WAVEASM IR, translate from MLIR - if (!hasWaveASMPrograms) { + // Check if input contains LLVM dialect kernels (handled by the + // --waveasm-translate-from-llvm pass flag, not auto-translation). + bool hasLLVMKernels = false; + module->walk([&](LLVM::LLVMFuncOp func) { + if (func->hasAttr("gpu.kernel") || func->hasAttr("rocdl.kernel")) + hasLLVMKernels = true; + }); + + // Auto-translate high-level MLIR dialects (not LLVM dialect, not already + // WaveASM). + if (!hasWaveASMPrograms && !hasLLVMKernels) { // Run pre-translation MLIR passes. { PassManager prePm(&context); @@ -193,6 +212,8 @@ int main(int argc, char **argv) { // Build pass pipeline from CLI flags. PassManager pm(&context); + // Apply CLI options for IR printing (--mlir-print-ir-after-all, etc.) + mlir::applyPassManagerCLOptions(pm); if (passPipeline.hasAnyOccurrences()) { auto errorHandler = [](const Twine &msg) { llvm::errs() << msg << "\n"; diff --git a/waveasm/waveasm_e2e.py b/waveasm/waveasm_e2e.py index 1a33e798f8..db7273dfa5 100644 --- a/waveasm/waveasm_e2e.py +++ b/waveasm/waveasm_e2e.py @@ -211,7 +211,9 @@ def compile_mlir_to_asm( "--canonicalize", # Clean up dead instructions from offset opt. "--waveasm-scoped-cse", # Re-deduplicate after offset folding. "--waveasm-loop-address-promotion", + "--waveasm-scc-verifier", # Verify no SCC hazards before regalloc. "--waveasm-linear-scan=max-vgprs=512 max-agprs=512", # Register allocation. + "--waveasm-vgpr-compaction", # Compact VGPR assignments. f"--waveasm-insert-waitcnt=ticketed-waitcnt={ticketed}", # Insert waits. f"--waveasm-hazard-mitigation=target={self.target}", # Handle hazards. "--emit-assembly",