Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 88 additions & 42 deletions libcudacxx/include/cuda/__argument/argument.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
#include <cuda/std/__iterator/iterator_traits.h>
#include <cuda/std/__iterator/readable_traits.h>
#include <cuda/std/__ranges/concepts.h>
#include <cuda/std/__type_traits/extent.h>
#include <cuda/std/__type_traits/is_arithmetic.h>
#include <cuda/std/__type_traits/is_array.h>
#include <cuda/std/__type_traits/is_bounded_array.h>
#include <cuda/std/__type_traits/is_integer.h>
#include <cuda/std/__type_traits/is_integral.h>
#include <cuda/std/__type_traits/is_same.h>
Expand All @@ -40,6 +42,7 @@
#include <cuda/std/__utility/declval.h>
#include <cuda/std/__utility/forward.h>
#include <cuda/std/__utility/move.h>
#include <cuda/std/array>
#include <cuda/std/cstddef>
#include <cuda/std/limits>

Expand Down Expand Up @@ -111,16 +114,60 @@ class constant
};

//! @brief Wraps a compile-time constant argument sequence.
template <auto _Value>
template <class _Tp, _Tp... _Vs>
class __constant_sequence
{
public:
using value_type = ::cuda::std::remove_cvref_t<decltype(_Value)>;
using __element_type = __element_type_of_t<value_type>;

static_assert(__is_sequence_v<value_type>, "The value type of __constant_sequence must be a sequence");
using __element_type = ::cuda::std::remove_cvref_t<_Tp>;
using value_type = ::cuda::std::array<__element_type, sizeof...(_Vs)>;
static constexpr ::cuda::std::size_t size = sizeof...(_Vs);
};

template <const auto& _Arr, ::cuda::std::size_t... _Is>
_CCCL_API constexpr auto __make_constant_sequence_impl(::cuda::std::index_sequence<_Is...>)
{
using __raw_array = ::cuda::std::remove_cvref_t<decltype(_Arr)>;

if constexpr (::cuda::std::is_bounded_array_v<__raw_array>)
{
using _Tp = ::cuda::std::remove_cv_t<::cuda::std::remove_extent_t<__raw_array>>;
return __constant_sequence<_Tp, _Arr[_Is]...>{};
}
else if constexpr (::cuda::std::__is_cuda_std_array_v<__raw_array>)
{
using _Tp = typename __raw_array::value_type;
return __constant_sequence<_Tp, _Arr[_Is]...>{};
}
else
{
static_assert(::cuda::std::__always_false_v<__raw_array>, "unsupported array type");
}
}

//! @brief Makes a compile-time constant argument sequence.
//! In C++17, Arr must have static storage duration.
template <const auto& Arr>
_CCCL_API constexpr auto __make_constant_sequence()
{
using __raw_array = ::cuda::std::remove_cv_t<::cuda::std::remove_reference_t<decltype(Arr)>>;

static_assert(::cuda::std::is_bounded_array_v<__raw_array> || ::cuda::std::__is_cuda_std_array_v<__raw_array>,
"make_constant_sequence requires a cuda::std::array or non-empty C-style array");

constexpr ::cuda::std::size_t N = []() constexpr {
if constexpr (::cuda::std::is_bounded_array_v<__raw_array>)
{
return ::cuda::std::extent_v<__raw_array>;
}
else
{
return ::cuda::std::tuple_size_v<__raw_array>;
}
}();

return __make_constant_sequence_impl<Arr>(::cuda::std::make_index_sequence<N>{});
}

// __assert_in_range
// =====================================================================

Expand Down Expand Up @@ -621,8 +668,8 @@ template <class _Arg, class _StaticBounds>
inline constexpr bool __is_wrapper_v<immediate<_Arg, _StaticBounds>> = true;
template <auto _Value, class _Tp>
inline constexpr bool __is_wrapper_v<constant<_Value, _Tp>> = true;
template <auto _Value>
inline constexpr bool __is_wrapper_v<__constant_sequence<_Value>> = true;
template <class _Tp, _Tp... _Vs>
inline constexpr bool __is_wrapper_v<__constant_sequence<_Tp, _Vs...>> = true;
template <class _Arg, class _StaticBounds>
inline constexpr bool __is_wrapper_v<__immediate_sequence<_Arg, _StaticBounds>> = true;
template <class _Arg, class _StaticBounds>
Expand Down Expand Up @@ -662,11 +709,11 @@ __unwrap(const constant<_Value, _Tp>&) noexcept
return constant<_Value, _Tp>::__get_value();
}

