Skip to content

Add gelu_tanh activation to no-quant CK 2-stage fused MoE#3972

Open
jonahbernard wants to merge 4 commits into
ROCm:mainfrom
jonahbernard:gelu-tanh-moe-bf16
Open

Add gelu_tanh activation to no-quant CK 2-stage fused MoE#3972
jonahbernard wants to merge 4 commits into
ROCm:mainfrom
jonahbernard:gelu-tanh-moe-bf16

Conversation

@jonahbernard

Copy link
Copy Markdown

Motivation

DiffusionGemma / Gemma4 MoE models use the tanh approximation of GELU (gelu_tanh) in their expert MLPs. The AITER fused 2-stage CK MoE kernel previously supported only silu and gelu (erf) activations, so on ROCm these models either fell back to the slower unfused Triton implementation or could not run the fused path with the correct activation. This PR adds gelu_tanh as an activation in the bf16 (no-quant) CK 2-stage MoE path so Gemma MoE can use the fused kernel with the mathematically correct activation.

Technical Details

Paired with this PR into CK: ROCm/rocm-libraries#8886

  • Codegen (csrc/ck_gemm_moe_2stages_codegen/): added gelu_tanh to the activation tables. ACT_TO_INT = {"gelu": 0, "silu": 1, "gelu_tanh": 3} and matching INT_TO_ACT; widened
    ActOP from bool to int. In gen_instances.py, the no-quant loop now iterates acts + ["gelu_tanh"] so the new instance is generated only for the f16/b16 no-quant families (it stays out of the quant families). gelu_tanh added to the -act CLI choices.
  • Enum remap (gemm_moe_ck2stages.cu): AITER's ActivationType enum (Silu=0, Gelu=1, Swiglu=2, GeluTanh=3) differs from CK's Activation enum (gelu_and_mul=0, silu_and_mul=1, …, gelu_tanh_and_mul=3). Replaced the old activation = !activation bool flip with an explicit aiter_act_to_ck() remap (Silu 0→1, Gelu 1→0, GeluTanh 3→3) at both stage1 and stage2 dispatch.
  • Python op map (aiter/ops/moe_op.py): added ActivationType.GeluTanh: "gelu_tanh" to act2str_dict.
  • Torch reference (aiter/ops/quant.py, get_torch_act): added GeluTanh → F.gelu(x, approximate="tanh") so the test oracle matches the CK FastGelu.
  • Arg parsing (aiter/utility/dtypes.py, str2ActivationType): fixed snake_case→CamelCase parsing so -a gelu_tanh resolves to ActivationType.GeluTanh (previously "gelu_tanh".capitalize()"Gelu_tanh"AttributeError).

Test Plan

Verified on a gfx950 box using the standalone CK 2-stage MoE accuracy test (op_tests/test_moe_2stage.py), no-quant bf16, GeluTanh activation. The test compares the fused AITER/CK
kernel against a torch reference using F.gelu(x, approximate="tanh").

 python3 op_tests/test_moe_2stage.py \
  -q 0 -a gelu_tanh -d bf16 -e 32 -k 5 -t 64 -dim 512,256 --no-flydsl-csv
  • -q 0 = no-quant (a16w16), the path this PR adds gelu_tanh to.
  • -e 32 -k 5 keeps topk ≤ experts.

Test Result

shape (model_dim, inter_dim) logits_diff result
512, 256 6.99e-06 pass

Submission Checklist

Plumb a GeluTanh activation through the CK 2-stage MoE path:
- ActivationType::GeluTanh enum + pybind export (aiter_enum.h, rocm_ops.hpp)
- explicit ActivationType->name and AITER->CK enum-value maps; the integer
  values differ (Silu 0->1, Gelu 1->0, GeluTanh 3->3), handled by
  aiter_act_to_ck() in the host dispatch and ACT_TO_INT in codegen
- ActOP widened from bool to int so codegen can emit the gelu_tanh instances
- generate the no-quant bf16/fp16 gelu_tanh instances

Requires the matching CK gelu_tanh_and_mul kernel support (rocm-libraries
gelu-tanh-moe-bf16); submodule bump deferred until that lands and mirrors.
@jonahbernard jonahbernard requested a review from a team June 27, 2026 20:07
@github-actions

Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3972 --add-label <label>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant