diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index eb9455dc23c7..d2d1c0ee9f4f 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -499,6 +499,9 @@ class Envs: # decode without runtime permutes. SGLANG_AITER_KV_CACHE_LAYOUT = EnvStr("nhd") SGLANG_ROCM_FUSED_DECODE_MLA = EnvBool(False) + # Fuse the bf16 kv_b_proj GEMM + nope/v split + k_pe cat + fp8 cast into a + # single aiter Triton kernel (fused_gemm_a16w16_split_cat) on ROCm. + SGLANG_AITER_FUSED_KVB_SPLIT_CAT = EnvBool(False) SGLANG_ROCM_DISABLE_LINEARQUANT = EnvBool(False) SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK = EnvInt(4096) # Enable dual-stream MoE (shared experts vs routed experts) on the diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index 9d990a9ad35a..9bc63070c869 100755 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -61,6 +61,7 @@ ) from sglang.srt.configs.model_config import AttentionArch +from sglang.srt.environ import envs from sglang.srt.layers.attention.aiter_utils import ( forward_decode_vectorized_5d, forward_extend_vectorized_5d, @@ -84,6 +85,27 @@ get_bool_env_var("SGLANG_AITER_FP8_PREFILL_ATTN", "True") and is_gfx95_supported() ) +# Opt-in (default off): fuse the bf16 kv_b_proj GEMM with its nope/v split, +# k_pe cat, and fp8 cast into a single Triton kernel (bf16 analog of the MXFP4 +# fused_gemm_afp4wfp4_split_cat path). Only valid on gfx95 with bf16/fp16 +# kv_b_proj weights and fp8 prefill attention. Requires a recent aiter exposing +# fused_gemm_a16w16_split_cat; falls back to the unfused path if not present. +_has_fused_gemm_a16w16_split_cat = False +try: + from aiter.ops.triton.gemm.fused.fused_gemm_a16w16_split_cat import ( + fused_gemm_a16w16_split_cat, + ) + + _has_fused_gemm_a16w16_split_cat = True +except ImportError: + pass + +_use_fused_kvb_split_cat = ( + envs.SGLANG_AITER_FUSED_KVB_SPLIT_CAT.get() + and is_gfx95_supported() + and _has_fused_gemm_a16w16_split_cat +) + # Persist # fast_mode=True if _use_mla_ps_kernel else False # intra_batch_mode=False if _use_mla_ps_kernel else True @@ -2028,6 +2050,25 @@ def forward_extend( fp8_dtype, ) )[0] + elif ( + _use_fp8_prefill_attn + and _use_fused_kvb_split_cat + and layer.kv_b_proj.weight.dtype + in (torch.bfloat16, torch.float16) + ): + # BF16 weights + FP8 prefill: fuse the kv_b_proj GEMM, + # nope/v split, and k_pe cat into a single kernel + # (fused_gemm_a16w16_split_cat) that writes k and v + # directly in FP8, avoiding separate split / cat / + # float8_copy passes. + k, v = fused_gemm_a16w16_split_cat( + x=kvc.squeeze(1).contiguous(), + w=layer.kv_b_proj.weight, + y=k_pe.expand(-1, layer.tp_k_head_num, -1), + S1=qk_nope_head_dim, + S2=layer.v_head_dim, + dtype=fp8_dtype, + ) else: kv = layer.kv_b_proj(kvc.contiguous())[0] diff --git a/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mha.py b/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mha.py index 729be49044eb..a9d3265dd402 100644 --- a/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mha.py +++ b/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mha.py @@ -39,12 +39,32 @@ elif _is_musa: from sgl_kernel import concat_mla_k +# Opt-in (default off): fuse the bf16 kv_b_proj GEMM with its nope/v split, +# k_pe cat, and fp8 cast into a single Triton kernel (the bf16 analog of the +# MXFP4 fused_gemm_afp4wfp4_split_cat path). Requires a recent aiter exposing +# fused_gemm_a16w16_split_cat; falls back to the unfused path if not present. +_has_fused_gemm_a16w16_split_cat = False if _use_aiter_gfx95: from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant + try: + from aiter.ops.triton.gemm.fused.fused_gemm_a16w16_split_cat import ( + fused_gemm_a16w16_split_cat, + ) + + _has_fused_gemm_a16w16_split_cat = True + except ImportError: + pass + +_use_fused_kvb_split_cat = ( + envs.SGLANG_AITER_FUSED_KVB_SPLIT_CAT.get() + and _use_aiter_gfx95 + and _has_fused_gemm_a16w16_split_cat +) + def _resolve_attn_backend(forward_batch: ForwardBatch): backend = get_attn_backend() @@ -289,6 +309,23 @@ def forward_normal_prepare( fp8_dtype, ) )[0] + elif ( + _use_fp8_prefill_attn + and _use_fused_kvb_split_cat + and self.kv_b_proj.weight.dtype in (torch.bfloat16, torch.float16) + ): + # BF16 weights + FP8 prefill: fuse the kv_b_proj GEMM, nope/v split, + # and k_pe cat into a single kernel (fused_gemm_a16w16_split_cat) + # that writes k and v directly in FP8, avoiding separate split / cat + # / float8_copy passes. + k, v = fused_gemm_a16w16_split_cat( + x=kv_a, + w=self.kv_b_proj.weight, + y=k_pe.expand(-1, self.num_local_heads, -1), + S1=self.qk_nope_head_dim, + S2=self.v_head_dim, + dtype=fp8_dtype, + ) else: if _use_aiter_gfx95 and self.kv_b_proj.weight.dtype == torch.float8_e4m3fn: kv = self.kv_b_proj(kv_a_quanted)[0]