From 1ee4dad3d67815ae46c182016341b2ac9aa18b61 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 11 May 2026 09:11:54 +0000 Subject: [PATCH 01/13] Add trailing_nans parameter to control NaN placement in sort results All sorting and selection routines (qsort, qselect, partial_qsort, argsort, argselect) now accept an optional `bool trailing_nans` parameter (default: true). When hasnan=true: - trailing_nans=true (default): NaN values appear at the end of the result, independent of sort direction. - trailing_nans=false: NaN values appear at the beginning of the result, independent of sort direction. Previously, NaN placement was coupled to the sort direction: ascending always put NaN at the end, descending always put NaN at the beginning. This change decouples the two. Implementation notes: - qsort path: replace_inf_with_nan now takes both `descending` and `trailing_nans`; when they differ, the real-value portion is rotated before NaN slots are written back. - qselect path: NaN is moved to the desired end before partitioning, keeping the existing move_nans_to_{end,start}_of_array helpers. - argsort path: after std_argsort_withnan + optional reverse, a std::rotate moves NaN indices to the desired end if needed. - argselect path: comparator lambda respects trailing_nans. - keyvalue_qsort/select path: descending reverse now only reverses the real portion [0, index_last_elem], leaving trailing NaN in place. - Scalar fallback: four NaN-aware comparators (compare_nan_end/begin, compare_arg_nan_end/begin) added to utils/custom-compare.h. Tests updated to use compare_nan_end for descending reference sorts and to skip value-comparison checks when arr[k] is NaN. --- README.md | 30 ++++++---- lib/x86simdsort-avx2.cpp | 34 +++++++----- lib/x86simdsort-icl.cpp | 51 +++++++++++------ lib/x86simdsort-internal.h | 15 +++-- lib/x86simdsort-scalar.h | 82 +++++++++++++++++++-------- lib/x86simdsort-skx.cpp | 34 +++++++----- lib/x86simdsort-spr.cpp | 17 ++++-- lib/x86simdsort.cpp | 98 ++++++++++++++++++++------------- lib/x86simdsort.h | 15 +++-- src/avx512-16bit-qsort.hpp | 18 +++--- src/avx512fp16-16bit-qsort.hpp | 18 ++++-- src/x86simdsort-static-incl.h | 85 ++++++++++++++++++---------- src/xss-common-argsort.h | 71 ++++++++++++++++-------- src/xss-common-keyvaluesort.hpp | 15 +++-- src/xss-common-qsort.h | 78 ++++++++++++++++++-------- tests/test-qsort-common.h | 17 ++++-- tests/test-qsort.cpp | 8 +-- utils/custom-compare.h | 64 +++++++++++++++++++++ 18 files changed, 510 insertions(+), 240 deletions(-) diff --git a/README.md b/README.md index da2f3f85..da10e51a 100644 --- a/README.md +++ b/README.md @@ -37,9 +37,9 @@ how fast this is relative to `std::sort`. ## Sort an array of built-in integers and floats ```cpp -void x86simdsort::qsort(T* arr, size_t size, bool hasnan, bool descending); -void x86simdsort::qselect(T* arr, size_t k, size_t size, bool hasnan, bool descending); -void x86simdsort::partial_qsort(T* arr, size_t k, size_t size, bool hasnan, bool descending); +void x86simdsort::qsort(T* arr, size_t size, bool hasnan, bool descending, bool trailing_nans); +void x86simdsort::qselect(T* arr, size_t k, size_t size, bool hasnan, bool descending, bool trailing_nans); +void x86simdsort::partial_qsort(T* arr, size_t k, size_t size, bool hasnan, bool descending, bool trailing_nans); ``` Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t, int32_t, double, uint64_t, int64_t]` @@ -56,8 +56,8 @@ data types. ## Arg sort routines on arrays ```cpp -std::vector arg = x86simdsort::argsort(const T* arr, size_t size, bool hasnan, bool descending); -std::vector arg = x86simdsort::argselect(const T* arr, size_t k, size_t size, bool hasnan); +std::vector arg = x86simdsort::argsort(const T* arr, size_t size, bool hasnan, bool descending, bool trailing_nans); +std::vector arg = x86simdsort::argselect(const T* arr, size_t k, size_t size, bool hasnan, bool trailing_nans); ``` Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t, int32_t, double, uint64_t, int64_t]` Note that argsort and argselect are not accelerated with SIMD when using 16-bit @@ -174,13 +174,19 @@ Supported datatypes: `uint16_t, int16_t, _Float16, uint32_t, int32_t, float, uint64_t, int64_t, double`. Note that `_Float16` will require building this library with g++ >= 12.x. All the functions have an optional argument `bool hasnan` set to `false` by default (these are relevant to floating point data -types only). If your array has NAN's, the the behaviour of the sorting routine -is undefined. If `hasnan` is set to true, NAN's are always sorted to the end of -the array. In addition to that, qsort will replace all your NAN's with -`std::numeric_limits::quiet_NaN`. The original bit-exact NaNs in -the input are not preserved. Also note that the arg methods (argsort and -argselect) will not use the SIMD based algorithms if they detect NAN's in the -array. You can read details of all the implementations +types only). If your array has NaN values, the behaviour of the sorting routine +is undefined unless `hasnan` is set to `true`. When `hasnan=true`, NaN placement +is controlled by the optional `bool trailing_nans` parameter (default `true`): + +- `trailing_nans=true` (default): NaN values are placed at the **end** of the + result, regardless of sort direction. +- `trailing_nans=false`: NaN values are placed at the **beginning** of the + result, regardless of sort direction. + +Note that `qsort` will replace all NaN values with `std::numeric_limits::quiet_NaN`; +the original bit-exact NaN payload is not preserved. Also note that the arg +methods (argsort and argselect) will not use the SIMD based algorithms if they +detect NaN values in the array. You can read details of all the implementations [here](https://github.com/intel/x86-simd-sort/blob/main/src/README.md). ## Performance comparison on AVX-512: `object_qsort` v/s `std::sort` diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 36b43e89..527e5e49 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -5,33 +5,39 @@ #define DEFINE_ALL_METHODS(type) \ template <> \ - void qsort(type *arr, size_t arrsize, bool hasnan, bool descending) \ + void qsort(type *arr, size_t arrsize, bool hasnan, bool descending, \ + bool trailing_nans) \ { \ - x86simdsortStatic::qsort(arr, arrsize, hasnan, descending); \ + x86simdsortStatic::qsort(arr, arrsize, hasnan, descending, \ + trailing_nans); \ } \ template <> \ - void qselect( \ - type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void qselect(type *arr, size_t k, size_t arrsize, bool hasnan, \ + bool descending, bool trailing_nans) \ { \ - x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending); \ + x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending, \ + trailing_nans); \ } \ template <> \ - void partial_qsort( \ - type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void partial_qsort(type *arr, size_t k, size_t arrsize, bool hasnan, \ + bool descending, bool trailing_nans) \ { \ - x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending); \ + x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending, \ + trailing_nans); \ } \ template <> \ - std::vector argsort( \ - const type *arr, size_t arrsize, bool hasnan, bool descending) \ + std::vector argsort(const type *arr, size_t arrsize, bool hasnan, \ + bool descending, bool trailing_nans) \ { \ - return x86simdsortStatic::argsort(arr, arrsize, hasnan, descending); \ + return x86simdsortStatic::argsort(arr, arrsize, hasnan, descending, \ + trailing_nans); \ } \ template <> \ - std::vector argselect( \ - const type *arr, size_t k, size_t arrsize, bool hasnan) \ + std::vector argselect(const type *arr, size_t k, size_t arrsize, \ + bool hasnan, bool trailing_nans) \ { \ - return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \ + return x86simdsortStatic::argselect(arr, k, arrsize, hasnan, \ + trailing_nans); \ } #define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \ diff --git a/lib/x86simdsort-icl.cpp b/lib/x86simdsort-icl.cpp index 3e5c4b5b..8d04ee64 100644 --- a/lib/x86simdsort-icl.cpp +++ b/lib/x86simdsort-icl.cpp @@ -8,76 +8,91 @@ namespace xss { namespace avx512 { template <> - void qsort(uint16_t *arr, size_t size, bool hasnan, bool descending) + void qsort(uint16_t *arr, size_t size, bool hasnan, bool descending, + bool trailing_nans) { - x86simdsortStatic::qsort(arr, size, hasnan, descending); + x86simdsortStatic::qsort(arr, size, hasnan, descending, trailing_nans); } template <> void qselect(uint16_t *arr, size_t k, size_t arrsize, bool hasnan, - bool descending) + bool descending, + bool trailing_nans) { - x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending); + x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending, + trailing_nans); } template <> void partial_qsort(uint16_t *arr, size_t k, size_t arrsize, bool hasnan, - bool descending) + bool descending, + bool trailing_nans) { - x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending); + x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending, + trailing_nans); } template <> - void qsort(int16_t *arr, size_t size, bool hasnan, bool descending) + void qsort(int16_t *arr, size_t size, bool hasnan, bool descending, + bool trailing_nans) { - x86simdsortStatic::qsort(arr, size, hasnan, descending); + x86simdsortStatic::qsort(arr, size, hasnan, descending, trailing_nans); } template <> void qselect(int16_t *arr, size_t k, size_t arrsize, bool hasnan, - bool descending) + bool descending, + bool trailing_nans) { - x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending); + x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending, + trailing_nans); } template <> void partial_qsort(int16_t *arr, size_t k, size_t arrsize, bool hasnan, - bool descending) + bool descending, + bool trailing_nans) { - x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending); + x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending, + trailing_nans); } } // namespace avx512 namespace fp16_icl { #ifdef __FLT16_MAX__ template <> - void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending) + void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending, + bool trailing_nans) { - x86simdsortStatic::qsort(arr, size, hasnan, descending); + x86simdsortStatic::qsort(arr, size, hasnan, descending, trailing_nans); } template <> void qselect(_Float16 *arr, size_t k, size_t arrsize, bool hasnan, - bool descending) + bool descending, + bool trailing_nans) { - x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending); + x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending, + trailing_nans); } template <> void partial_qsort(_Float16 *arr, size_t k, size_t arrsize, bool hasnan, - bool descending) + bool descending, + bool trailing_nans) { - x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending); + x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending, + trailing_nans); } #endif } // namespace fp16_icl diff --git a/lib/x86simdsort-internal.h b/lib/x86simdsort-internal.h index 055df2bd..7bd91afc 100644 --- a/lib/x86simdsort-internal.h +++ b/lib/x86simdsort-internal.h @@ -10,7 +10,8 @@ XSS_HIDE_SYMBOL void qsort(T *arr, \ size_t arrsize, \ bool hasnan = false, \ - bool descending = false); \ + bool descending = false, \ + bool trailing_nans = true); \ template \ XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key, \ T2 *val, \ @@ -22,7 +23,8 @@ size_t k, \ size_t arrsize, \ bool hasnan = false, \ - bool descending = false); \ + bool descending = false, \ + bool trailing_nans = true); \ template \ XSS_HIDE_SYMBOL void keyvalue_select(T1 *key, \ T2 *val, \ @@ -35,7 +37,8 @@ size_t k, \ size_t arrsize, \ bool hasnan = false, \ - bool descending = false); \ + bool descending = false, \ + bool trailing_nans = true); \ template \ XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key, \ T2 *val, \ @@ -47,10 +50,12 @@ XSS_HIDE_SYMBOL std::vector argsort(const T *arr, \ size_t arrsize, \ bool hasnan = false, \ - bool descending = false); \ + bool descending = false, \ + bool trailing_nans = true); \ template \ XSS_HIDE_SYMBOL std::vector \ - argselect(const T *arr, size_t k, size_t arrsize, bool hasnan = false); \ + argselect(const T *arr, size_t k, size_t arrsize, bool hasnan = false, \ + bool trailing_nans = true); \ } namespace xss { diff --git a/lib/x86simdsort-scalar.h b/lib/x86simdsort-scalar.h index 9f08f9b2..03c23572 100644 --- a/lib/x86simdsort-scalar.h +++ b/lib/x86simdsort-scalar.h @@ -24,13 +24,26 @@ namespace utils { } } template - decltype(auto) get_cmp_func(bool hasnan, bool reverse) + decltype(auto) get_cmp_func(bool hasnan, bool reverse, + bool trailing_nans = true) { std::function cmp; if (hasnan) { - if (reverse == true) { cmp = compare>(); } + if (trailing_nans) { + if (reverse == true) { + cmp = compare_nan_end>(); + } + else { + cmp = compare>(); + } + } else { - cmp = compare>(); + if (reverse == true) { + cmp = compare>(); + } + else { + cmp = compare_nan_begin>(); + } } } else { @@ -45,66 +58,89 @@ namespace utils { namespace scalar { template - void qsort(T *arr, size_t arrsize, bool hasnan, bool reversed) + void qsort(T *arr, size_t arrsize, bool hasnan, bool reversed, + bool trailing_nans) { std::sort(arr, arr + arrsize, - xss::utils::get_cmp_func(hasnan, reversed)); + xss::utils::get_cmp_func(hasnan, reversed, + trailing_nans)); } template - void qselect(T *arr, size_t k, size_t arrsize, bool hasnan, bool reversed) + void qselect(T *arr, size_t k, size_t arrsize, bool hasnan, bool reversed, + bool trailing_nans) { std::nth_element(arr, arr + k, arr + arrsize, - xss::utils::get_cmp_func(hasnan, reversed)); + xss::utils::get_cmp_func(hasnan, reversed, + trailing_nans)); } template - void - partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan, bool reversed) + void partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan, + bool reversed, bool trailing_nans) { std::partial_sort(arr, arr + k, arr + arrsize, - xss::utils::get_cmp_func(hasnan, reversed)); + xss::utils::get_cmp_func(hasnan, reversed, + trailing_nans)); } template std::vector - argsort(const T *arr, size_t arrsize, bool hasnan, bool reversed) + argsort(const T *arr, size_t arrsize, bool hasnan, bool reversed, + bool trailing_nans) { UNUSED(hasnan); std::vector arg(arrsize); std::iota(arg.begin(), arg.end(), 0); - if (reversed) { - std::sort(arg.begin(), - arg.end(), - compare_arg>(arr)); + if (trailing_nans) { + if (reversed) { + std::sort(arg.begin(), arg.end(), + compare_arg_nan_end>(arr)); + } + else { + std::sort(arg.begin(), arg.end(), + compare_arg>(arr)); + } } else { - std::sort( - arg.begin(), arg.end(), compare_arg>(arr)); + if (reversed) { + std::sort(arg.begin(), arg.end(), + compare_arg>(arr)); + } + else { + std::sort(arg.begin(), arg.end(), + compare_arg_nan_begin>(arr)); + } } return arg; } template std::vector - argselect(const T *arr, size_t k, size_t arrsize, bool hasnan) + argselect(const T *arr, size_t k, size_t arrsize, bool hasnan, + bool trailing_nans) { UNUSED(hasnan); std::vector arg(arrsize); std::iota(arg.begin(), arg.end(), 0); - std::nth_element(arg.begin(), - arg.begin() + k, - arg.end(), - compare_arg>(arr)); + if (hasnan && !trailing_nans) { + std::nth_element(arg.begin(), arg.begin() + k, arg.end(), + compare_arg_nan_begin>(arr)); + } + else { + std::nth_element(arg.begin(), arg.begin() + k, arg.end(), + compare_arg>(arr)); + } return arg; } template void keyvalue_qsort( T1 *key, T2 *val, size_t arrsize, bool hasnan, bool descending) { - std::vector arg = argsort(key, arrsize, hasnan, descending); + std::vector arg = argsort(key, arrsize, hasnan, descending, + true); utils::apply_permutation_in_place(key, arg); utils::apply_permutation_in_place(val, arg); } diff --git a/lib/x86simdsort-skx.cpp b/lib/x86simdsort-skx.cpp index 6260bab7..f7b462ae 100644 --- a/lib/x86simdsort-skx.cpp +++ b/lib/x86simdsort-skx.cpp @@ -5,33 +5,39 @@ #define DEFINE_ALL_METHODS(type) \ template <> \ - void qsort(type *arr, size_t arrsize, bool hasnan, bool descending) \ + void qsort(type *arr, size_t arrsize, bool hasnan, bool descending, \ + bool trailing_nans) \ { \ - x86simdsortStatic::qsort(arr, arrsize, hasnan, descending); \ + x86simdsortStatic::qsort(arr, arrsize, hasnan, descending, \ + trailing_nans); \ } \ template <> \ - void qselect( \ - type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void qselect(type *arr, size_t k, size_t arrsize, bool hasnan, \ + bool descending, bool trailing_nans) \ { \ - x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending); \ + x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending, \ + trailing_nans); \ } \ template <> \ - void partial_qsort( \ - type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void partial_qsort(type *arr, size_t k, size_t arrsize, bool hasnan, \ + bool descending, bool trailing_nans) \ { \ - x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending); \ + x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending, \ + trailing_nans); \ } \ template <> \ - std::vector argsort( \ - const type *arr, size_t arrsize, bool hasnan, bool descending) \ + std::vector argsort(const type *arr, size_t arrsize, bool hasnan, \ + bool descending, bool trailing_nans) \ { \ - return x86simdsortStatic::argsort(arr, arrsize, hasnan, descending); \ + return x86simdsortStatic::argsort(arr, arrsize, hasnan, descending, \ + trailing_nans); \ } \ template <> \ - std::vector argselect( \ - const type *arr, size_t k, size_t arrsize, bool hasnan) \ + std::vector argselect(const type *arr, size_t k, size_t arrsize, \ + bool hasnan, bool trailing_nans) \ { \ - return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \ + return x86simdsortStatic::argselect(arr, k, arrsize, hasnan, \ + trailing_nans); \ } #define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \ diff --git a/lib/x86simdsort-spr.cpp b/lib/x86simdsort-spr.cpp index 7587640a..59f0a829 100644 --- a/lib/x86simdsort-spr.cpp +++ b/lib/x86simdsort-spr.cpp @@ -5,27 +5,32 @@ namespace xss { namespace fp16_spr { template <> - void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending) + void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending, + bool trailing_nans) { - x86simdsortStatic::qsort(arr, size, hasnan, descending); + x86simdsortStatic::qsort(arr, size, hasnan, descending, trailing_nans); } template <> void qselect(_Float16 *arr, size_t k, size_t arrsize, bool hasnan, - bool descending) + bool descending, + bool trailing_nans) { - x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending); + x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending, + trailing_nans); } template <> void partial_qsort(_Float16 *arr, size_t k, size_t arrsize, bool hasnan, - bool descending) + bool descending, + bool trailing_nans) { - x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending); + x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending, + trailing_nans); } } // namespace fp16_spr } // namespace xss diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 776ec56d..62b2f8b9 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -101,122 +101,144 @@ namespace x86simdsort { #ifdef _MSC_VER #define DECLARE_INTERNAL_qsort(TYPE) \ static void CAT(resolve_qsort, TYPE)(void); \ - static void (*internal_qsort##TYPE)(TYPE *, size_t, bool, bool) = NULL; \ + static void (*internal_qsort##TYPE)(TYPE *, size_t, bool, bool, bool) \ + = NULL; \ template <> \ - void XSS_EXPORT_SYMBOL qsort( \ - TYPE *arr, size_t arrsize, bool hasnan, bool descending) \ + void XSS_EXPORT_SYMBOL qsort(TYPE *arr, size_t arrsize, bool hasnan, \ + bool descending, bool trailing_nans) \ { \ if (internal_qsort##TYPE == NULL) { CAT(resolve_qsort, TYPE)(); } \ - (*internal_qsort##TYPE)(arr, arrsize, hasnan, descending); \ + (*internal_qsort##TYPE)(arr, arrsize, hasnan, descending, \ + trailing_nans); \ } #define DECLARE_INTERNAL_qselect(TYPE) \ static void CAT(resolve_qselect, TYPE)(void); \ - static void (*internal_qselect##TYPE)(TYPE *, size_t, size_t, bool, bool) \ + static void (*internal_qselect##TYPE)(TYPE *, size_t, size_t, bool, bool, \ + bool) \ = NULL; \ template <> \ - void XSS_EXPORT_SYMBOL qselect( \ - TYPE *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void XSS_EXPORT_SYMBOL qselect(TYPE *arr, size_t k, size_t arrsize, \ + bool hasnan, bool descending, \ + bool trailing_nans) \ { \ if (internal_qselect##TYPE == NULL) { CAT(resolve_qselect, TYPE)(); } \ - (*internal_qselect##TYPE)(arr, k, arrsize, hasnan, descending); \ + (*internal_qselect##TYPE)(arr, k, arrsize, hasnan, descending, \ + trailing_nans); \ } #define DECLARE_INTERNAL_partial_qsort(TYPE) \ static void CAT(resolve_partial_qsort, TYPE)(void); \ - static void (*internal_partial_qsort##TYPE)( \ - TYPE *, size_t, size_t, bool, bool) \ + static void (*internal_partial_qsort##TYPE)(TYPE *, size_t, size_t, bool, \ + bool, bool) \ = NULL; \ template <> \ - void XSS_EXPORT_SYMBOL partial_qsort( \ - TYPE *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void XSS_EXPORT_SYMBOL partial_qsort(TYPE *arr, size_t k, size_t arrsize, \ + bool hasnan, bool descending, \ + bool trailing_nans) \ { \ if (internal_partial_qsort##TYPE == NULL) { \ CAT(resolve_partial_qsort, TYPE)(); \ } \ - (*internal_partial_qsort##TYPE)(arr, k, arrsize, hasnan, descending); \ + (*internal_partial_qsort##TYPE)(arr, k, arrsize, hasnan, descending, \ + trailing_nans); \ } #define DECLARE_INTERNAL_argsort(TYPE) \ static void CAT(resolve_argsort, TYPE)(void); \ static std::vector (*internal_argsort##TYPE)( \ - const TYPE *, size_t, bool, bool) \ + const TYPE *, size_t, bool, bool, bool) \ = NULL; \ template <> \ std::vector XSS_EXPORT_SYMBOL argsort( \ - const TYPE *arr, size_t arrsize, bool hasnan, bool descending) \ + const TYPE *arr, size_t arrsize, bool hasnan, bool descending, \ + bool trailing_nans) \ { \ if (internal_argsort##TYPE == NULL) { CAT(resolve_argsort, TYPE)(); } \ - return (*internal_argsort##TYPE)(arr, arrsize, hasnan, descending); \ + return (*internal_argsort##TYPE)(arr, arrsize, hasnan, descending, \ + trailing_nans); \ } #define DECLARE_INTERNAL_argselect(TYPE) \ static void CAT(resolve_argselect, TYPE)(void); \ static std::vector (*internal_argselect##TYPE)( \ - const TYPE *, size_t, size_t, bool) \ + const TYPE *, size_t, size_t, bool, bool) \ = NULL; \ template <> \ std::vector XSS_EXPORT_SYMBOL argselect( \ - const TYPE *arr, size_t k, size_t arrsize, bool hasnan) \ + const TYPE *arr, size_t k, size_t arrsize, bool hasnan, \ + bool trailing_nans) \ { \ if (internal_argselect##TYPE == NULL) { \ CAT(resolve_argselect, TYPE)(); \ } \ - return (*internal_argselect##TYPE)(arr, k, arrsize, hasnan); \ + return (*internal_argselect##TYPE)(arr, k, arrsize, hasnan, \ + trailing_nans); \ } #else #define DECLARE_INTERNAL_qsort(TYPE) \ - static void (*internal_qsort##TYPE)(TYPE *, size_t, bool, bool) = NULL; \ + static void (*internal_qsort##TYPE)(TYPE *, size_t, bool, bool, bool) \ + = NULL; \ template <> \ - void XSS_EXPORT_SYMBOL qsort( \ - TYPE *arr, size_t arrsize, bool hasnan, bool descending) \ + void XSS_EXPORT_SYMBOL qsort(TYPE *arr, size_t arrsize, bool hasnan, \ + bool descending, bool trailing_nans) \ { \ - (*internal_qsort##TYPE)(arr, arrsize, hasnan, descending); \ + (*internal_qsort##TYPE)(arr, arrsize, hasnan, descending, \ + trailing_nans); \ } #define DECLARE_INTERNAL_qselect(TYPE) \ - static void (*internal_qselect##TYPE)(TYPE *, size_t, size_t, bool, bool) \ + static void (*internal_qselect##TYPE)(TYPE *, size_t, size_t, bool, bool, \ + bool) \ = NULL; \ template <> \ - void XSS_EXPORT_SYMBOL qselect( \ - TYPE *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void XSS_EXPORT_SYMBOL qselect(TYPE *arr, size_t k, size_t arrsize, \ + bool hasnan, bool descending, \ + bool trailing_nans) \ { \ - (*internal_qselect##TYPE)(arr, k, arrsize, hasnan, descending); \ + (*internal_qselect##TYPE)(arr, k, arrsize, hasnan, descending, \ + trailing_nans); \ } #define DECLARE_INTERNAL_partial_qsort(TYPE) \ - static void (*internal_partial_qsort##TYPE)( \ - TYPE *, size_t, size_t, bool, bool) \ + static void (*internal_partial_qsort##TYPE)(TYPE *, size_t, size_t, bool, \ + bool, bool) \ = NULL; \ template <> \ - void XSS_EXPORT_SYMBOL partial_qsort( \ - TYPE *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void XSS_EXPORT_SYMBOL partial_qsort(TYPE *arr, size_t k, size_t arrsize, \ + bool hasnan, bool descending, \ + bool trailing_nans) \ { \ - (*internal_partial_qsort##TYPE)(arr, k, arrsize, hasnan, descending); \ + (*internal_partial_qsort##TYPE)(arr, k, arrsize, hasnan, descending, \ + trailing_nans); \ } #define DECLARE_INTERNAL_argsort(TYPE) \ static std::vector (*internal_argsort##TYPE)( \ - const TYPE *, size_t, bool, bool) \ + const TYPE *, size_t, bool, bool, bool) \ = NULL; \ template <> \ std::vector XSS_EXPORT_SYMBOL argsort( \ - const TYPE *arr, size_t arrsize, bool hasnan, bool descending) \ + const TYPE *arr, size_t arrsize, bool hasnan, bool descending, \ + bool trailing_nans) \ { \ - return (*internal_argsort##TYPE)(arr, arrsize, hasnan, descending); \ + return (*internal_argsort##TYPE)(arr, arrsize, hasnan, descending, \ + trailing_nans); \ } #define DECLARE_INTERNAL_argselect(TYPE) \ static std::vector (*internal_argselect##TYPE)( \ - const TYPE *, size_t, size_t, bool) \ + const TYPE *, size_t, size_t, bool, bool) \ = NULL; \ template <> \ std::vector XSS_EXPORT_SYMBOL argselect( \ - const TYPE *arr, size_t k, size_t arrsize, bool hasnan) \ + const TYPE *arr, size_t k, size_t arrsize, bool hasnan, \ + bool trailing_nans) \ { \ - return (*internal_argselect##TYPE)(arr, k, arrsize, hasnan); \ + return (*internal_argselect##TYPE)(arr, k, arrsize, hasnan, \ + trailing_nans); \ } #endif // _MSC_VER diff --git a/lib/x86simdsort.h b/lib/x86simdsort.h index e30120ef..80c6ce67 100644 --- a/lib/x86simdsort.h +++ b/lib/x86simdsort.h @@ -20,7 +20,8 @@ namespace x86simdsort { // quicksort template XSS_EXPORT_SYMBOL void -qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); +qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false, + bool trailing_nans = true); // quickselect template @@ -28,7 +29,8 @@ XSS_EXPORT_SYMBOL void qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false, - bool descending = false); + bool descending = false, + bool trailing_nans = true); // partial sort template @@ -36,19 +38,22 @@ XSS_EXPORT_SYMBOL void partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false, - bool descending = false); + bool descending = false, + bool trailing_nans = true); // argsort template XSS_EXPORT_SYMBOL std::vector argsort(const T *arr, size_t arrsize, bool hasnan = false, - bool descending = false); + bool descending = false, + bool trailing_nans = true); // argselect template XSS_EXPORT_SYMBOL std::vector -argselect(const T *arr, size_t k, size_t arrsize, bool hasnan = false); +argselect(const T *arr, size_t k, size_t arrsize, bool hasnan = false, + bool trailing_nans = true); // keyvalue sort template diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index fbe18567..9fd79c47 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -584,7 +584,8 @@ avx512_qsort_fp16_helper(uint16_t *arr, arrsize_t arrsize) avx512_qsort_fp16(uint16_t *arr, arrsize_t arrsize, bool hasnan = false, - bool descending = false) + bool descending = false, + bool trailing_nans = true) { using vtype = zmm_vector; @@ -599,7 +600,7 @@ avx512_qsort_fp16(uint16_t *arr, else { avx512_qsort_fp16_helper>(arr, arrsize); } - replace_inf_with_nan(arr, arrsize, nan_count, descending); + replace_inf_with_nan(arr, arrsize, nan_count, descending, trailing_nans); } #ifdef __MMX__ @@ -613,7 +614,8 @@ avx512_qselect_fp16(uint16_t *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false, - bool descending = false) + bool descending = false, + bool trailing_nans = true) { using vtype = zmm_vector; @@ -624,7 +626,7 @@ avx512_qselect_fp16(uint16_t *arr, arrsize_t index_last_elem = arrsize - 1; if (UNLIKELY(hasnan)) { - if (descending) { + if (!trailing_nans) { index_first_elem = move_nans_to_start_of_array(arr, arrsize); } else { @@ -662,10 +664,12 @@ avx512_partial_qsort_fp16(uint16_t *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false, - bool descending = false) + bool descending = false, + bool trailing_nans = true) { if (k == 0) return; - avx512_qselect_fp16(arr, k - 1, arrsize, hasnan, descending); - avx512_qsort_fp16(arr, k - 1, hasnan, descending); + avx512_qselect_fp16(arr, k - 1, arrsize, hasnan, descending, + trailing_nans); + avx512_qsort_fp16(arr, k - 1, hasnan, descending, trailing_nans); } #endif // AVX512_QSORT_16BIT diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 8f85e599..6af50148 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -179,21 +179,27 @@ template <> X86_SIMD_SORT_INLINE_ONLY void replace_inf_with_nan(_Float16 *arr, arrsize_t size, arrsize_t nan_count, - bool descending) + bool descending, + bool trailing_nans) { + if (nan_count == 0) return; Fp16Bits val; val.i_ = 0x7c01; - if (descending) { - for (arrsize_t ii = 0; nan_count > 0; ++ii) { + if (descending && trailing_nans) { + std::rotate(arr, arr + nan_count, arr + size); + } + else if (!descending && !trailing_nans) { + std::rotate(arr, arr + (size - nan_count), arr + size); + } + if (trailing_nans) { + for (arrsize_t ii = size - nan_count; ii < size; ++ii) { arr[ii] = val.f_; - nan_count -= 1; } } else { - for (arrsize_t ii = size - 1; nan_count > 0; --ii) { + for (arrsize_t ii = 0; ii < nan_count; ++ii) { arr[ii] = val.f_; - nan_count -= 1; } } } diff --git a/src/x86simdsort-static-incl.h b/src/x86simdsort-static-incl.h index 2b0a11e0..a06ec15a 100644 --- a/src/x86simdsort-static-incl.h +++ b/src/x86simdsort-static-incl.h @@ -8,27 +8,31 @@ namespace x86simdsortStatic { template X86_SIMD_SORT_FINLINE void -qsort(T *arr, size_t size, bool hasnan = false, bool descending = false); +qsort(T *arr, size_t size, bool hasnan = false, bool descending = false, + bool trailing_nans = true); template X86_SIMD_SORT_FINLINE void qselect(T *arr, size_t k, size_t size, bool hasnan = false, - bool descending = false); + bool descending = false, + bool trailing_nans = true); template X86_SIMD_SORT_FINLINE void partial_qsort(T *arr, size_t k, size_t size, bool hasnan = false, - bool descending = false); + bool descending = false, + bool trailing_nans = true); template X86_SIMD_SORT_FINLINE std::vector argsort(const T *arr, size_t size, bool hasnan = false, - bool descending = false); + bool descending = false, + bool trailing_nans = true); /* argsort API required by NumPy: */ template @@ -36,16 +40,19 @@ X86_SIMD_SORT_FINLINE void argsort(const T *arr, size_t *arg, size_t size, bool hasnan = false, - bool descending = false); + bool descending = false, + bool trailing_nans = true); template X86_SIMD_SORT_FINLINE std::vector -argselect(const T *arr, size_t k, size_t size, bool hasnan = false); +argselect(const T *arr, size_t k, size_t size, bool hasnan = false, + bool trailing_nans = true); /* argselect API required by NumPy: */ template void X86_SIMD_SORT_FINLINE argselect( - const T *arr, size_t *arg, size_t k, size_t size, bool hasnan = false); + const T *arr, size_t *arg, size_t k, size_t size, bool hasnan = false, + bool trailing_nans = true); template X86_SIMD_SORT_FINLINE void keyvalue_qsort(T1 *key, @@ -75,54 +82,62 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key, #define XSS_METHODS(ISA) \ template \ X86_SIMD_SORT_FINLINE void x86simdsortStatic::qsort( \ - T *arr, size_t size, bool hasnan, bool descending) \ + T *arr, size_t size, bool hasnan, bool descending, \ + bool trailing_nans) \ { \ - ISA##_qsort(arr, size, hasnan, descending); \ + ISA##_qsort(arr, size, hasnan, descending, trailing_nans); \ } \ template \ X86_SIMD_SORT_FINLINE void x86simdsortStatic::qselect( \ - T *arr, size_t k, size_t size, bool hasnan, bool descending) \ + T *arr, size_t k, size_t size, bool hasnan, bool descending, \ + bool trailing_nans) \ { \ - ISA##_qselect(arr, k, size, hasnan, descending); \ + ISA##_qselect(arr, k, size, hasnan, descending, trailing_nans); \ } \ template \ X86_SIMD_SORT_FINLINE void x86simdsortStatic::partial_qsort( \ - T *arr, size_t k, size_t size, bool hasnan, bool descending) \ + T *arr, size_t k, size_t size, bool hasnan, bool descending, \ + bool trailing_nans) \ { \ - ISA##_partial_qsort(arr, k, size, hasnan, descending); \ + ISA##_partial_qsort(arr, k, size, hasnan, descending, trailing_nans); \ } \ template \ X86_SIMD_SORT_FINLINE void x86simdsortStatic::argsort(const T *arr, \ size_t *arg, \ size_t size, \ bool hasnan, \ - bool descending) \ + bool descending, \ + bool trailing_nans) \ { \ - ISA##_argsort(arr, arg, size, hasnan, descending); \ + ISA##_argsort(arr, arg, size, hasnan, descending, trailing_nans); \ } \ template \ X86_SIMD_SORT_FINLINE std::vector x86simdsortStatic::argsort( \ - const T *arr, size_t size, bool hasnan, bool descending) \ + const T *arr, size_t size, bool hasnan, bool descending, \ + bool trailing_nans) \ { \ std::vector indices(size); \ std::iota(indices.begin(), indices.end(), 0); \ x86simdsortStatic::argsort( \ - arr, indices.data(), size, hasnan, descending); \ + arr, indices.data(), size, hasnan, descending, trailing_nans); \ return indices; \ } \ template \ X86_SIMD_SORT_FINLINE void x86simdsortStatic::argselect( \ - const T *arr, size_t *arg, size_t k, size_t size, bool hasnan) \ + const T *arr, size_t *arg, size_t k, size_t size, bool hasnan, \ + bool trailing_nans) \ { \ - ISA##_argselect(arr, arg, k, size, hasnan); \ + ISA##_argselect(arr, arg, k, size, hasnan, trailing_nans); \ } \ template \ X86_SIMD_SORT_FINLINE std::vector x86simdsortStatic::argselect( \ - const T *arr, size_t k, size_t size, bool hasnan) \ + const T *arr, size_t k, size_t size, bool hasnan, \ + bool trailing_nans) \ { \ std::vector indices(size); \ std::iota(indices.begin(), indices.end(), 0); \ - x86simdsortStatic::argselect(arr, indices.data(), k, size, hasnan); \ + x86simdsortStatic::argselect( \ + arr, indices.data(), k, size, hasnan, trailing_nans); \ return indices; \ } \ template \ @@ -185,23 +200,35 @@ template <> void x86simdsortStatic::qsort<_Float16>(_Float16 *arr, size_t size, bool hasnan, - bool descending) + bool descending, + bool trailing_nans) { - avx512_qsort_fp16((uint16_t *)arr, size, hasnan, descending); + avx512_qsort_fp16((uint16_t *)arr, size, hasnan, descending, + trailing_nans); } template <> [[maybe_unused]] -void x86simdsortStatic::qselect<_Float16>( - _Float16 *arr, size_t k, size_t size, bool hasnan, bool descending) +void x86simdsortStatic::qselect<_Float16>(_Float16 *arr, + size_t k, + size_t size, + bool hasnan, + bool descending, + bool trailing_nans) { - avx512_qselect_fp16((uint16_t *)arr, k, size, hasnan, descending); + avx512_qselect_fp16((uint16_t *)arr, k, size, hasnan, descending, + trailing_nans); } template <> [[maybe_unused]] -void x86simdsortStatic::partial_qsort<_Float16>( - _Float16 *arr, size_t k, size_t size, bool hasnan, bool descending) +void x86simdsortStatic::partial_qsort<_Float16>(_Float16 *arr, + size_t k, + size_t size, + bool hasnan, + bool descending, + bool trailing_nans) { - avx512_partial_qsort_fp16((uint16_t *)arr, k, size, hasnan, descending); + avx512_partial_qsort_fp16((uint16_t *)arr, k, size, hasnan, descending, + trailing_nans); } #endif diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h index 1bec821b..e2267257 100644 --- a/src/xss-common-argsort.h +++ b/src/xss-common-argsort.h @@ -15,21 +15,18 @@ X86_SIMD_SORT_INLINE void std_argselect_withnan(const T *arr, arrsize_t *arg, arrsize_t k, arrsize_t left, - arrsize_t right) + arrsize_t right, + bool trailing_nans = true) { std::nth_element(arg + left, arg + k, arg + right, - [arr](arrsize_t a, arrsize_t b) -> bool { - if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { - return arr[a] < arr[b]; - } - else if (std::isnan(arr[a])) { - return false; - } - else { - return true; - } + [arr, trailing_nans](arrsize_t a, arrsize_t b) -> bool { + bool a_nan = std::isnan(arr[a]); + bool b_nan = std::isnan(arr[b]); + if (!a_nan && !b_nan) { return arr[a] < arr[b]; } + if (a_nan && b_nan) { return false; } + return trailing_nans ? !a_nan : a_nan; }); } @@ -599,7 +596,8 @@ X86_SIMD_SORT_INLINE void xss_argsort(const T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false, - bool descending = false) + bool descending = false, + bool trailing_nans = true) { using vectype = typename std::conditional) { if ((hasnan) && (array_has_nan(arr, arrsize))) { std_argsort_withnan(arr, arg, 0, arrsize); - + // std_argsort_withnan produces ascending order with NaN at end. + // After optional reverse, NaN is at beginning for descending. if (descending) { std::reverse(arg, arg + arrsize); } + // Now adjust NaN position if it doesn't match trailing_nans: + // descending=false → NaN at end; descending=true → NaN at beginning + bool nan_currently_at_end = !descending; + if (nan_currently_at_end != trailing_nans) { + arrsize_t nan_count = 0; + for (arrsize_t i = 0; i < arrsize; i++) { + nan_count += is_a_nan(arr[i]); + } + if (trailing_nans) { + // NaN is at beginning, rotate to end + std::rotate(arg, arg + nan_count, arg + arrsize); + } + else { + // NaN is at end, rotate to beginning + std::rotate(arg, + arg + arrsize - nan_count, + arg + arrsize); + } + } return; } @@ -678,10 +696,11 @@ X86_SIMD_SORT_INLINE void avx512_argsort(const T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false, - bool descending = false) + bool descending = false, + bool trailing_nans = true) { xss_argsort( - arr, arg, arrsize, hasnan, descending); + arr, arg, arrsize, hasnan, descending, trailing_nans); } template @@ -689,10 +708,11 @@ X86_SIMD_SORT_INLINE void avx2_argsort(const T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false, - bool descending = false) + bool descending = false, + bool trailing_nans = true) { xss_argsort( - arr, arg, arrsize, hasnan, descending); + arr, arg, arrsize, hasnan, descending, trailing_nans); } /* argselect methods for 32-bit and 64-bit dtypes */ @@ -705,7 +725,8 @@ X86_SIMD_SORT_INLINE void xss_argselect(const T *arr, arrsize_t *arg, arrsize_t k, arrsize_t arrsize, - bool hasnan = false) + bool hasnan = false, + bool trailing_nans = true) { /* TODO optimization: on 32-bit, use full_vector for 32-bit dtype */ using vectype = typename std::conditional 1) { if constexpr (xss::fp::is_floating_point_v) { if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argselect_withnan(arr, arg, k, 0, arrsize); + std_argselect_withnan( + arr, arg, k, 0, arrsize, trailing_nans); return; } } @@ -740,9 +762,11 @@ X86_SIMD_SORT_INLINE void avx512_argselect(const T *arr, arrsize_t *arg, arrsize_t k, arrsize_t arrsize, - bool hasnan = false) + bool hasnan = false, + bool trailing_nans = true) { - xss_argselect(arr, arg, k, arrsize, hasnan); + xss_argselect( + arr, arg, k, arrsize, hasnan, trailing_nans); } template @@ -750,10 +774,11 @@ X86_SIMD_SORT_INLINE void avx2_argselect(const T *arr, arrsize_t *arg, arrsize_t k, arrsize_t arrsize, - bool hasnan = false) + bool hasnan = false, + bool trailing_nans = true) { xss_argselect( - arr, arg, k, arrsize, hasnan); + arr, arg, k, arrsize, hasnan, trailing_nans); } #endif // XSS_COMMON_ARGSORT diff --git a/src/xss-common-keyvaluesort.hpp b/src/xss-common-keyvaluesort.hpp index 3a07e01b..fcf8778e 100644 --- a/src/xss-common-keyvaluesort.hpp +++ b/src/xss-common-keyvaluesort.hpp @@ -641,8 +641,9 @@ X86_SIMD_SORT_INLINE void xss_qsort_kv( #endif if (descending) { - std::reverse(keys, keys + arrsize); - std::reverse(indexes, indexes + arrsize); + // Only reverse the real portion; NaN at the end stays in place + std::reverse(keys, keys + index_last_elem + 1); + std::reverse(indexes, indexes + index_last_elem + 1); } } @@ -688,8 +689,6 @@ X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys, #endif // XSS_TEST_KEYVALUE_BASE_CASE if (minarrsize) { - if (descending) { k = arrsize - 1 - k; } - arrsize_t index_last_elem = arrsize - 1; if constexpr (xss::fp::is_floating_point_v) { if (UNLIKELY(hasnan)) { @@ -699,6 +698,9 @@ X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys, } } + // For descending: map k to ascending position within real portion + if (descending) { k = index_last_elem - k; } + UNUSED(hasnan); if (index_last_elem >= k) { kvselect_( @@ -706,8 +708,9 @@ X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys, } if (descending) { - std::reverse(keys, keys + arrsize); - std::reverse(indexes, indexes + arrsize); + // Only reverse the real portion; NaN at the end stays in place + std::reverse(keys, keys + index_last_elem + 1); + std::reverse(indexes, indexes + index_last_elem + 1); } } diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 2bf7ca61..541f5436 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -108,28 +108,39 @@ template X86_SIMD_SORT_INLINE void replace_inf_with_nan(type_t *arr, arrsize_t size, arrsize_t nan_count, - bool descending = false) + bool descending, + bool trailing_nans) { - if (descending) { - for (arrsize_t ii = 0; nan_count > 0; ++ii) { + if (nan_count == 0) return; + // After ascending sort +inf lands at the end; after descending at the start. + // When the desired NaN position differs from where +inf landed, rotate first. + if (descending && trailing_nans) { + // +inf at beginning, want NaN at end: rotate left + std::rotate(arr, arr + nan_count, arr + size); + } + else if (!descending && !trailing_nans) { + // +inf at end, want NaN at beginning: rotate right + std::rotate(arr, arr + (size - nan_count), arr + size); + } + // Write NaN at the now-correct position + if (trailing_nans) { + for (arrsize_t ii = size - nan_count; ii < size; ++ii) { if constexpr (xss::fp::is_floating_point_v) { arr[ii] = xss::fp::quiet_NaN(); } else { arr[ii] = 0x7c01; // std::quiet_nan } - nan_count -= 1; } } else { - for (arrsize_t ii = size - 1; nan_count > 0; --ii) { + for (arrsize_t ii = 0; ii < nan_count; ++ii) { if constexpr (xss::fp::is_floating_point_v) { arr[ii] = xss::fp::quiet_NaN(); } else { arr[ii] = 0x7c01; // std::quiet_nan } - nan_count -= 1; } } } @@ -650,7 +661,8 @@ X86_SIMD_SORT_INLINE void qselect_(type_t *arr, // Quicksort routines: template -X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan) +X86_SIMD_SORT_INLINE void +xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, bool trailing_nans = true) { using comparator = typename std::conditional -X86_SIMD_SORT_INLINE void -xss_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) +X86_SIMD_SORT_INLINE void xss_qselect(T *arr, + arrsize_t k, + arrsize_t arrsize, + bool hasnan, + bool trailing_nans = true) { using comparator = typename std::conditional) { if (UNLIKELY(hasnan)) { - if constexpr (descending) { + if (!trailing_nans) { index_first_elem = move_nans_to_start_of_array(arr, arrsize); } else { @@ -753,12 +768,16 @@ xss_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) // Partial sort methods: template -X86_SIMD_SORT_INLINE void -xss_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) +X86_SIMD_SORT_INLINE void xss_partial_qsort(T *arr, + arrsize_t k, + arrsize_t arrsize, + bool hasnan, + bool trailing_nans = true) { if (k == 0) return; - xss_qselect(arr, k - 1, arrsize, hasnan); - xss_qsort(arr, k - 1, hasnan); + xss_qselect(arr, k - 1, arrsize, hasnan, + trailing_nans); + xss_qsort(arr, k - 1, hasnan, trailing_nans); } #define DEFINE_METHODS(ISA, VTYPE) \ @@ -766,11 +785,14 @@ xss_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) X86_SIMD_SORT_INLINE void ISA##_qsort(T *arr, \ arrsize_t size, \ bool hasnan = false, \ - bool descending = false) \ + bool descending = false, \ + bool trailing_nans = true) \ { \ - if (descending) { xss_qsort(arr, size, hasnan); } \ + if (descending) { \ + xss_qsort(arr, size, hasnan, trailing_nans); \ + } \ else { \ - xss_qsort(arr, size, hasnan); \ + xss_qsort(arr, size, hasnan, trailing_nans); \ } \ } \ template \ @@ -778,11 +800,16 @@ xss_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) arrsize_t k, \ arrsize_t size, \ bool hasnan = false, \ - bool descending = false) \ + bool descending = false, \ + bool trailing_nans = true) \ { \ - if (descending) { xss_qselect(arr, k, size, hasnan); } \ + if (descending) { \ + xss_qselect(arr, k, size, hasnan, \ + trailing_nans); \ + } \ else { \ - xss_qselect(arr, k, size, hasnan); \ + xss_qselect(arr, k, size, hasnan, \ + trailing_nans); \ } \ } \ template \ @@ -790,13 +817,16 @@ xss_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) arrsize_t k, \ arrsize_t size, \ bool hasnan = false, \ - bool descending = false) \ + bool descending = false, \ + bool trailing_nans = true) \ { \ if (descending) { \ - xss_partial_qsort(arr, k, size, hasnan); \ + xss_partial_qsort(arr, k, size, hasnan, \ + trailing_nans); \ } \ else { \ - xss_partial_qsort(arr, k, size, hasnan); \ + xss_partial_qsort(arr, k, size, hasnan, \ + trailing_nans); \ } \ } diff --git a/tests/test-qsort-common.h b/tests/test-qsort-common.h index e894a86e..b568d540 100644 --- a/tests/test-qsort-common.h +++ b/tests/test-qsort-common.h @@ -60,20 +60,25 @@ void IS_ARR_PARTITIONED(std::vector arr, cmp_eq = compare>(); if (!descending) { - cmp_less = compare>(); - cmp_leq = compare>(); - cmp_geq = compare>(); + cmp_less = compare_nan_end>(); + cmp_leq = compare_nan_end>(); + cmp_geq = compare_nan_end>(); } else { - cmp_less = compare>(); - cmp_leq = compare>(); - cmp_geq = compare>(); + cmp_less = compare_nan_end>(); + cmp_leq = compare_nan_end>(); + cmp_geq = compare_nan_end>(); } // 1) arr[k] == sorted[k]; use memcmp to handle nan if (!cmp_eq(arr[k], true_kth)) { REPORT_FAIL("kth element is incorrect", arr.size(), type, k); } + // If arr[k] is NaN, k is in the trailing NaN block; value comparisons are + // not meaningful, so skip the left/right partition checks. + if constexpr (xss::fp::is_floating_point_v) { + if (xss::fp::isnan(arr[k])) return; + } // ( 2) Elements to the left of k should be atmost arr[k] if (k >= 1) { T max_left diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index f2ce3a6b..672ea573 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -76,7 +76,7 @@ TYPED_TEST_P(simdsort, test_qsort_descending) #ifndef XSS_ASAN_CI_NOCHECK std::sort(sortedarr.begin(), sortedarr.end(), - compare>()); + compare_nan_end>()); IS_SORTED(sortedarr, arr, type); #endif arr.clear(); @@ -119,7 +119,7 @@ TYPED_TEST_P(simdsort, test_argsort_descending) #ifndef XSS_ASAN_CI_NOCHECK std::sort(sortedarr.begin(), sortedarr.end(), - compare>()); + compare_nan_end>()); IS_ARG_SORTED(sortedarr, arr, arg, type); #endif arr.clear(); @@ -172,7 +172,7 @@ TYPED_TEST_P(simdsort, test_qselect_descending) std::nth_element(sortedarr.begin(), sortedarr.begin() + k, sortedarr.end(), - compare>()); + compare_nan_end>()); if (size == 0) continue; IS_ARR_PARTITIONED(arr, k, sortedarr[k], type, true); #endif @@ -248,7 +248,7 @@ TYPED_TEST_P(simdsort, test_partial_qsort_descending) #ifndef XSS_ASAN_CI_NOCHECK std::sort(sortedarr.begin(), sortedarr.end(), - compare>()); + compare_nan_end>()); if (size == 0) continue; IS_ARR_PARTIALSORTED(arr, k, sortedarr, type); #endif diff --git a/utils/custom-compare.h b/utils/custom-compare.h index f2c8d61e..80274f46 100644 --- a/utils/custom-compare.h +++ b/utils/custom-compare.h @@ -46,4 +46,68 @@ struct compare_arg { const T *arr; }; +/* + * Comparator that always places NaN at the end of the sorted array, + * regardless of whether Comparator is ascending or descending. + */ +template +struct compare_nan_end { + static constexpr auto op = Comparator {}; + bool operator()(const T a, const T b) const + { + if constexpr (xss::fp::is_floating_point_v) { + bool a_nan = xss::fp::isnan(a); + bool b_nan = xss::fp::isnan(b); + if (!a_nan && !b_nan) { return op(a, b); } + if (a_nan && b_nan) { return false; } + return !a_nan; // b is NaN → a before b → NaN at end + } + else { + return op(a, b); + } + } +}; + +/* + * Comparator that always places NaN at the beginning of the sorted array, + * regardless of whether Comparator is ascending or descending. + */ +template +struct compare_nan_begin { + static constexpr auto op = Comparator {}; + bool operator()(const T a, const T b) const + { + if constexpr (xss::fp::is_floating_point_v) { + bool a_nan = xss::fp::isnan(a); + bool b_nan = xss::fp::isnan(b); + if (!a_nan && !b_nan) { return op(a, b); } + if (a_nan && b_nan) { return false; } + return a_nan; // a is NaN → a before b → NaN at beginning + } + else { + return op(a, b); + } + } +}; + +template +struct compare_arg_nan_end { + compare_arg_nan_end(const T *arr) : arr(arr) {} + bool operator()(const int64_t a, const int64_t b) const + { + return compare_nan_end()(arr[a], arr[b]); + } + const T *arr; +}; + +template +struct compare_arg_nan_begin { + compare_arg_nan_begin(const T *arr) : arr(arr) {} + bool operator()(const int64_t a, const int64_t b) const + { + return compare_nan_begin()(arr[a], arr[b]); + } + const T *arr; +}; + #endif // UTILS_CUSTOM_COMPARE \ No newline at end of file From be7c2da85dba02b6b673253d7b23803929f1a8b7 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 25 May 2026 04:01:14 +0000 Subject: [PATCH 02/13] qsort: partition NaNs to end/start instead of replacing with inf --- src/avx512-16bit-qsort.hpp | 72 +++++++++---------- src/avx512fp16-16bit-qsort.hpp | 28 -------- src/xss-common-qsort.h | 128 ++++++++++----------------------- 3 files changed, 70 insertions(+), 158 deletions(-) diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 9fd79c47..6efe1a50 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -519,41 +519,25 @@ comparison_func>(const uint16_t &a, const uint16_t &b) //return npy_half_to_float(a) < npy_half_to_float(b); } -template <> -X86_SIMD_SORT_INLINE_ONLY arrsize_t -replace_nan_with_inf>(uint16_t *arr, arrsize_t arrsize) -{ - arrsize_t nan_count = 0; - __mmask16 loadmask = 0xFFFF; - for (arrsize_t ii = 0; ii < arrsize; - ii = ii + zmm_vector::numlanes / 2) { - if (arrsize - ii < 16) { - loadmask = (0x0001 << (arrsize - ii)) - 0x0001; - } - __m256i in_zmm = _mm256_maskz_loadu_epi16(loadmask, arr); - __m512 in_zmm_asfloat = _mm512_cvtph_ps(in_zmm); - __mmask16 nanmask = _mm512_cmp_ps_mask( - in_zmm_asfloat, in_zmm_asfloat, _CMP_NEQ_UQ); - nan_count += _mm_popcnt_u32((int32_t)nanmask); - _mm256_mask_storeu_epi16(arr, nanmask, YMM_MAX_HALF); - arr += 16; - } - return nan_count; -} template [[maybe_unused]] X86_SIMD_SORT_INLINE void -avx512_qsort_fp16_helper(uint16_t *arr, arrsize_t arrsize) +avx512_qsort_fp16_helper(uint16_t *arr, + arrsize_t arrsize, + arrsize_t index_first_elem, + arrsize_t index_last_elem) { using T = uint16_t; using vtype = zmm_vector; #ifdef XSS_COMPILE_OPENMP - bool use_parallel = arrsize > 100000; + bool use_parallel = (index_last_elem - index_first_elem + 1) > 100000; if (use_parallel) { int thread_count = xss_get_num_threads(); - arrsize_t task_threshold = std::max((arrsize_t)100000, arrsize / 100); + arrsize_t task_threshold = std::max( + (arrsize_t)100000, + (index_last_elem - index_first_elem + 1) / 100); // We use omp parallel and then omp single to setup the threads that will run the omp task calls in qsort_ // The omp single prevents multiple threads from running the initial qsort_ simultaneously and causing problems @@ -561,22 +545,25 @@ avx512_qsort_fp16_helper(uint16_t *arr, arrsize_t arrsize) #pragma omp parallel num_threads(thread_count) #pragma omp single qsort_(arr, - 0, - arrsize - 1, + index_first_elem, + index_last_elem, 2 * (arrsize_t)log2(arrsize), task_threshold); } else { qsort_(arr, - 0, - arrsize - 1, + index_first_elem, + index_last_elem, 2 * (arrsize_t)log2(arrsize), std::numeric_limits::max()); } #pragma omp taskwait #else - qsort_( - arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0); + qsort_(arr, + index_first_elem, + index_last_elem, + 2 * (arrsize_t)log2(arrsize), + 0); #endif } @@ -590,17 +577,26 @@ avx512_qsort_fp16(uint16_t *arr, using vtype = zmm_vector; if (arrsize > 1) { - arrsize_t nan_count = 0; + arrsize_t index_first_elem = 0; + arrsize_t index_last_elem = arrsize - 1; if (UNLIKELY(hasnan)) { - nan_count = replace_nan_with_inf(arr, arrsize); + if (!trailing_nans) { + index_first_elem = move_nans_to_start_of_array(arr, arrsize); + } + else { + index_last_elem = move_nans_to_end_of_array(arr, arrsize); + } } - if (descending) { - avx512_qsort_fp16_helper>(arr, arrsize); - } - else { - avx512_qsort_fp16_helper>(arr, arrsize); + if (index_first_elem <= index_last_elem && index_last_elem < arrsize) { + if (descending) { + avx512_qsort_fp16_helper>( + arr, arrsize, index_first_elem, index_last_elem); + } + else { + avx512_qsort_fp16_helper>( + arr, arrsize, index_first_elem, index_last_elem); + } } - replace_inf_with_nan(arr, arrsize, nan_count, descending, trailing_nans); } #ifdef __MMX__ diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 6af50148..d4bebe74 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -175,32 +175,4 @@ X86_SIMD_SORT_INLINE_ONLY bool is_a_nan<_Float16>(_Float16 elem) return elem != elem; } -template <> -X86_SIMD_SORT_INLINE_ONLY void replace_inf_with_nan(_Float16 *arr, - arrsize_t size, - arrsize_t nan_count, - bool descending, - bool trailing_nans) -{ - if (nan_count == 0) return; - Fp16Bits val; - val.i_ = 0x7c01; - - if (descending && trailing_nans) { - std::rotate(arr, arr + nan_count, arr + size); - } - else if (!descending && !trailing_nans) { - std::rotate(arr, arr + (size - nan_count), arr + size); - } - if (trailing_nans) { - for (arrsize_t ii = size - nan_count; ii < size; ++ii) { - arr[ii] = val.f_; - } - } - else { - for (arrsize_t ii = 0; ii < nan_count; ++ii) { - arr[ii] = val.f_; - } - } -} #endif // AVX512FP16_QSORT_16BIT diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 541f5436..35a78613 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -51,32 +51,6 @@ X86_SIMD_SORT_INLINE_ONLY bool is_a_nan(uint16_t elem) return ((elem & 0x7c00u) == 0x7c00u) && ((elem & 0x03ffu) != 0); } -template -X86_SIMD_SORT_INLINE arrsize_t replace_nan_with_inf(T *arr, arrsize_t size) -{ - arrsize_t nan_count = 0; - using opmask_t = typename vtype::opmask_t; - using reg_t = typename vtype::reg_t; - opmask_t loadmask; - reg_t in; - /* - * (ii + numlanes) can never overflow: max val of size is 2**63 on 64-bit - * and 2**31 on 32-bit systems - */ - for (arrsize_t ii = 0; ii < size; ii = ii + vtype::numlanes) { - if (size - ii < vtype::numlanes) { - loadmask = vtype::get_partial_loadmask(size - ii); - in = vtype::maskz_loadu(loadmask, arr + ii); - } - else { - in = vtype::loadu(arr + ii); - } - opmask_t nanmask = vtype::template fpclass<0x01 | 0x80>(in); - nan_count += _mm_popcnt_u32(vtype::convert_mask_to_int(nanmask)); - vtype::mask_storeu(arr + ii, nanmask, vtype::zmm_max()); - } - return nan_count; -} template X86_SIMD_SORT_INLINE bool array_has_nan(const type_t *arr, arrsize_t size) @@ -104,46 +78,6 @@ X86_SIMD_SORT_INLINE bool array_has_nan(const type_t *arr, arrsize_t size) return found_nan; } -template -X86_SIMD_SORT_INLINE void replace_inf_with_nan(type_t *arr, - arrsize_t size, - arrsize_t nan_count, - bool descending, - bool trailing_nans) -{ - if (nan_count == 0) return; - // After ascending sort +inf lands at the end; after descending at the start. - // When the desired NaN position differs from where +inf landed, rotate first. - if (descending && trailing_nans) { - // +inf at beginning, want NaN at end: rotate left - std::rotate(arr, arr + nan_count, arr + size); - } - else if (!descending && !trailing_nans) { - // +inf at end, want NaN at beginning: rotate right - std::rotate(arr, arr + (size - nan_count), arr + size); - } - // Write NaN at the now-correct position - if (trailing_nans) { - for (arrsize_t ii = size - nan_count; ii < size; ++ii) { - if constexpr (xss::fp::is_floating_point_v) { - arr[ii] = xss::fp::quiet_NaN(); - } - else { - arr[ii] = 0x7c01; // std::quiet_nan - } - } - } - else { - for (arrsize_t ii = 0; ii < nan_count; ++ii) { - if constexpr (xss::fp::is_floating_point_v) { - arr[ii] = xss::fp::quiet_NaN(); - } - else { - arr[ii] = 0x7c01; // std::quiet_nan - } - } - } -} /* * Sort all the NAN's to end of the array and return the index of the last elem @@ -670,49 +604,59 @@ xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, bool trailing_nans = true) Comparator>::type; if (arrsize > 1) { - arrsize_t nan_count = 0; + arrsize_t index_first_elem = 0; + arrsize_t index_last_elem = arrsize - 1; if constexpr (xss::fp::is_floating_point_v) { if (UNLIKELY(hasnan)) { - nan_count = replace_nan_with_inf(arr, arrsize); + if (!trailing_nans) { + index_first_elem = move_nans_to_start_of_array(arr, arrsize); + } + else { + index_last_elem = move_nans_to_end_of_array(arr, arrsize); + } } } UNUSED(hasnan); + if (index_first_elem <= index_last_elem && index_last_elem < arrsize) { #ifdef XSS_COMPILE_OPENMP - bool use_parallel = arrsize > 100000; + bool use_parallel = (index_last_elem - index_first_elem + 1) > 100000; - if (use_parallel) { - int thread_count = xss_get_num_threads(); - arrsize_t task_threshold - = std::max((arrsize_t)100000, arrsize / 100); + if (use_parallel) { + int thread_count = xss_get_num_threads(); + arrsize_t task_threshold = std::max( + (arrsize_t)100000, + (index_last_elem - index_first_elem + 1) / 100); - // We use omp parallel and then omp single to setup the threads that will run the omp task calls in qsort_ - // The omp single prevents multiple threads from running the initial qsort_ simultaneously and causing problems - // Note that we do not use the if(...) clause built into OpenMP, because it causes a performance regression for small arrays + // We use omp parallel and then omp single to setup the threads that will run the omp task calls in qsort_ + // The omp single prevents multiple threads from running the initial qsort_ simultaneously and causing problems + // Note that we do not use the if(...) clause built into OpenMP, because it causes a performance regression for small arrays #pragma omp parallel num_threads(thread_count) #pragma omp single - qsort_(arr, - 0, - arrsize - 1, - 2 * (arrsize_t)log2(arrsize), - task_threshold); + qsort_(arr, + index_first_elem, + index_last_elem, + 2 * (arrsize_t)log2(arrsize), + task_threshold); #pragma omp taskwait - } - else { + } + else { + qsort_(arr, + index_first_elem, + index_last_elem, + 2 * (arrsize_t)log2(arrsize), + std::numeric_limits::max()); + } +#else qsort_(arr, - 0, - arrsize - 1, + index_first_elem, + index_last_elem, 2 * (arrsize_t)log2(arrsize), - std::numeric_limits::max()); - } -#else - qsort_( - arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0); + 0); #endif - - replace_inf_with_nan(arr, arrsize, nan_count, descending, trailing_nans); + } } #ifdef __MMX__ From 8d476a587d29bc21844e9b2d1e1a318495ee9b94 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 25 May 2026 04:19:47 +0000 Subject: [PATCH 03/13] argsort: fold descending+trailing_nans logic into std_argsort_withnan --- src/xss-common-argsort.h | 48 +++++++++++----------------------------- 1 file changed, 13 insertions(+), 35 deletions(-) diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h index e2267257..f012b52d 100644 --- a/src/xss-common-argsort.h +++ b/src/xss-common-argsort.h @@ -30,25 +30,25 @@ X86_SIMD_SORT_INLINE void std_argselect_withnan(const T *arr, }); } -/* argsort using std::sort */ +/* argsort using std::sort, handles NaN placement and descending order */ template X86_SIMD_SORT_INLINE void std_argsort_withnan(const T *arr, arrsize_t *arg, arrsize_t left, - arrsize_t right) + arrsize_t right, + bool trailing_nans = true, + bool descending = false) { std::sort(arg + left, arg + right, - [arr](arrsize_t left, arrsize_t right) -> bool { - if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) { - return arr[left] < arr[right]; - } - else if (std::isnan(arr[left])) { - return false; - } - else { - return true; + [arr, trailing_nans, descending](arrsize_t a, arrsize_t b) -> bool { + bool a_nan = std::isnan(arr[a]); + bool b_nan = std::isnan(arr[b]); + if (!a_nan && !b_nan) { + return descending ? arr[a] > arr[b] : arr[a] < arr[b]; } + if (a_nan && b_nan) { return false; } + return trailing_nans ? !a_nan : a_nan; }); } @@ -613,30 +613,8 @@ X86_SIMD_SORT_INLINE void xss_argsort(const T *arr, /* simdargsort does not work for float/double arrays with nan */ if constexpr (xss::fp::is_floating_point_v) { if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argsort_withnan(arr, arg, 0, arrsize); - // std_argsort_withnan produces ascending order with NaN at end. - // After optional reverse, NaN is at beginning for descending. - if (descending) { std::reverse(arg, arg + arrsize); } - // Now adjust NaN position if it doesn't match trailing_nans: - // descending=false → NaN at end; descending=true → NaN at beginning - bool nan_currently_at_end = !descending; - if (nan_currently_at_end != trailing_nans) { - arrsize_t nan_count = 0; - for (arrsize_t i = 0; i < arrsize; i++) { - nan_count += is_a_nan(arr[i]); - } - if (trailing_nans) { - // NaN is at beginning, rotate to end - std::rotate(arg, arg + nan_count, arg + arrsize); - } - else { - // NaN is at end, rotate to beginning - std::rotate(arg, - arg + arrsize - nan_count, - arg + arrsize); - } - } - + std_argsort_withnan( + arr, arg, 0, arrsize, trailing_nans, descending); return; } } From 6265c9a07587c65440e00181a1d3cfd62de8cbac Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 25 May 2026 04:40:07 +0000 Subject: [PATCH 04/13] tests: add trailing/leading NaN coverage for qsort, argsort, qselect --- tests/test-qsort.cpp | 275 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 275 insertions(+) diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index 672ea573..467acd6a 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -128,6 +128,91 @@ TYPED_TEST_P(simdsort, test_argsort_descending) } } +TYPED_TEST_P(simdsort, test_argsort_trailing_nans) +{ + if constexpr (xss::fp::is_floating_point_v) { + std::vector nan_types = {"rand_with_nan", + "rand_with_max_and_nan"}; + for (auto type : nan_types) { + for (auto size : this->arrsize_long) { + std::vector base = get_array(type, size); + + // ascending, NaNs at end + { + std::vector arr = base; + std::vector sortedarr = base; + auto arg = x86simdsort::argsort( + arr.data(), arr.size(), true, false, true); +#ifndef XSS_ASAN_CI_NOCHECK + std::sort(sortedarr.begin(), + sortedarr.end(), + compare_nan_end>()); + IS_ARG_SORTED(sortedarr, arr, arg, type); +#endif + } + + // descending, NaNs at end + { + std::vector arr = base; + std::vector sortedarr = base; + auto arg = x86simdsort::argsort( + arr.data(), arr.size(), true, true, true); +#ifndef XSS_ASAN_CI_NOCHECK + std::sort(sortedarr.begin(), + sortedarr.end(), + compare_nan_end>()); + IS_ARG_SORTED(sortedarr, arr, arg, type); +#endif + } + } + } + } +} + +TYPED_TEST_P(simdsort, test_argsort_leading_nans) +{ + if constexpr (xss::fp::is_floating_point_v) { + std::vector nan_types = {"rand_with_nan", + "rand_with_max_and_nan"}; + for (auto type : nan_types) { + for (auto size : this->arrsize_long) { + std::vector base = get_array(type, size); + + // ascending, NaNs at start + { + std::vector arr = base; + std::vector sortedarr = base; + auto arg = x86simdsort::argsort( + arr.data(), arr.size(), true, false, false); +#ifndef XSS_ASAN_CI_NOCHECK + std::sort(sortedarr.begin(), + sortedarr.end(), + compare_nan_begin>()); + IS_ARG_SORTED(sortedarr, arr, arg, type); +#endif + } + + // descending, NaNs at start + { + std::vector arr = base; + std::vector sortedarr = base; + auto arg = x86simdsort::argsort( + arr.data(), arr.size(), true, true, false); +#ifndef XSS_ASAN_CI_NOCHECK + std::sort(sortedarr.begin(), + sortedarr.end(), + compare_nan_begin>()); + IS_ARG_SORTED(sortedarr, arr, arg, type); +#endif + } + } + } + } +} + TYPED_TEST_P(simdsort, test_qselect_ascending) { for (auto type : this->arrtype) { @@ -182,6 +267,105 @@ TYPED_TEST_P(simdsort, test_qselect_descending) } } +TYPED_TEST_P(simdsort, test_qselect_trailing_nans) +{ + if constexpr (xss::fp::is_floating_point_v) { + std::vector nan_types = {"rand_with_nan", + "rand_with_max_and_nan"}; + for (auto type : nan_types) { + for (auto size : this->arrsize) { + size_t k = size != 0 ? rand() % size : 0; + std::vector base = get_array(type, size); + + // ascending, NaNs at end + { + std::vector arr = base; + std::vector sortedarr = base; + x86simdsort::qselect( + arr.data(), k, arr.size(), true, false, true); +#ifndef XSS_ASAN_CI_NOCHECK + std::nth_element( + sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare_nan_end>()); + if (size == 0) continue; + IS_ARR_PARTITIONED(arr, k, sortedarr[k], type); +#endif + } + + // descending, NaNs at end + { + std::vector arr = base; + std::vector sortedarr = base; + x86simdsort::qselect( + arr.data(), k, arr.size(), true, true, true); +#ifndef XSS_ASAN_CI_NOCHECK + std::nth_element( + sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare_nan_end>()); + if (size == 0) continue; + IS_ARR_PARTITIONED(arr, k, sortedarr[k], type, true); +#endif + } + } + } + } +} + +TYPED_TEST_P(simdsort, test_qselect_leading_nans) +{ + if constexpr (xss::fp::is_floating_point_v) { + std::vector nan_types = {"rand_with_nan", + "rand_with_max_and_nan"}; + for (auto type : nan_types) { + for (auto size : this->arrsize) { + size_t k = size != 0 ? rand() % size : 0; + std::vector base = get_array(type, size); + + // ascending, NaNs at start + { + std::vector arr = base; + std::vector sortedarr = base; + x86simdsort::qselect( + arr.data(), k, arr.size(), true, false, false); +#ifndef XSS_ASAN_CI_NOCHECK + std::nth_element( + sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare_nan_begin>()); + if (size == 0) continue; + IS_ARR_PARTITIONED(arr, k, sortedarr[k], type); +#endif + } + + // descending, NaNs at start + { + std::vector arr = base; + std::vector sortedarr = base; + x86simdsort::qselect( + arr.data(), k, arr.size(), true, true, false); +#ifndef XSS_ASAN_CI_NOCHECK + std::nth_element( + sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare_nan_begin>()); + if (size == 0) continue; + IS_ARR_PARTITIONED(arr, k, sortedarr[k], type, true); +#endif + } + } + } + } +} + TYPED_TEST_P(simdsort, test_argselect) { for (auto type : this->arrtype) { @@ -291,14 +475,105 @@ TYPED_TEST_P(simdsort, test_comparator) } } +TYPED_TEST_P(simdsort, test_qsort_trailing_nans) +{ + if constexpr (xss::fp::is_floating_point_v) { + std::vector nan_types = {"rand_with_nan", + "rand_with_max_and_nan"}; + for (auto type : nan_types) { + for (auto size : this->arrsize_long) { + std::vector base = get_array(type, size); + + // ascending, NaNs at end + { + std::vector arr = base; + std::vector sortedarr = base; + x86simdsort::qsort( + arr.data(), arr.size(), true, false, true); +#ifndef XSS_ASAN_CI_NOCHECK + std::sort(sortedarr.begin(), + sortedarr.end(), + compare_nan_end>()); + IS_SORTED(sortedarr, arr, type); +#endif + } + + // descending, NaNs at end + { + std::vector arr = base; + std::vector sortedarr = base; + x86simdsort::qsort( + arr.data(), arr.size(), true, true, true); +#ifndef XSS_ASAN_CI_NOCHECK + std::sort(sortedarr.begin(), + sortedarr.end(), + compare_nan_end>()); + IS_SORTED(sortedarr, arr, type); +#endif + } + } + } + } +} + +TYPED_TEST_P(simdsort, test_qsort_leading_nans) +{ + if constexpr (xss::fp::is_floating_point_v) { + std::vector nan_types = {"rand_with_nan", + "rand_with_max_and_nan"}; + for (auto type : nan_types) { + for (auto size : this->arrsize_long) { + std::vector base = get_array(type, size); + + // ascending, NaNs at start + { + std::vector arr = base; + std::vector sortedarr = base; + x86simdsort::qsort( + arr.data(), arr.size(), true, false, false); +#ifndef XSS_ASAN_CI_NOCHECK + std::sort(sortedarr.begin(), + sortedarr.end(), + compare_nan_begin>()); + IS_SORTED(sortedarr, arr, type); +#endif + } + + // descending, NaNs at start + { + std::vector arr = base; + std::vector sortedarr = base; + x86simdsort::qsort( + arr.data(), arr.size(), true, true, false); +#ifndef XSS_ASAN_CI_NOCHECK + std::sort(sortedarr.begin(), + sortedarr.end(), + compare_nan_begin>()); + IS_SORTED(sortedarr, arr, type); +#endif + } + } + } + } +} + REGISTER_TYPED_TEST_SUITE_P(simdsort, test_qsort_ascending, test_qsort_descending, + test_qsort_trailing_nans, + test_qsort_leading_nans, test_argsort_ascending, test_argsort_descending, + test_argsort_trailing_nans, + test_argsort_leading_nans, test_argselect, test_qselect_ascending, test_qselect_descending, + test_qselect_trailing_nans, + test_qselect_leading_nans, test_partial_qsort_ascending, test_partial_qsort_descending, test_comparator); From e9f6e4967df2ebe1b5e39211c332788a556b90b7 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 25 May 2026 04:49:31 +0000 Subject: [PATCH 05/13] Clang format files --- lib/x86simdsort-avx2.cpp | 55 +++++++++----- lib/x86simdsort-icl.cpp | 39 ++++++---- lib/x86simdsort-internal.h | 8 ++- lib/x86simdsort-scalar.h | 90 ++++++++++++++--------- lib/x86simdsort-skx.cpp | 55 +++++++++----- lib/x86simdsort-spr.cpp | 13 ++-- lib/x86simdsort.cpp | 130 ++++++++++++++++++++-------------- lib/x86simdsort.h | 16 +++-- src/avx512-16bit-qsort.hpp | 10 ++- src/x86simdsort-static-incl.h | 75 +++++++++++++------- src/xss-common-argsort.h | 26 +++---- src/xss-common-qsort.h | 45 ++++++------ tests/test-qsort.cpp | 74 +++++++++---------- 13 files changed, 379 insertions(+), 257 deletions(-) diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 527e5e49..3ac5bb12 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -5,39 +5,56 @@ #define DEFINE_ALL_METHODS(type) \ template <> \ - void qsort(type *arr, size_t arrsize, bool hasnan, bool descending, \ + void qsort(type *arr, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ bool trailing_nans) \ { \ - x86simdsortStatic::qsort(arr, arrsize, hasnan, descending, \ - trailing_nans); \ + x86simdsortStatic::qsort( \ + arr, arrsize, hasnan, descending, trailing_nans); \ } \ template <> \ - void qselect(type *arr, size_t k, size_t arrsize, bool hasnan, \ - bool descending, bool trailing_nans) \ + void qselect(type *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool trailing_nans) \ { \ - x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending, \ - trailing_nans); \ + x86simdsortStatic::qselect( \ + arr, k, arrsize, hasnan, descending, trailing_nans); \ } \ template <> \ - void partial_qsort(type *arr, size_t k, size_t arrsize, bool hasnan, \ - bool descending, bool trailing_nans) \ + void partial_qsort(type *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool trailing_nans) \ { \ - x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending, \ - trailing_nans); \ + x86simdsortStatic::partial_qsort( \ + arr, k, arrsize, hasnan, descending, trailing_nans); \ } \ template <> \ - std::vector argsort(const type *arr, size_t arrsize, bool hasnan, \ - bool descending, bool trailing_nans) \ + std::vector argsort(const type *arr, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool trailing_nans) \ { \ - return x86simdsortStatic::argsort(arr, arrsize, hasnan, descending, \ - trailing_nans); \ + return x86simdsortStatic::argsort( \ + arr, arrsize, hasnan, descending, trailing_nans); \ } \ template <> \ - std::vector argselect(const type *arr, size_t k, size_t arrsize, \ - bool hasnan, bool trailing_nans) \ + std::vector argselect(const type *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool trailing_nans) \ { \ - return x86simdsortStatic::argselect(arr, k, arrsize, hasnan, \ - trailing_nans); \ + return x86simdsortStatic::argselect( \ + arr, k, arrsize, hasnan, trailing_nans); \ } #define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \ diff --git a/lib/x86simdsort-icl.cpp b/lib/x86simdsort-icl.cpp index 8d04ee64..27d891fb 100644 --- a/lib/x86simdsort-icl.cpp +++ b/lib/x86simdsort-icl.cpp @@ -8,7 +8,10 @@ namespace xss { namespace avx512 { template <> - void qsort(uint16_t *arr, size_t size, bool hasnan, bool descending, + void qsort(uint16_t *arr, + size_t size, + bool hasnan, + bool descending, bool trailing_nans) { x86simdsortStatic::qsort(arr, size, hasnan, descending, trailing_nans); @@ -21,8 +24,8 @@ namespace avx512 { bool descending, bool trailing_nans) { - x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending, - trailing_nans); + x86simdsortStatic::qselect( + arr, k, arrsize, hasnan, descending, trailing_nans); } template <> void partial_qsort(uint16_t *arr, @@ -32,11 +35,14 @@ namespace avx512 { bool descending, bool trailing_nans) { - x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending, - trailing_nans); + x86simdsortStatic::partial_qsort( + arr, k, arrsize, hasnan, descending, trailing_nans); } template <> - void qsort(int16_t *arr, size_t size, bool hasnan, bool descending, + void qsort(int16_t *arr, + size_t size, + bool hasnan, + bool descending, bool trailing_nans) { x86simdsortStatic::qsort(arr, size, hasnan, descending, trailing_nans); @@ -49,8 +55,8 @@ namespace avx512 { bool descending, bool trailing_nans) { - x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending, - trailing_nans); + x86simdsortStatic::qselect( + arr, k, arrsize, hasnan, descending, trailing_nans); } template <> void partial_qsort(int16_t *arr, @@ -60,14 +66,17 @@ namespace avx512 { bool descending, bool trailing_nans) { - x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending, - trailing_nans); + x86simdsortStatic::partial_qsort( + arr, k, arrsize, hasnan, descending, trailing_nans); } } // namespace avx512 namespace fp16_icl { #ifdef __FLT16_MAX__ template <> - void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending, + void qsort(_Float16 *arr, + size_t size, + bool hasnan, + bool descending, bool trailing_nans) { x86simdsortStatic::qsort(arr, size, hasnan, descending, trailing_nans); @@ -80,8 +89,8 @@ namespace fp16_icl { bool descending, bool trailing_nans) { - x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending, - trailing_nans); + x86simdsortStatic::qselect( + arr, k, arrsize, hasnan, descending, trailing_nans); } template <> void partial_qsort(_Float16 *arr, @@ -91,8 +100,8 @@ namespace fp16_icl { bool descending, bool trailing_nans) { - x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending, - trailing_nans); + x86simdsortStatic::partial_qsort( + arr, k, arrsize, hasnan, descending, trailing_nans); } #endif } // namespace fp16_icl diff --git a/lib/x86simdsort-internal.h b/lib/x86simdsort-internal.h index 7bd91afc..93e158af 100644 --- a/lib/x86simdsort-internal.h +++ b/lib/x86simdsort-internal.h @@ -53,9 +53,11 @@ bool descending = false, \ bool trailing_nans = true); \ template \ - XSS_HIDE_SYMBOL std::vector \ - argselect(const T *arr, size_t k, size_t arrsize, bool hasnan = false, \ - bool trailing_nans = true); \ + XSS_HIDE_SYMBOL std::vector argselect(const T *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan = false, \ + bool trailing_nans = true); \ } namespace xss { diff --git a/lib/x86simdsort-scalar.h b/lib/x86simdsort-scalar.h index 03c23572..ad703da6 100644 --- a/lib/x86simdsort-scalar.h +++ b/lib/x86simdsort-scalar.h @@ -24,8 +24,8 @@ namespace utils { } } template - decltype(auto) get_cmp_func(bool hasnan, bool reverse, - bool trailing_nans = true) + decltype(auto) + get_cmp_func(bool hasnan, bool reverse, bool trailing_nans = true) { std::function cmp; if (hasnan) { @@ -38,9 +38,7 @@ namespace utils { } } else { - if (reverse == true) { - cmp = compare>(); - } + if (reverse == true) { cmp = compare>(); } else { cmp = compare_nan_begin>(); } @@ -58,79 +56,101 @@ namespace utils { namespace scalar { template - void qsort(T *arr, size_t arrsize, bool hasnan, bool reversed, + void qsort(T *arr, + size_t arrsize, + bool hasnan, + bool reversed, bool trailing_nans) { std::sort(arr, arr + arrsize, - xss::utils::get_cmp_func(hasnan, reversed, - trailing_nans)); + xss::utils::get_cmp_func(hasnan, reversed, trailing_nans)); } template - void qselect(T *arr, size_t k, size_t arrsize, bool hasnan, bool reversed, + void qselect(T *arr, + size_t k, + size_t arrsize, + bool hasnan, + bool reversed, bool trailing_nans) { - std::nth_element(arr, - arr + k, - arr + arrsize, - xss::utils::get_cmp_func(hasnan, reversed, - trailing_nans)); + std::nth_element( + arr, + arr + k, + arr + arrsize, + xss::utils::get_cmp_func(hasnan, reversed, trailing_nans)); } template - void partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan, - bool reversed, bool trailing_nans) + void partial_qsort(T *arr, + size_t k, + size_t arrsize, + bool hasnan, + bool reversed, + bool trailing_nans) { - std::partial_sort(arr, - arr + k, - arr + arrsize, - xss::utils::get_cmp_func(hasnan, reversed, - trailing_nans)); + std::partial_sort( + arr, + arr + k, + arr + arrsize, + xss::utils::get_cmp_func(hasnan, reversed, trailing_nans)); } template - std::vector - argsort(const T *arr, size_t arrsize, bool hasnan, bool reversed, - bool trailing_nans) + std::vector argsort(const T *arr, + size_t arrsize, + bool hasnan, + bool reversed, + bool trailing_nans) { UNUSED(hasnan); std::vector arg(arrsize); std::iota(arg.begin(), arg.end(), 0); if (trailing_nans) { if (reversed) { - std::sort(arg.begin(), arg.end(), + std::sort(arg.begin(), + arg.end(), compare_arg_nan_end>(arr)); } else { - std::sort(arg.begin(), arg.end(), + std::sort(arg.begin(), + arg.end(), compare_arg>(arr)); } } else { if (reversed) { - std::sort(arg.begin(), arg.end(), + std::sort(arg.begin(), + arg.end(), compare_arg>(arr)); } else { - std::sort(arg.begin(), arg.end(), + std::sort(arg.begin(), + arg.end(), compare_arg_nan_begin>(arr)); } } return arg; } template - std::vector - argselect(const T *arr, size_t k, size_t arrsize, bool hasnan, - bool trailing_nans) + std::vector argselect(const T *arr, + size_t k, + size_t arrsize, + bool hasnan, + bool trailing_nans) { UNUSED(hasnan); std::vector arg(arrsize); std::iota(arg.begin(), arg.end(), 0); if (hasnan && !trailing_nans) { - std::nth_element(arg.begin(), arg.begin() + k, arg.end(), + std::nth_element(arg.begin(), + arg.begin() + k, + arg.end(), compare_arg_nan_begin>(arr)); } else { - std::nth_element(arg.begin(), arg.begin() + k, arg.end(), + std::nth_element(arg.begin(), + arg.begin() + k, + arg.end(), compare_arg>(arr)); } return arg; @@ -139,8 +159,8 @@ namespace scalar { void keyvalue_qsort( T1 *key, T2 *val, size_t arrsize, bool hasnan, bool descending) { - std::vector arg = argsort(key, arrsize, hasnan, descending, - true); + std::vector arg + = argsort(key, arrsize, hasnan, descending, true); utils::apply_permutation_in_place(key, arg); utils::apply_permutation_in_place(val, arg); } diff --git a/lib/x86simdsort-skx.cpp b/lib/x86simdsort-skx.cpp index f7b462ae..813c607e 100644 --- a/lib/x86simdsort-skx.cpp +++ b/lib/x86simdsort-skx.cpp @@ -5,39 +5,56 @@ #define DEFINE_ALL_METHODS(type) \ template <> \ - void qsort(type *arr, size_t arrsize, bool hasnan, bool descending, \ + void qsort(type *arr, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ bool trailing_nans) \ { \ - x86simdsortStatic::qsort(arr, arrsize, hasnan, descending, \ - trailing_nans); \ + x86simdsortStatic::qsort( \ + arr, arrsize, hasnan, descending, trailing_nans); \ } \ template <> \ - void qselect(type *arr, size_t k, size_t arrsize, bool hasnan, \ - bool descending, bool trailing_nans) \ + void qselect(type *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool trailing_nans) \ { \ - x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending, \ - trailing_nans); \ + x86simdsortStatic::qselect( \ + arr, k, arrsize, hasnan, descending, trailing_nans); \ } \ template <> \ - void partial_qsort(type *arr, size_t k, size_t arrsize, bool hasnan, \ - bool descending, bool trailing_nans) \ + void partial_qsort(type *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool trailing_nans) \ { \ - x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending, \ - trailing_nans); \ + x86simdsortStatic::partial_qsort( \ + arr, k, arrsize, hasnan, descending, trailing_nans); \ } \ template <> \ - std::vector argsort(const type *arr, size_t arrsize, bool hasnan, \ - bool descending, bool trailing_nans) \ + std::vector argsort(const type *arr, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool trailing_nans) \ { \ - return x86simdsortStatic::argsort(arr, arrsize, hasnan, descending, \ - trailing_nans); \ + return x86simdsortStatic::argsort( \ + arr, arrsize, hasnan, descending, trailing_nans); \ } \ template <> \ - std::vector argselect(const type *arr, size_t k, size_t arrsize, \ - bool hasnan, bool trailing_nans) \ + std::vector argselect(const type *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool trailing_nans) \ { \ - return x86simdsortStatic::argselect(arr, k, arrsize, hasnan, \ - trailing_nans); \ + return x86simdsortStatic::argselect( \ + arr, k, arrsize, hasnan, trailing_nans); \ } #define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \ diff --git a/lib/x86simdsort-spr.cpp b/lib/x86simdsort-spr.cpp index 59f0a829..c615d785 100644 --- a/lib/x86simdsort-spr.cpp +++ b/lib/x86simdsort-spr.cpp @@ -5,7 +5,10 @@ namespace xss { namespace fp16_spr { template <> - void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending, + void qsort(_Float16 *arr, + size_t size, + bool hasnan, + bool descending, bool trailing_nans) { x86simdsortStatic::qsort(arr, size, hasnan, descending, trailing_nans); @@ -18,8 +21,8 @@ namespace fp16_spr { bool descending, bool trailing_nans) { - x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending, - trailing_nans); + x86simdsortStatic::qselect( + arr, k, arrsize, hasnan, descending, trailing_nans); } template <> void partial_qsort(_Float16 *arr, @@ -29,8 +32,8 @@ namespace fp16_spr { bool descending, bool trailing_nans) { - x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending, - trailing_nans); + x86simdsortStatic::partial_qsort( + arr, k, arrsize, hasnan, descending, trailing_nans); } } // namespace fp16_spr } // namespace xss diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 62b2f8b9..0da4cc47 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -104,44 +104,53 @@ namespace x86simdsort { static void (*internal_qsort##TYPE)(TYPE *, size_t, bool, bool, bool) \ = NULL; \ template <> \ - void XSS_EXPORT_SYMBOL qsort(TYPE *arr, size_t arrsize, bool hasnan, \ - bool descending, bool trailing_nans) \ + void XSS_EXPORT_SYMBOL qsort(TYPE *arr, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool trailing_nans) \ { \ if (internal_qsort##TYPE == NULL) { CAT(resolve_qsort, TYPE)(); } \ - (*internal_qsort##TYPE)(arr, arrsize, hasnan, descending, \ - trailing_nans); \ + (*internal_qsort##TYPE)( \ + arr, arrsize, hasnan, descending, trailing_nans); \ } #define DECLARE_INTERNAL_qselect(TYPE) \ static void CAT(resolve_qselect, TYPE)(void); \ - static void (*internal_qselect##TYPE)(TYPE *, size_t, size_t, bool, bool, \ - bool) \ + static void (*internal_qselect##TYPE)( \ + TYPE *, size_t, size_t, bool, bool, bool) \ = NULL; \ template <> \ - void XSS_EXPORT_SYMBOL qselect(TYPE *arr, size_t k, size_t arrsize, \ - bool hasnan, bool descending, \ + void XSS_EXPORT_SYMBOL qselect(TYPE *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ bool trailing_nans) \ { \ if (internal_qselect##TYPE == NULL) { CAT(resolve_qselect, TYPE)(); } \ - (*internal_qselect##TYPE)(arr, k, arrsize, hasnan, descending, \ - trailing_nans); \ + (*internal_qselect##TYPE)( \ + arr, k, arrsize, hasnan, descending, trailing_nans); \ } #define DECLARE_INTERNAL_partial_qsort(TYPE) \ static void CAT(resolve_partial_qsort, TYPE)(void); \ - static void (*internal_partial_qsort##TYPE)(TYPE *, size_t, size_t, bool, \ - bool, bool) \ + static void (*internal_partial_qsort##TYPE)( \ + TYPE *, size_t, size_t, bool, bool, bool) \ = NULL; \ template <> \ - void XSS_EXPORT_SYMBOL partial_qsort(TYPE *arr, size_t k, size_t arrsize, \ - bool hasnan, bool descending, \ + void XSS_EXPORT_SYMBOL partial_qsort(TYPE *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ bool trailing_nans) \ { \ if (internal_partial_qsort##TYPE == NULL) { \ CAT(resolve_partial_qsort, TYPE)(); \ } \ - (*internal_partial_qsort##TYPE)(arr, k, arrsize, hasnan, descending, \ - trailing_nans); \ + (*internal_partial_qsort##TYPE)( \ + arr, k, arrsize, hasnan, descending, trailing_nans); \ } #define DECLARE_INTERNAL_argsort(TYPE) \ @@ -150,13 +159,15 @@ namespace x86simdsort { const TYPE *, size_t, bool, bool, bool) \ = NULL; \ template <> \ - std::vector XSS_EXPORT_SYMBOL argsort( \ - const TYPE *arr, size_t arrsize, bool hasnan, bool descending, \ - bool trailing_nans) \ + std::vector XSS_EXPORT_SYMBOL argsort(const TYPE *arr, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool trailing_nans) \ { \ if (internal_argsort##TYPE == NULL) { CAT(resolve_argsort, TYPE)(); } \ - return (*internal_argsort##TYPE)(arr, arrsize, hasnan, descending, \ - trailing_nans); \ + return (*internal_argsort##TYPE)( \ + arr, arrsize, hasnan, descending, trailing_nans); \ } #define DECLARE_INTERNAL_argselect(TYPE) \ @@ -165,15 +176,17 @@ namespace x86simdsort { const TYPE *, size_t, size_t, bool, bool) \ = NULL; \ template <> \ - std::vector XSS_EXPORT_SYMBOL argselect( \ - const TYPE *arr, size_t k, size_t arrsize, bool hasnan, \ - bool trailing_nans) \ + std::vector XSS_EXPORT_SYMBOL argselect(const TYPE *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool trailing_nans) \ { \ if (internal_argselect##TYPE == NULL) { \ CAT(resolve_argselect, TYPE)(); \ } \ - return (*internal_argselect##TYPE)(arr, k, arrsize, hasnan, \ - trailing_nans); \ + return (*internal_argselect##TYPE)( \ + arr, k, arrsize, hasnan, trailing_nans); \ } #else @@ -182,37 +195,46 @@ namespace x86simdsort { static void (*internal_qsort##TYPE)(TYPE *, size_t, bool, bool, bool) \ = NULL; \ template <> \ - void XSS_EXPORT_SYMBOL qsort(TYPE *arr, size_t arrsize, bool hasnan, \ - bool descending, bool trailing_nans) \ + void XSS_EXPORT_SYMBOL qsort(TYPE *arr, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool trailing_nans) \ { \ - (*internal_qsort##TYPE)(arr, arrsize, hasnan, descending, \ - trailing_nans); \ + (*internal_qsort##TYPE)( \ + arr, arrsize, hasnan, descending, trailing_nans); \ } #define DECLARE_INTERNAL_qselect(TYPE) \ - static void (*internal_qselect##TYPE)(TYPE *, size_t, size_t, bool, bool, \ - bool) \ + static void (*internal_qselect##TYPE)( \ + TYPE *, size_t, size_t, bool, bool, bool) \ = NULL; \ template <> \ - void XSS_EXPORT_SYMBOL qselect(TYPE *arr, size_t k, size_t arrsize, \ - bool hasnan, bool descending, \ + void XSS_EXPORT_SYMBOL qselect(TYPE *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ bool trailing_nans) \ { \ - (*internal_qselect##TYPE)(arr, k, arrsize, hasnan, descending, \ - trailing_nans); \ + (*internal_qselect##TYPE)( \ + arr, k, arrsize, hasnan, descending, trailing_nans); \ } #define DECLARE_INTERNAL_partial_qsort(TYPE) \ - static void (*internal_partial_qsort##TYPE)(TYPE *, size_t, size_t, bool, \ - bool, bool) \ + static void (*internal_partial_qsort##TYPE)( \ + TYPE *, size_t, size_t, bool, bool, bool) \ = NULL; \ template <> \ - void XSS_EXPORT_SYMBOL partial_qsort(TYPE *arr, size_t k, size_t arrsize, \ - bool hasnan, bool descending, \ + void XSS_EXPORT_SYMBOL partial_qsort(TYPE *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ bool trailing_nans) \ { \ - (*internal_partial_qsort##TYPE)(arr, k, arrsize, hasnan, descending, \ - trailing_nans); \ + (*internal_partial_qsort##TYPE)( \ + arr, k, arrsize, hasnan, descending, trailing_nans); \ } #define DECLARE_INTERNAL_argsort(TYPE) \ @@ -220,12 +242,14 @@ namespace x86simdsort { const TYPE *, size_t, bool, bool, bool) \ = NULL; \ template <> \ - std::vector XSS_EXPORT_SYMBOL argsort( \ - const TYPE *arr, size_t arrsize, bool hasnan, bool descending, \ - bool trailing_nans) \ + std::vector XSS_EXPORT_SYMBOL argsort(const TYPE *arr, \ + size_t arrsize, \ + bool hasnan, \ + bool descending, \ + bool trailing_nans) \ { \ - return (*internal_argsort##TYPE)(arr, arrsize, hasnan, descending, \ - trailing_nans); \ + return (*internal_argsort##TYPE)( \ + arr, arrsize, hasnan, descending, trailing_nans); \ } #define DECLARE_INTERNAL_argselect(TYPE) \ @@ -233,12 +257,14 @@ namespace x86simdsort { const TYPE *, size_t, size_t, bool, bool) \ = NULL; \ template <> \ - std::vector XSS_EXPORT_SYMBOL argselect( \ - const TYPE *arr, size_t k, size_t arrsize, bool hasnan, \ - bool trailing_nans) \ + std::vector XSS_EXPORT_SYMBOL argselect(const TYPE *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool trailing_nans) \ { \ - return (*internal_argselect##TYPE)(arr, k, arrsize, hasnan, \ - trailing_nans); \ + return (*internal_argselect##TYPE)( \ + arr, k, arrsize, hasnan, trailing_nans); \ } #endif // _MSC_VER diff --git a/lib/x86simdsort.h b/lib/x86simdsort.h index 80c6ce67..b9e94b2e 100644 --- a/lib/x86simdsort.h +++ b/lib/x86simdsort.h @@ -19,9 +19,11 @@ namespace x86simdsort { // quicksort template -XSS_EXPORT_SYMBOL void -qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false, - bool trailing_nans = true); +XSS_EXPORT_SYMBOL void qsort(T *arr, + size_t arrsize, + bool hasnan = false, + bool descending = false, + bool trailing_nans = true); // quickselect template @@ -51,9 +53,11 @@ XSS_EXPORT_SYMBOL std::vector argsort(const T *arr, // argselect template -XSS_EXPORT_SYMBOL std::vector -argselect(const T *arr, size_t k, size_t arrsize, bool hasnan = false, - bool trailing_nans = true); +XSS_EXPORT_SYMBOL std::vector argselect(const T *arr, + size_t k, + size_t arrsize, + bool hasnan = false, + bool trailing_nans = true); // keyvalue sort template diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 6efe1a50..7ae78778 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -519,7 +519,6 @@ comparison_func>(const uint16_t &a, const uint16_t &b) //return npy_half_to_float(a) < npy_half_to_float(b); } - template [[maybe_unused]] X86_SIMD_SORT_INLINE void avx512_qsort_fp16_helper(uint16_t *arr, @@ -535,9 +534,9 @@ avx512_qsort_fp16_helper(uint16_t *arr, if (use_parallel) { int thread_count = xss_get_num_threads(); - arrsize_t task_threshold = std::max( - (arrsize_t)100000, - (index_last_elem - index_first_elem + 1) / 100); + arrsize_t task_threshold + = std::max((arrsize_t)100000, + (index_last_elem - index_first_elem + 1) / 100); // We use omp parallel and then omp single to setup the threads that will run the omp task calls in qsort_ // The omp single prevents multiple threads from running the initial qsort_ simultaneously and causing problems @@ -664,8 +663,7 @@ avx512_partial_qsort_fp16(uint16_t *arr, bool trailing_nans = true) { if (k == 0) return; - avx512_qselect_fp16(arr, k - 1, arrsize, hasnan, descending, - trailing_nans); + avx512_qselect_fp16(arr, k - 1, arrsize, hasnan, descending, trailing_nans); avx512_qsort_fp16(arr, k - 1, hasnan, descending, trailing_nans); } #endif // AVX512_QSORT_16BIT diff --git a/src/x86simdsort-static-incl.h b/src/x86simdsort-static-incl.h index a06ec15a..d958c515 100644 --- a/src/x86simdsort-static-incl.h +++ b/src/x86simdsort-static-incl.h @@ -7,9 +7,11 @@ // Supported methods declared here for a quick reference: namespace x86simdsortStatic { template -X86_SIMD_SORT_FINLINE void -qsort(T *arr, size_t size, bool hasnan = false, bool descending = false, - bool trailing_nans = true); +X86_SIMD_SORT_FINLINE void qsort(T *arr, + size_t size, + bool hasnan = false, + bool descending = false, + bool trailing_nans = true); template X86_SIMD_SORT_FINLINE void qselect(T *arr, @@ -44,15 +46,20 @@ X86_SIMD_SORT_FINLINE void argsort(const T *arr, bool trailing_nans = true); template -X86_SIMD_SORT_FINLINE std::vector -argselect(const T *arr, size_t k, size_t size, bool hasnan = false, - bool trailing_nans = true); +X86_SIMD_SORT_FINLINE std::vector argselect(const T *arr, + size_t k, + size_t size, + bool hasnan = false, + bool trailing_nans = true); /* argselect API required by NumPy: */ template -void X86_SIMD_SORT_FINLINE argselect( - const T *arr, size_t *arg, size_t k, size_t size, bool hasnan = false, - bool trailing_nans = true); +void X86_SIMD_SORT_FINLINE argselect(const T *arr, + size_t *arg, + size_t k, + size_t size, + bool hasnan = false, + bool trailing_nans = true); template X86_SIMD_SORT_FINLINE void keyvalue_qsort(T1 *key, @@ -81,22 +88,31 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key, #define XSS_METHODS(ISA) \ template \ - X86_SIMD_SORT_FINLINE void x86simdsortStatic::qsort( \ - T *arr, size_t size, bool hasnan, bool descending, \ - bool trailing_nans) \ + X86_SIMD_SORT_FINLINE void x86simdsortStatic::qsort(T *arr, \ + size_t size, \ + bool hasnan, \ + bool descending, \ + bool trailing_nans) \ { \ ISA##_qsort(arr, size, hasnan, descending, trailing_nans); \ } \ template \ - X86_SIMD_SORT_FINLINE void x86simdsortStatic::qselect( \ - T *arr, size_t k, size_t size, bool hasnan, bool descending, \ - bool trailing_nans) \ + X86_SIMD_SORT_FINLINE void x86simdsortStatic::qselect(T *arr, \ + size_t k, \ + size_t size, \ + bool hasnan, \ + bool descending, \ + bool trailing_nans) \ { \ ISA##_qselect(arr, k, size, hasnan, descending, trailing_nans); \ } \ template \ X86_SIMD_SORT_FINLINE void x86simdsortStatic::partial_qsort( \ - T *arr, size_t k, size_t size, bool hasnan, bool descending, \ + T *arr, \ + size_t k, \ + size_t size, \ + bool hasnan, \ + bool descending, \ bool trailing_nans) \ { \ ISA##_partial_qsort(arr, k, size, hasnan, descending, trailing_nans); \ @@ -113,7 +129,10 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key, } \ template \ X86_SIMD_SORT_FINLINE std::vector x86simdsortStatic::argsort( \ - const T *arr, size_t size, bool hasnan, bool descending, \ + const T *arr, \ + size_t size, \ + bool hasnan, \ + bool descending, \ bool trailing_nans) \ { \ std::vector indices(size); \ @@ -124,14 +143,21 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key, } \ template \ X86_SIMD_SORT_FINLINE void x86simdsortStatic::argselect( \ - const T *arr, size_t *arg, size_t k, size_t size, bool hasnan, \ + const T *arr, \ + size_t *arg, \ + size_t k, \ + size_t size, \ + bool hasnan, \ bool trailing_nans) \ { \ ISA##_argselect(arr, arg, k, size, hasnan, trailing_nans); \ } \ template \ X86_SIMD_SORT_FINLINE std::vector x86simdsortStatic::argselect( \ - const T *arr, size_t k, size_t size, bool hasnan, \ + const T *arr, \ + size_t k, \ + size_t size, \ + bool hasnan, \ bool trailing_nans) \ { \ std::vector indices(size); \ @@ -203,8 +229,7 @@ void x86simdsortStatic::qsort<_Float16>(_Float16 *arr, bool descending, bool trailing_nans) { - avx512_qsort_fp16((uint16_t *)arr, size, hasnan, descending, - trailing_nans); + avx512_qsort_fp16((uint16_t *)arr, size, hasnan, descending, trailing_nans); } template <> [[maybe_unused]] @@ -215,8 +240,8 @@ void x86simdsortStatic::qselect<_Float16>(_Float16 *arr, bool descending, bool trailing_nans) { - avx512_qselect_fp16((uint16_t *)arr, k, size, hasnan, descending, - trailing_nans); + avx512_qselect_fp16( + (uint16_t *)arr, k, size, hasnan, descending, trailing_nans); } template <> [[maybe_unused]] @@ -227,8 +252,8 @@ void x86simdsortStatic::partial_qsort<_Float16>(_Float16 *arr, bool descending, bool trailing_nans) { - avx512_partial_qsort_fp16((uint16_t *)arr, k, size, hasnan, descending, - trailing_nans); + avx512_partial_qsort_fp16( + (uint16_t *)arr, k, size, hasnan, descending, trailing_nans); } #endif diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h index f012b52d..dfe331e4 100644 --- a/src/xss-common-argsort.h +++ b/src/xss-common-argsort.h @@ -39,17 +39,18 @@ X86_SIMD_SORT_INLINE void std_argsort_withnan(const T *arr, bool trailing_nans = true, bool descending = false) { - std::sort(arg + left, - arg + right, - [arr, trailing_nans, descending](arrsize_t a, arrsize_t b) -> bool { - bool a_nan = std::isnan(arr[a]); - bool b_nan = std::isnan(arr[b]); - if (!a_nan && !b_nan) { - return descending ? arr[a] > arr[b] : arr[a] < arr[b]; - } - if (a_nan && b_nan) { return false; } - return trailing_nans ? !a_nan : a_nan; - }); + std::sort( + arg + left, + arg + right, + [arr, trailing_nans, descending](arrsize_t a, arrsize_t b) -> bool { + bool a_nan = std::isnan(arr[a]); + bool b_nan = std::isnan(arr[b]); + if (!a_nan && !b_nan) { + return descending ? arr[a] > arr[b] : arr[a] < arr[b]; + } + if (a_nan && b_nan) { return false; } + return trailing_nans ? !a_nan : a_nan; + }); } /* argsort using std::sort */ @@ -719,8 +720,7 @@ X86_SIMD_SORT_INLINE void xss_argselect(const T *arr, if (arrsize > 1) { if constexpr (xss::fp::is_floating_point_v) { if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argselect_withnan( - arr, arg, k, 0, arrsize, trailing_nans); + std_argselect_withnan(arr, arg, k, 0, arrsize, trailing_nans); return; } } diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 35a78613..6fe34bbd 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -51,7 +51,6 @@ X86_SIMD_SORT_INLINE_ONLY bool is_a_nan(uint16_t elem) return ((elem & 0x7c00u) == 0x7c00u) && ((elem & 0x03ffu) != 0); } - template X86_SIMD_SORT_INLINE bool array_has_nan(const type_t *arr, arrsize_t size) { @@ -78,7 +77,6 @@ X86_SIMD_SORT_INLINE bool array_has_nan(const type_t *arr, arrsize_t size) return found_nan; } - /* * Sort all the NAN's to end of the array and return the index of the last elem * in the array which is not a nan @@ -609,7 +607,8 @@ xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, bool trailing_nans = true) if constexpr (xss::fp::is_floating_point_v) { if (UNLIKELY(hasnan)) { if (!trailing_nans) { - index_first_elem = move_nans_to_start_of_array(arr, arrsize); + index_first_elem + = move_nans_to_start_of_array(arr, arrsize); } else { index_last_elem = move_nans_to_end_of_array(arr, arrsize); @@ -622,7 +621,8 @@ xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, bool trailing_nans = true) if (index_first_elem <= index_last_elem && index_last_elem < arrsize) { #ifdef XSS_COMPILE_OPENMP - bool use_parallel = (index_last_elem - index_first_elem + 1) > 100000; + bool use_parallel + = (index_last_elem - index_first_elem + 1) > 100000; if (use_parallel) { int thread_count = xss_get_num_threads(); @@ -643,11 +643,12 @@ xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, bool trailing_nans = true) #pragma omp taskwait } else { - qsort_(arr, - index_first_elem, - index_last_elem, - 2 * (arrsize_t)log2(arrsize), - std::numeric_limits::max()); + qsort_( + arr, + index_first_elem, + index_last_elem, + 2 * (arrsize_t)log2(arrsize), + std::numeric_limits::max()); } #else qsort_(arr, @@ -713,14 +714,14 @@ X86_SIMD_SORT_INLINE void xss_qselect(T *arr, // Partial sort methods: template X86_SIMD_SORT_INLINE void xss_partial_qsort(T *arr, - arrsize_t k, - arrsize_t arrsize, - bool hasnan, - bool trailing_nans = true) + arrsize_t k, + arrsize_t arrsize, + bool hasnan, + bool trailing_nans = true) { if (k == 0) return; - xss_qselect(arr, k - 1, arrsize, hasnan, - trailing_nans); + xss_qselect( + arr, k - 1, arrsize, hasnan, trailing_nans); xss_qsort(arr, k - 1, hasnan, trailing_nans); } @@ -748,12 +749,10 @@ X86_SIMD_SORT_INLINE void xss_partial_qsort(T *arr, bool trailing_nans = true) \ { \ if (descending) { \ - xss_qselect(arr, k, size, hasnan, \ - trailing_nans); \ + xss_qselect(arr, k, size, hasnan, trailing_nans); \ } \ else { \ - xss_qselect(arr, k, size, hasnan, \ - trailing_nans); \ + xss_qselect(arr, k, size, hasnan, trailing_nans); \ } \ } \ template \ @@ -765,12 +764,12 @@ X86_SIMD_SORT_INLINE void xss_partial_qsort(T *arr, bool trailing_nans = true) \ { \ if (descending) { \ - xss_partial_qsort(arr, k, size, hasnan, \ - trailing_nans); \ + xss_partial_qsort( \ + arr, k, size, hasnan, trailing_nans); \ } \ else { \ - xss_partial_qsort(arr, k, size, hasnan, \ - trailing_nans); \ + xss_partial_qsort( \ + arr, k, size, hasnan, trailing_nans); \ } \ } diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index 467acd6a..b4baeb05 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -131,8 +131,8 @@ TYPED_TEST_P(simdsort, test_argsort_descending) TYPED_TEST_P(simdsort, test_argsort_trailing_nans) { if constexpr (xss::fp::is_floating_point_v) { - std::vector nan_types = {"rand_with_nan", - "rand_with_max_and_nan"}; + std::vector nan_types + = {"rand_with_nan", "rand_with_max_and_nan"}; for (auto type : nan_types) { for (auto size : this->arrsize_long) { std::vector base = get_array(type, size); @@ -144,9 +144,10 @@ TYPED_TEST_P(simdsort, test_argsort_trailing_nans) auto arg = x86simdsort::argsort( arr.data(), arr.size(), true, false, true); #ifndef XSS_ASAN_CI_NOCHECK - std::sort(sortedarr.begin(), - sortedarr.end(), - compare_nan_end>()); + std::sort( + sortedarr.begin(), + sortedarr.end(), + compare_nan_end>()); IS_ARG_SORTED(sortedarr, arr, arg, type); #endif } @@ -161,7 +162,7 @@ TYPED_TEST_P(simdsort, test_argsort_trailing_nans) std::sort(sortedarr.begin(), sortedarr.end(), compare_nan_end>()); + std::greater>()); IS_ARG_SORTED(sortedarr, arr, arg, type); #endif } @@ -173,8 +174,8 @@ TYPED_TEST_P(simdsort, test_argsort_trailing_nans) TYPED_TEST_P(simdsort, test_argsort_leading_nans) { if constexpr (xss::fp::is_floating_point_v) { - std::vector nan_types = {"rand_with_nan", - "rand_with_max_and_nan"}; + std::vector nan_types + = {"rand_with_nan", "rand_with_max_and_nan"}; for (auto type : nan_types) { for (auto size : this->arrsize_long) { std::vector base = get_array(type, size); @@ -189,7 +190,7 @@ TYPED_TEST_P(simdsort, test_argsort_leading_nans) std::sort(sortedarr.begin(), sortedarr.end(), compare_nan_begin>()); + std::less>()); IS_ARG_SORTED(sortedarr, arr, arg, type); #endif } @@ -204,7 +205,7 @@ TYPED_TEST_P(simdsort, test_argsort_leading_nans) std::sort(sortedarr.begin(), sortedarr.end(), compare_nan_begin>()); + std::greater>()); IS_ARG_SORTED(sortedarr, arr, arg, type); #endif } @@ -254,10 +255,11 @@ TYPED_TEST_P(simdsort, test_qselect_descending) x86simdsort::qselect(arr.data(), k, arr.size(), hasnan, true); #ifndef XSS_ASAN_CI_NOCHECK - std::nth_element(sortedarr.begin(), - sortedarr.begin() + k, - sortedarr.end(), - compare_nan_end>()); + std::nth_element( + sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare_nan_end>()); if (size == 0) continue; IS_ARR_PARTITIONED(arr, k, sortedarr[k], type, true); #endif @@ -270,8 +272,8 @@ TYPED_TEST_P(simdsort, test_qselect_descending) TYPED_TEST_P(simdsort, test_qselect_trailing_nans) { if constexpr (xss::fp::is_floating_point_v) { - std::vector nan_types = {"rand_with_nan", - "rand_with_max_and_nan"}; + std::vector nan_types + = {"rand_with_nan", "rand_with_max_and_nan"}; for (auto type : nan_types) { for (auto size : this->arrsize) { size_t k = size != 0 ? rand() % size : 0; @@ -306,7 +308,7 @@ TYPED_TEST_P(simdsort, test_qselect_trailing_nans) sortedarr.begin() + k, sortedarr.end(), compare_nan_end>()); + std::greater>()); if (size == 0) continue; IS_ARR_PARTITIONED(arr, k, sortedarr[k], type, true); #endif @@ -319,8 +321,8 @@ TYPED_TEST_P(simdsort, test_qselect_trailing_nans) TYPED_TEST_P(simdsort, test_qselect_leading_nans) { if constexpr (xss::fp::is_floating_point_v) { - std::vector nan_types = {"rand_with_nan", - "rand_with_max_and_nan"}; + std::vector nan_types + = {"rand_with_nan", "rand_with_max_and_nan"}; for (auto type : nan_types) { for (auto size : this->arrsize) { size_t k = size != 0 ? rand() % size : 0; @@ -333,12 +335,11 @@ TYPED_TEST_P(simdsort, test_qselect_leading_nans) x86simdsort::qselect( arr.data(), k, arr.size(), true, false, false); #ifndef XSS_ASAN_CI_NOCHECK - std::nth_element( - sortedarr.begin(), - sortedarr.begin() + k, - sortedarr.end(), - compare_nan_begin>()); + std::nth_element(sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare_nan_begin>()); if (size == 0) continue; IS_ARR_PARTITIONED(arr, k, sortedarr[k], type); #endif @@ -356,7 +357,7 @@ TYPED_TEST_P(simdsort, test_qselect_leading_nans) sortedarr.begin() + k, sortedarr.end(), compare_nan_begin>()); + std::greater>()); if (size == 0) continue; IS_ARR_PARTITIONED(arr, k, sortedarr[k], type, true); #endif @@ -478,8 +479,8 @@ TYPED_TEST_P(simdsort, test_comparator) TYPED_TEST_P(simdsort, test_qsort_trailing_nans) { if constexpr (xss::fp::is_floating_point_v) { - std::vector nan_types = {"rand_with_nan", - "rand_with_max_and_nan"}; + std::vector nan_types + = {"rand_with_nan", "rand_with_max_and_nan"}; for (auto type : nan_types) { for (auto size : this->arrsize_long) { std::vector base = get_array(type, size); @@ -491,9 +492,10 @@ TYPED_TEST_P(simdsort, test_qsort_trailing_nans) x86simdsort::qsort( arr.data(), arr.size(), true, false, true); #ifndef XSS_ASAN_CI_NOCHECK - std::sort(sortedarr.begin(), - sortedarr.end(), - compare_nan_end>()); + std::sort( + sortedarr.begin(), + sortedarr.end(), + compare_nan_end>()); IS_SORTED(sortedarr, arr, type); #endif } @@ -508,7 +510,7 @@ TYPED_TEST_P(simdsort, test_qsort_trailing_nans) std::sort(sortedarr.begin(), sortedarr.end(), compare_nan_end>()); + std::greater>()); IS_SORTED(sortedarr, arr, type); #endif } @@ -520,8 +522,8 @@ TYPED_TEST_P(simdsort, test_qsort_trailing_nans) TYPED_TEST_P(simdsort, test_qsort_leading_nans) { if constexpr (xss::fp::is_floating_point_v) { - std::vector nan_types = {"rand_with_nan", - "rand_with_max_and_nan"}; + std::vector nan_types + = {"rand_with_nan", "rand_with_max_and_nan"}; for (auto type : nan_types) { for (auto size : this->arrsize_long) { std::vector base = get_array(type, size); @@ -536,7 +538,7 @@ TYPED_TEST_P(simdsort, test_qsort_leading_nans) std::sort(sortedarr.begin(), sortedarr.end(), compare_nan_begin>()); + std::less>()); IS_SORTED(sortedarr, arr, type); #endif } @@ -551,7 +553,7 @@ TYPED_TEST_P(simdsort, test_qsort_leading_nans) std::sort(sortedarr.begin(), sortedarr.end(), compare_nan_begin>()); + std::greater>()); IS_SORTED(sortedarr, arr, type); #endif } From edce53272ad69e8c10c45811035aabff5988dcab Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 25 May 2026 04:52:06 +0000 Subject: [PATCH 06/13] docs: update NaN/trailing_nans docs in README files --- README.md | 6 ++---- src/README.md | 25 ++++++++++++++----------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index da10e51a..44c719d7 100644 --- a/README.md +++ b/README.md @@ -183,10 +183,8 @@ is controlled by the optional `bool trailing_nans` parameter (default `true`): - `trailing_nans=false`: NaN values are placed at the **beginning** of the result, regardless of sort direction. -Note that `qsort` will replace all NaN values with `std::numeric_limits::quiet_NaN`; -the original bit-exact NaN payload is not preserved. Also note that the arg -methods (argsort and argselect) will not use the SIMD based algorithms if they -detect NaN values in the array. You can read details of all the implementations +Note that the arg methods (argsort and argselect) will not use the SIMD based +algorithms if they detect NaN values in the array. You can read details of all the implementations [here](https://github.com/intel/x86-simd-sort/blob/main/src/README.md). ## Performance comparison on AVX-512: `object_qsort` v/s `std::sort` diff --git a/src/README.md b/src/README.md index ad5fc7ba..57d3e12e 100644 --- a/src/README.md +++ b/src/README.md @@ -18,13 +18,14 @@ Equivalent to `qsort` in `std::sort` in [C++](https://en.cppreference.com/w/cpp/algorithm/sort). ```cpp -void x86simdsortStatic::qsort(T* arr, size_t arrsize, bool hasnan = false, bool descending = false); +void x86simdsortStatic::qsort(T* arr, size_t arrsize, bool hasnan = false, bool descending = false, bool trailing_nans = true); ``` Supported datatypes: `uint16_t`, `int16_t`, `_Float16`, `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. AVX2 versions currently support 32-bit and 64-bit dtypes only. For floating-point types, if `arr` contains -NaNs, they are moved to the end and replaced with a quiet NaN. That is, the -original, bit-exact NaNs in the input are not preserved. +NaNs, their placement is controlled by `trailing_nans`: `true` (default) places +NaNs at the end; `false` places them at the beginning. Bit-exact NaN payloads +are preserved. #### Quickselect Equivalent to `std::nth_element` in @@ -34,13 +35,14 @@ Equivalent to `std::nth_element` in ```cpp -void x86simdsortStatic::qselect(T* arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); +void x86simdsortStatic::qselect(T* arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false, bool trailing_nans = true); ``` Supported datatypes: `uint16_t`, `int16_t`, `_Float16`, `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. AVX2 versions currently support 32-bit and 64-bit dtypes only. For floating-point types, if `bool hasnan` is -set, NaNs are moved to the end of the array, preserving the bit-exact NaNs in -the input. If NaNs are present but `hasnan` is `false`, the behavior is +set, NaN placement is controlled by `trailing_nans`: `true` (default) places +NaNs at the end; `false` places them at the beginning. Bit-exact NaN payloads +are preserved. If NaNs are present but `hasnan` is `false`, the behavior is undefined. #### Partialsort @@ -49,13 +51,14 @@ Equivalent to `std::partial_sort` in ```cpp -void x86simdsortStatic::partial_qsort(T* arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false) +void x86simdsortStatic::partial_qsort(T* arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false, bool trailing_nans = true); ``` Supported datatypes: `uint16_t`, `int16_t`, `_Float16`, `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. AVX2 versions currently support 32-bit and 64-bit dtypes only. For floating-point types, if `bool hasnan` is -set, NaNs are moved to the end of the array, preserving the bit-exact NaNs in -the input. If NaNs are present but `hasnan` is `false`, the behavior is +set, NaN placement is controlled by `trailing_nans`: `true` (default) places +NaNs at the end; `false` places them at the beginning. Bit-exact NaN payloads +are preserved. If NaNs are present but `hasnan` is `false`, the behavior is undefined. #### Argsort @@ -63,7 +66,7 @@ Equivalent to `np.argsort` in [NumPy](https://numpy.org/doc/stable/reference/generated/numpy.argsort.html). ```cpp -void x86simdsortStatic::argsort(const T* arr, size_t *arg, size_t arrsize, bool hasnan = false, bool descending = false); +void x86simdsortStatic::argsort(const T* arr, size_t *arg, size_t arrsize, bool hasnan = false, bool descending = false, bool trailing_nans = true); ``` Supported datatypes: `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. @@ -75,7 +78,7 @@ Equivalent to `np.argselect` in [NumPy](https://numpy.org/doc/stable/reference/generated/numpy.argpartition.html). ```cpp -void x86simdsortStatic::argselect(const T* arr, size_t *arg, size_t k, size_t arrsize, bool hasnan = false); +void x86simdsortStatic::argselect(const T* arr, size_t *arg, size_t k, size_t arrsize, bool hasnan = false, bool trailing_nans = true); ``` Supported datatypes: `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. From 05747989ffe46fcfdea6152b4afb2ec089a65017 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 25 May 2026 05:31:04 +0000 Subject: [PATCH 07/13] Add descending parameter to argselect - Add bool descending = false parameter to all argselect APIs and internal implementations (avx512, avx2, scalar, static, dispatch) - In xss_argselect: reverse the index array after partitioning when descending is true; pass descending to std_argselect_withnan for NaN-array fast path - std_argselect_withnan: honour descending flag in the comparator - scalar argselect: pick comparator once based on descending and trailing_nans, then call nth_element once - Tests: replace test_argselect with four typed tests covering ascending, descending, trailing NaNs and leading NaNs - IS_ARG_PARTITIONED: add descending parameter forwarded to IS_ARR_PARTITIONED --- lib/x86simdsort-avx2.cpp | 3 +- lib/x86simdsort-internal.h | 1 + lib/x86simdsort-scalar.h | 17 +++-- lib/x86simdsort-skx.cpp | 3 +- lib/x86simdsort.cpp | 10 +-- lib/x86simdsort.h | 1 + src/x86simdsort-static-incl.h | 8 ++- src/xss-common-argsort.h | 38 ++++++---- tests/.test-qsort.cpp.swp | Bin 0 -> 16384 bytes tests/test-qsort-common.h | 5 +- tests/test-qsort.cpp | 132 +++++++++++++++++++++++++++++++++- 11 files changed, 183 insertions(+), 35 deletions(-) create mode 100644 tests/.test-qsort.cpp.swp diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 3ac5bb12..a5ae0ae8 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -51,10 +51,11 @@ size_t k, \ size_t arrsize, \ bool hasnan, \ + bool descending, \ bool trailing_nans) \ { \ return x86simdsortStatic::argselect( \ - arr, k, arrsize, hasnan, trailing_nans); \ + arr, k, arrsize, hasnan, descending, trailing_nans); \ } #define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \ diff --git a/lib/x86simdsort-internal.h b/lib/x86simdsort-internal.h index 93e158af..7d589686 100644 --- a/lib/x86simdsort-internal.h +++ b/lib/x86simdsort-internal.h @@ -57,6 +57,7 @@ size_t k, \ size_t arrsize, \ bool hasnan = false, \ + bool descending = false, \ bool trailing_nans = true); \ } diff --git a/lib/x86simdsort-scalar.h b/lib/x86simdsort-scalar.h index ad703da6..3964e834 100644 --- a/lib/x86simdsort-scalar.h +++ b/lib/x86simdsort-scalar.h @@ -136,23 +136,22 @@ namespace scalar { size_t k, size_t arrsize, bool hasnan, + bool descending, bool trailing_nans) { UNUSED(hasnan); std::vector arg(arrsize); std::iota(arg.begin(), arg.end(), 0); - if (hasnan && !trailing_nans) { - std::nth_element(arg.begin(), - arg.begin() + k, - arg.end(), - compare_arg_nan_begin>(arr)); + std::function cmp; + if (trailing_nans) { + if (descending) { cmp = compare_arg_nan_end>(arr); } + else { cmp = compare_arg>(arr); } } else { - std::nth_element(arg.begin(), - arg.begin() + k, - arg.end(), - compare_arg>(arr)); + if (descending) { cmp = compare_arg>(arr); } + else { cmp = compare_arg_nan_begin>(arr); } } + std::nth_element(arg.begin(), arg.begin() + k, arg.end(), cmp); return arg; } template diff --git a/lib/x86simdsort-skx.cpp b/lib/x86simdsort-skx.cpp index 813c607e..eaf329a3 100644 --- a/lib/x86simdsort-skx.cpp +++ b/lib/x86simdsort-skx.cpp @@ -51,10 +51,11 @@ size_t k, \ size_t arrsize, \ bool hasnan, \ + bool descending, \ bool trailing_nans) \ { \ return x86simdsortStatic::argselect( \ - arr, k, arrsize, hasnan, trailing_nans); \ + arr, k, arrsize, hasnan, descending, trailing_nans); \ } #define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \ diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 0da4cc47..0c53b4d1 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -173,20 +173,21 @@ namespace x86simdsort { #define DECLARE_INTERNAL_argselect(TYPE) \ static void CAT(resolve_argselect, TYPE)(void); \ static std::vector (*internal_argselect##TYPE)( \ - const TYPE *, size_t, size_t, bool, bool) \ + const TYPE *, size_t, size_t, bool, bool, bool) \ = NULL; \ template <> \ std::vector XSS_EXPORT_SYMBOL argselect(const TYPE *arr, \ size_t k, \ size_t arrsize, \ bool hasnan, \ + bool descending, \ bool trailing_nans) \ { \ if (internal_argselect##TYPE == NULL) { \ CAT(resolve_argselect, TYPE)(); \ } \ return (*internal_argselect##TYPE)( \ - arr, k, arrsize, hasnan, trailing_nans); \ + arr, k, arrsize, hasnan, descending, trailing_nans); \ } #else @@ -254,17 +255,18 @@ namespace x86simdsort { #define DECLARE_INTERNAL_argselect(TYPE) \ static std::vector (*internal_argselect##TYPE)( \ - const TYPE *, size_t, size_t, bool, bool) \ + const TYPE *, size_t, size_t, bool, bool, bool) \ = NULL; \ template <> \ std::vector XSS_EXPORT_SYMBOL argselect(const TYPE *arr, \ size_t k, \ size_t arrsize, \ bool hasnan, \ + bool descending, \ bool trailing_nans) \ { \ return (*internal_argselect##TYPE)( \ - arr, k, arrsize, hasnan, trailing_nans); \ + arr, k, arrsize, hasnan, descending, trailing_nans); \ } #endif // _MSC_VER diff --git a/lib/x86simdsort.h b/lib/x86simdsort.h index b9e94b2e..6f46e05f 100644 --- a/lib/x86simdsort.h +++ b/lib/x86simdsort.h @@ -57,6 +57,7 @@ XSS_EXPORT_SYMBOL std::vector argselect(const T *arr, size_t k, size_t arrsize, bool hasnan = false, + bool descending = false, bool trailing_nans = true); // keyvalue sort diff --git a/src/x86simdsort-static-incl.h b/src/x86simdsort-static-incl.h index d958c515..2aa33238 100644 --- a/src/x86simdsort-static-incl.h +++ b/src/x86simdsort-static-incl.h @@ -50,6 +50,7 @@ X86_SIMD_SORT_FINLINE std::vector argselect(const T *arr, size_t k, size_t size, bool hasnan = false, + bool descending = false, bool trailing_nans = true); /* argselect API required by NumPy: */ @@ -59,6 +60,7 @@ void X86_SIMD_SORT_FINLINE argselect(const T *arr, size_t k, size_t size, bool hasnan = false, + bool descending = false, bool trailing_nans = true); template @@ -148,9 +150,10 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key, size_t k, \ size_t size, \ bool hasnan, \ + bool descending, \ bool trailing_nans) \ { \ - ISA##_argselect(arr, arg, k, size, hasnan, trailing_nans); \ + ISA##_argselect(arr, arg, k, size, hasnan, descending, trailing_nans); \ } \ template \ X86_SIMD_SORT_FINLINE std::vector x86simdsortStatic::argselect( \ @@ -158,12 +161,13 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key, size_t k, \ size_t size, \ bool hasnan, \ + bool descending, \ bool trailing_nans) \ { \ std::vector indices(size); \ std::iota(indices.begin(), indices.end(), 0); \ x86simdsortStatic::argselect( \ - arr, indices.data(), k, size, hasnan, trailing_nans); \ + arr, indices.data(), k, size, hasnan, descending, trailing_nans); \ return indices; \ } \ template \ diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h index dfe331e4..885fb14d 100644 --- a/src/xss-common-argsort.h +++ b/src/xss-common-argsort.h @@ -16,18 +16,22 @@ X86_SIMD_SORT_INLINE void std_argselect_withnan(const T *arr, arrsize_t k, arrsize_t left, arrsize_t right, - bool trailing_nans = true) + bool trailing_nans = true, + bool descending = false) { - std::nth_element(arg + left, - arg + k, - arg + right, - [arr, trailing_nans](arrsize_t a, arrsize_t b) -> bool { - bool a_nan = std::isnan(arr[a]); - bool b_nan = std::isnan(arr[b]); - if (!a_nan && !b_nan) { return arr[a] < arr[b]; } - if (a_nan && b_nan) { return false; } - return trailing_nans ? !a_nan : a_nan; - }); + std::nth_element( + arg + left, + arg + k, + arg + right, + [arr, trailing_nans, descending](arrsize_t a, arrsize_t b) -> bool { + bool a_nan = std::isnan(arr[a]); + bool b_nan = std::isnan(arr[b]); + if (!a_nan && !b_nan) { + return descending ? arr[a] > arr[b] : arr[a] < arr[b]; + } + if (a_nan && b_nan) { return false; } + return trailing_nans ? !a_nan : a_nan; + }); } /* argsort using std::sort, handles NaN placement and descending order */ @@ -705,6 +709,7 @@ X86_SIMD_SORT_INLINE void xss_argselect(const T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false, + bool descending = false, bool trailing_nans = true) { /* TODO optimization: on 32-bit, use full_vector for 32-bit dtype */ @@ -720,13 +725,16 @@ X86_SIMD_SORT_INLINE void xss_argselect(const T *arr, if (arrsize > 1) { if constexpr (xss::fp::is_floating_point_v) { if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argselect_withnan(arr, arg, k, 0, arrsize, trailing_nans); + std_argselect_withnan( + arr, arg, k, 0, arrsize, trailing_nans, descending); return; } } UNUSED(hasnan); argselect_( arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + + if (descending) { std::reverse(arg, arg + arrsize); } } #ifdef __MMX__ @@ -741,10 +749,11 @@ X86_SIMD_SORT_INLINE void avx512_argselect(const T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false, + bool descending = false, bool trailing_nans = true) { xss_argselect( - arr, arg, k, arrsize, hasnan, trailing_nans); + arr, arg, k, arrsize, hasnan, descending, trailing_nans); } template @@ -753,10 +762,11 @@ X86_SIMD_SORT_INLINE void avx2_argselect(const T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false, + bool descending = false, bool trailing_nans = true) { xss_argselect( - arr, arg, k, arrsize, hasnan, trailing_nans); + arr, arg, k, arrsize, hasnan, descending, trailing_nans); } #endif // XSS_COMMON_ARGSORT diff --git a/tests/.test-qsort.cpp.swp b/tests/.test-qsort.cpp.swp new file mode 100644 index 0000000000000000000000000000000000000000..71a71009ac435f623262c05551053caf4600dc24 GIT binary patch literal 16384 zcmeI2O^*~s7{?32HxM*nVmw(c2WQF7%gzE~yDSVd6CHJT2WC*j#I|<2cBZALd$zj< zcU=LMgC3L*V1fyW7d?0}@n96gfp~?G=)veg;|1c&ClLSDFEc&E&N_P_fL15J>6d!y znW}$1U0qe(bD0-Q)AXs-IKlQ1As=4c_WGI9I`ZC5LR?mFw2yP{%CuFs(t+t@`K3(Z zxuZ|+TE&c<>^Mv}ms%EEs=H3xHqwiGb|*cvX(T)~dG*P}Io$nNhjT@TZL?90DCJijO0YhlJHe|TB#lmbctrGQdEDWDWk z3Md7X0!jg;fKosypcJ?X3NV9^Ec|Yf3jlckAI<--JV?l=;3Sv<(;x+s;1Tfq7D9dl z--AowBXATH!1c|9`~W@&r@>ob8N3Siff29){Q3YP=fTI|6nF)sKmyzaE^b18@HRLF z7ud<)KlGvEU-5B37to}0niHX3g8j!P4) zEgb5ZZ*rAEnDAGM{rCi#C{NGn?+Erl;dTm)~(D#8QioKJ6QJ zeaee*pS}^?EtBfYw=U0ggL^e@8zzcC$O+$d@GDEx zJ(6e0lB!vpxrwk0@}k^Ok#53v>203)3E@wvqUR1hqgQ4RRf~m${9A4y&zQ_8f9fbx zw`TF2!Ew!LwwTK&t1v1vmo>+zCk#!qxaV~*PE937`wG@|)%nvDikBE2Bi$q#IFlCL zvZdyD-L$3Qc5~h~_&hyQK`kn|89iUpXJ+%y7W2>dIPu#TbwQKi3!F%8O$&uiNRbRC z7<~GI$)=abK|`b9J>LBK9p^RSxUm}2PBri^a2SL_RFwgFUE{rIt*dnyv2~X-!Q)MOYlF5Y zQvL74{v}Rvr@Cz!EhMhyfXM5#!J4CWj(O;3`rwJKL%a3!{8VvDh@Ce5xj{H)CH%L zvfXUfa4Z4lIBi*bjUg?A4|ePzopolmbctrGQdEDWDWk3Md7X z0!o29PJz-)rJ9?mmU38{TVbtk(evc{rsz%N!3tiw$u)v1mIb{2CDv%kqfiiN;}N3;tE-06x+jmBCCsG&87Rm{vR`ntlt0t literal 0 HcmV?d00001 diff --git a/tests/test-qsort-common.h b/tests/test-qsort-common.h index b568d540..598aee49 100644 --- a/tests/test-qsort-common.h +++ b/tests/test-qsort-common.h @@ -113,13 +113,14 @@ void IS_ARG_PARTITIONED(std::vector arr, std::vector arg, T true_kth, size_t k, - std::string type) + std::string type, + bool descending = false) { EXPECT_UNIQUE(arg) std::vector part_arr; for (auto ii : arg) { part_arr.push_back(arr[ii]); } - IS_ARR_PARTITIONED(part_arr, k, true_kth, type); + IS_ARR_PARTITIONED(part_arr, k, true_kth, type, descending); } #endif diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index b4baeb05..a165211e 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -367,7 +367,7 @@ TYPED_TEST_P(simdsort, test_qselect_leading_nans) } } -TYPED_TEST_P(simdsort, test_argselect) +TYPED_TEST_P(simdsort, test_argselect_ascending) { for (auto type : this->arrtype) { bool hasnan = is_nan_test(type); @@ -391,6 +391,131 @@ TYPED_TEST_P(simdsort, test_argselect) } } +TYPED_TEST_P(simdsort, test_argselect_descending) +{ + for (auto type : this->arrtype) { + bool hasnan = is_nan_test(type); + for (auto size : this->arrsize) { + size_t k = size != 0 ? rand() % size : 0; + std::vector arr = get_array(type, size); + std::vector sortedarr = arr; + + auto arg = x86simdsort::argselect( + arr.data(), k, arr.size(), hasnan, true); +#ifndef XSS_ASAN_CI_NOCHECK + std::nth_element( + sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare_nan_end>()); + if (size == 0) continue; + IS_ARG_PARTITIONED(arr, arg, sortedarr[k], k, type, true); +#endif + arr.clear(); + sortedarr.clear(); + } + } +} + +TYPED_TEST_P(simdsort, test_argselect_trailing_nans) +{ + if constexpr (xss::fp::is_floating_point_v) { + std::vector nan_types + = {"rand_with_nan", "rand_with_max_and_nan"}; + for (auto type : nan_types) { + for (auto size : this->arrsize) { + size_t k = size != 0 ? rand() % size : 0; + std::vector base = get_array(type, size); + + // ascending, NaNs at end + { + std::vector arr = base; + std::vector sortedarr = base; + auto arg = x86simdsort::argselect( + arr.data(), k, arr.size(), true, false, true); +#ifndef XSS_ASAN_CI_NOCHECK + std::nth_element( + sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare_nan_end>()); + if (size == 0) continue; + IS_ARG_PARTITIONED(arr, arg, sortedarr[k], k, type); +#endif + } + + // descending, NaNs at end + { + std::vector arr = base; + std::vector sortedarr = base; + auto arg = x86simdsort::argselect( + arr.data(), k, arr.size(), true, true, true); +#ifndef XSS_ASAN_CI_NOCHECK + std::nth_element( + sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare_nan_end>()); + if (size == 0) continue; + IS_ARG_PARTITIONED(arr, arg, sortedarr[k], k, type, true); +#endif + } + } + } + } +} + +TYPED_TEST_P(simdsort, test_argselect_leading_nans) +{ + if constexpr (xss::fp::is_floating_point_v) { + std::vector nan_types + = {"rand_with_nan", "rand_with_max_and_nan"}; + for (auto type : nan_types) { + for (auto size : this->arrsize) { + size_t k = size != 0 ? rand() % size : 0; + std::vector base = get_array(type, size); + + // ascending, NaNs at start + { + std::vector arr = base; + std::vector sortedarr = base; + auto arg = x86simdsort::argselect( + arr.data(), k, arr.size(), true, false, false); +#ifndef XSS_ASAN_CI_NOCHECK + std::nth_element( + sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare_nan_begin>()); + if (size == 0) continue; + IS_ARG_PARTITIONED(arr, arg, sortedarr[k], k, type); +#endif + } + + // descending, NaNs at start + { + std::vector arr = base; + std::vector sortedarr = base; + auto arg = x86simdsort::argselect( + arr.data(), k, arr.size(), true, true, false); +#ifndef XSS_ASAN_CI_NOCHECK + std::nth_element( + sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare_nan_begin>()); + if (size == 0) continue; + IS_ARG_PARTITIONED(arr, arg, sortedarr[k], k, type, true); +#endif + } + } + } + } +} + TYPED_TEST_P(simdsort, test_partial_qsort_ascending) { for (auto type : this->arrtype) { @@ -571,7 +696,10 @@ REGISTER_TYPED_TEST_SUITE_P(simdsort, test_argsort_descending, test_argsort_trailing_nans, test_argsort_leading_nans, - test_argselect, + test_argselect_ascending, + test_argselect_descending, + test_argselect_trailing_nans, + test_argselect_leading_nans, test_qselect_ascending, test_qselect_descending, test_qselect_trailing_nans, From a854d5887726d89179b9a57597d7ed79d3cd69aa Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 25 May 2026 05:35:07 +0000 Subject: [PATCH 08/13] Fix descending argselect: partition at mirror position before reversing argselect_ always partitions in ascending order. For descending, we must select at position arrsize-1-k so the k-th largest lands at that mirror index; reversing then moves it to position k with the correct left/right partition invariant. --- src/xss-common-argsort.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h index 885fb14d..02ef06d6 100644 --- a/src/xss-common-argsort.h +++ b/src/xss-common-argsort.h @@ -731,8 +731,11 @@ X86_SIMD_SORT_INLINE void xss_argselect(const T *arr, } } UNUSED(hasnan); + /* For descending, partition at the mirror position so the k-th + * largest lands at arrsize-1-k; reversal then moves it to k. */ + arrsize_t pos = descending ? arrsize - 1 - k : k; argselect_( - arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + arr, arg, pos, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); if (descending) { std::reverse(arg, arg + arrsize); } } From 81bfcff2f1cac89012f0f96197f6f42a95fa5019 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 25 May 2026 05:36:52 +0000 Subject: [PATCH 09/13] Clang format files --- lib/x86simdsort-scalar.h | 12 +++++++++--- src/x86simdsort-static-incl.h | 9 +++++++-- tests/test-qsort.cpp | 11 +++++------ 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/lib/x86simdsort-scalar.h b/lib/x86simdsort-scalar.h index 3964e834..1995370c 100644 --- a/lib/x86simdsort-scalar.h +++ b/lib/x86simdsort-scalar.h @@ -144,12 +144,18 @@ namespace scalar { std::iota(arg.begin(), arg.end(), 0); std::function cmp; if (trailing_nans) { - if (descending) { cmp = compare_arg_nan_end>(arr); } - else { cmp = compare_arg>(arr); } + if (descending) { + cmp = compare_arg_nan_end>(arr); + } + else { + cmp = compare_arg>(arr); + } } else { if (descending) { cmp = compare_arg>(arr); } - else { cmp = compare_arg_nan_begin>(arr); } + else { + cmp = compare_arg_nan_begin>(arr); + } } std::nth_element(arg.begin(), arg.begin() + k, arg.end(), cmp); return arg; diff --git a/src/x86simdsort-static-incl.h b/src/x86simdsort-static-incl.h index 2aa33238..b13c2525 100644 --- a/src/x86simdsort-static-incl.h +++ b/src/x86simdsort-static-incl.h @@ -166,8 +166,13 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key, { \ std::vector indices(size); \ std::iota(indices.begin(), indices.end(), 0); \ - x86simdsortStatic::argselect( \ - arr, indices.data(), k, size, hasnan, descending, trailing_nans); \ + x86simdsortStatic::argselect(arr, \ + indices.data(), \ + k, \ + size, \ + hasnan, \ + descending, \ + trailing_nans); \ return indices; \ } \ template \ diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index a165211e..084fc889 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -483,12 +483,11 @@ TYPED_TEST_P(simdsort, test_argselect_leading_nans) auto arg = x86simdsort::argselect( arr.data(), k, arr.size(), true, false, false); #ifndef XSS_ASAN_CI_NOCHECK - std::nth_element( - sortedarr.begin(), - sortedarr.begin() + k, - sortedarr.end(), - compare_nan_begin>()); + std::nth_element(sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare_nan_begin>()); if (size == 0) continue; IS_ARG_PARTITIONED(arr, arg, sortedarr[k], k, type); #endif From 8c4cc0ddcb8a8c6fda272c8b79bf4562fc412c21 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 25 May 2026 05:56:45 +0000 Subject: [PATCH 10/13] Unused parameter --- src/xss-common-argsort.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h index 02ef06d6..88c88c32 100644 --- a/src/xss-common-argsort.h +++ b/src/xss-common-argsort.h @@ -624,6 +624,7 @@ X86_SIMD_SORT_INLINE void xss_argsort(const T *arr, } } UNUSED(hasnan); + UNUSED(trailing_nans); /* early exit for already sorted arrays: float/double with nan never reach here*/ auto comp = descending ? Comparator::STDSortComparator @@ -731,6 +732,7 @@ X86_SIMD_SORT_INLINE void xss_argselect(const T *arr, } } UNUSED(hasnan); + UNUSED(trailing_nans); /* For descending, partition at the mirror position so the k-th * largest lands at arrsize-1-k; reversal then moves it to k. */ arrsize_t pos = descending ? arrsize - 1 - k : k; From 030227a08b12ed3e5b13b0092ad3d15caf013a9b Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 25 May 2026 08:01:02 +0000 Subject: [PATCH 11/13] Update argselect docs to reflect new descending parameter --- README.md | 7 ++++++- src/README.md | 9 +++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 44c719d7..0eab5f4a 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ data types. ## Arg sort routines on arrays ```cpp std::vector arg = x86simdsort::argsort(const T* arr, size_t size, bool hasnan, bool descending, bool trailing_nans); -std::vector arg = x86simdsort::argselect(const T* arr, size_t k, size_t size, bool hasnan, bool trailing_nans); +std::vector arg = x86simdsort::argselect(const T* arr, size_t k, size_t size, bool hasnan, bool descending, bool trailing_nans); ``` Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t, int32_t, double, uint64_t, int64_t]` Note that argsort and argselect are not accelerated with SIMD when using 16-bit @@ -183,6 +183,11 @@ is controlled by the optional `bool trailing_nans` parameter (default `true`): - `trailing_nans=false`: NaN values are placed at the **beginning** of the result, regardless of sort direction. +All routines accept an optional `bool descending` parameter (default `false`). +When `descending=true`, results are in descending order. For `argselect`, the +k-th element becomes the k-th **largest**, with all elements before index k +being greater than or equal to it. + Note that the arg methods (argsort and argselect) will not use the SIMD based algorithms if they detect NaN values in the array. You can read details of all the implementations [here](https://github.com/intel/x86-simd-sort/blob/main/src/README.md). diff --git a/src/README.md b/src/README.md index 57d3e12e..6100c897 100644 --- a/src/README.md +++ b/src/README.md @@ -78,12 +78,17 @@ Equivalent to `np.argselect` in [NumPy](https://numpy.org/doc/stable/reference/generated/numpy.argpartition.html). ```cpp -void x86simdsortStatic::argselect(const T* arr, size_t *arg, size_t k, size_t arrsize, bool hasnan = false, bool trailing_nans = true); +void x86simdsortStatic::argselect(const T* arr, size_t *arg, size_t k, size_t arrsize, bool hasnan = false, bool descending = false, bool trailing_nans = true); ``` Supported datatypes: `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. -The algorithm resorts to scalar `std::sort` if the array contains NaNs. +When `descending=true`, the k-th element is the k-th **largest** and elements +before index k are all greater than or equal to it. For floating-point types, +if `bool hasnan` is set, NaN placement is controlled by `trailing_nans`: +`true` (default) places NaNs at the end; `false` places them at the beginning. + +The algorithm resorts to scalar `std::nth_element` if the array contains NaNs. #### Key-value sort ```cpp From 740d2b102a2500f9fc7d7ca7cf94bf1a9c6e0533 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 16 Jun 2026 15:35:56 +0000 Subject: [PATCH 12/13] remove temp file --- tests/.test-qsort.cpp.swp | Bin 16384 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tests/.test-qsort.cpp.swp diff --git a/tests/.test-qsort.cpp.swp b/tests/.test-qsort.cpp.swp deleted file mode 100644 index 71a71009ac435f623262c05551053caf4600dc24..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 16384 zcmeI2O^*~s7{?32HxM*nVmw(c2WQF7%gzE~yDSVd6CHJT2WC*j#I|<2cBZALd$zj< zcU=LMgC3L*V1fyW7d?0}@n96gfp~?G=)veg;|1c&ClLSDFEc&E&N_P_fL15J>6d!y znW}$1U0qe(bD0-Q)AXs-IKlQ1As=4c_WGI9I`ZC5LR?mFw2yP{%CuFs(t+t@`K3(Z zxuZ|+TE&c<>^Mv}ms%EEs=H3xHqwiGb|*cvX(T)~dG*P}Io$nNhjT@TZL?90DCJijO0YhlJHe|TB#lmbctrGQdEDWDWk z3Md7X0!jg;fKosypcJ?X3NV9^Ec|Yf3jlckAI<--JV?l=;3Sv<(;x+s;1Tfq7D9dl z--AowBXATH!1c|9`~W@&r@>ob8N3Siff29){Q3YP=fTI|6nF)sKmyzaE^b18@HRLF z7ud<)KlGvEU-5B37to}0niHX3g8j!P4) zEgb5ZZ*rAEnDAGM{rCi#C{NGn?+Erl;dTm)~(D#8QioKJ6QJ zeaee*pS}^?EtBfYw=U0ggL^e@8zzcC$O+$d@GDEx zJ(6e0lB!vpxrwk0@}k^Ok#53v>203)3E@wvqUR1hqgQ4RRf~m${9A4y&zQ_8f9fbx zw`TF2!Ew!LwwTK&t1v1vmo>+zCk#!qxaV~*PE937`wG@|)%nvDikBE2Bi$q#IFlCL zvZdyD-L$3Qc5~h~_&hyQK`kn|89iUpXJ+%y7W2>dIPu#TbwQKi3!F%8O$&uiNRbRC z7<~GI$)=abK|`b9J>LBK9p^RSxUm}2PBri^a2SL_RFwgFUE{rIt*dnyv2~X-!Q)MOYlF5Y zQvL74{v}Rvr@Cz!EhMhyfXM5#!J4CWj(O;3`rwJKL%a3!{8VvDh@Ce5xj{H)CH%L zvfXUfa4Z4lIBi*bjUg?A4|ePzopolmbctrGQdEDWDWk3Md7X z0!o29PJz-)rJ9?mmU38{TVbtk(evc{rsz%N!3tiw$u)v1mIb{2CDv%kqfiiN;}N3;tE-06x+jmBCCsG&87Rm{vR`ntlt0t From 0de11c79202410f995b8a73558590bea4f42910a Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 16 Jun 2026 15:40:29 +0000 Subject: [PATCH 13/13] rename trailing_nans to nans_last --- lib/x86simdsort-avx2.cpp | 20 ++++++------- lib/x86simdsort-icl.cpp | 36 +++++++++++------------ lib/x86simdsort-internal.h | 10 +++---- lib/x86simdsort-scalar.h | 24 ++++++++-------- lib/x86simdsort-skx.cpp | 20 ++++++------- lib/x86simdsort-spr.cpp | 12 ++++---- lib/x86simdsort.cpp | 40 +++++++++++++------------- lib/x86simdsort.h | 10 +++---- src/README.md | 18 ++++++------ src/avx512-16bit-qsort.hpp | 14 ++++----- src/x86simdsort-static-incl.h | 54 +++++++++++++++++------------------ src/xss-common-argsort.h | 40 +++++++++++++------------- src/xss-common-qsort.h | 32 ++++++++++----------- tests/test-qsort.cpp | 16 +++++------ 14 files changed, 173 insertions(+), 173 deletions(-) diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index a5ae0ae8..afac33c7 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -9,10 +9,10 @@ size_t arrsize, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ x86simdsortStatic::qsort( \ - arr, arrsize, hasnan, descending, trailing_nans); \ + arr, arrsize, hasnan, descending, nans_last); \ } \ template <> \ void qselect(type *arr, \ @@ -20,10 +20,10 @@ size_t arrsize, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ x86simdsortStatic::qselect( \ - arr, k, arrsize, hasnan, descending, trailing_nans); \ + arr, k, arrsize, hasnan, descending, nans_last); \ } \ template <> \ void partial_qsort(type *arr, \ @@ -31,20 +31,20 @@ size_t arrsize, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ x86simdsortStatic::partial_qsort( \ - arr, k, arrsize, hasnan, descending, trailing_nans); \ + arr, k, arrsize, hasnan, descending, nans_last); \ } \ template <> \ std::vector argsort(const type *arr, \ size_t arrsize, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ return x86simdsortStatic::argsort( \ - arr, arrsize, hasnan, descending, trailing_nans); \ + arr, arrsize, hasnan, descending, nans_last); \ } \ template <> \ std::vector argselect(const type *arr, \ @@ -52,10 +52,10 @@ size_t arrsize, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ return x86simdsortStatic::argselect( \ - arr, k, arrsize, hasnan, descending, trailing_nans); \ + arr, k, arrsize, hasnan, descending, nans_last); \ } #define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \ diff --git a/lib/x86simdsort-icl.cpp b/lib/x86simdsort-icl.cpp index 27d891fb..d1d9adff 100644 --- a/lib/x86simdsort-icl.cpp +++ b/lib/x86simdsort-icl.cpp @@ -12,9 +12,9 @@ namespace avx512 { size_t size, bool hasnan, bool descending, - bool trailing_nans) + bool nans_last) { - x86simdsortStatic::qsort(arr, size, hasnan, descending, trailing_nans); + x86simdsortStatic::qsort(arr, size, hasnan, descending, nans_last); } template <> void qselect(uint16_t *arr, @@ -22,10 +22,10 @@ namespace avx512 { size_t arrsize, bool hasnan, bool descending, - bool trailing_nans) + bool nans_last) { x86simdsortStatic::qselect( - arr, k, arrsize, hasnan, descending, trailing_nans); + arr, k, arrsize, hasnan, descending, nans_last); } template <> void partial_qsort(uint16_t *arr, @@ -33,19 +33,19 @@ namespace avx512 { size_t arrsize, bool hasnan, bool descending, - bool trailing_nans) + bool nans_last) { x86simdsortStatic::partial_qsort( - arr, k, arrsize, hasnan, descending, trailing_nans); + arr, k, arrsize, hasnan, descending, nans_last); } template <> void qsort(int16_t *arr, size_t size, bool hasnan, bool descending, - bool trailing_nans) + bool nans_last) { - x86simdsortStatic::qsort(arr, size, hasnan, descending, trailing_nans); + x86simdsortStatic::qsort(arr, size, hasnan, descending, nans_last); } template <> void qselect(int16_t *arr, @@ -53,10 +53,10 @@ namespace avx512 { size_t arrsize, bool hasnan, bool descending, - bool trailing_nans) + bool nans_last) { x86simdsortStatic::qselect( - arr, k, arrsize, hasnan, descending, trailing_nans); + arr, k, arrsize, hasnan, descending, nans_last); } template <> void partial_qsort(int16_t *arr, @@ -64,10 +64,10 @@ namespace avx512 { size_t arrsize, bool hasnan, bool descending, - bool trailing_nans) + bool nans_last) { x86simdsortStatic::partial_qsort( - arr, k, arrsize, hasnan, descending, trailing_nans); + arr, k, arrsize, hasnan, descending, nans_last); } } // namespace avx512 namespace fp16_icl { @@ -77,9 +77,9 @@ namespace fp16_icl { size_t size, bool hasnan, bool descending, - bool trailing_nans) + bool nans_last) { - x86simdsortStatic::qsort(arr, size, hasnan, descending, trailing_nans); + x86simdsortStatic::qsort(arr, size, hasnan, descending, nans_last); } template <> void qselect(_Float16 *arr, @@ -87,10 +87,10 @@ namespace fp16_icl { size_t arrsize, bool hasnan, bool descending, - bool trailing_nans) + bool nans_last) { x86simdsortStatic::qselect( - arr, k, arrsize, hasnan, descending, trailing_nans); + arr, k, arrsize, hasnan, descending, nans_last); } template <> void partial_qsort(_Float16 *arr, @@ -98,10 +98,10 @@ namespace fp16_icl { size_t arrsize, bool hasnan, bool descending, - bool trailing_nans) + bool nans_last) { x86simdsortStatic::partial_qsort( - arr, k, arrsize, hasnan, descending, trailing_nans); + arr, k, arrsize, hasnan, descending, nans_last); } #endif } // namespace fp16_icl diff --git a/lib/x86simdsort-internal.h b/lib/x86simdsort-internal.h index 7d589686..d6e30b5e 100644 --- a/lib/x86simdsort-internal.h +++ b/lib/x86simdsort-internal.h @@ -11,7 +11,7 @@ size_t arrsize, \ bool hasnan = false, \ bool descending = false, \ - bool trailing_nans = true); \ + bool nans_last = true); \ template \ XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key, \ T2 *val, \ @@ -24,7 +24,7 @@ size_t arrsize, \ bool hasnan = false, \ bool descending = false, \ - bool trailing_nans = true); \ + bool nans_last = true); \ template \ XSS_HIDE_SYMBOL void keyvalue_select(T1 *key, \ T2 *val, \ @@ -38,7 +38,7 @@ size_t arrsize, \ bool hasnan = false, \ bool descending = false, \ - bool trailing_nans = true); \ + bool nans_last = true); \ template \ XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key, \ T2 *val, \ @@ -51,14 +51,14 @@ size_t arrsize, \ bool hasnan = false, \ bool descending = false, \ - bool trailing_nans = true); \ + bool nans_last = true); \ template \ XSS_HIDE_SYMBOL std::vector argselect(const T *arr, \ size_t k, \ size_t arrsize, \ bool hasnan = false, \ bool descending = false, \ - bool trailing_nans = true); \ + bool nans_last = true); \ } namespace xss { diff --git a/lib/x86simdsort-scalar.h b/lib/x86simdsort-scalar.h index 1995370c..7faf4446 100644 --- a/lib/x86simdsort-scalar.h +++ b/lib/x86simdsort-scalar.h @@ -25,11 +25,11 @@ namespace utils { } template decltype(auto) - get_cmp_func(bool hasnan, bool reverse, bool trailing_nans = true) + get_cmp_func(bool hasnan, bool reverse, bool nans_last = true) { std::function cmp; if (hasnan) { - if (trailing_nans) { + if (nans_last) { if (reverse == true) { cmp = compare_nan_end>(); } @@ -60,11 +60,11 @@ namespace scalar { size_t arrsize, bool hasnan, bool reversed, - bool trailing_nans) + bool nans_last) { std::sort(arr, arr + arrsize, - xss::utils::get_cmp_func(hasnan, reversed, trailing_nans)); + xss::utils::get_cmp_func(hasnan, reversed, nans_last)); } template @@ -73,13 +73,13 @@ namespace scalar { size_t arrsize, bool hasnan, bool reversed, - bool trailing_nans) + bool nans_last) { std::nth_element( arr, arr + k, arr + arrsize, - xss::utils::get_cmp_func(hasnan, reversed, trailing_nans)); + xss::utils::get_cmp_func(hasnan, reversed, nans_last)); } template void partial_qsort(T *arr, @@ -87,25 +87,25 @@ namespace scalar { size_t arrsize, bool hasnan, bool reversed, - bool trailing_nans) + bool nans_last) { std::partial_sort( arr, arr + k, arr + arrsize, - xss::utils::get_cmp_func(hasnan, reversed, trailing_nans)); + xss::utils::get_cmp_func(hasnan, reversed, nans_last)); } template std::vector argsort(const T *arr, size_t arrsize, bool hasnan, bool reversed, - bool trailing_nans) + bool nans_last) { UNUSED(hasnan); std::vector arg(arrsize); std::iota(arg.begin(), arg.end(), 0); - if (trailing_nans) { + if (nans_last) { if (reversed) { std::sort(arg.begin(), arg.end(), @@ -137,13 +137,13 @@ namespace scalar { size_t arrsize, bool hasnan, bool descending, - bool trailing_nans) + bool nans_last) { UNUSED(hasnan); std::vector arg(arrsize); std::iota(arg.begin(), arg.end(), 0); std::function cmp; - if (trailing_nans) { + if (nans_last) { if (descending) { cmp = compare_arg_nan_end>(arr); } diff --git a/lib/x86simdsort-skx.cpp b/lib/x86simdsort-skx.cpp index eaf329a3..b38653e7 100644 --- a/lib/x86simdsort-skx.cpp +++ b/lib/x86simdsort-skx.cpp @@ -9,10 +9,10 @@ size_t arrsize, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ x86simdsortStatic::qsort( \ - arr, arrsize, hasnan, descending, trailing_nans); \ + arr, arrsize, hasnan, descending, nans_last); \ } \ template <> \ void qselect(type *arr, \ @@ -20,10 +20,10 @@ size_t arrsize, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ x86simdsortStatic::qselect( \ - arr, k, arrsize, hasnan, descending, trailing_nans); \ + arr, k, arrsize, hasnan, descending, nans_last); \ } \ template <> \ void partial_qsort(type *arr, \ @@ -31,20 +31,20 @@ size_t arrsize, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ x86simdsortStatic::partial_qsort( \ - arr, k, arrsize, hasnan, descending, trailing_nans); \ + arr, k, arrsize, hasnan, descending, nans_last); \ } \ template <> \ std::vector argsort(const type *arr, \ size_t arrsize, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ return x86simdsortStatic::argsort( \ - arr, arrsize, hasnan, descending, trailing_nans); \ + arr, arrsize, hasnan, descending, nans_last); \ } \ template <> \ std::vector argselect(const type *arr, \ @@ -52,10 +52,10 @@ size_t arrsize, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ return x86simdsortStatic::argselect( \ - arr, k, arrsize, hasnan, descending, trailing_nans); \ + arr, k, arrsize, hasnan, descending, nans_last); \ } #define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \ diff --git a/lib/x86simdsort-spr.cpp b/lib/x86simdsort-spr.cpp index c615d785..9e8b2887 100644 --- a/lib/x86simdsort-spr.cpp +++ b/lib/x86simdsort-spr.cpp @@ -9,9 +9,9 @@ namespace fp16_spr { size_t size, bool hasnan, bool descending, - bool trailing_nans) + bool nans_last) { - x86simdsortStatic::qsort(arr, size, hasnan, descending, trailing_nans); + x86simdsortStatic::qsort(arr, size, hasnan, descending, nans_last); } template <> void qselect(_Float16 *arr, @@ -19,10 +19,10 @@ namespace fp16_spr { size_t arrsize, bool hasnan, bool descending, - bool trailing_nans) + bool nans_last) { x86simdsortStatic::qselect( - arr, k, arrsize, hasnan, descending, trailing_nans); + arr, k, arrsize, hasnan, descending, nans_last); } template <> void partial_qsort(_Float16 *arr, @@ -30,10 +30,10 @@ namespace fp16_spr { size_t arrsize, bool hasnan, bool descending, - bool trailing_nans) + bool nans_last) { x86simdsortStatic::partial_qsort( - arr, k, arrsize, hasnan, descending, trailing_nans); + arr, k, arrsize, hasnan, descending, nans_last); } } // namespace fp16_spr } // namespace xss diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 0c53b4d1..c90661e0 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -108,11 +108,11 @@ namespace x86simdsort { size_t arrsize, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ if (internal_qsort##TYPE == NULL) { CAT(resolve_qsort, TYPE)(); } \ (*internal_qsort##TYPE)( \ - arr, arrsize, hasnan, descending, trailing_nans); \ + arr, arrsize, hasnan, descending, nans_last); \ } #define DECLARE_INTERNAL_qselect(TYPE) \ @@ -126,11 +126,11 @@ namespace x86simdsort { size_t arrsize, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ if (internal_qselect##TYPE == NULL) { CAT(resolve_qselect, TYPE)(); } \ (*internal_qselect##TYPE)( \ - arr, k, arrsize, hasnan, descending, trailing_nans); \ + arr, k, arrsize, hasnan, descending, nans_last); \ } #define DECLARE_INTERNAL_partial_qsort(TYPE) \ @@ -144,13 +144,13 @@ namespace x86simdsort { size_t arrsize, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ if (internal_partial_qsort##TYPE == NULL) { \ CAT(resolve_partial_qsort, TYPE)(); \ } \ (*internal_partial_qsort##TYPE)( \ - arr, k, arrsize, hasnan, descending, trailing_nans); \ + arr, k, arrsize, hasnan, descending, nans_last); \ } #define DECLARE_INTERNAL_argsort(TYPE) \ @@ -163,11 +163,11 @@ namespace x86simdsort { size_t arrsize, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ if (internal_argsort##TYPE == NULL) { CAT(resolve_argsort, TYPE)(); } \ return (*internal_argsort##TYPE)( \ - arr, arrsize, hasnan, descending, trailing_nans); \ + arr, arrsize, hasnan, descending, nans_last); \ } #define DECLARE_INTERNAL_argselect(TYPE) \ @@ -181,13 +181,13 @@ namespace x86simdsort { size_t arrsize, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ if (internal_argselect##TYPE == NULL) { \ CAT(resolve_argselect, TYPE)(); \ } \ return (*internal_argselect##TYPE)( \ - arr, k, arrsize, hasnan, descending, trailing_nans); \ + arr, k, arrsize, hasnan, descending, nans_last); \ } #else @@ -200,10 +200,10 @@ namespace x86simdsort { size_t arrsize, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ (*internal_qsort##TYPE)( \ - arr, arrsize, hasnan, descending, trailing_nans); \ + arr, arrsize, hasnan, descending, nans_last); \ } #define DECLARE_INTERNAL_qselect(TYPE) \ @@ -216,10 +216,10 @@ namespace x86simdsort { size_t arrsize, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ (*internal_qselect##TYPE)( \ - arr, k, arrsize, hasnan, descending, trailing_nans); \ + arr, k, arrsize, hasnan, descending, nans_last); \ } #define DECLARE_INTERNAL_partial_qsort(TYPE) \ @@ -232,10 +232,10 @@ namespace x86simdsort { size_t arrsize, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ (*internal_partial_qsort##TYPE)( \ - arr, k, arrsize, hasnan, descending, trailing_nans); \ + arr, k, arrsize, hasnan, descending, nans_last); \ } #define DECLARE_INTERNAL_argsort(TYPE) \ @@ -247,10 +247,10 @@ namespace x86simdsort { size_t arrsize, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ return (*internal_argsort##TYPE)( \ - arr, arrsize, hasnan, descending, trailing_nans); \ + arr, arrsize, hasnan, descending, nans_last); \ } #define DECLARE_INTERNAL_argselect(TYPE) \ @@ -263,10 +263,10 @@ namespace x86simdsort { size_t arrsize, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ return (*internal_argselect##TYPE)( \ - arr, k, arrsize, hasnan, descending, trailing_nans); \ + arr, k, arrsize, hasnan, descending, nans_last); \ } #endif // _MSC_VER diff --git a/lib/x86simdsort.h b/lib/x86simdsort.h index 6f46e05f..37c7bb6e 100644 --- a/lib/x86simdsort.h +++ b/lib/x86simdsort.h @@ -23,7 +23,7 @@ XSS_EXPORT_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false, - bool trailing_nans = true); + bool nans_last = true); // quickselect template @@ -32,7 +32,7 @@ XSS_EXPORT_SYMBOL void qselect(T *arr, size_t arrsize, bool hasnan = false, bool descending = false, - bool trailing_nans = true); + bool nans_last = true); // partial sort template @@ -41,7 +41,7 @@ XSS_EXPORT_SYMBOL void partial_qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false, - bool trailing_nans = true); + bool nans_last = true); // argsort template @@ -49,7 +49,7 @@ XSS_EXPORT_SYMBOL std::vector argsort(const T *arr, size_t arrsize, bool hasnan = false, bool descending = false, - bool trailing_nans = true); + bool nans_last = true); // argselect template @@ -58,7 +58,7 @@ XSS_EXPORT_SYMBOL std::vector argselect(const T *arr, size_t arrsize, bool hasnan = false, bool descending = false, - bool trailing_nans = true); + bool nans_last = true); // keyvalue sort template diff --git a/src/README.md b/src/README.md index 6100c897..a89f2d08 100644 --- a/src/README.md +++ b/src/README.md @@ -18,12 +18,12 @@ Equivalent to `qsort` in `std::sort` in [C++](https://en.cppreference.com/w/cpp/algorithm/sort). ```cpp -void x86simdsortStatic::qsort(T* arr, size_t arrsize, bool hasnan = false, bool descending = false, bool trailing_nans = true); +void x86simdsortStatic::qsort(T* arr, size_t arrsize, bool hasnan = false, bool descending = false, bool nans_last = true); ``` Supported datatypes: `uint16_t`, `int16_t`, `_Float16`, `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. AVX2 versions currently support 32-bit and 64-bit dtypes only. For floating-point types, if `arr` contains -NaNs, their placement is controlled by `trailing_nans`: `true` (default) places +NaNs, their placement is controlled by `nans_last`: `true` (default) places NaNs at the end; `false` places them at the beginning. Bit-exact NaN payloads are preserved. @@ -35,12 +35,12 @@ Equivalent to `std::nth_element` in ```cpp -void x86simdsortStatic::qselect(T* arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false, bool trailing_nans = true); +void x86simdsortStatic::qselect(T* arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false, bool nans_last = true); ``` Supported datatypes: `uint16_t`, `int16_t`, `_Float16`, `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. AVX2 versions currently support 32-bit and 64-bit dtypes only. For floating-point types, if `bool hasnan` is -set, NaN placement is controlled by `trailing_nans`: `true` (default) places +set, NaN placement is controlled by `nans_last`: `true` (default) places NaNs at the end; `false` places them at the beginning. Bit-exact NaN payloads are preserved. If NaNs are present but `hasnan` is `false`, the behavior is undefined. @@ -51,12 +51,12 @@ Equivalent to `std::partial_sort` in ```cpp -void x86simdsortStatic::partial_qsort(T* arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false, bool trailing_nans = true); +void x86simdsortStatic::partial_qsort(T* arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false, bool nans_last = true); ``` Supported datatypes: `uint16_t`, `int16_t`, `_Float16`, `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. AVX2 versions currently support 32-bit and 64-bit dtypes only. For floating-point types, if `bool hasnan` is -set, NaN placement is controlled by `trailing_nans`: `true` (default) places +set, NaN placement is controlled by `nans_last`: `true` (default) places NaNs at the end; `false` places them at the beginning. Bit-exact NaN payloads are preserved. If NaNs are present but `hasnan` is `false`, the behavior is undefined. @@ -66,7 +66,7 @@ Equivalent to `np.argsort` in [NumPy](https://numpy.org/doc/stable/reference/generated/numpy.argsort.html). ```cpp -void x86simdsortStatic::argsort(const T* arr, size_t *arg, size_t arrsize, bool hasnan = false, bool descending = false, bool trailing_nans = true); +void x86simdsortStatic::argsort(const T* arr, size_t *arg, size_t arrsize, bool hasnan = false, bool descending = false, bool nans_last = true); ``` Supported datatypes: `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. @@ -78,14 +78,14 @@ Equivalent to `np.argselect` in [NumPy](https://numpy.org/doc/stable/reference/generated/numpy.argpartition.html). ```cpp -void x86simdsortStatic::argselect(const T* arr, size_t *arg, size_t k, size_t arrsize, bool hasnan = false, bool descending = false, bool trailing_nans = true); +void x86simdsortStatic::argselect(const T* arr, size_t *arg, size_t k, size_t arrsize, bool hasnan = false, bool descending = false, bool nans_last = true); ``` Supported datatypes: `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. When `descending=true`, the k-th element is the k-th **largest** and elements before index k are all greater than or equal to it. For floating-point types, -if `bool hasnan` is set, NaN placement is controlled by `trailing_nans`: +if `bool hasnan` is set, NaN placement is controlled by `nans_last`: `true` (default) places NaNs at the end; `false` places them at the beginning. The algorithm resorts to scalar `std::nth_element` if the array contains NaNs. diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 7ae78778..f8c327af 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -571,7 +571,7 @@ avx512_qsort_fp16(uint16_t *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false, - bool trailing_nans = true) + bool nans_last = true) { using vtype = zmm_vector; @@ -579,7 +579,7 @@ avx512_qsort_fp16(uint16_t *arr, arrsize_t index_first_elem = 0; arrsize_t index_last_elem = arrsize - 1; if (UNLIKELY(hasnan)) { - if (!trailing_nans) { + if (!nans_last) { index_first_elem = move_nans_to_start_of_array(arr, arrsize); } else { @@ -610,7 +610,7 @@ avx512_qselect_fp16(uint16_t *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false, - bool trailing_nans = true) + bool nans_last = true) { using vtype = zmm_vector; @@ -621,7 +621,7 @@ avx512_qselect_fp16(uint16_t *arr, arrsize_t index_last_elem = arrsize - 1; if (UNLIKELY(hasnan)) { - if (!trailing_nans) { + if (!nans_last) { index_first_elem = move_nans_to_start_of_array(arr, arrsize); } else { @@ -660,10 +660,10 @@ avx512_partial_qsort_fp16(uint16_t *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false, - bool trailing_nans = true) + bool nans_last = true) { if (k == 0) return; - avx512_qselect_fp16(arr, k - 1, arrsize, hasnan, descending, trailing_nans); - avx512_qsort_fp16(arr, k - 1, hasnan, descending, trailing_nans); + avx512_qselect_fp16(arr, k - 1, arrsize, hasnan, descending, nans_last); + avx512_qsort_fp16(arr, k - 1, hasnan, descending, nans_last); } #endif // AVX512_QSORT_16BIT diff --git a/src/x86simdsort-static-incl.h b/src/x86simdsort-static-incl.h index b13c2525..bdc91baf 100644 --- a/src/x86simdsort-static-incl.h +++ b/src/x86simdsort-static-incl.h @@ -11,7 +11,7 @@ X86_SIMD_SORT_FINLINE void qsort(T *arr, size_t size, bool hasnan = false, bool descending = false, - bool trailing_nans = true); + bool nans_last = true); template X86_SIMD_SORT_FINLINE void qselect(T *arr, @@ -19,7 +19,7 @@ X86_SIMD_SORT_FINLINE void qselect(T *arr, size_t size, bool hasnan = false, bool descending = false, - bool trailing_nans = true); + bool nans_last = true); template X86_SIMD_SORT_FINLINE void partial_qsort(T *arr, @@ -27,14 +27,14 @@ X86_SIMD_SORT_FINLINE void partial_qsort(T *arr, size_t size, bool hasnan = false, bool descending = false, - bool trailing_nans = true); + bool nans_last = true); template X86_SIMD_SORT_FINLINE std::vector argsort(const T *arr, size_t size, bool hasnan = false, bool descending = false, - bool trailing_nans = true); + bool nans_last = true); /* argsort API required by NumPy: */ template @@ -43,7 +43,7 @@ X86_SIMD_SORT_FINLINE void argsort(const T *arr, size_t size, bool hasnan = false, bool descending = false, - bool trailing_nans = true); + bool nans_last = true); template X86_SIMD_SORT_FINLINE std::vector argselect(const T *arr, @@ -51,7 +51,7 @@ X86_SIMD_SORT_FINLINE std::vector argselect(const T *arr, size_t size, bool hasnan = false, bool descending = false, - bool trailing_nans = true); + bool nans_last = true); /* argselect API required by NumPy: */ template @@ -61,7 +61,7 @@ void X86_SIMD_SORT_FINLINE argselect(const T *arr, size_t size, bool hasnan = false, bool descending = false, - bool trailing_nans = true); + bool nans_last = true); template X86_SIMD_SORT_FINLINE void keyvalue_qsort(T1 *key, @@ -94,9 +94,9 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key, size_t size, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ - ISA##_qsort(arr, size, hasnan, descending, trailing_nans); \ + ISA##_qsort(arr, size, hasnan, descending, nans_last); \ } \ template \ X86_SIMD_SORT_FINLINE void x86simdsortStatic::qselect(T *arr, \ @@ -104,9 +104,9 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key, size_t size, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ - ISA##_qselect(arr, k, size, hasnan, descending, trailing_nans); \ + ISA##_qselect(arr, k, size, hasnan, descending, nans_last); \ } \ template \ X86_SIMD_SORT_FINLINE void x86simdsortStatic::partial_qsort( \ @@ -115,9 +115,9 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key, size_t size, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ - ISA##_partial_qsort(arr, k, size, hasnan, descending, trailing_nans); \ + ISA##_partial_qsort(arr, k, size, hasnan, descending, nans_last); \ } \ template \ X86_SIMD_SORT_FINLINE void x86simdsortStatic::argsort(const T *arr, \ @@ -125,9 +125,9 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key, size_t size, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ - ISA##_argsort(arr, arg, size, hasnan, descending, trailing_nans); \ + ISA##_argsort(arr, arg, size, hasnan, descending, nans_last); \ } \ template \ X86_SIMD_SORT_FINLINE std::vector x86simdsortStatic::argsort( \ @@ -135,12 +135,12 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key, size_t size, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ std::vector indices(size); \ std::iota(indices.begin(), indices.end(), 0); \ x86simdsortStatic::argsort( \ - arr, indices.data(), size, hasnan, descending, trailing_nans); \ + arr, indices.data(), size, hasnan, descending, nans_last); \ return indices; \ } \ template \ @@ -151,9 +151,9 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key, size_t size, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ - ISA##_argselect(arr, arg, k, size, hasnan, descending, trailing_nans); \ + ISA##_argselect(arr, arg, k, size, hasnan, descending, nans_last); \ } \ template \ X86_SIMD_SORT_FINLINE std::vector x86simdsortStatic::argselect( \ @@ -162,7 +162,7 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key, size_t size, \ bool hasnan, \ bool descending, \ - bool trailing_nans) \ + bool nans_last) \ { \ std::vector indices(size); \ std::iota(indices.begin(), indices.end(), 0); \ @@ -172,7 +172,7 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key, size, \ hasnan, \ descending, \ - trailing_nans); \ + nans_last); \ return indices; \ } \ template \ @@ -236,9 +236,9 @@ void x86simdsortStatic::qsort<_Float16>(_Float16 *arr, size_t size, bool hasnan, bool descending, - bool trailing_nans) + bool nans_last) { - avx512_qsort_fp16((uint16_t *)arr, size, hasnan, descending, trailing_nans); + avx512_qsort_fp16((uint16_t *)arr, size, hasnan, descending, nans_last); } template <> [[maybe_unused]] @@ -247,10 +247,10 @@ void x86simdsortStatic::qselect<_Float16>(_Float16 *arr, size_t size, bool hasnan, bool descending, - bool trailing_nans) + bool nans_last) { avx512_qselect_fp16( - (uint16_t *)arr, k, size, hasnan, descending, trailing_nans); + (uint16_t *)arr, k, size, hasnan, descending, nans_last); } template <> [[maybe_unused]] @@ -259,10 +259,10 @@ void x86simdsortStatic::partial_qsort<_Float16>(_Float16 *arr, size_t size, bool hasnan, bool descending, - bool trailing_nans) + bool nans_last) { avx512_partial_qsort_fp16( - (uint16_t *)arr, k, size, hasnan, descending, trailing_nans); + (uint16_t *)arr, k, size, hasnan, descending, nans_last); } #endif diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h index 88c88c32..b4fd433a 100644 --- a/src/xss-common-argsort.h +++ b/src/xss-common-argsort.h @@ -16,21 +16,21 @@ X86_SIMD_SORT_INLINE void std_argselect_withnan(const T *arr, arrsize_t k, arrsize_t left, arrsize_t right, - bool trailing_nans = true, + bool nans_last = true, bool descending = false) { std::nth_element( arg + left, arg + k, arg + right, - [arr, trailing_nans, descending](arrsize_t a, arrsize_t b) -> bool { + [arr, nans_last, descending](arrsize_t a, arrsize_t b) -> bool { bool a_nan = std::isnan(arr[a]); bool b_nan = std::isnan(arr[b]); if (!a_nan && !b_nan) { return descending ? arr[a] > arr[b] : arr[a] < arr[b]; } if (a_nan && b_nan) { return false; } - return trailing_nans ? !a_nan : a_nan; + return nans_last ? !a_nan : a_nan; }); } @@ -40,20 +40,20 @@ X86_SIMD_SORT_INLINE void std_argsort_withnan(const T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right, - bool trailing_nans = true, + bool nans_last = true, bool descending = false) { std::sort( arg + left, arg + right, - [arr, trailing_nans, descending](arrsize_t a, arrsize_t b) -> bool { + [arr, nans_last, descending](arrsize_t a, arrsize_t b) -> bool { bool a_nan = std::isnan(arr[a]); bool b_nan = std::isnan(arr[b]); if (!a_nan && !b_nan) { return descending ? arr[a] > arr[b] : arr[a] < arr[b]; } if (a_nan && b_nan) { return false; } - return trailing_nans ? !a_nan : a_nan; + return nans_last ? !a_nan : a_nan; }); } @@ -602,7 +602,7 @@ X86_SIMD_SORT_INLINE void xss_argsort(const T *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false, - bool trailing_nans = true) + bool nans_last = true) { using vectype = typename std::conditional) { if ((hasnan) && (array_has_nan(arr, arrsize))) { std_argsort_withnan( - arr, arg, 0, arrsize, trailing_nans, descending); + arr, arg, 0, arrsize, nans_last, descending); return; } } UNUSED(hasnan); - UNUSED(trailing_nans); + UNUSED(nans_last); /* early exit for already sorted arrays: float/double with nan never reach here*/ auto comp = descending ? Comparator::STDSortComparator @@ -681,10 +681,10 @@ X86_SIMD_SORT_INLINE void avx512_argsort(const T *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false, - bool trailing_nans = true) + bool nans_last = true) { xss_argsort( - arr, arg, arrsize, hasnan, descending, trailing_nans); + arr, arg, arrsize, hasnan, descending, nans_last); } template @@ -693,10 +693,10 @@ X86_SIMD_SORT_INLINE void avx2_argsort(const T *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false, - bool trailing_nans = true) + bool nans_last = true) { xss_argsort( - arr, arg, arrsize, hasnan, descending, trailing_nans); + arr, arg, arrsize, hasnan, descending, nans_last); } /* argselect methods for 32-bit and 64-bit dtypes */ @@ -711,7 +711,7 @@ X86_SIMD_SORT_INLINE void xss_argselect(const T *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false, - bool trailing_nans = true) + bool nans_last = true) { /* TODO optimization: on 32-bit, use full_vector for 32-bit dtype */ using vectype = typename std::conditional) { if ((hasnan) && (array_has_nan(arr, arrsize))) { std_argselect_withnan( - arr, arg, k, 0, arrsize, trailing_nans, descending); + arr, arg, k, 0, arrsize, nans_last, descending); return; } } UNUSED(hasnan); - UNUSED(trailing_nans); + UNUSED(nans_last); /* For descending, partition at the mirror position so the k-th * largest lands at arrsize-1-k; reversal then moves it to k. */ arrsize_t pos = descending ? arrsize - 1 - k : k; @@ -755,10 +755,10 @@ X86_SIMD_SORT_INLINE void avx512_argselect(const T *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false, - bool trailing_nans = true) + bool nans_last = true) { xss_argselect( - arr, arg, k, arrsize, hasnan, descending, trailing_nans); + arr, arg, k, arrsize, hasnan, descending, nans_last); } template @@ -768,10 +768,10 @@ X86_SIMD_SORT_INLINE void avx2_argselect(const T *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false, - bool trailing_nans = true) + bool nans_last = true) { xss_argselect( - arr, arg, k, arrsize, hasnan, descending, trailing_nans); + arr, arg, k, arrsize, hasnan, descending, nans_last); } #endif // XSS_COMMON_ARGSORT diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 6fe34bbd..70aa3465 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -594,7 +594,7 @@ X86_SIMD_SORT_INLINE void qselect_(type_t *arr, // Quicksort routines: template X86_SIMD_SORT_INLINE void -xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, bool trailing_nans = true) +xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, bool nans_last = true) { using comparator = typename std::conditional) { if (UNLIKELY(hasnan)) { - if (!trailing_nans) { + if (!nans_last) { index_first_elem = move_nans_to_start_of_array(arr, arrsize); } @@ -672,7 +672,7 @@ X86_SIMD_SORT_INLINE void xss_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, - bool trailing_nans = true) + bool nans_last = true) { using comparator = typename std::conditional) { if (UNLIKELY(hasnan)) { - if (!trailing_nans) { + if (!nans_last) { index_first_elem = move_nans_to_start_of_array(arr, arrsize); } else { @@ -717,12 +717,12 @@ X86_SIMD_SORT_INLINE void xss_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, - bool trailing_nans = true) + bool nans_last = true) { if (k == 0) return; xss_qselect( - arr, k - 1, arrsize, hasnan, trailing_nans); - xss_qsort(arr, k - 1, hasnan, trailing_nans); + arr, k - 1, arrsize, hasnan, nans_last); + xss_qsort(arr, k - 1, hasnan, nans_last); } #define DEFINE_METHODS(ISA, VTYPE) \ @@ -731,13 +731,13 @@ X86_SIMD_SORT_INLINE void xss_partial_qsort(T *arr, arrsize_t size, \ bool hasnan = false, \ bool descending = false, \ - bool trailing_nans = true) \ + bool nans_last = true) \ { \ if (descending) { \ - xss_qsort(arr, size, hasnan, trailing_nans); \ + xss_qsort(arr, size, hasnan, nans_last); \ } \ else { \ - xss_qsort(arr, size, hasnan, trailing_nans); \ + xss_qsort(arr, size, hasnan, nans_last); \ } \ } \ template \ @@ -746,13 +746,13 @@ X86_SIMD_SORT_INLINE void xss_partial_qsort(T *arr, arrsize_t size, \ bool hasnan = false, \ bool descending = false, \ - bool trailing_nans = true) \ + bool nans_last = true) \ { \ if (descending) { \ - xss_qselect(arr, k, size, hasnan, trailing_nans); \ + xss_qselect(arr, k, size, hasnan, nans_last); \ } \ else { \ - xss_qselect(arr, k, size, hasnan, trailing_nans); \ + xss_qselect(arr, k, size, hasnan, nans_last); \ } \ } \ template \ @@ -761,15 +761,15 @@ X86_SIMD_SORT_INLINE void xss_partial_qsort(T *arr, arrsize_t size, \ bool hasnan = false, \ bool descending = false, \ - bool trailing_nans = true) \ + bool nans_last = true) \ { \ if (descending) { \ xss_partial_qsort( \ - arr, k, size, hasnan, trailing_nans); \ + arr, k, size, hasnan, nans_last); \ } \ else { \ xss_partial_qsort( \ - arr, k, size, hasnan, trailing_nans); \ + arr, k, size, hasnan, nans_last); \ } \ } diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index 084fc889..89e1435b 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -128,7 +128,7 @@ TYPED_TEST_P(simdsort, test_argsort_descending) } } -TYPED_TEST_P(simdsort, test_argsort_trailing_nans) +TYPED_TEST_P(simdsort, test_argsort_nans_last) { if constexpr (xss::fp::is_floating_point_v) { std::vector nan_types @@ -269,7 +269,7 @@ TYPED_TEST_P(simdsort, test_qselect_descending) } } -TYPED_TEST_P(simdsort, test_qselect_trailing_nans) +TYPED_TEST_P(simdsort, test_qselect_nans_last) { if constexpr (xss::fp::is_floating_point_v) { std::vector nan_types @@ -417,7 +417,7 @@ TYPED_TEST_P(simdsort, test_argselect_descending) } } -TYPED_TEST_P(simdsort, test_argselect_trailing_nans) +TYPED_TEST_P(simdsort, test_argselect_nans_last) { if constexpr (xss::fp::is_floating_point_v) { std::vector nan_types @@ -600,7 +600,7 @@ TYPED_TEST_P(simdsort, test_comparator) } } -TYPED_TEST_P(simdsort, test_qsort_trailing_nans) +TYPED_TEST_P(simdsort, test_qsort_nans_last) { if constexpr (xss::fp::is_floating_point_v) { std::vector nan_types @@ -689,19 +689,19 @@ TYPED_TEST_P(simdsort, test_qsort_leading_nans) REGISTER_TYPED_TEST_SUITE_P(simdsort, test_qsort_ascending, test_qsort_descending, - test_qsort_trailing_nans, + test_qsort_nans_last, test_qsort_leading_nans, test_argsort_ascending, test_argsort_descending, - test_argsort_trailing_nans, + test_argsort_nans_last, test_argsort_leading_nans, test_argselect_ascending, test_argselect_descending, - test_argselect_trailing_nans, + test_argselect_nans_last, test_argselect_leading_nans, test_qselect_ascending, test_qselect_descending, - test_qselect_trailing_nans, + test_qselect_nans_last, test_qselect_leading_nans, test_partial_qsort_ascending, test_partial_qsort_descending,