Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
99dc47f
Add dry-run memory resources for allocation profiling without real me…
achirkin Feb 18, 2026
695a8a3
First batch of dry-run guards
achirkin Feb 18, 2026
42d8ad4
Dry run compliance for raft::linalg namespace
achirkin Feb 19, 2026
6db7ec8
Update developer guide with the dry run protocol
achirkin Feb 19, 2026
d91a1c6
BREAKING CHANGE: replaced pinned_container with host_container using …
achirkin Feb 19, 2026
1a114f6
Dry run compliance for raft::matrix namespace
achirkin Feb 19, 2026
dec5e95
Dry run compliance for raft::random namespace
achirkin Feb 19, 2026
f84d9a9
Dry run compliance for raft::solver namespace
achirkin Feb 19, 2026
44793cd
Dry run compliance for raft::sparse namespace
achirkin Feb 20, 2026
d566fe9
Dry run compliance for raft::spectral namespace
achirkin Feb 20, 2026
fc3bde6
Dry run compliance for raft::stats namespace
achirkin Feb 20, 2026
b0ddbc8
Add a little bit more tests
achirkin Feb 20, 2026
15c07a1
Add the Dry Run Protocol Overview
achirkin Feb 20, 2026
1c57abb
Fix C++ example in the docs
achirkin Feb 20, 2026
d916b45
Merge branch 'main' into fea-dry-run-protocol
achirkin Feb 20, 2026
9d24480
Add a few more tests and fix a missed CUDA call in QR algorithm
achirkin Feb 20, 2026
7577e56
Fix excess subsample doing work in dry run
achirkin Feb 20, 2026
99faf68
Add dry run compliance to the raft::copy on mdspans
achirkin Feb 20, 2026
b859894
Merge branch 'main' into fea-dry-run-protocol
achirkin Feb 20, 2026
57d4c19
Revert changing includes from public to detail namespace to avoid bre…
achirkin Feb 23, 2026
694ec63
Merge branch 'main' into fea-dry-run-protocol
achirkin Feb 23, 2026
a2dd18c
Merge rapidsai/main into fea-dry-run-protocol
achirkin Feb 24, 2026
45e2d49
Rename device_uvector_policy -> device_container_policy and add non-i…
achirkin Feb 26, 2026
65d4570
Declare the new resources in raft handle
achirkin Feb 26, 2026
d86638f
Renamed managed policy
achirkin Feb 26, 2026
d6788f6
Add raft::resources for pinned and managed resources and the type-era…
achirkin Feb 26, 2026
e7bea48
Updated container policies
achirkin Feb 26, 2026
2514621
All but host memory resource are done
achirkin Feb 26, 2026
49735a5
Simplify the implementation
achirkin Feb 27, 2026
22b4048
Make the host container policy use the resource concept
achirkin Feb 27, 2026
557cc8c
Settle down with raft::mr::*et_default_host_resource()
achirkin Feb 27, 2026
cc7a4b0
Add some thread-safety
achirkin Feb 27, 2026
e77fe2a
Merge branch 'main' into fea-unify-memory-resources
achirkin Feb 27, 2026
8922b8f
Merge branch 'main' into fea-dry-run-protocol
achirkin Feb 27, 2026
866211e
C++17 backwards-compatibility
achirkin Feb 28, 2026
c171d84
Merge branch 'main' into fea-unify-memory-resources
achirkin Feb 28, 2026
268eb1b
newline
achirkin Feb 28, 2026
5c718d6
Add raft::mr::device_resource wrapper for cuda::mr::any_resource
achirkin Mar 1, 2026
c5ab9c4
Copy semantics and return resource refs
achirkin Mar 2, 2026
6af142e
Rework workspace resources to avoid nesting bridge layers
achirkin Mar 2, 2026
ece1990
Fix the argument order in tests
achirkin Mar 2, 2026
3c17e3e
Merge branch 'main' into fea-dry-run-protocol
achirkin Mar 2, 2026
4dd256b
Merge branch 'main' into fea-unify-memory-resources
achirkin Mar 3, 2026
a26357d
Add explicit conversion through cuda::mr refs to rmm ref
achirkin Mar 3, 2026
2a90680
Switch from rmm host and host_device resource reference wrappers to r…
achirkin Mar 4, 2026
59c3793
Merge branch 'main' into fea-unify-memory-resources
achirkin Mar 4, 2026
3a40d22
Prefer rmm::mr::get_current_device_resource_ref() over rmm::mr::get_c…
achirkin Mar 4, 2026
cce4f45
Remove raft pinned and managed memory resources in favor of cuda::mr …
achirkin Mar 4, 2026
fb56025
Merge branch 'main' into fea-dry-run-protocol
achirkin Mar 4, 2026
ff20962
Merge fea-unify-memory-resources into fea-dry-run-protocol
achirkin Mar 5, 2026
e76bf7c
Adapt to fea-unify-memory-resources
achirkin Mar 5, 2026
2d3f8fc
Refactor dry_run_resources as a child of raft::resources to better ke…
achirkin Mar 5, 2026
d2cf85e
Merge branch 'main' into fea-dry-run-protocol
achirkin Mar 9, 2026
e86b56d
Merge branch 'main' into fea-dry-run-protocol
achirkin Mar 14, 2026
d9a0abf
Fix style after merge commit
achirkin Mar 16, 2026
16324fb
Merge branch 'main' into fea-dry-run-protocol
achirkin Mar 18, 2026
ced0e6e
Fix merge commit typo
achirkin Mar 18, 2026
c9bf618
Merge branch 'main' into fea-dry-run-protocol
achirkin Mar 19, 2026
1acf6cf
Fix some sparse routines not being dry-run compliant
achirkin Mar 23, 2026
2326061
Unify the looks of the three custom raft::resources
achirkin Mar 24, 2026
d4ff16e
Expand test coverage Part 1
achirkin Mar 24, 2026
fee4b62
Expand test coverage Part 2
achirkin Mar 24, 2026
f1b7aca
Update docs to reflect unify memory resources PR changes
achirkin Mar 24, 2026
156a437
Fix segfault in sparse tests caused by invalid thrust exec policy
achirkin Mar 25, 2026
51c0b16
Better allocation estimates in the sparse namespace
achirkin Mar 30, 2026
c99b879
Fixing more failing tests
achirkin Apr 1, 2026
e535124
Fixing last failing tests
achirkin Apr 1, 2026
9971c71
Merge branch 'main' into fea-dry-run-protocol
achirkin Apr 1, 2026
b682d46
Fix not initialize the mdarray scalars only in dry run mode
achirkin Apr 1, 2026
69543a1
Clarify that all workspace resources are actually counted independent…
achirkin Apr 2, 2026
5db5727
Rename the dry_run_resources header file for conistency
achirkin Apr 2, 2026
d1cf594
Dry-run compliance for coo_sort
achirkin Apr 2, 2026
e47c41f
Fix the expected minimum allocation calculation
achirkin Apr 2, 2026
3d93b0f
Merge branch 'main' into fea-dry-run-protocol
achirkin Apr 3, 2026
e60a048
Merge branch 'main' into fea-dry-run-protocol
achirkin Apr 8, 2026
f8754d9
Merge branch 'main' into fea-dry-run-protocol
achirkin Apr 10, 2026
0f65503
Merge branch 'main' into fea-dry-run-protocol
achirkin Apr 13, 2026
969b868
Merge remote-tracking branch 'rapidsai/main' into fea-dry-run-protocol
achirkin Apr 22, 2026
217ca58
Fix tests after rmm breaking change
achirkin Apr 22, 2026
e72e872
Store the device resources by values to safely keep them alive while …
achirkin Apr 23, 2026
1c4deb8
Switch to owning semantics for both host and per-device resources
achirkin Apr 24, 2026
8fdf194
Don't let allocations cross dry-run/normal scopes
achirkin Apr 29, 2026
1a7501c
More thorough tests for bitset/bitmap in dry run mode
achirkin Apr 29, 2026
65dda3b
Merge branch 'main' into fea-dry-run-protocol
achirkin May 6, 2026
a710e97
make bitset.count() dry-run-compliant
achirkin May 7, 2026
3a65638
Merge branch 'main' into fea-dry-run-protocol
achirkin May 7, 2026
7f1210e
Merge branch 'main' into fea-dry-run-protocol
achirkin May 11, 2026
592b8e0
Merge branch 'main' into fea-dry-run-protocol
achirkin May 15, 2026
0f5641e
Merge branch 'main' into fea-dry-run-protocol
achirkin Jun 10, 2026
205cffa
Make randomized SVD dry run compliant (adopting new features to dry run)
achirkin Jun 10, 2026
ec1a794
Merge branch 'main' into fea-dry-run-protocol
achirkin Jun 15, 2026
8d56307
Merge remote-tracking branch main into fea-dry-run-protocol
achirkin Jun 17, 2026
a9bc0af
Follow up on merge commit
achirkin Jun 17, 2026
91107c3
Merge branch 'main' into fea-dry-run-protocol
achirkin Jun 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cpp/include/raft/core/bitset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <raft/core/device_container_policy.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/resource/dry_run_flag.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/map.cuh>
#include <raft/linalg/reduce.cuh>
Expand Down Expand Up @@ -166,6 +167,8 @@ void bitset_view<bitset_t, index_t>::repeat(const raft::resources& res,
index_t times,
bitset_t* output_device_ptr) const
{
// Only a copy and kernel run below this point.
if (resource::get_dry_run_flag(res)) { return; }
constexpr index_t bits_per_element = sizeof(bitset_t) * 8;

if (bitset_len_ % bits_per_element == 0) {
Expand Down
17 changes: 11 additions & 6 deletions cpp/include/raft/core/bitset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <raft/core/detail/macros.hpp>
#include <raft/core/device_container_policy.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/resource/dry_run_flag.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/util/integer_utils.hpp>
Expand Down Expand Up @@ -133,9 +134,11 @@ struct bitset_view {
auto count_gpu_scalar = raft::make_device_scalar<index_t>(res, 0.0);
count(res, count_gpu_scalar.view());
index_t count_cpu = 0;
raft::update_host(
&count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res));
resource::sync_stream(res);
if (!resource::get_dry_run_flag(res)) {
raft::update_host(
&count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res));
resource::sync_stream(res);
}
return count_cpu;
}

