diff --git a/lit_tests/kernel/wave/water_host_wrapper.py b/lit_tests/kernel/wave/water_host_wrapper.py index bc047f7a10..5e6e7a552b 100644 --- a/lit_tests/kernel/wave/water_host_wrapper.py +++ b/lit_tests/kernel/wave/water_host_wrapper.py @@ -31,6 +31,7 @@ def get_wave_compile_options( canonicalize: bool = False, dynamic_symbols=[], additional_symbols={}, + wave_runtime: bool = False, location_capture_config=LocationCaptureConfig( level=LocationCaptureLevel.FILE_LINE_COL ), @@ -61,6 +62,7 @@ def get_wave_compile_options( location_capture_config=location_capture_config, drop_debug_info_before_mlir=drop_debug_info_before_mlir, use_water_backend=True, + wave_runtime=wave_runtime, ) @@ -227,3 +229,52 @@ def read_write( # CHECK-SAME: blocks in (%[[C1]], %[[C1]], %[[C1]]) # CHECK-SAME: threads in (%[[C64]], %[[C1]], %[[C1]]) # CHECK: return + + +@run_test +def test_dynamic_strides_output_placeholder_first(): + constraints = get_constraints() + + @tkw.wave(constraints) + def output_then_input( + out: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + inp: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + ): + res = tkw.read(inp) + tkw.write(res, out) + + output_then_input = wave_compile( + get_wave_compile_options( + canonicalize=True, + wave_runtime=True, + drop_debug_info_before_mlir=True, + ), + output_then_input, + ) + print(output_then_input.asm) + + # Even though the Python placeholder order is (out, inp), kernel ABI order + # is linearized to input first, output second. + # CHECK-LABEL: test_dynamic_strides_output_placeholder_first + # CHECK: gpu.func @output_then_input + # CHECK-SAME: (%[[IN:.*]]: memref {llvm.inreg}, %[[OUT:.*]]: memref {llvm.inreg}, %[[IN_STRIDE:.*]]: index {llvm.inreg}, %[[OUT_STRIDE:.*]]: index {llvm.inreg}) + # CHECK: %[[IN_VIEW:.*]] = memref.reinterpret_cast %[[IN]] to offset: [0], sizes: [16, 16], strides: [%[[IN_STRIDE]], 1] + # CHECK: %[[OUT_VIEW:.*]] = memref.reinterpret_cast %[[OUT]] to offset: [0], sizes: [16, 16], strides: [%[[OUT_STRIDE]], 1] + + # Host wrapper must mirror that same ABI order for both buffers and strides. + # CHECK-LABEL: func.func @isolated_benchmark + # CHECK-SAME: (%[[STREAM:.*]]: !llvm.ptr, %[[OUT_ARG:.*]]: !llvm.ptr, %[[IN_ARG:.*]]: !llvm.ptr) + # CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + # CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + # CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index + # CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32 + # CHECK: %[[IN_BUF:.*]] = call @wave_get_buffer(%[[IN_ARG]]) : (!llvm.ptr) -> memref + # CHECK: %[[IN_VIEW0:.*]] = memref.view %[[IN_BUF]][%[[C0]]][] : memref to memref + # CHECK: %[[OUT_BUF:.*]] = call @wave_get_buffer(%[[OUT_ARG]]) : (!llvm.ptr) -> memref + # CHECK: %[[OUT_VIEW0:.*]] = memref.view %[[OUT_BUF]][%[[C0]]][] : memref to memref + # CHECK: %[[IN_STRIDE_I64:.*]] = call @wave_get_stride(%[[IN_ARG]], %[[C0_I32]]) : (!llvm.ptr, i32) -> i64 + # CHECK: %[[IN_STRIDE_IDX:.*]] = arith.index_cast %[[IN_STRIDE_I64]] : i64 to index + # CHECK: %[[OUT_STRIDE_I64:.*]] = call @wave_get_stride(%[[OUT_ARG]], %[[C0_I32]]) : (!llvm.ptr, i32) -> i64 + # CHECK: %[[OUT_STRIDE_IDX:.*]] = arith.index_cast %[[OUT_STRIDE_I64]] : i64 to index + # CHECK: gpu.launch_func @gpu_module::@output_then_input blocks in (%[[C1]], %[[C1]], %[[C1]]) threads in (%[[C64]], %[[C1]], %[[C1]]) args(%[[IN_VIEW0]] : memref, %[[OUT_VIEW0]] : memref, %[[IN_STRIDE_IDX]] : index, %[[OUT_STRIDE_IDX]] : index) + # CHECK: return diff --git a/tests/kernel/e2e/test_gemm_waveasm.py b/tests/kernel/e2e/test_gemm_waveasm.py new file mode 100644 index 0000000000..8b9fa86634 --- /dev/null +++ b/tests/kernel/e2e/test_gemm_waveasm.py @@ -0,0 +1,64 @@ +# Copyright 2026 The Wave Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""GEMM through the water+waveasm pipeline (LLVM dialect -> WaveASM -> binary).""" + +import pytest +import torch +from torch.testing import assert_close + +from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile +from wave_lang.kernel.wave.constraints import MMAType +from wave_lang.kernel.wave.templates.gemm import get_gemm_kernel +from wave_lang.kernel.wave.utils.run_utils import set_default_run_config +from wave_lang.kernel.wave.utils.torch_utils import device_randn, device_zeros + +from ..common.utils import require_cdna4, require_e2e + + +@require_e2e +@require_cdna4 +@pytest.mark.parametrize( + "shape,block_shape,waves_per_block", + [ + ((64, 64, 64), (32, 32, 16), (1, 1)), + ], + ids=["64x64x64"], +) +def test_gemm_water_waveasm( + shape: tuple[int, int, int], + block_shape: tuple[int, int, int], + waves_per_block: tuple[int, int], +) -> None: + """Test GEMM through the water+waveasm pipeline.""" + m, n, k = shape + + gemm, hyperparams, _ = get_gemm_kernel( + shape=shape, + dynamic_dims=False, + mfma_variant=MMAType.F32_16x16x16_F16, + block_shape=block_shape, + waves_per_block=waves_per_block, + ) + + options = WaveCompileOptions( + subs=hyperparams, + canonicalize=True, + use_water_backend=True, + use_buffer_ops=True, + backend="asm", + wave_runtime=True, + ) + options = set_default_run_config(options) + compiled = wave_compile(options, gemm) + + a = device_randn((m, k), dtype=torch.float16) + b = device_randn((n, k), dtype=torch.float16) + c = device_zeros((m, n), dtype=torch.float32) + compiled(a, b, c) + + expected = torch.matmul(a, b.T).float() + assert_close(c, expected, rtol=1e-3, atol=1e-3) diff --git a/wave_lang/kernel/compiler/wave_codegen/emitter.py b/wave_lang/kernel/compiler/wave_codegen/emitter.py index 36cb2ce912..b17bc3adab 100644 --- a/wave_lang/kernel/compiler/wave_codegen/emitter.py +++ b/wave_lang/kernel/compiler/wave_codegen/emitter.py @@ -71,6 +71,7 @@ BindingDesc, BoundKernelSignature, create_argument_locations, + get_dynamic_stride_arg_count, ) from ...lang.wave_types import IndexSymbol from ...wave.compile_options import WaveCompileOptions @@ -137,40 +138,70 @@ def emit_program_invariants(self): ), ] - def emit_func(self) -> Operation: - bindings = self.root_sig.sig.linear_bindings + def _abi_type(self, binding: BindingDesc) -> IrType: + if binding.binding_type == BindingType.KERNEL_BUFFER: + # Buffer passed to kernel as 0D memrefs to simplify ABI. + element_type = IrType.parse(binding.kernel_buffer_type.dtype.ir_type_asm()) + return MemRefType.get([], element_type=element_type) - def abi_type(binding: BindingDesc): - if binding.binding_type == BindingType.KERNEL_BUFFER: - # Buffer passed to kernel as 0D memrefs to simplify ABI. - element_type = IrType.parse( - binding.kernel_buffer_type.dtype.ir_type_asm() - ) - return MemRefType.get([], element_type=element_type) + # Scalars are passed as is. + return binding.as_mlir_type() - # Scalars are passed as is. - return binding.as_mlir_type() + def _create_stride_argument_locations( + self, bindings: list[BindingDesc], base_locations: list[Location] + ) -> list[Location]: + if not self.options.dynamic_strides: + return [] - arg_types = [abi_type(b) for b in bindings] + stride_locations = [] + for binding, base_location in zip(bindings, base_locations): + if binding.binding_type != BindingType.KERNEL_BUFFER: + continue + leading_count = max(0, len(binding.kernel_buffer_type.symbolic_shape) - 1) + argument_name = binding.name or "argument" + for dim_idx in range(leading_count): + stride_locations.append( + Location.name( + f"{argument_name}.stride.{dim_idx}", + base_location, + ) + ) + return stride_locations + + def _ordered_kernel_buffer_bindings( + self, bindings: list[BindingDesc] + ) -> list[BindingDesc]: + return [ + binding + for binding in bindings + if binding.binding_type == BindingType.KERNEL_BUFFER + ] - # Dynamic strides only with Wave runtime and LLVM backend (not ASM). - # Stride args only for leading dimensions (innermost always has unit stride). - stride_arg_count = 0 - if self.options.dynamic_strides: - stride_arg_count = sum( - max(0, len(b.kernel_buffer_type.symbolic_shape) - 1) - for b in self.root_sig.sig.kernel_buffer_bindings - ) - if stride_arg_count > 0: - arg_types += [IndexType.get()] * stride_arg_count + def _kernel_argument_abi( + self, bindings: list[BindingDesc] + ) -> tuple[list[IrType], list[Location], int]: + arg_types = [self._abi_type(binding) for binding in bindings] + arg_locations = create_argument_locations(bindings) + kernel_buffer_bindings = self._ordered_kernel_buffer_bindings(bindings) + stride_arg_count = get_dynamic_stride_arg_count( + self.options.dynamic_strides, + kernel_buffer_bindings, + ) + stride_locations = self._create_stride_argument_locations( + bindings, arg_locations + ) + assert len(stride_locations) == stride_arg_count + if stride_arg_count > 0: + arg_types += [IndexType.get()] * stride_arg_count + arg_locations += stride_locations + return arg_types, arg_locations, stride_arg_count + def emit_func(self) -> Operation: + bindings = self.root_sig.sig.linear_bindings + arg_types, arg_locations, stride_arg_count = self._kernel_argument_abi(bindings) ftype = FunctionType.get(arg_types, []) func_op = func_d.FuncOp(self.kernel_name, ftype, visibility="private") - - locs = create_argument_locations(bindings) - if stride_arg_count > 0: - locs += [Location.unknown()] * stride_arg_count - entry_block = func_op.add_entry_block(locs) + entry_block = func_op.add_entry_block(arg_locations) # Map dynamic symbols to buffer argument indices and dimensions. for bind, arg in zip(bindings, entry_block.arguments): @@ -285,8 +316,9 @@ def _declare_runtime_func( return func_op, symbol def emit_host_func(self, kernel_func: Operation) -> Operation: - # TODO: kernel bindings order may not be the same as the kernel function - # arguments order, so map kernel order to host function arguments order. + # Host placeholders follow the original Python signature, while kernel ABI + # arguments use linear_bindings order. Track both so launch args can be + # materialized in kernel ABI order from host-visible inputs. binding_map = {} symbol_map = {} @@ -308,22 +340,10 @@ def emit_host_func(self, kernel_func: Operation) -> Operation: bindings = self.root_sig.sig.linear_bindings ptr = llvm_d.PointerType.get() - - def abi_type(binding: BindingDesc): - if binding.binding_type == BindingType.KERNEL_BUFFER: - # Buffer passed to kernel as 0D memrefs to simplify ABI. - element_type = IrType.parse( - binding.kernel_buffer_type.dtype.ir_type_asm() - ) - return MemRefType.get([], element_type=element_type) - - # Scalars are passed as is. - return binding.as_mlir_type() - - arg_types = [abi_type(b) for b in bindings] - + arg_types, kernel_arg_locations, stride_arg_count = self._kernel_argument_abi( + bindings + ) ftype = FunctionType.get(arg_types, []) - locs = [a.location for a in kernel_func.body.blocks[0].arguments] gpu_module = gpu_d.module("gpu_module") gpu_module.parent.operation.attributes["gpu.container_module"] = UnitAttr.get() @@ -345,7 +365,7 @@ def abi_type(binding: BindingDesc): new_kernel_entry_block = kernel_func_wrapper.body.blocks.append( *arg_types, - arg_locs=locs, + arg_locs=kernel_arg_locations, ) # Inline the kernel function into the gpu module function body and erase the original function @@ -396,6 +416,11 @@ def abi_type(binding: BindingDesc): "wave_get_float64", [ptr], [f64], emit_c_interface=True ) + # Get tensor stride function from PyObject*. + get_stride_func, get_stride_func_symbol = self._declare_runtime_func( + "wave_get_stride", [ptr, i32], [i64], emit_c_interface=True + ) + # Declare host function # First argument is stream pointer # Rest are kernel arguments as PyObject* @@ -490,6 +515,24 @@ def abi_type(binding: BindingDesc): else: raise CodegenError(f"Unsupported binding type: {binding}") + # Append stride arguments matching the trailing index args + # added by emit_func. One stride per leading dimension (rank-1) + # for each kernel buffer. + if self.options.dynamic_strides: + for binding in self._ordered_kernel_buffer_bindings(bindings): + rank = len(binding.kernel_buffer_type.symbolic_shape) + leading_count = max(0, rank - 1) + arg = func_args[binding_map[id(binding)]] + for dim_idx in range(leading_count): + dim = arith_d.constant(i32, dim_idx) + stride = func_d.call( + get_stride_func.type.results, + get_stride_func_symbol, + [arg, dim], + ) + stride = arith_d.index_cast(IndexType.get(), stride) + launch_args.append(stride) + gpu_d.launch_func( kernel=[gpu_module.sym_name.value, self.kernel_name], grid_size=grid, diff --git a/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp b/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp index 7a82700f96..afa4d0c7c7 100644 --- a/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp +++ b/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp @@ -121,3 +121,38 @@ extern "C" int64_t _mlir_ciface_wave_get_dim(PyObject *obj_ptr, return dim_size; } + +extern "C" int64_t _mlir_ciface_wave_get_stride(PyObject *obj_ptr, + int32_t dim_idx) { + GILState gil_state; + + PyObjectPtr stride_method(PyObject_GetAttrString(obj_ptr, "stride")); + if (!stride_method) { + PyErr_Clear(); + throw std::runtime_error( + "wave_get_stride: object does not have 'stride' attribute"); + } + + PyObjectPtr dim_arg(PyLong_FromLong(dim_idx)); + if (!dim_arg) { + PyErr_Clear(); + throw std::runtime_error( + "wave_get_stride: failed to create dimension argument"); + } + + PyObjectPtr stride_result( + PyObject_CallOneArg(stride_method.get(), dim_arg.get())); + if (!stride_result) { + PyErr_Clear(); + throw std::runtime_error("wave_get_stride: failed to call stride()"); + } + + int64_t stride_val = PyLong_AsLongLong(stride_result.get()); + if (PyErr_Occurred()) { + PyErr_Clear(); + throw std::runtime_error( + "wave_get_stride: stride() returned invalid value"); + } + + return stride_val; +} diff --git a/wave_lang/kernel/wave/execution_engine/buffer_utils.h b/wave_lang/kernel/wave/execution_engine/buffer_utils.h index 25938c6d47..54c89dfb7e 100644 --- a/wave_lang/kernel/wave/execution_engine/buffer_utils.h +++ b/wave_lang/kernel/wave/execution_engine/buffer_utils.h @@ -60,4 +60,18 @@ double _mlir_ciface_wave_get_float64(PyObject *obj); /// std::runtime_error if the object doesn't have a size() method or /// if the dimension index is invalid int64_t _mlir_ciface_wave_get_dim(PyObject *obj, int32_t dim_idx); + +/// Extract the stride of a specific dimension from a PyObject (PyTorch tensor). +/// +/// Args: +/// obj: PyObject* pointing to a PyTorch tensor +/// dim_idx: Dimension index to query (0-based) +/// +/// Returns: +/// Stride of the specified dimension in elements as int64_t. +/// +/// Throws: +/// std::runtime_error if the object doesn't have a stride() method or +/// if the dimension index is invalid. +int64_t _mlir_ciface_wave_get_stride(PyObject *obj, int32_t dim_idx); } diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.py b/wave_lang/kernel/wave/execution_engine/execution_engine.py index ff2c5dd8ba..93ca05bd3e 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.py +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.py @@ -104,6 +104,7 @@ def _load_runtime_helpers(): _add_symbol(symbol_map, lib, "_mlir_ciface_wave_get_int64") _add_symbol(symbol_map, lib, "_mlir_ciface_wave_get_float64") _add_symbol(symbol_map, lib, "_mlir_ciface_wave_get_dim") + _add_symbol(symbol_map, lib, "_mlir_ciface_wave_get_stride") return symbol_map diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index 2af389e294..b1e242c764 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -14,6 +14,13 @@ import math from typing import Any, Sequence +from iree.compiler.dialects import ( + _structured_transform_ops_gen as structured_transform_ops, +) +from iree.compiler.dialects.transform import ( + any_op_t, +) + from wave_lang.kernel.wave.compile_options import WaveCompileOptions from wave_lang.support.detect_water import get_water_mlir_pkg_path, get_water_opt from wave_lang.support.detect_waveasm import get_waveasm_translate @@ -21,17 +28,21 @@ Attribute, BlockArgument, FunctionType, + IrType, InsertionPoint, IntegerType, + Location, MemRefType, Module, Operation, TypeAttr, + UnitAttr, WalkResult, gpu_d, llvm_d, memref_d, stream_d, + transform_d, ) @@ -216,6 +227,41 @@ def make_pass_arguments( ) +_ALLOCA_TO_GLOBAL_ENTRYPOINT = "__transform_alloca_to_global" +_TRANSFORM_MEMREF_ALLOCA_TYPE = '!transform.op<"memref.alloca">' + + +def _module_asm_with_alloca_to_global_transform(module: Module) -> str: + """Return a cloned module with a named alloca-to-global transform attached.""" + + with module.context, Location.unknown(): + transformed_module = Module.parse( + module.operation.get_asm(), context=module.context + ) + transformed_module.operation.attributes["transform.with_named_sequence"] = ( + UnitAttr.get() + ) + with InsertionPoint(transformed_module.body): + named_sequence = transform_d.NamedSequenceOp( + _ALLOCA_TO_GLOBAL_ENTRYPOINT, [any_op_t()], [] + ) + with InsertionPoint(named_sequence.body): + target = named_sequence.body.arguments[0] + alloca_handle_type = IrType.parse(_TRANSFORM_MEMREF_ALLOCA_TYPE) + alloca = structured_transform_ops.structured_match( + alloca_handle_type, + target, + ops=["memref.alloca"], + ) + Operation.create( + "transform.memref.alloca_to_global", + results=[any_op_t(), any_op_t()], + operands=[alloca], + ) + transform_d.YieldOp([]) + return transformed_module.operation.get_asm() + + def water_leak_in_bounds_check(module: Module, override_ir: str = ""): binary = get_water_opt() generic_mlir = _deiree(module) if override_ir == "" else override_ir @@ -381,7 +427,7 @@ def water_waveasm_lowering_pipeline( Step 3 (water-opt): host runtime wrapping (gpu.binary -> runtime calls). """ water_opt = get_water_opt() - mlir_asm = module.operation.get_asm() + mlir_asm = _module_asm_with_alloca_to_global_transform(module) target_chip = options.target lld_path = get_water_mlir_pkg_path() / "llvm" / "bin" / "ld.lld" @@ -401,13 +447,16 @@ def add_opt(pipeline): } # Step 1: water-opt lowers host + device to LLVM dialect. + # Note: convert-scf-to-cf is NOT used here -- waveasm-translate handles + # scf.for directly via waveasm.loop structured ops. lowering_pipeline = [ "water-memref-decomposition", *add_opt(canonicalize_cse), "lower-affine", *add_opt(int_range_optimizations), *add_opt("loop-invariant-code-motion"), - "convert-scf-to-cf", + ("water-alloc-to-alloca", {}, "gpu.module"), + ("transform-interpreter", {"entry-point": _ALLOCA_TO_GLOBAL_ENTRYPOINT}), ("convert-amdgpu-to-rocdl", {"chipset": target_chip}), ( "convert-gpu-to-rocdl", @@ -417,6 +466,8 @@ def add_opt(pipeline): ("gpu-to-llvm", {"use-bare-pointers-for-kernels": "1"}), "convert-vector-to-llvm", "reconcile-unrealized-casts", + "water-drop-transform-ops", + "symbol-dce", *add_opt(canonicalize_cse), ] @@ -453,6 +504,7 @@ def run_subprocess(args, input_text, tool_name): waveasm_translate = get_waveasm_translate() waveasm_args = [ waveasm_translate, + "--waveasm-llvm-sdiv-srem-legalization", f"--waveasm-translate-from-llvm=target={target_chip}", "--waveasm-arith-legalization", "--waveasm-scoped-cse", @@ -494,7 +546,7 @@ def run_subprocess(args, input_text, tool_name): def water_lowering_pipeline(module: Module, options: WaveCompileOptions) -> Module: binary = get_water_opt() - mlir_asm = module.operation.get_asm() + mlir_asm = _module_asm_with_alloca_to_global_transform(module) target_chip = options.target def add_opt(pipeline): @@ -503,38 +555,6 @@ def add_opt(pipeline): return [] - def add_transform(transform: str, entry_point: str) -> tuple[str, dict[str, Any]]: - nonlocal mlir_asm - # Erase the last occurrence of '}' from mlir_asm which closes the module operation - last_close = mlir_asm.rfind("}") - if last_close != -1: - mlir_asm = mlir_asm[:last_close] - mlir_asm += transform - mlir_asm += "}\n" - return ("transform-interpreter", {"entry-point": entry_point}) - - # TODO: this transform refuses to work. - alloc_to_alloca = """ - transform.named_sequence @__transform_alloc_to_alloca(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["gpu.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %0 { - transform.apply_patterns.memref.alloc_to_alloca - } : !transform.any_op - transform.yield - } -""" - - alloca_to_global = """ - transform.named_sequence @__transform_alloca_to_global(%arg0: !transform.any_op {transform.readonly}) { - %alloca = transform.structured.match ops{["memref.alloca"]} in %arg0 - : (!transform.any_op) -> !transform.op<"memref.alloca"> - %get_global, %global = transform.memref.alloca_to_global %alloca - : (!transform.op<"memref.alloca">) - -> (!transform.any_op, !transform.any_op) - transform.yield - } -""" - canonicalize_cse = "composite-fixed-point-pass", { "name": "canonicalize_cse", "pipeline": "any(canonicalize,cse)", @@ -556,8 +576,7 @@ def add_transform(transform: str, entry_point: str) -> tuple[str, dict[str, Any] *add_opt(int_range_optimizations), *add_opt("loop-invariant-code-motion"), ("water-alloc-to-alloca", {}, "gpu.module"), - # add_transform(alloc_to_alloca, "__transform_alloc_to_alloca"), - add_transform(alloca_to_global, "__transform_alloca_to_global"), + ("transform-interpreter", {"entry-point": _ALLOCA_TO_GLOBAL_ENTRYPOINT}), "convert-scf-to-cf", ("convert-amdgpu-to-rocdl", {"chipset": target_chip}), ("convert-gpu-to-rocdl", {"use-bare-ptr-memref-call-conv": "1"}, "gpu.module"), diff --git a/waveasm/include/waveasm/Transforms/Passes.td b/waveasm/include/waveasm/Transforms/Passes.td index 8178213319..d4dd32df97 100644 --- a/waveasm/include/waveasm/Transforms/Passes.td +++ b/waveasm/include/waveasm/Transforms/Passes.td @@ -373,6 +373,20 @@ def WAVEASMArithLegalization : Pass<"waveasm-arith-legalization"> { let dependentDialects = ["::waveasm::WaveASMDialect"]; } +def WAVEASMLLVMSDivSRemLegalization + : Pass<"waveasm-llvm-sdiv-srem-legalization"> { + let summary = "Rewrite LLVM signed div/rem by positive power-of-two constants"; + let description = [{ + Rewrites `llvm.sdiv` and `llvm.srem` on i32 values with positive + power-of-two constant divisors into equivalent LLVM dialect compare/select/ + add/and/ashr sequences. + + This keeps the strength reduction in LLVM dialect, before LLVM->WaveASM + translation changes the abstraction level, and makes the rewrite + independently testable. + }]; +} + def WAVEASMTranslateFromLLVM : Pass<"waveasm-translate-from-llvm"> { let summary = "Translate LLVM dialect kernels to WaveASM IR"; let description = [{ @@ -384,8 +398,6 @@ def WAVEASMTranslateFromLLVM : Pass<"waveasm-translate-from-llvm"> { Option<"targetArch", "target", "std::string", "\"gfx950\"", "Target GPU architecture"> ]; - - let dependentDialects = ["::waveasm::WaveASMDialect"]; } #endif // WaveASM_TRANSFORMS_PASSES diff --git a/waveasm/lib/Transforms/ArithLegalization.cpp b/waveasm/lib/Transforms/ArithLegalization.cpp index 2ca093f076..75ef74be90 100644 --- a/waveasm/lib/Transforms/ArithLegalization.cpp +++ b/waveasm/lib/Transforms/ArithLegalization.cpp @@ -549,6 +549,8 @@ static LogicalResult legalizeBitwiseOp(ArithOp op, OpBuilder &builder) { auto vregTy = VRegType::get(builder.getContext(), 2); result = VALUOp64::create(builder, loc, vregTy, lhs, rhs); } else { + if (isa(lhs.getType()) && isSGPRType(rhs.getType())) + std::swap(lhs, rhs); auto sregTy = SRegType::get(builder.getContext(), 2, 2); auto sccTy = SCCType::get(builder.getContext()); result = SALUOp64::create(builder, loc, sregTy, sccTy, lhs, rhs).getDst(); @@ -559,6 +561,8 @@ static LogicalResult legalizeBitwiseOp(ArithOp op, OpBuilder &builder) { auto vregTy = VRegType::get(builder.getContext()); result = VALUOp32::create(builder, loc, vregTy, lhs, rhs); } else { + if (isa(lhs.getType()) && isSGPRType(rhs.getType())) + std::swap(lhs, rhs); auto sregTy = SRegType::get(builder.getContext(), 1, 1); auto sccTy = SCCType::get(builder.getContext()); result = SALUOp32::create(builder, loc, sregTy, sccTy, lhs, rhs).getDst(); diff --git a/waveasm/lib/Transforms/CMakeLists.txt b/waveasm/lib/Transforms/CMakeLists.txt index b226c536a8..314c57e1c7 100644 --- a/waveasm/lib/Transforms/CMakeLists.txt +++ b/waveasm/lib/Transforms/CMakeLists.txt @@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRWaveASMTransforms LinearScanRegAlloc.cpp LiteralMaterialization.cpp Liveness.cpp + LLVMSDivSRemLegalization.cpp LoopAddressPromotion.cpp M0RedundancyElimination.cpp MemoryOffsetOptimization.cpp diff --git a/waveasm/lib/Transforms/LLVMSDivSRemLegalization.cpp b/waveasm/lib/Transforms/LLVMSDivSRemLegalization.cpp new file mode 100644 index 0000000000..a669e277de --- /dev/null +++ b/waveasm/lib/Transforms/LLVMSDivSRemLegalization.cpp @@ -0,0 +1,139 @@ +// Copyright 2026 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// LLVM sdiv/srem legalization +// +// Rewrites signed division and remainder by positive power-of-two constants +// into equivalent LLVM dialect compare/select/add/and/ashr sequences before +// LLVM->WaveASM translation changes the abstraction level. +//===----------------------------------------------------------------------===// + +#include "waveasm/Transforms/Passes.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/MathExtras.h" + +namespace waveasm { +#define GEN_PASS_DEF_WAVEASMLLVMSDIVSREMLEGALIZATION +#include "waveasm/Transforms/Passes.h.inc" +} // namespace waveasm + +using namespace mlir; + +namespace { + +static std::optional getConstantI32Value(Value value) { + LLVM::ConstantOp constOp = value.getDefiningOp(); + if (!constOp) + return std::nullopt; + IntegerAttr intAttr = dyn_cast(constOp.getValue()); + if (!intAttr) + return std::nullopt; + return intAttr.getInt(); +} + +static std::optional> +matchPositivePowerOfTwoI32Divisor(Value rhs) { + std::optional constVal = getConstantI32Value(rhs); + if (!constVal || *constVal <= 0 || !llvm::isPowerOf2_64(*constVal)) + return std::nullopt; + int64_t divisor = *constVal; + int64_t shiftAmt = llvm::Log2_64(static_cast(divisor)); + return std::pair{divisor, shiftAmt}; +} + +static Value createI32Constant(PatternRewriter &rewriter, Location loc, + int64_t value) { + Type i32 = rewriter.getI32Type(); + return LLVM::ConstantOp::create(rewriter, loc, i32, + rewriter.getIntegerAttr(i32, value)); +} + +struct LegalizePowerOfTwoSDivPattern : OpRewritePattern { + using Base::Base; + + LogicalResult matchAndRewrite(LLVM::SDivOp op, + PatternRewriter &rewriter) const override { + if (!op.getType().isSignlessInteger(32)) + return failure(); + std::optional> divisorAndShift = + matchPositivePowerOfTwoI32Divisor(op.getRhs()); + if (!divisorAndShift) + return failure(); + + int64_t divisor = divisorAndShift->first; + int64_t shiftAmt = divisorAndShift->second; + Location loc = op.getLoc(); + + Value zero = createI32Constant(rewriter, loc, 0); + Value biasImm = createI32Constant(rewriter, loc, divisor - 1); + Value shiftConst = createI32Constant(rewriter, loc, shiftAmt); + Value isNegative = LLVM::ICmpOp::create( + rewriter, loc, LLVM::ICmpPredicate::slt, op.getLhs(), zero); + Value bias = LLVM::SelectOp::create(rewriter, loc, isNegative, biasImm, + zero, LLVM::FastmathFlags::none); + Value biased = LLVM::AddOp::create(rewriter, loc, op.getLhs(), bias, + LLVM::IntegerOverflowFlags::none); + Value result = LLVM::AShrOp::create(rewriter, loc, biased, shiftConst, + /*isExact=*/false); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct LegalizePowerOfTwoSRemPattern : OpRewritePattern { + using Base::Base; + + LogicalResult matchAndRewrite(LLVM::SRemOp op, + PatternRewriter &rewriter) const override { + if (!op.getType().isSignlessInteger(32)) + return failure(); + std::optional> divisorAndShift = + matchPositivePowerOfTwoI32Divisor(op.getRhs()); + if (!divisorAndShift) + return failure(); + + int64_t divisor = divisorAndShift->first; + Location loc = op.getLoc(); + + Value zero = createI32Constant(rewriter, loc, 0); + Value maskConst = createI32Constant(rewriter, loc, divisor - 1); + Value negDivisor = createI32Constant(rewriter, loc, -divisor); + Value rawRem = LLVM::AndOp::create(rewriter, loc, op.getLhs(), maskConst); + Value isNegative = LLVM::ICmpOp::create( + rewriter, loc, LLVM::ICmpPredicate::slt, op.getLhs(), zero); + Value isNonZero = LLVM::ICmpOp::create( + rewriter, loc, LLVM::ICmpPredicate::ne, rawRem, zero); + Value needsAdjust = + LLVM::AndOp::create(rewriter, loc, isNegative, isNonZero); + Value adjust = + LLVM::SelectOp::create(rewriter, loc, needsAdjust, negDivisor, zero, + LLVM::FastmathFlags::none); + Value result = LLVM::AddOp::create(rewriter, loc, rawRem, adjust, + LLVM::IntegerOverflowFlags::none); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct LLVMSDivSRemLegalizationPass + : public waveasm::impl::WAVEASMLLVMSDivSRemLegalizationBase< + LLVMSDivSRemLegalizationPass> { + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add( + &getContext()); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace diff --git a/waveasm/lib/Transforms/TranslateFromLLVMDialect.cpp b/waveasm/lib/Transforms/TranslateFromLLVMDialect.cpp index 409cba963b..3655337cc9 100644 --- a/waveasm/lib/Transforms/TranslateFromLLVMDialect.cpp +++ b/waveasm/lib/Transforms/TranslateFromLLVMDialect.cpp @@ -21,11 +21,16 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/SymbolTable.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringMap.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include #define DEBUG_TYPE "waveasm-translate-llvm" @@ -63,7 +68,14 @@ struct BufferPtrInfo { /// State for LLVM->WaveASM translation, layered on top of TranslationContext. class LLVMTranslationState { public: - explicit LLVMTranslationState(TranslationContext &ctx) : ctx(ctx) {} + explicit LLVMTranslationState(TranslationContext &ctx, + Operation *symbolTableOp) + : ctx(ctx) { + assert(symbolTableOp && "expected symbol table operation"); + symbolTableOp->walk([&](LLVM::GlobalOp global) { + globalsByName[global.getSymName()] = global; + }); + } TranslationContext &ctx; @@ -86,10 +98,27 @@ class LLVMTranslationState { void setBaseOffset(Value ptr, Value offset) { baseOffsets[ptr] = offset; } Value lookupBaseOffset(Value ptr) const { return baseOffsets.lookup(ptr); } + /// Track the assigned LDS byte offset for each referenced llvm.mlir.global. + void setLDSGlobalOffset(LLVM::GlobalOp global, int64_t offset) { + ldsGlobalOffsets[global.getOperation()] = offset; + } + std::optional lookupLDSGlobalOffset(LLVM::GlobalOp global) const { + auto it = ldsGlobalOffsets.find(global.getOperation()); + if (it != ldsGlobalOffsets.end()) + return it->second; + return std::nullopt; + } + + LLVM::GlobalOp lookupGlobal(StringRef name) const { + return globalsByName.lookup(name); + } + private: DenseMap rsrcToSRD; DenseMap gepMap; DenseMap baseOffsets; + DenseMap ldsGlobalOffsets; + llvm::StringMap globalsByName; }; //===----------------------------------------------------------------------===// @@ -164,13 +193,17 @@ static ProgramOp createProgramFromLLVMFunc(LLVM::LLVMFuncOp func, } /// Look up the WaveASM value that an LLVM value was translated to. -/// Returns failure if the value was never mapped -- silently returning -/// the original (soon-to-be-erased) LLVM value is a use-after-free bug. +/// Returns failure if the value was never mapped. Never fall back to the +/// original LLVM SSA value: translation erases the LLVM ops after building +/// WaveASM, so reusing the source SSA value here would leave a dangling +/// reference. static FailureOr resolve(Value v, TranslationContext &ctx) { if (auto mapped = ctx.getMapper().getMapped(v)) return *mapped; - // Block arguments (func params) are mapped during prologue setup. - // If we get here, an LLVM op was skipped or handled incorrectly. + // Block arguments (func params) are mapped during prologue setup. Reaching + // this point means either malformed IR referenced an unmapped SSA value or + // translation skipped a producer; surface that as a hard failure in the + // caller instead of guessing through it. return failure(); } @@ -187,12 +220,123 @@ static Value truncToI32(Value v, Type llvmType, OpBuilder &builder, } /// Infer the pseudo-op result type from operand types. -/// If any operand is VGPR -> VReg; otherwise SReg. +/// Register file: VGPR if any operand is VGPR, else SGPR. +/// Width: max operand width (in 32-bit dwords). static Type inferResultType(ValueRange operands, TranslationContext &ctx) { + int64_t width = 1; + for (Value v : operands) + width = std::max(width, getRegSize(v.getType())); for (Value v : operands) if (isVGPRType(v.getType())) - return ctx.createVRegType(); - return ctx.createSRegType(); + return ctx.createVRegType(width, width); + return ctx.createSRegType(width, width); +} + +/// Return the LLVM pointer address space, or 0 for non-pointer types. +static unsigned getLLVMAddrSpace(Value v) { + if (LLVM::LLVMPointerType ptrTy = + dyn_cast(v.getType())) + return ptrTy.getAddressSpace(); + return 0; +} + +/// WaveASM register sizes are tracked in 32-bit dwords. +static int64_t getWaveASMDwordCount(int64_t bitWidth) { + return llvm::divideCeil(bitWidth, int64_t{32}); +} + +/// Return the byte size for an LLVM type whose layout is obvious without a +/// DataLayout query. +/// +/// We intentionally support only scalars, fixed vectors, and arrays thereof. +/// Nested aggregates such as structs require real DataLayout-based layout +/// reasoning. +static FailureOr getLLVMTypeBytes(Type ty) { + if (ty.isIntOrFloat()) + return ty.getIntOrFloatBitWidth() / 8; + if (VectorType vecTy = dyn_cast(ty)) { + FailureOr elemBytes = getLLVMTypeBytes(vecTy.getElementType()); + if (failed(elemBytes)) + return failure(); + return *elemBytes * vecTy.getNumElements(); + } + if (LLVM::LLVMArrayType arrTy = dyn_cast(ty)) { + FailureOr elemBytes = getLLVMTypeBytes(arrTy.getElementType()); + if (failed(elemBytes)) + return failure(); + return *elemBytes * arrTy.getNumElements(); + } + return failure(); +} + +/// Return true iff a GEP index is statically known to be zero. +static bool isZeroGEPIndex(llvm::PointerUnion idx) { + if (Value value = dyn_cast(idx)) + return isConstantIntValue(value, 0); + + return isConstantIntValue(cast(idx), 0); +} + +/// Structural GEPs like [0, 0] are a no-op and can be forwarded even though we +/// do not model nested aggregate layouts yet. +static bool isAllZeroIndexGEP(LLVM::GEPOp op) { + return llvm::all_of(op.getIndices(), isZeroGEPIndex); +} + +/// Compute the non-zero byte offset for a supported single-index GEP. +/// Emits a diagnostic and returns failure for unsupported element types or +/// unmapped dynamic indices. +static FailureOr computeGEPByteOffset(LLVM::GEPOp op, + TranslationContext &ctx) { + OpBuilder &builder = ctx.getBuilder(); + Location loc = op.getLoc(); + LLVM::GEPIndicesAdaptor indices = op.getIndices(); + + // Only handle single-index GEPs. Multi-index GEPs walk nested types + // and require full DataLayout support. + if (indices.size() != 1) + return op->emitOpError("GEP byte offset requires a single index"); + + FailureOr elemBytes = getLLVMTypeBytes(op.getElemType()); + if (failed(elemBytes)) + return op->emitOpError( + "unsupported GEP element type for byte offset computation " + "without DataLayout: ") + << op.getElemType(); + + llvm::PointerUnion idx = indices[0]; + if (Value dynIdx = dyn_cast(idx)) { + FailureOr resolved = resolve(dynIdx, ctx); + if (failed(resolved)) + return op->emitOpError("unmapped GEP index"); + if (*elemBytes == 1) + return *resolved; + // Scale by element size: offset = index * elemBytes. + ImmType scaleTy = ctx.createImmType(*elemBytes); + Value scale = ConstantOp::create(builder, loc, scaleTy, *elemBytes); + Type resTy = inferResultType({*resolved, scale}, ctx); + return ArithMulOp::create(builder, loc, resTy, *resolved, scale) + .getResult(); + } + + IntegerAttr constIdxAttr = dyn_cast(idx); + assert(constIdxAttr && "GEP index must be a Value or IntegerAttr"); + int64_t constIdx = constIdxAttr.getInt(); + int64_t byteOffset = constIdx * *elemBytes; + ImmType immTy = ctx.createImmType(byteOffset); + return ConstantOp::create(builder, loc, immTy, byteOffset).getResult(); +} + +/// DS instructions address LDS by a 32-bit byte offset, not a generic pointer. +/// Valid in-range LDS accesses are bounded by the kernel's per-workgroup LDS +/// allocation, so well-formed LLVM IR may represent the offset as i64 but still +/// fits in the hardware vaddr field after truncation. +static Value materializeLDSVAddr(Value offset, OpBuilder &builder, Location loc, + TranslationContext &ctx) { + if (getRegSize(offset.getType()) <= 1) + return offset; + VRegType vregTy = ctx.createVRegType(); + return ArithTruncOp::create(builder, loc, vregTy, offset); } //===----------------------------------------------------------------------===// @@ -236,18 +380,34 @@ static LogicalResult handlePoison(LLVM::PoisonOp op, LLVMTranslationState &st) { static LogicalResult handleConstant(LLVM::ConstantOp op, LLVMTranslationState &st) { - auto &ctx = st.ctx; - auto &builder = ctx.getBuilder(); - auto loc = op.getLoc(); + TranslationContext &ctx = st.ctx; + OpBuilder &builder = ctx.getBuilder(); + Location loc = op.getLoc(); + + Attribute valAttr = op.getValue(); + + // Dense vector constant (e.g. MFMA accumulator init). + if (DenseElementsAttr denseAttr = dyn_cast(valAttr)) { + if (!denseAttr.isSplat()) + return op->emitOpError("non-splat dense constant not yet supported"); + int64_t numElems = denseAttr.getNumElements(); + APFloat splatVal = denseAttr.getSplatValue(); + int64_t rawBits = splatVal.bitcastToAPInt().getZExtValue(); + ImmType immTy = ctx.createImmType(rawBits); + Value immOp = ConstantOp::create(builder, loc, immTy, rawBits); + VRegType vregTy = ctx.createVRegType(numElems, numElems); + Value mov = V_MOV_B32::create(builder, loc, vregTy, immOp); + ctx.getMapper().mapValue(op.getResult(), mov); + return success(); + } - auto valAttr = op.getValue(); int64_t intVal = 0; if (auto intAttr = dyn_cast(valAttr)) intVal = intAttr.getValue().getSExtValue(); else return op->emitOpError("unsupported constant type"); - auto intType = dyn_cast(op.getResult().getType()); + IntegerType intType = dyn_cast(op.getResult().getType()); if (!intType) return op->emitOpError("expected integer constant"); @@ -316,8 +476,10 @@ static LogicalResult handleWorkgroupId(OpTy op, LLVMTranslationState &st, return success(); } -// Emit arith pseudo-ops for i32/i64 casts -- legalization pass handles width. /// Translate an LLVM cast op to a WaveASM arithmetic pseudo-op. +/// Result width is derived from the LLVM type in 32-bit dwords, so sext +/// i32->i64 produces a 2-wide register and trunc i64->i32 produces a 1-wide +/// one. template static LogicalResult handleCastOp(LLVMOp op, LLVMTranslationState &st) { auto &ctx = st.ctx; @@ -325,8 +487,12 @@ static LogicalResult handleCastOp(LLVMOp op, LLVMTranslationState &st) { FailureOr src = resolve(op.getOperand(), ctx); if (failed(src)) return op->emitOpError("unmapped operand in cast"); - Type resTy = isVGPRType(src->getType()) ? (Type)ctx.createVRegType() - : (Type)ctx.createSRegType(); + int64_t width = 1; + if (IntegerType intTy = dyn_cast(op.getResult().getType())) + width = getWaveASMDwordCount(intTy.getWidth()); + Type resTy = isVGPRType(src->getType()) + ? (Type)ctx.createVRegType(width, width) + : (Type)ctx.createSRegType(width, width); Value pseudo = WaveASMOp::create(builder, op.getLoc(), resTy, *src); ctx.getMapper().mapValue(op.getResult(), pseudo); return success(); @@ -406,6 +572,36 @@ static LogicalResult handleBinaryOp(LLVMOp op, LLVMTranslationState &st) { return success(); } +static LogicalResult handleAShr(LLVM::AShrOp op, LLVMTranslationState &st) { + TranslationContext &ctx = st.ctx; + OpBuilder &builder = ctx.getBuilder(); + Location loc = op.getLoc(); + FailureOr lhs = resolve(op.getLhs(), ctx); + FailureOr rhs = resolve(op.getRhs(), ctx); + if (failed(lhs) || failed(rhs)) + return op->emitOpError("unmapped operand in ashr"); + if (getRegSize(lhs->getType()) != 1 || getRegSize(rhs->getType()) != 1) + return op->emitOpError( + "arithmetic shift currently supports only i32 operands"); + + Value result; + if (isSGPRType(lhs->getType())) { + if (isVGPRType(rhs->getType())) + return op->emitOpError( + "SGPR arithmetic shift requires an SGPR or immediate shift amount"); + SRegType sregTy = ctx.createSRegType(); + SCCType sccTy = ctx.createSCCType(); + result = + S_ASHR_I32::create(builder, loc, sregTy, sccTy, *lhs, *rhs).getDst(); + } else { + VRegType vregTy = ctx.createVRegType(); + result = V_ASHRREV_I32::create(builder, loc, vregTy, *rhs, *lhs); + } + + ctx.getMapper().mapValue(op.getResult(), result); + return success(); +} + static LogicalResult handleMakeBufferRsrc(ROCDL::MakeBufferRsrcOp op, LLVMTranslationState &st) { auto &ctx = st.ctx; @@ -418,10 +614,10 @@ static LogicalResult handleMakeBufferRsrc(ROCDL::MakeBufferRsrcOp op, if (!srdVal) return op->emitOpError("SRD not found for base pointer"); - // The prologue used s_mov_b64 to copy the 64-bit pointer into SRD[0:1]. - // This corrupts SRD word 1 bits [31:16] (stride/swizzle) with pointer bits. - // Also, the prologue hardcodes SRD[3]=0x20000 but make.buffer.rsrc may - // want different flags. Rebuild the SRD with corrected words via PackOp. + // The prologue seeds SRD[0:1] from the raw 64-bit pointer. That leaves SRD + // word 1 bits [31:16] (stride/swizzle) populated with pointer bits, and + // initializes SRD[3] to the default flag value. Rebuild the descriptor here + // so make.buffer.rsrc controls both fields explicitly. bool needsPatch = (srdVal->getDefiningOp() != nullptr); if (needsPatch) { SRegType sregTy = ctx.createSRegType(); @@ -433,17 +629,18 @@ static LogicalResult handleMakeBufferRsrc(ROCDL::MakeBufferRsrcOp op, Value word3 = ExtractOp::create(builder, loc, sregTy, *srdVal, 3); // Clear stride/swizzle bits in SRD word 1 (keep only base_addr[47:32]). - auto maskImm = ctx.createImmType(kSRDWord1BaseMask); - auto maskVal = ConstantOp::create(builder, loc, maskImm, kSRDWord1BaseMask); + ImmType maskImm = ctx.createImmType(kSRDWord1BaseMask); + Value maskVal = + ConstantOp::create(builder, loc, maskImm, kSRDWord1BaseMask); word1 = S_AND_B32::create(builder, loc, sregTy, ctx.createSCCType(), word1, maskVal) .getDst(); - // Patch SRD[3] with the actual flags from make.buffer.rsrc. - auto flags = getConstantIntValue(op.getFlags()); + // Patch SRD[3] with the requested descriptor flags. + std::optional flags = getConstantIntValue(op.getFlags()); if (flags && *flags != kSRDDefaultFlags) { - auto flagsImm = ctx.createImmType(*flags); - auto flagsVal = ConstantOp::create(builder, loc, flagsImm, *flags); + ImmType flagsImm = ctx.createImmType(*flags); + Value flagsVal = ConstantOp::create(builder, loc, flagsImm, *flags); word3 = S_MOV_B32::create(builder, loc, sregTy, flagsVal); } @@ -467,7 +664,8 @@ static LogicalResult handleMakeBufferRsrc(ROCDL::MakeBufferRsrcOp op, // Adjust base: S_ADD_U32 sets SCC, S_ADDC_U32 reads it. SCCType sccTy = ctx.createSCCType(); - auto addLo = S_ADD_U32::create(builder, loc, sregTy, sccTy, word0, offLo); + S_ADD_U32 addLo = + S_ADD_U32::create(builder, loc, sregTy, sccTy, word0, offLo); word0 = addLo.getDst(); word1 = S_ADDC_U32::create(builder, loc, sregTy, sccTy, addLo.getScc(), word1, offHi) @@ -475,7 +673,7 @@ static LogicalResult handleMakeBufferRsrc(ROCDL::MakeBufferRsrcOp op, } // Pack into a 4-wide SGPR SRD. - auto srdType = ctx.createSRegType(4, 4); + SRegType srdType = ctx.createSRegType(4, 4); Value newSrd = PackOp::create(builder, loc, srdType, ValueRange{word0, word1, word2, word3}); st.mapBufferRsrc(op.getResult(), newSrd); @@ -491,37 +689,61 @@ static LogicalResult handleGEP(LLVM::GEPOp op, LLVMTranslationState &st) { auto &builder = ctx.getBuilder(); auto loc = op.getLoc(); Value base = op.getBase(); - - // GEP index is a dynamic Value (not a constant attr). auto indices = op.getIndices(); - if (indices.size() != 1) - return op->emitOpError("GEP with multiple indices not yet supported"); - auto idx = indices[0].dyn_cast(); - if (!idx) - return op->emitOpError("GEP with constant index attr not yet supported"); + bool isSingleIndex = indices.size() == 1; + bool isAllZeroGEP = isAllZeroIndexGEP(op); + + unsigned addrSpace = getLLVMAddrSpace(op.getBase()); + + // LDS GEP (ptr<3>): compute byte offset for ds_read/ds_write. + // DS instructions only accept a 32-bit vaddr, so truncate wide offsets. + if (addrSpace == 3) { + if (!isSingleIndex && !isAllZeroGEP) + return op->emitOpError( + "LDS GEP must have a single index or all-zero structural indices"); + FailureOr baseOff = resolve(base, ctx); + if (failed(baseOff)) + return op->emitOpError("unmapped LDS GEP base"); + *baseOff = materializeLDSVAddr(*baseOff, builder, loc, ctx); + if (isAllZeroGEP) { + ctx.getMapper().mapValue(op.getResult(), *baseOff); + return success(); + } + FailureOr maybeOffset = computeGEPByteOffset(op, ctx); + if (failed(maybeOffset)) + return failure(); + Value off = materializeLDSVAddr(*maybeOffset, builder, loc, ctx); + VRegType vregTy = ctx.createVRegType(); + Value sum = ArithAddOp::create(builder, loc, vregTy, *baseOff, off); + ctx.getMapper().mapValue(op.getResult(), sum); + return success(); + } - FailureOr resolved = resolve(idx, ctx); - if (failed(resolved)) - return op->emitOpError("unmapped GEP index"); - Value newOffset = *resolved; + if (addrSpace != 0 && addrSpace != 7) + return op->emitOpError("unsupported address space ") << addrSpace; // Bare-pointer GEP (!llvm.ptr, not <7>): 64-bit pointer arithmetic before // make.buffer.rsrc. Propagate the mapper entry and accumulate the byte // offset so it can be folded into the SRD base later. - auto baseTy = op.getBase().getType(); - unsigned addrSpace = 0; - if (auto ptrTy = dyn_cast(baseTy)) - addrSpace = ptrTy.getAddressSpace(); - - if (addrSpace != 0 && addrSpace != 7) - return op->emitOpError("unsupported address space ") << addrSpace; - if (addrSpace == 0) { + if (!isSingleIndex && !isAllZeroGEP) + return op->emitOpError("bare-pointer GEP must have a single index or " + "all-zero structural indices"); // Forward mapper entry so make.buffer.rsrc can find the SRD. - if (auto mapped = ctx.getMapper().getMapped(base)) + if (std::optional mapped = ctx.getMapper().getMapped(base)) ctx.getMapper().mapValue(op.getResult(), *mapped); - // Accumulate base offset with 64-bit add (pointer arithmetic). + if (isAllZeroGEP) { + Value prevOffset = st.lookupBaseOffset(base); + if (prevOffset) + st.setBaseOffset(op.getResult(), prevOffset); + return success(); + } + + FailureOr maybeOffset = computeGEPByteOffset(op, ctx); + if (failed(maybeOffset)) + return failure(); + Value newOffset = *maybeOffset; Value prevOffset = st.lookupBaseOffset(base); if (prevOffset) { Type resTy = inferResultType({prevOffset, newOffset}, ctx); @@ -532,10 +754,18 @@ static LogicalResult handleGEP(LLVM::GEPOp op, LLVMTranslationState &st) { return success(); } - // Buffer voffsets are 32-bit. Truncate i64 GEP indices. - newOffset = truncToI32(newOffset, idx.getType(), builder, loc, ctx); + // Buffer GEP (ptr<7>): single dynamic index. + if (indices.size() != 1) + return op->emitOpError("buffer GEP must have a single index"); + Value idx = indices[0].template dyn_cast(); + if (!idx) + return op->emitOpError("buffer GEP with constant index not yet supported"); + + FailureOr resolved = resolve(idx, ctx); + if (failed(resolved)) + return op->emitOpError("unmapped GEP index"); + Value newOffset = truncToI32(*resolved, idx.getType(), builder, loc, ctx); - // Buffer GEP (ptr<7>): decompose into (SRD, voffset). // Check gepMap first -- covers both chained buffer GEPs and // make.buffer.rsrc entries seeded with a bare-pointer base offset. if (std::optional baseGEP = st.lookupGEP(base)) { @@ -566,11 +796,21 @@ static int64_t getBufferAccessBytes(Type ty) { return 0; } +/// Forward declaration for LDS handlers. +static LogicalResult handleLDSLoad(LLVM::LoadOp op, Value addr, + LLVMTranslationState &st); +static LogicalResult handleLDSStore(LLVM::StoreOp op, Value addr, + LLVMTranslationState &st); + static LogicalResult handleLoad(LLVM::LoadOp op, LLVMTranslationState &st) { auto &ctx = st.ctx; auto &builder = ctx.getBuilder(); auto loc = op.getLoc(); + // LDS load (ptr<3>). + if (getLLVMAddrSpace(op.getAddr()) == 3) + return handleLDSLoad(op, op.getAddr(), st); + std::optional ptr = st.lookupGEP(op.getAddr()); if (!ptr) return op->emitOpError("load address not from a tracked GEP"); @@ -617,6 +857,10 @@ static LogicalResult handleStore(LLVM::StoreOp op, LLVMTranslationState &st) { auto &builder = ctx.getBuilder(); auto loc = op.getLoc(); + // LDS store (ptr<3>). + if (getLLVMAddrSpace(op.getAddr()) == 3) + return handleLDSStore(op, op.getAddr(), st); + std::optional ptr = st.lookupGEP(op.getAddr()); if (!ptr) return op->emitOpError("store address not from a tracked GEP"); @@ -647,6 +891,334 @@ static LogicalResult handleStore(LLVM::StoreOp op, LLVMTranslationState &st) { return success(); } +//===----------------------------------------------------------------------===// +// LDS global / addressof handlers +//===----------------------------------------------------------------------===// + +static LogicalResult handleAddressOf(LLVM::AddressOfOp op, + LLVMTranslationState &st) { + TranslationContext &ctx = st.ctx; + LLVM::GlobalOp global = st.lookupGlobal(op.getGlobalName()); + if (!global) + return op->emitOpError("referenced global not found"); + + IntegerAttr addrSpaceAttr = global->getAttrOfType("addr_space"); + if (!addrSpaceAttr || addrSpaceAttr.getInt() != 3) + return op->emitOpError( + "llvm.mlir.addressof currently supports only LDS globals in " + "addrspace(3)"); + + FailureOr sizeBytes = getLLVMTypeBytes(global.getType()); + if (failed(sizeBytes)) + return op->emitOpError( + "unsupported LDS global type for byte size computation " + "without DataLayout: ") + << global.getType(); + + int64_t baseOffset = 0; + if (std::optional existingOffset = + st.lookupLDSGlobalOffset(global)) { + baseOffset = *existingOffset; + } else { + baseOffset = ctx.getLDSAllocOffset(); + assert(baseOffset >= 0 && + static_cast(baseOffset) <= + std::numeric_limits::max() && + "LDS base offset must fit in the 32-bit DS vaddr field"); + st.setLDSGlobalOffset(global, baseOffset); + ctx.addLDSSize(*sizeBytes); + ctx.advanceLDSAllocOffset(*sizeBytes); + } + + // Map to the byte offset assigned to this LDS global. + OpBuilder &builder = ctx.getBuilder(); + ImmType immTy = ctx.createImmType(baseOffset); + Value baseImm = ConstantOp::create(builder, op.getLoc(), immTy, baseOffset); + VRegType vregTy = ctx.createVRegType(); + Value mov = V_MOV_B32::create(builder, op.getLoc(), vregTy, baseImm); + ctx.getMapper().mapValue(op.getResult(), mov); + return success(); +} + +//===----------------------------------------------------------------------===// +// LDS load/store handlers +//===----------------------------------------------------------------------===// + +static LogicalResult handleLDSLoad(LLVM::LoadOp op, Value addr, + LLVMTranslationState &st) { + TranslationContext &ctx = st.ctx; + OpBuilder &builder = ctx.getBuilder(); + Location loc = op.getLoc(); + + int64_t numBytes = getBufferAccessBytes(op.getResult().getType()); + FailureOr offset = resolve(addr, ctx); + if (failed(offset)) + return op->emitOpError("unmapped LDS address"); + *offset = materializeLDSVAddr(*offset, builder, loc, ctx); + VRegType vregTy = ctx.createVRegType(); + + Operation *loadOp = nullptr; + if (numBytes == 2) + loadOp = DS_READ_U16::create(builder, loc, TypeRange{vregTy}, *offset); + else if (numBytes == 4) + loadOp = DS_READ_B32::create(builder, loc, TypeRange{vregTy}, *offset); + else if (numBytes == 8) { + VRegType wideTy = ctx.createVRegType(2, 2); + loadOp = DS_READ_B64::create(builder, loc, TypeRange{wideTy}, *offset); + } else if (numBytes == 16) { + VRegType wideTy = ctx.createVRegType(4, 4); + loadOp = DS_READ_B128::create(builder, loc, TypeRange{wideTy}, *offset); + } else { + return op->emitOpError("unsupported LDS load size: ") + << numBytes << " bytes"; + } + + ctx.getMapper().mapValue(op.getResult(), loadOp->getResult(0)); + return success(); +} + +static LogicalResult handleLDSStore(LLVM::StoreOp op, Value addr, + LLVMTranslationState &st) { + TranslationContext &ctx = st.ctx; + OpBuilder &builder = ctx.getBuilder(); + Location loc = op.getLoc(); + + FailureOr data = resolve(op.getValue(), ctx); + FailureOr offset = resolve(addr, ctx); + if (failed(data) || failed(offset)) + return op->emitOpError("unmapped LDS operand"); + *offset = materializeLDSVAddr(*offset, builder, loc, ctx); + int64_t numBytes = getBufferAccessBytes(op.getValue().getType()); + + if (numBytes == 2) + DS_WRITE_B16::create(builder, loc, *data, *offset); + else if (numBytes == 4) + DS_WRITE_B32::create(builder, loc, *data, *offset); + else if (numBytes == 8) + DS_WRITE_B64::create(builder, loc, *data, *offset); + else if (numBytes == 16) + DS_WRITE_B128::create(builder, loc, *data, *offset); + else + return op->emitOpError("unsupported LDS store size: ") + << numBytes << " bytes"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// Signed div/rem handlers +//===----------------------------------------------------------------------===// + +static LogicalResult handleSDiv(LLVM::SDivOp op, LLVMTranslationState &st) { + return op->emitOpError( + "llvm.sdiv must be legalized before LLVM->WaveASM translation"); +} + +static LogicalResult handleSRem(LLVM::SRemOp op, LLVMTranslationState &st) { + return op->emitOpError( + "llvm.srem must be legalized before LLVM->WaveASM translation"); +} + +//===----------------------------------------------------------------------===// +// Memory fence / barrier handlers +//===----------------------------------------------------------------------===// + +static LogicalResult handleFence(LLVM::FenceOp, LLVMTranslationState &) { + // Memory fences are handled implicitly by s_barrier and waitcnt insertion. + return success(); +} + +template +static LogicalResult handleBarrier(OpTy op, LLVMTranslationState &st) { + S_BARRIER::create(st.ctx.getBuilder(), op.getLoc()); + return success(); +} + +//===----------------------------------------------------------------------===// +// Vector shuffle handler +//===----------------------------------------------------------------------===// + +static LogicalResult handleShuffleVector(LLVM::ShuffleVectorOp op, + LLVMTranslationState &st) { + TranslationContext &ctx = st.ctx; + OpBuilder &builder = ctx.getBuilder(); + FailureOr src = resolve(op.getV1(), ctx); + if (failed(src)) + return op->emitOpError("unmapped operand in shufflevector"); + + // shufflevector with a single index extracts one element. + llvm::ArrayRef mask = op.getMask(); + if (mask.size() == 1) { + int64_t idx = mask[0]; + VRegType vregTy = ctx.createVRegType(); + Value extract = ExtractOp::create(builder, op.getLoc(), vregTy, *src, idx); + ctx.getMapper().mapValue(op.getResult(), extract); + return success(); + } + + return op->emitOpError("multi-element shufflevector not yet supported"); +} + +//===----------------------------------------------------------------------===// +// MFMA handler +//===----------------------------------------------------------------------===// + +static LogicalResult handleMFMA_F32_16x16x16_F16(ROCDL::mfma_f32_16x16x16f16 op, + LLVMTranslationState &st) { + TranslationContext &ctx = st.ctx; + OpBuilder &builder = ctx.getBuilder(); + Location loc = op.getLoc(); + + FailureOr a = resolve(op.getA(), ctx); + FailureOr b = resolve(op.getB(), ctx); + FailureOr c = resolve(op.getC(), ctx); + if (failed(a) || failed(b) || failed(c)) + return op->emitOpError("unmapped operand in MFMA"); + + // Result and accumulator are vector<4xf32> -> 4 VGPRs. + VRegType accTy = ctx.createVRegType(4, 4); + Value mfma = V_MFMA_F32_16X16X16_F16::create(builder, loc, accTy, *a, *b, *c); + ctx.getMapper().mapValue(op.getResult(), mfma); + return success(); +} + +//===----------------------------------------------------------------------===// +// SCF for/yield handler +//===----------------------------------------------------------------------===// + +static LogicalResult translateOp(Operation *op, LLVMTranslationState &st); + +static FailureOr materializeI32SReg(Value v, StringRef name, + scf::ForOp op, + TranslationContext &ctx, + OpBuilder &builder, Location loc) { + if (getRegSize(v.getType()) != 1) { + op->emitOpError(name) << " must lower to i32"; + return failure(); + } + SRegType sregTy = ctx.createSRegType(); + if (isSGPRType(v.getType())) + return v; + if (isImmType(v.getType())) + return S_MOV_B32::create(builder, loc, sregTy, v).getResult(); + if (isVGPRType(v.getType())) + return V_READFIRSTLANE_B32::create(builder, loc, sregTy, v).getResult(); + op->emitOpError(name) << " must lower to an SGPR or VGPR i32"; + return failure(); +} + +struct SCFForLoopBounds { + Value initialIV; + Value upperBound; + Value step; +}; + +/// Resolve and scalarize scf.for loop control in the preheader. +/// +/// scf.for evaluates lb/ub/step once before the first iteration. Capture those +/// values outside the lowered waveasm.loop body so the upper bound is not +/// implicitly rematerialized on the backedge. +static FailureOr +resolveSCFForLoopBounds(scf::ForOp op, LLVMTranslationState &st) { + TranslationContext &ctx = st.ctx; + OpBuilder &builder = ctx.getBuilder(); + Location loc = op.getLoc(); + + FailureOr lowerBound = resolve(op.getLowerBound(), ctx); + FailureOr upperBound = resolve(op.getUpperBound(), ctx); + FailureOr step = resolve(op.getStep(), ctx); + if (failed(lowerBound) || failed(upperBound) || failed(step)) { + op->emitOpError("unmapped operand in scf.for"); + return failure(); + } + + // TODO: Preserve full scf.for zero-trip semantics. waveasm.loop is do-while + // and this lowering currently assumes the trip count is known positive. + FailureOr initialIV = materializeI32SReg( + *lowerBound, "scf.for lower bound", op, ctx, builder, loc); + FailureOr loopUpperBound = materializeI32SReg( + *upperBound, "scf.for upper bound", op, ctx, builder, loc); + FailureOr loopStep = + materializeI32SReg(*step, "scf.for step", op, ctx, builder, loc); + if (failed(initialIV) || failed(loopUpperBound) || failed(loopStep)) + return failure(); + + return SCFForLoopBounds{ + /*initialIV=*/*initialIV, + /*upperBound=*/*loopUpperBound, + /*step=*/*loopStep, + }; +} + +static LogicalResult handleSCFFor(scf::ForOp op, LLVMTranslationState &st) { + TranslationContext &ctx = st.ctx; + OpBuilder &builder = ctx.getBuilder(); + Location loc = op.getLoc(); + + FailureOr loopBounds = resolveSCFForLoopBounds(op, st); + if (failed(loopBounds)) + return failure(); + + // Build init args: [lower_bound, iter_args...]. + SmallVector initArgs; + initArgs.push_back(loopBounds->initialIV); + for (Value arg : op.getInitArgs()) { + FailureOr resolved = resolve(arg, ctx); + if (failed(resolved)) + return op->emitOpError("unmapped init arg in scf.for"); + initArgs.push_back(*resolved); + } + + // Create the waveasm.loop (do-while semantics). + LoopOp loopOp = LoopOp::create(builder, loc, initArgs); + Block &bodyBlock = loopOp.getBodyBlock(); + + // Map the induction variable (block arg 0). + ctx.getMapper().mapValue(op.getInductionVar(), bodyBlock.getArgument(0)); + + // Map iter_args (block args 1..N). + for (auto i : llvm::seq(op.getInitArgs().size())) + ctx.getMapper().mapValue(op.getRegionIterArgs()[i], + bodyBlock.getArgument(i + 1)); + + // Translate the loop body. + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&bodyBlock); + for (Operation &bodyOp : op.getBody()->without_terminator()) + if (failed(translateOp(&bodyOp, st))) + return failure(); + + // Build loop increment and condition. + Value inductionVar = bodyBlock.getArgument(0); + SRegType sregTy = ctx.createSRegType(); + SCCType sccTy = ctx.createSCCType(); + Value nextIV = S_ADD_U32::create(builder, loc, sregTy, sccTy, inductionVar, + loopBounds->step) + .getDst(); + Value cond = + S_CMP_LT_U32::create(builder, loc, sccTy, nextIV, loopBounds->upperBound); + + // Collect iter args from yield. + scf::YieldOp yieldOp = cast(op.getBody()->getTerminator()); + SmallVector condIterArgs; + condIterArgs.push_back(nextIV); + for (Value v : yieldOp.getOperands()) { + FailureOr resolved = resolve(v, ctx); + if (failed(resolved)) + return op->emitOpError("unmapped yield operand in scf.for"); + condIterArgs.push_back(*resolved); + } + + ConditionOp::create(builder, loc, cond, condIterArgs); + + // Map loop results. scf.for results are iter_args only (no IV), + // but waveasm.loop results include the IV at index 0. + for (auto i : llvm::seq(op.getNumResults())) + ctx.getMapper().mapValue(op.getResult(i), loopOp.getResult(i + 1)); + + return success(); +} + //===----------------------------------------------------------------------===// // Op dispatch //===----------------------------------------------------------------------===// @@ -676,6 +1248,7 @@ static LogicalResult translateOp(Operation *op, LLVMTranslationState &st) { .Case([&](LLVM::AddOp o) { return handleBinaryOp(o, st); }) + .Case([&](LLVM::AShrOp o) { return handleAShr(o, st); }) .Case([&](LLVM::OrOp o) { return handleBinaryOp(o, st); }) @@ -688,6 +1261,18 @@ static LogicalResult translateOp(Operation *op, LLVMTranslationState &st) { .Case([&](LLVM::GEPOp o) { return handleGEP(o, st); }) .Case([&](LLVM::LoadOp o) { return handleLoad(o, st); }) .Case([&](LLVM::StoreOp o) { return handleStore(o, st); }) + .Case([&](LLVM::AddressOfOp o) { return handleAddressOf(o, st); }) + .Case([&](LLVM::SDivOp o) { return handleSDiv(o, st); }) + .Case([&](LLVM::SRemOp o) { return handleSRem(o, st); }) + .Case([&](LLVM::FenceOp o) { return handleFence(o, st); }) + .Case([&](LLVM::ShuffleVectorOp o) { return handleShuffleVector(o, st); }) + .Case([&](ROCDL::BarrierOp o) { return handleBarrier(o, st); }) + .Case([&](ROCDL::SBarrierOp o) { return handleBarrier(o, st); }) + .Case([&](ROCDL::mfma_f32_16x16x16f16 o) { + return handleMFMA_F32_16x16x16_F16(o, st); + }) + .Case([&](scf::ForOp o) { return handleSCFFor(o, st); }) + .Case([&](scf::YieldOp) { return success(); }) .Default([](Operation *op) { return op->emitOpError("unhandled op in LLVM->WaveASM translation"); }); @@ -726,7 +1311,9 @@ static LogicalResult translateLLVMModule(Operation *rootOp, return failure(); builder.setInsertionPointToStart(&program.getBodyBlock()); TranslationContext ctx(builder, program, target); - LLVMTranslationState st(ctx); + Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(func); + assert(symbolTableOp && "expected nearest symbol table for llvm.func"); + LLVMTranslationState st(ctx, symbolTableOp); // Map llvm.func arguments: pointers get SRD setup, scalars get mapped // to their preloaded SGPR positions directly. @@ -755,9 +1342,10 @@ static LogicalResult translateLLVMModule(Operation *rootOp, ctx.getMapper().mapValue(arg, sreg); } - // Enable all workgroup IDs so the SGPR layout is predictable. - // Note: LLVM enables them selectively via amdgpu-no-workgroup-id-{y,z} - // attributes. We enable all three unconditionally for simplicity. + // Keep all workgroup IDs enabled so the SGPR layout is stable across + // translated kernels. LLVM can disable y/z selectively via + // amdgpu-no-workgroup-id-{y,z}, but the fixed WaveASM prologue layout here + // assumes all three slots exist. ctx.enableAllWorkgroupIds(); for (Operation &op : func.getBody().front()) { diff --git a/waveasm/test/Transforms/arith-legalization-bitwise.mlir b/waveasm/test/Transforms/arith-legalization-bitwise.mlir index da3f9fda7b..dc7ffd1f04 100644 --- a/waveasm/test/Transforms/arith-legalization-bitwise.mlir +++ b/waveasm/test/Transforms/arith-legalization-bitwise.mlir @@ -14,6 +14,7 @@ waveasm.program @test_or_i32 %v0 = waveasm.precolored.vreg 0 : !waveasm.vreg %s0 = waveasm.precolored.sreg 0 : !waveasm.sreg %s1 = waveasm.precolored.sreg 1 : !waveasm.sreg + %c1 = waveasm.constant 1 : !waveasm.imm<1> %c42 = waveasm.constant 42 : !waveasm.imm<42> // VGPR | VGPR -> v_or_b32. @@ -33,6 +34,10 @@ waveasm.program @test_or_i32 // CHECK: waveasm.s_or_b32 %{{.*}}, %{{.*}} : !waveasm.sreg, !waveasm.imm<42> %or_si = waveasm.arith.or %s0, %c42 : (!waveasm.sreg, !waveasm.imm<42>) -> !waveasm.sreg + // imm | SGPR -> s_or_b32 with the SGPR normalized to operand 0. + // CHECK: waveasm.s_or_b32 %{{.*}}, %{{.*}} : !waveasm.sreg, !waveasm.imm<1> + %or_is = waveasm.arith.or %c1, %s0 : (!waveasm.imm<1>, !waveasm.sreg) -> !waveasm.sreg + // CHECK-NOT: waveasm.arith. waveasm.s_endpgm } @@ -50,6 +55,7 @@ waveasm.program @test_and_i32 %v0 = waveasm.precolored.vreg 0 : !waveasm.vreg %s0 = waveasm.precolored.sreg 0 : !waveasm.sreg %s1 = waveasm.precolored.sreg 1 : !waveasm.sreg + %c1 = waveasm.constant 1 : !waveasm.imm<1> // VGPR & VGPR -> v_and_b32. // CHECK: waveasm.v_and_b32 %{{.*}}, %{{.*}} : !waveasm.vreg, !waveasm.vreg @@ -64,6 +70,10 @@ waveasm.program @test_and_i32 // CHECK: waveasm.v_and_b32 %and_sv = waveasm.arith.and %s0, %v0 : (!waveasm.sreg, !waveasm.vreg) -> !waveasm.vreg + // imm & SGPR -> s_and_b32 with the SGPR normalized to operand 0. + // CHECK: waveasm.s_and_b32 %{{.*}}, %{{.*}} : !waveasm.sreg, !waveasm.imm<1> + %and_is = waveasm.arith.and %c1, %s0 : (!waveasm.imm<1>, !waveasm.sreg) -> !waveasm.sreg + // CHECK-NOT: waveasm.arith. waveasm.s_endpgm } diff --git a/waveasm/test/Transforms/llvm-sdiv-srem-legalization.mlir b/waveasm/test/Transforms/llvm-sdiv-srem-legalization.mlir new file mode 100644 index 0000000000..638f94d283 --- /dev/null +++ b/waveasm/test/Transforms/llvm-sdiv-srem-legalization.mlir @@ -0,0 +1,31 @@ +// RUN: waveasm-translate %s --waveasm-llvm-sdiv-srem-legalization | FileCheck %s +// Verify that the LLVM pre-pass rewrites signed div/rem by positive power-of-2 +// constants before LLVM->WaveASM translation. + +// CHECK-LABEL: llvm.func @test +// CHECK: llvm.icmp "slt" +// CHECK: llvm.select +// CHECK: llvm.ashr +// CHECK: llvm.and +// CHECK: llvm.icmp "ne" +// CHECK: llvm.select + +gpu.module @gpu_module { + llvm.mlir.global private @scratch() {addr_space = 3 : i32} : i32 + llvm.func @test() attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %scratch = llvm.mlir.addressof @scratch : !llvm.ptr<3> + %tid = rocdl.workitem.id.x range : i32 + %16 = llvm.mlir.constant(16 : i32) : i32 + %neg5 = llvm.mlir.constant(-5 : i32) : i32 + %4 = llvm.mlir.constant(4 : i32) : i32 + %div = llvm.sdiv %tid, %16 : i32 + %rem = llvm.srem %tid, %16 : i32 + %negdiv = llvm.sdiv %neg5, %4 : i32 + %negrem = llvm.srem %neg5, %4 : i32 + %sum0 = llvm.add %div, %rem : i32 + %sum1 = llvm.add %sum0, %negdiv : i32 + %sum2 = llvm.add %sum1, %negrem : i32 + llvm.store %sum2, %scratch : i32, !llvm.ptr<3> + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-barrier.mlir b/waveasm/test/Transforms/translate-from-llvm-barrier.mlir new file mode 100644 index 0000000000..eef8ddbc94 --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-barrier.mlir @@ -0,0 +1,19 @@ +// RUN: waveasm-translate %s --waveasm-translate-from-llvm | FileCheck %s +// Verify barrier and fence translation. +// rocdl.barrier -> s_barrier; rocdl.s.barrier -> s_barrier; llvm.fence -> no-op. + +// CHECK: waveasm.program @test__waveasm +// CHECK: waveasm.s_barrier +// CHECK: waveasm.s_barrier +// fence produces no output. +// CHECK-NOT: fence +// CHECK: waveasm.s_endpgm + +gpu.module @gpu_module { + llvm.func @test() attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + rocdl.barrier + rocdl.s.barrier + llvm.fence acquire + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-error-addressof-non-lds.mlir b/waveasm/test/Transforms/translate-from-llvm-error-addressof-non-lds.mlir new file mode 100644 index 0000000000..16aa7cc6a8 --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-error-addressof-non-lds.mlir @@ -0,0 +1,12 @@ +// RUN: not waveasm-translate %s --waveasm-translate-from-llvm 2>&1 | FileCheck %s +// Verify that llvm.mlir.addressof is currently limited to LDS globals. + +// CHECK: llvm.mlir.addressof currently supports only LDS globals in addrspace(3) + +gpu.module @gpu_module { + llvm.mlir.global private @global() : i32 + llvm.func @test() attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %0 = llvm.mlir.addressof @global : !llvm.ptr + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-error-bare-ptr-multi-index-gep.mlir b/waveasm/test/Transforms/translate-from-llvm-error-bare-ptr-multi-index-gep.mlir new file mode 100644 index 0000000000..06562490ba --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-error-bare-ptr-multi-index-gep.mlir @@ -0,0 +1,12 @@ +// RUN: not waveasm-translate %s --waveasm-translate-from-llvm 2>&1 | FileCheck %s +// Verify that non-zero multi-index bare-pointer GEPs fail instead of silently +// dropping the aggregate offset. + +// CHECK: bare-pointer GEP must have a single index or all-zero structural indices + +gpu.module @gpu_module { + llvm.func @test(%arg0: !llvm.ptr) attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %0 = llvm.getelementptr %arg0[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<4 x i32> + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-error-gep-addrspace.mlir b/waveasm/test/Transforms/translate-from-llvm-error-gep-addrspace.mlir new file mode 100644 index 0000000000..c0c2fd6472 --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-error-gep-addrspace.mlir @@ -0,0 +1,12 @@ +// RUN: not waveasm-translate %s --waveasm-translate-from-llvm 2>&1 | FileCheck %s +// Verify that GEPs on unsupported address spaces produce a diagnostic. + +// CHECK: unsupported address space 5 + +gpu.module @gpu_module { + llvm.func @test(%arg0: !llvm.ptr<5>) attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %0 = llvm.mlir.constant(0 : i32) : i32 + %1 = llvm.getelementptr %arg0[%0] : (!llvm.ptr<5>, i32) -> !llvm.ptr<5>, i8 + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-error-gep-const-index.mlir b/waveasm/test/Transforms/translate-from-llvm-error-gep-const-index.mlir new file mode 100644 index 0000000000..f107d750d1 --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-error-gep-const-index.mlir @@ -0,0 +1,16 @@ +// RUN: not waveasm-translate %s --waveasm-translate-from-llvm 2>&1 | FileCheck %s +// Verify that buffer GEPs with constant attr indices (not dynamic Values) +// produce a diagnostic. + +// CHECK: buffer GEP with constant index not yet supported + +gpu.module @gpu_module { + llvm.func @test(%arg0: !llvm.ptr) attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %0 = llvm.mlir.constant(0 : i16) : i16 + %1 = llvm.mlir.constant(2147483645 : i64) : i64 + %2 = llvm.mlir.constant(822243328 : i32) : i32 + %3 = rocdl.make.buffer.rsrc %arg0, %0, %1, %2 : !llvm.ptr to <7> + %4 = llvm.getelementptr %3[42] : (!llvm.ptr<7>) -> !llvm.ptr<7>, i8 + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-error-gep-elem-type.mlir b/waveasm/test/Transforms/translate-from-llvm-error-gep-elem-type.mlir new file mode 100644 index 0000000000..4556186ab7 --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-error-gep-elem-type.mlir @@ -0,0 +1,13 @@ +// RUN: not waveasm-translate %s --waveasm-translate-from-llvm 2>&1 | FileCheck %s +// Verify that single-index GEPs with unsupported element types fail instead of +// silently treating the index as a byte offset. + +// CHECK: unsupported GEP element type for byte offset computation + +gpu.module @gpu_module { + llvm.func @test(%arg0: !llvm.ptr) attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %idx = llvm.mlir.constant(1 : i64) : i64 + %gep = llvm.getelementptr %arg0[%idx] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(i32, i32)> + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-error-multi-index-gep.mlir b/waveasm/test/Transforms/translate-from-llvm-error-multi-index-gep.mlir index aa926f03e7..42e61d33e7 100644 --- a/waveasm/test/Transforms/translate-from-llvm-error-multi-index-gep.mlir +++ b/waveasm/test/Transforms/translate-from-llvm-error-multi-index-gep.mlir @@ -1,7 +1,7 @@ // RUN: not waveasm-translate %s --waveasm-translate-from-llvm 2>&1 | FileCheck %s // Verify that multi-index GEPs produce a diagnostic instead of crashing. -// CHECK: GEP with multiple indices not yet supported +// CHECK: buffer GEP must have a single index gpu.module @gpu_module { llvm.func @test(%arg0: !llvm.ptr) attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { diff --git a/waveasm/test/Transforms/translate-from-llvm-error-scf-for-i64.mlir b/waveasm/test/Transforms/translate-from-llvm-error-scf-for-i64.mlir new file mode 100644 index 0000000000..33bef58425 --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-error-scf-for-i64.mlir @@ -0,0 +1,17 @@ +// RUN: not waveasm-translate %s --waveasm-translate-from-llvm 2>&1 | FileCheck %s +// Verify that scf.for loop control stays explicitly limited to i32 lowering. + +// CHECK: scf.for lower bound must lower to i32 + +gpu.module @gpu_module { + llvm.func @test() attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %lb = llvm.mlir.constant(0 : i64) : i64 + %ub = llvm.mlir.constant(64 : i64) : i64 + %step = llvm.mlir.constant(16 : i64) : i64 + %init = llvm.mlir.constant(0 : i32) : i32 + %result = scf.for %iv = %lb to %ub step %step iter_args(%acc = %init) -> (i32) : i64 { + scf.yield %acc : i32 + } + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-lds.mlir b/waveasm/test/Transforms/translate-from-llvm-lds.mlir new file mode 100644 index 0000000000..dcbd579008 --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-lds.mlir @@ -0,0 +1,46 @@ +// RUN: waveasm-translate %s --waveasm-translate-from-llvm | FileCheck %s +// Verify LDS addressof, GEP, load, and store translation. +// addressof assigns per-global LDS byte offsets; LDS GEPs produce byte +// offsets; loads/stores dispatch to ds_read/ds_write by width; and LDS size is +// counted in bytes. + +// CHECK: waveasm.program @test__waveasm +// CHECK: lds_size = 1040 : i64 +// First addressof uses byte offset 0. +// CHECK: waveasm.constant 0 +// CHECK: waveasm.v_mov_b32 +// Second addressof uses the next available LDS byte offset. +// CHECK: waveasm.constant 1024 +// GEP [0,0] on the array type passes through (all-zero indices). +// GEP with constant offset 512 produces an arith.add. +// CHECK: waveasm.arith.add +// 4-byte LDS load -> ds_read_b32. +// CHECK: waveasm.ds_read_b32 +// 4-byte LDS store -> ds_write_b32. +// CHECK: waveasm.ds_write_b32 +// CHECK: waveasm.s_endpgm + +gpu.module @gpu_module { + llvm.mlir.global private @alloca() {addr_space = 3 : i32} : !llvm.array<256 x i32> + llvm.mlir.global private @scratch() {addr_space = 3 : i32} : !llvm.array<4 x i32> + llvm.func @test() attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %0 = llvm.mlir.addressof @alloca : !llvm.ptr<3> + %scratch = llvm.mlir.addressof @scratch : !llvm.ptr<3> + %1 = llvm.mlir.constant(42 : i32) : i32 + // Multi-index GEP with all-zero indices -> passthrough. + %2 = llvm.getelementptr %0[0, 0] : (!llvm.ptr<3>) -> !llvm.ptr<3>, !llvm.array<256 x i32> + // Constant-offset GEP. 128 * sizeof(i32) = 512 bytes. + %3 = llvm.getelementptr %2[128] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i32 + // Dynamic-offset GEP. + %tid = rocdl.workitem.id.x range : i32 + %tidext = llvm.sext %tid : i32 to i64 + %4 = llvm.getelementptr nusw %2[%tidext] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // LDS load (4 bytes). + %5 = llvm.load %4 : !llvm.ptr<3> -> i32 + // LDS store (4 bytes). + llvm.store %1, %3 : i32, !llvm.ptr<3> + // Distinct LDS global uses a distinct non-zero base offset. + llvm.store %1, %scratch : i32, !llvm.ptr<3> + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-mfma.mlir b/waveasm/test/Transforms/translate-from-llvm-mfma.mlir new file mode 100644 index 0000000000..1cee2889da --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-mfma.mlir @@ -0,0 +1,31 @@ +// RUN: waveasm-translate %s --waveasm-translate-from-llvm | FileCheck %s +// Verify MFMA and shufflevector (element extraction) translation. +// Dense vector constant -> wide v_mov_b32; mfma -> v_mfma; shuffle -> extract. + +// CHECK: waveasm.program @test__waveasm +// Zero-init accumulator (dense<0.0> : vector<4xf32>). +// CHECK: waveasm.v_mov_b32 {{.*}} -> !waveasm.vreg<4, 4> +// MFMA instruction. +// CHECK: waveasm.v_mfma_f32_16x16x16_f16 +// Extract elements from MFMA result. +// CHECK: waveasm.extract {{.*}}[0] : !waveasm.vreg<4, 4> -> !waveasm.vreg +// CHECK: waveasm.extract {{.*}}[1] : !waveasm.vreg<4, 4> -> !waveasm.vreg +// CHECK: waveasm.extract {{.*}}[2] : !waveasm.vreg<4, 4> -> !waveasm.vreg +// CHECK: waveasm.extract {{.*}}[3] : !waveasm.vreg<4, 4> -> !waveasm.vreg +// CHECK: waveasm.s_endpgm + +gpu.module @gpu_module { + llvm.func @test(%arg0: !llvm.ptr) attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %c0 = llvm.mlir.constant(dense<0.000000e+00> : vector<4xf32>) : vector<4xf32> + // Use zero for A and B inputs (vector<4xf16>). + %a = llvm.mlir.constant(dense<0.000000e+00> : vector<4xf16>) : vector<4xf16> + %b = llvm.mlir.constant(dense<0.000000e+00> : vector<4xf16>) : vector<4xf16> + %mfma = rocdl.mfma.f32.16x16x16f16 %a, %b, %c0, 0, 0, 0 : (vector<4xf16>, vector<4xf16>, vector<4xf32>) -> vector<4xf32> + // Extract individual elements. + %e0 = llvm.shufflevector %mfma, %mfma [0] : vector<4xf32> + %e1 = llvm.shufflevector %mfma, %mfma [1] : vector<4xf32> + %e2 = llvm.shufflevector %mfma, %mfma [2] : vector<4xf32> + %e3 = llvm.shufflevector %mfma, %mfma [3] : vector<4xf32> + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-scf-for.mlir b/waveasm/test/Transforms/translate-from-llvm-scf-for.mlir new file mode 100644 index 0000000000..e7ae148c02 --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-scf-for.mlir @@ -0,0 +1,29 @@ +// RUN: waveasm-translate %s --waveasm-translate-from-llvm | FileCheck %s +// Verify scf.for translation to waveasm.loop with condition terminator. +// The loop body contains an add; iter_args carry the accumulator. + +// CHECK: waveasm.program @test__waveasm +// Loop with one iter_arg (the accumulator). +// CHECK: waveasm.loop +// Loop body: arith.add for the accumulation. +// CHECK: waveasm.arith.add +// IV increment stays scalar. +// CHECK: waveasm.s_add_u32 +// Condition: s_cmp_lt_u32 for the loop back-edge. +// CHECK: waveasm.s_cmp_lt_u32 +// CHECK: waveasm.condition +// CHECK: waveasm.s_endpgm + +gpu.module @gpu_module { + llvm.func @test() attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %lb = llvm.mlir.constant(0 : i32) : i32 + %ub = llvm.mlir.constant(64 : i32) : i32 + %step = llvm.mlir.constant(16 : i32) : i32 + %init = llvm.mlir.constant(0 : i32) : i32 + %result = scf.for %iv = %lb to %ub step %step iter_args(%acc = %init) -> (i32) : i32 { + %sum = llvm.add %acc, %iv : i32 + scf.yield %sum : i32 + } + llvm.return + } +} diff --git a/waveasm/test/Transforms/translate-from-llvm-sdiv-srem.mlir b/waveasm/test/Transforms/translate-from-llvm-sdiv-srem.mlir new file mode 100644 index 0000000000..b1bcd96f99 --- /dev/null +++ b/waveasm/test/Transforms/translate-from-llvm-sdiv-srem.mlir @@ -0,0 +1,35 @@ +// RUN: waveasm-translate %s --waveasm-llvm-sdiv-srem-legalization --waveasm-translate-from-llvm | FileCheck %s +// Verify sdiv and srem with positive power-of-2 constants. +// Negative dividends require bias/correction so we should see compare/select +// scaffolding in addition to the final shift/mask-shaped arithmetic. + +// CHECK: waveasm.program @test__waveasm +// sdiv lowers through signed-bias correction before the arithmetic shift. +// CHECK: waveasm.arith.cmp slt +// CHECK: waveasm.arith.select +// CHECK: waveasm.v_ashrrev_i32 +// srem keeps LLVM sign semantics via mask + conditional correction. +// CHECK: waveasm.arith.and +// CHECK: waveasm.arith.add +// CHECK: waveasm.arith.select +// CHECK: waveasm.s_endpgm + +gpu.module @gpu_module { + llvm.mlir.global private @scratch() {addr_space = 3 : i32} : i32 + llvm.func @test() attributes {gpu.kernel, gpu.known_block_size = array, rocdl.kernel, rocdl.reqd_work_group_size = array} { + %scratch = llvm.mlir.addressof @scratch : !llvm.ptr<3> + %tid = rocdl.workitem.id.x range : i32 + %16 = llvm.mlir.constant(16 : i32) : i32 + %neg5 = llvm.mlir.constant(-5 : i32) : i32 + %4 = llvm.mlir.constant(4 : i32) : i32 + %div = llvm.sdiv %tid, %16 : i32 + %rem = llvm.srem %tid, %16 : i32 + %negdiv = llvm.sdiv %neg5, %4 : i32 + %negrem = llvm.srem %neg5, %4 : i32 + %sum0 = llvm.add %div, %rem : i32 + %sum1 = llvm.add %sum0, %negdiv : i32 + %sum2 = llvm.add %sum1, %negrem : i32 + llvm.store %sum2, %scratch : i32, !llvm.ptr<3> + llvm.return + } +}