diff --git a/aiter/ops/moe_op.py b/aiter/ops/moe_op.py index cffe598634..91d1afbf6c 100755 --- a/aiter/ops/moe_op.py +++ b/aiter/ops/moe_op.py @@ -480,6 +480,12 @@ def moe_cktile2stages_gemm2( "b16": dtypes.bf16, } +act2str_dict = { + ActivationType.Silu: "silu", + ActivationType.Gelu: "gelu", + ActivationType.GeluTanh: "gelu_tanh", +} + @functools.lru_cache(maxsize=1024) def get_moe_stage_module( @@ -509,7 +515,7 @@ def get_moe_stage_module( quant_type = ( QuantType.per_1x128 if quant_type == QuantType.per_128x128 else quant_type ) - act = str(activation).split(".")[-1].lower() + act = act2str_dict[activation] quant_type = str(quant_type).split(".")[-1].lower() parts = [ diff --git a/aiter/ops/quant.py b/aiter/ops/quant.py index 9c3a9ff4dd..f337694235 100644 --- a/aiter/ops/quant.py +++ b/aiter/ops/quant.py @@ -645,6 +645,7 @@ def get_torch_act(aType): ActivationType.No: lambda *a, **k: a[0], ActivationType.Silu: F.silu, ActivationType.Gelu: F.gelu, + ActivationType.GeluTanh: lambda x: F.gelu(x, approximate="tanh"), } return tmp.get(aType, NotImplementedError) diff --git a/aiter/utility/dtypes.py b/aiter/utility/dtypes.py index 3f90b6f152..f5606a58eb 100644 --- a/aiter/utility/dtypes.py +++ b/aiter/utility/dtypes.py @@ -142,4 +142,5 @@ def _convert(s): def str2ActivationType(s): """Convert string to ActivationType.""" - return getattr(ActivationType, s.capitalize()) + name = "".join(p.capitalize() for p in s.split("_")) + return getattr(ActivationType, name) diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu index b92dabad12..ad5440ebc8 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu @@ -13,6 +13,17 @@ using MoeKernelMap = std::unordered_map; +static inline int aiter_act_to_ck(int activation) +{ + switch (activation) + { + case 0: return 1; // Silu + case 1: return 0; // Gelu + case 3: return 3; // GeluTanh + default: return !activation; + } +} + // API for user aiter.ck_moe_stage1(...) template @@ -108,7 +119,7 @@ void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token K *= 2; } - activation = !activation; + activation = aiter_act_to_ck(activation); auto kernel = moe_dispatch<1>(kernelName, MPerBlock, N, hidden_states.dtype().toScalarType(), w1.dtype().toScalarType(), out.dtype().toScalarType(), activation, quant_type, MulRoutedWeight, is_shuffled); @@ -172,7 +183,7 @@ void ck_moe_stage2(torch::Tensor &inter_states, // [m, k], input token K *= 2; } - activation = !activation; + activation = aiter_act_to_ck(activation); auto kernel = moe_dispatch<2>(kernelName, MPerBlock, K, inter_states.dtype().toScalarType(), w1.dtype().toScalarType(), out.dtype().toScalarType(), activation, quant_type, MulRoutedWeight, is_shuffled); kernel(at::hip::getCurrentHIPStream(), diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py index f035a7079d..578c158ba0 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py @@ -17,6 +17,9 @@ from chip_info import get_gfx # noqa: E402 +ACT_TO_INT = {"gelu": 0, "silu": 1, "gelu_tanh": 3} +INT_TO_ACT = {v: k for k, v in ACT_TO_INT.items()} + @dataclass class kernelInstanceGEMM1: @@ -29,7 +32,7 @@ class kernelInstanceGEMM1: GemmPipelineVersion: int Nswizzle: bool = False MulRoutedWeight: bool = False - ActOP: bool = False + ActOP: int = 0 CDEElementOp: str = "TypeCast" QuantType: int = 1 stage: int = 1 @@ -59,7 +62,7 @@ def name(self) -> str: "Nswizzle" + str(int(self.Nswizzle)), "Quant" + str(self.QuantType), "MulRoutedWeight" + str(int(self.MulRoutedWeight)), - "silu" if self.ActOP else "gelu", + INT_TO_ACT[int(self.ActOP)], self.Adtype.upper(), self.Bdtype.upper(), self.Cdtype.upper(), @@ -410,7 +413,7 @@ def get_gemm1_kernels_list( kernels_list = {k: copy.deepcopy(v) for k, v in gemm1_kernels_dict[tag].items()} for id, kernel in kernels_list.items(): kernel.MulRoutedWeight = MulRoutedWeight - kernel.ActOP = ActOP == "silu" + kernel.ActOP = ACT_TO_INT[ActOP] kernel.Nswizzle = Nswizzle kernel.QuantType = QuantType kernel.Adtype = Adtype diff --git a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py index b778c74b17..535147da63 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py @@ -3,7 +3,11 @@ import os import argparse import itertools -from gemm_moe_ck2stages_common import get_gemm1_kernels_list, get_gemm2_kernels_list +from gemm_moe_ck2stages_common import ( + get_gemm1_kernels_list, + get_gemm2_kernels_list, + ACT_TO_INT, +) STG_INSTANCE_IMPL = """// SPDX-License-Identifier: MIT // Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. @@ -926,9 +930,7 @@ def generate_instance_and_lookUpTable(self): Nswizzle=str(self.nswizzle).lower(), Quant=self.quant_type, ActOP=( - int(self.activation == "silu") - if kernel.stage == 1 - else 0 + ACT_TO_INT[self.activation] if kernel.stage == 1 else 0 ), Stage=kernel.stage, BlockSize=kernel.BLOCK_SIZE, @@ -958,7 +960,7 @@ def generate_instance_and_lookUpTable(self): CDEElementOp=kernel.CDEElementOp, Nswizzle=str(self.nswizzle).lower(), Quant=self.quant_type, - ActOP=int(self.activation == "silu") if kernel.stage == 1 else 0, + ActOP=ACT_TO_INT[self.activation] if kernel.stage == 1 else 0, Stage=kernel.stage, BlockSize=kernel.BLOCK_SIZE, MPerBlock=kernel.MPerBlock, @@ -989,7 +991,7 @@ def generate_instance_and_lookUpTable(self): CDEElementOp=kernel_list[0].CDEElementOp, Nswizzle=str(self.nswizzle).lower(), Quant=self.quant_type, - ActOP=str(int(self.activation == "silu")), + ActOP=str(ACT_TO_INT[self.activation]), MulRoutedWeight=str(self.mul_routed_weight_stage == 1).lower(), Preshuffle=str(self.preshuffle).lower(), ) @@ -1071,7 +1073,7 @@ def generate_instance_and_lookUpTable(self): default="silu", required=False, type=str, - choices=["silu", "gelu"], + choices=["silu", "gelu", "gelu_tanh"], help="select activation", ) @@ -1184,11 +1186,12 @@ def generate_instance_and_lookUpTable(self): "f16", "b16", ] + no_quant_acts = acts + ["gelu_tanh"] for ( b_dtype, act, routed_weight, - ) in itertools.product(b_quant_dtypes, acts, routed_weight_l): + ) in itertools.product(b_quant_dtypes, no_quant_acts, routed_weight_l): c_dtype = a_dtype = b_dtype codegen = ck_moe_2stage_gemm_codegen( diff --git a/csrc/include/aiter_enum.h b/csrc/include/aiter_enum.h index df56bdb518..ed8b795156 100644 --- a/csrc/include/aiter_enum.h +++ b/csrc/include/aiter_enum.h @@ -6,10 +6,11 @@ enum class ActivationType : int { - No = -1, - Silu = 0, - Gelu = 1, - Swiglu = 2, + No = -1, + Silu = 0, + Gelu = 1, + Swiglu = 2, + GeluTanh = 3, }; enum class QuantType : int { diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 319940594b..7fbe44a498 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -31,6 +31,7 @@ namespace py = pybind11; .value("Silu", ActivationType::Silu) \ .value("Gelu", ActivationType::Gelu) \ .value("Swiglu", ActivationType::Swiglu) \ + .value("GeluTanh", ActivationType::GeluTanh) \ .export_values(); \ pybind11::enum_(m, "MxScaleRoundMode") \ .value("RoundDown", aiter::MxScaleRoundMode::RoundDown) \