Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ target_include_directories(
# Keep RAFT as lightweight as possible. Only CUDA libs and rmm should be used in global target.
target_link_libraries(raft INTERFACE rapids_logger::rapids_logger rmm::rmm CCCL::CCCL)

target_compile_features(raft INTERFACE cxx_std_17 $<BUILD_INTERFACE:cuda_std_17>)
target_compile_features(raft INTERFACE cxx_std_20 $<BUILD_INTERFACE:cuda_std_20>)
target_compile_options(
raft INTERFACE $<$<COMPILE_LANG_AND_ID:CUDA,NVIDIA>:--expt-extended-lambda
--expt-relaxed-constexpr>
Expand Down
9 changes: 4 additions & 5 deletions cpp/include/raft/core/device_resources.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,8 @@
#include <cusparse.h>

#include <memory>
#include <mutex>
#include <optional>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

Expand All @@ -60,9 +58,10 @@ class device_resources : public resources {
resource::set_workspace_resource(*this, std::move(workspace_resource), allocation_limit);
}

device_resources(const device_resources& handle) : resources{handle} {}
device_resources(device_resources&&) = delete;
device_resources& operator=(device_resources&&) = delete;
device_resources(const device_resources&) = default;
device_resources(device_resources&&) = default;
device_resources& operator=(const device_resources&) = default;
device_resources& operator=(device_resources&&) = default;

/**
* @brief Construct a resources instance with a stream view and stream pool
Expand Down
8 changes: 4 additions & 4 deletions cpp/include/raft/core/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ class handle_t : public raft::device_resources {
{
}

handle_t(const handle_t& handle) : device_resources{handle} {}

handle_t(handle_t&&) = delete;
handle_t& operator=(handle_t&&) = delete;
handle_t(const handle_t&) = default;
handle_t(handle_t&&) = default;
handle_t& operator=(const handle_t&) = default;
handle_t& operator=(handle_t&&) = default;

/**
* @brief Construct a resources instance with a stream view and stream pool
Expand Down
13 changes: 6 additions & 7 deletions cpp/include/raft/core/memory_stats_resources.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class memory_stats_resources : public resources {
};
}

std::vector<pair_resource> snapshot_;
std::vector<std::shared_ptr<resource::resource_cell>> snapshot_;

raft::mr::host_resource old_host_;
raft::mr::device_resource old_device_;
Expand Down Expand Up @@ -182,7 +182,7 @@ class memory_stats_resources : public resources {
auto pinned_ref = resource::get_pinned_memory_resource_ref(*this);
auto managed_ref = resource::get_managed_memory_resource_ref(*this);

snapshot_ = resources_;
snapshot_ = cells_;

// --- Host (global) ---
{
Expand All @@ -207,11 +207,10 @@ class memory_stats_resources : public resources {

// --- Device (global) ---
// Invalidate the cached thrust policy (the resource_ref it captured
// will be stale once we replace the global device resource).
factories_.at(resource::resource_type::THRUST_POLICY) = std::make_pair(
resource::resource_type::LAST_KEY, std::make_shared<resource::empty_resource_factory>());
resources_.at(resource::resource_type::THRUST_POLICY) = std::make_pair(
resource::resource_type::LAST_KEY, std::make_shared<resource::empty_resource>());
// will be stale once we replace the global device resource). Swapping in a
// fresh cell drops the cached factory/resource locally while snapshot_ keeps
// the originals alive, so it gets lazily rebuilt against the new device MR.
cells_[resource::resource_type::THRUST_POLICY] = std::make_shared<resource::resource_cell>();
{
device_stats_adaptor_t sa{rmm::device_async_resource_ref{old_device_}};
device_stats_ = sa.get_stats();
Expand Down
13 changes: 6 additions & 7 deletions cpp/include/raft/core/memory_tracking_resources.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class memory_tracking_resources : public resources {
// snapshot_ is destroyed last (keeps original resource shared_ptrs alive).
// owned_stream_ outlives report_ (report_ writes to it).
// report_ is destroyed first of the three (stops background thread).
std::vector<pair_resource> snapshot_;
std::vector<std::shared_ptr<resource::resource_cell>> snapshot_;
std::unique_ptr<std::ofstream> owned_stream_;
raft::mr::resource_monitor report_;

Expand Down Expand Up @@ -177,7 +177,7 @@ class memory_tracking_resources : public resources {
auto managed_ref = raft::resource::get_managed_memory_resource_ref(*this);

// Keeps original resource objects alive while tracking refs point into them.
snapshot_ = resources_;
snapshot_ = cells_;

// --- Host (global) ---
{
Expand Down Expand Up @@ -209,11 +209,10 @@ class memory_tracking_resources : public resources {

// --- Device (global) ---
// Invalidate the cached thrust policy (the resource_ref it captured
// will be stale once we replace the global device resource).
factories_.at(resource::resource_type::THRUST_POLICY) = std::make_pair(
resource::resource_type::LAST_KEY, std::make_shared<resource::empty_resource_factory>());
resources_.at(resource::resource_type::THRUST_POLICY) = std::make_pair(
resource::resource_type::LAST_KEY, std::make_shared<resource::empty_resource>());
// will be stale once we replace the global device resource). Swapping in a
// fresh cell drops the cached factory/resource locally while snapshot_ keeps
// the originals alive, so it gets lazily rebuilt against the new device MR.
cells_[resource::resource_type::THRUST_POLICY] = std::make_shared<resource::resource_cell>();
{
device_stats_t sa{rmm::device_async_resource_ref{old_device_}};
report_.register_source("device", sa.get_stats());
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/core/resource/comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ inline comms::comms_t const& get_comms(resources const& res)
return *(*res.get_resource<std::shared_ptr<comms::comms_t>>(resource_type::COMMUNICATOR));
}

inline void set_comms(resources const& res, std::shared_ptr<comms::comms_t> communicator)
inline void set_comms(resources& res, std::shared_ptr<comms::comms_t> communicator)
Comment thread
achirkin marked this conversation as resolved.
{
res.add_resource_factory(std::make_shared<comms_resource_factory>(communicator));
}
Expand Down
5 changes: 1 addition & 4 deletions cpp/include/raft/core/resource/cublas_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,7 @@ class cublas_resource_factory : public resource_factory {
*/
inline cublasHandle_t get_cublas_handle(resources const& res)
{
if (!res.has_resource_factory(resource_type::CUBLAS_HANDLE)) {
cudaStream_t stream = get_cuda_stream(res);
res.add_resource_factory(std::make_shared<cublas_resource_factory>(stream));
}
res.ensure_default_factory(std::make_shared<cublas_resource_factory>(get_cuda_stream(res)));
auto ret = *res.get_resource<cublasHandle_t>(resource_type::CUBLAS_HANDLE);
RAFT_CUBLAS_TRY(cublasSetStream(ret, get_cuda_stream(res)));
Comment on lines +62 to 64

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | 🏗️ Heavy lift

HIGH: Shared cuBLAS handle is rebound across divergent stream copies.

After a resources copy changes its CUDA_STREAM_VIEW, both copies can still share the same CUBLAS_HANDLE cell, and Line 64 mutates that shared handle’s stream on every access. Interleaved use of the two handles can therefore enqueue BLAS work on the wrong stream or introduce false ordering between otherwise independent copies.

As per path instructions, "different instances of raft::resources are safe" and "ensure work enqueued on raft::resources’ internal streams is correctly ordered w.r.t. get_cuda_stream()".

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/include/raft/core/resource/cublas_handle.hpp` around lines 62 - 64, The
cublas handle access path in cublas_handle.hpp is mutating a shared
CUBLAS_HANDLE cell by rebinding its stream on every call, which breaks safety
when copied resources diverge via CUDA_STREAM_VIEW. Update the cublas handle
setup/access logic around cublas_resource_factory and get_cublas_handle so each
raft::resources instance keeps stream ordering isolated, rather than reusing and
retargeting the same handle across copies. Ensure the handle’s stream is bound
per-resource or per-stream-view in a way that does not affect sibling copies and
preserves correct ordering with get_cuda_stream().

Source: Path instructions

return ret;
Expand Down
7 changes: 2 additions & 5 deletions cpp/include/raft/core/resource/cublaslt_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,8 @@ class cublaslt_resource_factory : public resource_factory {
*/
inline auto get_cublaslt_handle(resources const& res) -> cublasLtHandle_t
{
if (!res.has_resource_factory(resource_type::CUBLASLT_HANDLE)) {
res.add_resource_factory(std::make_shared<cublaslt_resource_factory>());
}
auto ret = *res.get_resource<cublasLtHandle_t>(resource_type::CUBLASLT_HANDLE);
return ret;
res.ensure_default_factory(std::make_shared<cublaslt_resource_factory>());
return *res.get_resource<cublasLtHandle_t>(resource_type::CUBLASLT_HANDLE);
};

/**
Expand Down
6 changes: 2 additions & 4 deletions cpp/include/raft/core/resource/cuda_stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ class cuda_stream_resource_factory : public resource_factory {
*/
inline rmm::cuda_stream_view get_cuda_stream(resources const& res)
{
if (!res.has_resource_factory(resource_type::CUDA_STREAM_VIEW)) {
res.add_resource_factory(std::make_shared<cuda_stream_resource_factory>());
}
res.ensure_default_factory(std::make_shared<cuda_stream_resource_factory>());
return *res.get_resource<rmm::cuda_stream_view>(resource_type::CUDA_STREAM_VIEW);
};

Expand All @@ -71,7 +69,7 @@ inline rmm::cuda_stream_view get_cuda_stream(resources const& res)
* @param[in] res raft resources object for managing resources
* @param[in] stream_view cuda stream view
*/
inline void set_cuda_stream(resources const& res, rmm::cuda_stream_view stream_view)
inline void set_cuda_stream(resources& res, rmm::cuda_stream_view stream_view)
{
res.add_resource_factory(std::make_shared<cuda_stream_resource_factory>(stream_view));
};
Expand Down
11 changes: 3 additions & 8 deletions cpp/include/raft/core/resource/cuda_stream_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ inline bool is_stream_pool_initialized(const resources& res)
*/
inline const rmm::cuda_stream_pool& get_cuda_stream_pool(const resources& res)
{
if (!res.has_resource_factory(resource_type::CUDA_STREAM_POOL)) {
res.add_resource_factory(std::make_shared<cuda_stream_pool_resource_factory>());
}
res.ensure_default_factory(std::make_shared<cuda_stream_pool_resource_factory>());
return *(
*res.get_resource<std::shared_ptr<rmm::cuda_stream_pool>>(resource_type::CUDA_STREAM_POOL));
};
Expand All @@ -82,8 +80,7 @@ inline const rmm::cuda_stream_pool& get_cuda_stream_pool(const resources& res)
* @param res
* @param stream_pool
*/
inline void set_cuda_stream_pool(const resources& res,
std::shared_ptr<rmm::cuda_stream_pool> stream_pool)
inline void set_cuda_stream_pool(resources& res, std::shared_ptr<rmm::cuda_stream_pool> stream_pool)
{
res.add_resource_factory(std::make_shared<cuda_stream_pool_resource_factory>(stream_pool));
};
Expand Down Expand Up @@ -165,9 +162,7 @@ inline void sync_stream_pool(const resources& res, const std::vector<std::size_t
*/
inline void wait_stream_pool_on_stream(const resources& res)
{
if (!res.has_resource_factory(resource_type::CUDA_STREAM_POOL)) {
res.add_resource_factory(std::make_shared<cuda_stream_pool_resource_factory>());
}
res.ensure_default_factory(std::make_shared<cuda_stream_pool_resource_factory>());

cudaEvent_t event = detail::get_cuda_stream_sync_event(res);
RAFT_CUDA_TRY(cudaEventRecord(event, get_cuda_stream(res)));
Expand Down
5 changes: 1 addition & 4 deletions cpp/include/raft/core/resource/cusolver_dn_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,7 @@ class cusolver_dn_resource_factory : public resource_factory {
*/
inline cusolverDnHandle_t get_cusolver_dn_handle(resources const& res)
{
if (!res.has_resource_factory(resource_type::CUSOLVER_DN_HANDLE)) {
cudaStream_t stream = get_cuda_stream(res);
res.add_resource_factory(std::make_shared<cusolver_dn_resource_factory>(stream));
}
res.ensure_default_factory(std::make_shared<cusolver_dn_resource_factory>(get_cuda_stream(res)));
return *res.get_resource<cusolverDnHandle_t>(resource_type::CUSOLVER_DN_HANDLE);
Comment on lines +68 to 69

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
rg -n -C3 'get_cusolver|get_cusparse|SetStream|get_cuda_stream\(res\)' cpp/include/raft/core/resource

Repository: rapidsai/raft

Length of output: 9241


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Inspect resource ownership/copy semantics and the sibling handle getters.
ast-grep outline cpp/include/raft/core/resource/resource.hpp --view expanded || true
ast-grep outline cpp/include/raft/core/resource/cublas_handle.hpp --view expanded || true
ast-grep outline cpp/include/raft/core/resource/cusparse_handle.hpp --view expanded || true
ast-grep outline cpp/include/raft/core/resource/cusolver_sp_handle.hpp --view expanded || true
ast-grep outline cpp/include/raft/core/resource/cusolver_dn_handle.hpp --view expanded || true

echo '--- resource/copy semantics search ---'
rg -n -C3 'shared_ptr|copy|copy-construct|clone|resource_cell|ensure_default_factory|get_resource<|set_cuda_stream|CUDA_STREAM_VIEW' cpp/include/raft/core/resource cpp/include/raft/core | sed -n '1,220p'

Repository: rapidsai/raft

Length of output: 18111


🏁 Script executed:

#!/bin/bash
set -euo pipefail

ast-grep outline cpp/include/raft/core/resource/cuda_stream.hpp --view expanded || true
sed -n '1,220p' cpp/include/raft/core/resource/cuda_stream.hpp

Repository: rapidsai/raft

Length of output: 2874


HIGH: Rebind stream-bound solver handles on access

set_cuda_stream() can replace the stream resource after this handle has already been cached, so get_cusolver_dn_handle() may keep enqueuing work on the old stream. get_cusparse_handle() and get_cusolver_sp_handle() have the same pattern; cusolverDnSetStream(...)/cusparseSetStream(...)/cusolverSpSetStream(...) on each access would keep the cached handle aligned with the current raft::resources stream.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/include/raft/core/resource/cusolver_dn_handle.hpp` around lines 68 - 69,
get_cusolver_dn_handle currently returns a cached cusolverDnHandle_t without
reattaching it to the current stream, so after set_cuda_stream() the handle may
still target an old stream. Update get_cusolver_dn_handle to call
cusolverDnSetStream(...) on the retrieved handle using get_cuda_stream(res)
before returning it, and apply the same rebinding pattern to get_cusparse_handle
and get_cusolver_sp_handle so all stream-bound solver handles stay aligned with
the active raft::resources stream.

Source: Path instructions

};

Expand Down
5 changes: 1 addition & 4 deletions cpp/include/raft/core/resource/cusolver_sp_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,7 @@ class cusolver_sp_resource_factory : public resource_factory {
*/
inline cusolverSpHandle_t get_cusolver_sp_handle(resources const& res)
{
if (!res.has_resource_factory(resource_type::CUSOLVER_SP_HANDLE)) {
cudaStream_t stream = get_cuda_stream(res);
res.add_resource_factory(std::make_shared<cusolver_sp_resource_factory>(stream));
}
res.ensure_default_factory(std::make_shared<cusolver_sp_resource_factory>(get_cuda_stream(res)));
return *res.get_resource<cusolverSpHandle_t>(resource_type::CUSOLVER_SP_HANDLE);
};

Expand Down
5 changes: 1 addition & 4 deletions cpp/include/raft/core/resource/cusparse_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,7 @@ class cusparse_resource_factory : public resource_factory {
*/
inline cusparseHandle_t get_cusparse_handle(resources const& res)
{
if (!res.has_resource_factory(resource_type::CUSPARSE_HANDLE)) {
rmm::cuda_stream_view stream = get_cuda_stream(res);
res.add_resource_factory(std::make_shared<cusparse_resource_factory>(stream));
}
res.ensure_default_factory(std::make_shared<cusparse_resource_factory>(get_cuda_stream(res)));
return *res.get_resource<cusparseHandle_t>(resource_type::CUSPARSE_HANDLE);
};

Expand Down
4 changes: 1 addition & 3 deletions cpp/include/raft/core/resource/custom_resource.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ template <typename ResourceT>
auto get_custom_resource(resources const& res) -> ResourceT*
{
static_assert(std::is_default_constructible_v<ResourceT>);
if (!res.has_resource_factory(resource_type::CUSTOM)) {
res.add_resource_factory(std::make_shared<custom_resource_factory>());
}
res.ensure_default_factory(std::make_shared<custom_resource_factory>());
return res.get_resource<custom_resource>(resource_type::CUSTOM)->load<ResourceT>();
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ class cuda_stream_sync_event_resource_factory : public resource_factory {
*/
inline cudaEvent_t& get_cuda_stream_sync_event(resources const& res)
{
if (!res.has_resource_factory(resource_type::CUDA_STREAM_SYNC_EVENT)) {
res.add_resource_factory(std::make_shared<cuda_stream_sync_event_resource_factory>());
}
res.ensure_default_factory(std::make_shared<cuda_stream_sync_event_resource_factory>());
return *res.get_resource<cudaEvent_t>(resource_type::CUDA_STREAM_SYNC_EVENT);
};

Expand Down
4 changes: 1 addition & 3 deletions cpp/include/raft/core/resource/device_id.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ class device_id_resource_factory : public resource_factory {
*/
inline int get_device_id(resources const& res)
{
if (!res.has_resource_factory(resource_type::DEVICE_ID)) {
res.add_resource_factory(std::make_shared<device_id_resource_factory>());
}
res.ensure_default_factory(std::make_shared<device_id_resource_factory>());
return *res.get_resource<int>(resource_type::DEVICE_ID);
};

Expand Down
16 changes: 6 additions & 10 deletions cpp/include/raft/core/resource/device_memory_resource.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,7 @@ namespace detail {

inline auto get_workspace_adaptor(resources const& res) -> rmm::mr::limiting_resource_adaptor*
{
if (!res.has_resource_factory(resource_type::WORKSPACE_RESOURCE)) {
res.add_resource_factory(std::make_shared<workspace_resource_factory>());
}
res.ensure_default_factory(std::make_shared<workspace_resource_factory>());
return res.get_resource<rmm::mr::limiting_resource_adaptor>(resource_type::WORKSPACE_RESOURCE);
}

Expand Down Expand Up @@ -254,7 +252,7 @@ inline auto get_workspace_free_bytes(resources const& res) -> size_t
* the total amount of memory in bytes available to the temporary workspace resources.
* @param alignment optional alignment requirements passed to allocations
*/
inline void set_workspace_resource(resources const& res,
inline void set_workspace_resource(resources& res,
raft::mr::device_resource mr,
std::optional<std::size_t> allocation_limit = std::nullopt,
std::optional<std::size_t> alignment = std::nullopt)
Expand All @@ -274,7 +272,7 @@ inline void set_workspace_resource(resources const& res,
*
*/
inline void set_workspace_to_pool_resource(
resources const& res, std::optional<std::size_t> allocation_limit = std::nullopt)
resources& res, std::optional<std::size_t> allocation_limit = std::nullopt)
{
if (!allocation_limit.has_value()) { allocation_limit = get_workspace_total_bytes(res); }
res.add_resource_factory(std::make_shared<workspace_resource_factory>(
Expand All @@ -295,7 +293,7 @@ inline void set_workspace_to_pool_resource(
* the total amount of memory in bytes available to the temporary workspace resources.
*/
inline void set_workspace_to_global_resource(
resources const& res, std::optional<std::size_t> allocation_limit = std::nullopt)
resources& res, std::optional<std::size_t> allocation_limit = std::nullopt)
{
res.add_resource_factory(std::make_shared<workspace_resource_factory>(
raft::mr::device_resource{rmm::mr::get_current_device_resource_ref()},
Expand All @@ -311,9 +309,7 @@ inline void set_workspace_to_global_resource(
*/
inline auto get_large_workspace_resource_ref(resources const& res) -> rmm::device_async_resource_ref
{
if (!res.has_resource_factory(resource_type::LARGE_WORKSPACE_RESOURCE)) {
res.add_resource_factory(std::make_shared<large_workspace_resource_factory>());
}
res.ensure_default_factory(std::make_shared<large_workspace_resource_factory>());
return rmm::device_async_resource_ref{
*res.get_resource<raft::mr::device_resource>(resource_type::LARGE_WORKSPACE_RESOURCE)};
}
Expand All @@ -324,7 +320,7 @@ inline auto get_large_workspace_resource_ref(resources const& res) -> rmm::devic
* @param res raft resources object for managing resources
* @param mr device memory resource
*/
inline void set_large_workspace_resource(resources const& res, raft::mr::device_resource mr)
inline void set_large_workspace_resource(resources& res, raft::mr::device_resource mr)
{
res.add_resource_factory(std::make_shared<large_workspace_resource_factory>(std::move(mr)));
}
Expand Down
6 changes: 2 additions & 4 deletions cpp/include/raft/core/resource/device_properties.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,8 @@ class device_properties_resource_factory : public resource_factory {
*/
inline cudaDeviceProp& get_device_properties(resources const& res)
{
if (!res.has_resource_factory(resource_type::DEVICE_PROPERTIES)) {
int dev_id = get_device_id(res);
res.add_resource_factory(std::make_shared<device_properties_resource_factory>(dev_id));
}
res.ensure_default_factory(
std::make_shared<device_properties_resource_factory>(get_device_id(res)));
return *res.get_resource<cudaDeviceProp>(resource_type::DEVICE_PROPERTIES);
};

Expand Down
6 changes: 2 additions & 4 deletions cpp/include/raft/core/resource/managed_memory_resource.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ class managed_memory_resource_factory : public resource_factory {
inline auto get_managed_memory_resource_ref(resources const& res)
-> raft::mr::host_device_resource_ref
{
if (!res.has_resource_factory(resource_type::MANAGED_MEMORY_RESOURCE)) {
res.add_resource_factory(std::make_shared<managed_memory_resource_factory>());
}
res.ensure_default_factory(std::make_shared<managed_memory_resource_factory>());
auto& mr =
*res.get_resource<raft::mr::host_device_resource>(resource_type::MANAGED_MEMORY_RESOURCE);
return raft::mr::host_device_resource_ref{mr};
Expand All @@ -72,7 +70,7 @@ inline auto get_managed_memory_resource_ref(resources const& res)
* @param res raft resources object for managing resources
* @param mr host+device accessible memory resource
*/
inline void set_managed_memory_resource(resources const& res, raft::mr::host_device_resource mr)
inline void set_managed_memory_resource(resources& res, raft::mr::host_device_resource mr)
{
res.add_resource_factory(std::make_shared<managed_memory_resource_factory>(std::move(mr)));
}
Expand Down
Loading
Loading