diff --git a/csrc/include/aiter_opus_plus.h b/csrc/include/aiter_opus_plus.h index f6ec84e942..8e1ff90917 100644 --- a/csrc/include/aiter_opus_plus.h +++ b/csrc/include/aiter_opus_plus.h @@ -53,13 +53,9 @@ OPUS_D decltype(auto) fp32_to_fp8_scaled_x2(const S& s, float inverted_scale) float a = tmp[0], b = tmp[1]; #if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx1200__) || \ defined(__gfx1201__) || defined(__gfx1250__) - int w; - asm volatile("v_med3_f32 %1, %1, %3, %4\n" - "v_med3_f32 %2, %2, %3, %4\n" - "v_cvt_pk_fp8_f32 %0, %1, %2" - : "=v"(w), "+v"(a), "+v"(b) - : "v"(lo), "v"(hi)); - return __builtin_bit_cast(fp8x2_t, static_cast(w)); + a = med3(a, lo, hi); + b = med3(b, lo, hi); + return fp32_to_fp8_packed_x2(fp32x2_t{a, b}); #else // Arches without packed fp8-cvt (RDNA3/3.5, host): compile-only stub. // fp8 KV-cache is unused on these arches; never executed at runtime. @@ -85,13 +81,9 @@ OPUS_D decltype(auto) fp32_to_bf8_scaled_x2(const S& s, float inverted_scale) float a = tmp[0], b = tmp[1]; #if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx1200__) || \ defined(__gfx1201__) || defined(__gfx1250__) - int w; - asm volatile("v_med3_f32 %1, %1, %3, %4\n" - "v_med3_f32 %2, %2, %3, %4\n" - "v_cvt_pk_bf8_f32 %0, %1, %2" - : "=v"(w), "+v"(a), "+v"(b) - : "v"(lo), "v"(hi)); - return __builtin_bit_cast(bf8x2_t, static_cast(w)); + a = med3(a, lo, hi); + b = med3(b, lo, hi); + return fp32_to_bf8_packed_x2(fp32x2_t{a, b}); #else (void)a; (void)b; (void)lo; (void)hi; return bf8x2_t{}; #endif @@ -127,50 +119,22 @@ OPUS_D decltype(auto) fp32_to_i8_scaled_x4(const S& s, float inverted_scale) ///////////////////////////////////////////////////////////////////////////////////////////////////////// // fp16x2 -> fp4 with scale (v_cvt_scalef32_pk_fp4_f16, gfx950 only) -// opus.hpp has fp32->fp4 and bf16->fp4 but NOT fp16->fp4 -#if defined(__gfx950__) +// delegates to opus fp16_to_fp4_packed_x2/x4/x8 (arch handling lives in opus) template , bool> = true> OPUS_D constexpr decltype(auto) fp16_to_fp4_scaled_x2(const S& s, float scale, number = {}) { - u32_t w; - w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, s, scale, sel); - return __builtin_bit_cast(array, static_cast(w)); + return fp16_to_fp4_packed_x2(s, scale, number{}); } template , bool> = true> OPUS_D constexpr decltype(auto) fp16_to_fp4_scaled_x4(const S& s, float scale) { - u32_t w; - w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[0], s[1]}, scale, 0); - w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[2], s[3]}, scale, 1); - return __builtin_bit_cast(array, static_cast(w)); + return fp16_to_fp4_packed_x4(s, scale); } template , bool> = true> OPUS_D constexpr decltype(auto) fp16_to_fp4_scaled_x8(const S& s, float scale) { - u32_t w; - w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[0], s[1]}, scale, 0); - w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[2], s[3]}, scale, 1); - w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[4], s[5]}, scale, 2); - w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[6], s[7]}, scale, 3); - return __builtin_bit_cast(array, w); -} -#else -template , bool> = true> -OPUS_D constexpr decltype(auto) fp16_to_fp4_scaled_x2(const S&, float) -{ - return array{}; + return fp16_to_fp4_packed_x8(s, scale); } -template , bool> = true> -OPUS_D constexpr decltype(auto) fp16_to_fp4_scaled_x4(const S&, float) -{ - return array{}; -} -template , bool> = true> -OPUS_D constexpr decltype(auto) fp16_to_fp4_scaled_x8(const S&, float) -{ - return array{}; -} -#endif // bf16 -> fp4 larger vectors (bf16x4/x8) using opus bf16_to_fp4_packed_x2 template , bool> = true> diff --git a/csrc/include/opus/opus.hpp b/csrc/include/opus/opus.hpp index 08eecad214..32a6a86984 100644 --- a/csrc/include/opus/opus.hpp +++ b/csrc/include/opus/opus.hpp @@ -1253,6 +1253,11 @@ OPUS_D constexpr decltype(auto) fp8_to_fp32_packed_x4(const S& s) { auto x = __builtin_amdgcn_cvt_pk_f32_fp8(bitwise, 0); auto y = __builtin_amdgcn_cvt_pk_f32_fp8(bitwise, 1); return fp32x4_t{x[0], x[1], y[0], y[1]}; } +template, bool> = true> +OPUS_D constexpr decltype(auto) fp32_to_bf8_packed_x2(const S& s, number = {}) { + int w ; w = __builtin_amdgcn_cvt_pk_bf8_f32(s[0], s[1], w, sel); + return __builtin_bit_cast(bf8x2_t, static_cast(w)); +} namespace impl { template OPUS_D constexpr decltype(auto) fold_as_tuple_of_vec(const S& s, seq) { @@ -1340,7 +1345,23 @@ OPUS_D constexpr decltype(auto) bf16_to_fp4_packed_x2(const S& s, float scale = return value.fp4_pack[0]; } template, bool> = true> -OPUS_D constexpr decltype(auto) fp4_to_bf16_packed_x2(const S& s, float scale = 1.0f, number = {}) { return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(s, scale, sel); } +OPUS_D constexpr decltype(auto) fp4_to_bf16_packed_x2(const S& s, float scale = 1.0f, number = {}) { return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(__builtin_bit_cast(u8_t, s), scale, sel); } +template, bool> = true> +OPUS_D constexpr decltype(auto) fp16_to_fp4_packed_x2(const S& s, float scale = 1.0f, number = {}) { + u32_t w; w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, s, scale, sel); + return __builtin_bit_cast(array, static_cast(w)); +} +template, bool> = true> +OPUS_D constexpr decltype(auto) fp16_to_fp4_packed_x4(const S& s, float scale = 1.0f) { + u32_t w; w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[0], s[1]}, scale, 0); w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[2], s[3]}, scale, 1); + return __builtin_bit_cast(array, static_cast(w)); +} +template, bool> = true> +OPUS_D constexpr decltype(auto) fp16_to_fp4_packed_x8(const S& s, float scale = 1.0f) { + u32_t w; w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[0], s[1]}, scale, 0); w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[2], s[3]}, scale, 1); + w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[4], s[5]}, scale, 2); w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[6], s[7]}, scale, 3); + return __builtin_bit_cast(array, w); +} #elif defined(__gfx1250__) // gfx1250: pk8 builtins convert 8 fp4 <-> 8 f32 at once // f32->fp4: __builtin_amdgcn_cvt_scalef32_pk8_fp4_f32(v8f32 src, float scale) -> i32 @@ -1385,11 +1406,14 @@ OPUS_D constexpr decltype(auto) fp4_to_fp32_packed_x8(const S& s, float scale = fp32x8_t r = __builtin_amdgcn_cvt_scale_pk8_f32_fp4(static_cast(__builtin_bit_cast(u32_t, s)), scale_e8m0, 0); return fp32x8_t{r[0], r[1], r[2], r[3], r[4], r[5], r[6], r[7]}; } -// bf16<->fp4 stubs for gfx1250 (no pk bf16<->fp4 builtins available) +// bf16<->fp4 and fp16->fp4 stubs for gfx1250 (no pk bf16<->fp4 / fp16->fp4 builtins available) template, bool> = true> OPUS_D constexpr decltype(auto) bf16_to_fp4_packed_x2(const S& /*s*/, float /*scale*/ = 1.0f, number = {}) { return fp4_t{}; } template, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_bf16_packed_x2(const S& /*s*/, float /*scale*/ = 1.0f, number = {}) { return bf16x2_t{}; } +template, bool> = true> OPUS_D constexpr decltype(auto) fp16_to_fp4_packed_x2(const S& /*s*/, float /*scale*/ = 1.0f, number = {}) { return array{}; } +template, bool> = true> OPUS_D constexpr decltype(auto) fp16_to_fp4_packed_x4(const S& /*s*/, float /*scale*/ = 1.0f) { return array{}; } +template, bool> = true> OPUS_D constexpr decltype(auto) fp16_to_fp4_packed_x8(const S& /*s*/, float /*scale*/ = 1.0f) { return array{}; } #else template, bool> = true> OPUS_D constexpr decltype(auto) fp32_to_fp4_packed_x2(const S& /*s*/, float /*scale*/ = 1.0f) { return array{}; } template, bool> = true> OPUS_D constexpr decltype(auto) fp32_to_fp4_packed_x4(const S& /*s*/, float /*scale*/ = 1.0f) { return array{}; } @@ -1399,6 +1423,9 @@ template>, bool> template>, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_fp32_packed_x8(const S& /*s*/, float /*scale*/ = 1.0f) { return fp32x8_t{}; } template, bool> = true> OPUS_D constexpr decltype(auto) bf16_to_fp4_packed_x2(const S& /*s*/, float /*scale*/ = 1.0f) { return fp4_t{}; } template, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_bf16_packed_x2(const S& /*s*/, float /*scale*/ = 1.0f) { return bf16x2_t{}; } +template, bool> = true> OPUS_D constexpr decltype(auto) fp16_to_fp4_packed_x2(const S& /*s*/, float /*scale*/ = 1.0f, number = {}) { return array{}; } +template, bool> = true> OPUS_D constexpr decltype(auto) fp16_to_fp4_packed_x4(const S& /*s*/, float /*scale*/ = 1.0f) { return array{}; } +template, bool> = true> OPUS_D constexpr decltype(auto) fp16_to_fp4_packed_x8(const S& /*s*/, float /*scale*/ = 1.0f) { return array{}; } #endif #pragma clang diagnostic pop