[Tune] Add qwen3.5-397B MXFP4 a16w16 GEMM tuning configs#3974
Open
yichiche wants to merge 1 commit into
Open
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
1342ee1 to
6f8e6f2
Compare
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Add a16w16 (bf16) GEMM tuning configs for Qwen3.5-397B-A17B-MoE-MXFP4 (TP2). In this MXFP4 MoE model the experts run on the fp4 fused-MoE path; the remaining dense bf16 GEMMs (attention q/k/v/o, router gate, KV proj, and the MTP/EAGLE draft-layer projections) go through the
a16w16path and benefit from per-shape kernel tuning.Technical Details
Adds two files under
aiter/configs/model_configs/:qwen3_5_397b_untuned_gemm.csv— 188 shapes to tuneqwen3_5_397b_bf16_tuned_gemm.csv— 187 tuned shapes (gfx950, cu_num=256)Shape selection (TP2):
10240×4096,8704×4096,4096×4096,4096×512,1024×4096,512×4096,64×4096.8192×4096(output-gated), fc fusion / o_proj4096×8192(K=8192), sharded kv256×4096.4096×512shapes at M=1024/4096 were omitted because they already exist inkimi_bf16_tuned_gemm.csv(the repo enforces one owner per shape across bf16 configs; qwen reuses kimi's entry at runtime on the same gfx950).1×bs, target verify atnum_draft_tokens×bs):1,2,4,8,16,32,48,64,96,128,192,256,384,512,768,1024(multiples of the runtime's 16/32getPaddedMlookup granularity), plus a prefill tail2048,4096,8192.16384/32768intentionally dropped (thegl=1fallback capspadded_Mat 8192 forN>4096).Tuned with
csrc/gemm_a16w16/gemm_a16w16_tune.py,--libtype asm,opus,flydsl,triton,torch,skinny --shape_grouped. hipBLASLt is excluded (opt-in--with-hipblasltnot used). flydsl wins ~68% of shapes.Test Plan
End-to-end before/after serving sweep on Qwen3.5-397B-A17B-MoE-MXFP4 (TP2,
--attention-backend aiter,AITER_FLYDSL_FORCE=1) viasglang.bench_serving, dataset=random, IL=8192 / OL=1024, concurrency 4→256 (num_prompts = conc×10).before= run without these tuned configs (default kernels);after= with the tuned configs.Test Result
Before → After (tuned), IL=8192 / OL=1024:
Mean: +2.08% total throughput, −1.94% median e2e latency. Gains are ~2–3% at conc 4–64 (decode-time bf16 GEMMs at small M), tapering to ~0 at conc=256 (memory/KV-bound saturation). No regressions.
A prefill-bound run (IL=70000 / OL=300) showed ~0–1% as expected — that workload is attention/MoE/bandwidth-bound and the heavy expert compute is on the (untuned) MXFP4 path.
Accuracy: pure config addition — no model or runtime logic changed. During tuning each candidate kernel is validated against the reference GEMM within
err_ratio ≤ 0.05, so selected kernels are numerically equivalent; no accuracy impact.Submission Checklist