Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,9 @@ class Envs:
# decode without runtime permutes.
SGLANG_AITER_KV_CACHE_LAYOUT = EnvStr("nhd")
SGLANG_ROCM_FUSED_DECODE_MLA = EnvBool(False)
# Fuse the bf16 kv_b_proj GEMM + nope/v split + k_pe cat + fp8 cast into a
# single aiter Triton kernel (fused_gemm_a16w16_split_cat) on ROCm.
SGLANG_AITER_FUSED_KVB_SPLIT_CAT = EnvBool(False)
SGLANG_ROCM_DISABLE_LINEARQUANT = EnvBool(False)
SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK = EnvInt(4096)
# Enable dual-stream MoE (shared experts vs routed experts) on the
Expand Down
41 changes: 41 additions & 0 deletions python/sglang/srt/layers/attention/aiter_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
)

from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.environ import envs
from sglang.srt.layers.attention.aiter_utils import (
forward_decode_vectorized_5d,
forward_extend_vectorized_5d,
Expand All @@ -84,6 +85,27 @@
get_bool_env_var("SGLANG_AITER_FP8_PREFILL_ATTN", "True") and is_gfx95_supported()
)

# Opt-in (default off): fuse the bf16 kv_b_proj GEMM with its nope/v split,
# k_pe cat, and fp8 cast into a single Triton kernel (bf16 analog of the MXFP4
# fused_gemm_afp4wfp4_split_cat path). Only valid on gfx95 with bf16/fp16
# kv_b_proj weights and fp8 prefill attention. Requires a recent aiter exposing
# fused_gemm_a16w16_split_cat; falls back to the unfused path if not present.
_has_fused_gemm_a16w16_split_cat = False
try:
from aiter.ops.triton.gemm.fused.fused_gemm_a16w16_split_cat import (
fused_gemm_a16w16_split_cat,
)

_has_fused_gemm_a16w16_split_cat = True
except ImportError:
pass

_use_fused_kvb_split_cat = (
envs.SGLANG_AITER_FUSED_KVB_SPLIT_CAT.get()
and is_gfx95_supported()
and _has_fused_gemm_a16w16_split_cat
)

# Persist
# fast_mode=True if _use_mla_ps_kernel else False
# intra_batch_mode=False if _use_mla_ps_kernel else True
Expand Down Expand Up @@ -2028,6 +2050,25 @@ def forward_extend(
fp8_dtype,
)
)[0]
elif (
_use_fp8_prefill_attn
and _use_fused_kvb_split_cat
and layer.kv_b_proj.weight.dtype
in (torch.bfloat16, torch.float16)
):
# BF16 weights + FP8 prefill: fuse the kv_b_proj GEMM,
# nope/v split, and k_pe cat into a single kernel
# (fused_gemm_a16w16_split_cat) that writes k and v
# directly in FP8, avoiding separate split / cat /
# float8_copy passes.
k, v = fused_gemm_a16w16_split_cat(
x=kvc.squeeze(1).contiguous(),
w=layer.kv_b_proj.weight,
y=k_pe.expand(-1, layer.tp_k_head_num, -1),
S1=qk_nope_head_dim,
S2=layer.v_head_dim,
dtype=fp8_dtype,
)
else:
kv = layer.kv_b_proj(kvc.contiguous())[0]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,32 @@
elif _is_musa:
from sgl_kernel import concat_mla_k

# Opt-in (default off): fuse the bf16 kv_b_proj GEMM with its nope/v split,
# k_pe cat, and fp8 cast into a single Triton kernel (the bf16 analog of the
# MXFP4 fused_gemm_afp4wfp4_split_cat path). Requires a recent aiter exposing
# fused_gemm_a16w16_split_cat; falls back to the unfused path if not present.
_has_fused_gemm_a16w16_split_cat = False
if _use_aiter_gfx95:
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant

from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype
from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant

try:
from aiter.ops.triton.gemm.fused.fused_gemm_a16w16_split_cat import (
fused_gemm_a16w16_split_cat,
)

_has_fused_gemm_a16w16_split_cat = True
except ImportError:
pass

_use_fused_kvb_split_cat = (
envs.SGLANG_AITER_FUSED_KVB_SPLIT_CAT.get()
and _use_aiter_gfx95
and _has_fused_gemm_a16w16_split_cat
)


def _resolve_attn_backend(forward_batch: ForwardBatch):
backend = get_attn_backend()
Expand Down Expand Up @@ -289,6 +309,23 @@ def forward_normal_prepare(
fp8_dtype,
)
)[0]
elif (
_use_fp8_prefill_attn
and _use_fused_kvb_split_cat
and self.kv_b_proj.weight.dtype in (torch.bfloat16, torch.float16)
):
# BF16 weights + FP8 prefill: fuse the kv_b_proj GEMM, nope/v split,
# and k_pe cat into a single kernel (fused_gemm_a16w16_split_cat)
# that writes k and v directly in FP8, avoiding separate split / cat
# / float8_copy passes.
k, v = fused_gemm_a16w16_split_cat(
x=kv_a,
w=self.kv_b_proj.weight,
y=k_pe.expand(-1, self.num_local_heads, -1),
S1=self.qk_nope_head_dim,
S2=self.v_head_dim,
dtype=fp8_dtype,
)
else:
if _use_aiter_gfx95 and self.kv_b_proj.weight.dtype == torch.float8_e4m3fn:
kv = self.kv_b_proj(kv_a_quanted)[0]
Expand Down
Loading