template <auto _Value>
[[nodiscard]] _CCCL_API constexpr ::cuda::std::remove_cvref_t<decltype(_Value)>
__unwrap(const __constant_sequence<_Value>&) noexcept
//! Unwraps a compile-time constant argument sequence into a canonical cuda::std::array value.
template <class _Tp, _Tp... _Vs>
[[nodiscard]] _CCCL_API constexpr auto __unwrap(const __constant_sequence<_Tp, _Vs...>&) noexcept
{
return _Value;
return ::cuda::std::array<::cuda::std::remove_cvref_t<_Tp>, sizeof...(_Vs)>{_Vs...};
}

template <class _Arg, class _StaticBounds>
Expand Down Expand Up @@ -735,32 +782,32 @@ _CCCL_API constexpr auto __constant_compute_highest() noexcept
return constant<_Value, _Tp>::__get_value();
}

template <auto _Value>
_CCCL_API constexpr auto __constant_sequence_compute_lowest() noexcept
template <class _Tp, _Tp... _Vs>
_CCCL_API constexpr _Tp __constant_sequence_compute_lowest() noexcept
{
using _ElementType = __element_type_of_t<::cuda::std::remove_cvref_t<decltype(_Value)>>;
auto __first = _Value.begin();
auto __last = _Value.end();

if (__first == __last)
if constexpr (sizeof...(_Vs) == 0)
{
return __type_lowest<_ElementType>();
return __type_lowest<_Tp>();
}
else
{
constexpr _Tp __values[] = {_Vs...};
return static_cast<_Tp>(*::cuda::std::min_element(__values, __values + sizeof...(_Vs)));
}
return static_cast<_ElementType>(*::cuda::std::min_element(__first, __last));

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why no longer min_element?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review!

I think at one point during development I wasn't able to use it and replaced it here.
I found a way to use it again here making the code simpler.

Applied the same for the max computation too.

}

template <auto _Value>
_CCCL_API constexpr auto __constant_sequence_compute_highest() noexcept
template <class _Tp, _Tp... _Vs>
_CCCL_API constexpr _Tp __constant_sequence_compute_highest() noexcept
{
using _ElementType = __element_type_of_t<::cuda::std::remove_cvref_t<decltype(_Value)>>;
auto __first = _Value.begin();
auto __last = _Value.end();

if (__first == __last)
if constexpr (sizeof...(_Vs) == 0)
{
return __type_highest<_ElementType>();
return __type_highest<_Tp>();
}
else
{
constexpr _Tp __values[] = {_Vs...};
return static_cast<_Tp>(*::cuda::std::max_element(__values, __values + sizeof...(_Vs)));
}
return static_cast<_ElementType>(*::cuda::std::max_element(__first, __last));
}

// =====================================================================
Expand Down Expand Up @@ -811,17 +858,16 @@ struct __traits_impl<immediate<_Arg, _StaticBounds>>
static constexpr element_type highest = __wrapper_static_highest<element_type, _StaticBounds>();
};

template <auto _Value>
struct __traits_impl<__constant_sequence<_Value>>
template <class _Tp, _Tp... _Vs>
struct __traits_impl<__constant_sequence<_Tp, _Vs...>>
{
using value_type = ::cuda::std::remove_cvref_t<decltype(_Value)>;
using element_type = __element_type_of_t<value_type>;
static_assert(__is_sequence_v<value_type>, "The value type of __constant_sequence must be a sequence");
using element_type = ::cuda::std::remove_cvref_t<_Tp>;
using value_type = ::cuda::std::array<element_type, sizeof...(_Vs)>;
static constexpr bool is_constant = true;
static constexpr bool is_deferred = false;
static constexpr bool is_single_value = false;
static constexpr element_type lowest = __constant_sequence_compute_lowest<_Value>();
static constexpr element_type highest = __constant_sequence_compute_highest<_Value>();
static constexpr element_type lowest = __constant_sequence_compute_lowest<_Tp, _Vs...>();
static constexpr element_type highest = __constant_sequence_compute_highest<_Tp, _Vs...>();
};

