From 394f62a514e8a68a51f173166560479f8ea89033 Mon Sep 17 00:00:00 2001 From: jonahbernard Date: Sat, 27 Jun 2026 18:38:35 +0000 Subject: [PATCH] CK 2-stage MoE: route inter_dim % 128 != 0 shapes to PerBlock=64 instances The gfx950 heuristic dispatch sent all inter_dim > 192 shapes to the NPerBlock/KPerBlock=128 fast path, which fails CK's N%NPerBlock (stage1) and K%KPerBlock (stage2) divisibility checks when inter_dim is not a multiple of 128 (e.g. DiffusionGemma moe_inter=704=64*11). Route those shapes to the PerBlock=64 instances, which divide any multiple of 64. --- csrc/ck_gemm_moe_2stages_codegen/gen_instances.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py index b778c74b17..1d29258387 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py @@ -84,7 +84,7 @@ }} else if (block_m == 128) {{ - if (inter_dim <= 192) + if (inter_dim <= 192 || inter_dim % 128 != 0) {{ return ck_moe_stage1_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V1, 256, 128, 64, 128/sizeof({A0DataType}), 1, 4, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; }} @@ -95,7 +95,7 @@ }} else if (block_m == 256) {{ - if (inter_dim <= 192) + if (inter_dim <= 192 || inter_dim % 128 != 0) {{ return ck_moe_stage1_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V1, 256, 256, 64, 128/sizeof({A0DataType}), 1, 4, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; }} @@ -370,7 +370,7 @@ {{ if (block_m == 32) {{ - if (inter_dim <= 192) + if (inter_dim <= 192 || inter_dim % 128 != 0) {{ return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V1, 256, 32, 64, 64, 1, 4, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; }} @@ -381,7 +381,7 @@ }} else if (block_m == 64) {{ - if (inter_dim <= 192) + if (inter_dim <= 192 || inter_dim % 128 != 0) {{ return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V1, 256, 64, 128, 64, 1, 4, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; }} @@ -392,7 +392,7 @@ }} else if (block_m == 128) {{ - if (inter_dim <= 192) + if (inter_dim <= 192 || inter_dim % 128 != 0) {{ return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 256, 128, 64, 64, 1, 4, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; }} @@ -403,7 +403,7 @@ }} else if (block_m == 256) {{ - if (inter_dim <= 192) + if (inter_dim <= 192 || inter_dim % 128 != 0) {{ return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 256, 256, 128, 64, 1, 4, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; }}