[AMD] Fuse bf16 kv_b_proj prefill epilogue via fused_gemm_a16w16_split_cat#29344
Open
rbrugaro-amd wants to merge 1 commit into
Open
[AMD] Fuse bf16 kv_b_proj prefill epilogue via fused_gemm_a16w16_split_cat#29344rbrugaro-amd wants to merge 1 commit into
rbrugaro-amd wants to merge 1 commit into
Conversation
…16_split_cat In DeepSeek-R1 MLA prefill (MHA path) the compact latent KV is expanded through kv_b_proj and then materialized into per-head K/V by splitting nope/v, concatenating k_pe, and casting to FP8 -- several separate elementwise passes (concat_and_cast, CatArrayBatchedCopy, float8 casts) around the GEMM. For MXFP4 weights this is already fused via fused_gemm_afp4wfp4_split_cat; the bf16-weight path was left unfused. Add an opt-in flag SGLANG_AITER_FUSED_KVB_SPLIT_CAT (EnvBool, default off, gfx95 only) that routes the bf16 kv_b_proj materialization through aiter's fused_gemm_a16w16_split_cat kernel, which folds the GEMM, the nope/v split, the k_pe cat, and the FP8 write into one launch. Wired at both kv_b sites: the current-token materialization in forward_mha (forward_normal_prepare) and the gathered-prefix materialization in the aiter backend forward_extend. The unfused path is unchanged and remains the default fallback. Requires aiter's fused_gemm_a16w16_split_cat (ROCm/aiter). Numerically this is the same bf16 GEMM with the split/cat/cast folded into the epilogue and the result written in FP8, matching the unfused path. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Rita Brugarolas Brufau <rita.brugarolasbrufau@amd.com>
Contributor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
[AMD] Fuse bf16
kv_b_projprefill epilogue via aiterfused_gemm_a16w16_split_catMotivation
In DeepSeek-R1 MLA prefill (the MHA path on the aiter backend), the compact
latent KV is expanded through
kv_b_projand then materialized into per-headK/V by:
kv_b_projGEMM,k_nope/v,k_peontok_nopeto form K (concat_and_cast_mha_k),torch.catof prefix+current K/V(
CatArrayBatchedCopy),float8_copy).For MXFP4 weights this epilogue is already fused via
fused_gemm_afp4wfp4_split_cat. The bf16-weight path was left unfused andpays the full set of separate elementwise passes around the GEMM.
A DeepSeek-R1 MXFP4 prefill trace on MI355X shows this cluster
(
concat_and_cast_mha_k,CatArrayBatchedCopy, K/Vfloat8_copy) is a largefraction of MLA-prefill data movement.
Modifications
SGLANG_AITER_FUSED_KVB_SPLIT_CAT(EnvBool(False)) inpython/sglang/srt/environ.py, next to the other ROCm/AITER MLA flags.kv_b_projmaterialization through aiter'sfused_gemm_a16w16_split_cat(GEMM + nope/v split + k_pe cat + FP8 write in onekernel) when the flag is on, FP8 prefill attention is active, and the weights
are bf16/fp16. Wired at both
kv_bmaterialization sites:deepseek_common/attention_forward_methods/forward_mha.py(
forward_normal_prepare),layers/attention/aiter_backend.py(
forward_extend).+63 lines across
environ.py(+3),aiter_backend.py(+32),forward_mha.py(+28).
Dependency
Requires aiter
fused_gemm_a16w16_split_cat— companion aiter PR:ROCm/aiter#3940 (ROCm/aiter#3940). The import is
guarded by a
try/except ImportErrorcapability check (_has_fused_gemm_a16w16_split_cat)and the
_use_aiter_gfx95/ flag gate, so this PR is safe to merge before theaiter side lands: if the pinned aiter lacks the kernel, the flag is forced off
and the existing unfused path runs unchanged.
Before and After this PR
Before:


After"
Accuracy
This swaps the Tensile/hipBLASLt bf16
kv_b_projGEMM for a Triton bf16 GEMMwith the split/cat/cast folded into the epilogue (output written in FP8), so it
is not bit-identical — but it is accuracy-neutral within run-to-run noise.
GSM8K (DeepSeek-R1 MXFP4, 8×MI355X, 1319 questions, 5-shot, greedy):
The baseline alone spans 0.944–0.952 across identical repeats (DeepSeek-R1 +
DP-attention + fp8 dynamic batching is nondeterministic even at temperature 0),
and the flag-on result sits inside that band. Default-off, so no behavior change
unless explicitly enabled.
Benchmarking and Profiling
DeepSeek-R1 MXFP4, 8×MI355X, aiter backend, MoRI-EP.
sglang.bench_serving(random, 500 prompts, in≈1024 / out≈1024, rate=inf),split_cat off vs on:
As expected for a prefill-only fusion, TTFT improves materially while decode
(TPOT) is unchanged. A prefill trace confirms the kernel swap:
concat_and_cast_mha_k, the prefixCatArrayBatchedCopy, and the K/Vfloat8_copykernels are replaced by a singlefused_gemm_a16w16_split_catperMLA layer.
Checklist
main.CI States
Latest PR Test (Base): ❌ Run #28206637152
Latest PR Test (Extra): ❌ Run #28206637080