Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 10 additions & 46 deletions csrc/include/aiter_opus_plus.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,9 @@ OPUS_D decltype(auto) fp32_to_fp8_scaled_x2(const S& s, float inverted_scale)
float a = tmp[0], b = tmp[1];
#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx1200__) || \
defined(__gfx1201__) || defined(__gfx1250__)
int w;
asm volatile("v_med3_f32 %1, %1, %3, %4\n"
"v_med3_f32 %2, %2, %3, %4\n"
"v_cvt_pk_fp8_f32 %0, %1, %2"
: "=v"(w), "+v"(a), "+v"(b)
: "v"(lo), "v"(hi));
return __builtin_bit_cast(fp8x2_t, static_cast<int16_t>(w));
a = med3<float>(a, lo, hi);
b = med3<float>(b, lo, hi);
return fp32_to_fp8_packed_x2(fp32x2_t{a, b});
#else
// Arches without packed fp8-cvt (RDNA3/3.5, host): compile-only stub.
// fp8 KV-cache is unused on these arches; never executed at runtime.
Expand All @@ -85,13 +81,9 @@ OPUS_D decltype(auto) fp32_to_bf8_scaled_x2(const S& s, float inverted_scale)
float a = tmp[0], b = tmp[1];
#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx1200__) || \
defined(__gfx1201__) || defined(__gfx1250__)
int w;
asm volatile("v_med3_f32 %1, %1, %3, %4\n"
"v_med3_f32 %2, %2, %3, %4\n"
"v_cvt_pk_bf8_f32 %0, %1, %2"
: "=v"(w), "+v"(a), "+v"(b)
: "v"(lo), "v"(hi));
return __builtin_bit_cast(bf8x2_t, static_cast<int16_t>(w));
a = med3<float>(a, lo, hi);
b = med3<float>(b, lo, hi);
return fp32_to_bf8_packed_x2(fp32x2_t{a, b});
#else
(void)a; (void)b; (void)lo; (void)hi; return bf8x2_t{};
#endif
Expand Down Expand Up @@ -127,50 +119,22 @@ OPUS_D decltype(auto) fp32_to_i8_scaled_x4(const S& s, float inverted_scale)

/////////////////////////////////////////////////////////////////////////////////////////////////////////
// fp16x2 -> fp4 with scale (v_cvt_scalef32_pk_fp4_f16, gfx950 only)
// opus.hpp has fp32->fp4 and bf16->fp4 but NOT fp16->fp4
#if defined(__gfx950__)
// delegates to opus fp16_to_fp4_packed_x2/x4/x8 (arch handling lives in opus)
template <typename S, index_t sel = 0, std::enable_if_t<std::is_same_v<S, fp16x2_t>, bool> = true>
OPUS_D constexpr decltype(auto) fp16_to_fp4_scaled_x2(const S& s, float scale, number<sel> = {})
{
u32_t w;
w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, s, scale, sel);
return __builtin_bit_cast(array<fp4_t, 1>, static_cast<u8_t>(w));
return fp16_to_fp4_packed_x2(s, scale, number<sel>{});
}
template <typename S, std::enable_if_t<std::is_same_v<S, fp16x4_t>, bool> = true>
OPUS_D constexpr decltype(auto) fp16_to_fp4_scaled_x4(const S& s, float scale)
{
u32_t w;
w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[0], s[1]}, scale, 0);
w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[2], s[3]}, scale, 1);
return __builtin_bit_cast(array<fp4_t, 2>, static_cast<u16_t>(w));
return fp16_to_fp4_packed_x4(s, scale);
}
template <typename S, std::enable_if_t<std::is_same_v<S, fp16x8_t>, bool> = true>
OPUS_D constexpr decltype(auto) fp16_to_fp4_scaled_x8(const S& s, float scale)
{
u32_t w;
w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[0], s[1]}, scale, 0);
w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[2], s[3]}, scale, 1);
w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[4], s[5]}, scale, 2);
w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[6], s[7]}, scale, 3);
return __builtin_bit_cast(array<fp4_t, 4>, w);
}
#else
template <typename S, std::enable_if_t<std::is_same_v<S, fp16x2_t>, bool> = true>
OPUS_D constexpr decltype(auto) fp16_to_fp4_scaled_x2(const S&, float)
{
return array<fp4_t, 1>{};
return fp16_to_fp4_packed_x8(s, scale);
}
template <typename S, std::enable_if_t<std::is_same_v<S, fp16x4_t>, bool> = true>
OPUS_D constexpr decltype(auto) fp16_to_fp4_scaled_x4(const S&, float)
{
return array<fp4_t, 2>{};
}
template <typename S, std::enable_if_t<std::is_same_v<S, fp16x8_t>, bool> = true>
OPUS_D constexpr decltype(auto) fp16_to_fp4_scaled_x8(const S&, float)
{
return array<fp4_t, 4>{};
}
#endif

