Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions aiter/ops/triton/_gluon_kernels/gfx1250/moe/reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

"""Gluon grouped row-reduce for the gfx1250 MoE scatter-combine (replaces the Triton ``_reduce_grouped``).

One workgroup per group sums the ``K*B`` rows ``indx[g, :]`` (TDM-loaded, summed
in-register, no cross-wave communication) into ``out[g, :N]``, with optional
external residual fold-in.
"""

from triton.experimental import gluon
from triton.experimental.gluon import language as gl


@gluon.jit
def reduce_grouped_gluon(
X, # [B, M, N] (flattened to [B*M, N] in the descriptor)
Out, # [num_groups, N]
InIndx, # [num_groups, K] int
Residual, # [num_groups, N] external residual to fold in (dummy ptr if unused)
stride_xm,
stride_om,
stride_on,
stride_res_m,
stride_res_n,
M,
N: gl.constexpr,
NPAD: gl.constexpr, # next_pow2(N)
B: gl.constexpr,
K: gl.constexpr,
NUM_WARPS: gl.constexpr,
HAS_EXT_RESIDUAL: gl.constexpr,
):
group = gl.program_id(0)
gl.static_assert(NPAD >= 32, "NPAD must be >= 32")
gl.static_assert(NPAD % (NUM_WARPS * 32) == 0, "NPAD must be a multiple of NUM_WARPS*32")

# Load a power-of-2 column tile NPAD>=N (TDM block dims must be pow2) while the descriptor shape stays at true N, so TDM zero-pads cols [N:NPAD) (masked off on store).
SIZE_N: gl.constexpr = NPAD // (NUM_WARPS * 32)
BLKN: gl.constexpr = gl.BlockedLayout([1, SIZE_N], [1, 32], [1, NUM_WARPS], [1, 0])
SH: gl.constexpr = gl.SwizzledSharedLayout(1, 1, 1, order=[1, 0])

x_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor(
X, [B * M, N], [stride_xm, 1], [1, NPAD], SH
)
smem = gl.allocate_shared_memory(X.dtype.element_ty, [K * B, 1, NPAD], SH)

# issue all K*B row loads (overlapped), then reduce
buf = 0
for i in gl.static_range(K):
idx_i = gl.load(InIndx + group * K + i)
for b in gl.static_range(B):
row = b * M + idx_i
gl.amd.gfx1250.tdm.async_load(x_desc, [row, 0], smem.index(buf))
buf += 1
gl.amd.gfx1250.tdm.async_wait(0)

acc = gl.zeros([1, NPAD], dtype=gl.float32, layout=BLKN)
buf = 0
for i in gl.static_range(K):
for b in gl.static_range(B):
acc += smem.index(buf).load(BLKN).to(gl.float32)
buf += 1

offs_n = gl.arange(0, NPAD, layout=gl.SliceLayout(0, BLKN))
o_offs = group * stride_om + offs_n[None, :] * stride_on
o_mask = offs_n[None, :] < N

# Fold in the external residual before writeback (matches the Triton HAS_EXT_RESIDUAL path).
if HAS_EXT_RESIDUAL:
r_offs = group * stride_res_m + offs_n[None, :] * stride_res_n
res = gl.amd.gfx1250.buffer_load(Residual, r_offs, mask=o_mask, other=0.0)
acc += res.to(gl.float32)

gl.amd.gfx1250.buffer_store(acc.to(Out.dtype.element_ty), Out, o_offs, mask=o_mask)


def reduce_grouped_gluon_num_warps(npad: int) -> int:
"""Pick the largest wave count W in {8,4,2,1} with ``npad % (W*32) == 0``."""
for w in (8, 4, 2, 1):
if npad % (w * 32) == 0:
return w
return 1
51 changes: 51 additions & 0 deletions aiter/ops/triton/moe/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
from aiter.ops.triton._triton_kernels.moe.reduce import _reduce_grouped
from aiter.ops.triton.utils._triton.arch_info import is_tdm_avail

try:
from aiter.ops.triton._gluon_kernels.gfx1250.moe.reduce import (
reduce_grouped_gluon as _reduce_grouped_gluon,
reduce_grouped_gluon_num_warps as _reduce_grouped_gluon_num_warps,
)
except (ImportError, ModuleNotFoundError):
_reduce_grouped_gluon = None
_reduce_grouped_gluon_num_warps = None


def reduce_grouped(
x: torch.Tensor,
Expand Down Expand Up @@ -52,6 +61,48 @@ def reduce_grouped(
K = 1 if indx is None else indx.shape[1]
out_dtype = x.dtype if out_dtype is None else out_dtype
assert x.shape[-1] % reduction_n == 0

# Gluon path on gfx1250 for the plain grouped combine; swiglu-fused (MoE1 split-k) reductions, reduction_n != 1, and non-contiguous inputs stay on the Triton _reduce_grouped.
use_gluon = (
is_tdm_avail()
and indx is not None
and not apply_swiglu
and reduction_n == 1
and x.ndim == 3
and x.is_contiguous()
and indx.is_contiguous()
)
if use_gluon:
B, M, N = x.shape[0], x.shape[1], x.shape[2]
npad = triton.next_power_of_2(N)
has_ext_residual = residual is not None
if has_ext_residual:
assert residual.shape == out.shape, (
f"residual.shape {tuple(residual.shape)} must match "
f"out.shape {tuple(out.shape)}"
)
gluon_num_warps = _reduce_grouped_gluon_num_warps(npad)
_reduce_grouped_gluon[(num_groups,)](
X=x,
Out=out,
InIndx=indx,
Residual=residual if has_ext_residual else out,
stride_xm=x.stride(1),
stride_om=out.stride(0),
stride_on=out.stride(1),
stride_res_m=residual.stride(0) if has_ext_residual else 0,
stride_res_n=residual.stride(1) if has_ext_residual else 0,
M=M,
N=N,
NPAD=npad,
B=B,
K=K,
NUM_WARPS=gluon_num_warps,
HAS_EXT_RESIDUAL=has_ext_residual,
num_warps=gluon_num_warps,
)
return out

BLOCK_N = 512
num_blocks = triton.cdiv(x.shape[-1], BLOCK_N)

Expand Down
Loading