Expand Down Expand Up @@ -408,9 +411,11 @@ struct bitset {
auto count_gpu_scalar = raft::make_device_scalar<index_t>(res, 0.0);
count(res, count_gpu_scalar.view());
index_t count_cpu = 0;
raft::update_host(
&count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res));
resource::sync_stream(res);
if (!resource::get_dry_run_flag(res)) {
raft::update_host(
&count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res));
resource::sync_stream(res);
}
return count_cpu;
}
/**
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/core/coo_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ class coordinate_structure : public coordinate_structure_t<RowType, ColType, NZT
void initialize_sparsity(nnz_type nnz)
{
sparse_structure_type::initialize_sparsity(nnz);
c_rows_.resize(nnz);
c_cols_.resize(nnz);
c_rows_.reallocate(nnz);
c_cols_.reallocate(nnz);
}

protected:
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/core/csr_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ class compressed_structure
void initialize_sparsity(NZType nnz) override
{
sparse_structure_type::initialize_sparsity(nnz);
c_indptr_.resize(this->get_n_rows() + 1);
c_indices_.resize(nnz);
c_indptr_.reallocate(this->get_n_rows() + 1);
c_indices_.reallocate(nnz);
}

protected:
Expand Down
5 changes: 5 additions & 0 deletions cpp/include/raft/core/detail/copy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <raft/core/host_mdspan.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/mdspan.hpp>
#include <raft/core/resource/dry_run_flag.hpp>
#include <raft/core/resource/stream_view.hpp>
#include <raft/core/resources.hpp>

Expand Down Expand Up @@ -399,6 +400,10 @@ mdspan_copyable_t<DstType, SrcType> copy(resources const& res, DstType&& dst, Sr
RAFT_EXPECTS(src.extent(i) == dst.extent(i), "Must copy between mdspans of the same shape");
}

// Dry-run guard: raft::copy is a pure data-movement utility with no
// allocations that callers would need tracked.
if (resource::get_dry_run_flag(res)) { return; }

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I didn't realize we were adding this as a new resource. This would make it hard to use the dry-run for pre-initializing resources.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but that's fine! We can still push the initialized resources back to the original resources handle on destruction of the dry run resources

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NB: with #3052 , resources initialized in dry run mode will be automatically shared back with the input resources.


if constexpr (config::use_intermediate_src) {
#ifndef RAFT_DISABLE_CUDA
// Copy to intermediate source on device, then perform necessary
Expand Down
23 changes: 23 additions & 0 deletions cpp/include/raft/core/device_container_policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,29 @@ class device_uvector {

void resize(size_type size) { data_.resize(size, data_.stream()); }

/**
* @brief Resize the internal buffer without copying old data.
*
* Unlike resize(), this never copies old data.
* Thus, unlike in resize(), there's no point in time where the old and the new buffers are both
* alive, and the peak memory usage is lower.
*
* Unlike resize(), this deallocates the old buffer even if the new size is smaller.
* This ensures the memory is released promptly.
*/
void reallocate(size_type size)
{
if (size != data_.size()) {
auto stream = data_.stream();
auto mr = data_.memory_resource();
// Resize and shrink rmm::device_uvector: force deallocation without copying old data
data_.resize(0, data_.stream());
data_.shrink_to_fit(data_.stream());
// Assign a new value after the old one is deallocated
data_ = rmm::device_uvector<T>(size, stream, mr);
}
}

[[nodiscard]] auto data() noexcept -> pointer { return data_.data(); }
[[nodiscard]] auto data() const noexcept -> const_pointer { return data_.data(); }
};
Expand Down
3 changes: 2 additions & 1 deletion cpp/include/raft/core/device_mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <raft/core/device_container_policy.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/mdarray.hpp>
#include <raft/core/resource/dry_run_flag.hpp>
#include <raft/core/resources.hpp>

#include <rmm/resource_ref.hpp>
Expand Down Expand Up @@ -164,7 +165,7 @@ auto make_device_scalar(raft::resources const& handle, ElementType const& v)
using policy_t = typename device_scalar<ElementType, IndexType>::container_policy_type;
policy_t policy{};
auto scalar = device_scalar<ElementType, IndexType>{handle, extents, policy};
scalar(0) = v;
if (!resource::get_dry_run_flag(handle)) { scalar(0) = v; }
return scalar;
}

Expand Down
253 changes: 253 additions & 0 deletions cpp/include/raft/core/dry_run_resources.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once

#include <raft/core/memory_stats_resources.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/core/resource/dry_run_flag.hpp>
#include <raft/core/resource/managed_memory_resource.hpp>
#include <raft/core/resource/pinned_memory_resource.hpp>
#include <raft/core/resources.hpp>
#include <raft/mr/dry_run_resource.hpp>
#include <raft/mr/host_device_resource.hpp>
#include <raft/mr/host_memory_resource.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/per_device_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <cuda/stream_ref>

#include <cstddef>
#include <cstdint>
#include <memory>
#include <utility>

namespace raft {

/**
* @defgroup dry_run_memory Dry-run memory resources
* @{
*/

/**
* @brief Resources handle that wraps all reachable memory resources with
* dry-run adaptors and tracks peak allocation usage.
*
* Inherits from raft::resources, so it can be passed anywhere a
* raft::resources& is expected. On construction the handle:
* - If dry-run mode is already active, does nothing (no-op).
* - Materializes all tracked resource types (host, device, pinned,
* managed, workspace, large_workspace).
* - Takes a snapshot of the original resources to keep them alive.
* - Wraps each with dry_run_resource.
* - Replaces global host and device resources with dry-run versions.
* - Sets the dry-run flag.
*
* On destruction the handle resets the flag and restores global resources.
* Composable with memory_tracking_resources in either order.
*/
class dry_run_resources : public resources {
public:
explicit dry_run_resources(const resources& existing)
: resources(existing),
active_(!resource::get_dry_run_flag(existing)),
old_host_(raft::mr::get_default_host_resource()),
old_device_(rmm::mr::get_current_device_resource_ref())
{
if (active_) init();
Comment on lines +54 to +60

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🔴 Critical | ⚡ Quick win

CRITICAL: Restore global resources if init() throws.

Issue: init() replaces global host/device resources, then continues through operations that can throw; if construction fails, ~dry_run_resources() is never called.
Why: the global resource can be left pointing at a dry-run adaptor that is being unwound, causing later allocations to use a dangling resource.

Suggested fix
   {
-    if (active_) init();
+    if (active_) {
+      try {
+        init();
+      } catch (...) {
+        resource::set_dry_run_flag(*this, false);
+        raft::mr::set_default_host_resource(old_host_);
+        rmm::mr::set_current_device_resource(old_device_);
+        resources_.clear();
+        factories_.clear();
+        throw;
+      }
+    }
   }

As per path instructions, dry-run guards must not skip required resource setup/teardown that can affect later non-dry-run calls.

Also applies to: 159-208

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/include/raft/core/dry_run_resources.hpp` around lines 54 - 60, The
`dry_run_resources` construction path in `init()` can leave global host/device
resources pointing at the dry-run adaptor if an exception is thrown before
`~dry_run_resources()` runs. Update the `dry_run_resources(const resources&)`
constructor and/or `init()` so the old resources saved in `old_host_` and
`old_device_` are restored immediately on any failure during setup, using an
exception-safe guard/rollback path that also covers the later setup block
referenced by the same issue.

Source: Path instructions

}

~dry_run_resources() override
{
if (!active_) return;
resource::set_dry_run_flag(*this, false);
raft::mr::set_default_host_resource(old_host_);
rmm::mr::set_current_device_resource(old_device_);

// Drop all base-class entries so that probe container RAII cleanup runs
// while old_device_ and snapshot_ are still alive
resources_.clear();
factories_.clear();
}

dry_run_resources(dry_run_resources const&) = delete;
dry_run_resources& operator=(dry_run_resources const&) = delete;
dry_run_resources(dry_run_resources&&) = delete;
dry_run_resources& operator=(dry_run_resources&&) = delete;

[[nodiscard]] auto get_bytes_peak() const -> memory_stats
{
if (!active_) return {};
return {
.device_workspace = ws_stats_->get_peak_bytes(),
.device_large_workspace = lws_stats_->get_peak_bytes(),
.device_global = device_stats_->get_peak_bytes(),
.device_managed = managed_stats_->get_peak_bytes(),
.host = host_stats_->get_peak_bytes(),
.host_pinned = pinned_stats_->get_peak_bytes(),
};
}

[[nodiscard]] auto get_bytes_current() const -> memory_stats
{
if (!active_) return {};
return {
.device_workspace = ws_stats_->get_allocated_bytes(),
.device_large_workspace = lws_stats_->get_allocated_bytes(),
.device_global = device_stats_->get_allocated_bytes(),
.device_managed = managed_stats_->get_allocated_bytes(),
.host = host_stats_->get_allocated_bytes(),
.host_pinned = pinned_stats_->get_allocated_bytes(),
};
}

private:
// Declaration order determines destruction order.
// snapshot_ is destroyed last (keeps original resource shared_ptrs alive
// while dry-run adaptors hold non-owning refs into them).
// old_device_ is destroyed after device_adaptor_ so the probe can
// deallocate through it during device_adaptor_ destruction.
std::vector<pair_resource> snapshot_;

bool active_;
raft::mr::host_resource old_host_;
raft::mr::device_resource old_device_;

using host_dry_run_t = raft::mr::dry_run_resource<raft::mr::host_resource_ref>;
using device_dry_run_t = raft::mr::dry_run_resource<rmm::device_async_resource_ref>;
std::unique_ptr<host_dry_run_t> host_adaptor_;
std::unique_ptr<device_dry_run_t> device_adaptor_;

using counter_t = raft::mr::detail::dry_run_memory_counter;
std::shared_ptr<counter_t> host_stats_;
std::shared_ptr<counter_t> pinned_stats_;
std::shared_ptr<counter_t> managed_stats_;
std::shared_ptr<counter_t> ws_stats_;
std::shared_ptr<counter_t> lws_stats_;
std::shared_ptr<counter_t> device_stats_;

void init()
{
// Independent-counting invariant
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// 1. Force-initialize all lazily-created resources (workspace, large workspace,
// pinned, managed) so that their factories resolve against the *original*
// global device MR, not a tracking wrapper we install later.
// 2. Capture every upstream ref while it still points to the original resource.
// 3. Snapshot the resource map to keep the originals alive.
// 4. Only *then* replace the global device resource with the tracking bridge.
// 5. Wrap each captured upstream with a separate dry_run_resource adaptor.
//
// Because step 2 happens before step 4, workspace/lws allocations flow through
// their own adaptor directly to the original device MR, bypassing the device adaptor.
// Each allocation is therefore counted in exactly one category, and
// memory_stats::total() returns an accurate, non-overlapping sum.
auto* ws = resource::get_workspace_resource(*this);
auto ws_free = resource::get_workspace_free_bytes(*this);
auto ws_upstream = ws->get_upstream_resource();
auto lws_ref = resource::get_large_workspace_resource_ref(*this);
auto pinned_ref = resource::get_pinned_memory_resource_ref(*this);
auto managed_ref = resource::get_managed_memory_resource_ref(*this);

// Snapshot keeps original resource objects alive while dry-run
// adaptors hold non-owning refs into them.
snapshot_ = resources_;

// --- Host (global) ---
{
host_adaptor_ = std::make_unique<host_dry_run_t>(raft::mr::host_resource_ref{old_host_});
host_stats_ = host_adaptor_->get_counter();
mr::set_default_host_resource(mr::host_resource_ref{*host_adaptor_});
}

// --- Pinned ---
{
mr::dry_run_resource<mr::host_device_resource_ref> dr{pinned_ref};
pinned_stats_ = dr.get_counter();
resource::set_pinned_memory_resource(*this, std::move(dr));
}

// --- Managed ---
{
mr::dry_run_resource<mr::host_device_resource_ref> dr{managed_ref};
managed_stats_ = dr.get_counter();
resource::set_managed_memory_resource(*this, std::move(dr));
}

// --- Device (global) ---
// Invalidate the cached thrust policy (the resource_ref it captured
// will be stale once we replace the global device resource).
factories_.at(resource::resource_type::THRUST_POLICY) = std::make_pair(
resource::resource_type::LAST_KEY, std::make_shared<resource::empty_resource_factory>());
resources_.at(resource::resource_type::THRUST_POLICY) = std::make_pair(
resource::resource_type::LAST_KEY, std::make_shared<resource::empty_resource>());
{
device_dry_run_t dr{rmm::device_async_resource_ref{old_device_}};
device_stats_ = dr.get_counter();
device_adaptor_ = std::make_unique<device_dry_run_t>(std::move(dr));
rmm::mr::set_current_device_resource(*device_adaptor_);
}

// --- Workspace ---
{
mr::dry_run_resource<rmm::device_async_resource_ref> dr{ws_upstream};
ws_stats_ = dr.get_counter();
resource::set_workspace_resource(*this, std::move(dr), ws_free);
}

// --- Large workspace ---
{
mr::dry_run_resource<rmm::device_async_resource_ref> dr{lws_ref};
lws_stats_ = dr.get_counter();
resource::set_large_workspace_resource(*this, std::move(dr));
}

resource::set_dry_run_flag(*this, true);
}
};

/** @} */

} // namespace raft

namespace raft::util {

/**
* @brief Execute an action in dry-run mode and return peak memory usage.
*
* Creates an independent copy of the resources handle with all memory resources
* replaced by dry-run versions, executes the action, and returns peak usage stats.
*
* The action receives the dry-run resources handle (as const raft::resources&)
* and can check the dry-run flag via raft::resource::get_dry_run_flag(res) to
* skip kernel execution.
*
* @tparam Action A callable with signature void(const raft::resources&, Args...).
* @tparam Args Additional argument types to forward to the action.
* @param res The raft resources handle.
* @param action The action to execute in dry-run mode.
* @param args Additional arguments to forward to the action.
* @return memory_stats with peak memory usage from the dry run.
*
* @code{.cpp}
* raft::resources res;
* auto stats = raft::util::dry_run_execute(res, [](const raft::resources& r) {
* my_algorithm(r);
* });
* std::cout << "Peak workspace: " << stats.device_workspace << " bytes\n";
* @endcode
*/
template <typename Action, typename... Args>
auto dry_run_execute(const raft::resources& res, Action&& action, Args&&... args)
-> raft::memory_stats
{
raft::dry_run_resources dry_res(res);
std::forward<Action>(action)(static_cast<const raft::resources&>(dry_res),
std::forward<Args>(args)...);
return dry_res.get_bytes_peak();
}

} // namespace raft::util
Loading
Loading