[Triton] Add fused_gemm_a16w16_split_cat#3940
Open
rbrugaro-amd wants to merge 1 commit into
Open
Conversation
Add a bf16 (a16w16) analog of fused_gemm_afp4wfp4_split_cat / fused_gemm_a8w8_blockscale_split_cat: a single Triton kernel that computes C = X @ W^T from bf16/fp16 inputs (no quant scales), reshapes the N axis to [D, S1 + S2], splits it, concatenates an auxiliary tensor Y onto the first part, and writes both outputs directly in the requested output dtype (e.g. fp8_e4m3) -- folding a split, a concat, and a dtype cast into the GEMM epilogue. Motivation: DeepSeek-R1 MLA prefill expands the compact latent KV through kv_b_proj (bf16 weights) and then materializes per-head K/V by splitting nope/v, concatenating k_pe, and casting to fp8 -- several separate elementwise passes around the GEMM. This kernel fuses that epilogue for the bf16-weight case, mirroring the existing fp4/fp8 split_cat kernels. Adds the kernel, the public wrapper (registered in ops/triton/__init__.py), and an op_test that checks the GEMM+split+cat result against a torch reference for bf16 and fp8_e4m3 outputs across representative shapes (including the DeepSeek-R1 kv_b shape N=32768, K=512, D=128). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Rita Brugarolas Brufau <rita.brugarolasbrufau@amd.com>
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Open
5 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
[Triton] Add
fused_gemm_a16w16_split_catMotivation
fused_gemm_afp4wfp4_split_catandfused_gemm_a8w8_blockscale_split_catalready fuse a GEMM with a nope/v split + auxiliary concat + dtype cast epilogue
for fp4 and fp8 weights. There was no equivalent for bf16/fp16 weights.
DeepSeek-R1 MLA prefill expands the compact latent KV through
kv_b_proj(bf16 weights) and then materializes per-head K/V by splitting nope/v,
concatenating
k_pe, and casting to fp8 — several separate elementwise passesaround the GEMM. This kernel fuses that epilogue for the bf16-weight case,
mirroring the existing fp4/fp8 split_cat kernels so the SGLang MLA prefill path
can use it (companion SGLang PR wires it behind an opt-in flag).
Modifications
aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a16w16_split_cat.py:C = X @ W^T(bf16/fp16 inputs, fp32 accumulate, no quant scales), reshape theN axis to
[D, S1 + S2], writec1 = [M, D, S1 + S3](S1 from C, S3 from Y) andc2 = [M, D, S2]directly in the output dtype (e.g.fp8_e4m3).aiter/ops/triton/gemm/fused/fused_gemm_a16w16_split_cat.py:fused_gemm_a16w16_split_cat(x, w, y, S1, S2, dtype=fp8_e4m3, config=None).Reuses
fused_gemm_a16w16_quant_x._get_configfor tile selection.aiter/ops/triton/__init__.py(lazy-import map), next tothe other
*_split_catentries.op_tests/triton_tests/gemm/fused/test_fused_gemm_a16w16_split_cat.pychecking the GEMM+split+cat result against a torch reference for
bf16andfp8_e4m3outputs across representative shapes, including the DeepSeek-R1kv_bshape (N=32768, K=512, D=128).No existing behavior changes (purely additive).
Equivalence
The kernel is equivalent to:
The op_test asserts this against the torch reference (bf16:
atol=1e-2; fp8:compared in fp32 within one e4m3 step).
Test status
blackandruffclean on the added files (aiter CI hooks).reference exactly; fp8_e4m3 output matches within one e4m3 step, across the
parametrized shapes including the DeepSeek-R1 kv_b shape (N=32768, K=512, D=128).
Kernel microbenchmark
Fused kernel vs the unfused
GEMM + nope/v split + k_pe cat + 2× fp8-castsequence at the DeepSeek-R1
kv_bshape (K=512, N=32768, D=128, S1=S2=128,S3=64), MI355X / gfx950,
triton.testing.do_bench, output dtype fp8_e4m3:Outputs match the unfused sequence at every size (bf16 GEMM + fp8 cast, within
one e4m3 step). Folding the split/cat/cast into the GEMM epilogue removes the
separate elementwise passes and their HBM round-trips, giving ~1.8–3.4×.
Notes
fused_gemm_a16w16_quant_x._get_configand the standardtriton utils (
pid_grid,remap_xcd,make_kernel_repr).