Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions lit_tests/kernel/wave/water_host_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_wave_compile_options(
canonicalize: bool = False,
dynamic_symbols=[],
additional_symbols={},
wave_runtime: bool = False,
location_capture_config=LocationCaptureConfig(
level=LocationCaptureLevel.FILE_LINE_COL
),
Expand Down Expand Up @@ -61,6 +62,7 @@ def get_wave_compile_options(
location_capture_config=location_capture_config,
drop_debug_info_before_mlir=drop_debug_info_before_mlir,
use_water_backend=True,
wave_runtime=wave_runtime,
)


Expand Down Expand Up @@ -227,3 +229,52 @@ def read_write(
# CHECK-SAME: blocks in (%[[C1]], %[[C1]], %[[C1]])
# CHECK-SAME: threads in (%[[C64]], %[[C1]], %[[C1]])
# CHECK: return


@run_test
def test_dynamic_strides_output_placeholder_first():
constraints = get_constraints()

@tkw.wave(constraints)
def output_then_input(
out: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
inp: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
res = tkw.read(inp)
tkw.write(res, out)

output_then_input = wave_compile(
get_wave_compile_options(
canonicalize=True,
wave_runtime=True,
drop_debug_info_before_mlir=True,
),
output_then_input,
)
print(output_then_input.asm)

# Even though the Python placeholder order is (out, inp), kernel ABI order
# is linearized to input first, output second.
# CHECK-LABEL: test_dynamic_strides_output_placeholder_first
# CHECK: gpu.func @output_then_input
# CHECK-SAME: (%[[IN:.*]]: memref<f16> {llvm.inreg}, %[[OUT:.*]]: memref<f16> {llvm.inreg}, %[[IN_STRIDE:.*]]: index {llvm.inreg}, %[[OUT_STRIDE:.*]]: index {llvm.inreg})
# CHECK: %[[IN_VIEW:.*]] = memref.reinterpret_cast %[[IN]] to offset: [0], sizes: [16, 16], strides: [%[[IN_STRIDE]], 1]
# CHECK: %[[OUT_VIEW:.*]] = memref.reinterpret_cast %[[OUT]] to offset: [0], sizes: [16, 16], strides: [%[[OUT_STRIDE]], 1]

# Host wrapper must mirror that same ABI order for both buffers and strides.
# CHECK-LABEL: func.func @isolated_benchmark
# CHECK-SAME: (%[[STREAM:.*]]: !llvm.ptr, %[[OUT_ARG:.*]]: !llvm.ptr, %[[IN_ARG:.*]]: !llvm.ptr)
# CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
# CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
# CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
# CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32
# CHECK: %[[IN_BUF:.*]] = call @wave_get_buffer(%[[IN_ARG]]) : (!llvm.ptr) -> memref<?xi8>
# CHECK: %[[IN_VIEW0:.*]] = memref.view %[[IN_BUF]][%[[C0]]][] : memref<?xi8> to memref<f16>
# CHECK: %[[OUT_BUF:.*]] = call @wave_get_buffer(%[[OUT_ARG]]) : (!llvm.ptr) -> memref<?xi8>
# CHECK: %[[OUT_VIEW0:.*]] = memref.view %[[OUT_BUF]][%[[C0]]][] : memref<?xi8> to memref<f16>
# CHECK: %[[IN_STRIDE_I64:.*]] = call @wave_get_stride(%[[IN_ARG]], %[[C0_I32]]) : (!llvm.ptr, i32) -> i64
# CHECK: %[[IN_STRIDE_IDX:.*]] = arith.index_cast %[[IN_STRIDE_I64]] : i64 to index
# CHECK: %[[OUT_STRIDE_I64:.*]] = call @wave_get_stride(%[[OUT_ARG]], %[[C0_I32]]) : (!llvm.ptr, i32) -> i64
# CHECK: %[[OUT_STRIDE_IDX:.*]] = arith.index_cast %[[OUT_STRIDE_I64]] : i64 to index
# CHECK: gpu.launch_func @gpu_module::@output_then_input blocks in (%[[C1]], %[[C1]], %[[C1]]) threads in (%[[C64]], %[[C1]], %[[C1]]) args(%[[IN_VIEW0]] : memref<f16>, %[[OUT_VIEW0]] : memref<f16>, %[[IN_STRIDE_IDX]] : index, %[[OUT_STRIDE_IDX]] : index)
# CHECK: return
64 changes: 64 additions & 0 deletions tests/kernel/e2e/test_gemm_waveasm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2026 The Wave Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""GEMM through the water+waveasm pipeline (LLVM dialect -> WaveASM -> binary)."""

import pytest
import torch
from torch.testing import assert_close

from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
from wave_lang.kernel.wave.constraints import MMAType
from wave_lang.kernel.wave.templates.gemm import get_gemm_kernel
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
from wave_lang.kernel.wave.utils.torch_utils import device_randn, device_zeros

from ..common.utils import require_cdna4, require_e2e


@require_e2e
@require_cdna4
@pytest.mark.parametrize(
"shape,block_shape,waves_per_block",
[
((64, 64, 64), (32, 32, 16), (1, 1)),
],
ids=["64x64x64"],
)
def test_gemm_water_waveasm(
shape: tuple[int, int, int],
block_shape: tuple[int, int, int],
waves_per_block: tuple[int, int],
) -> None:
"""Test GEMM through the water+waveasm pipeline."""
m, n, k = shape

gemm, hyperparams, _ = get_gemm_kernel(
shape=shape,
dynamic_dims=False,
mfma_variant=MMAType.F32_16x16x16_F16,
block_shape=block_shape,
waves_per_block=waves_per_block,
)

options = WaveCompileOptions(
subs=hyperparams,
canonicalize=True,
use_water_backend=True,
use_buffer_ops=True,
backend="asm",
wave_runtime=True,
)
options = set_default_run_config(options)
compiled = wave_compile(options, gemm)

a = device_randn((m, k), dtype=torch.float16)
b = device_randn((n, k), dtype=torch.float16)
c = device_zeros((m, n), dtype=torch.float32)
compiled(a, b, c)

expected = torch.matmul(a, b.T).float()
assert_close(c, expected, rtol=1e-3, atol=1e-3)
133 changes: 88 additions & 45 deletions wave_lang/kernel/compiler/wave_codegen/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
BindingDesc,
BoundKernelSignature,
create_argument_locations,
get_dynamic_stride_arg_count,
)
from ...lang.wave_types import IndexSymbol
from ...wave.compile_options import WaveCompileOptions
Expand Down Expand Up @@ -137,40 +138,70 @@ def emit_program_invariants(self):
),
]

def emit_func(self) -> Operation:
bindings = self.root_sig.sig.linear_bindings
def _abi_type(self, binding: BindingDesc) -> IrType:
if binding.binding_type == BindingType.KERNEL_BUFFER:
# Buffer passed to kernel as 0D memrefs to simplify ABI.
element_type = IrType.parse(binding.kernel_buffer_type.dtype.ir_type_asm())
return MemRefType.get([], element_type=element_type)

def abi_type(binding: BindingDesc):
if binding.binding_type == BindingType.KERNEL_BUFFER:
# Buffer passed to kernel as 0D memrefs to simplify ABI.
element_type = IrType.parse(
binding.kernel_buffer_type.dtype.ir_type_asm()
)
return MemRefType.get([], element_type=element_type)
# Scalars are passed as is.
return binding.as_mlir_type()

# Scalars are passed as is.
return binding.as_mlir_type()
def _create_stride_argument_locations(
self, bindings: list[BindingDesc], base_locations: list[Location]
) -> list[Location]:
if not self.options.dynamic_strides:
return []

arg_types = [abi_type(b) for b in bindings]
stride_locations = []
for binding, base_location in zip(bindings, base_locations):
if binding.binding_type != BindingType.KERNEL_BUFFER:
continue
leading_count = max(0, len(binding.kernel_buffer_type.symbolic_shape) - 1)
argument_name = binding.name or "argument"
for dim_idx in range(leading_count):
stride_locations.append(
Location.name(
f"{argument_name}.stride.{dim_idx}",
base_location,
)
)
return stride_locations

def _ordered_kernel_buffer_bindings(
self, bindings: list[BindingDesc]
) -> list[BindingDesc]:
return [
binding
for binding in bindings
if binding.binding_type == BindingType.KERNEL_BUFFER
]

# Dynamic strides only with Wave runtime and LLVM backend (not ASM).
# Stride args only for leading dimensions (innermost always has unit stride).
stride_arg_count = 0
if self.options.dynamic_strides:
stride_arg_count = sum(
max(0, len(b.kernel_buffer_type.symbolic_shape) - 1)
for b in self.root_sig.sig.kernel_buffer_bindings
)
if stride_arg_count > 0:
arg_types += [IndexType.get()] * stride_arg_count
def _kernel_argument_abi(
self, bindings: list[BindingDesc]
) -> tuple[list[IrType], list[Location], int]:
arg_types = [self._abi_type(binding) for binding in bindings]
arg_locations = create_argument_locations(bindings)
kernel_buffer_bindings = self._ordered_kernel_buffer_bindings(bindings)
stride_arg_count = get_dynamic_stride_arg_count(
self.options.dynamic_strides,
kernel_buffer_bindings,
)
stride_locations = self._create_stride_argument_locations(
bindings, arg_locations
)
assert len(stride_locations) == stride_arg_count
if stride_arg_count > 0:
arg_types += [IndexType.get()] * stride_arg_count
arg_locations += stride_locations
return arg_types, arg_locations, stride_arg_count

def emit_func(self) -> Operation:
bindings = self.root_sig.sig.linear_bindings
arg_types, arg_locations, stride_arg_count = self._kernel_argument_abi(bindings)
ftype = FunctionType.get(arg_types, [])
func_op = func_d.FuncOp(self.kernel_name, ftype, visibility="private")

locs = create_argument_locations(bindings)
if stride_arg_count > 0:
locs += [Location.unknown()] * stride_arg_count
entry_block = func_op.add_entry_block(locs)
entry_block = func_op.add_entry_block(arg_locations)

# Map dynamic symbols to buffer argument indices and dimensions.
for bind, arg in zip(bindings, entry_block.arguments):
Expand Down Expand Up @@ -285,8 +316,9 @@ def _declare_runtime_func(
return func_op, symbol

def emit_host_func(self, kernel_func: Operation) -> Operation:
# TODO: kernel bindings order may not be the same as the kernel function
# arguments order, so map kernel order to host function arguments order.
# Host placeholders follow the original Python signature, while kernel ABI
# arguments use linear_bindings order. Track both so launch args can be
# materialized in kernel ABI order from host-visible inputs.
binding_map = {}
symbol_map = {}

Expand All @@ -308,22 +340,10 @@ def emit_host_func(self, kernel_func: Operation) -> Operation:
bindings = self.root_sig.sig.linear_bindings

ptr = llvm_d.PointerType.get()

def abi_type(binding: BindingDesc):
if binding.binding_type == BindingType.KERNEL_BUFFER:
# Buffer passed to kernel as 0D memrefs to simplify ABI.
element_type = IrType.parse(
binding.kernel_buffer_type.dtype.ir_type_asm()
)
return MemRefType.get([], element_type=element_type)

# Scalars are passed as is.
return binding.as_mlir_type()

arg_types = [abi_type(b) for b in bindings]

arg_types, kernel_arg_locations, stride_arg_count = self._kernel_argument_abi(
bindings
)
ftype = FunctionType.get(arg_types, [])
locs = [a.location for a in kernel_func.body.blocks[0].arguments]

gpu_module = gpu_d.module("gpu_module")
gpu_module.parent.operation.attributes["gpu.container_module"] = UnitAttr.get()
Expand All @@ -345,7 +365,7 @@ def abi_type(binding: BindingDesc):

new_kernel_entry_block = kernel_func_wrapper.body.blocks.append(
*arg_types,
arg_locs=locs,
arg_locs=kernel_arg_locations,
)

# Inline the kernel function into the gpu module function body and erase the original function
Expand Down Expand Up @@ -396,6 +416,11 @@ def abi_type(binding: BindingDesc):
"wave_get_float64", [ptr], [f64], emit_c_interface=True
)

# Get tensor stride function from PyObject*.
get_stride_func, get_stride_func_symbol = self._declare_runtime_func(
"wave_get_stride", [ptr, i32], [i64], emit_c_interface=True
)

# Declare host function
# First argument is stream pointer
# Rest are kernel arguments as PyObject*
Expand Down Expand Up @@ -490,6 +515,24 @@ def abi_type(binding: BindingDesc):
else:
raise CodegenError(f"Unsupported binding type: {binding}")

# Append stride arguments matching the trailing index args
# added by emit_func. One stride per leading dimension (rank-1)
# for each kernel buffer.
if self.options.dynamic_strides:
for binding in self._ordered_kernel_buffer_bindings(bindings):
rank = len(binding.kernel_buffer_type.symbolic_shape)
leading_count = max(0, rank - 1)
arg = func_args[binding_map[id(binding)]]
for dim_idx in range(leading_count):
dim = arith_d.constant(i32, dim_idx)
stride = func_d.call(
get_stride_func.type.results,
get_stride_func_symbol,
[arg, dim],
)
stride = arith_d.index_cast(IndexType.get(), stride)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just pass it as i32 and not bother with index?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can, probably, but it will require more changes across the codebase (StreamExecutable.define_entrypoint and other places), lets keep it as index for now.

launch_args.append(stride)

gpu_d.launch_func(
kernel=[gpu_module.sym_name.value, self.kernel_name],
grid_size=grid,
Expand Down
35 changes: 35 additions & 0 deletions wave_lang/kernel/wave/execution_engine/buffer_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,38 @@ extern "C" int64_t _mlir_ciface_wave_get_dim(PyObject *obj_ptr,

return dim_size;
}

extern "C" int64_t _mlir_ciface_wave_get_stride(PyObject *obj_ptr,
int32_t dim_idx) {
GILState gil_state;

PyObjectPtr stride_method(PyObject_GetAttrString(obj_ptr, "stride"));
if (!stride_method) {
PyErr_Clear();
throw std::runtime_error(
"wave_get_stride: object does not have 'stride' attribute");
}

PyObjectPtr dim_arg(PyLong_FromLong(dim_idx));
if (!dim_arg) {
PyErr_Clear();
throw std::runtime_error(
"wave_get_stride: failed to create dimension argument");
}

PyObjectPtr stride_result(
PyObject_CallOneArg(stride_method.get(), dim_arg.get()));
if (!stride_result) {
PyErr_Clear();
throw std::runtime_error("wave_get_stride: failed to call stride()");
}

int64_t stride_val = PyLong_AsLongLong(stride_result.get());
if (PyErr_Occurred()) {
PyErr_Clear();
throw std::runtime_error(
"wave_get_stride: stride() returned invalid value");
}

return stride_val;
}
14 changes: 14 additions & 0 deletions wave_lang/kernel/wave/execution_engine/buffer_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,18 @@ double _mlir_ciface_wave_get_float64(PyObject *obj);
/// std::runtime_error if the object doesn't have a size() method or
/// if the dimension index is invalid
int64_t _mlir_ciface_wave_get_dim(PyObject *obj, int32_t dim_idx);

/// Extract the stride of a specific dimension from a PyObject (PyTorch tensor).
///
/// Args:
/// obj: PyObject* pointing to a PyTorch tensor
/// dim_idx: Dimension index to query (0-based)
///
/// Returns:
/// Stride of the specified dimension in elements as int64_t.
///
/// Throws:
/// std::runtime_error if the object doesn't have a stride() method or
/// if the dimension index is invalid.
int64_t _mlir_ciface_wave_get_stride(PyObject *obj, int32_t dim_idx);
}
1 change: 1 addition & 0 deletions wave_lang/kernel/wave/execution_engine/execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _load_runtime_helpers():
_add_symbol(symbol_map, lib, "_mlir_ciface_wave_get_int64")
_add_symbol(symbol_map, lib, "_mlir_ciface_wave_get_float64")
_add_symbol(symbol_map, lib, "_mlir_ciface_wave_get_dim")
_add_symbol(symbol_map, lib, "_mlir_ciface_wave_get_stride")

return symbol_map

Expand Down
Loading
Loading