Skip to content

[AMD] Fuse bf16 kv_b_proj prefill epilogue via fused_gemm_a16w16_split_cat#29344

Open
rbrugaro-amd wants to merge 1 commit into
sgl-project:mainfrom
rbrugaro-amd:rbrugaro/mla-prefill-fused-kvb-split-cat
Open

[AMD] Fuse bf16 kv_b_proj prefill epilogue via fused_gemm_a16w16_split_cat#29344
rbrugaro-amd wants to merge 1 commit into
sgl-project:mainfrom
rbrugaro-amd:rbrugaro/mla-prefill-fused-kvb-split-cat

Conversation

@rbrugaro-amd

@rbrugaro-amd rbrugaro-amd commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

[AMD] Fuse bf16 kv_b_proj prefill epilogue via aiter fused_gemm_a16w16_split_cat

Motivation

In DeepSeek-R1 MLA prefill (the MHA path on the aiter backend), the compact
latent KV is expanded through kv_b_proj and then materialized into per-head
K/V by:

  1. the kv_b_proj GEMM,
  2. splitting the output into k_nope / v,
  3. concatenating k_pe onto k_nope to form K (concat_and_cast_mha_k),
  4. (for the cached prefix) a torch.cat of prefix+current K/V
    (CatArrayBatchedCopy),
  5. casting K and V to FP8 (float8_copy).

For MXFP4 weights this epilogue is already fused via
fused_gemm_afp4wfp4_split_cat. The bf16-weight path was left unfused and
pays the full set of separate elementwise passes around the GEMM.

A DeepSeek-R1 MXFP4 prefill trace on MI355X shows this cluster
(concat_and_cast_mha_k, CatArrayBatchedCopy, K/V float8_copy) is a large
fraction of MLA-prefill data movement.

Modifications

  • Register opt-in flag SGLANG_AITER_FUSED_KVB_SPLIT_CAT (EnvBool(False)) in
    python/sglang/srt/environ.py, next to the other ROCm/AITER MLA flags.
  • Route the bf16 kv_b_proj materialization through aiter's
    fused_gemm_a16w16_split_cat (GEMM + nope/v split + k_pe cat + FP8 write in one
    kernel) when the flag is on, FP8 prefill attention is active, and the weights
    are bf16/fp16. Wired at both kv_b materialization sites:
    • current-token path in
      deepseek_common/attention_forward_methods/forward_mha.py
      (forward_normal_prepare),
    • gathered-prefix path in layers/attention/aiter_backend.py
      (forward_extend).
  • The unfused path is unchanged and remains the default fallback.

+63 lines across environ.py (+3), aiter_backend.py (+32), forward_mha.py
(+28).

Dependency

Requires aiter fused_gemm_a16w16_split_cat — companion aiter PR:
ROCm/aiter#3940 (ROCm/aiter#3940). The import is
guarded by a try/except ImportError capability check (_has_fused_gemm_a16w16_split_cat)
and the _use_aiter_gfx95 / flag gate, so this PR is safe to merge before the
aiter side lands: if the pinned aiter lacks the kernel, the flag is forced off
and the existing unfused path runs unchanged.

Before and After this PR

Before:
image
After"
image

Accuracy

This swaps the Tensile/hipBLASLt bf16 kv_b_proj GEMM for a Triton bf16 GEMM
with the split/cat/cast folded into the epilogue (output written in FP8), so it
is not bit-identical — but it is accuracy-neutral within run-to-run noise.

GSM8K (DeepSeek-R1 MXFP4, 8×MI355X, 1319 questions, 5-shot, greedy):

Accuracy
baseline (flag off) — 3 repeats 0.952 / 0.948 / 0.944
split_cat (flag on) 0.943

The baseline alone spans 0.944–0.952 across identical repeats (DeepSeek-R1 +
DP-attention + fp8 dynamic batching is nondeterministic even at temperature 0),
and the flag-on result sits inside that band. Default-off, so no behavior change
unless explicitly enabled.

Benchmarking and Profiling

DeepSeek-R1 MXFP4, 8×MI355X, aiter backend, MoRI-EP.

sglang.bench_serving (random, 500 prompts, in≈1024 / out≈1024, rate=inf),
split_cat off vs on:

Metric off on Δ
Total token throughput (tok/s) 25945 26378 +1.7%
Prefill (input) throughput (tok/s) 12979 13196 +1.7%
Mean TTFT (ms) 3216 2780 −13.6%
Median TTFT (ms) 3151 2643 −16.1%
Mean TPOT (ms) 31.64 31.50 ~same

As expected for a prefill-only fusion, TTFT improves materially while decode
(TPOT) is unchanged. A prefill trace confirms the kernel swap:
concat_and_cast_mha_k, the prefix CatArrayBatchedCopy, and the K/V
float8_copy kernels are replaced by a single fused_gemm_a16w16_split_cat per
MLA layer.

Checklist

  • Rebased on latest main.
  • Commit signed off (DCO).
  • Flag is opt-in and defaults to off (no behavior change by default).
  • Lint/format (pre-commit) pass (isort, ruff, black, codespell).
  • Companion aiter PR merged / available in the pinned aiter.

CI States

Latest PR Test (Base): ❌ Run #28206637152
Latest PR Test (Extra): ❌ Run #28206637080

…16_split_cat

In DeepSeek-R1 MLA prefill (MHA path) the compact latent KV is expanded
through kv_b_proj and then materialized into per-head K/V by splitting
nope/v, concatenating k_pe, and casting to FP8 -- several separate
elementwise passes (concat_and_cast, CatArrayBatchedCopy, float8 casts)
around the GEMM. For MXFP4 weights this is already fused via
fused_gemm_afp4wfp4_split_cat; the bf16-weight path was left unfused.

Add an opt-in flag SGLANG_AITER_FUSED_KVB_SPLIT_CAT (EnvBool, default off,
gfx95 only) that routes the bf16 kv_b_proj materialization through aiter's
fused_gemm_a16w16_split_cat kernel, which folds the GEMM, the nope/v split,
the k_pe cat, and the FP8 write into one launch. Wired at both kv_b sites:
the current-token materialization in forward_mha (forward_normal_prepare)
and the gathered-prefix materialization in the aiter backend forward_extend.
The unfused path is unchanged and remains the default fallback.

Requires aiter's fused_gemm_a16w16_split_cat (ROCm/aiter). Numerically this
is the same bf16 GEMM with the split/cat/cast folded into the epilogue and
the result written in FP8, matching the unfused path.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Rita Brugarolas Brufau <rita.brugarolasbrufau@amd.com>
@gemini-code-assist

Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant