diff --git a/cpp/include/raft/core/bitset.cuh b/cpp/include/raft/core/bitset.cuh index 5616a9019c..e66b4a7989 100644 --- a/cpp/include/raft/core/bitset.cuh +++ b/cpp/include/raft/core/bitset.cuh @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -166,6 +167,8 @@ void bitset_view::repeat(const raft::resources& res, index_t times, bitset_t* output_device_ptr) const { + // Only a copy and kernel run below this point. + if (resource::get_dry_run_flag(res)) { return; } constexpr index_t bits_per_element = sizeof(bitset_t) * 8; if (bitset_len_ % bits_per_element == 0) { diff --git a/cpp/include/raft/core/bitset.hpp b/cpp/include/raft/core/bitset.hpp index fe47557ce4..3a8a363c62 100644 --- a/cpp/include/raft/core/bitset.hpp +++ b/cpp/include/raft/core/bitset.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -133,9 +134,11 @@ struct bitset_view { auto count_gpu_scalar = raft::make_device_scalar(res, 0.0); count(res, count_gpu_scalar.view()); index_t count_cpu = 0; - raft::update_host( - &count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res)); - resource::sync_stream(res); + if (!resource::get_dry_run_flag(res)) { + raft::update_host( + &count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res)); + resource::sync_stream(res); + } return count_cpu; } @@ -408,9 +411,11 @@ struct bitset { auto count_gpu_scalar = raft::make_device_scalar(res, 0.0); count(res, count_gpu_scalar.view()); index_t count_cpu = 0; - raft::update_host( - &count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res)); - resource::sync_stream(res); + if (!resource::get_dry_run_flag(res)) { + raft::update_host( + &count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res)); + resource::sync_stream(res); + } return count_cpu; } /** diff --git a/cpp/include/raft/core/coo_matrix.hpp b/cpp/include/raft/core/coo_matrix.hpp index 45bf3d3d54..f201b27afe 100644 --- a/cpp/include/raft/core/coo_matrix.hpp +++ b/cpp/include/raft/core/coo_matrix.hpp @@ -180,8 +180,8 @@ class coordinate_structure : public coordinate_structure_tget_n_rows() + 1); - c_indices_.resize(nnz); + c_indptr_.reallocate(this->get_n_rows() + 1); + c_indices_.reallocate(nnz); } protected: diff --git a/cpp/include/raft/core/detail/copy.hpp b/cpp/include/raft/core/detail/copy.hpp index 785665a99a..354d619411 100644 --- a/cpp/include/raft/core/detail/copy.hpp +++ b/cpp/include/raft/core/detail/copy.hpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -399,6 +400,10 @@ mdspan_copyable_t copy(resources const& res, DstType&& dst, Sr RAFT_EXPECTS(src.extent(i) == dst.extent(i), "Must copy between mdspans of the same shape"); } + // Dry-run guard: raft::copy is a pure data-movement utility with no + // allocations that callers would need tracked. + if (resource::get_dry_run_flag(res)) { return; } + if constexpr (config::use_intermediate_src) { #ifndef RAFT_DISABLE_CUDA // Copy to intermediate source on device, then perform necessary diff --git a/cpp/include/raft/core/device_container_policy.hpp b/cpp/include/raft/core/device_container_policy.hpp index 30233b69e6..acabec54ff 100644 --- a/cpp/include/raft/core/device_container_policy.hpp +++ b/cpp/include/raft/core/device_container_policy.hpp @@ -127,6 +127,29 @@ class device_uvector { void resize(size_type size) { data_.resize(size, data_.stream()); } + /** + * @brief Resize the internal buffer without copying old data. + * + * Unlike resize(), this never copies old data. + * Thus, unlike in resize(), there's no point in time where the old and the new buffers are both + * alive, and the peak memory usage is lower. + * + * Unlike resize(), this deallocates the old buffer even if the new size is smaller. + * This ensures the memory is released promptly. + */ + void reallocate(size_type size) + { + if (size != data_.size()) { + auto stream = data_.stream(); + auto mr = data_.memory_resource(); + // Resize and shrink rmm::device_uvector: force deallocation without copying old data + data_.resize(0, data_.stream()); + data_.shrink_to_fit(data_.stream()); + // Assign a new value after the old one is deallocated + data_ = rmm::device_uvector(size, stream, mr); + } + } + [[nodiscard]] auto data() noexcept -> pointer { return data_.data(); } [[nodiscard]] auto data() const noexcept -> const_pointer { return data_.data(); } }; diff --git a/cpp/include/raft/core/device_mdarray.hpp b/cpp/include/raft/core/device_mdarray.hpp index f7f564283c..28bae1ce1f 100644 --- a/cpp/include/raft/core/device_mdarray.hpp +++ b/cpp/include/raft/core/device_mdarray.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -164,7 +165,7 @@ auto make_device_scalar(raft::resources const& handle, ElementType const& v) using policy_t = typename device_scalar::container_policy_type; policy_t policy{}; auto scalar = device_scalar{handle, extents, policy}; - scalar(0) = v; + if (!resource::get_dry_run_flag(handle)) { scalar(0) = v; } return scalar; } diff --git a/cpp/include/raft/core/dry_run_resources.hpp b/cpp/include/raft/core/dry_run_resources.hpp new file mode 100644 index 0000000000..50e06973e5 --- /dev/null +++ b/cpp/include/raft/core/dry_run_resources.hpp @@ -0,0 +1,253 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +#include +#include +#include +#include + +namespace raft { + +/** + * @defgroup dry_run_memory Dry-run memory resources + * @{ + */ + +/** + * @brief Resources handle that wraps all reachable memory resources with + * dry-run adaptors and tracks peak allocation usage. + * + * Inherits from raft::resources, so it can be passed anywhere a + * raft::resources& is expected. On construction the handle: + * - If dry-run mode is already active, does nothing (no-op). + * - Materializes all tracked resource types (host, device, pinned, + * managed, workspace, large_workspace). + * - Takes a snapshot of the original resources to keep them alive. + * - Wraps each with dry_run_resource. + * - Replaces global host and device resources with dry-run versions. + * - Sets the dry-run flag. + * + * On destruction the handle resets the flag and restores global resources. + * Composable with memory_tracking_resources in either order. + */ +class dry_run_resources : public resources { + public: + explicit dry_run_resources(const resources& existing) + : resources(existing), + active_(!resource::get_dry_run_flag(existing)), + old_host_(raft::mr::get_default_host_resource()), + old_device_(rmm::mr::get_current_device_resource_ref()) + { + if (active_) init(); + } + + ~dry_run_resources() override + { + if (!active_) return; + resource::set_dry_run_flag(*this, false); + raft::mr::set_default_host_resource(old_host_); + rmm::mr::set_current_device_resource(old_device_); + + // Drop all base-class entries so that probe container RAII cleanup runs + // while old_device_ and snapshot_ are still alive + resources_.clear(); + factories_.clear(); + } + + dry_run_resources(dry_run_resources const&) = delete; + dry_run_resources& operator=(dry_run_resources const&) = delete; + dry_run_resources(dry_run_resources&&) = delete; + dry_run_resources& operator=(dry_run_resources&&) = delete; + + [[nodiscard]] auto get_bytes_peak() const -> memory_stats + { + if (!active_) return {}; + return { + .device_workspace = ws_stats_->get_peak_bytes(), + .device_large_workspace = lws_stats_->get_peak_bytes(), + .device_global = device_stats_->get_peak_bytes(), + .device_managed = managed_stats_->get_peak_bytes(), + .host = host_stats_->get_peak_bytes(), + .host_pinned = pinned_stats_->get_peak_bytes(), + }; + } + + [[nodiscard]] auto get_bytes_current() const -> memory_stats + { + if (!active_) return {}; + return { + .device_workspace = ws_stats_->get_allocated_bytes(), + .device_large_workspace = lws_stats_->get_allocated_bytes(), + .device_global = device_stats_->get_allocated_bytes(), + .device_managed = managed_stats_->get_allocated_bytes(), + .host = host_stats_->get_allocated_bytes(), + .host_pinned = pinned_stats_->get_allocated_bytes(), + }; + } + + private: + // Declaration order determines destruction order. + // snapshot_ is destroyed last (keeps original resource shared_ptrs alive + // while dry-run adaptors hold non-owning refs into them). + // old_device_ is destroyed after device_adaptor_ so the probe can + // deallocate through it during device_adaptor_ destruction. + std::vector snapshot_; + + bool active_; + raft::mr::host_resource old_host_; + raft::mr::device_resource old_device_; + + using host_dry_run_t = raft::mr::dry_run_resource; + using device_dry_run_t = raft::mr::dry_run_resource; + std::unique_ptr host_adaptor_; + std::unique_ptr device_adaptor_; + + using counter_t = raft::mr::detail::dry_run_memory_counter; + std::shared_ptr host_stats_; + std::shared_ptr pinned_stats_; + std::shared_ptr managed_stats_; + std::shared_ptr ws_stats_; + std::shared_ptr lws_stats_; + std::shared_ptr device_stats_; + + void init() + { + // Independent-counting invariant + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // 1. Force-initialize all lazily-created resources (workspace, large workspace, + // pinned, managed) so that their factories resolve against the *original* + // global device MR, not a tracking wrapper we install later. + // 2. Capture every upstream ref while it still points to the original resource. + // 3. Snapshot the resource map to keep the originals alive. + // 4. Only *then* replace the global device resource with the tracking bridge. + // 5. Wrap each captured upstream with a separate dry_run_resource adaptor. + // + // Because step 2 happens before step 4, workspace/lws allocations flow through + // their own adaptor directly to the original device MR, bypassing the device adaptor. + // Each allocation is therefore counted in exactly one category, and + // memory_stats::total() returns an accurate, non-overlapping sum. + auto* ws = resource::get_workspace_resource(*this); + auto ws_free = resource::get_workspace_free_bytes(*this); + auto ws_upstream = ws->get_upstream_resource(); + auto lws_ref = resource::get_large_workspace_resource_ref(*this); + auto pinned_ref = resource::get_pinned_memory_resource_ref(*this); + auto managed_ref = resource::get_managed_memory_resource_ref(*this); + + // Snapshot keeps original resource objects alive while dry-run + // adaptors hold non-owning refs into them. + snapshot_ = resources_; + + // --- Host (global) --- + { + host_adaptor_ = std::make_unique(raft::mr::host_resource_ref{old_host_}); + host_stats_ = host_adaptor_->get_counter(); + mr::set_default_host_resource(mr::host_resource_ref{*host_adaptor_}); + } + + // --- Pinned --- + { + mr::dry_run_resource dr{pinned_ref}; + pinned_stats_ = dr.get_counter(); + resource::set_pinned_memory_resource(*this, std::move(dr)); + } + + // --- Managed --- + { + mr::dry_run_resource dr{managed_ref}; + managed_stats_ = dr.get_counter(); + resource::set_managed_memory_resource(*this, std::move(dr)); + } + + // --- Device (global) --- + // Invalidate the cached thrust policy (the resource_ref it captured + // will be stale once we replace the global device resource). + factories_.at(resource::resource_type::THRUST_POLICY) = std::make_pair( + resource::resource_type::LAST_KEY, std::make_shared()); + resources_.at(resource::resource_type::THRUST_POLICY) = std::make_pair( + resource::resource_type::LAST_KEY, std::make_shared()); + { + device_dry_run_t dr{rmm::device_async_resource_ref{old_device_}}; + device_stats_ = dr.get_counter(); + device_adaptor_ = std::make_unique(std::move(dr)); + rmm::mr::set_current_device_resource(*device_adaptor_); + } + + // --- Workspace --- + { + mr::dry_run_resource dr{ws_upstream}; + ws_stats_ = dr.get_counter(); + resource::set_workspace_resource(*this, std::move(dr), ws_free); + } + + // --- Large workspace --- + { + mr::dry_run_resource dr{lws_ref}; + lws_stats_ = dr.get_counter(); + resource::set_large_workspace_resource(*this, std::move(dr)); + } + + resource::set_dry_run_flag(*this, true); + } +}; + +/** @} */ + +} // namespace raft + +namespace raft::util { + +/** + * @brief Execute an action in dry-run mode and return peak memory usage. + * + * Creates an independent copy of the resources handle with all memory resources + * replaced by dry-run versions, executes the action, and returns peak usage stats. + * + * The action receives the dry-run resources handle (as const raft::resources&) + * and can check the dry-run flag via raft::resource::get_dry_run_flag(res) to + * skip kernel execution. + * + * @tparam Action A callable with signature void(const raft::resources&, Args...). + * @tparam Args Additional argument types to forward to the action. + * @param res The raft resources handle. + * @param action The action to execute in dry-run mode. + * @param args Additional arguments to forward to the action. + * @return memory_stats with peak memory usage from the dry run. + * + * @code{.cpp} + * raft::resources res; + * auto stats = raft::util::dry_run_execute(res, [](const raft::resources& r) { + * my_algorithm(r); + * }); + * std::cout << "Peak workspace: " << stats.device_workspace << " bytes\n"; + * @endcode + */ +template +auto dry_run_execute(const raft::resources& res, Action&& action, Args&&... args) + -> raft::memory_stats +{ + raft::dry_run_resources dry_res(res); + std::forward(action)(static_cast(dry_res), + std::forward(args)...); + return dry_res.get_bytes_peak(); +} + +} // namespace raft::util diff --git a/cpp/include/raft/core/host_container_policy.hpp b/cpp/include/raft/core/host_container_policy.hpp index 6839431945..296b4d1710 100644 --- a/cpp/include/raft/core/host_container_policy.hpp +++ b/cpp/include/raft/core/host_container_policy.hpp @@ -105,6 +105,27 @@ requires cuda::mr::synchronous_resource_with *this = std::move(new_container); } + /** + * @brief Resize the internal buffer without copying old data. + * + * Unlike resize(), this never copies old data. + * Thus, unlike in resize(), there's no point in time where the old and the new buffers are both + * alive, and the peak memory usage is lower. + * + * Unlike resize(), this deallocates the old buffer even if the new size is smaller. + * This ensures the memory is released promptly. + */ + void reallocate(size_type count) + { + if (bytesize_ == sizeof(value_type) * count) { return; } + if (data_ != nullptr) { + mr_.deallocate_sync(data_, bytesize_); + data_ = nullptr; + } + auto tmp = host_container{count, mr_}; + std::swap(tmp, *this); + } + [[nodiscard]] auto data() noexcept -> pointer { return data_; } [[nodiscard]] auto data() const noexcept -> const_pointer { return data_; } }; diff --git a/cpp/include/raft/core/host_mdarray.hpp b/cpp/include/raft/core/host_mdarray.hpp index 712170b00e..09857cd2c1 100644 --- a/cpp/include/raft/core/host_mdarray.hpp +++ b/cpp/include/raft/core/host_mdarray.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -224,7 +225,7 @@ auto make_host_scalar(raft::resources const& res, ElementType const& v) using policy_t = typename host_scalar::container_policy_type; policy_t policy; auto scalar = host_scalar{res, extents, policy}; - scalar(0) = v; + if (!resource::get_dry_run_flag(res)) { scalar(0) = v; } return scalar; } diff --git a/cpp/include/raft/core/managed_mdarray.hpp b/cpp/include/raft/core/managed_mdarray.hpp index d6084a69ad..57e9eaf7bb 100644 --- a/cpp/include/raft/core/managed_mdarray.hpp +++ b/cpp/include/raft/core/managed_mdarray.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -118,7 +119,7 @@ auto make_managed_scalar(raft::resources const& handle, ElementType const& v) using policy_t = typename managed_scalar::container_policy_type; policy_t policy{}; auto scalar = managed_scalar{handle, extents, policy}; - scalar(0) = v; + if (!resource::get_dry_run_flag(handle)) { scalar(0) = v; } return scalar; } diff --git a/cpp/include/raft/core/pinned_mdarray.hpp b/cpp/include/raft/core/pinned_mdarray.hpp index 287430b69a..0ad69ceb17 100644 --- a/cpp/include/raft/core/pinned_mdarray.hpp +++ b/cpp/include/raft/core/pinned_mdarray.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -118,7 +119,7 @@ auto make_pinned_scalar(raft::resources const& handle, ElementType const& v) using policy_t = typename pinned_scalar::container_policy_type; policy_t policy{}; auto scalar = pinned_scalar{handle, extents, policy}; - scalar(0) = v; + if (!resource::get_dry_run_flag(handle)) { scalar(0) = v; } return scalar; } diff --git a/cpp/include/raft/core/resource/cuda_stream.hpp b/cpp/include/raft/core/resource/cuda_stream.hpp index b66c16f199..a20653db5f 100644 --- a/cpp/include/raft/core/resource/cuda_stream.hpp +++ b/cpp/include/raft/core/resource/cuda_stream.hpp @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -84,13 +85,18 @@ inline void set_cuda_stream(resources const& res, rmm::cuda_stream_view stream_v */ inline void sync_stream(const resources& res, rmm::cuda_stream_view stream) { + if (raft::resource::get_dry_run_flag(res)) { return; } interruptible::synchronize(stream); } /** * @brief synchronize main stream on the resources instance */ -inline void sync_stream(const resources& res) { sync_stream(res, get_cuda_stream(res)); } +inline void sync_stream(const resources& res) +{ + if (raft::resource::get_dry_run_flag(res)) { return; } + sync_stream(res, get_cuda_stream(res)); +} /** * @} diff --git a/cpp/include/raft/core/resource/dry_run_flag.hpp b/cpp/include/raft/core/resource/dry_run_flag.hpp new file mode 100644 index 0000000000..4d0c9e27b5 --- /dev/null +++ b/cpp/include/raft/core/resource/dry_run_flag.hpp @@ -0,0 +1,89 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include +#include + +#include + +namespace raft::resource { + +/** + * @defgroup dry_run_flag Dry-run flag resource + * @{ + */ + +/** + * @brief Resource that holds a boolean dry-run flag. + * + * When the dry-run flag is set, algorithms should skip kernel execution + * and only perform allocations to measure memory usage. + */ +class dry_run_flag_resource : public resource { + public: + dry_run_flag_resource() = default; + explicit dry_run_flag_resource(bool value) : flag_(value) {} + ~dry_run_flag_resource() override = default; + + auto get_resource() -> void* override { return &flag_; } + + void set(bool value) { flag_ = value; } + [[nodiscard]] auto get() const -> bool { return flag_; } + + private: + bool flag_{false}; +}; + +/** + * @brief Factory that creates a dry_run_flag_resource. + */ +class dry_run_flag_resource_factory : public resource_factory { + public: + explicit dry_run_flag_resource_factory(bool initial_value = false) : initial_value_(initial_value) + { + } + + auto get_resource_type() -> resource_type override { return resource_type::DRY_RUN_FLAG; } + auto make_resource() -> resource* override { return new dry_run_flag_resource(initial_value_); } + + private: + bool initial_value_; +}; + +/** + * @brief Get the dry-run flag from a resources handle. + * + * @param res raft resources object + * @return true if dry-run mode is active + */ +inline auto get_dry_run_flag(resources const& res) -> bool +{ + if (!res.has_resource_factory(resource_type::DRY_RUN_FLAG)) { + res.add_resource_factory(std::make_shared()); + } + return *res.get_resource(resource_type::DRY_RUN_FLAG); +} + +/** + * @brief Set the dry-run flag on a resources handle. + * + * @param res raft resources object + * @param value true to enable dry-run mode, false to disable + */ +inline void set_dry_run_flag(resources const& res, bool value) +{ + if (!res.has_resource_factory(resource_type::DRY_RUN_FLAG)) { + res.add_resource_factory(std::make_shared(value)); + } else { + // The resource may already be instantiated; update it directly + auto* flag = res.get_resource(resource_type::DRY_RUN_FLAG); + *flag = value; + } +} + +/** @} */ + +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/resource_types.hpp b/cpp/include/raft/core/resource/resource_types.hpp index e3af719eda..ae2c9b21cf 100644 --- a/cpp/include/raft/core/resource/resource_types.hpp +++ b/cpp/include/raft/core/resource/resource_types.hpp @@ -42,6 +42,7 @@ enum resource_type { MULTI_GPU, // resource that tracks resource of each device in multi-gpu world PINNED_MEMORY_RESOURCE, // memory resource for pinned (page-locked) host allocations MANAGED_MEMORY_RESOURCE, // resource for managed (unified) allocations + DRY_RUN_FLAG, // dry-run mode flag for allocation profiling LAST_KEY // reserved for the last key }; diff --git a/cpp/include/raft/core/sparse_types.hpp b/cpp/include/raft/core/sparse_types.hpp index 1657a8e494..3b7d9b9c59 100644 --- a/cpp/include/raft/core/sparse_types.hpp +++ b/cpp/include/raft/core/sparse_types.hpp @@ -178,7 +178,7 @@ class sparse_matrix { ~sparse_matrix() noexcept(std::is_nothrow_destructible::value) = default; - void initialize_sparsity(nnz_type nnz) { c_elements_.resize(nnz); }; + void initialize_sparsity(nnz_type nnz) { c_elements_.reallocate(nnz); }; raft::span get_elements() { diff --git a/cpp/include/raft/label/classlabels.cuh b/cpp/include/raft/label/classlabels.cuh index d02bf8feaf..02b6f3cb93 100644 --- a/cpp/include/raft/label/classlabels.cuh +++ b/cpp/include/raft/label/classlabels.cuh @@ -8,11 +8,37 @@ #pragma once #include +#include +#include +#include #include namespace raft { namespace label { +/** + * Get unique class labels. + * + * The y array is assumed to store class labels. The unique values are selected + * from this array. + * + * @tparam value_t numeric type of the arrays with class labels + * @param [in] handle raft resources handle (dry-run aware) + * @param [inout] unique output unique labels + * @param [in] y device array of labels, size [n] + * @param [in] n number of labels + * @returns number of unique labels (upper bound in dry-run mode) + */ +template +int getUniquelabels(raft::resources const& handle, + rmm::device_uvector& unique, + value_t* y, + size_t n) +{ + return detail::getUniquelabels( + resource::get_dry_run_flag(handle), unique, y, n, resource::get_cuda_stream(handle)); +} + /** * Get unique class labels. * diff --git a/cpp/include/raft/label/detail/classlabels.cuh b/cpp/include/raft/label/detail/classlabels.cuh index f0e9a14f69..2a3d7b50eb 100644 --- a/cpp/include/raft/label/detail/classlabels.cuh +++ b/cpp/include/raft/label/detail/classlabels.cuh @@ -30,15 +30,17 @@ namespace detail { * from this array. * * \tparam value_t numeric type of the arrays with class labels - * \param [in] y device array of labels, size [n] - * \param [in] n number of labels + * \param [in] dry_run if true, perform allocations but skip CUDA work * \param [out] unique device array of unique labels, unallocated on entry, * on exit it has size [n_unique] - * \param [out] n_unique number of unique labels + * \param [in] y device array of labels, size [n] + * \param [in] n number of labels * \param [in] stream cuda stream + * \return number of unique labels (upper bound when dry_run is true) */ template -int getUniquelabels(rmm::device_uvector& unique, value_t* y, size_t n, cudaStream_t stream) +int getUniquelabels( + bool dry_run, rmm::device_uvector& unique, value_t* y, size_t n, cudaStream_t stream) { rmm::device_scalar d_num_selected(stream); rmm::device_uvector workspace(n, stream); @@ -54,6 +56,11 @@ int getUniquelabels(rmm::device_uvector& unique, value_t* y, size_t n, bytes = std::max(bytes, bytes2); rmm::device_uvector cub_storage(bytes, stream); + if (dry_run) { + if (unique.size() < n) { unique = rmm::device_uvector(n, stream); } + return static_cast(n); + } + // Select Unique classes cub::DeviceRadixSort::SortKeys( cub_storage.data(), bytes, y, workspace.data(), n, 0, sizeof(value_t) * 8, stream); @@ -73,6 +80,26 @@ int getUniquelabels(rmm::device_uvector& unique, value_t* y, size_t n, return n_unique; } +/** + * Get unique class labels. + * + * The y array is assumed to store class labels. The unique values are selected + * from this array. + * + * \tparam value_t numeric type of the arrays with class labels + * \param [out] unique device array of unique labels, unallocated on entry, + * on exit it has size [n_unique] + * \param [in] y device array of labels, size [n] + * \param [in] n number of labels + * \param [in] stream cuda stream + * \return number of unique labels + */ +template +int getUniquelabels(rmm::device_uvector& unique, value_t* y, size_t n, cudaStream_t stream) +{ + return getUniquelabels(false, unique, y, n, stream); +} + /** * Assign one versus rest labels. * diff --git a/cpp/include/raft/linalg/add.cuh b/cpp/include/raft/linalg/add.cuh index 4171b53a27..d87d146a51 100644 --- a/cpp/include/raft/linalg/add.cuh +++ b/cpp/include/raft/linalg/add.cuh @@ -13,6 +13,7 @@ #include #include #include +#include #include namespace raft { @@ -103,6 +104,7 @@ template > void add(raft::resources const& handle, InType in1, InType in2, OutType out) { + if (resource::get_dry_run_flag(handle)) { return; } using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; @@ -140,6 +142,7 @@ void add_scalar(raft::resources const& handle, OutType out, raft::device_scalar_view scalar) { + if (resource::get_dry_run_flag(handle)) { return; } using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; @@ -175,6 +178,7 @@ void add_scalar(raft::resources const& handle, OutType out, raft::host_scalar_view scalar) { + if (resource::get_dry_run_flag(handle)) { return; } using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; diff --git a/cpp/include/raft/linalg/coalesced_reduction.cuh b/cpp/include/raft/linalg/coalesced_reduction.cuh index 818eee0ec3..835c81bf7a 100644 --- a/cpp/include/raft/linalg/coalesced_reduction.cuh +++ b/cpp/include/raft/linalg/coalesced_reduction.cuh @@ -13,6 +13,7 @@ #include #include #include +#include #include namespace raft { @@ -63,7 +64,7 @@ void coalescedReduction(OutType* dots, FinalLambda final_op = raft::identity_op()) { detail::coalescedReduction( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + false, dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } /** @@ -121,30 +122,32 @@ void coalesced_reduction(raft::resources const& handle, RAFT_EXPECTS(static_cast(dots.size()) == data.extent(0), "Output should be equal to number of rows in Input"); - coalescedReduction(dots.data_handle(), - data.data_handle(), - data.extent(1), - data.extent(0), - init, - resource::get_cuda_stream(handle), - inplace, - main_op, - reduce_op, - final_op); + detail::coalescedReduction(resource::get_dry_run_flag(handle), + dots.data_handle(), + data.data_handle(), + data.extent(1), + data.extent(0), + init, + resource::get_cuda_stream(handle), + inplace, + main_op, + reduce_op, + final_op); } else if constexpr (std::is_same_v) { RAFT_EXPECTS(static_cast(dots.size()) == data.extent(1), "Output should be equal to number of columns in Input"); - coalescedReduction(dots.data_handle(), - data.data_handle(), - data.extent(0), - data.extent(1), - init, - resource::get_cuda_stream(handle), - inplace, - main_op, - reduce_op, - final_op); + detail::coalescedReduction(resource::get_dry_run_flag(handle), + dots.data_handle(), + data.data_handle(), + data.extent(0), + data.extent(1), + init, + resource::get_cuda_stream(handle), + inplace, + main_op, + reduce_op, + final_op); } } diff --git a/cpp/include/raft/linalg/detail/axpy.cuh b/cpp/include/raft/linalg/detail/axpy.cuh index 6347522138..488cad5bec 100644 --- a/cpp/include/raft/linalg/detail/axpy.cuh +++ b/cpp/include/raft/linalg/detail/axpy.cuh @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -26,6 +27,7 @@ void axpy(raft::resources const& handle, const int incy, cudaStream_t stream) { + if (resource::get_dry_run_flag(handle)) { return; } auto cublas_h = resource::get_cublas_handle(handle); cublas_device_pointer_mode pmode(cublas_h); RAFT_CUBLAS_TRY(cublasaxpy(cublas_h, n, alpha, x, incx, y, incy, stream)); diff --git a/cpp/include/raft/linalg/detail/cholesky_r1_update.cuh b/cpp/include/raft/linalg/detail/cholesky_r1_update.cuh index ae1f82a74f..2c3131451c 100644 --- a/cpp/include/raft/linalg/detail/cholesky_r1_update.cuh +++ b/cpp/include/raft/linalg/detail/cholesky_r1_update.cuh @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -54,6 +55,7 @@ void choleskyRank1Update(raft::resources const& handle, *n_bytes = offset + 1 * sizeof(math_t); return; } + if (resource::get_dry_run_flag(handle)) { return; } math_t* s = reinterpret_cast(((char*)workspace) + offset); math_t* L_22 = L + (n - 1) * ld + n - 1; diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh index 4cc549f79e..e1ec857b8a 100644 --- a/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh +++ b/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh @@ -499,7 +499,8 @@ template -void coalescedReductionThick(OutType* dots, +void coalescedReductionThick(bool dry_run, + OutType* dots, const InType* data, IdxType D, IdxType N, @@ -518,6 +519,8 @@ void coalescedReductionThick(OutType* dots, rmm::device_uvector buffer(N * ThickPolicy::BlocksPerRow, stream); + if (dry_run) { return; } + /* We apply a two-step reduction: * 1. coalescedReductionThickKernel reduces the [N x D] input data to [N x BlocksPerRow]. It * applies the main_op but not the final op. @@ -551,7 +554,8 @@ template -void coalescedReductionThickDispatcher(OutType* dots, +void coalescedReductionThickDispatcher(bool dry_run, + OutType* dots, const InType* data, IdxType D, IdxType N, @@ -565,7 +569,7 @@ void coalescedReductionThickDispatcher(OutType* dots, // Note: multiple elements per thread to take advantage of the sequential reduction and loop // unrolling coalescedReductionThick, ReductionThinPolicy<32, 128, 1>>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + dry_run, dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } // Primitive to perform reductions along the coalesced dimension of the matrix, i.e. reduce along @@ -580,7 +584,8 @@ template -void coalescedReduction(OutType* dots, +void coalescedReduction(bool dry_run, + OutType* dots, const InType* data, IdxType D, IdxType N, @@ -601,12 +606,16 @@ void coalescedReduction(OutType* dots, */ const IdxType numSMs = raft::getMultiProcessorCount(); if (D <= IdxType(512) || (N >= IdxType(16) * numSMs && D < IdxType(2048))) { + if (dry_run) { return; } coalescedReductionThinDispatcher( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else if (N < numSMs && D >= IdxType(1 << 17)) { + // Must call through to coalescedReductionThick even in dry-run so workspace + // allocations are recorded (coalescedReductionThick allocates before guarding). coalescedReductionThickDispatcher( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + dry_run, dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else { + if (dry_run) { return; } coalescedReductionMediumDispatcher( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } diff --git a/cpp/include/raft/linalg/detail/cublaslt_wrappers.hpp b/cpp/include/raft/linalg/detail/cublaslt_wrappers.hpp index 2337413fbd..06d087d755 100644 --- a/cpp/include/raft/linalg/detail/cublaslt_wrappers.hpp +++ b/cpp/include/raft/linalg/detail/cublaslt_wrappers.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -284,6 +285,8 @@ template batch_scope( "linalg::matmul(m = %d, n = %d, k = %d)", m, n, k); std::shared_ptr mm_desc{nullptr}; diff --git a/cpp/include/raft/linalg/detail/eig.cuh b/cpp/include/raft/linalg/detail/eig.cuh index 5dca01d87d..38f41deb21 100644 --- a/cpp/include/raft/linalg/detail/eig.cuh +++ b/cpp/include/raft/linalg/detail/eig.cuh @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -45,9 +46,13 @@ void eigDC_legacy(raft::resources const& handle, eig_vals, &lwork)); + // TODO(achirkin): Consider using the workspace resource for these temporary allocations. rmm::device_uvector d_work(lwork, stream); rmm::device_scalar d_dev_info(stream); + // The workspace is already allocated, no more allocation are foreseeable. + if (resource::get_dry_run_flag(handle)) { return; } + raft::matrix::copy(handle, make_device_matrix_view(in, n_rows, n_cols), make_device_matrix_view(eig_vectors, n_rows, n_cols)); @@ -122,6 +127,12 @@ void eigDC(raft::resources const& handle, rmm::device_scalar d_dev_info(stream_new); std::vector h_work(workspaceHost / sizeof(math_t)); + if (resource::get_dry_run_flag(handle)) { + // No more allocations beyond this points, but need to cleanup. + RAFT_CUSOLVER_TRY(cusolverDnDestroyParams(dn_params)); + return; + } + raft::copy(eig_vectors, in, n_rows * n_cols, stream_new); RAFT_CUSOLVER_TRY(cusolverDnxsyevd(cusolverH, @@ -188,7 +199,9 @@ void eigSelDC(raft::resources const& handle, rmm::device_uvector d_work(lwork, stream); rmm::device_scalar d_dev_info(stream); - rmm::device_uvector d_eig_vectors(0, stream); + rmm::device_uvector d_eig_vectors(memUsage == COPY_INPUT ? n_rows * n_cols : 0, stream); + + if (resource::get_dry_run_flag(handle)) { return; } if (memUsage == OVERWRITE_INPUT) { RAFT_CUSOLVER_TRY(cusolverDnsyevdx(cusolverH, @@ -209,7 +222,6 @@ void eigSelDC(raft::resources const& handle, d_dev_info.data(), stream)); } else if (memUsage == COPY_INPUT) { - d_eig_vectors.resize(n_rows * n_cols, stream); raft::matrix::copy(handle, make_device_matrix_view(in, n_rows, n_cols), make_device_matrix_view(eig_vectors, n_rows, n_cols)); @@ -286,6 +298,12 @@ void eigJacobi(raft::resources const& handle, rmm::device_uvector d_work(lwork, stream); rmm::device_scalar dev_info(stream); + if (resource::get_dry_run_flag(handle)) { + // No more allocations beyond this points, but need to cleanup. + RAFT_CUSOLVER_TRY(cusolverDnDestroySyevjInfo(syevj_params)); + return; + } + raft::matrix::copy(handle, make_device_matrix_view(in, n_rows, n_cols), make_device_matrix_view(eig_vectors, n_rows, n_cols)); diff --git a/cpp/include/raft/linalg/detail/gemv.hpp b/cpp/include/raft/linalg/detail/gemv.hpp index 8e5760f706..5ddcbf9ad9 100644 --- a/cpp/include/raft/linalg/detail/gemv.hpp +++ b/cpp/include/raft/linalg/detail/gemv.hpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -32,6 +33,7 @@ void gemv(raft::resources const& handle, const int incy, cudaStream_t stream) { + if (resource::get_dry_run_flag(handle)) { return; } cublasHandle_t cublas_h = resource::get_cublas_handle(handle); detail::cublas_device_pointer_mode pmode(cublas_h); RAFT_CUBLAS_TRY(detail::cublasgemv(cublas_h, @@ -110,6 +112,7 @@ void gemv(raft::resources const& handle, const math_t beta, cudaStream_t stream) { + if (resource::get_dry_run_flag(handle)) { return; } cublasHandle_t cublas_h = resource::get_cublas_handle(handle); cublasOperation_t op_a = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; RAFT_CUBLAS_TRY( diff --git a/cpp/include/raft/linalg/detail/lstsq.cuh b/cpp/include/raft/linalg/detail/lstsq.cuh index 2f0d2aa5c3..8df37527c9 100644 --- a/cpp/include/raft/linalg/detail/lstsq.cuh +++ b/cpp/include/raft/linalg/detail/lstsq.cuh @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -131,6 +132,9 @@ void lstsqSvdQR(raft::resources const& handle, + 1 // devInfo , stream); + + if (resource::get_dry_run_flag(handle)) { return; } + math_t* cusolverWorkSet = workset.data(); math_t* U = cusolverWorkSet + cusolverWorkSetSize; math_t* Vt = U + n_rows * minmn; @@ -205,6 +209,12 @@ void lstsqSvdJacobi(raft::resources const& handle, + 1 // devInfo , stream); + + if (resource::get_dry_run_flag(handle)) { + RAFT_CUSOLVER_TRY(cusolverDnDestroyGesvdjInfo(gesvdj_params)); + return; + } + math_t* cusolverWorkSet = workset.data(); math_t* U = cusolverWorkSet + cusolverWorkSetSize; math_t* V = U + n_rows * minmn; @@ -249,21 +259,27 @@ void lstsqEig(raft::resources const& handle, { rmm::cuda_stream_view mainStream = rmm::cuda_stream_view(stream); rmm::cuda_stream_view multAbStream = resource::get_next_usable_stream(handle); + bool dry_run = resource::get_dry_run_flag(handle); bool concurrent; - // Check if the two streams can run concurrently. This is needed because a legacy default stream - // would synchronize with other blocking streams. To avoid synchronization in such case, we try to - // use an additional stream from the pool. - if (!are_implicitly_synchronized(mainStream, multAbStream)) { - concurrent = true; - } else if (resource::get_stream_pool_size(handle) > 1) { - mainStream = resource::get_next_usable_stream(handle); - concurrent = true; + if (dry_run) { + concurrent = false; } else { - multAbStream = mainStream; - concurrent = false; + // Check if the two streams can run concurrently. This is needed because a legacy default stream + // would synchronize with other blocking streams. To avoid synchronization in such case, we try + // to use an additional stream from the pool. + if (!are_implicitly_synchronized(mainStream, multAbStream)) { + concurrent = true; + } else if (resource::get_stream_pool_size(handle) > 1) { + mainStream = resource::get_next_usable_stream(handle); + concurrent = true; + } else { + multAbStream = mainStream; + concurrent = false; + } } rmm::device_uvector workset(n_cols * n_cols * 3 + n_cols * 2, mainStream); + // the event is created only if the given raft handle is capable of running // at least two CUDA streams without implicit synchronization. DeviceEvent worksetDone(concurrent); @@ -303,8 +319,8 @@ void lstsqEig(raft::resources const& handle, raft::common::nvtx::pop_range(); // QS <- Q invS - raft::linalg::matrixVectorOp( - QS, Q, S, n_cols, n_cols, DivideByNonZero(), mainStream); + raft::linalg::detail::matrixVectorOp( + dry_run, QS, Q, S, n_cols, n_cols, DivideByNonZero(), mainStream); // covA <- QS Q* == Q invS Q* == inv(A* A) raft::linalg::gemm(handle, QS, @@ -393,6 +409,8 @@ void lstsqQR(raft::resources const& handle, rmm::device_uvector d_work(lwork, stream); + if (resource::get_dry_run_flag(handle)) { return; } + // #TODO: Call from public API when ready RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDngeqrf( cusolverH, m, n, A, lda, d_tau.data(), d_work.data(), lwork, d_info.data(), stream)); diff --git a/cpp/include/raft/linalg/detail/map.cuh b/cpp/include/raft/linalg/detail/map.cuh index 714869aaa5..97df85a3ff 100644 --- a/cpp/include/raft/linalg/detail/map.cuh +++ b/cpp/include/raft/linalg/detail/map.cuh @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -208,6 +209,7 @@ template > void map(const raft::resources& res, OutType out, Func f, InTypes... ins) { + if (resource::get_dry_run_flag(res)) { return; } RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); (map_check_shape(out, ins), ...); diff --git a/cpp/include/raft/linalg/detail/matrix_vector_op.cuh b/cpp/include/raft/linalg/detail/matrix_vector_op.cuh index af9632a7da..c238d0961e 100644 --- a/cpp/include/raft/linalg/detail/matrix_vector_op.cuh +++ b/cpp/include/raft/linalg/detail/matrix_vector_op.cuh @@ -7,6 +7,7 @@ #include #include +#include #include namespace raft { @@ -20,7 +21,8 @@ template -void matrixVectorOp(MatT* out, +void matrixVectorOp(bool dry_run, + MatT* out, const MatT* matrix, const VecT* vec, IdxType D, @@ -28,6 +30,7 @@ void matrixVectorOp(MatT* out, Lambda op, cudaStream_t stream) { + if (dry_run) { return; } raft::resources handle; resource::set_cuda_stream(handle, stream); constexpr raft::Apply apply = @@ -57,7 +60,8 @@ template -void matrixVectorOp(MatT* out, +void matrixVectorOp(bool dry_run, + MatT* out, const MatT* matrix, const Vec1T* vec1, const Vec2T* vec2, @@ -66,6 +70,7 @@ void matrixVectorOp(MatT* out, Lambda op, cudaStream_t stream) { + if (dry_run) { return; } raft::resources handle; resource::set_cuda_stream(handle, stream); constexpr raft::Apply apply = diff --git a/cpp/include/raft/linalg/detail/norm.cuh b/cpp/include/raft/linalg/detail/norm.cuh index 782438fdd2..9a563ee23d 100644 --- a/cpp/include/raft/linalg/detail/norm.cuh +++ b/cpp/include/raft/linalg/detail/norm.cuh @@ -20,18 +20,23 @@ template -void rowNormCaller( - OutType* dots, const Type* data, IdxType D, IdxType N, cudaStream_t stream, Lambda fin_op) +void rowNormCaller(bool dry_run, + OutType* dots, + const Type* data, + IdxType D, + IdxType N, + cudaStream_t stream, + Lambda fin_op) { if constexpr (norm_type == L1Norm) { - raft::linalg::reduce( - dots, data, D, N, (OutType)0, stream, false, raft::abs_op(), raft::add_op(), fin_op); + reduce( + dry_run, dots, data, D, N, (OutType)0, stream, false, raft::abs_op(), raft::add_op(), fin_op); } else if constexpr (norm_type == L2Norm) { - raft::linalg::reduce( - dots, data, D, N, (OutType)0, stream, false, raft::sq_op(), raft::add_op(), fin_op); + reduce( + dry_run, dots, data, D, N, (OutType)0, stream, false, raft::sq_op(), raft::add_op(), fin_op); } else if constexpr (norm_type == LinfNorm) { - raft::linalg::reduce( - dots, data, D, N, (OutType)0, stream, false, raft::abs_op(), raft::max_op(), fin_op); + reduce( + dry_run, dots, data, D, N, (OutType)0, stream, false, raft::abs_op(), raft::max_op(), fin_op); } else { THROW("Unsupported norm type: %d", norm_type); } @@ -43,18 +48,23 @@ template -void colNormCaller( - OutType* dots, const Type* data, IdxType D, IdxType N, cudaStream_t stream, Lambda fin_op) +void colNormCaller(bool dry_run, + OutType* dots, + const Type* data, + IdxType D, + IdxType N, + cudaStream_t stream, + Lambda fin_op) { if constexpr (norm_type == L1Norm) { - raft::linalg::reduce( - dots, data, D, N, (OutType)0, stream, false, raft::abs_op(), raft::add_op(), fin_op); + reduce( + dry_run, dots, data, D, N, (OutType)0, stream, false, raft::abs_op(), raft::add_op(), fin_op); } else if constexpr (norm_type == L2Norm) { - raft::linalg::reduce( - dots, data, D, N, (OutType)0, stream, false, raft::sq_op(), raft::add_op(), fin_op); + reduce( + dry_run, dots, data, D, N, (OutType)0, stream, false, raft::sq_op(), raft::add_op(), fin_op); } else if constexpr (norm_type == LinfNorm) { - raft::linalg::reduce( - dots, data, D, N, (OutType)0, stream, false, raft::abs_op(), raft::max_op(), fin_op); + reduce( + false, dots, data, D, N, (OutType)0, stream, false, raft::abs_op(), raft::max_op(), fin_op); } else { THROW("Unsupported norm type: %d", norm_type); } diff --git a/cpp/include/raft/linalg/detail/qr.cuh b/cpp/include/raft/linalg/detail/qr.cuh index bf981ecae0..41e6ad87fd 100644 --- a/cpp/include/raft/linalg/detail/qr.cuh +++ b/cpp/include/raft/linalg/detail/qr.cuh @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -40,15 +41,26 @@ void qrGetQ_inplace( { RAFT_EXPECTS(n_rows >= n_cols, "QR decomposition expects n_rows >= n_cols."); cusolverDnHandle_t cusolver = resource::get_cusolver_dn_handle(handle); + auto is_dry_run = resource::get_dry_run_flag(handle); rmm::device_uvector tau(n_cols, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(tau.data(), 0, sizeof(math_t) * n_cols, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(tau.data(), 0, sizeof(math_t) * n_cols, stream)); + } rmm::device_scalar dev_info(stream); - int ws_size; + int ws_size_Dngeqrf; + int ws_size_Dnorgqr; + + RAFT_CUSOLVER_TRY( + cusolverDngeqrf_bufferSize(cusolver, n_rows, n_cols, Q, n_rows, &ws_size_Dngeqrf)); + RAFT_CUSOLVER_TRY(cusolverDnorgqr_bufferSize( + cusolver, n_rows, n_cols, n_cols, Q, n_rows, tau.data(), &ws_size_Dnorgqr)); + + rmm::device_uvector workspace(std::max(ws_size_Dngeqrf, ws_size_Dnorgqr), stream); + + if (is_dry_run) { return; } - RAFT_CUSOLVER_TRY(cusolverDngeqrf_bufferSize(cusolver, n_rows, n_cols, Q, n_rows, &ws_size)); - rmm::device_uvector workspace(ws_size, stream); RAFT_CUSOLVER_TRY(cusolverDngeqrf(cusolver, n_rows, n_cols, @@ -56,13 +68,10 @@ void qrGetQ_inplace( n_rows, tau.data(), workspace.data(), - ws_size, + ws_size_Dngeqrf, dev_info.data(), stream)); - RAFT_CUSOLVER_TRY( - cusolverDnorgqr_bufferSize(cusolver, n_rows, n_cols, n_cols, Q, n_rows, tau.data(), &ws_size)); - workspace.resize(ws_size, stream); RAFT_CUSOLVER_TRY(cusolverDnorgqr(cusolver, n_rows, n_cols, @@ -71,7 +80,7 @@ void qrGetQ_inplace( n_rows, tau.data(), workspace.data(), - ws_size, + ws_size_Dnorgqr, dev_info.data(), stream)); } @@ -84,7 +93,7 @@ void qrGetQ(raft::resources const& handle, int n_cols, cudaStream_t stream) { - raft::copy(Q, M, n_rows * n_cols, stream); + if (!resource::get_dry_run_flag(handle)) { raft::copy(Q, M, n_rows * n_cols, stream); } qrGetQ_inplace(handle, Q, n_rows, n_cols, stream); } @@ -100,19 +109,32 @@ void qrGetQR(raft::resources const& handle, cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle); int m = n_rows, n = n_cols; + int R_full_nrows = m, R_full_ncols = n; + int Q_nrows = m, Q_ncols = n; + int Lwork_Dngeqrf, Lwork_Dnorgqr; rmm::device_uvector R_full(m * n, stream); rmm::device_uvector tau(std::min(m, n), stream); + rmm::device_scalar devInfo(stream); + + RAFT_CUSOLVER_TRY(cusolverDngeqrf_bufferSize( + cusolverH, R_full_nrows, R_full_ncols, R_full.data(), R_full_nrows, &Lwork_Dngeqrf)); + RAFT_CUSOLVER_TRY(cusolverDnorgqr_bufferSize(cusolverH, + Q_nrows, + Q_ncols, + std::min(Q_ncols, Q_nrows), + Q, + Q_nrows, + tau.data(), + &Lwork_Dnorgqr)); + + rmm::device_uvector workspace(std::max(Lwork_Dngeqrf, Lwork_Dnorgqr), stream); + + if (resource::get_dry_run_flag(handle)) { return; } + RAFT_CUDA_TRY(cudaMemsetAsync(tau.data(), 0, sizeof(math_t) * std::min(m, n), stream)); - int R_full_nrows = m, R_full_ncols = n; RAFT_CUDA_TRY( cudaMemcpyAsync(R_full.data(), M, sizeof(math_t) * m * n, cudaMemcpyDeviceToDevice, stream)); - int Lwork; - rmm::device_scalar devInfo(stream); - - RAFT_CUSOLVER_TRY(cusolverDngeqrf_bufferSize( - cusolverH, R_full_nrows, R_full_ncols, R_full.data(), R_full_nrows, &Lwork)); - rmm::device_uvector workspace(Lwork, stream); RAFT_CUSOLVER_TRY(cusolverDngeqrf(cusolverH, R_full_nrows, R_full_ncols, @@ -120,7 +142,7 @@ void qrGetQR(raft::resources const& handle, R_full_nrows, tau.data(), workspace.data(), - Lwork, + Lwork_Dngeqrf, devInfo.data(), stream)); @@ -131,11 +153,7 @@ void qrGetQR(raft::resources const& handle, RAFT_CUDA_TRY( cudaMemcpyAsync(Q, R_full.data(), sizeof(math_t) * m * n, cudaMemcpyDeviceToDevice, stream)); - int Q_nrows = m, Q_ncols = n; - RAFT_CUSOLVER_TRY(cusolverDnorgqr_bufferSize( - cusolverH, Q_nrows, Q_ncols, std::min(Q_ncols, Q_nrows), Q, Q_nrows, tau.data(), &Lwork)); - workspace.resize(Lwork, stream); RAFT_CUSOLVER_TRY(cusolverDnorgqr(cusolverH, Q_nrows, Q_ncols, @@ -144,7 +162,7 @@ void qrGetQR(raft::resources const& handle, Q_nrows, tau.data(), workspace.data(), - Lwork, + Lwork_Dnorgqr, devInfo.data(), stream)); } diff --git a/cpp/include/raft/linalg/detail/reduce.cuh b/cpp/include/raft/linalg/detail/reduce.cuh index 2a689649b4..f58dc12f67 100644 --- a/cpp/include/raft/linalg/detail/reduce.cuh +++ b/cpp/include/raft/linalg/detail/reduce.cuh @@ -22,7 +22,8 @@ template -void reduce(OutType* dots, +void reduce(bool dry_run, + OutType* dots, const InType* data, IdxType D, IdxType N, @@ -34,17 +35,19 @@ void reduce(OutType* dots, FinalLambda final_op = raft::identity_op()) { if constexpr (rowMajor && alongRows) { - raft::linalg::coalescedReduction( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + coalescedReduction( + dry_run, dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else if constexpr (rowMajor && !alongRows) { + if (dry_run) { return; } // no allocations in strided reduction raft::linalg::stridedReduction( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else if constexpr (!rowMajor && alongRows) { + if (dry_run) { return; } // no allocations in strided reduction raft::linalg::stridedReduction( dots, data, N, D, init, stream, inplace, main_op, reduce_op, final_op); } else { - raft::linalg::coalescedReduction( - dots, data, N, D, init, stream, inplace, main_op, reduce_op, final_op); + coalescedReduction( + dry_run, dots, data, N, D, init, stream, inplace, main_op, reduce_op, final_op); } } diff --git a/cpp/include/raft/linalg/detail/rsvd.cuh b/cpp/include/raft/linalg/detail/rsvd.cuh index 8adf3bfb48..7220feea6a 100644 --- a/cpp/include/raft/linalg/detail/rsvd.cuh +++ b/cpp/include/raft/linalg/detail/rsvd.cuh @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -86,6 +87,8 @@ void randomized_svd(const raft::resources& handle, auto h_workspace = raft::make_host_vector(workspaceHost); auto devInfo = raft::make_device_scalar(handle, 0); + if (resource::get_dry_run_flag(handle)) { return; } + RAFT_CUSOLVER_TRY(cusolverDnxgesvdr(cusolverH, jobu, jobv, @@ -155,6 +158,7 @@ void rsvdFixedRank(raft::resources const& handle, int max_sweeps, cudaStream_t stream) { + bool is_dry_run = resource::get_dry_run_flag(handle); cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle); cublasHandle_t cublasH = resource::get_cublas_handle(handle); @@ -172,7 +176,9 @@ void rsvdFixedRank(raft::resources const& handle, // Build temporary U, S, V matrices rmm::device_uvector S_vec_tmp(l, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(S_vec_tmp.data(), 0, sizeof(math_t) * l, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(S_vec_tmp.data(), 0, sizeof(math_t) * l, stream)); + } // build random matrix rmm::device_uvector RN(n * l, stream); @@ -188,9 +194,11 @@ void rsvdFixedRank(raft::resources const& handle, rmm::device_uvector Z(n * l, stream); rmm::device_uvector Yorth(m * l, stream); rmm::device_uvector Zorth(n * l, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(Z.data(), 0, sizeof(math_t) * n * l, stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(Yorth.data(), 0, sizeof(math_t) * m * l, stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(Zorth.data(), 0, sizeof(math_t) * n * l, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(Z.data(), 0, sizeof(math_t) * n * l, stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(Yorth.data(), 0, sizeof(math_t) * m * l, stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(Zorth.data(), 0, sizeof(math_t) * n * l, stream)); + } // power sampling scheme for (int j = 1; j < q; j++) { @@ -237,30 +245,40 @@ void rsvdFixedRank(raft::resources const& handle, // orthogonalize on exit from loop to get Q rmm::device_uvector Q(m * l, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(Q.data(), 0, sizeof(math_t) * m * l, stream)); + if (!is_dry_run) { RAFT_CUDA_TRY(cudaMemsetAsync(Q.data(), 0, sizeof(math_t) * m * l, stream)); } raft::linalg::qrGetQ(handle, Y.data(), Q.data(), m, l, stream); // either QR of B^T method, or eigendecompose BB^T method if (!use_bbt) { // form Bt = Mt*Q : nxm * mxl = nxl rmm::device_uvector Bt(n * l, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(Bt.data(), 0, sizeof(math_t) * n * l, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(Bt.data(), 0, sizeof(math_t) * n * l, stream)); + } raft::linalg::gemm( handle, M, m, n, Q.data(), Bt.data(), n, l, CUBLAS_OP_T, CUBLAS_OP_N, alpha, beta, stream); // compute QR factorization of Bt // M is mxn ; Q is mxn ; R is min(m,n) x min(m,n) */ rmm::device_uvector Qhat(n * l, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(Qhat.data(), 0, sizeof(math_t) * n * l, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(Qhat.data(), 0, sizeof(math_t) * n * l, stream)); + } rmm::device_uvector Rhat(l * l, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(Rhat.data(), 0, sizeof(math_t) * l * l, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(Rhat.data(), 0, sizeof(math_t) * l * l, stream)); + } raft::linalg::qrGetQR(handle, Bt.data(), Qhat.data(), Rhat.data(), n, l, stream); // compute SVD of Rhat (lxl) rmm::device_uvector Uhat(l * l, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(Uhat.data(), 0, sizeof(math_t) * l * l, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(Uhat.data(), 0, sizeof(math_t) * l * l, stream)); + } rmm::device_uvector Vhat(l * l, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(Vhat.data(), 0, sizeof(math_t) * l * l, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(Vhat.data(), 0, sizeof(math_t) * l * l, stream)); + } if (use_jacobi) raft::linalg::svdJacobi(handle, Rhat.data(), @@ -351,9 +369,13 @@ void rsvdFixedRank(raft::resources const& handle, // compute eigendecomposition of BBt rmm::device_uvector Uhat(l * l, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(Uhat.data(), 0, sizeof(math_t) * l * l, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(Uhat.data(), 0, sizeof(math_t) * l * l, stream)); + } rmm::device_uvector Uhat_dup(l * l, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(Uhat_dup.data(), 0, sizeof(math_t) * l * l, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(Uhat_dup.data(), 0, sizeof(math_t) * l * l, stream)); + } raft::matrix::upper_triangular( handle, @@ -398,9 +420,13 @@ void rsvdFixedRank(raft::resources const& handle, // Sigma^{-1}[(p+1):l, (p+1):l] nxl * lxk * kxk = nxk if (gen_right_vec) { rmm::device_uvector Sinv(k * k, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(Sinv.data(), 0, sizeof(math_t) * k * k, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(Sinv.data(), 0, sizeof(math_t) * k * k, stream)); + } rmm::device_uvector UhatSinv(l * k, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(UhatSinv.data(), 0, sizeof(math_t) * l * k, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(UhatSinv.data(), 0, sizeof(math_t) * l * k, stream)); + } math_t scalar = 1.0; raft::matrix::reciprocal( handle, diff --git a/cpp/include/raft/linalg/detail/svd.cuh b/cpp/include/raft/linalg/detail/svd.cuh index 15396324cc..7589edd6f9 100644 --- a/cpp/include/raft/linalg/detail/svd.cuh +++ b/cpp/include/raft/linalg/detail/svd.cuh @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -60,6 +61,8 @@ void svdQR(raft::resources const& handle, RAFT_CUSOLVER_TRY(cusolverDngesvd_bufferSize(cusolverH, n_rows, n_cols, &lwork)); rmm::device_uvector d_work(lwork, stream); + if (resource::get_dry_run_flag(handle)) { return; } + char jobu = 'S'; char jobvt = 'A'; @@ -217,6 +220,11 @@ void svdJacobi(raft::resources const& handle, rmm::device_uvector d_work(lwork, stream); + if (resource::get_dry_run_flag(handle)) { + RAFT_CUSOLVER_TRY(cusolverDnDestroyGesvdjInfo(gesvdj_params)); + return; + } + RAFT_CUSOLVER_TRY(cusolverDngesvdj(cusolverH, CUSOLVER_EIG_MODE_VECTOR, econ, @@ -281,16 +289,19 @@ bool evaluateSVDByL2Norm(raft::resources const& handle, math_t tol, cudaStream_t stream) { - cublasHandle_t cublasH = resource::get_cublas_handle(handle); - int m = n_rows, n = n_cols; + bool is_dry_run = resource::get_dry_run_flag(handle); // form product matrix rmm::device_uvector P_d(m * n, stream); rmm::device_uvector S_mat(k * k, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(P_d.data(), 0, sizeof(math_t) * m * n, stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(S_mat.data(), 0, sizeof(math_t) * k * k, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(P_d.data(), 0, sizeof(math_t) * m * n, stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(S_mat.data(), 0, sizeof(math_t) * k * k, stream)); + } + + // These RAFT functions have their own dry-run guards at the leaf level raft::matrix::set_diagonal(handle, make_device_vector_view(S_vec, k), make_device_matrix_view(S_mat.data(), k, k)); @@ -308,8 +319,12 @@ bool evaluateSVDByL2Norm(raft::resources const& handle, // calculate percent error const math_t alpha = 1.0, beta = -1.0; rmm::device_uvector A_minus_P(m * n, stream); + + if (is_dry_run) { return false; } + RAFT_CUDA_TRY(cudaMemsetAsync(A_minus_P.data(), 0, sizeof(math_t) * m * n, stream)); + cublasHandle_t cublasH = resource::get_cublas_handle(handle); RAFT_CUBLAS_TRY(cublasgeam(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, diff --git a/cpp/include/raft/linalg/detail/transpose.cuh b/cpp/include/raft/linalg/detail/transpose.cuh index 82fdb1c6f7..bf068d7049 100644 --- a/cpp/include/raft/linalg/detail/transpose.cuh +++ b/cpp/include/raft/linalg/detail/transpose.cuh @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -88,6 +89,7 @@ void transpose_half(raft::resources const& handle, const IndexType stride_out = 1) { if (n_cols == 0 || n_rows == 0) return; + if (resource::get_dry_run_flag(handle)) { return; } auto stream = resource::get_cuda_stream(handle); int dev_id, sm_count; @@ -135,6 +137,7 @@ void transpose(raft::resources const& handle, int n_cols, cudaStream_t stream) { + if (resource::get_dry_run_flag(handle)) { return; } int out_n_rows = n_cols; int out_n_cols = n_rows; @@ -189,6 +192,7 @@ void transpose_row_major_impl( raft::mdspan, LayoutPolicy, AccessorPolicy> in, raft::mdspan, LayoutPolicy, AccessorPolicy> out) { + if (resource::get_dry_run_flag(handle)) { return; } auto out_n_rows = in.extent(1); auto out_n_cols = in.extent(0); T constexpr kOne = 1; @@ -231,6 +235,7 @@ void transpose_col_major_impl( raft::mdspan, LayoutPolicy, AccessorPolicy> in, raft::mdspan, LayoutPolicy, AccessorPolicy> out) { + if (resource::get_dry_run_flag(handle)) { return; } auto out_n_rows = in.extent(1); auto out_n_cols = in.extent(0); T constexpr kOne = 1; diff --git a/cpp/include/raft/linalg/divide.cuh b/cpp/include/raft/linalg/divide.cuh index b5cbacbce3..cbe5aec0f3 100644 --- a/cpp/include/raft/linalg/divide.cuh +++ b/cpp/include/raft/linalg/divide.cuh @@ -12,6 +12,7 @@ #include #include #include +#include #include #include @@ -62,6 +63,7 @@ void divide_scalar(raft::resources const& handle, OutType out, raft::host_scalar_view scalar) { + if (resource::get_dry_run_flag(handle)) { return; } using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; diff --git a/cpp/include/raft/linalg/dot.cuh b/cpp/include/raft/linalg/dot.cuh index c8684341a8..b0e4792338 100644 --- a/cpp/include/raft/linalg/dot.cuh +++ b/cpp/include/raft/linalg/dot.cuh @@ -12,6 +12,7 @@ #include #include #include +#include #include #include @@ -42,6 +43,7 @@ void dot(raft::resources const& handle, { RAFT_EXPECTS(x.size() == y.size(), "Size mismatch between x and y input vectors in raft::linalg::dot"); + if (resource::get_dry_run_flag(handle)) { return; } RAFT_CUBLAS_TRY(detail::cublasdot(resource::get_cublas_handle(handle), x.size(), @@ -72,6 +74,7 @@ void dot(raft::resources const& handle, { RAFT_EXPECTS(x.size() == y.size(), "Size mismatch between x and y input vectors in raft::linalg::dot"); + if (resource::get_dry_run_flag(handle)) { return; } RAFT_CUBLAS_TRY(detail::cublasdot(resource::get_cublas_handle(handle), x.size(), diff --git a/cpp/include/raft/linalg/map_reduce.cuh b/cpp/include/raft/linalg/map_reduce.cuh index 66d8a1d6a2..2a678738ea 100644 --- a/cpp/include/raft/linalg/map_reduce.cuh +++ b/cpp/include/raft/linalg/map_reduce.cuh @@ -12,6 +12,7 @@ #include #include #include +#include namespace raft { namespace linalg { @@ -91,6 +92,7 @@ void map_reduce(raft::resources const& handle, ReduceLambda op, Args... args) { + if (resource::get_dry_run_flag(handle)) { return; } mapReduce( out.data_handle(), in.extent(0), diff --git a/cpp/include/raft/linalg/matrix_vector_op.cuh b/cpp/include/raft/linalg/matrix_vector_op.cuh index 6eca1ea9e8..766d2a433b 100644 --- a/cpp/include/raft/linalg/matrix_vector_op.cuh +++ b/cpp/include/raft/linalg/matrix_vector_op.cuh @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -57,7 +58,7 @@ void matrixVectorOp(MatT* out, Lambda op, cudaStream_t stream) { - detail::matrixVectorOp(out, matrix, vec, D, N, op, stream); + detail::matrixVectorOp(false, out, matrix, vec, D, N, op, stream); } /** @@ -101,7 +102,8 @@ void matrixVectorOp(MatT* out, Lambda op, cudaStream_t stream) { - detail::matrixVectorOp(out, matrix, vec1, vec2, D, N, op, stream); + detail::matrixVectorOp( + false, out, matrix, vec1, vec2, D, N, op, stream); } /** @@ -157,13 +159,14 @@ void matrix_vector_op(raft::resources const& handle, "Size mismatch between matrix and vector"); } - matrixVectorOp(out.data_handle(), - matrix.data_handle(), - vec.data_handle(), - out.extent(1), - out.extent(0), - op, - resource::get_cuda_stream(handle)); + detail::matrixVectorOp(resource::get_dry_run_flag(handle), + out.data_handle(), + matrix.data_handle(), + vec.data_handle(), + out.extent(1), + out.extent(0), + op, + resource::get_cuda_stream(handle)); } /** @@ -222,14 +225,15 @@ void matrix_vector_op(raft::resources const& handle, "Size mismatch between matrix and vector"); } - matrixVectorOp(out.data_handle(), - matrix.data_handle(), - vec1.data_handle(), - vec2.data_handle(), - out.extent(1), - out.extent(0), - op, - resource::get_cuda_stream(handle)); + detail::matrixVectorOp(resource::get_dry_run_flag(handle), + out.data_handle(), + matrix.data_handle(), + vec1.data_handle(), + vec2.data_handle(), + out.extent(1), + out.extent(0), + op, + resource::get_cuda_stream(handle)); } /** @} */ // end of group matrix_vector_op diff --git a/cpp/include/raft/linalg/mean_squared_error.cuh b/cpp/include/raft/linalg/mean_squared_error.cuh index f14a64a7c8..b700e92495 100644 --- a/cpp/include/raft/linalg/mean_squared_error.cuh +++ b/cpp/include/raft/linalg/mean_squared_error.cuh @@ -12,6 +12,7 @@ #include #include #include +#include namespace raft { namespace linalg { @@ -58,6 +59,7 @@ void mean_squared_error(raft::resources const& handle, raft::device_scalar_view out, OutValueType weight) { + if (resource::get_dry_run_flag(handle)) { return; } RAFT_EXPECTS(A.size() == B.size(), "Size mismatch between inputs"); meanSquaredError(out.data_handle(), diff --git a/cpp/include/raft/linalg/multiply.cuh b/cpp/include/raft/linalg/multiply.cuh index 30d9be2611..7a901500a0 100644 --- a/cpp/include/raft/linalg/multiply.cuh +++ b/cpp/include/raft/linalg/multiply.cuh @@ -13,6 +13,7 @@ #include #include #include +#include #include namespace raft { @@ -64,6 +65,7 @@ void multiply_scalar( OutType out, raft::host_scalar_view scalar) { + if (resource::get_dry_run_flag(handle)) { return; } using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; diff --git a/cpp/include/raft/linalg/norm.cuh b/cpp/include/raft/linalg/norm.cuh index 7395c41925..a98e61d72a 100644 --- a/cpp/include/raft/linalg/norm.cuh +++ b/cpp/include/raft/linalg/norm.cuh @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -55,7 +56,7 @@ void rowNorm(OutType* dots, cudaStream_t stream, Lambda fin_op = raft::identity_op()) { - detail::rowNormCaller(dots, data, D, N, stream, fin_op); + detail::rowNormCaller(false, dots, data, D, N, stream, fin_op); } /** @@ -86,7 +87,7 @@ void colNorm(OutType* dots, cudaStream_t stream, Lambda fin_op = raft::identity_op()) { - detail::colNormCaller(dots, data, D, N, stream, fin_op); + detail::colNormCaller(false, dots, data, D, N, stream, fin_op); } /** @@ -129,21 +130,23 @@ void norm(raft::resources const& handle, if constexpr (along_rows) { RAFT_EXPECTS(static_cast(out.size()) == in.extent(0), "Output should be equal to number of rows in Input"); - rowNorm(out.data_handle(), - in.data_handle(), - in.extent(1), - in.extent(0), - resource::get_cuda_stream(handle), - fin_op); + detail::rowNormCaller(resource::get_dry_run_flag(handle), + out.data_handle(), + in.data_handle(), + in.extent(1), + in.extent(0), + resource::get_cuda_stream(handle), + fin_op); } else { RAFT_EXPECTS(static_cast(out.size()) == in.extent(1), "Output should be equal to number of columns in Input"); - colNorm(out.data_handle(), - in.data_handle(), - in.extent(1), - in.extent(0), - resource::get_cuda_stream(handle), - fin_op); + detail::colNormCaller(resource::get_dry_run_flag(handle), + out.data_handle(), + in.data_handle(), + in.extent(1), + in.extent(0), + resource::get_cuda_stream(handle), + fin_op); } } diff --git a/cpp/include/raft/linalg/normalize.cuh b/cpp/include/raft/linalg/normalize.cuh index ca1f65b26c..6e9cde8bad 100644 --- a/cpp/include/raft/linalg/normalize.cuh +++ b/cpp/include/raft/linalg/normalize.cuh @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -54,6 +55,7 @@ void row_normalize(raft::resources const& handle, FinalLambda fin_op, ElementType eps = ElementType(1e-8)) { + if (resource::get_dry_run_flag(handle)) { return; } RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous"); RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); RAFT_EXPECTS(in.extent(0) == out.extent(0), diff --git a/cpp/include/raft/linalg/power.cuh b/cpp/include/raft/linalg/power.cuh index f3ddc4037a..5f1cc2d2ac 100644 --- a/cpp/include/raft/linalg/power.cuh +++ b/cpp/include/raft/linalg/power.cuh @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -75,6 +76,7 @@ template > void power(raft::resources const& handle, InType in1, InType in2, OutType out) { + if (resource::get_dry_run_flag(handle)) { return; } using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; @@ -113,6 +115,7 @@ void power_scalar( OutType out, const raft::host_scalar_view scalar) { + if (resource::get_dry_run_flag(handle)) { return; } using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; diff --git a/cpp/include/raft/linalg/reduce.cuh b/cpp/include/raft/linalg/reduce.cuh index 63db7d3ce6..6ae82d5a17 100644 --- a/cpp/include/raft/linalg/reduce.cuh +++ b/cpp/include/raft/linalg/reduce.cuh @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -72,7 +73,7 @@ void reduce(OutType* dots, FinalLambda final_op = raft::identity_op()) { detail::reduce( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + false, dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } /** @@ -167,16 +168,18 @@ void reduce(raft::resources const& handle, "Output should be equal to number of columns in Input"); } - reduce(dots.data_handle(), - data.data_handle(), - data.extent(1), - data.extent(0), - init, - resource::get_cuda_stream(handle), - inplace, - main_op, - reduce_op, - final_op); + detail::reduce( + resource::get_dry_run_flag(handle), + dots.data_handle(), + data.data_handle(), + data.extent(1), + data.extent(0), + init, + resource::get_cuda_stream(handle), + inplace, + main_op, + reduce_op, + final_op); } /** @} */ // end of group reduction diff --git a/cpp/include/raft/linalg/reduce_cols_by_key.cuh b/cpp/include/raft/linalg/reduce_cols_by_key.cuh index 07759ec206..3eda80c1a9 100644 --- a/cpp/include/raft/linalg/reduce_cols_by_key.cuh +++ b/cpp/include/raft/linalg/reduce_cols_by_key.cuh @@ -12,6 +12,7 @@ #include #include #include +#include #include namespace raft { @@ -82,6 +83,7 @@ void reduce_cols_by_key( IndexType nkeys = 0, bool reset_sums = true) { + if (resource::get_dry_run_flag(handle)) { return; } if (nkeys > 0) { RAFT_EXPECTS(out.extent(1) == nkeys, "Output doesn't have nkeys columns"); } else { diff --git a/cpp/include/raft/linalg/reduce_rows_by_key.cuh b/cpp/include/raft/linalg/reduce_rows_by_key.cuh index dd2f54c7bc..61bce8bb03 100644 --- a/cpp/include/raft/linalg/reduce_rows_by_key.cuh +++ b/cpp/include/raft/linalg/reduce_rows_by_key.cuh @@ -12,6 +12,7 @@ #include #include #include +#include #include namespace raft { @@ -148,6 +149,7 @@ void reduce_rows_by_key( std::optional> d_weights = std::nullopt, bool reset_sums = true) { + if (resource::get_dry_run_flag(handle)) { return; } RAFT_EXPECTS(d_A.extent(0) == d_A.extent(0) && d_sums.extent(1) == n_unique_keys, "Output is not of size ncols * n_unique_keys"); RAFT_EXPECTS(d_keys.extent(0) == d_A.extent(1), "Keys is not of size nrows"); diff --git a/cpp/include/raft/linalg/sqrt.cuh b/cpp/include/raft/linalg/sqrt.cuh index c571b68ae5..7bc1e2f4bd 100644 --- a/cpp/include/raft/linalg/sqrt.cuh +++ b/cpp/include/raft/linalg/sqrt.cuh @@ -11,6 +11,7 @@ #include #include #include +#include #include namespace raft { @@ -52,6 +53,7 @@ template > void sqrt(raft::resources const& handle, InType in, OutType out) { + if (resource::get_dry_run_flag(handle)) { return; } using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; diff --git a/cpp/include/raft/linalg/strided_reduction.cuh b/cpp/include/raft/linalg/strided_reduction.cuh index 9480eb9fa0..bd293aff36 100644 --- a/cpp/include/raft/linalg/strided_reduction.cuh +++ b/cpp/include/raft/linalg/strided_reduction.cuh @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -128,6 +129,7 @@ void strided_reduction(raft::resources const& handle, ReduceLambda reduce_op = raft::add_op(), FinalLambda final_op = raft::identity_op()) { + if (resource::get_dry_run_flag(handle)) { return; } if constexpr (std::is_same_v) { RAFT_EXPECTS(static_cast(dots.size()) == data.extent(1), "Output should be equal to number of columns in Input"); diff --git a/cpp/include/raft/linalg/subtract.cuh b/cpp/include/raft/linalg/subtract.cuh index 8e1b9ca9db..51b66ffbd2 100644 --- a/cpp/include/raft/linalg/subtract.cuh +++ b/cpp/include/raft/linalg/subtract.cuh @@ -14,6 +14,7 @@ #include #include #include +#include #include namespace raft { @@ -99,6 +100,7 @@ template > void subtract(raft::resources const& handle, InType in1, InType in2, OutType out) { + if (resource::get_dry_run_flag(handle)) { return; } using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; @@ -137,6 +139,7 @@ void subtract_scalar( OutType out, raft::device_scalar_view scalar) { + if (resource::get_dry_run_flag(handle)) { return; } using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; @@ -173,6 +176,7 @@ void subtract_scalar( OutType out, raft::host_scalar_view scalar) { + if (resource::get_dry_run_flag(handle)) { return; } using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; diff --git a/cpp/include/raft/linalg/unary_op.cuh b/cpp/include/raft/linalg/unary_op.cuh index abba6113a1..efa3082b88 100644 --- a/cpp/include/raft/linalg/unary_op.cuh +++ b/cpp/include/raft/linalg/unary_op.cuh @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -110,6 +111,7 @@ template > void write_only_unary_op(const raft::resources& handle, OutType out, Lambda op) { + if (resource::get_dry_run_flag(handle)) { return; } return writeOnlyUnaryOp(out.data_handle(), out.size(), op, resource::get_cuda_stream(handle)); } diff --git a/cpp/include/raft/matrix/argmax.cuh b/cpp/include/raft/matrix/argmax.cuh index 83736ba2c0..0337edce02 100644 --- a/cpp/include/raft/matrix/argmax.cuh +++ b/cpp/include/raft/matrix/argmax.cuh @@ -8,6 +8,7 @@ #include #include #include +#include #include namespace raft { @@ -29,6 +30,7 @@ void argmax(raft::resources const& handle, raft::device_matrix_view in, raft::device_vector_view out) { + if (resource::get_dry_run_flag(handle)) { return; } RAFT_EXPECTS(out.extent(0) == in.extent(0), "Size of output vector must equal number of rows in input matrix."); detail::argmax(in.data_handle(), diff --git a/cpp/include/raft/matrix/argmin.cuh b/cpp/include/raft/matrix/argmin.cuh index c5d37e05cd..4e746b4305 100644 --- a/cpp/include/raft/matrix/argmin.cuh +++ b/cpp/include/raft/matrix/argmin.cuh @@ -8,6 +8,7 @@ #include #include #include +#include #include namespace raft { @@ -29,6 +30,7 @@ void argmin(raft::resources const& handle, raft::device_matrix_view in, raft::device_vector_view out) { + if (resource::get_dry_run_flag(handle)) { return; } RAFT_EXPECTS(out.extent(0) == in.extent(0), "Size of output vector must equal number of rows in input matrix."); detail::argmin(in.data_handle(), diff --git a/cpp/include/raft/matrix/col_wise_sort.cuh b/cpp/include/raft/matrix/col_wise_sort.cuh index fc0f3f1063..fed94e4511 100644 --- a/cpp/include/raft/matrix/col_wise_sort.cuh +++ b/cpp/include/raft/matrix/col_wise_sort.cuh @@ -11,6 +11,7 @@ #include #include #include +#include #include namespace raft { @@ -40,8 +41,16 @@ void sort_cols_per_row(const InType* in, cudaStream_t stream, InType* sortedKeys = nullptr) { - detail::sortColumnsPerRow( - in, out, n_rows, n_columns, bAllocWorkspace, workspacePtr, workspaceSize, stream, sortedKeys); + detail::sortColumnsPerRow(false, + in, + out, + n_rows, + n_columns, + bAllocWorkspace, + workspacePtr, + workspaceSize, + stream, + sortedKeys); } /** @@ -80,12 +89,14 @@ void sort_cols_per_row(raft::resources const& handle, "Input and `sorted_keys` matrices must have the same shape."); } + bool dry_run = resource::get_dry_run_flag(handle); size_t workspace_size = 0; bool alloc_workspace = false; in_t* keys = sorted_keys.has_value() ? sorted_keys.value().data_handle() : nullptr; - detail::sortColumnsPerRow(in.data_handle(), + detail::sortColumnsPerRow(dry_run, + in.data_handle(), out.data_handle(), in.extent(0), in.extent(1), @@ -98,7 +109,10 @@ void sort_cols_per_row(raft::resources const& handle, if (alloc_workspace) { auto workspace = raft::make_device_vector(handle, workspace_size); - detail::sortColumnsPerRow(in.data_handle(), + if (dry_run) { return; } + + detail::sortColumnsPerRow(dry_run, + in.data_handle(), out.data_handle(), in.extent(0), in.extent(1), diff --git a/cpp/include/raft/matrix/copy.cuh b/cpp/include/raft/matrix/copy.cuh index f673835915..b5478113e8 100644 --- a/cpp/include/raft/matrix/copy.cuh +++ b/cpp/include/raft/matrix/copy.cuh @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -36,6 +37,7 @@ void copy_rows(raft::resources const& handle, raft::device_matrix_view out, raft::device_vector_view indices) { + if (resource::get_dry_run_flag(handle)) { return; } RAFT_EXPECTS(in.extent(1) == out.extent(1), "Input and output matrices must have same number of columns"); RAFT_EXPECTS(indices.extent(0) == out.extent(0), @@ -61,6 +63,7 @@ void copy(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out) { + if (resource::get_dry_run_flag(handle)) { return; } RAFT_EXPECTS(in.extent(0) == out.extent(0) && in.extent(1) == out.extent(1), "Input and output matrix shapes must match."); @@ -81,6 +84,7 @@ void copy(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out) { + if (resource::get_dry_run_flag(handle)) { return; } RAFT_EXPECTS(in.extent(0) == out.extent(0) && in.extent(1) == out.extent(1), "Input and output matrix shapes must match."); @@ -102,6 +106,7 @@ void trunc_zero_origin(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out) { + if (resource::get_dry_run_flag(handle)) { return; } RAFT_EXPECTS(out.extent(0) <= in.extent(0) && out.extent(1) <= in.extent(1), "Output matrix must have less or equal number of rows and columns"); diff --git a/cpp/include/raft/matrix/detail/columnWiseSort.cuh b/cpp/include/raft/matrix/detail/columnWiseSort.cuh index c8c8b9090d..2487ce8b8d 100644 --- a/cpp/include/raft/matrix/detail/columnWiseSort.cuh +++ b/cpp/include/raft/matrix/detail/columnWiseSort.cuh @@ -164,7 +164,8 @@ cudaError_t layoutSortOffset(T* in, T value, int n_times, cudaStream_t stream) * @param sortedKeys: Optional, output matrix for sorted keys (input) */ template -void sortColumnsPerRow(const InType* in, +void sortColumnsPerRow(bool dry_run, + const InType* in, OutType* out, int n_rows, int n_columns, @@ -204,6 +205,8 @@ void sortColumnsPerRow(const InType* in, // more elements per thread --> more register pressure // 512(blockSize) * 8 elements per thread = 71 register / thread + if (dry_run) { return; } + // instantiate some kernel combinations if (n_columns <= 512) INST_BLOCK_SORT(in, sortedKeys, out, n_rows, n_columns, 128, 4, stream); @@ -256,6 +259,8 @@ void sortColumnsPerRow(const InType* in, // for segment offsets (numOffsets = numSegments + 1, see above) workspaceSize += raft::alignTo(sizeof(int) * (size_t)numOffsets, memAlignWidth); } else { + if (dry_run) { return; } + size_t workspaceOffset = 0; if (!sortedKeys) { @@ -307,6 +312,8 @@ void sortColumnsPerRow(const InType* in, workspaceSize += raft::alignTo(sizeof(OutType) * (size_t)n_columns, memAlignWidth); } else { + if (dry_run) { return; } + size_t workspaceOffset = 0; bool userKeyOutputBuffer = true; diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index c1686b2f55..08b2755710 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -551,13 +552,15 @@ void gather(raft::resources const& res, device_vector_view indices, raft::device_matrix_view output) { + auto dry_run = resource::get_dry_run_flag(res); raft::common::nvtx::range fun_scope("gather"); IdxT n_dim = output.extent(1); IdxT n_train = output.extent(0); auto indices_host = raft::make_host_vector(n_train); - raft::copy( - indices_host.data_handle(), indices.data_handle(), n_train, resource::get_cuda_stream(res)); - resource::sync_stream(res); + if (!dry_run) { + raft::copy( + indices_host.data_handle(), indices.data_handle(), n_train, resource::get_cuda_stream(res)); + } const size_t buffer_size = 32768 * 1024; // bytes const size_t max_batch_size = @@ -569,6 +572,10 @@ void gather(raft::resources const& res, auto out_tmp1 = raft::make_pinned_matrix(res, max_batch_size, n_dim); auto out_tmp2 = raft::make_pinned_matrix(res, max_batch_size, n_dim); + if (dry_run) { return; } + + resource::sync_stream(res); + // Usually a limited number of threads provide sufficient bandwidth for gathering data. #if defined(_OPENMP) int n_threads = std::min(omp_get_max_threads(), 32); diff --git a/cpp/include/raft/matrix/detail/gather_inplace.cuh b/cpp/include/raft/matrix/detail/gather_inplace.cuh index 1cfd7664ec..7eaf05539b 100644 --- a/cpp/include/raft/matrix/detail/gather_inplace.cuh +++ b/cpp/include/raft/matrix/detail/gather_inplace.cuh @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -39,12 +40,14 @@ void gatherInplaceImpl(raft::resources const& handle, // re-assign batch_size for default case if (batch_size == 0 || batch_size > n) batch_size = n; + auto scratch_space = raft::make_device_vector(handle, map_length * batch_size); + + if (resource::get_dry_run_flag(handle)) { return; } + auto exec_policy = resource::get_thrust_policy(handle); IndexT n_batches = raft::ceildiv(n, batch_size); - auto scratch_space = raft::make_device_vector(handle, map_length * batch_size); - for (IndexT bid = 0; bid < n_batches; bid++) { IndexT batch_offset = bid * batch_size; IndexT cols_per_batch = min(batch_size, n - batch_offset); diff --git a/cpp/include/raft/matrix/detail/math.cuh b/cpp/include/raft/matrix/detail/math.cuh index 14a7846704..bd6a6a0144 100644 --- a/cpp/include/raft/matrix/detail/math.cuh +++ b/cpp/include/raft/matrix/detail/math.cuh @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -187,10 +188,10 @@ template void ratio( raft::resources const& handle, const math_t* src, math_t* dest, IdxType len, cudaStream_t stream) { - auto d_src = src; - auto d_dest = dest; - rmm::device_scalar d_sum(stream); + if (resource::get_dry_run_flag(handle)) { return; } + auto d_src = src; + auto d_dest = dest; auto* d_sum_ptr = d_sum.data(); raft::linalg::mapThenSumReduce(d_sum_ptr, len, raft::identity_op{}, stream, src); raft::linalg::unaryOp( @@ -201,15 +202,16 @@ template ( - data, data, vec, n_col, n_row, raft::mul_op(), stream); + raft::linalg::detail::matrixVectorOp( + false, data, data, vec, n_col, n_row, raft::mul_op(), stream); } template void matrixVectorBinaryMultSkipZero( Type* data, const Type* vec, IdxType n_row, IdxType n_col, cudaStream_t stream) { - raft::linalg::matrixVectorOp( + raft::linalg::detail::matrixVectorOp( + false, data, data, vec, @@ -228,8 +230,8 @@ template ( - data, data, vec, n_col, n_row, raft::div_op(), stream); + raft::linalg::detail::matrixVectorOp( + false, data, data, vec, n_col, n_row, raft::div_op(), stream); } template @@ -241,7 +243,8 @@ void matrixVectorBinaryDivSkipZero(Type* data, bool return_zero = false) { if (return_zero) { - raft::linalg::matrixVectorOp( + raft::linalg::detail::matrixVectorOp( + false, data, data, vec, @@ -255,7 +258,8 @@ void matrixVectorBinaryDivSkipZero(Type* data, }, stream); } else { - raft::linalg::matrixVectorOp( + raft::linalg::detail::matrixVectorOp( + false, data, data, vec, @@ -275,16 +279,16 @@ template ( - data, data, vec, n_col, n_row, raft::add_op(), stream); + raft::linalg::detail::matrixVectorOp( + false, data, data, vec, n_col, n_row, raft::add_op(), stream); } template void matrixVectorBinarySub( Type* data, const Type* vec, IdxType n_row, IdxType n_col, cudaStream_t stream) { - raft::linalg::matrixVectorOp( - data, data, vec, n_col, n_row, raft::sub_op(), stream); + raft::linalg::detail::matrixVectorOp( + false, data, data, vec, n_col, n_row, raft::sub_op(), stream); } // Computes an argmin/argmax column-wise in a DxN matrix diff --git a/cpp/include/raft/matrix/detail/matrix.cuh b/cpp/include/raft/matrix/detail/matrix.cuh index f3545fb103..a7d41e19f8 100644 --- a/cpp/include/raft/matrix/detail/matrix.cuh +++ b/cpp/include/raft/matrix/detail/matrix.cuh @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -297,6 +298,7 @@ void getDiagonalInverseMatrix(m_t* in, idx_t len, cudaStream_t stream) template m_t getL2Norm(raft::resources const& handle, const m_t* in, idx_t size, cudaStream_t stream) { + if (resource::get_dry_run_flag(handle)) { return m_t{0}; } cublasHandle_t cublasH = resource::get_cublas_handle(handle); m_t normval = 0; RAFT_EXPECTS( diff --git a/cpp/include/raft/matrix/detail/scatter_inplace.cuh b/cpp/include/raft/matrix/detail/scatter_inplace.cuh index 2c735e3fda..ecad4a0477 100644 --- a/cpp/include/raft/matrix/detail/scatter_inplace.cuh +++ b/cpp/include/raft/matrix/detail/scatter_inplace.cuh @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -64,12 +65,14 @@ void scatterInplaceImpl( // re-assign batch_size for default case if (batch_size == 0 || batch_size > n) batch_size = n; + auto scratch_space = raft::make_device_vector(handle, m * batch_size); + + if (resource::get_dry_run_flag(handle)) { return; } + auto exec_policy = resource::get_thrust_policy(handle); IndexT n_batches = raft::ceildiv(n, batch_size); - auto scratch_space = raft::make_device_vector(handle, m * batch_size); - for (IndexT bid = 0; bid < n_batches; bid++) { IndexT batch_offset = bid * batch_size; IndexT cols_per_batch = min(batch_size, n - batch_offset); diff --git a/cpp/include/raft/matrix/detail/select_k-inl.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh index f693f986c6..d22d0c24ce 100644 --- a/cpp/include/raft/matrix/detail/select_k-inl.cuh +++ b/cpp/include/raft/matrix/detail/select_k-inl.cuh @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -127,6 +128,8 @@ void segmented_sort_by_key(raft::resources const& handle, auto d_temp_storage = raft::make_device_mdarray( handle, mr, raft::make_extents(temp_storage_bytes)); + if (resource::get_dry_run_flag(handle)) { return; } + if (asc) { // Run sorting operation cub::DeviceSegmentedRadixSort::SortPairs((void*)d_temp_storage.data_handle(), diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 718096c466..ea64a0f524 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -877,7 +878,8 @@ unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt) } template -void radix_topk(const T* in, +void radix_topk(bool dry_run, + const T* in, const IdxT* in_idx, int batch_size, IdxT len, @@ -911,6 +913,8 @@ void radix_topk(const T* in, rmm::device_buffer bufs(max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)), stream, mr); + if (dry_run) { return; } + for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { int chunk_size = std::min(max_chunk_size, batch_size - offset); RAFT_CUDA_TRY( @@ -1152,7 +1156,8 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, // used. It's used when len is relatively small or when the number of blocks per row calculated by // `calc_grid_dim()` is 1. template -void radix_topk_one_block(const T* in, +void radix_topk_one_block(bool dry_run, + const T* in, const IdxT* in_idx, int batch_size, IdxT len, @@ -1174,6 +1179,8 @@ void radix_topk_one_block(const T* in, rmm::device_buffer bufs(max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)), stream, mr); + if (dry_run) { return; } + for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { int chunk_size = std::min(max_chunk_size, batch_size - offset); const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; @@ -1270,9 +1277,11 @@ void select_k(raft::resources const& res, RAFT_EXPECTS(RowLayout::is_uniform || len_i != nullptr, "CSR layout requires a non-null indptr array (len_i)!"); - auto stream = resource::get_cuda_stream(res); - auto mr = resource::get_workspace_resource_ref(res); + bool dry_run = resource::get_dry_run_flag(res); + auto stream = resource::get_cuda_stream(res); + auto mr = resource::get_workspace_resource_ref(res); if (k == len && RowLayout::is_uniform) { + if (dry_run) { return; } RAFT_CUDA_TRY( cudaMemcpyAsync(out, in, sizeof(T) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); if (in_idx) { @@ -1292,15 +1301,27 @@ void select_k(raft::resources const& res, if (len <= BlockSize * items_per_thread) { impl::radix_topk_one_block( - in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); + dry_run, in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); } else { unsigned grid_dim = impl::calc_grid_dim(batch_size, len, sm_cnt); if (grid_dim == 1) { - impl::radix_topk_one_block( - in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); + impl::radix_topk_one_block(dry_run, + in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + len_i, + sm_cnt, + stream, + mr); } else { - impl::radix_topk(in, + impl::radix_topk(dry_run, + in, in_idx, batch_size, len, diff --git a/cpp/include/raft/matrix/detail/select_warpsort.cuh b/cpp/include/raft/matrix/detail/select_warpsort.cuh index b517ef8c10..830720c42d 100644 --- a/cpp/include/raft/matrix/detail/select_warpsort.cuh +++ b/cpp/include/raft/matrix/detail/select_warpsort.cuh @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -1043,7 +1044,8 @@ template