template <class _Arg, class _StaticBounds>
Expand Down Expand Up @@ -896,10 +942,10 @@ template <auto _Value, class _Tp>
return __constant_compute_lowest<_Value, _Tp>();
}

template <auto _Value>
[[nodiscard]] _CCCL_API constexpr auto __lowest_(__constant_sequence<_Value>) noexcept
template <class _Tp, _Tp... _Vs>
[[nodiscard]] _CCCL_API constexpr auto __lowest_(__constant_sequence<_Tp, _Vs...>) noexcept
{
return __constant_sequence_compute_lowest<_Value>();
return __constant_sequence_compute_lowest<_Tp, _Vs...>();
}

template <class _Arg, class _StaticBounds>
Expand Down Expand Up @@ -949,10 +995,10 @@ template <auto _Value, class _Tp>
return __constant_compute_highest<_Value, _Tp>();
}

template <auto _Value>
[[nodiscard]] _CCCL_API constexpr auto __highest_(__constant_sequence<_Value>) noexcept
template <class _Tp, _Tp... _Vs>
[[nodiscard]] _CCCL_API constexpr auto __highest_(__constant_sequence<_Tp, _Vs...>) noexcept
{
return __constant_sequence_compute_highest<_Value>();
return __constant_sequence_compute_highest<_Tp, _Vs...>();
}

template <class _Arg, class _StaticBounds>
Expand Down
54 changes: 28 additions & 26 deletions libcudacxx/test/libcudacxx/cuda/argument/argument_traits.pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,13 @@ TEST_FUNC void test()
static_assert(!cuda::args::__traits<cuda::args::immediate<int>>::is_deferred);
static_assert(!cuda::args::__traits<cuda::args::__immediate_sequence<cuda::std::span<int>>>::is_deferred);
static_assert(!cuda::args::__traits<cuda::args::constant<42>>::is_deferred);
#if TEST_HAS_CLASS_NTTP
static_assert(!cuda::args::__traits<cuda::args::__constant_sequence<cuda::std::array<int, 3>{1, 2, 3}>>::is_deferred);
#endif // TEST_HAS_CLASS_NTTP

static_assert(!cuda::args::__traits<cuda::args::__constant_sequence<int, 1, 2, 3>>::is_deferred);
static constexpr int carr[] = {1, 2, 3};
static constexpr ::cuda::std::array<int, 3> cudaarr = {1, 2, 3};
static_assert(!cuda::args::__traits<decltype(cuda::args::__make_constant_sequence<carr>())>::is_deferred);
static_assert(!cuda::args::__traits<decltype(cuda::args::__make_constant_sequence<cudaarr>())>::is_deferred);

static_assert(cuda::args::__traits<cuda::args::deferred<cuda::std::span<int, 1>>>::is_deferred);
static_assert(cuda::args::__traits<cuda::args::deferred_sequence<cuda::std::span<int>>>::is_deferred);

Expand All @@ -118,10 +122,11 @@ TEST_FUNC void test()
static_assert(cuda::args::__traits<cuda::args::immediate<cuda::counting_iterator<int>>>::is_single_value);
static_assert(!cuda::args::__traits<cuda::args::__immediate_sequence<cuda::std::span<int>>>::is_single_value);
static_assert(cuda::args::__traits<cuda::args::constant<42>>::is_single_value);
#if TEST_HAS_CLASS_NTTP
static_assert(
!cuda::args::__traits<cuda::args::__constant_sequence<cuda::std::array<int, 3>{1, 2, 3}>>::is_single_value);
#endif // TEST_HAS_CLASS_NTTP

static_assert(!cuda::args::__traits<cuda::args::__constant_sequence<int, 1, 2, 3>>::is_single_value);
static_assert(!cuda::args::__traits<decltype(cuda::args::__make_constant_sequence<carr>())>::is_single_value);
static_assert(!cuda::args::__traits<decltype(cuda::args::__make_constant_sequence<cudaarr>())>::is_single_value);

