From 5becee905207c24ef4a3cd76aced44858c33ebdb Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Sat, 27 Jun 2026 18:47:29 +0000 Subject: [PATCH 1/4] Add gluon moe reduce --- .../_gluon_kernels/gfx1250/moe/reduce.py | 82 +++++++++++++++++++ aiter/ops/triton/moe/reduce.py | 53 ++++++++++++ 2 files changed, 135 insertions(+) create mode 100644 aiter/ops/triton/_gluon_kernels/gfx1250/moe/reduce.py diff --git a/aiter/ops/triton/_gluon_kernels/gfx1250/moe/reduce.py b/aiter/ops/triton/_gluon_kernels/gfx1250/moe/reduce.py new file mode 100644 index 0000000000..0173325d12 --- /dev/null +++ b/aiter/ops/triton/_gluon_kernels/gfx1250/moe/reduce.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +"""Gluon grouped row-reduce for the gfx1250 MoE scatter-combine (replaces the Triton ``_reduce_grouped``). + +One workgroup per group sums the ``K*B`` rows ``indx[g, :]`` (TDM-loaded, summed +in-register, no cross-wave communication) into ``out[g, :N]``, with optional +external residual fold-in. Swiglu stays on the Triton path. +""" + +import triton +from triton.experimental import gluon +from triton.experimental.gluon import language as gl + + +@gluon.jit +def reduce_grouped_gluon( + X, # [B, M, N] (flattened to [B*M, N] in the descriptor) + Out, # [num_groups, N] + InIndx, # [num_groups, K] int + Residual, # [num_groups, N] external residual to fold in (dummy ptr if unused) + stride_xm, + stride_om, + stride_on, + stride_res_m, + stride_res_n, + M, + N: gl.constexpr, + NPAD: gl.constexpr, # next_pow2(N) + B: gl.constexpr, + K: gl.constexpr, + NUM_WARPS: gl.constexpr, + HAS_EXT_RESIDUAL: gl.constexpr, +): + group = gl.program_id(0) + + # Load a power-of-2 column tile NPAD>=N (TDM block dims must be pow2) while the descriptor shape stays at true N, so TDM zero-pads cols [N:NPAD) (masked off on store). + SIZE_N: gl.constexpr = NPAD // (NUM_WARPS * 32) + BLKN: gl.constexpr = gl.BlockedLayout([1, SIZE_N], [1, 32], [1, NUM_WARPS], [1, 0]) + SH: gl.constexpr = gl.SwizzledSharedLayout(1, 1, 1, order=[1, 0]) + + x_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + X, [B * M, N], [stride_xm, 1], [1, NPAD], SH + ) + smem = gl.allocate_shared_memory(X.dtype.element_ty, [K * B, 1, NPAD], SH) + + # issue all K*B row loads (overlapped), then reduce + buf = 0 + for i in gl.static_range(K): + idx_i = gl.load(InIndx + group * K + i) + for b in gl.static_range(B): + row = b * M + idx_i + gl.amd.gfx1250.tdm.async_load(x_desc, [row, 0], smem.index(buf)) + buf += 1 + gl.amd.gfx1250.tdm.async_wait(0) + + acc = gl.zeros([1, NPAD], dtype=gl.float32, layout=BLKN) + buf = 0 + for i in gl.static_range(K): + for b in gl.static_range(B): + acc += smem.index(buf).load(BLKN).to(gl.float32) + buf += 1 + + offs_n = gl.arange(0, NPAD, layout=gl.SliceLayout(0, BLKN)) + o_offs = group * stride_om + offs_n[None, :] * stride_on + o_mask = offs_n[None, :] < N + + # Fold in the external residual before writeback (matches the Triton HAS_EXT_RESIDUAL path). + if HAS_EXT_RESIDUAL: + r_offs = group * stride_res_m + offs_n[None, :] * stride_res_n + res = gl.amd.gfx1250.buffer_load(Residual, r_offs, mask=o_mask, other=0.0) + acc += res.to(gl.float32) + + gl.amd.gfx1250.buffer_store(acc.to(Out.dtype.element_ty), Out, o_offs, mask=o_mask) + + +def reduce_grouped_gluon_num_warps(npad: int) -> int: + """Pick the largest wave count W in {8,4,2,1} with ``npad % (W*32) == 0``.""" + for w in (8, 4, 2, 1): + if npad % (w * 32) == 0: + return w + return 1 diff --git a/aiter/ops/triton/moe/reduce.py b/aiter/ops/triton/moe/reduce.py index 50e9f4f77f..da98c06be5 100644 --- a/aiter/ops/triton/moe/reduce.py +++ b/aiter/ops/triton/moe/reduce.py @@ -4,6 +4,15 @@ from aiter.ops.triton._triton_kernels.moe.reduce import _reduce_grouped from aiter.ops.triton.utils._triton.arch_info import is_tdm_avail +try: + from aiter.ops.triton._gluon_kernels.gfx1250.moe.reduce import ( + reduce_grouped_gluon as _reduce_grouped_gluon, + reduce_grouped_gluon_num_warps as _reduce_grouped_gluon_num_warps, + ) +except (ImportError, ModuleNotFoundError): + _reduce_grouped_gluon = None + _reduce_grouped_gluon_num_warps = None + def reduce_grouped( x: torch.Tensor, @@ -52,6 +61,50 @@ def reduce_grouped( K = 1 if indx is None else indx.shape[1] out_dtype = x.dtype if out_dtype is None else out_dtype assert x.shape[-1] % reduction_n == 0 + + # Gluon fast path: always used on gfx1250 for the post-MoE2 expert combine; fused swiglu and non-contiguous inputs fall back to the Triton _reduce_grouped. + use_gluon = ( + is_tdm_avail() + and _reduce_grouped_gluon is not None + and indx is not None + and not apply_swiglu + and reduction_n == 1 + and x.ndim == 3 + and x.is_contiguous() + and indx.is_contiguous() + ) + if use_gluon: + B, M, N = x.shape[0], x.shape[1], x.shape[2] + npad = triton.next_power_of_2(N) + if npad >= 32: + has_ext_residual = residual is not None + if has_ext_residual: + assert residual.shape == out.shape, ( + f"residual.shape {tuple(residual.shape)} must match " + f"out.shape {tuple(out.shape)}" + ) + gluon_num_warps = _reduce_grouped_gluon_num_warps(npad) + _reduce_grouped_gluon[(num_groups,)]( + X=x, + Out=out, + InIndx=indx, + Residual=residual if has_ext_residual else out, + stride_xm=x.stride(1), + stride_om=out.stride(0), + stride_on=out.stride(1), + stride_res_m=residual.stride(0) if has_ext_residual else 0, + stride_res_n=residual.stride(1) if has_ext_residual else 0, + M=M, + N=N, + NPAD=npad, + B=B, + K=K, + NUM_WARPS=gluon_num_warps, + HAS_EXT_RESIDUAL=has_ext_residual, + num_warps=gluon_num_warps, + ) + return out + BLOCK_N = 512 num_blocks = triton.cdiv(x.shape[-1], BLOCK_N) From c9ab195683a250b1c26c7c221bc2cb87715934a6 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Sat, 27 Jun 2026 18:53:06 +0000 Subject: [PATCH 2/4] Ling --- aiter/ops/triton/_gluon_kernels/gfx1250/moe/reduce.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aiter/ops/triton/_gluon_kernels/gfx1250/moe/reduce.py b/aiter/ops/triton/_gluon_kernels/gfx1250/moe/reduce.py index 0173325d12..2efd0e204c 100644 --- a/aiter/ops/triton/_gluon_kernels/gfx1250/moe/reduce.py +++ b/aiter/ops/triton/_gluon_kernels/gfx1250/moe/reduce.py @@ -8,7 +8,6 @@ external residual fold-in. Swiglu stays on the Triton path. """ -import triton from triton.experimental import gluon from triton.experimental.gluon import language as gl From 0b3051af9ffa9893827eb1c2f2d5732229e449bf Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Sat, 27 Jun 2026 22:05:32 +0000 Subject: [PATCH 3/4] Simplify gfx1250 gluon reduce_grouped gate Drop the redundant _reduce_grouped_gluon None check (the kernel is always available on gfx1250), gate the gluon path on is_tdm_avail() plus the plain non-swiglu/reduction_n==1/contiguous combine, and assert the NPAD power-of-2 invariants in-kernel via gl.static_assert. Swiglu-fused (MoE1 split-k) and reduction_n!=1 reductions still fall back to the Triton _reduce_grouped. Co-Authored-By: Claude Opus 4.8 --- .../_gluon_kernels/gfx1250/moe/reduce.py | 4 +- aiter/ops/triton/moe/reduce.py | 56 +++++++++---------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/aiter/ops/triton/_gluon_kernels/gfx1250/moe/reduce.py b/aiter/ops/triton/_gluon_kernels/gfx1250/moe/reduce.py index 2efd0e204c..d5c5b081c3 100644 --- a/aiter/ops/triton/_gluon_kernels/gfx1250/moe/reduce.py +++ b/aiter/ops/triton/_gluon_kernels/gfx1250/moe/reduce.py @@ -5,7 +5,7 @@ One workgroup per group sums the ``K*B`` rows ``indx[g, :]`` (TDM-loaded, summed in-register, no cross-wave communication) into ``out[g, :N]``, with optional -external residual fold-in. Swiglu stays on the Triton path. +external residual fold-in. """ from triton.experimental import gluon @@ -32,6 +32,8 @@ def reduce_grouped_gluon( HAS_EXT_RESIDUAL: gl.constexpr, ): group = gl.program_id(0) + gl.static_assert(NPAD >= 32, "NPAD must be >= 32") + gl.static_assert(NPAD % (NUM_WARPS * 32) == 0, "NPAD must be a multiple of NUM_WARPS*32") # Load a power-of-2 column tile NPAD>=N (TDM block dims must be pow2) while the descriptor shape stays at true N, so TDM zero-pads cols [N:NPAD) (masked off on store). SIZE_N: gl.constexpr = NPAD // (NUM_WARPS * 32) diff --git a/aiter/ops/triton/moe/reduce.py b/aiter/ops/triton/moe/reduce.py index da98c06be5..90b4a6fa37 100644 --- a/aiter/ops/triton/moe/reduce.py +++ b/aiter/ops/triton/moe/reduce.py @@ -62,10 +62,9 @@ def reduce_grouped( out_dtype = x.dtype if out_dtype is None else out_dtype assert x.shape[-1] % reduction_n == 0 - # Gluon fast path: always used on gfx1250 for the post-MoE2 expert combine; fused swiglu and non-contiguous inputs fall back to the Triton _reduce_grouped. + # Gluon path on gfx1250 for the plain grouped combine; swiglu-fused (MoE1 split-k) reductions, reduction_n != 1, and non-contiguous inputs stay on the Triton _reduce_grouped. use_gluon = ( is_tdm_avail() - and _reduce_grouped_gluon is not None and indx is not None and not apply_swiglu and reduction_n == 1 @@ -76,34 +75,33 @@ def reduce_grouped( if use_gluon: B, M, N = x.shape[0], x.shape[1], x.shape[2] npad = triton.next_power_of_2(N) - if npad >= 32: - has_ext_residual = residual is not None - if has_ext_residual: - assert residual.shape == out.shape, ( - f"residual.shape {tuple(residual.shape)} must match " - f"out.shape {tuple(out.shape)}" - ) - gluon_num_warps = _reduce_grouped_gluon_num_warps(npad) - _reduce_grouped_gluon[(num_groups,)]( - X=x, - Out=out, - InIndx=indx, - Residual=residual if has_ext_residual else out, - stride_xm=x.stride(1), - stride_om=out.stride(0), - stride_on=out.stride(1), - stride_res_m=residual.stride(0) if has_ext_residual else 0, - stride_res_n=residual.stride(1) if has_ext_residual else 0, - M=M, - N=N, - NPAD=npad, - B=B, - K=K, - NUM_WARPS=gluon_num_warps, - HAS_EXT_RESIDUAL=has_ext_residual, - num_warps=gluon_num_warps, + has_ext_residual = residual is not None + if has_ext_residual: + assert residual.shape == out.shape, ( + f"residual.shape {tuple(residual.shape)} must match " + f"out.shape {tuple(out.shape)}" ) - return out + gluon_num_warps = _reduce_grouped_gluon_num_warps(npad) + _reduce_grouped_gluon[(num_groups,)]( + X=x, + Out=out, + InIndx=indx, + Residual=residual if has_ext_residual else out, + stride_xm=x.stride(1), + stride_om=out.stride(0), + stride_on=out.stride(1), + stride_res_m=residual.stride(0) if has_ext_residual else 0, + stride_res_n=residual.stride(1) if has_ext_residual else 0, + M=M, + N=N, + NPAD=npad, + B=B, + K=K, + NUM_WARPS=gluon_num_warps, + HAS_EXT_RESIDUAL=has_ext_residual, + num_warps=gluon_num_warps, + ) + return out BLOCK_N = 512 num_blocks = triton.cdiv(x.shape[-1], BLOCK_N) From 36001a05bf9b4955fd85d862ed13394a1d1a4e3d Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Mon, 29 Jun 2026 21:08:30 +0000 Subject: [PATCH 4/4] Apply black formatting to reduce.py --- aiter/ops/triton/_gluon_kernels/gfx1250/moe/reduce.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/aiter/ops/triton/_gluon_kernels/gfx1250/moe/reduce.py b/aiter/ops/triton/_gluon_kernels/gfx1250/moe/reduce.py index d5c5b081c3..f48121b50e 100644 --- a/aiter/ops/triton/_gluon_kernels/gfx1250/moe/reduce.py +++ b/aiter/ops/triton/_gluon_kernels/gfx1250/moe/reduce.py @@ -33,7 +33,9 @@ def reduce_grouped_gluon( ): group = gl.program_id(0) gl.static_assert(NPAD >= 32, "NPAD must be >= 32") - gl.static_assert(NPAD % (NUM_WARPS * 32) == 0, "NPAD must be a multiple of NUM_WARPS*32") + gl.static_assert( + NPAD % (NUM_WARPS * 32) == 0, "NPAD must be a multiple of NUM_WARPS*32" + ) # Load a power-of-2 column tile NPAD>=N (TDM block dims must be pow2) while the descriptor shape stays at true N, so TDM zero-pads cols [N:NPAD) (masked off on store). SIZE_N: gl.constexpr = NPAD // (NUM_WARPS * 32)