diff --git a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py index b778c74b17..1d29258387 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py @@ -84,7 +84,7 @@ }} else if (block_m == 128) {{ - if (inter_dim <= 192) + if (inter_dim <= 192 || inter_dim % 128 != 0) {{ return ck_moe_stage1_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V1, 256, 128, 64, 128/sizeof({A0DataType}), 1, 4, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; }} @@ -95,7 +95,7 @@ }} else if (block_m == 256) {{ - if (inter_dim <= 192) + if (inter_dim <= 192 || inter_dim % 128 != 0) {{ return ck_moe_stage1_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V1, 256, 256, 64, 128/sizeof({A0DataType}), 1, 4, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; }} @@ -370,7 +370,7 @@ {{ if (block_m == 32) {{ - if (inter_dim <= 192) + if (inter_dim <= 192 || inter_dim % 128 != 0) {{ return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V1, 256, 32, 64, 64, 1, 4, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; }} @@ -381,7 +381,7 @@ }} else if (block_m == 64) {{ - if (inter_dim <= 192) + if (inter_dim <= 192 || inter_dim % 128 != 0) {{ return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V1, 256, 64, 128, 64, 1, 4, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; }} @@ -392,7 +392,7 @@ }} else if (block_m == 128) {{ - if (inter_dim <= 192) + if (inter_dim <= 192 || inter_dim % 128 != 0) {{ return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 256, 128, 64, 64, 1, 4, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; }} @@ -403,7 +403,7 @@ }} else if (block_m == 256) {{ - if (inter_dim <= 192) + if (inter_dim <= 192 || inter_dim % 128 != 0) {{ return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 256, 256, 128, 64, 1, 4, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; }}