// bf16 -> fp4 larger vectors (bf16x4/x8) using opus bf16_to_fp4_packed_x2
template <typename S, std::enable_if_t<std::is_same_v<S, bf16x4_t>, bool> = true>
Expand Down
31 changes: 29 additions & 2 deletions csrc/include/opus/opus.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1253,6 +1253,11 @@ OPUS_D constexpr decltype(auto) fp8_to_fp32_packed_x4(const S& s) {
auto x = __builtin_amdgcn_cvt_pk_f32_fp8(bitwise, 0); auto y = __builtin_amdgcn_cvt_pk_f32_fp8(bitwise, 1);
return fp32x4_t{x[0], x[1], y[0], y[1]};
}
template<typename S, index_t sel = 0, std::enable_if_t<std::is_same_v<S, fp32x2_t>, bool> = true>
OPUS_D constexpr decltype(auto) fp32_to_bf8_packed_x2(const S& s, number<sel> = {}) {
int w ; w = __builtin_amdgcn_cvt_pk_bf8_f32(s[0], s[1], w, sel);
return __builtin_bit_cast(bf8x2_t, static_cast<short>(w));
Comment on lines +1258 to +1259
}

namespace impl {
template<typename S, index_t... Xs> OPUS_D constexpr decltype(auto) fold_as_tuple_of_vec(const S& s, seq<Xs...>) {
Expand Down Expand Up @@ -1340,7 +1345,23 @@ OPUS_D constexpr decltype(auto) bf16_to_fp4_packed_x2(const S& s, float scale =
return value.fp4_pack[0];
}
template<typename S, index_t sel = 0, std::enable_if_t<std::is_same_v<S, fp4_t>, bool> = true>
OPUS_D constexpr decltype(auto) fp4_to_bf16_packed_x2(const S& s, float scale = 1.0f, number<sel> = {}) { return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(s, scale, sel); }
OPUS_D constexpr decltype(auto) fp4_to_bf16_packed_x2(const S& s, float scale = 1.0f, number<sel> = {}) { return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(__builtin_bit_cast(u8_t, s), scale, sel); }
template<typename S, index_t sel = 0, std::enable_if_t<std::is_same_v<S, fp16x2_t>, bool> = true>
OPUS_D constexpr decltype(auto) fp16_to_fp4_packed_x2(const S& s, float scale = 1.0f, number<sel> = {}) {
u32_t w; w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, s, scale, sel);
return __builtin_bit_cast(array<fp4_t, 1>, static_cast<u8_t>(w));
Comment on lines +1351 to +1352
}
template<typename S, std::enable_if_t<std::is_same_v<S, fp16x4_t>, bool> = true>
OPUS_D constexpr decltype(auto) fp16_to_fp4_packed_x4(const S& s, float scale = 1.0f) {
u32_t w; w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[0], s[1]}, scale, 0); w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[2], s[3]}, scale, 1);
return __builtin_bit_cast(array<fp4_t, 2>, static_cast<u16_t>(w));
Comment on lines +1356 to +1357
}
template<typename S, std::enable_if_t<std::is_same_v<S, fp16x8_t>, bool> = true>
OPUS_D constexpr decltype(auto) fp16_to_fp4_packed_x8(const S& s, float scale = 1.0f) {
u32_t w; w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[0], s[1]}, scale, 0); w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[2], s[3]}, scale, 1);
w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[4], s[5]}, scale, 2); w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[6], s[7]}, scale, 3);
return __builtin_bit_cast(array<fp4_t, 4>, w);
Comment thread
junhaha666 marked this conversation as resolved.
}
#elif defined(__gfx1250__)
// gfx1250: pk8 builtins convert 8 fp4 <-> 8 f32 at once
// f32->fp4: __builtin_amdgcn_cvt_scalef32_pk8_fp4_f32(v8f32 src, float scale) -> i32
Expand Down Expand Up @@ -1385,11 +1406,14 @@ OPUS_D constexpr decltype(auto) fp4_to_fp32_packed_x8(const S& s, float scale =
fp32x8_t r = __builtin_amdgcn_cvt_scale_pk8_f32_fp4(static_cast<i32_t>(__builtin_bit_cast(u32_t, s)), scale_e8m0, 0);
return fp32x8_t{r[0], r[1], r[2], r[3], r[4], r[5], r[6], r[7]};
}
// bf16<->fp4 stubs for gfx1250 (no pk bf16<->fp4 builtins available)
// bf16<->fp4 and fp16->fp4 stubs for gfx1250 (no pk bf16<->fp4 / fp16->fp4 builtins available)
template<typename S, index_t sel = 0, std::enable_if_t<std::is_same_v<S, bf16x2_t>, bool> = true>
OPUS_D constexpr decltype(auto) bf16_to_fp4_packed_x2(const S& /*s*/, float /*scale*/ = 1.0f, number<sel> = {}) { return fp4_t{}; }
template<typename S, index_t sel = 0, std::enable_if_t<std::is_same_v<S, fp4_t>, bool> = true>
OPUS_D constexpr decltype(auto) fp4_to_bf16_packed_x2(const S& /*s*/, float /*scale*/ = 1.0f, number<sel> = {}) { return bf16x2_t{}; }
template<typename S, index_t sel = 0, std::enable_if_t<std::is_same_v<S, fp16x2_t>, bool> = true> OPUS_D constexpr decltype(auto) fp16_to_fp4_packed_x2(const S& /*s*/, float /*scale*/ = 1.0f, number<sel> = {}) { return array<fp4_t, 1>{}; }
template<typename S, std::enable_if_t<std::is_same_v<S, fp16x4_t>, bool> = true> OPUS_D constexpr decltype(auto) fp16_to_fp4_packed_x4(const S& /*s*/, float /*scale*/ = 1.0f) { return array<fp4_t, 2>{}; }
template<typename S, std::enable_if_t<std::is_same_v<S, fp16x8_t>, bool> = true> OPUS_D constexpr decltype(auto) fp16_to_fp4_packed_x8(const S& /*s*/, float /*scale*/ = 1.0f) { return array<fp4_t, 4>{}; }
#else
template<typename S, std::enable_if_t<std::is_same_v<S, fp32x2_t>, bool> = true> OPUS_D constexpr decltype(auto) fp32_to_fp4_packed_x2(const S& /*s*/, float /*scale*/ = 1.0f) { return array<fp4_t, 1>{}; }
template<typename S, std::enable_if_t<std::is_same_v<S, fp32x4_t>, bool> = true> OPUS_D constexpr decltype(auto) fp32_to_fp4_packed_x4(const S& /*s*/, float /*scale*/ = 1.0f) { return array<fp4_t, 2>{}; }
Expand All @@ -1399,6 +1423,9 @@ template<typename S, std::enable_if_t<std::is_same_v<S, array<fp4_t, 2>>, bool>
template<typename S, std::enable_if_t<std::is_same_v<S, array<fp4_t, 4>>, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_fp32_packed_x8(const S& /*s*/, float /*scale*/ = 1.0f) { return fp32x8_t{}; }
template<typename S, std::enable_if_t<std::is_same_v<S, bf16x2_t>, bool> = true> OPUS_D constexpr decltype(auto) bf16_to_fp4_packed_x2(const S& /*s*/, float /*scale*/ = 1.0f) { return fp4_t{}; }
template<typename S, std::enable_if_t<std::is_same_v<S, fp4_t>, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_bf16_packed_x2(const S& /*s*/, float /*scale*/ = 1.0f) { return bf16x2_t{}; }
template<typename S, index_t sel = 0, std::enable_if_t<std::is_same_v<S, fp16x2_t>, bool> = true> OPUS_D constexpr decltype(auto) fp16_to_fp4_packed_x2(const S& /*s*/, float /*scale*/ = 1.0f, number<sel> = {}) { return array<fp4_t, 1>{}; }
template<typename S, std::enable_if_t<std::is_same_v<S, fp16x4_t>, bool> = true> OPUS_D constexpr decltype(auto) fp16_to_fp4_packed_x4(const S& /*s*/, float /*scale*/ = 1.0f) { return array<fp4_t, 2>{}; }
template<typename S, std::enable_if_t<std::is_same_v<S, fp16x8_t>, bool> = true> OPUS_D constexpr decltype(auto) fp16_to_fp4_packed_x8(const S& /*s*/, float /*scale*/ = 1.0f) { return array<fp4_t, 4>{}; }
#endif
#pragma clang diagnostic pop

Expand Down
Loading