From fc2fbe4bb02773b110925f7a8e829d3823c434d2 Mon Sep 17 00:00:00 2001 From: Aurore De Spirlet Date: Wed, 1 Apr 2026 02:01:50 +0000 Subject: [PATCH 1/7] add magic number trick for dynamic kernels Signed-off-by: Aurore De Spirlet --- .../kernel/wave/magic_number_division.py | 146 ++++++++++++++++++ .../kernel/compiler/wave_codegen/emitter.py | 92 +++++++++++ 2 files changed, 238 insertions(+) create mode 100644 lit_tests/kernel/wave/magic_number_division.py diff --git a/lit_tests/kernel/wave/magic_number_division.py b/lit_tests/kernel/wave/magic_number_division.py new file mode 100644 index 0000000000..71f7baab88 --- /dev/null +++ b/lit_tests/kernel/wave/magic_number_division.py @@ -0,0 +1,146 @@ +# RUN: WAVE_MAGIC_NUMBER_DIV=1 python %s 2>&1 | FileCheck %s --check-prefix=MAGIC +# RUN: WAVE_MAGIC_NUMBER_DIV=0 python %s 2>&1 | FileCheck %s --check-prefix=NOMAGIC + +from sympy import ceiling + +import wave_lang.kernel.lang as tkl +import wave_lang.kernel.wave as tkw +from wave_lang.kernel.lang.global_symbols import * +from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile +from wave_lang.kernel.wave.utils.general_utils import ( + run_test, +) + +M = tkl.sym.M +N = tkl.sym.N +K = tkl.sym.K +BLOCK_M = tkl.sym.BLOCK_M +BLOCK_N = tkl.sym.BLOCK_N +BLOCK_K = tkl.sym.BLOCK_K +GROUP_SIZE_N = tkl.sym.GROUP_SIZE_N +ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE +ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0 + + +@run_test +def test_magic_number_div(): + """Test that floordiv/mod by dynamic (runtime) divisors are lowered + to the magic-number multiply-high trick instead of expensive hardware + division. + + When kernel dimensions are dynamic, the compiler cannot fold + floordiv/mod into compile-time constants. The magic-number + optimisation precomputes ``ceil(2^32 / d)`` once per unique divisor + and replaces every subsequent division with a 64-bit multiply + shift, + which is significantly cheaper on GPU. + + We use a GEMM with GROUP_SIZE_N workgroup reordering to exercise + this: the reordering delinearises the flat workgroup id via + ``ceildiv(M, BLOCK_M)``, and the GEMM's multiple memory accesses + (read A, read B, write C) each independently compute reordered + indices, producing enough dynamic floordiv/mod expressions to + demonstrate that the expensive magic-number precomputation + (a single divui) is performed once per divisor and then reused + by multiple cheap multiply-and-shift sequences. + """ + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=tkw.MMAType.F32_16x16x16_F16, + ) + ] + + wg0, wg1 = WORKGROUP_0, WORKGROUP_1 + num_wg_0 = ceiling(M / BLOCK_M) + + flat_wg_index = wg1 * num_wg_0 + wg0 + num_wg_group = GROUP_SIZE_N * num_wg_0 + group_id = flat_wg_index // num_wg_group + first_wg_id_1 = group_id * GROUP_SIZE_N + new_wg0 = (flat_wg_index % num_wg_group) // GROUP_SIZE_N + new_wg1 = first_wg_id_1 + (flat_wg_index % num_wg_group) % GROUP_SIZE_N + + constraints += [tkw.ReorderingConstraint(new_wg0, 0)] + constraints += [tkw.ReorderingConstraint(new_wg1, 1)] + + @tkw.wave(constraints) + def gemm( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.iterate(K, init_args=[c_reg]) + def repeat( + acc: tkl.Register[M, N, tkl.f32], + ) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a) + b_reg = tkw.read(b) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, c) + + options = WaveCompileOptions( + subs={ + M: 512, + N: 1024, + K: 256, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + GROUP_SIZE_N: 4, + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + }, + canonicalize=True, + compile_to_mlir=True, + ) + + # Enable dynamic symbols + options.dynamic_symbols = [M, N, K] + for sym in options.dynamic_symbols: + del options.subs[sym] + + gemm = wave_compile(options, gemm) + print(gemm.asm) + + # ---- MAGIC (WAVE_MAGIC_NUMBER_DIV=1) ---- + # MAGIC-LABEL: func.func @gemm + # MAGIC-DAG: arith.constant 4294967295 : i64 + # MAGIC-DAG: %[[C32:.*]] = arith.constant 32 : i64 + # + # Magic precomputation (divui) followed by multiply-high (shrui): + # MAGIC: arith.divui {{.*}} : i64 + # MAGIC: arith.shrui {{.*}}, %[[C32]] : i64 + # + # Consume remaining precomputations from other address calculations. + # MAGIC: arith.divui + # MAGIC: arith.divui + # MAGIC: arith.divui + # MAGIC: arith.shrui {{.*}}, %[[C32]] : i64 + # + # Amortised: mulhi reusing a previously computed magic number + # with a different dividend — no new divui needed. + # MAGIC-NOT: arith.divui + # MAGIC-NOT: arith.divsi + # MAGIC: arith.shrui {{.*}}, %[[C32]] : i64 + # MAGIC-NOT: arith.divsi + # MAGIC: return + + # ---- NOMAGIC (WAVE_MAGIC_NUMBER_DIV=0) ---- + # Without magic numbers the dynamic floordiv/mod stay inside affine + # maps; no arith division ops are emitted. + # NOMAGIC-LABEL: func.func @gemm + # NOMAGIC-NOT: arith.divui + # NOMAGIC-NOT: arith.divsi + # NOMAGIC-NOT: arith.shrui + # NOMAGIC-NOT: 4294967295 + # NOMAGIC: affine.apply diff --git a/wave_lang/kernel/compiler/wave_codegen/emitter.py b/wave_lang/kernel/compiler/wave_codegen/emitter.py index c97d311582..98d14ef261 100644 --- a/wave_lang/kernel/compiler/wave_codegen/emitter.py +++ b/wave_lang/kernel/compiler/wave_codegen/emitter.py @@ -633,10 +633,13 @@ def add_emitter_subs( _emulate_ceildiv = bool(int(environ.get("WAVE_EMULATE_CEILDIV", 0))) _use_affine_expr = bool(int(environ.get("WAVE_USE_AFFINE_EXPR", 1))) +_magic_number_enabled = bool(int(environ.get("WAVE_MAGIC_NUMBER_DIV", 1))) _Rational = namedtuple("_Rational", ["numerator", "denominator"]) _ApplyExpr = namedtuple("_ApplyExpr", ["expr", "args"]) +_magic_number_cache: dict = {} + def gen_sympy_index(dynamics: dict[IndexSymbol, Value], expr: sympy.Expr) -> Value: use_affine_expr = _use_affine_expr @@ -778,16 +781,105 @@ def muli_expr(lhs, rhs): return op_expr(lhs, rhs, lambda a, b: a * b) + def _is_dynamic_divisor(val) -> bool: + """Check if a value is NOT a compile-time constant.""" + if isinstance(val, _ApplyExpr): + if all( + isinstance(a, OpResult) and get_const_val(a) is not None + for a in val.args + ): + return False + val = _get_ir_value(val) + if isinstance(val, OpResult): + return get_const_val(val) is None + return True + + def _mulhi_u32(n_i32, m_i32): + """Unsigned 32-bit multiply-high: (n * m) >> 32, via 64-bit multiply.""" + i64 = IntegerType.get_signless(64) + c32_i64 = arith_d.constant(i64, 32) + n_i64 = arith_d.extui(i64, n_i32) + m_i64 = arith_d.extui(i64, m_i32) + prod_i64 = arith_d.muli(n_i64, m_i64) + hi_i64 = arith_d.shrui(prod_i64, c32_i64) + i32 = IntegerType.get_signless(32) + return arith_d.trunci(i32, hi_i64) + + def _precompute_magic_number(divisor_index: Value): + """ + Compute magic = ceil(2^32 / d) from a dynamic divisor. + Returns (magic_i32, d_i32) both as i32 Values. + """ + i32 = IntegerType.get_signless(32) + i64 = IntegerType.get_signless(64) + d_i32 = arith_d.index_cast(i32, divisor_index) + d_i64 = arith_d.extui(i64, d_i32) + c1_i64 = arith_d.constant(i64, 1) + c32_i64 = arith_d.constant(i64, 32) + pow32 = arith_d.shli(c1_i64, c32_i64) + d_minus_1_i64 = arith_d.subi(d_i64, c1_i64) + numer_i64 = arith_d.addi(pow32, d_minus_1_i64) + magic_i64 = arith_d.divui(numer_i64, d_i64) + magic_i32 = arith_d.trunci(i32, magic_i64) + return magic_i32, d_i32 + + def _get_or_create_magic(divisor: Value): + """Get cached (magic_i32, d_i32) or compute and cache them.""" + key = id(divisor) + if key in _magic_number_cache: + return _magic_number_cache[key] + magic_i32, d_i32 = _precompute_magic_number(divisor) + _magic_number_cache[key] = (magic_i32, d_i32) + return magic_i32, d_i32 + + def _magic_div_and_rem(lhs_val, rhs_val): + """ + Compute (quotient, remainder) of lhs_val // rhs_val using + magic number multiplication: q = mulhi(n, magic), with a + one-step correction for exactness. + Returns (quotient_index, remainder_index). + """ + i32 = IntegerType.get_signless(32) + magic_i32, d_i32 = _get_or_create_magic(rhs_val) + n_i32 = arith_d.index_cast(i32, lhs_val) + q_i32 = _mulhi_u32(n_i32, magic_i32) + qd_i32 = arith_d.muli(q_i32, d_i32) + r_i32 = arith_d.subi(n_i32, qd_i32) + # Correction: ceil(2^32/d) can overestimate quotient by 1. + # Detect via unsigned remainder >= divisor (wraps on overestimate). + too_big = arith_d.cmpi(arith_d.CmpIPredicate.uge, r_i32, d_i32) + c1_i32 = arith_d.constant(i32, 1) + c0_i32 = arith_d.constant(i32, 0) + corr = arith_d.select(too_big, c1_i32, c0_i32) + q_final = arith_d.subi(q_i32, corr) + d_or_zero = arith_d.select(too_big, d_i32, c0_i32) + r_final = arith_d.addi(r_i32, d_or_zero) + q_index = arith_d.index_cast(IndexType.get(), q_final) + r_index = arith_d.index_cast(IndexType.get(), r_final) + return q_index, r_index + def rem_expr(lhs, rhs): if not use_affine_expr or not check_index_types(lhs, rhs): return arith_d.remsi(*_broadcast(lhs, rhs)) + if _magic_number_enabled and _is_dynamic_divisor(rhs): + lhs_val = _get_ir_value(lhs) if isinstance(lhs, _ApplyExpr) else lhs + rhs_val = _get_ir_value(rhs) if isinstance(rhs, _ApplyExpr) else rhs + _, r = _magic_div_and_rem(lhs_val, rhs_val) + return r + return op_expr(lhs, rhs, lambda a, b: a % b) def floordiv_expr(lhs, rhs): if not use_affine_expr or not check_index_types(lhs, rhs): return arith_d.divsi(*_broadcast(lhs, rhs)) + if _magic_number_enabled and _is_dynamic_divisor(rhs): + lhs_val = _get_ir_value(lhs) if isinstance(lhs, _ApplyExpr) else lhs + rhs_val = _get_ir_value(rhs) if isinstance(rhs, _ApplyExpr) else rhs + q, _ = _magic_div_and_rem(lhs_val, rhs_val) + return q + return op_expr(lhs, rhs, lambda a, b: AffineExpr.get_floor_div(a, b)) def ceildiv_expr(lhs, rhs): From 64d1bfd4809188204579cb6fa03981972ceac5f2 Mon Sep 17 00:00:00 2001 From: Aurore De Spirlet Date: Wed, 1 Apr 2026 16:07:26 +0000 Subject: [PATCH 2/7] use compile options instead Signed-off-by: Aurore De Spirlet --- .../kernel/wave/magic_number_division.py | 280 +++++++++--------- .../kernel/compiler/wave_codegen/emitter.py | 6 +- wave_lang/kernel/wave/compile_options.py | 1 + 3 files changed, 140 insertions(+), 147 deletions(-) diff --git a/lit_tests/kernel/wave/magic_number_division.py b/lit_tests/kernel/wave/magic_number_division.py index 71f7baab88..931642d2fd 100644 --- a/lit_tests/kernel/wave/magic_number_division.py +++ b/lit_tests/kernel/wave/magic_number_division.py @@ -1,146 +1,134 @@ -# RUN: WAVE_MAGIC_NUMBER_DIV=1 python %s 2>&1 | FileCheck %s --check-prefix=MAGIC -# RUN: WAVE_MAGIC_NUMBER_DIV=0 python %s 2>&1 | FileCheck %s --check-prefix=NOMAGIC - -from sympy import ceiling - -import wave_lang.kernel.lang as tkl -import wave_lang.kernel.wave as tkw -from wave_lang.kernel.lang.global_symbols import * -from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile -from wave_lang.kernel.wave.utils.general_utils import ( - run_test, -) - -M = tkl.sym.M -N = tkl.sym.N -K = tkl.sym.K -BLOCK_M = tkl.sym.BLOCK_M -BLOCK_N = tkl.sym.BLOCK_N -BLOCK_K = tkl.sym.BLOCK_K -GROUP_SIZE_N = tkl.sym.GROUP_SIZE_N -ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE -ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0 - - -@run_test -def test_magic_number_div(): - """Test that floordiv/mod by dynamic (runtime) divisors are lowered - to the magic-number multiply-high trick instead of expensive hardware - division. - - When kernel dimensions are dynamic, the compiler cannot fold - floordiv/mod into compile-time constants. The magic-number - optimisation precomputes ``ceil(2^32 / d)`` once per unique divisor - and replaces every subsequent division with a 64-bit multiply + shift, - which is significantly cheaper on GPU. - - We use a GEMM with GROUP_SIZE_N workgroup reordering to exercise - this: the reordering delinearises the flat workgroup id via - ``ceildiv(M, BLOCK_M)``, and the GEMM's multiple memory accesses - (read A, read B, write C) each independently compute reordered - indices, producing enough dynamic floordiv/mod expressions to - demonstrate that the expensive magic-number precomputation - (a single divui) is performed once per divisor and then reused - by multiple cheap multiply-and-shift sequences. - """ - constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] - constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] - constraints += [tkw.TilingConstraint(K, BLOCK_K)] - constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] - constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] - constraints += [ - tkw.HardwareConstraint( - threads_per_wave=64, - waves_per_block=(2, 2, 1), - mma_type=tkw.MMAType.F32_16x16x16_F16, - ) - ] - - wg0, wg1 = WORKGROUP_0, WORKGROUP_1 - num_wg_0 = ceiling(M / BLOCK_M) - - flat_wg_index = wg1 * num_wg_0 + wg0 - num_wg_group = GROUP_SIZE_N * num_wg_0 - group_id = flat_wg_index // num_wg_group - first_wg_id_1 = group_id * GROUP_SIZE_N - new_wg0 = (flat_wg_index % num_wg_group) // GROUP_SIZE_N - new_wg1 = first_wg_id_1 + (flat_wg_index % num_wg_group) % GROUP_SIZE_N - - constraints += [tkw.ReorderingConstraint(new_wg0, 0)] - constraints += [tkw.ReorderingConstraint(new_wg1, 1)] - - @tkw.wave(constraints) - def gemm( - a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], - b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32], - ): - c_reg = tkl.Register[M, N, tkl.f32](0.0) - - @tkw.iterate(K, init_args=[c_reg]) - def repeat( - acc: tkl.Register[M, N, tkl.f32], - ) -> tkl.Register[M, N, tkl.f32]: - a_reg = tkw.read(a) - b_reg = tkw.read(b) - acc = tkw.mma(a_reg, b_reg, acc) - return acc - - tkw.write(repeat, c) - - options = WaveCompileOptions( - subs={ - M: 512, - N: 1024, - K: 256, - BLOCK_M: 64, - BLOCK_N: 64, - BLOCK_K: 32, - GROUP_SIZE_N: 4, - ADDRESS_SPACE: SHARED_ADDRESS_SPACE, - ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, - }, - canonicalize=True, - compile_to_mlir=True, - ) - - # Enable dynamic symbols - options.dynamic_symbols = [M, N, K] - for sym in options.dynamic_symbols: - del options.subs[sym] - - gemm = wave_compile(options, gemm) - print(gemm.asm) - - # ---- MAGIC (WAVE_MAGIC_NUMBER_DIV=1) ---- - # MAGIC-LABEL: func.func @gemm - # MAGIC-DAG: arith.constant 4294967295 : i64 - # MAGIC-DAG: %[[C32:.*]] = arith.constant 32 : i64 - # - # Magic precomputation (divui) followed by multiply-high (shrui): - # MAGIC: arith.divui {{.*}} : i64 - # MAGIC: arith.shrui {{.*}}, %[[C32]] : i64 - # - # Consume remaining precomputations from other address calculations. - # MAGIC: arith.divui - # MAGIC: arith.divui - # MAGIC: arith.divui - # MAGIC: arith.shrui {{.*}}, %[[C32]] : i64 - # - # Amortised: mulhi reusing a previously computed magic number - # with a different dividend — no new divui needed. - # MAGIC-NOT: arith.divui - # MAGIC-NOT: arith.divsi - # MAGIC: arith.shrui {{.*}}, %[[C32]] : i64 - # MAGIC-NOT: arith.divsi - # MAGIC: return - - # ---- NOMAGIC (WAVE_MAGIC_NUMBER_DIV=0) ---- - # Without magic numbers the dynamic floordiv/mod stay inside affine - # maps; no arith division ops are emitted. - # NOMAGIC-LABEL: func.func @gemm - # NOMAGIC-NOT: arith.divui - # NOMAGIC-NOT: arith.divsi - # NOMAGIC-NOT: arith.shrui - # NOMAGIC-NOT: 4294967295 - # NOMAGIC: affine.apply +# RUN: python %s | FileCheck %s + +from sympy import ceiling + +import wave_lang.kernel.lang as tkl +import wave_lang.kernel.wave as tkw +from wave_lang.kernel.lang.global_symbols import * +from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile +from wave_lang.kernel.wave.utils.general_utils import ( + run_test, +) + +M = tkl.sym.M +N = tkl.sym.N +K = tkl.sym.K +BLOCK_M = tkl.sym.BLOCK_M +BLOCK_N = tkl.sym.BLOCK_N +BLOCK_K = tkl.sym.BLOCK_K +GROUP_SIZE_N = tkl.sym.GROUP_SIZE_N +ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE +ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0 + + +@run_test +def test_magic_number_div(): + """Test that floordiv/mod by dynamic (runtime) divisors are lowered + to the magic-number multiply-high trick instead of expensive hardware + division. + + When kernel dimensions are dynamic, the compiler cannot fold + floordiv/mod into compile-time constants. The magic-number + optimisation precomputes ``ceil(2^32 / d)`` once per unique divisor + and replaces every subsequent division with a 64-bit multiply + shift, + which is significantly cheaper on GPU. + + We use a GEMM with GROUP_SIZE_N workgroup reordering to exercise + this: the reordering delinearises the flat workgroup id via + ``ceildiv(M, BLOCK_M)``, and the GEMM's multiple memory accesses + (read A, read B, write C) each independently compute reordered + indices, producing enough dynamic floordiv/mod expressions to + demonstrate that the expensive magic-number precomputation + (a single divui) is performed once per divisor and then reused + by multiple cheap multiply-and-shift sequences. + """ + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=tkw.MMAType.F32_16x16x16_F16, + ) + ] + + wg0, wg1 = WORKGROUP_0, WORKGROUP_1 + num_wg_0 = ceiling(M / BLOCK_M) + + flat_wg_index = wg1 * num_wg_0 + wg0 + num_wg_group = GROUP_SIZE_N * num_wg_0 + group_id = flat_wg_index // num_wg_group + first_wg_id_1 = group_id * GROUP_SIZE_N + new_wg0 = (flat_wg_index % num_wg_group) // GROUP_SIZE_N + new_wg1 = first_wg_id_1 + (flat_wg_index % num_wg_group) % GROUP_SIZE_N + + constraints += [tkw.ReorderingConstraint(new_wg0, 0)] + constraints += [tkw.ReorderingConstraint(new_wg1, 1)] + + @tkw.wave(constraints) + def gemm( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.iterate(K, init_args=[c_reg]) + def repeat( + acc: tkl.Register[M, N, tkl.f32], + ) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a) + b_reg = tkw.read(b) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, c) + + options = WaveCompileOptions( + subs={ + M: 512, + N: 1024, + K: 256, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + GROUP_SIZE_N: 4, + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + }, + canonicalize=True, + compile_to_mlir=True, + magic_number_div=True, + ) + + options.dynamic_symbols = [M, N, K] + for sym in options.dynamic_symbols: + del options.subs[sym] + + gemm = wave_compile(options, gemm) + print(gemm.asm) + + # CHECK-LABEL: func.func @gemm + # CHECK-DAG: arith.constant 4294967295 : i64 + # CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i64 + # + # Magic precomputation (divui) followed by multiply-high (shrui): + # CHECK: arith.divui {{.*}} : i64 + # CHECK: arith.shrui {{.*}}, %[[C32]] : i64 + # + # Consume remaining precomputations from other address calculations. + # CHECK: arith.divui + # CHECK: arith.divui + # CHECK: arith.divui + # CHECK: arith.shrui {{.*}}, %[[C32]] : i64 + # + # Amortised: mulhi reusing a previously computed magic number + # with a different dividend — no new divui needed. + # CHECK-NOT: arith.divui + # CHECK-NOT: arith.divsi + # CHECK: arith.shrui {{.*}}, %[[C32]] : i64 + # CHECK-NOT: arith.divsi + # CHECK: return diff --git a/wave_lang/kernel/compiler/wave_codegen/emitter.py b/wave_lang/kernel/compiler/wave_codegen/emitter.py index 98d14ef261..9a9567f2d1 100644 --- a/wave_lang/kernel/compiler/wave_codegen/emitter.py +++ b/wave_lang/kernel/compiler/wave_codegen/emitter.py @@ -261,6 +261,10 @@ def get_static_dim(s: Optional[IndexExpr]) -> int: return func_op def emit(self, graph: Optional[fx.Graph] = None) -> Operation: + global _magic_number_enabled, _magic_number_cache + _magic_number_enabled = self.options.magic_number_div + _magic_number_cache = {} + func = self.emit_func() with InsertionPoint.at_block_terminator(func.entry_block), Location.unknown(): self._emit_graph( @@ -633,7 +637,7 @@ def add_emitter_subs( _emulate_ceildiv = bool(int(environ.get("WAVE_EMULATE_CEILDIV", 0))) _use_affine_expr = bool(int(environ.get("WAVE_USE_AFFINE_EXPR", 1))) -_magic_number_enabled = bool(int(environ.get("WAVE_MAGIC_NUMBER_DIV", 1))) +_magic_number_enabled = False _Rational = namedtuple("_Rational", ["numerator", "denominator"]) _ApplyExpr = namedtuple("_ApplyExpr", ["expr", "args"]) diff --git a/wave_lang/kernel/wave/compile_options.py b/wave_lang/kernel/wave/compile_options.py index fafb01453b..e8a783d8dc 100644 --- a/wave_lang/kernel/wave/compile_options.py +++ b/wave_lang/kernel/wave/compile_options.py @@ -96,6 +96,7 @@ class WaveCompileOptions: enable_mark_hardware_transpose_candidates: bool = True # === Compiler options === + magic_number_div: bool = False minimize_shared_allocs: bool = True reorder_allocs: bool = True override_schedule: Optional[str] = None From e0160c63fe6dee889a18844a88a2623395b45688 Mon Sep 17 00:00:00 2001 From: Aurore De Spirlet Date: Wed, 1 Apr 2026 17:56:52 +0000 Subject: [PATCH 3/7] make sure to hoist the magic number computation before loop Signed-off-by: Aurore De Spirlet --- .../kernel/compiler/wave_codegen/emitter.py | 48 ++++++++++++------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/emitter.py b/wave_lang/kernel/compiler/wave_codegen/emitter.py index 9a9567f2d1..24f6c1794e 100644 --- a/wave_lang/kernel/compiler/wave_codegen/emitter.py +++ b/wave_lang/kernel/compiler/wave_codegen/emitter.py @@ -261,11 +261,12 @@ def get_static_dim(s: Optional[IndexExpr]) -> int: return func_op def emit(self, graph: Optional[fx.Graph] = None) -> Operation: - global _magic_number_enabled, _magic_number_cache + global _magic_number_enabled, _magic_number_cache, _magic_entry_block _magic_number_enabled = self.options.magic_number_div _magic_number_cache = {} func = self.emit_func() + _magic_entry_block = func.entry_block with InsertionPoint.at_block_terminator(func.entry_block), Location.unknown(): self._emit_graph( graph if graph is not None else self.trace.get_root_graph() @@ -638,6 +639,7 @@ def add_emitter_subs( _emulate_ceildiv = bool(int(environ.get("WAVE_EMULATE_CEILDIV", 0))) _use_affine_expr = bool(int(environ.get("WAVE_USE_AFFINE_EXPR", 1))) _magic_number_enabled = False +_magic_entry_block = None _Rational = namedtuple("_Rational", ["numerator", "denominator"]) _ApplyExpr = namedtuple("_ApplyExpr", ["expr", "args"]) @@ -788,12 +790,10 @@ def muli_expr(lhs, rhs): def _is_dynamic_divisor(val) -> bool: """Check if a value is NOT a compile-time constant.""" if isinstance(val, _ApplyExpr): - if all( + return not all( isinstance(a, OpResult) and get_const_val(a) is not None for a in val.args - ): - return False - val = _get_ir_value(val) + ) if isinstance(val, OpResult): return get_const_val(val) is None return True @@ -827,24 +827,42 @@ def _precompute_magic_number(divisor_index: Value): magic_i32 = arith_d.trunci(i32, magic_i64) return magic_i32, d_i32 - def _get_or_create_magic(divisor: Value): - """Get cached (magic_i32, d_i32) or compute and cache them.""" - key = id(divisor) + def _get_or_create_magic(divisor_expr): + """Get cached (magic_i32, d_i32) or compute and cache them. + + On cache miss the precomputation is hoisted to the function + entry block so that the magic constant dominates every use. + """ + key = id(divisor_expr) if key in _magic_number_cache: return _magic_number_cache[key] - magic_i32, d_i32 = _precompute_magic_number(divisor) + if _magic_entry_block is not None: + with InsertionPoint.at_block_begin(_magic_entry_block): + divisor_val = ( + _get_ir_value(divisor_expr) + if isinstance(divisor_expr, _ApplyExpr) + else divisor_expr + ) + magic_i32, d_i32 = _precompute_magic_number(divisor_val) + else: + divisor_val = ( + _get_ir_value(divisor_expr) + if isinstance(divisor_expr, _ApplyExpr) + else divisor_expr + ) + magic_i32, d_i32 = _precompute_magic_number(divisor_val) _magic_number_cache[key] = (magic_i32, d_i32) return magic_i32, d_i32 - def _magic_div_and_rem(lhs_val, rhs_val): + def _magic_div_and_rem(lhs_val, rhs_expr): """ - Compute (quotient, remainder) of lhs_val // rhs_val using + Compute (quotient, remainder) of lhs_val // rhs using magic number multiplication: q = mulhi(n, magic), with a one-step correction for exactness. Returns (quotient_index, remainder_index). """ i32 = IntegerType.get_signless(32) - magic_i32, d_i32 = _get_or_create_magic(rhs_val) + magic_i32, d_i32 = _get_or_create_magic(rhs_expr) n_i32 = arith_d.index_cast(i32, lhs_val) q_i32 = _mulhi_u32(n_i32, magic_i32) qd_i32 = arith_d.muli(q_i32, d_i32) @@ -868,8 +886,7 @@ def rem_expr(lhs, rhs): if _magic_number_enabled and _is_dynamic_divisor(rhs): lhs_val = _get_ir_value(lhs) if isinstance(lhs, _ApplyExpr) else lhs - rhs_val = _get_ir_value(rhs) if isinstance(rhs, _ApplyExpr) else rhs - _, r = _magic_div_and_rem(lhs_val, rhs_val) + _, r = _magic_div_and_rem(lhs_val, rhs) return r return op_expr(lhs, rhs, lambda a, b: a % b) @@ -880,8 +897,7 @@ def floordiv_expr(lhs, rhs): if _magic_number_enabled and _is_dynamic_divisor(rhs): lhs_val = _get_ir_value(lhs) if isinstance(lhs, _ApplyExpr) else lhs - rhs_val = _get_ir_value(rhs) if isinstance(rhs, _ApplyExpr) else rhs - q, _ = _magic_div_and_rem(lhs_val, rhs_val) + q, _ = _magic_div_and_rem(lhs_val, rhs) return q return op_expr(lhs, rhs, lambda a, b: AffineExpr.get_floor_div(a, b)) From a0ac1e093bdf38695b065f6a5437fbba38c62ed1 Mon Sep 17 00:00:00 2001 From: Aurore De Spirlet Date: Wed, 8 Apr 2026 22:28:35 +0000 Subject: [PATCH 4/7] expect 2 divisors in test instead of 4 due to caching Signed-off-by: Aurore De Spirlet --- lit_tests/kernel/wave/magic_number_division.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/lit_tests/kernel/wave/magic_number_division.py b/lit_tests/kernel/wave/magic_number_division.py index 931642d2fd..892bb228b7 100644 --- a/lit_tests/kernel/wave/magic_number_division.py +++ b/lit_tests/kernel/wave/magic_number_division.py @@ -115,14 +115,11 @@ def repeat( # CHECK-DAG: arith.constant 4294967295 : i64 # CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i64 # - # Magic precomputation (divui) followed by multiply-high (shrui): + # Magic precomputation: one divui per unique dynamic divisor. + # CHECK: arith.divui {{.*}} : i64 # CHECK: arith.divui {{.*}} : i64 - # CHECK: arith.shrui {{.*}}, %[[C32]] : i64 # - # Consume remaining precomputations from other address calculations. - # CHECK: arith.divui - # CHECK: arith.divui - # CHECK: arith.divui + # Multiply-high (shrui >> 32) reusing precomputed magic numbers. # CHECK: arith.shrui {{.*}}, %[[C32]] : i64 # # Amortised: mulhi reusing a previously computed magic number @@ -130,5 +127,6 @@ def repeat( # CHECK-NOT: arith.divui # CHECK-NOT: arith.divsi # CHECK: arith.shrui {{.*}}, %[[C32]] : i64 + # CHECK-NOT: arith.divui # CHECK-NOT: arith.divsi # CHECK: return From 07a606394b6fa048e47f3afb8eab31db65f0e164 Mon Sep 17 00:00:00 2001 From: Aurore De Spirlet Date: Wed, 8 Apr 2026 23:49:42 +0000 Subject: [PATCH 5/7] automatically detect if magic number trick is worth it Signed-off-by: Aurore De Spirlet --- .../kernel/wave/magic_number_division.py | 1 - .../kernel/compiler/wave_codegen/emitter.py | 59 ++++++++++++++++--- wave_lang/kernel/wave/compile_options.py | 2 +- 3 files changed, 51 insertions(+), 11 deletions(-) diff --git a/lit_tests/kernel/wave/magic_number_division.py b/lit_tests/kernel/wave/magic_number_division.py index 892bb228b7..5262786b1e 100644 --- a/lit_tests/kernel/wave/magic_number_division.py +++ b/lit_tests/kernel/wave/magic_number_division.py @@ -101,7 +101,6 @@ def repeat( }, canonicalize=True, compile_to_mlir=True, - magic_number_div=True, ) options.dynamic_symbols = [M, N, K] diff --git a/wave_lang/kernel/compiler/wave_codegen/emitter.py b/wave_lang/kernel/compiler/wave_codegen/emitter.py index 24f6c1794e..d30615f4f0 100644 --- a/wave_lang/kernel/compiler/wave_codegen/emitter.py +++ b/wave_lang/kernel/compiler/wave_codegen/emitter.py @@ -261,9 +261,10 @@ def get_static_dim(s: Optional[IndexExpr]) -> int: return func_op def emit(self, graph: Optional[fx.Graph] = None) -> Operation: - global _magic_number_enabled, _magic_number_cache, _magic_entry_block + global _magic_number_enabled, _magic_number_cache, _magic_entry_block, _magic_divisor_first_seen _magic_number_enabled = self.options.magic_number_div _magic_number_cache = {} + _magic_divisor_first_seen = {} func = self.emit_func() _magic_entry_block = func.entry_block @@ -645,6 +646,7 @@ def add_emitter_subs( _ApplyExpr = namedtuple("_ApplyExpr", ["expr", "args"]) _magic_number_cache: dict = {} +_magic_divisor_first_seen: dict = {} def gen_sympy_index(dynamics: dict[IndexSymbol, Value], expr: sympy.Expr) -> Value: @@ -798,6 +800,35 @@ def _is_dynamic_divisor(val) -> bool: return get_const_val(val) is None return True + def _divisor_key(val): + """Hashable key that identifies a divisor by its structure.""" + if isinstance(val, _ApplyExpr): + arg_keys = [] + for a in val.args: + c = get_const_val(a) if isinstance(a, OpResult) else None + arg_keys.append(("const", c) if c is not None else ("val", id(a))) + return ("apply", str(val.expr), tuple(arg_keys)) + if isinstance(val, OpResult): + c = get_const_val(val) + if c is not None: + return ("const", c) + return ("val", id(val)) + return ("other", id(val)) + + def _should_use_magic(rhs_expr) -> bool: + """Return True only when a dynamic divisor is seen for the second time. + + First encounter: record and decline (no benefit over a single div). + Second+ encounter: the precomputation is amortised, so use magic. + """ + key = _divisor_key(rhs_expr) + if key in _magic_number_cache: + return True + if key in _magic_divisor_first_seen: + return True + _magic_divisor_first_seen[key] = True + return False + def _mulhi_u32(n_i32, m_i32): """Unsigned 32-bit multiply-high: (n * m) >> 32, via 64-bit multiply.""" i64 = IntegerType.get_signless(64) @@ -833,7 +864,7 @@ def _get_or_create_magic(divisor_expr): On cache miss the precomputation is hoisted to the function entry block so that the magic constant dominates every use. """ - key = id(divisor_expr) + key = _divisor_key(divisor_expr) if key in _magic_number_cache: return _magic_number_cache[key] if _magic_entry_block is not None: @@ -855,11 +886,13 @@ def _get_or_create_magic(divisor_expr): return magic_i32, d_i32 def _magic_div_and_rem(lhs_val, rhs_expr): - """ - Compute (quotient, remainder) of lhs_val // rhs using - magic number multiplication: q = mulhi(n, magic), with a - one-step correction for exactness. - Returns (quotient_index, remainder_index). + """Compute (quotient, remainder) of lhs_val // rhs via mulhi. + + Uses unsigned 32-bit arithmetic (extui, divui, shrui, uge). + Requires both operands to be non-negative and fit in 32 bits. + This holds for GPU index computations: dividends are + workgroup/thread indices and divisors are derived from + positive kernel dimensions. """ i32 = IntegerType.get_signless(32) magic_i32, d_i32 = _get_or_create_magic(rhs_expr) @@ -884,7 +917,11 @@ def rem_expr(lhs, rhs): if not use_affine_expr or not check_index_types(lhs, rhs): return arith_d.remsi(*_broadcast(lhs, rhs)) - if _magic_number_enabled and _is_dynamic_divisor(rhs): + if ( + _magic_number_enabled + and _is_dynamic_divisor(rhs) + and _should_use_magic(rhs) + ): lhs_val = _get_ir_value(lhs) if isinstance(lhs, _ApplyExpr) else lhs _, r = _magic_div_and_rem(lhs_val, rhs) return r @@ -895,7 +932,11 @@ def floordiv_expr(lhs, rhs): if not use_affine_expr or not check_index_types(lhs, rhs): return arith_d.divsi(*_broadcast(lhs, rhs)) - if _magic_number_enabled and _is_dynamic_divisor(rhs): + if ( + _magic_number_enabled + and _is_dynamic_divisor(rhs) + and _should_use_magic(rhs) + ): lhs_val = _get_ir_value(lhs) if isinstance(lhs, _ApplyExpr) else lhs q, _ = _magic_div_and_rem(lhs_val, rhs) return q diff --git a/wave_lang/kernel/wave/compile_options.py b/wave_lang/kernel/wave/compile_options.py index e8a783d8dc..07e6d463e8 100644 --- a/wave_lang/kernel/wave/compile_options.py +++ b/wave_lang/kernel/wave/compile_options.py @@ -96,7 +96,7 @@ class WaveCompileOptions: enable_mark_hardware_transpose_candidates: bool = True # === Compiler options === - magic_number_div: bool = False + magic_number_div: bool = True minimize_shared_allocs: bool = True reorder_allocs: bool = True override_schedule: Optional[str] = None From 421c49e5f0b51ca83cb5451401fdebf2b494f3af Mon Sep 17 00:00:00 2001 From: Aurore De Spirlet Date: Thu, 9 Apr 2026 17:42:42 +0000 Subject: [PATCH 6/7] set magic pass back to default off Signed-off-by: Aurore De Spirlet --- lit_tests/kernel/wave/magic_number_division.py | 1 + wave_lang/kernel/wave/compile_options.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/lit_tests/kernel/wave/magic_number_division.py b/lit_tests/kernel/wave/magic_number_division.py index 5262786b1e..892bb228b7 100644 --- a/lit_tests/kernel/wave/magic_number_division.py +++ b/lit_tests/kernel/wave/magic_number_division.py @@ -101,6 +101,7 @@ def repeat( }, canonicalize=True, compile_to_mlir=True, + magic_number_div=True, ) options.dynamic_symbols = [M, N, K] diff --git a/wave_lang/kernel/wave/compile_options.py b/wave_lang/kernel/wave/compile_options.py index 07e6d463e8..e8a783d8dc 100644 --- a/wave_lang/kernel/wave/compile_options.py +++ b/wave_lang/kernel/wave/compile_options.py @@ -96,7 +96,7 @@ class WaveCompileOptions: enable_mark_hardware_transpose_candidates: bool = True # === Compiler options === - magic_number_div: bool = True + magic_number_div: bool = False minimize_shared_allocs: bool = True reorder_allocs: bool = True override_schedule: Optional[str] = None From 78cf1e88d63e1c540d754859467922855406f38d Mon Sep 17 00:00:00 2001 From: Aurore De Spirlet Date: Fri, 10 Apr 2026 22:21:13 +0000 Subject: [PATCH 7/7] add guard when d==1 Signed-off-by: Aurore De Spirlet --- wave_lang/kernel/compiler/wave_codegen/emitter.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/wave_lang/kernel/compiler/wave_codegen/emitter.py b/wave_lang/kernel/compiler/wave_codegen/emitter.py index d30615f4f0..927d1f7ca4 100644 --- a/wave_lang/kernel/compiler/wave_codegen/emitter.py +++ b/wave_lang/kernel/compiler/wave_codegen/emitter.py @@ -909,6 +909,11 @@ def _magic_div_and_rem(lhs_val, rhs_expr): q_final = arith_d.subi(q_i32, corr) d_or_zero = arith_d.select(too_big, d_i32, c0_i32) r_final = arith_d.addi(r_i32, d_or_zero) + # Guard: when d == 1 the magic number overflows i32 to 0, + # so fall back to the trivial n // 1 = n, n % 1 = 0. + d_is_one = arith_d.cmpi(arith_d.CmpIPredicate.eq, d_i32, c1_i32) + q_final = arith_d.select(d_is_one, n_i32, q_final) + r_final = arith_d.select(d_is_one, c0_i32, r_final) q_index = arith_d.index_cast(IndexType.get(), q_final) r_index = arith_d.index_cast(IndexType.get(), r_final) return q_index, r_index