Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion aiter/ops/moe_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,12 @@ def moe_cktile2stages_gemm2(
"b16": dtypes.bf16,
}

act2str_dict = {
ActivationType.Silu: "silu",
ActivationType.Gelu: "gelu",
ActivationType.GeluTanh: "gelu_tanh",
}


@functools.lru_cache(maxsize=1024)
def get_moe_stage_module(
Expand Down Expand Up @@ -509,7 +515,7 @@ def get_moe_stage_module(
quant_type = (
QuantType.per_1x128 if quant_type == QuantType.per_128x128 else quant_type
)
act = str(activation).split(".")[-1].lower()
act = act2str_dict[activation]
quant_type = str(quant_type).split(".")[-1].lower()

parts = [
Expand Down
1 change: 1 addition & 0 deletions aiter/ops/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,7 @@ def get_torch_act(aType):
ActivationType.No: lambda *a, **k: a[0],
ActivationType.Silu: F.silu,
ActivationType.Gelu: F.gelu,
ActivationType.GeluTanh: lambda x: F.gelu(x, approximate="tanh"),
}
return tmp.get(aType, NotImplementedError)

Expand Down
3 changes: 2 additions & 1 deletion aiter/utility/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,5 @@ def _convert(s):

def str2ActivationType(s):
"""Convert string to ActivationType."""
return getattr(ActivationType, s.capitalize())
name = "".join(p.capitalize() for p in s.split("_"))
return getattr(ActivationType, name)
15 changes: 13 additions & 2 deletions csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@

using MoeKernelMap = std::unordered_map<std::string, MoeKernel>;

static inline int aiter_act_to_ck(int activation)
{
switch (activation)
{
case 0: return 1; // Silu
case 1: return 0; // Gelu
case 3: return 3; // GeluTanh
default: return !activation;
}
}

// API for user aiter.ck_moe_stage1(...)

template <int stage = 1>
Expand Down Expand Up @@ -108,7 +119,7 @@ void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token
K *= 2;
}

activation = !activation;
activation = aiter_act_to_ck(activation);

auto kernel = moe_dispatch<1>(kernelName, MPerBlock, N, hidden_states.dtype().toScalarType(), w1.dtype().toScalarType(), out.dtype().toScalarType(), activation, quant_type, MulRoutedWeight, is_shuffled);

Expand Down Expand Up @@ -172,7 +183,7 @@ void ck_moe_stage2(torch::Tensor &inter_states, // [m, k], input token
K *= 2;
}

activation = !activation;
activation = aiter_act_to_ck(activation);
auto kernel = moe_dispatch<2>(kernelName, MPerBlock, K, inter_states.dtype().toScalarType(), w1.dtype().toScalarType(), out.dtype().toScalarType(), activation, quant_type, MulRoutedWeight, is_shuffled);

kernel(at::hip::getCurrentHIPStream(),
Expand Down
9 changes: 6 additions & 3 deletions csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

from chip_info import get_gfx # noqa: E402

ACT_TO_INT = {"gelu": 0, "silu": 1, "gelu_tanh": 3}
INT_TO_ACT = {v: k for k, v in ACT_TO_INT.items()}


@dataclass
class kernelInstanceGEMM1:
Expand All @@ -29,7 +32,7 @@ class kernelInstanceGEMM1:
GemmPipelineVersion: int
Nswizzle: bool = False
MulRoutedWeight: bool = False
ActOP: bool = False
ActOP: int = 0
CDEElementOp: str = "TypeCast"
QuantType: int = 1
stage: int = 1
Expand Down Expand Up @@ -59,7 +62,7 @@ def name(self) -> str:
"Nswizzle" + str(int(self.Nswizzle)),
"Quant" + str(self.QuantType),
"MulRoutedWeight" + str(int(self.MulRoutedWeight)),
"silu" if self.ActOP else "gelu",
INT_TO_ACT[int(self.ActOP)],
self.Adtype.upper(),
self.Bdtype.upper(),
self.Cdtype.upper(),
Expand Down Expand Up @@ -410,7 +413,7 @@ def get_gemm1_kernels_list(
kernels_list = {k: copy.deepcopy(v) for k, v in gemm1_kernels_dict[tag].items()}
for id, kernel in kernels_list.items():
kernel.MulRoutedWeight = MulRoutedWeight
kernel.ActOP = ActOP == "silu"
kernel.ActOP = ACT_TO_INT[ActOP]
kernel.Nswizzle = Nswizzle
kernel.QuantType = QuantType
kernel.Adtype = Adtype
Expand Down
19 changes: 11 additions & 8 deletions csrc/ck_gemm_moe_2stages_codegen/gen_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import os
import argparse
import itertools
from gemm_moe_ck2stages_common import get_gemm1_kernels_list, get_gemm2_kernels_list
from gemm_moe_ck2stages_common import (
get_gemm1_kernels_list,
get_gemm2_kernels_list,
ACT_TO_INT,
)

STG_INSTANCE_IMPL = """// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
Expand Down Expand Up @@ -926,9 +930,7 @@ def generate_instance_and_lookUpTable(self):
Nswizzle=str(self.nswizzle).lower(),
Quant=self.quant_type,
ActOP=(
int(self.activation == "silu")
if kernel.stage == 1
else 0
ACT_TO_INT[self.activation] if kernel.stage == 1 else 0
),
Stage=kernel.stage,
BlockSize=kernel.BLOCK_SIZE,
Expand Down Expand Up @@ -958,7 +960,7 @@ def generate_instance_and_lookUpTable(self):
CDEElementOp=kernel.CDEElementOp,
Nswizzle=str(self.nswizzle).lower(),
Quant=self.quant_type,
ActOP=int(self.activation == "silu") if kernel.stage == 1 else 0,
ActOP=ACT_TO_INT[self.activation] if kernel.stage == 1 else 0,
Stage=kernel.stage,
BlockSize=kernel.BLOCK_SIZE,
MPerBlock=kernel.MPerBlock,
Expand Down Expand Up @@ -989,7 +991,7 @@ def generate_instance_and_lookUpTable(self):
CDEElementOp=kernel_list[0].CDEElementOp,
Nswizzle=str(self.nswizzle).lower(),
Quant=self.quant_type,
ActOP=str(int(self.activation == "silu")),
ActOP=str(ACT_TO_INT[self.activation]),
MulRoutedWeight=str(self.mul_routed_weight_stage == 1).lower(),
Preshuffle=str(self.preshuffle).lower(),
)
Expand Down Expand Up @@ -1071,7 +1073,7 @@ def generate_instance_and_lookUpTable(self):
default="silu",
required=False,
type=str,
choices=["silu", "gelu"],
choices=["silu", "gelu", "gelu_tanh"],
help="select activation",
)

Expand Down Expand Up @@ -1184,11 +1186,12 @@ def generate_instance_and_lookUpTable(self):
"f16",
"b16",
]
no_quant_acts = acts + ["gelu_tanh"]
for (
b_dtype,
act,
routed_weight,
) in itertools.product(b_quant_dtypes, acts, routed_weight_l):
) in itertools.product(b_quant_dtypes, no_quant_acts, routed_weight_l):
c_dtype = a_dtype = b_dtype

codegen = ck_moe_2stage_gemm_codegen(
Expand Down
9 changes: 5 additions & 4 deletions csrc/include/aiter_enum.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

enum class ActivationType : int
{
No = -1,
Silu = 0,
Gelu = 1,
Swiglu = 2,
No = -1,
Silu = 0,
Gelu = 1,
Swiglu = 2,
GeluTanh = 3,
};
enum class QuantType : int
{
Expand Down
1 change: 1 addition & 0 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace py = pybind11;
.value("Silu", ActivationType::Silu) \
.value("Gelu", ActivationType::Gelu) \
.value("Swiglu", ActivationType::Swiglu) \
.value("GeluTanh", ActivationType::GeluTanh) \
.export_values(); \
pybind11::enum_<aiter::MxScaleRoundMode>(m, "MxScaleRoundMode") \
.value("RoundDown", aiter::MxScaleRoundMode::RoundDown) \
Expand Down
Loading