Skip to content

[Triton] Add fused_gemm_a16w16_split_cat#3940

Open
rbrugaro-amd wants to merge 1 commit into
ROCm:mainfrom
rbrugaro-amd:rbrugaro/fused-gemm-a16w16-split-cat
Open

[Triton] Add fused_gemm_a16w16_split_cat#3940
rbrugaro-amd wants to merge 1 commit into
ROCm:mainfrom
rbrugaro-amd:rbrugaro/fused-gemm-a16w16-split-cat

Conversation

@rbrugaro-amd

Copy link
Copy Markdown
Contributor

[Triton] Add fused_gemm_a16w16_split_cat

Motivation

fused_gemm_afp4wfp4_split_cat and fused_gemm_a8w8_blockscale_split_cat
already fuse a GEMM with a nope/v split + auxiliary concat + dtype cast epilogue
for fp4 and fp8 weights. There was no equivalent for bf16/fp16 weights.

DeepSeek-R1 MLA prefill expands the compact latent KV through kv_b_proj
(bf16 weights) and then materializes per-head K/V by splitting nope/v,
concatenating k_pe, and casting to fp8 — several separate elementwise passes
around the GEMM. This kernel fuses that epilogue for the bf16-weight case,
mirroring the existing fp4/fp8 split_cat kernels so the SGLang MLA prefill path
can use it (companion SGLang PR wires it behind an opt-in flag).

Modifications

  • New kernel aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a16w16_split_cat.py:
    C = X @ W^T (bf16/fp16 inputs, fp32 accumulate, no quant scales), reshape the
    N axis to [D, S1 + S2], write c1 = [M, D, S1 + S3] (S1 from C, S3 from Y) and
    c2 = [M, D, S2] directly in the output dtype (e.g. fp8_e4m3).
  • New wrapper aiter/ops/triton/gemm/fused/fused_gemm_a16w16_split_cat.py:
    fused_gemm_a16w16_split_cat(x, w, y, S1, S2, dtype=fp8_e4m3, config=None).
    Reuses fused_gemm_a16w16_quant_x._get_config for tile selection.
  • Register the op in aiter/ops/triton/__init__.py (lazy-import map), next to
    the other *_split_cat entries.
  • op_test op_tests/triton_tests/gemm/fused/test_fused_gemm_a16w16_split_cat.py
    checking the GEMM+split+cat result against a torch reference for bf16 and
    fp8_e4m3 outputs across representative shapes, including the DeepSeek-R1
    kv_b shape (N=32768, K=512, D=128).

No existing behavior changes (purely additive).

Equivalence

The kernel is equivalent to:

c = (x @ w.T).view(-1, D, S1 + S2)
c1, c2 = c.split([S1, S2], dim=-1)
c1 = torch.cat([c1, y], dim=-1)        # y = k_pe, broadcast across D
return c1.to(dtype), c2.to(dtype)

The op_test asserts this against the torch reference (bf16: atol=1e-2; fp8:
compared in fp32 within one e4m3 step).

Test status

  • Code style: black and ruff clean on the added files (aiter CI hooks).
  • op_test: 28/28 passed on MI355X (gfx950) — bf16 output matches the torch
    reference exactly; fp8_e4m3 output matches within one e4m3 step, across the
    parametrized shapes including the DeepSeek-R1 kv_b shape (N=32768, K=512, D=128).
  • Validated end-to-end in the DeepSeek-R1 MI355X deployment image (gsm8k + serving).

Kernel microbenchmark

Fused kernel vs the unfused GEMM + nope/v split + k_pe cat + 2× fp8-cast
sequence at the DeepSeek-R1 kv_b shape (K=512, N=32768, D=128, S1=S2=128,
S3=64), MI355X / gfx950, triton.testing.do_bench, output dtype fp8_e4m3:

M (tokens) unfused (ms) fused (ms) speedup
256 0.0475 0.0263 1.81×
512 0.0849 0.0469 1.81×
1024 0.1508 0.0446 3.39×
2048 0.2680 0.0879 3.05×
4096 0.5031 0.2127 2.37×

Outputs match the unfused sequence at every size (bf16 GEMM + fp8 cast, within
one e4m3 step). Folding the split/cat/cast into the GEMM epilogue removes the
separate elementwise passes and their HBM round-trips, giving ~1.8–3.4×.

Notes

  • ROCm/gfx95 Triton path; additive only.
  • Depends on existing fused_gemm_a16w16_quant_x._get_config and the standard
    triton utils (pid_grid, remap_xcd, make_kernel_repr).

Add a bf16 (a16w16) analog of fused_gemm_afp4wfp4_split_cat /
fused_gemm_a8w8_blockscale_split_cat: a single Triton kernel that computes
C = X @ W^T from bf16/fp16 inputs (no quant scales), reshapes the N axis to
[D, S1 + S2], splits it, concatenates an auxiliary tensor Y onto the first
part, and writes both outputs directly in the requested output dtype
(e.g. fp8_e4m3) -- folding a split, a concat, and a dtype cast into the GEMM
epilogue.

Motivation: DeepSeek-R1 MLA prefill expands the compact latent KV through
kv_b_proj (bf16 weights) and then materializes per-head K/V by splitting
nope/v, concatenating k_pe, and casting to fp8 -- several separate
elementwise passes around the GEMM. This kernel fuses that epilogue for the
bf16-weight case, mirroring the existing fp4/fp8 split_cat kernels.

Adds the kernel, the public wrapper (registered in ops/triton/__init__.py),
and an op_test that checks the GEMM+split+cat result against a torch
reference for bf16 and fp8_e4m3 outputs across representative shapes
(including the DeepSeek-R1 kv_b shape N=32768, K=512, D=128).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Rita Brugarolas Brufau <rita.brugarolasbrufau@amd.com>
@rbrugaro-amd rbrugaro-amd requested a review from a team June 25, 2026 23:17
@github-actions

Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3940 --add-label <label>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant