diff --git a/README.md b/README.md index 88401a94..ca0e25ac 100644 --- a/README.md +++ b/README.md @@ -37,9 +37,9 @@ how fast this is relative to `std::sort`. ## Sort an array of built-in integers and floats ```cpp -void x86simdsort::qsort(T* arr, size_t size, bool hasnan, bool descending); -void x86simdsort::qselect(T* arr, size_t k, size_t size, bool hasnan, bool descending); -void x86simdsort::partial_qsort(T* arr, size_t k, size_t size, bool hasnan, bool descending); +void x86simdsort::qsort(T* arr, size_t size, bool hasnan, bool descending, bool trailing_nans); +void x86simdsort::qselect(T* arr, size_t k, size_t size, bool hasnan, bool descending, bool trailing_nans); +void x86simdsort::partial_qsort(T* arr, size_t k, size_t size, bool hasnan, bool descending, bool trailing_nans); ``` Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t, int32_t, double, uint64_t, int64_t]` @@ -56,8 +56,8 @@ data types. ## Arg sort routines on arrays ```cpp -std::vector arg = x86simdsort::argsort(const T* arr, size_t size, bool hasnan, bool descending); -std::vector arg = x86simdsort::argselect(const T* arr, size_t k, size_t size, bool hasnan); +std::vector arg = x86simdsort::argsort(const T* arr, size_t size, bool hasnan, bool descending, bool trailing_nans); +std::vector arg = x86simdsort::argselect(const T* arr, size_t k, size_t size, bool hasnan, bool descending, bool trailing_nans); ``` Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t, int32_t, double, uint64_t, int64_t]` Note that argsort and argselect are not accelerated with SIMD when using 16-bit @@ -174,13 +174,22 @@ Supported datatypes: `uint16_t, int16_t, _Float16, uint32_t, int32_t, float, uint64_t, int64_t, double`. Note that `_Float16` will require building this library with g++ >= 12.x. All the functions have an optional argument `bool hasnan` set to `false` by default (these are relevant to floating point data -types only). If your array has NAN's, the the behaviour of the sorting routine -is undefined. If `hasnan` is set to true, NAN's are always sorted to the end of -the array. In addition to that, qsort will replace all your NAN's with -`std::numeric_limits::quiet_NaN`. The original bit-exact NaNs in -the input are not preserved. Also note that the arg methods (argsort and -argselect) will not use the SIMD based algorithms if they detect NAN's in the -array. You can read details of all the implementations +types only). If your array has NaN values, the behaviour of the sorting routine +is undefined unless `hasnan` is set to `true`. When `hasnan=true`, NaN placement +is controlled by the optional `bool trailing_nans` parameter (default `true`): + +- `trailing_nans=true` (default): NaN values are placed at the **end** of the + result, regardless of sort direction. +- `trailing_nans=false`: NaN values are placed at the **beginning** of the + result, regardless of sort direction. + +All routines accept an optional `bool descending` parameter (default `false`). +When `descending=true`, results are in descending order. For `argselect`, the +k-th element becomes the k-th **largest**, with all elements before index k +being greater than or equal to it. + +Note that the arg methods (argsort and argselect) will not use the SIMD based +algorithms if they detect NaN values in the array. You can read details of all the implementations [here](https://github.com/intel/x86-simd-sort/blob/main/src/README.md). ## Performance comparison on AVX-512: `object_qsort` v/s `std::sort` diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 36b43e89..afac33c7 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -5,33 +5,57 @@ #define DEFINE_ALL_METHODS(type) \ template <> \ - void qsort(type *arr, size_t arrsize, bool hasnan, bool descending) \ + void qsort(type *arr, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ - x86simdsortStatic::qsort(arr, arrsize, hasnan, descending); \ + x86simdsortStatic::qsort( \ + arr, arrsize, hasnan, descending, nans_last); \ } \ template <> \ - void qselect( \ - type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void qselect(type *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ - x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending); \ + x86simdsortStatic::qselect( \ + arr, k, arrsize, hasnan, descending, nans_last); \ } \ template <> \ - void partial_qsort( \ - type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void partial_qsort(type *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ - x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending); \ + x86simdsortStatic::partial_qsort( \ + arr, k, arrsize, hasnan, descending, nans_last); \ } \ template <> \ - std::vector argsort( \ - const type *arr, size_t arrsize, bool hasnan, bool descending) \ + std::vector argsort(const type *arr, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ - return x86simdsortStatic::argsort(arr, arrsize, hasnan, descending); \ + return x86simdsortStatic::argsort( \ + arr, arrsize, hasnan, descending, nans_last); \ } \ template <> \ - std::vector argselect( \ - const type *arr, size_t k, size_t arrsize, bool hasnan) \ + std::vector argselect(const type *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ - return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \ + return x86simdsortStatic::argselect( \ + arr, k, arrsize, hasnan, descending, nans_last); \ } #define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \ diff --git a/lib/x86simdsort-icl.cpp b/lib/x86simdsort-icl.cpp index 3e5c4b5b..d1d9adff 100644 --- a/lib/x86simdsort-icl.cpp +++ b/lib/x86simdsort-icl.cpp @@ -8,76 +8,100 @@ namespace xss { namespace avx512 { template <> - void qsort(uint16_t *arr, size_t size, bool hasnan, bool descending) + void qsort(uint16_t *arr, + size_t size, + bool hasnan, + bool descending, + bool nans_last) { - x86simdsortStatic::qsort(arr, size, hasnan, descending); + x86simdsortStatic::qsort(arr, size, hasnan, descending, nans_last); } template <> void qselect(uint16_t *arr, size_t k, size_t arrsize, bool hasnan, - bool descending) + bool descending, + bool nans_last) { - x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending); + x86simdsortStatic::qselect( + arr, k, arrsize, hasnan, descending, nans_last); } template <> void partial_qsort(uint16_t *arr, size_t k, size_t arrsize, bool hasnan, - bool descending) + bool descending, + bool nans_last) { - x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending); + x86simdsortStatic::partial_qsort( + arr, k, arrsize, hasnan, descending, nans_last); } template <> - void qsort(int16_t *arr, size_t size, bool hasnan, bool descending) + void qsort(int16_t *arr, + size_t size, + bool hasnan, + bool descending, + bool nans_last) { - x86simdsortStatic::qsort(arr, size, hasnan, descending); + x86simdsortStatic::qsort(arr, size, hasnan, descending, nans_last); } template <> void qselect(int16_t *arr, size_t k, size_t arrsize, bool hasnan, - bool descending) + bool descending, + bool nans_last) { - x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending); + x86simdsortStatic::qselect( + arr, k, arrsize, hasnan, descending, nans_last); } template <> void partial_qsort(int16_t *arr, size_t k, size_t arrsize, bool hasnan, - bool descending) + bool descending, + bool nans_last) { - x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending); + x86simdsortStatic::partial_qsort( + arr, k, arrsize, hasnan, descending, nans_last); } } // namespace avx512 namespace fp16_icl { #ifdef __FLT16_MAX__ template <> - void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending) + void qsort(_Float16 *arr, + size_t size, + bool hasnan, + bool descending, + bool nans_last) { - x86simdsortStatic::qsort(arr, size, hasnan, descending); + x86simdsortStatic::qsort(arr, size, hasnan, descending, nans_last); } template <> void qselect(_Float16 *arr, size_t k, size_t arrsize, bool hasnan, - bool descending) + bool descending, + bool nans_last) { - x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending); + x86simdsortStatic::qselect( + arr, k, arrsize, hasnan, descending, nans_last); } template <> void partial_qsort(_Float16 *arr, size_t k, size_t arrsize, bool hasnan, - bool descending) + bool descending, + bool nans_last) { - x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending); + x86simdsortStatic::partial_qsort( + arr, k, arrsize, hasnan, descending, nans_last); } #endif } // namespace fp16_icl diff --git a/lib/x86simdsort-internal.h b/lib/x86simdsort-internal.h index 055df2bd..d6e30b5e 100644 --- a/lib/x86simdsort-internal.h +++ b/lib/x86simdsort-internal.h @@ -10,7 +10,8 @@ XSS_HIDE_SYMBOL void qsort(T *arr, \ size_t arrsize, \ bool hasnan = false, \ - bool descending = false); \ + bool descending = false, \ + bool nans_last = true); \ template \ XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key, \ T2 *val, \ @@ -22,7 +23,8 @@ size_t k, \ size_t arrsize, \ bool hasnan = false, \ - bool descending = false); \ + bool descending = false, \ + bool nans_last = true); \ template \ XSS_HIDE_SYMBOL void keyvalue_select(T1 *key, \ T2 *val, \ @@ -35,7 +37,8 @@ size_t k, \ size_t arrsize, \ bool hasnan = false, \ - bool descending = false); \ + bool descending = false, \ + bool nans_last = true); \ template \ XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key, \ T2 *val, \ @@ -47,10 +50,15 @@ XSS_HIDE_SYMBOL std::vector argsort(const T *arr, \ size_t arrsize, \ bool hasnan = false, \ - bool descending = false); \ + bool descending = false, \ + bool nans_last = true); \ template \ - XSS_HIDE_SYMBOL std::vector \ - argselect(const T *arr, size_t k, size_t arrsize, bool hasnan = false); \ + XSS_HIDE_SYMBOL std::vector argselect(const T *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan = false, \ + bool descending = false, \ + bool nans_last = true); \ } namespace xss { diff --git a/lib/x86simdsort-scalar.h b/lib/x86simdsort-scalar.h index 9f08f9b2..7faf4446 100644 --- a/lib/x86simdsort-scalar.h +++ b/lib/x86simdsort-scalar.h @@ -24,13 +24,24 @@ namespace utils { } } template - decltype(auto) get_cmp_func(bool hasnan, bool reverse) + decltype(auto) + get_cmp_func(bool hasnan, bool reverse, bool nans_last = true) { std::function cmp; if (hasnan) { - if (reverse == true) { cmp = compare>(); } + if (nans_last) { + if (reverse == true) { + cmp = compare_nan_end>(); + } + else { + cmp = compare>(); + } + } else { - cmp = compare>(); + if (reverse == true) { cmp = compare>(); } + else { + cmp = compare_nan_begin>(); + } } } else { @@ -45,66 +56,116 @@ namespace utils { namespace scalar { template - void qsort(T *arr, size_t arrsize, bool hasnan, bool reversed) + void qsort(T *arr, + size_t arrsize, + bool hasnan, + bool reversed, + bool nans_last) { std::sort(arr, arr + arrsize, - xss::utils::get_cmp_func(hasnan, reversed)); + xss::utils::get_cmp_func(hasnan, reversed, nans_last)); } template - void qselect(T *arr, size_t k, size_t arrsize, bool hasnan, bool reversed) + void qselect(T *arr, + size_t k, + size_t arrsize, + bool hasnan, + bool reversed, + bool nans_last) { - std::nth_element(arr, - arr + k, - arr + arrsize, - xss::utils::get_cmp_func(hasnan, reversed)); + std::nth_element( + arr, + arr + k, + arr + arrsize, + xss::utils::get_cmp_func(hasnan, reversed, nans_last)); } template - void - partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan, bool reversed) + void partial_qsort(T *arr, + size_t k, + size_t arrsize, + bool hasnan, + bool reversed, + bool nans_last) { - std::partial_sort(arr, - arr + k, - arr + arrsize, - xss::utils::get_cmp_func(hasnan, reversed)); + std::partial_sort( + arr, + arr + k, + arr + arrsize, + xss::utils::get_cmp_func(hasnan, reversed, nans_last)); } template - std::vector - argsort(const T *arr, size_t arrsize, bool hasnan, bool reversed) + std::vector argsort(const T *arr, + size_t arrsize, + bool hasnan, + bool reversed, + bool nans_last) { UNUSED(hasnan); std::vector arg(arrsize); std::iota(arg.begin(), arg.end(), 0); - if (reversed) { - std::sort(arg.begin(), - arg.end(), - compare_arg>(arr)); + if (nans_last) { + if (reversed) { + std::sort(arg.begin(), + arg.end(), + compare_arg_nan_end>(arr)); + } + else { + std::sort(arg.begin(), + arg.end(), + compare_arg>(arr)); + } } else { - std::sort( - arg.begin(), arg.end(), compare_arg>(arr)); + if (reversed) { + std::sort(arg.begin(), + arg.end(), + compare_arg>(arr)); + } + else { + std::sort(arg.begin(), + arg.end(), + compare_arg_nan_begin>(arr)); + } } return arg; } template - std::vector - argselect(const T *arr, size_t k, size_t arrsize, bool hasnan) + std::vector argselect(const T *arr, + size_t k, + size_t arrsize, + bool hasnan, + bool descending, + bool nans_last) { UNUSED(hasnan); std::vector arg(arrsize); std::iota(arg.begin(), arg.end(), 0); - std::nth_element(arg.begin(), - arg.begin() + k, - arg.end(), - compare_arg>(arr)); + std::function cmp; + if (nans_last) { + if (descending) { + cmp = compare_arg_nan_end>(arr); + } + else { + cmp = compare_arg>(arr); + } + } + else { + if (descending) { cmp = compare_arg>(arr); } + else { + cmp = compare_arg_nan_begin>(arr); + } + } + std::nth_element(arg.begin(), arg.begin() + k, arg.end(), cmp); return arg; } template void keyvalue_qsort( T1 *key, T2 *val, size_t arrsize, bool hasnan, bool descending) { - std::vector arg = argsort(key, arrsize, hasnan, descending); + std::vector arg + = argsort(key, arrsize, hasnan, descending, true); utils::apply_permutation_in_place(key, arg); utils::apply_permutation_in_place(val, arg); } diff --git a/lib/x86simdsort-skx.cpp b/lib/x86simdsort-skx.cpp index 6260bab7..b38653e7 100644 --- a/lib/x86simdsort-skx.cpp +++ b/lib/x86simdsort-skx.cpp @@ -5,33 +5,57 @@ #define DEFINE_ALL_METHODS(type) \ template <> \ - void qsort(type *arr, size_t arrsize, bool hasnan, bool descending) \ + void qsort(type *arr, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ - x86simdsortStatic::qsort(arr, arrsize, hasnan, descending); \ + x86simdsortStatic::qsort( \ + arr, arrsize, hasnan, descending, nans_last); \ } \ template <> \ - void qselect( \ - type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void qselect(type *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ - x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending); \ + x86simdsortStatic::qselect( \ + arr, k, arrsize, hasnan, descending, nans_last); \ } \ template <> \ - void partial_qsort( \ - type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void partial_qsort(type *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ - x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending); \ + x86simdsortStatic::partial_qsort( \ + arr, k, arrsize, hasnan, descending, nans_last); \ } \ template <> \ - std::vector argsort( \ - const type *arr, size_t arrsize, bool hasnan, bool descending) \ + std::vector argsort(const type *arr, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ - return x86simdsortStatic::argsort(arr, arrsize, hasnan, descending); \ + return x86simdsortStatic::argsort( \ + arr, arrsize, hasnan, descending, nans_last); \ } \ template <> \ - std::vector argselect( \ - const type *arr, size_t k, size_t arrsize, bool hasnan) \ + std::vector argselect(const type *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ - return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \ + return x86simdsortStatic::argselect( \ + arr, k, arrsize, hasnan, descending, nans_last); \ } #define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \ diff --git a/lib/x86simdsort-spr.cpp b/lib/x86simdsort-spr.cpp index 7587640a..9e8b2887 100644 --- a/lib/x86simdsort-spr.cpp +++ b/lib/x86simdsort-spr.cpp @@ -5,27 +5,35 @@ namespace xss { namespace fp16_spr { template <> - void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending) + void qsort(_Float16 *arr, + size_t size, + bool hasnan, + bool descending, + bool nans_last) { - x86simdsortStatic::qsort(arr, size, hasnan, descending); + x86simdsortStatic::qsort(arr, size, hasnan, descending, nans_last); } template <> void qselect(_Float16 *arr, size_t k, size_t arrsize, bool hasnan, - bool descending) + bool descending, + bool nans_last) { - x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending); + x86simdsortStatic::qselect( + arr, k, arrsize, hasnan, descending, nans_last); } template <> void partial_qsort(_Float16 *arr, size_t k, size_t arrsize, bool hasnan, - bool descending) + bool descending, + bool nans_last) { - x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending); + x86simdsortStatic::partial_qsort( + arr, k, arrsize, hasnan, descending, nans_last); } } // namespace fp16_spr } // namespace xss diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 776ec56d..c90661e0 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -101,122 +101,172 @@ namespace x86simdsort { #ifdef _MSC_VER #define DECLARE_INTERNAL_qsort(TYPE) \ static void CAT(resolve_qsort, TYPE)(void); \ - static void (*internal_qsort##TYPE)(TYPE *, size_t, bool, bool) = NULL; \ + static void (*internal_qsort##TYPE)(TYPE *, size_t, bool, bool, bool) \ + = NULL; \ template <> \ - void XSS_EXPORT_SYMBOL qsort( \ - TYPE *arr, size_t arrsize, bool hasnan, bool descending) \ + void XSS_EXPORT_SYMBOL qsort(TYPE *arr, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ if (internal_qsort##TYPE == NULL) { CAT(resolve_qsort, TYPE)(); } \ - (*internal_qsort##TYPE)(arr, arrsize, hasnan, descending); \ + (*internal_qsort##TYPE)( \ + arr, arrsize, hasnan, descending, nans_last); \ } #define DECLARE_INTERNAL_qselect(TYPE) \ static void CAT(resolve_qselect, TYPE)(void); \ - static void (*internal_qselect##TYPE)(TYPE *, size_t, size_t, bool, bool) \ + static void (*internal_qselect##TYPE)( \ + TYPE *, size_t, size_t, bool, bool, bool) \ = NULL; \ template <> \ - void XSS_EXPORT_SYMBOL qselect( \ - TYPE *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void XSS_EXPORT_SYMBOL qselect(TYPE *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ if (internal_qselect##TYPE == NULL) { CAT(resolve_qselect, TYPE)(); } \ - (*internal_qselect##TYPE)(arr, k, arrsize, hasnan, descending); \ + (*internal_qselect##TYPE)( \ + arr, k, arrsize, hasnan, descending, nans_last); \ } #define DECLARE_INTERNAL_partial_qsort(TYPE) \ static void CAT(resolve_partial_qsort, TYPE)(void); \ static void (*internal_partial_qsort##TYPE)( \ - TYPE *, size_t, size_t, bool, bool) \ + TYPE *, size_t, size_t, bool, bool, bool) \ = NULL; \ template <> \ - void XSS_EXPORT_SYMBOL partial_qsort( \ - TYPE *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void XSS_EXPORT_SYMBOL partial_qsort(TYPE *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ if (internal_partial_qsort##TYPE == NULL) { \ CAT(resolve_partial_qsort, TYPE)(); \ } \ - (*internal_partial_qsort##TYPE)(arr, k, arrsize, hasnan, descending); \ + (*internal_partial_qsort##TYPE)( \ + arr, k, arrsize, hasnan, descending, nans_last); \ } #define DECLARE_INTERNAL_argsort(TYPE) \ static void CAT(resolve_argsort, TYPE)(void); \ static std::vector (*internal_argsort##TYPE)( \ - const TYPE *, size_t, bool, bool) \ + const TYPE *, size_t, bool, bool, bool) \ = NULL; \ template <> \ - std::vector XSS_EXPORT_SYMBOL argsort( \ - const TYPE *arr, size_t arrsize, bool hasnan, bool descending) \ + std::vector XSS_EXPORT_SYMBOL argsort(const TYPE *arr, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ if (internal_argsort##TYPE == NULL) { CAT(resolve_argsort, TYPE)(); } \ - return (*internal_argsort##TYPE)(arr, arrsize, hasnan, descending); \ + return (*internal_argsort##TYPE)( \ + arr, arrsize, hasnan, descending, nans_last); \ } #define DECLARE_INTERNAL_argselect(TYPE) \ static void CAT(resolve_argselect, TYPE)(void); \ static std::vector (*internal_argselect##TYPE)( \ - const TYPE *, size_t, size_t, bool) \ + const TYPE *, size_t, size_t, bool, bool, bool) \ = NULL; \ template <> \ - std::vector XSS_EXPORT_SYMBOL argselect( \ - const TYPE *arr, size_t k, size_t arrsize, bool hasnan) \ + std::vector XSS_EXPORT_SYMBOL argselect(const TYPE *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ if (internal_argselect##TYPE == NULL) { \ CAT(resolve_argselect, TYPE)(); \ } \ - return (*internal_argselect##TYPE)(arr, k, arrsize, hasnan); \ + return (*internal_argselect##TYPE)( \ + arr, k, arrsize, hasnan, descending, nans_last); \ } #else #define DECLARE_INTERNAL_qsort(TYPE) \ - static void (*internal_qsort##TYPE)(TYPE *, size_t, bool, bool) = NULL; \ + static void (*internal_qsort##TYPE)(TYPE *, size_t, bool, bool, bool) \ + = NULL; \ template <> \ - void XSS_EXPORT_SYMBOL qsort( \ - TYPE *arr, size_t arrsize, bool hasnan, bool descending) \ + void XSS_EXPORT_SYMBOL qsort(TYPE *arr, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ - (*internal_qsort##TYPE)(arr, arrsize, hasnan, descending); \ + (*internal_qsort##TYPE)( \ + arr, arrsize, hasnan, descending, nans_last); \ } #define DECLARE_INTERNAL_qselect(TYPE) \ - static void (*internal_qselect##TYPE)(TYPE *, size_t, size_t, bool, bool) \ + static void (*internal_qselect##TYPE)( \ + TYPE *, size_t, size_t, bool, bool, bool) \ = NULL; \ template <> \ - void XSS_EXPORT_SYMBOL qselect( \ - TYPE *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void XSS_EXPORT_SYMBOL qselect(TYPE *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ - (*internal_qselect##TYPE)(arr, k, arrsize, hasnan, descending); \ + (*internal_qselect##TYPE)( \ + arr, k, arrsize, hasnan, descending, nans_last); \ } #define DECLARE_INTERNAL_partial_qsort(TYPE) \ static void (*internal_partial_qsort##TYPE)( \ - TYPE *, size_t, size_t, bool, bool) \ + TYPE *, size_t, size_t, bool, bool, bool) \ = NULL; \ template <> \ - void XSS_EXPORT_SYMBOL partial_qsort( \ - TYPE *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void XSS_EXPORT_SYMBOL partial_qsort(TYPE *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ - (*internal_partial_qsort##TYPE)(arr, k, arrsize, hasnan, descending); \ + (*internal_partial_qsort##TYPE)( \ + arr, k, arrsize, hasnan, descending, nans_last); \ } #define DECLARE_INTERNAL_argsort(TYPE) \ static std::vector (*internal_argsort##TYPE)( \ - const TYPE *, size_t, bool, bool) \ + const TYPE *, size_t, bool, bool, bool) \ = NULL; \ template <> \ - std::vector XSS_EXPORT_SYMBOL argsort( \ - const TYPE *arr, size_t arrsize, bool hasnan, bool descending) \ + std::vector XSS_EXPORT_SYMBOL argsort(const TYPE *arr, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ - return (*internal_argsort##TYPE)(arr, arrsize, hasnan, descending); \ + return (*internal_argsort##TYPE)( \ + arr, arrsize, hasnan, descending, nans_last); \ } #define DECLARE_INTERNAL_argselect(TYPE) \ static std::vector (*internal_argselect##TYPE)( \ - const TYPE *, size_t, size_t, bool) \ + const TYPE *, size_t, size_t, bool, bool, bool) \ = NULL; \ template <> \ - std::vector XSS_EXPORT_SYMBOL argselect( \ - const TYPE *arr, size_t k, size_t arrsize, bool hasnan) \ + std::vector XSS_EXPORT_SYMBOL argselect(const TYPE *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ - return (*internal_argselect##TYPE)(arr, k, arrsize, hasnan); \ + return (*internal_argselect##TYPE)( \ + arr, k, arrsize, hasnan, descending, nans_last); \ } #endif // _MSC_VER diff --git a/lib/x86simdsort.h b/lib/x86simdsort.h index e30120ef..37c7bb6e 100644 --- a/lib/x86simdsort.h +++ b/lib/x86simdsort.h @@ -19,8 +19,11 @@ namespace x86simdsort { // quicksort template -XSS_EXPORT_SYMBOL void -qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); +XSS_EXPORT_SYMBOL void qsort(T *arr, + size_t arrsize, + bool hasnan = false, + bool descending = false, + bool nans_last = true); // quickselect template @@ -28,7 +31,8 @@ XSS_EXPORT_SYMBOL void qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false, - bool descending = false); + bool descending = false, + bool nans_last = true); // partial sort template @@ -36,19 +40,25 @@ XSS_EXPORT_SYMBOL void partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false, - bool descending = false); + bool descending = false, + bool nans_last = true); // argsort template XSS_EXPORT_SYMBOL std::vector argsort(const T *arr, size_t arrsize, bool hasnan = false, - bool descending = false); + bool descending = false, + bool nans_last = true); // argselect template -XSS_EXPORT_SYMBOL std::vector -argselect(const T *arr, size_t k, size_t arrsize, bool hasnan = false); +XSS_EXPORT_SYMBOL std::vector argselect(const T *arr, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false, + bool nans_last = true); // keyvalue sort template diff --git a/src/README.md b/src/README.md index ad5fc7ba..a89f2d08 100644 --- a/src/README.md +++ b/src/README.md @@ -18,13 +18,14 @@ Equivalent to `qsort` in `std::sort` in [C++](https://en.cppreference.com/w/cpp/algorithm/sort). ```cpp -void x86simdsortStatic::qsort(T* arr, size_t arrsize, bool hasnan = false, bool descending = false); +void x86simdsortStatic::qsort(T* arr, size_t arrsize, bool hasnan = false, bool descending = false, bool nans_last = true); ``` Supported datatypes: `uint16_t`, `int16_t`, `_Float16`, `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. AVX2 versions currently support 32-bit and 64-bit dtypes only. For floating-point types, if `arr` contains -NaNs, they are moved to the end and replaced with a quiet NaN. That is, the -original, bit-exact NaNs in the input are not preserved. +NaNs, their placement is controlled by `nans_last`: `true` (default) places +NaNs at the end; `false` places them at the beginning. Bit-exact NaN payloads +are preserved. #### Quickselect Equivalent to `std::nth_element` in @@ -34,13 +35,14 @@ Equivalent to `std::nth_element` in ```cpp -void x86simdsortStatic::qselect(T* arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); +void x86simdsortStatic::qselect(T* arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false, bool nans_last = true); ``` Supported datatypes: `uint16_t`, `int16_t`, `_Float16`, `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. AVX2 versions currently support 32-bit and 64-bit dtypes only. For floating-point types, if `bool hasnan` is -set, NaNs are moved to the end of the array, preserving the bit-exact NaNs in -the input. If NaNs are present but `hasnan` is `false`, the behavior is +set, NaN placement is controlled by `nans_last`: `true` (default) places +NaNs at the end; `false` places them at the beginning. Bit-exact NaN payloads +are preserved. If NaNs are present but `hasnan` is `false`, the behavior is undefined. #### Partialsort @@ -49,13 +51,14 @@ Equivalent to `std::partial_sort` in ```cpp -void x86simdsortStatic::partial_qsort(T* arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false) +void x86simdsortStatic::partial_qsort(T* arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false, bool nans_last = true); ``` Supported datatypes: `uint16_t`, `int16_t`, `_Float16`, `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. AVX2 versions currently support 32-bit and 64-bit dtypes only. For floating-point types, if `bool hasnan` is -set, NaNs are moved to the end of the array, preserving the bit-exact NaNs in -the input. If NaNs are present but `hasnan` is `false`, the behavior is +set, NaN placement is controlled by `nans_last`: `true` (default) places +NaNs at the end; `false` places them at the beginning. Bit-exact NaN payloads +are preserved. If NaNs are present but `hasnan` is `false`, the behavior is undefined. #### Argsort @@ -63,7 +66,7 @@ Equivalent to `np.argsort` in [NumPy](https://numpy.org/doc/stable/reference/generated/numpy.argsort.html). ```cpp -void x86simdsortStatic::argsort(const T* arr, size_t *arg, size_t arrsize, bool hasnan = false, bool descending = false); +void x86simdsortStatic::argsort(const T* arr, size_t *arg, size_t arrsize, bool hasnan = false, bool descending = false, bool nans_last = true); ``` Supported datatypes: `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. @@ -75,12 +78,17 @@ Equivalent to `np.argselect` in [NumPy](https://numpy.org/doc/stable/reference/generated/numpy.argpartition.html). ```cpp -void x86simdsortStatic::argselect(const T* arr, size_t *arg, size_t k, size_t arrsize, bool hasnan = false); +void x86simdsortStatic::argselect(const T* arr, size_t *arg, size_t k, size_t arrsize, bool hasnan = false, bool descending = false, bool nans_last = true); ``` Supported datatypes: `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. -The algorithm resorts to scalar `std::sort` if the array contains NaNs. +When `descending=true`, the k-th element is the k-th **largest** and elements +before index k are all greater than or equal to it. For floating-point types, +if `bool hasnan` is set, NaN placement is controlled by `nans_last`: +`true` (default) places NaNs at the end; `false` places them at the beginning. + +The algorithm resorts to scalar `std::nth_element` if the array contains NaNs. #### Key-value sort ```cpp diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index fbe18567..f8c327af 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -519,41 +519,24 @@ comparison_func>(const uint16_t &a, const uint16_t &b) //return npy_half_to_float(a) < npy_half_to_float(b); } -template <> -X86_SIMD_SORT_INLINE_ONLY arrsize_t -replace_nan_with_inf>(uint16_t *arr, arrsize_t arrsize) -{ - arrsize_t nan_count = 0; - __mmask16 loadmask = 0xFFFF; - for (arrsize_t ii = 0; ii < arrsize; - ii = ii + zmm_vector::numlanes / 2) { - if (arrsize - ii < 16) { - loadmask = (0x0001 << (arrsize - ii)) - 0x0001; - } - __m256i in_zmm = _mm256_maskz_loadu_epi16(loadmask, arr); - __m512 in_zmm_asfloat = _mm512_cvtph_ps(in_zmm); - __mmask16 nanmask = _mm512_cmp_ps_mask( - in_zmm_asfloat, in_zmm_asfloat, _CMP_NEQ_UQ); - nan_count += _mm_popcnt_u32((int32_t)nanmask); - _mm256_mask_storeu_epi16(arr, nanmask, YMM_MAX_HALF); - arr += 16; - } - return nan_count; -} - template [[maybe_unused]] X86_SIMD_SORT_INLINE void -avx512_qsort_fp16_helper(uint16_t *arr, arrsize_t arrsize) +avx512_qsort_fp16_helper(uint16_t *arr, + arrsize_t arrsize, + arrsize_t index_first_elem, + arrsize_t index_last_elem) { using T = uint16_t; using vtype = zmm_vector; #ifdef XSS_COMPILE_OPENMP - bool use_parallel = arrsize > 100000; + bool use_parallel = (index_last_elem - index_first_elem + 1) > 100000; if (use_parallel) { int thread_count = xss_get_num_threads(); - arrsize_t task_threshold = std::max((arrsize_t)100000, arrsize / 100); + arrsize_t task_threshold + = std::max((arrsize_t)100000, + (index_last_elem - index_first_elem + 1) / 100); // We use omp parallel and then omp single to setup the threads that will run the omp task calls in qsort_ // The omp single prevents multiple threads from running the initial qsort_ simultaneously and causing problems @@ -561,22 +544,25 @@ avx512_qsort_fp16_helper(uint16_t *arr, arrsize_t arrsize) #pragma omp parallel num_threads(thread_count) #pragma omp single qsort_(arr, - 0, - arrsize - 1, + index_first_elem, + index_last_elem, 2 * (arrsize_t)log2(arrsize), task_threshold); } else { qsort_(arr, - 0, - arrsize - 1, + index_first_elem, + index_last_elem, 2 * (arrsize_t)log2(arrsize), std::numeric_limits::max()); } #pragma omp taskwait #else - qsort_( - arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0); + qsort_(arr, + index_first_elem, + index_last_elem, + 2 * (arrsize_t)log2(arrsize), + 0); #endif } @@ -584,22 +570,32 @@ avx512_qsort_fp16_helper(uint16_t *arr, arrsize_t arrsize) avx512_qsort_fp16(uint16_t *arr, arrsize_t arrsize, bool hasnan = false, - bool descending = false) + bool descending = false, + bool nans_last = true) { using vtype = zmm_vector; if (arrsize > 1) { - arrsize_t nan_count = 0; + arrsize_t index_first_elem = 0; + arrsize_t index_last_elem = arrsize - 1; if (UNLIKELY(hasnan)) { - nan_count = replace_nan_with_inf(arr, arrsize); + if (!nans_last) { + index_first_elem = move_nans_to_start_of_array(arr, arrsize); + } + else { + index_last_elem = move_nans_to_end_of_array(arr, arrsize); + } } - if (descending) { - avx512_qsort_fp16_helper>(arr, arrsize); - } - else { - avx512_qsort_fp16_helper>(arr, arrsize); + if (index_first_elem <= index_last_elem && index_last_elem < arrsize) { + if (descending) { + avx512_qsort_fp16_helper>( + arr, arrsize, index_first_elem, index_last_elem); + } + else { + avx512_qsort_fp16_helper>( + arr, arrsize, index_first_elem, index_last_elem); + } } - replace_inf_with_nan(arr, arrsize, nan_count, descending); } #ifdef __MMX__ @@ -613,7 +609,8 @@ avx512_qselect_fp16(uint16_t *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false, - bool descending = false) + bool descending = false, + bool nans_last = true) { using vtype = zmm_vector; @@ -624,7 +621,7 @@ avx512_qselect_fp16(uint16_t *arr, arrsize_t index_last_elem = arrsize - 1; if (UNLIKELY(hasnan)) { - if (descending) { + if (!nans_last) { index_first_elem = move_nans_to_start_of_array(arr, arrsize); } else { @@ -662,10 +659,11 @@ avx512_partial_qsort_fp16(uint16_t *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false, - bool descending = false) + bool descending = false, + bool nans_last = true) { if (k == 0) return; - avx512_qselect_fp16(arr, k - 1, arrsize, hasnan, descending); - avx512_qsort_fp16(arr, k - 1, hasnan, descending); + avx512_qselect_fp16(arr, k - 1, arrsize, hasnan, descending, nans_last); + avx512_qsort_fp16(arr, k - 1, hasnan, descending, nans_last); } #endif // AVX512_QSORT_16BIT diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 8f85e599..d4bebe74 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -175,26 +175,4 @@ X86_SIMD_SORT_INLINE_ONLY bool is_a_nan<_Float16>(_Float16 elem) return elem != elem; } -template <> -X86_SIMD_SORT_INLINE_ONLY void replace_inf_with_nan(_Float16 *arr, - arrsize_t size, - arrsize_t nan_count, - bool descending) -{ - Fp16Bits val; - val.i_ = 0x7c01; - - if (descending) { - for (arrsize_t ii = 0; nan_count > 0; ++ii) { - arr[ii] = val.f_; - nan_count -= 1; - } - } - else { - for (arrsize_t ii = size - 1; nan_count > 0; --ii) { - arr[ii] = val.f_; - nan_count -= 1; - } - } -} #endif // AVX512FP16_QSORT_16BIT diff --git a/src/x86simdsort-static-incl.h b/src/x86simdsort-static-incl.h index 2b0a11e0..bdc91baf 100644 --- a/src/x86simdsort-static-incl.h +++ b/src/x86simdsort-static-incl.h @@ -7,28 +7,34 @@ // Supported methods declared here for a quick reference: namespace x86simdsortStatic { template -X86_SIMD_SORT_FINLINE void -qsort(T *arr, size_t size, bool hasnan = false, bool descending = false); +X86_SIMD_SORT_FINLINE void qsort(T *arr, + size_t size, + bool hasnan = false, + bool descending = false, + bool nans_last = true); template X86_SIMD_SORT_FINLINE void qselect(T *arr, size_t k, size_t size, bool hasnan = false, - bool descending = false); + bool descending = false, + bool nans_last = true); template X86_SIMD_SORT_FINLINE void partial_qsort(T *arr, size_t k, size_t size, bool hasnan = false, - bool descending = false); + bool descending = false, + bool nans_last = true); template X86_SIMD_SORT_FINLINE std::vector argsort(const T *arr, size_t size, bool hasnan = false, - bool descending = false); + bool descending = false, + bool nans_last = true); /* argsort API required by NumPy: */ template @@ -36,16 +42,26 @@ X86_SIMD_SORT_FINLINE void argsort(const T *arr, size_t *arg, size_t size, bool hasnan = false, - bool descending = false); + bool descending = false, + bool nans_last = true); template -X86_SIMD_SORT_FINLINE std::vector -argselect(const T *arr, size_t k, size_t size, bool hasnan = false); +X86_SIMD_SORT_FINLINE std::vector argselect(const T *arr, + size_t k, + size_t size, + bool hasnan = false, + bool descending = false, + bool nans_last = true); /* argselect API required by NumPy: */ template -void X86_SIMD_SORT_FINLINE argselect( - const T *arr, size_t *arg, size_t k, size_t size, bool hasnan = false); +void X86_SIMD_SORT_FINLINE argselect(const T *arr, + size_t *arg, + size_t k, + size_t size, + bool hasnan = false, + bool descending = false, + bool nans_last = true); template X86_SIMD_SORT_FINLINE void keyvalue_qsort(T1 *key, @@ -74,55 +90,89 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key, #define XSS_METHODS(ISA) \ template \ - X86_SIMD_SORT_FINLINE void x86simdsortStatic::qsort( \ - T *arr, size_t size, bool hasnan, bool descending) \ + X86_SIMD_SORT_FINLINE void x86simdsortStatic::qsort(T *arr, \ + size_t size, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ - ISA##_qsort(arr, size, hasnan, descending); \ + ISA##_qsort(arr, size, hasnan, descending, nans_last); \ } \ template \ - X86_SIMD_SORT_FINLINE void x86simdsortStatic::qselect( \ - T *arr, size_t k, size_t size, bool hasnan, bool descending) \ + X86_SIMD_SORT_FINLINE void x86simdsortStatic::qselect(T *arr, \ + size_t k, \ + size_t size, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ - ISA##_qselect(arr, k, size, hasnan, descending); \ + ISA##_qselect(arr, k, size, hasnan, descending, nans_last); \ } \ template \ X86_SIMD_SORT_FINLINE void x86simdsortStatic::partial_qsort( \ - T *arr, size_t k, size_t size, bool hasnan, bool descending) \ + T *arr, \ + size_t k, \ + size_t size, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ - ISA##_partial_qsort(arr, k, size, hasnan, descending); \ + ISA##_partial_qsort(arr, k, size, hasnan, descending, nans_last); \ } \ template \ X86_SIMD_SORT_FINLINE void x86simdsortStatic::argsort(const T *arr, \ size_t *arg, \ size_t size, \ bool hasnan, \ - bool descending) \ + bool descending, \ + bool nans_last) \ { \ - ISA##_argsort(arr, arg, size, hasnan, descending); \ + ISA##_argsort(arr, arg, size, hasnan, descending, nans_last); \ } \ template \ X86_SIMD_SORT_FINLINE std::vector x86simdsortStatic::argsort( \ - const T *arr, size_t size, bool hasnan, bool descending) \ + const T *arr, \ + size_t size, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ std::vector indices(size); \ std::iota(indices.begin(), indices.end(), 0); \ x86simdsortStatic::argsort( \ - arr, indices.data(), size, hasnan, descending); \ + arr, indices.data(), size, hasnan, descending, nans_last); \ return indices; \ } \ template \ X86_SIMD_SORT_FINLINE void x86simdsortStatic::argselect( \ - const T *arr, size_t *arg, size_t k, size_t size, bool hasnan) \ + const T *arr, \ + size_t *arg, \ + size_t k, \ + size_t size, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ - ISA##_argselect(arr, arg, k, size, hasnan); \ + ISA##_argselect(arr, arg, k, size, hasnan, descending, nans_last); \ } \ template \ X86_SIMD_SORT_FINLINE std::vector x86simdsortStatic::argselect( \ - const T *arr, size_t k, size_t size, bool hasnan) \ + const T *arr, \ + size_t k, \ + size_t size, \ + bool hasnan, \ + bool descending, \ + bool nans_last) \ { \ std::vector indices(size); \ std::iota(indices.begin(), indices.end(), 0); \ - x86simdsortStatic::argselect(arr, indices.data(), k, size, hasnan); \ + x86simdsortStatic::argselect(arr, \ + indices.data(), \ + k, \ + size, \ + hasnan, \ + descending, \ + nans_last); \ return indices; \ } \ template \ @@ -185,23 +235,34 @@ template <> void x86simdsortStatic::qsort<_Float16>(_Float16 *arr, size_t size, bool hasnan, - bool descending) + bool descending, + bool nans_last) { - avx512_qsort_fp16((uint16_t *)arr, size, hasnan, descending); + avx512_qsort_fp16((uint16_t *)arr, size, hasnan, descending, nans_last); } template <> [[maybe_unused]] -void x86simdsortStatic::qselect<_Float16>( - _Float16 *arr, size_t k, size_t size, bool hasnan, bool descending) +void x86simdsortStatic::qselect<_Float16>(_Float16 *arr, + size_t k, + size_t size, + bool hasnan, + bool descending, + bool nans_last) { - avx512_qselect_fp16((uint16_t *)arr, k, size, hasnan, descending); + avx512_qselect_fp16( + (uint16_t *)arr, k, size, hasnan, descending, nans_last); } template <> [[maybe_unused]] -void x86simdsortStatic::partial_qsort<_Float16>( - _Float16 *arr, size_t k, size_t size, bool hasnan, bool descending) +void x86simdsortStatic::partial_qsort<_Float16>(_Float16 *arr, + size_t k, + size_t size, + bool hasnan, + bool descending, + bool nans_last) { - avx512_partial_qsort_fp16((uint16_t *)arr, k, size, hasnan, descending); + avx512_partial_qsort_fp16( + (uint16_t *)arr, k, size, hasnan, descending, nans_last); } #endif diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h index 1bec821b..b4fd433a 100644 --- a/src/xss-common-argsort.h +++ b/src/xss-common-argsort.h @@ -15,44 +15,46 @@ X86_SIMD_SORT_INLINE void std_argselect_withnan(const T *arr, arrsize_t *arg, arrsize_t k, arrsize_t left, - arrsize_t right) + arrsize_t right, + bool nans_last = true, + bool descending = false) { - std::nth_element(arg + left, - arg + k, - arg + right, - [arr](arrsize_t a, arrsize_t b) -> bool { - if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { - return arr[a] < arr[b]; - } - else if (std::isnan(arr[a])) { - return false; - } - else { - return true; - } - }); + std::nth_element( + arg + left, + arg + k, + arg + right, + [arr, nans_last, descending](arrsize_t a, arrsize_t b) -> bool { + bool a_nan = std::isnan(arr[a]); + bool b_nan = std::isnan(arr[b]); + if (!a_nan && !b_nan) { + return descending ? arr[a] > arr[b] : arr[a] < arr[b]; + } + if (a_nan && b_nan) { return false; } + return nans_last ? !a_nan : a_nan; + }); } -/* argsort using std::sort */ +/* argsort using std::sort, handles NaN placement and descending order */ template X86_SIMD_SORT_INLINE void std_argsort_withnan(const T *arr, arrsize_t *arg, arrsize_t left, - arrsize_t right) + arrsize_t right, + bool nans_last = true, + bool descending = false) { - std::sort(arg + left, - arg + right, - [arr](arrsize_t left, arrsize_t right) -> bool { - if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) { - return arr[left] < arr[right]; - } - else if (std::isnan(arr[left])) { - return false; - } - else { - return true; - } - }); + std::sort( + arg + left, + arg + right, + [arr, nans_last, descending](arrsize_t a, arrsize_t b) -> bool { + bool a_nan = std::isnan(arr[a]); + bool b_nan = std::isnan(arr[b]); + if (!a_nan && !b_nan) { + return descending ? arr[a] > arr[b] : arr[a] < arr[b]; + } + if (a_nan && b_nan) { return false; } + return nans_last ? !a_nan : a_nan; + }); } /* argsort using std::sort */ @@ -599,7 +601,8 @@ X86_SIMD_SORT_INLINE void xss_argsort(const T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false, - bool descending = false) + bool descending = false, + bool nans_last = true) { using vectype = typename std::conditional) { if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argsort_withnan(arr, arg, 0, arrsize); - - if (descending) { std::reverse(arg, arg + arrsize); } - + std_argsort_withnan( + arr, arg, 0, arrsize, nans_last, descending); return; } } UNUSED(hasnan); + UNUSED(nans_last); /* early exit for already sorted arrays: float/double with nan never reach here*/ auto comp = descending ? Comparator::STDSortComparator @@ -678,10 +680,11 @@ X86_SIMD_SORT_INLINE void avx512_argsort(const T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false, - bool descending = false) + bool descending = false, + bool nans_last = true) { xss_argsort( - arr, arg, arrsize, hasnan, descending); + arr, arg, arrsize, hasnan, descending, nans_last); } template @@ -689,10 +692,11 @@ X86_SIMD_SORT_INLINE void avx2_argsort(const T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false, - bool descending = false) + bool descending = false, + bool nans_last = true) { xss_argsort( - arr, arg, arrsize, hasnan, descending); + arr, arg, arrsize, hasnan, descending, nans_last); } /* argselect methods for 32-bit and 64-bit dtypes */ @@ -705,7 +709,9 @@ X86_SIMD_SORT_INLINE void xss_argselect(const T *arr, arrsize_t *arg, arrsize_t k, arrsize_t arrsize, - bool hasnan = false) + bool hasnan = false, + bool descending = false, + bool nans_last = true) { /* TODO optimization: on 32-bit, use full_vector for 32-bit dtype */ using vectype = typename std::conditional 1) { if constexpr (xss::fp::is_floating_point_v) { if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argselect_withnan(arr, arg, k, 0, arrsize); + std_argselect_withnan( + arr, arg, k, 0, arrsize, nans_last, descending); return; } } UNUSED(hasnan); + UNUSED(nans_last); + /* For descending, partition at the mirror position so the k-th + * largest lands at arrsize-1-k; reversal then moves it to k. */ + arrsize_t pos = descending ? arrsize - 1 - k : k; argselect_( - arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + arr, arg, pos, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + + if (descending) { std::reverse(arg, arg + arrsize); } } #ifdef __MMX__ @@ -740,9 +753,12 @@ X86_SIMD_SORT_INLINE void avx512_argselect(const T *arr, arrsize_t *arg, arrsize_t k, arrsize_t arrsize, - bool hasnan = false) + bool hasnan = false, + bool descending = false, + bool nans_last = true) { - xss_argselect(arr, arg, k, arrsize, hasnan); + xss_argselect( + arr, arg, k, arrsize, hasnan, descending, nans_last); } template @@ -750,10 +766,12 @@ X86_SIMD_SORT_INLINE void avx2_argselect(const T *arr, arrsize_t *arg, arrsize_t k, arrsize_t arrsize, - bool hasnan = false) + bool hasnan = false, + bool descending = false, + bool nans_last = true) { xss_argselect( - arr, arg, k, arrsize, hasnan); + arr, arg, k, arrsize, hasnan, descending, nans_last); } #endif // XSS_COMMON_ARGSORT diff --git a/src/xss-common-keyvaluesort.hpp b/src/xss-common-keyvaluesort.hpp index 3a07e01b..fcf8778e 100644 --- a/src/xss-common-keyvaluesort.hpp +++ b/src/xss-common-keyvaluesort.hpp @@ -641,8 +641,9 @@ X86_SIMD_SORT_INLINE void xss_qsort_kv( #endif if (descending) { - std::reverse(keys, keys + arrsize); - std::reverse(indexes, indexes + arrsize); + // Only reverse the real portion; NaN at the end stays in place + std::reverse(keys, keys + index_last_elem + 1); + std::reverse(indexes, indexes + index_last_elem + 1); } } @@ -688,8 +689,6 @@ X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys, #endif // XSS_TEST_KEYVALUE_BASE_CASE if (minarrsize) { - if (descending) { k = arrsize - 1 - k; } - arrsize_t index_last_elem = arrsize - 1; if constexpr (xss::fp::is_floating_point_v) { if (UNLIKELY(hasnan)) { @@ -699,6 +698,9 @@ X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys, } } + // For descending: map k to ascending position within real portion + if (descending) { k = index_last_elem - k; } + UNUSED(hasnan); if (index_last_elem >= k) { kvselect_( @@ -706,8 +708,9 @@ X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys, } if (descending) { - std::reverse(keys, keys + arrsize); - std::reverse(indexes, indexes + arrsize); + // Only reverse the real portion; NaN at the end stays in place + std::reverse(keys, keys + index_last_elem + 1); + std::reverse(indexes, indexes + index_last_elem + 1); } } diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 2bf7ca61..70aa3465 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -51,33 +51,6 @@ X86_SIMD_SORT_INLINE_ONLY bool is_a_nan(uint16_t elem) return ((elem & 0x7c00u) == 0x7c00u) && ((elem & 0x03ffu) != 0); } -template -X86_SIMD_SORT_INLINE arrsize_t replace_nan_with_inf(T *arr, arrsize_t size) -{ - arrsize_t nan_count = 0; - using opmask_t = typename vtype::opmask_t; - using reg_t = typename vtype::reg_t; - opmask_t loadmask; - reg_t in; - /* - * (ii + numlanes) can never overflow: max val of size is 2**63 on 64-bit - * and 2**31 on 32-bit systems - */ - for (arrsize_t ii = 0; ii < size; ii = ii + vtype::numlanes) { - if (size - ii < vtype::numlanes) { - loadmask = vtype::get_partial_loadmask(size - ii); - in = vtype::maskz_loadu(loadmask, arr + ii); - } - else { - in = vtype::loadu(arr + ii); - } - opmask_t nanmask = vtype::template fpclass<0x01 | 0x80>(in); - nan_count += _mm_popcnt_u32(vtype::convert_mask_to_int(nanmask)); - vtype::mask_storeu(arr + ii, nanmask, vtype::zmm_max()); - } - return nan_count; -} - template X86_SIMD_SORT_INLINE bool array_has_nan(const type_t *arr, arrsize_t size) { @@ -104,36 +77,6 @@ X86_SIMD_SORT_INLINE bool array_has_nan(const type_t *arr, arrsize_t size) return found_nan; } -template -X86_SIMD_SORT_INLINE void replace_inf_with_nan(type_t *arr, - arrsize_t size, - arrsize_t nan_count, - bool descending = false) -{ - if (descending) { - for (arrsize_t ii = 0; nan_count > 0; ++ii) { - if constexpr (xss::fp::is_floating_point_v) { - arr[ii] = xss::fp::quiet_NaN(); - } - else { - arr[ii] = 0x7c01; // std::quiet_nan - } - nan_count -= 1; - } - } - else { - for (arrsize_t ii = size - 1; nan_count > 0; --ii) { - if constexpr (xss::fp::is_floating_point_v) { - arr[ii] = xss::fp::quiet_NaN(); - } - else { - arr[ii] = 0x7c01; // std::quiet_nan - } - nan_count -= 1; - } - } -} - /* * Sort all the NAN's to end of the array and return the index of the last elem * in the array which is not a nan @@ -650,7 +593,8 @@ X86_SIMD_SORT_INLINE void qselect_(type_t *arr, // Quicksort routines: template -X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan) +X86_SIMD_SORT_INLINE void +xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, bool nans_last = true) { using comparator = typename std::conditional>::type; if (arrsize > 1) { - arrsize_t nan_count = 0; + arrsize_t index_first_elem = 0; + arrsize_t index_last_elem = arrsize - 1; if constexpr (xss::fp::is_floating_point_v) { if (UNLIKELY(hasnan)) { - nan_count = replace_nan_with_inf(arr, arrsize); + if (!nans_last) { + index_first_elem + = move_nans_to_start_of_array(arr, arrsize); + } + else { + index_last_elem = move_nans_to_end_of_array(arr, arrsize); + } } } UNUSED(hasnan); + if (index_first_elem <= index_last_elem && index_last_elem < arrsize) { #ifdef XSS_COMPILE_OPENMP - bool use_parallel = arrsize > 100000; + bool use_parallel + = (index_last_elem - index_first_elem + 1) > 100000; - if (use_parallel) { - int thread_count = xss_get_num_threads(); - arrsize_t task_threshold - = std::max((arrsize_t)100000, arrsize / 100); + if (use_parallel) { + int thread_count = xss_get_num_threads(); + arrsize_t task_threshold = std::max( + (arrsize_t)100000, + (index_last_elem - index_first_elem + 1) / 100); - // We use omp parallel and then omp single to setup the threads that will run the omp task calls in qsort_ - // The omp single prevents multiple threads from running the initial qsort_ simultaneously and causing problems - // Note that we do not use the if(...) clause built into OpenMP, because it causes a performance regression for small arrays + // We use omp parallel and then omp single to setup the threads that will run the omp task calls in qsort_ + // The omp single prevents multiple threads from running the initial qsort_ simultaneously and causing problems + // Note that we do not use the if(...) clause built into OpenMP, because it causes a performance regression for small arrays #pragma omp parallel num_threads(thread_count) #pragma omp single - qsort_(arr, - 0, - arrsize - 1, - 2 * (arrsize_t)log2(arrsize), - task_threshold); + qsort_(arr, + index_first_elem, + index_last_elem, + 2 * (arrsize_t)log2(arrsize), + task_threshold); #pragma omp taskwait - } - else { + } + else { + qsort_( + arr, + index_first_elem, + index_last_elem, + 2 * (arrsize_t)log2(arrsize), + std::numeric_limits::max()); + } +#else qsort_(arr, - 0, - arrsize - 1, + index_first_elem, + index_last_elem, 2 * (arrsize_t)log2(arrsize), - std::numeric_limits::max()); - } -#else - qsort_( - arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0); + 0); #endif - - replace_inf_with_nan(arr, arrsize, nan_count, descending); + } } #ifdef __MMX__ @@ -711,8 +668,11 @@ X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan) // Quick select methods template -X86_SIMD_SORT_INLINE void -xss_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) +X86_SIMD_SORT_INLINE void xss_qselect(T *arr, + arrsize_t k, + arrsize_t arrsize, + bool hasnan, + bool nans_last = true) { using comparator = typename std::conditional) { if (UNLIKELY(hasnan)) { - if constexpr (descending) { + if (!nans_last) { index_first_elem = move_nans_to_start_of_array(arr, arrsize); } else { @@ -753,12 +713,16 @@ xss_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) // Partial sort methods: template -X86_SIMD_SORT_INLINE void -xss_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) +X86_SIMD_SORT_INLINE void xss_partial_qsort(T *arr, + arrsize_t k, + arrsize_t arrsize, + bool hasnan, + bool nans_last = true) { if (k == 0) return; - xss_qselect(arr, k - 1, arrsize, hasnan); - xss_qsort(arr, k - 1, hasnan); + xss_qselect( + arr, k - 1, arrsize, hasnan, nans_last); + xss_qsort(arr, k - 1, hasnan, nans_last); } #define DEFINE_METHODS(ISA, VTYPE) \ @@ -766,11 +730,14 @@ xss_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) X86_SIMD_SORT_INLINE void ISA##_qsort(T *arr, \ arrsize_t size, \ bool hasnan = false, \ - bool descending = false) \ + bool descending = false, \ + bool nans_last = true) \ { \ - if (descending) { xss_qsort(arr, size, hasnan); } \ + if (descending) { \ + xss_qsort(arr, size, hasnan, nans_last); \ + } \ else { \ - xss_qsort(arr, size, hasnan); \ + xss_qsort(arr, size, hasnan, nans_last); \ } \ } \ template \ @@ -778,11 +745,14 @@ xss_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) arrsize_t k, \ arrsize_t size, \ bool hasnan = false, \ - bool descending = false) \ + bool descending = false, \ + bool nans_last = true) \ { \ - if (descending) { xss_qselect(arr, k, size, hasnan); } \ + if (descending) { \ + xss_qselect(arr, k, size, hasnan, nans_last); \ + } \ else { \ - xss_qselect(arr, k, size, hasnan); \ + xss_qselect(arr, k, size, hasnan, nans_last); \ } \ } \ template \ @@ -790,13 +760,16 @@ xss_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) arrsize_t k, \ arrsize_t size, \ bool hasnan = false, \ - bool descending = false) \ + bool descending = false, \ + bool nans_last = true) \ { \ if (descending) { \ - xss_partial_qsort(arr, k, size, hasnan); \ + xss_partial_qsort( \ + arr, k, size, hasnan, nans_last); \ } \ else { \ - xss_partial_qsort(arr, k, size, hasnan); \ + xss_partial_qsort( \ + arr, k, size, hasnan, nans_last); \ } \ } diff --git a/tests/test-qsort-common.h b/tests/test-qsort-common.h index e894a86e..598aee49 100644 --- a/tests/test-qsort-common.h +++ b/tests/test-qsort-common.h @@ -60,20 +60,25 @@ void IS_ARR_PARTITIONED(std::vector arr, cmp_eq = compare>(); if (!descending) { - cmp_less = compare>(); - cmp_leq = compare>(); - cmp_geq = compare>(); + cmp_less = compare_nan_end>(); + cmp_leq = compare_nan_end>(); + cmp_geq = compare_nan_end>(); } else { - cmp_less = compare>(); - cmp_leq = compare>(); - cmp_geq = compare>(); + cmp_less = compare_nan_end>(); + cmp_leq = compare_nan_end>(); + cmp_geq = compare_nan_end>(); } // 1) arr[k] == sorted[k]; use memcmp to handle nan if (!cmp_eq(arr[k], true_kth)) { REPORT_FAIL("kth element is incorrect", arr.size(), type, k); } + // If arr[k] is NaN, k is in the trailing NaN block; value comparisons are + // not meaningful, so skip the left/right partition checks. + if constexpr (xss::fp::is_floating_point_v) { + if (xss::fp::isnan(arr[k])) return; + } // ( 2) Elements to the left of k should be atmost arr[k] if (k >= 1) { T max_left @@ -108,13 +113,14 @@ void IS_ARG_PARTITIONED(std::vector arr, std::vector arg, T true_kth, size_t k, - std::string type) + std::string type, + bool descending = false) { EXPECT_UNIQUE(arg) std::vector part_arr; for (auto ii : arg) { part_arr.push_back(arr[ii]); } - IS_ARR_PARTITIONED(part_arr, k, true_kth, type); + IS_ARR_PARTITIONED(part_arr, k, true_kth, type, descending); } #endif diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index f2ce3a6b..89e1435b 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -76,7 +76,7 @@ TYPED_TEST_P(simdsort, test_qsort_descending) #ifndef XSS_ASAN_CI_NOCHECK std::sort(sortedarr.begin(), sortedarr.end(), - compare>()); + compare_nan_end>()); IS_SORTED(sortedarr, arr, type); #endif arr.clear(); @@ -119,7 +119,7 @@ TYPED_TEST_P(simdsort, test_argsort_descending) #ifndef XSS_ASAN_CI_NOCHECK std::sort(sortedarr.begin(), sortedarr.end(), - compare>()); + compare_nan_end>()); IS_ARG_SORTED(sortedarr, arr, arg, type); #endif arr.clear(); @@ -128,6 +128,92 @@ TYPED_TEST_P(simdsort, test_argsort_descending) } } +TYPED_TEST_P(simdsort, test_argsort_nans_last) +{ + if constexpr (xss::fp::is_floating_point_v) { + std::vector nan_types + = {"rand_with_nan", "rand_with_max_and_nan"}; + for (auto type : nan_types) { + for (auto size : this->arrsize_long) { + std::vector base = get_array(type, size); + + // ascending, NaNs at end + { + std::vector arr = base; + std::vector sortedarr = base; + auto arg = x86simdsort::argsort( + arr.data(), arr.size(), true, false, true); +#ifndef XSS_ASAN_CI_NOCHECK + std::sort( + sortedarr.begin(), + sortedarr.end(), + compare_nan_end>()); + IS_ARG_SORTED(sortedarr, arr, arg, type); +#endif + } + + // descending, NaNs at end + { + std::vector arr = base; + std::vector sortedarr = base; + auto arg = x86simdsort::argsort( + arr.data(), arr.size(), true, true, true); +#ifndef XSS_ASAN_CI_NOCHECK + std::sort(sortedarr.begin(), + sortedarr.end(), + compare_nan_end>()); + IS_ARG_SORTED(sortedarr, arr, arg, type); +#endif + } + } + } + } +} + +TYPED_TEST_P(simdsort, test_argsort_leading_nans) +{ + if constexpr (xss::fp::is_floating_point_v) { + std::vector nan_types + = {"rand_with_nan", "rand_with_max_and_nan"}; + for (auto type : nan_types) { + for (auto size : this->arrsize_long) { + std::vector base = get_array(type, size); + + // ascending, NaNs at start + { + std::vector arr = base; + std::vector sortedarr = base; + auto arg = x86simdsort::argsort( + arr.data(), arr.size(), true, false, false); +#ifndef XSS_ASAN_CI_NOCHECK + std::sort(sortedarr.begin(), + sortedarr.end(), + compare_nan_begin>()); + IS_ARG_SORTED(sortedarr, arr, arg, type); +#endif + } + + // descending, NaNs at start + { + std::vector arr = base; + std::vector sortedarr = base; + auto arg = x86simdsort::argsort( + arr.data(), arr.size(), true, true, false); +#ifndef XSS_ASAN_CI_NOCHECK + std::sort(sortedarr.begin(), + sortedarr.end(), + compare_nan_begin>()); + IS_ARG_SORTED(sortedarr, arr, arg, type); +#endif + } + } + } + } +} + TYPED_TEST_P(simdsort, test_qselect_ascending) { for (auto type : this->arrtype) { @@ -169,10 +255,11 @@ TYPED_TEST_P(simdsort, test_qselect_descending) x86simdsort::qselect(arr.data(), k, arr.size(), hasnan, true); #ifndef XSS_ASAN_CI_NOCHECK - std::nth_element(sortedarr.begin(), - sortedarr.begin() + k, - sortedarr.end(), - compare>()); + std::nth_element( + sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare_nan_end>()); if (size == 0) continue; IS_ARR_PARTITIONED(arr, k, sortedarr[k], type, true); #endif @@ -182,7 +269,105 @@ TYPED_TEST_P(simdsort, test_qselect_descending) } } -TYPED_TEST_P(simdsort, test_argselect) +TYPED_TEST_P(simdsort, test_qselect_nans_last) +{ + if constexpr (xss::fp::is_floating_point_v) { + std::vector nan_types + = {"rand_with_nan", "rand_with_max_and_nan"}; + for (auto type : nan_types) { + for (auto size : this->arrsize) { + size_t k = size != 0 ? rand() % size : 0; + std::vector base = get_array(type, size); + + // ascending, NaNs at end + { + std::vector arr = base; + std::vector sortedarr = base; + x86simdsort::qselect( + arr.data(), k, arr.size(), true, false, true); +#ifndef XSS_ASAN_CI_NOCHECK + std::nth_element( + sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare_nan_end>()); + if (size == 0) continue; + IS_ARR_PARTITIONED(arr, k, sortedarr[k], type); +#endif + } + + // descending, NaNs at end + { + std::vector arr = base; + std::vector sortedarr = base; + x86simdsort::qselect( + arr.data(), k, arr.size(), true, true, true); +#ifndef XSS_ASAN_CI_NOCHECK + std::nth_element( + sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare_nan_end>()); + if (size == 0) continue; + IS_ARR_PARTITIONED(arr, k, sortedarr[k], type, true); +#endif + } + } + } + } +} + +TYPED_TEST_P(simdsort, test_qselect_leading_nans) +{ + if constexpr (xss::fp::is_floating_point_v) { + std::vector nan_types + = {"rand_with_nan", "rand_with_max_and_nan"}; + for (auto type : nan_types) { + for (auto size : this->arrsize) { + size_t k = size != 0 ? rand() % size : 0; + std::vector base = get_array(type, size); + + // ascending, NaNs at start + { + std::vector arr = base; + std::vector sortedarr = base; + x86simdsort::qselect( + arr.data(), k, arr.size(), true, false, false); +#ifndef XSS_ASAN_CI_NOCHECK + std::nth_element(sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare_nan_begin>()); + if (size == 0) continue; + IS_ARR_PARTITIONED(arr, k, sortedarr[k], type); +#endif + } + + // descending, NaNs at start + { + std::vector arr = base; + std::vector sortedarr = base; + x86simdsort::qselect( + arr.data(), k, arr.size(), true, true, false); +#ifndef XSS_ASAN_CI_NOCHECK + std::nth_element( + sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare_nan_begin>()); + if (size == 0) continue; + IS_ARR_PARTITIONED(arr, k, sortedarr[k], type, true); +#endif + } + } + } + } +} + +TYPED_TEST_P(simdsort, test_argselect_ascending) { for (auto type : this->arrtype) { bool hasnan = is_nan_test(type); @@ -206,6 +391,130 @@ TYPED_TEST_P(simdsort, test_argselect) } } +TYPED_TEST_P(simdsort, test_argselect_descending) +{ + for (auto type : this->arrtype) { + bool hasnan = is_nan_test(type); + for (auto size : this->arrsize) { + size_t k = size != 0 ? rand() % size : 0; + std::vector arr = get_array(type, size); + std::vector sortedarr = arr; + + auto arg = x86simdsort::argselect( + arr.data(), k, arr.size(), hasnan, true); +#ifndef XSS_ASAN_CI_NOCHECK + std::nth_element( + sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare_nan_end>()); + if (size == 0) continue; + IS_ARG_PARTITIONED(arr, arg, sortedarr[k], k, type, true); +#endif + arr.clear(); + sortedarr.clear(); + } + } +} + +TYPED_TEST_P(simdsort, test_argselect_nans_last) +{ + if constexpr (xss::fp::is_floating_point_v) { + std::vector nan_types + = {"rand_with_nan", "rand_with_max_and_nan"}; + for (auto type : nan_types) { + for (auto size : this->arrsize) { + size_t k = size != 0 ? rand() % size : 0; + std::vector base = get_array(type, size); + + // ascending, NaNs at end + { + std::vector arr = base; + std::vector sortedarr = base; + auto arg = x86simdsort::argselect( + arr.data(), k, arr.size(), true, false, true); +#ifndef XSS_ASAN_CI_NOCHECK + std::nth_element( + sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare_nan_end>()); + if (size == 0) continue; + IS_ARG_PARTITIONED(arr, arg, sortedarr[k], k, type); +#endif + } + + // descending, NaNs at end + { + std::vector arr = base; + std::vector sortedarr = base; + auto arg = x86simdsort::argselect( + arr.data(), k, arr.size(), true, true, true); +#ifndef XSS_ASAN_CI_NOCHECK + std::nth_element( + sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare_nan_end>()); + if (size == 0) continue; + IS_ARG_PARTITIONED(arr, arg, sortedarr[k], k, type, true); +#endif + } + } + } + } +} + +TYPED_TEST_P(simdsort, test_argselect_leading_nans) +{ + if constexpr (xss::fp::is_floating_point_v) { + std::vector nan_types + = {"rand_with_nan", "rand_with_max_and_nan"}; + for (auto type : nan_types) { + for (auto size : this->arrsize) { + size_t k = size != 0 ? rand() % size : 0; + std::vector base = get_array(type, size); + + // ascending, NaNs at start + { + std::vector arr = base; + std::vector sortedarr = base; + auto arg = x86simdsort::argselect( + arr.data(), k, arr.size(), true, false, false); +#ifndef XSS_ASAN_CI_NOCHECK + std::nth_element(sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare_nan_begin>()); + if (size == 0) continue; + IS_ARG_PARTITIONED(arr, arg, sortedarr[k], k, type); +#endif + } + + // descending, NaNs at start + { + std::vector arr = base; + std::vector sortedarr = base; + auto arg = x86simdsort::argselect( + arr.data(), k, arr.size(), true, true, false); +#ifndef XSS_ASAN_CI_NOCHECK + std::nth_element( + sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare_nan_begin>()); + if (size == 0) continue; + IS_ARG_PARTITIONED(arr, arg, sortedarr[k], k, type, true); +#endif + } + } + } + } +} + TYPED_TEST_P(simdsort, test_partial_qsort_ascending) { for (auto type : this->arrtype) { @@ -248,7 +557,7 @@ TYPED_TEST_P(simdsort, test_partial_qsort_descending) #ifndef XSS_ASAN_CI_NOCHECK std::sort(sortedarr.begin(), sortedarr.end(), - compare>()); + compare_nan_end>()); if (size == 0) continue; IS_ARR_PARTIALSORTED(arr, k, sortedarr, type); #endif @@ -291,14 +600,109 @@ TYPED_TEST_P(simdsort, test_comparator) } } +TYPED_TEST_P(simdsort, test_qsort_nans_last) +{ + if constexpr (xss::fp::is_floating_point_v) { + std::vector nan_types + = {"rand_with_nan", "rand_with_max_and_nan"}; + for (auto type : nan_types) { + for (auto size : this->arrsize_long) { + std::vector base = get_array(type, size); + + // ascending, NaNs at end + { + std::vector arr = base; + std::vector sortedarr = base; + x86simdsort::qsort( + arr.data(), arr.size(), true, false, true); +#ifndef XSS_ASAN_CI_NOCHECK + std::sort( + sortedarr.begin(), + sortedarr.end(), + compare_nan_end>()); + IS_SORTED(sortedarr, arr, type); +#endif + } + + // descending, NaNs at end + { + std::vector arr = base; + std::vector sortedarr = base; + x86simdsort::qsort( + arr.data(), arr.size(), true, true, true); +#ifndef XSS_ASAN_CI_NOCHECK + std::sort(sortedarr.begin(), + sortedarr.end(), + compare_nan_end>()); + IS_SORTED(sortedarr, arr, type); +#endif + } + } + } + } +} + +TYPED_TEST_P(simdsort, test_qsort_leading_nans) +{ + if constexpr (xss::fp::is_floating_point_v) { + std::vector nan_types + = {"rand_with_nan", "rand_with_max_and_nan"}; + for (auto type : nan_types) { + for (auto size : this->arrsize_long) { + std::vector base = get_array(type, size); + + // ascending, NaNs at start + { + std::vector arr = base; + std::vector sortedarr = base; + x86simdsort::qsort( + arr.data(), arr.size(), true, false, false); +#ifndef XSS_ASAN_CI_NOCHECK + std::sort(sortedarr.begin(), + sortedarr.end(), + compare_nan_begin>()); + IS_SORTED(sortedarr, arr, type); +#endif + } + + // descending, NaNs at start + { + std::vector arr = base; + std::vector sortedarr = base; + x86simdsort::qsort( + arr.data(), arr.size(), true, true, false); +#ifndef XSS_ASAN_CI_NOCHECK + std::sort(sortedarr.begin(), + sortedarr.end(), + compare_nan_begin>()); + IS_SORTED(sortedarr, arr, type); +#endif + } + } + } + } +} + REGISTER_TYPED_TEST_SUITE_P(simdsort, test_qsort_ascending, test_qsort_descending, + test_qsort_nans_last, + test_qsort_leading_nans, test_argsort_ascending, test_argsort_descending, - test_argselect, + test_argsort_nans_last, + test_argsort_leading_nans, + test_argselect_ascending, + test_argselect_descending, + test_argselect_nans_last, + test_argselect_leading_nans, test_qselect_ascending, test_qselect_descending, + test_qselect_nans_last, + test_qselect_leading_nans, test_partial_qsort_ascending, test_partial_qsort_descending, test_comparator); diff --git a/utils/custom-compare.h b/utils/custom-compare.h index f2c8d61e..80274f46 100644 --- a/utils/custom-compare.h +++ b/utils/custom-compare.h @@ -46,4 +46,68 @@ struct compare_arg { const T *arr; }; +/* + * Comparator that always places NaN at the end of the sorted array, + * regardless of whether Comparator is ascending or descending. + */ +template +struct compare_nan_end { + static constexpr auto op = Comparator {}; + bool operator()(const T a, const T b) const + { + if constexpr (xss::fp::is_floating_point_v) { + bool a_nan = xss::fp::isnan(a); + bool b_nan = xss::fp::isnan(b); + if (!a_nan && !b_nan) { return op(a, b); } + if (a_nan && b_nan) { return false; } + return !a_nan; // b is NaN → a before b → NaN at end + } + else { + return op(a, b); + } + } +}; + +/* + * Comparator that always places NaN at the beginning of the sorted array, + * regardless of whether Comparator is ascending or descending. + */ +template +struct compare_nan_begin { + static constexpr auto op = Comparator {}; + bool operator()(const T a, const T b) const + { + if constexpr (xss::fp::is_floating_point_v) { + bool a_nan = xss::fp::isnan(a); + bool b_nan = xss::fp::isnan(b); + if (!a_nan && !b_nan) { return op(a, b); } + if (a_nan && b_nan) { return false; } + return a_nan; // a is NaN → a before b → NaN at beginning + } + else { + return op(a, b); + } + } +}; + +template +struct compare_arg_nan_end { + compare_arg_nan_end(const T *arr) : arr(arr) {} + bool operator()(const int64_t a, const int64_t b) const + { + return compare_nan_end()(arr[a], arr[b]); + } + const T *arr; +}; + +template +struct compare_arg_nan_begin { + compare_arg_nan_begin(const T *arr) : arr(arr) {} + bool operator()(const int64_t a, const int64_t b) const + { + return compare_nan_begin()(arr[a], arr[b]); + } + const T *arr; +}; + #endif // UTILS_CUSTOM_COMPARE \ No newline at end of file