static_assert(cuda::args::__traits<cuda::args::deferred<int*>>::is_single_value);
static_assert(!cuda::args::__traits<cuda::args::deferred_sequence<cuda::std::span<int>>>::is_single_value);

Expand All @@ -134,11 +139,15 @@ TEST_FUNC void test()
cuda::std::span<int>>);
static_assert(cuda::std::is_same_v<cuda::args::__traits<cuda::args::constant<42>>::value_type, int>);
static_assert(cuda::std::is_same_v<cuda::args::__traits<cuda::args::constant<10, float>>::value_type, float>);
#if TEST_HAS_CLASS_NTTP
static_assert(cuda::std::is_same_v<
cuda::args::__traits<cuda::args::__constant_sequence<cuda::std::array<int, 3>{1, 2, 3}>>::value_type,
cuda::std::array<int, 3>>);
#endif // TEST_HAS_CLASS_NTTP

static_assert(cuda::std::is_same_v<cuda::args::__traits<cuda::args::__constant_sequence<int, 1, 2, 3>>::value_type,
cuda::std::array<int, 3>>);
static_assert(
cuda::std::is_same_v<cuda::args::__traits<decltype(cuda::args::__make_constant_sequence<carr>())>::value_type,
cuda::std::array<int, 3>>);
static_assert(
cuda::std::is_same_v<cuda::args::__traits<decltype(cuda::args::__make_constant_sequence<cudaarr>())>::value_type,
cuda::std::array<int, 3>>);

// --- argument_traits: lowest / highest ---

Expand All @@ -155,27 +164,20 @@ TEST_FUNC void test()
== 8);
static_assert(cuda::args::__traits<cuda::args::constant<10, float>>::lowest == 10.0f);
static_assert(cuda::args::__traits<cuda::args::constant<10, float>>::highest == 10.0f);
#if TEST_HAS_CLASS_NTTP
static_assert(cuda::args::__traits<cuda::args::__constant_sequence<cuda::std::array<int, 3>{3, 1, 2}>>::lowest == 1);
static_assert(cuda::args::__traits<cuda::args::__constant_sequence<cuda::std::array<int, 3>{3, 1, 2}>>::highest == 3);
#endif // TEST_HAS_CLASS_NTTP

static_assert(cuda::args::__traits<cuda::args::__constant_sequence<int, 1, 2, 3>>::lowest == 1);
static_assert(cuda::args::__traits<cuda::args::__constant_sequence<int, 1, 2, 3>>::highest == 3);
static_assert(cuda::args::__traits<decltype(cuda::args::__make_constant_sequence<carr>())>::lowest == 1);
static_assert(cuda::args::__traits<decltype(cuda::args::__make_constant_sequence<carr>())>::highest == 3);
static_assert(cuda::args::__traits<decltype(cuda::args::__make_constant_sequence<cudaarr>())>::lowest == 1);
static_assert(cuda::args::__traits<decltype(cuda::args::__make_constant_sequence<cudaarr>())>::highest == 3);

// --- Free function bounds on plain values ---

static_assert(cuda::args::__lowest_(42) == cuda::std::numeric_limits<int>::lowest());
static_assert(cuda::args::__highest_(42) == (cuda::std::numeric_limits<int>::max)());
static_assert(cuda::args::__lowest_(1.0f) == cuda::std::numeric_limits<float>::lowest());
static_assert(cuda::args::__highest_(1.0f) == (cuda::std::numeric_limits<float>::max)());

// --- Scalar and sequence wrappers expose distinct single-value traits ---

static_assert(cuda::args::__traits<cuda::args::constant<42>>::is_single_value);
static_assert(cuda::args::__traits<cuda::args::immediate<int>>::is_single_value);
static_assert(!cuda::args::__traits<cuda::args::__immediate_sequence<cuda::std::span<int>>>::is_single_value);
#if TEST_HAS_CLASS_NTTP
static_assert(
!cuda::args::__traits<cuda::args::__constant_sequence<cuda::std::array<int, 3>{1, 2, 3}>>::is_single_value);
#endif // TEST_HAS_CLASS_NTTP
}

int main(int, char**)
Expand Down
Loading