From 6be7b6cc5e26dea1af76947904d2a9a77f7c0b54 Mon Sep 17 00:00:00 2001 From: achirkin Date: Wed, 29 Apr 2026 12:50:13 +0200 Subject: [PATCH 1/5] Lazy-initialization propagates through copies --- cpp/include/raft/core/device_resources.hpp | 9 +- cpp/include/raft/core/handle.hpp | 8 +- cpp/include/raft/core/resource/comms.hpp | 4 +- .../raft/core/resource/cublas_handle.hpp | 7 +- .../raft/core/resource/cublaslt_handle.hpp | 9 +- .../raft/core/resource/cuda_stream.hpp | 8 +- .../raft/core/resource/cuda_stream_pool.hpp | 13 +- .../raft/core/resource/cusolver_dn_handle.hpp | 7 +- .../raft/core/resource/cusolver_sp_handle.hpp | 7 +- .../raft/core/resource/cusparse_handle.hpp | 7 +- .../raft/core/resource/custom_resource.hpp | 6 +- .../resource/detail/stream_sync_event.hpp | 6 +- cpp/include/raft/core/resource/device_id.hpp | 6 +- .../core/resource/device_memory_resource.hpp | 16 +-- .../raft/core/resource/device_properties.hpp | 8 +- .../core/resource/managed_memory_resource.hpp | 6 +- cpp/include/raft/core/resource/multi_gpu.hpp | 25 ++-- cpp/include/raft/core/resource/nccl_comm.hpp | 4 +- .../core/resource/pinned_memory_resource.hpp | 6 +- .../raft/core/resource/resource_types.hpp | 40 ++---- .../raft/core/resource/stream_view.hpp | 8 +- cpp/include/raft/core/resource/sub_comms.hpp | 15 +-- .../raft/core/resource/thrust_policy.hpp | 8 +- cpp/include/raft/core/resources.hpp | 116 ++++++++++-------- .../raft/util/memory_tracking_resources.hpp | 4 +- cpp/tests/core/handle.cpp | 70 +++++++++++ 26 files changed, 223 insertions(+), 200 deletions(-) diff --git a/cpp/include/raft/core/device_resources.hpp b/cpp/include/raft/core/device_resources.hpp index 753ac769d3..18dbae5f0d 100644 --- a/cpp/include/raft/core/device_resources.hpp +++ b/cpp/include/raft/core/device_resources.hpp @@ -36,10 +36,8 @@ #include #include -#include #include #include -#include #include #include @@ -59,9 +57,10 @@ class device_resources : public resources { resource::set_workspace_resource(*this, std::move(workspace_resource), allocation_limit); } - device_resources(const device_resources& handle) : resources{handle} {} - device_resources(device_resources&&) = delete; - device_resources& operator=(device_resources&&) = delete; + device_resources(const device_resources&) = default; + device_resources(device_resources&&) = default; + device_resources& operator=(const device_resources&) = default; + device_resources& operator=(device_resources&&) = default; /** * @brief Construct a resources instance with a stream view and stream pool diff --git a/cpp/include/raft/core/handle.hpp b/cpp/include/raft/core/handle.hpp index ac2b5705b6..63eb519f8f 100644 --- a/cpp/include/raft/core/handle.hpp +++ b/cpp/include/raft/core/handle.hpp @@ -26,10 +26,10 @@ class handle_t : public raft::device_resources { { } - handle_t(const handle_t& handle) : device_resources{handle} {} - - handle_t(handle_t&&) = delete; - handle_t& operator=(handle_t&&) = delete; + handle_t(const handle_t&) = default; + handle_t(handle_t&&) = default; + handle_t& operator=(const handle_t&) = default; + handle_t& operator=(handle_t&&) = default; /** * @brief Construct a resources instance with a stream view and stream pool diff --git a/cpp/include/raft/core/resource/comms.hpp b/cpp/include/raft/core/resource/comms.hpp index 58e7cffd9f..4b8b2a9859 100644 --- a/cpp/include/raft/core/resource/comms.hpp +++ b/cpp/include/raft/core/resource/comms.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -56,7 +56,7 @@ inline comms::comms_t const& get_comms(resources const& res) return *(*res.get_resource>(resource_type::COMMUNICATOR)); } -inline void set_comms(resources const& res, std::shared_ptr communicator) +inline void set_comms(resources& res, std::shared_ptr communicator) { res.add_resource_factory(std::make_shared(communicator)); } diff --git a/cpp/include/raft/core/resource/cublas_handle.hpp b/cpp/include/raft/core/resource/cublas_handle.hpp index ba7443f708..e7bb43ea0c 100644 --- a/cpp/include/raft/core/resource/cublas_handle.hpp +++ b/cpp/include/raft/core/resource/cublas_handle.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -57,10 +57,7 @@ class cublas_resource_factory : public resource_factory { */ inline cublasHandle_t get_cublas_handle(resources const& res) { - if (!res.has_resource_factory(resource_type::CUBLAS_HANDLE)) { - cudaStream_t stream = get_cuda_stream(res); - res.add_resource_factory(std::make_shared(stream)); - } + res.ensure_default_factory(std::make_shared(get_cuda_stream(res))); auto ret = *res.get_resource(resource_type::CUBLAS_HANDLE); RAFT_CUBLAS_TRY(cublasSetStream(ret, get_cuda_stream(res))); return ret; diff --git a/cpp/include/raft/core/resource/cublaslt_handle.hpp b/cpp/include/raft/core/resource/cublaslt_handle.hpp index 074393b18b..67c27d3076 100644 --- a/cpp/include/raft/core/resource/cublaslt_handle.hpp +++ b/cpp/include/raft/core/resource/cublaslt_handle.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -44,11 +44,8 @@ class cublaslt_resource_factory : public resource_factory { */ inline auto get_cublaslt_handle(resources const& res) -> cublasLtHandle_t { - if (!res.has_resource_factory(resource_type::CUBLASLT_HANDLE)) { - res.add_resource_factory(std::make_shared()); - } - auto ret = *res.get_resource(resource_type::CUBLASLT_HANDLE); - return ret; + res.ensure_default_factory(std::make_shared()); + return *res.get_resource(resource_type::CUBLASLT_HANDLE); }; /** diff --git a/cpp/include/raft/core/resource/cuda_stream.hpp b/cpp/include/raft/core/resource/cuda_stream.hpp index 690bd610f9..9df44d634b 100644 --- a/cpp/include/raft/core/resource/cuda_stream.hpp +++ b/cpp/include/raft/core/resource/cuda_stream.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -57,9 +57,7 @@ class cuda_stream_resource_factory : public resource_factory { */ inline rmm::cuda_stream_view get_cuda_stream(resources const& res) { - if (!res.has_resource_factory(resource_type::CUDA_STREAM_VIEW)) { - res.add_resource_factory(std::make_shared()); - } + res.ensure_default_factory(std::make_shared()); return *res.get_resource(resource_type::CUDA_STREAM_VIEW); }; @@ -69,7 +67,7 @@ inline rmm::cuda_stream_view get_cuda_stream(resources const& res) * @param[in] res raft resources object for managing resources * @param[in] stream_view cuda stream view */ -inline void set_cuda_stream(resources const& res, rmm::cuda_stream_view stream_view) +inline void set_cuda_stream(resources& res, rmm::cuda_stream_view stream_view) { res.add_resource_factory(std::make_shared(stream_view)); }; diff --git a/cpp/include/raft/core/resource/cuda_stream_pool.hpp b/cpp/include/raft/core/resource/cuda_stream_pool.hpp index 0df76aa13f..3e6173bdf7 100644 --- a/cpp/include/raft/core/resource/cuda_stream_pool.hpp +++ b/cpp/include/raft/core/resource/cuda_stream_pool.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -67,9 +67,7 @@ inline bool is_stream_pool_initialized(const resources& res) */ inline const rmm::cuda_stream_pool& get_cuda_stream_pool(const resources& res) { - if (!res.has_resource_factory(resource_type::CUDA_STREAM_POOL)) { - res.add_resource_factory(std::make_shared()); - } + res.ensure_default_factory(std::make_shared()); return *( *res.get_resource>(resource_type::CUDA_STREAM_POOL)); }; @@ -80,8 +78,7 @@ inline const rmm::cuda_stream_pool& get_cuda_stream_pool(const resources& res) * @param res * @param stream_pool */ -inline void set_cuda_stream_pool(const resources& res, - std::shared_ptr stream_pool) +inline void set_cuda_stream_pool(resources& res, std::shared_ptr stream_pool) { res.add_resource_factory(std::make_shared(stream_pool)); }; @@ -163,9 +160,7 @@ inline void sync_stream_pool(const resources& res, const std::vector()); - } + res.ensure_default_factory(std::make_shared()); cudaEvent_t event = detail::get_cuda_stream_sync_event(res); RAFT_CUDA_TRY(cudaEventRecord(event, get_cuda_stream(res))); diff --git a/cpp/include/raft/core/resource/cusolver_dn_handle.hpp b/cpp/include/raft/core/resource/cusolver_dn_handle.hpp index 66b09a108b..d981d04220 100644 --- a/cpp/include/raft/core/resource/cusolver_dn_handle.hpp +++ b/cpp/include/raft/core/resource/cusolver_dn_handle.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -63,10 +63,7 @@ class cusolver_dn_resource_factory : public resource_factory { */ inline cusolverDnHandle_t get_cusolver_dn_handle(resources const& res) { - if (!res.has_resource_factory(resource_type::CUSOLVER_DN_HANDLE)) { - cudaStream_t stream = get_cuda_stream(res); - res.add_resource_factory(std::make_shared(stream)); - } + res.ensure_default_factory(std::make_shared(get_cuda_stream(res))); return *res.get_resource(resource_type::CUSOLVER_DN_HANDLE); }; diff --git a/cpp/include/raft/core/resource/cusolver_sp_handle.hpp b/cpp/include/raft/core/resource/cusolver_sp_handle.hpp index 9568274ef0..78ee8c09f6 100644 --- a/cpp/include/raft/core/resource/cusolver_sp_handle.hpp +++ b/cpp/include/raft/core/resource/cusolver_sp_handle.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -60,10 +60,7 @@ class cusolver_sp_resource_factory : public resource_factory { */ inline cusolverSpHandle_t get_cusolver_sp_handle(resources const& res) { - if (!res.has_resource_factory(resource_type::CUSOLVER_SP_HANDLE)) { - cudaStream_t stream = get_cuda_stream(res); - res.add_resource_factory(std::make_shared(stream)); - } + res.ensure_default_factory(std::make_shared(get_cuda_stream(res))); return *res.get_resource(resource_type::CUSOLVER_SP_HANDLE); }; diff --git a/cpp/include/raft/core/resource/cusparse_handle.hpp b/cpp/include/raft/core/resource/cusparse_handle.hpp index 364f617e13..005a87f64b 100644 --- a/cpp/include/raft/core/resource/cusparse_handle.hpp +++ b/cpp/include/raft/core/resource/cusparse_handle.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -55,10 +55,7 @@ class cusparse_resource_factory : public resource_factory { */ inline cusparseHandle_t get_cusparse_handle(resources const& res) { - if (!res.has_resource_factory(resource_type::CUSPARSE_HANDLE)) { - rmm::cuda_stream_view stream = get_cuda_stream(res); - res.add_resource_factory(std::make_shared(stream)); - } + res.ensure_default_factory(std::make_shared(get_cuda_stream(res))); return *res.get_resource(resource_type::CUSPARSE_HANDLE); }; diff --git a/cpp/include/raft/core/resource/custom_resource.hpp b/cpp/include/raft/core/resource/custom_resource.hpp index bde81055d4..8cbc5b2817 100644 --- a/cpp/include/raft/core/resource/custom_resource.hpp +++ b/cpp/include/raft/core/resource/custom_resource.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -69,9 +69,7 @@ template auto get_custom_resource(resources const& res) -> ResourceT* { static_assert(std::is_default_constructible_v); - if (!res.has_resource_factory(resource_type::CUSTOM)) { - res.add_resource_factory(std::make_shared()); - } + res.ensure_default_factory(std::make_shared()); return res.get_resource(resource_type::CUSTOM)->load(); }; diff --git a/cpp/include/raft/core/resource/detail/stream_sync_event.hpp b/cpp/include/raft/core/resource/detail/stream_sync_event.hpp index d34e60fe72..d822b88096 100644 --- a/cpp/include/raft/core/resource/detail/stream_sync_event.hpp +++ b/cpp/include/raft/core/resource/detail/stream_sync_event.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -31,9 +31,7 @@ class cuda_stream_sync_event_resource_factory : public resource_factory { */ inline cudaEvent_t& get_cuda_stream_sync_event(resources const& res) { - if (!res.has_resource_factory(resource_type::CUDA_STREAM_SYNC_EVENT)) { - res.add_resource_factory(std::make_shared()); - } + res.ensure_default_factory(std::make_shared()); return *res.get_resource(resource_type::CUDA_STREAM_SYNC_EVENT); }; diff --git a/cpp/include/raft/core/resource/device_id.hpp b/cpp/include/raft/core/resource/device_id.hpp index e710f28b89..eba09c503c 100644 --- a/cpp/include/raft/core/resource/device_id.hpp +++ b/cpp/include/raft/core/resource/device_id.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -53,9 +53,7 @@ class device_id_resource_factory : public resource_factory { */ inline int get_device_id(resources const& res) { - if (!res.has_resource_factory(resource_type::DEVICE_ID)) { - res.add_resource_factory(std::make_shared()); - } + res.ensure_default_factory(std::make_shared()); return *res.get_resource(resource_type::DEVICE_ID); }; diff --git a/cpp/include/raft/core/resource/device_memory_resource.hpp b/cpp/include/raft/core/resource/device_memory_resource.hpp index 17c929aae9..5654197227 100644 --- a/cpp/include/raft/core/resource/device_memory_resource.hpp +++ b/cpp/include/raft/core/resource/device_memory_resource.hpp @@ -166,9 +166,7 @@ namespace detail { inline auto get_workspace_adaptor(resources const& res) -> rmm::mr::limiting_resource_adaptor* { - if (!res.has_resource_factory(resource_type::WORKSPACE_RESOURCE)) { - res.add_resource_factory(std::make_shared()); - } + res.ensure_default_factory(std::make_shared()); return res.get_resource(resource_type::WORKSPACE_RESOURCE); } @@ -243,7 +241,7 @@ inline auto get_workspace_free_bytes(resources const& res) -> size_t * the total amount of memory in bytes available to the temporary workspace resources. * @param alignment optional alignment requirements passed to allocations */ -inline void set_workspace_resource(resources const& res, +inline void set_workspace_resource(resources& res, raft::mr::device_resource mr, std::optional allocation_limit = std::nullopt, std::optional alignment = std::nullopt) @@ -263,7 +261,7 @@ inline void set_workspace_resource(resources const& res, * */ inline void set_workspace_to_pool_resource( - resources const& res, std::optional allocation_limit = std::nullopt) + resources& res, std::optional allocation_limit = std::nullopt) { if (!allocation_limit.has_value()) { allocation_limit = get_workspace_total_bytes(res); } res.add_resource_factory(std::make_shared( @@ -284,7 +282,7 @@ inline void set_workspace_to_pool_resource( * the total amount of memory in bytes available to the temporary workspace resources. */ inline void set_workspace_to_global_resource( - resources const& res, std::optional allocation_limit = std::nullopt) + resources& res, std::optional allocation_limit = std::nullopt) { res.add_resource_factory(std::make_shared( raft::mr::device_resource{rmm::mr::get_current_device_resource_ref()}, @@ -300,9 +298,7 @@ inline void set_workspace_to_global_resource( */ inline auto get_large_workspace_resource_ref(resources const& res) -> rmm::device_async_resource_ref { - if (!res.has_resource_factory(resource_type::LARGE_WORKSPACE_RESOURCE)) { - res.add_resource_factory(std::make_shared()); - } + res.ensure_default_factory(std::make_shared()); return rmm::device_async_resource_ref{ *res.get_resource(resource_type::LARGE_WORKSPACE_RESOURCE)}; } @@ -313,7 +309,7 @@ inline auto get_large_workspace_resource_ref(resources const& res) -> rmm::devic * @param res raft resources object for managing resources * @param mr device memory resource */ -inline void set_large_workspace_resource(resources const& res, raft::mr::device_resource mr) +inline void set_large_workspace_resource(resources& res, raft::mr::device_resource mr) { res.add_resource_factory(std::make_shared(std::move(mr))); } diff --git a/cpp/include/raft/core/resource/device_properties.hpp b/cpp/include/raft/core/resource/device_properties.hpp index 5bdab71d79..03b432583c 100644 --- a/cpp/include/raft/core/resource/device_properties.hpp +++ b/cpp/include/raft/core/resource/device_properties.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -54,10 +54,8 @@ class device_properties_resource_factory : public resource_factory { */ inline cudaDeviceProp& get_device_properties(resources const& res) { - if (!res.has_resource_factory(resource_type::DEVICE_PROPERTIES)) { - int dev_id = get_device_id(res); - res.add_resource_factory(std::make_shared(dev_id)); - } + res.ensure_default_factory( + std::make_shared(get_device_id(res))); return *res.get_resource(resource_type::DEVICE_PROPERTIES); }; diff --git a/cpp/include/raft/core/resource/managed_memory_resource.hpp b/cpp/include/raft/core/resource/managed_memory_resource.hpp index 171d0b4cd4..92b1356f26 100644 --- a/cpp/include/raft/core/resource/managed_memory_resource.hpp +++ b/cpp/include/raft/core/resource/managed_memory_resource.hpp @@ -56,9 +56,7 @@ class managed_memory_resource_factory : public resource_factory { inline auto get_managed_memory_resource_ref(resources const& res) -> raft::mr::host_device_resource_ref { - if (!res.has_resource_factory(resource_type::MANAGED_MEMORY_RESOURCE)) { - res.add_resource_factory(std::make_shared()); - } + res.ensure_default_factory(std::make_shared()); auto& mr = *res.get_resource(resource_type::MANAGED_MEMORY_RESOURCE); return raft::mr::host_device_resource_ref{mr}; @@ -70,7 +68,7 @@ inline auto get_managed_memory_resource_ref(resources const& res) * @param res raft resources object for managing resources * @param mr host+device accessible memory resource */ -inline void set_managed_memory_resource(resources const& res, raft::mr::host_device_resource mr) +inline void set_managed_memory_resource(resources& res, raft::mr::host_device_resource mr) { res.add_resource_factory(std::make_shared(std::move(mr))); } diff --git a/cpp/include/raft/core/resource/multi_gpu.hpp b/cpp/include/raft/core/resource/multi_gpu.hpp index 841743e222..4560695666 100644 --- a/cpp/include/raft/core/resource/multi_gpu.hpp +++ b/cpp/include/raft/core/resource/multi_gpu.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -31,7 +31,7 @@ class multi_gpu_resource_factory : public resource_factory { class root_rank_resource : public resource { public: - root_rank_resource() : root_rank_(0) {} + explicit root_rank_resource(int root_rank = 0) : root_rank_(root_rank) {} void* get_resource() override { return &root_rank_; } ~root_rank_resource() override {} @@ -42,15 +42,17 @@ class root_rank_resource : public resource { class root_rank_resource_factory : public resource_factory { public: + explicit root_rank_resource_factory(int root_rank = 0) : root_rank_(root_rank) {} resource_type get_resource_type() override { return resource_type::ROOT_RANK; } - resource* make_resource() override { return new root_rank_resource(); } + resource* make_resource() override { return new root_rank_resource(root_rank_); } + + private: + int root_rank_; }; -inline int& get_root_rank(resources const& res) +inline int get_root_rank(resources const& res) { - if (!res.has_resource_factory(resource_type::ROOT_RANK)) { - res.add_resource_factory(std::make_shared()); - } + res.ensure_default_factory(std::make_shared()); return *res.get_resource(resource_type::ROOT_RANK); }; @@ -63,9 +65,7 @@ inline int& get_root_rank(resources const& res) */ inline std::vector& get_multi_gpu_resource(resources const& res) { - if (!res.has_resource_factory(resource_type::MULTI_GPU)) { - res.add_resource_factory(std::make_shared()); - } + res.ensure_default_factory(std::make_shared()); return *res.get_resource>(resource_type::MULTI_GPU); }; @@ -118,10 +118,9 @@ inline const raft::resources& set_current_device_to_root_rank(resources const& r /** * @brief Set the root rank to given rank */ -inline void set_root_rank(resources const& res, int root_rank) +inline void set_root_rank(resources& res, int root_rank) { - int& root_rank_ = get_root_rank(res); - root_rank_ = root_rank; + res.add_resource_factory(std::make_shared(root_rank)); }; } // namespace raft::resource diff --git a/cpp/include/raft/core/resource/nccl_comm.hpp b/cpp/include/raft/core/resource/nccl_comm.hpp index 885112c571..086e45d373 100644 --- a/cpp/include/raft/core/resource/nccl_comm.hpp +++ b/cpp/include/raft/core/resource/nccl_comm.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -74,7 +74,7 @@ inline void _init_nccl_comms(const resources& res) inline std::vector& get_nccl_comms(const resources& res) { if (!res.has_resource_factory(resource_type::NCCL_COMM)) { - res.add_resource_factory(std::make_shared()); + res.ensure_default_factory(std::make_shared()); _init_nccl_comms(res); } return *res.get_resource>(resource_type::NCCL_COMM); diff --git a/cpp/include/raft/core/resource/pinned_memory_resource.hpp b/cpp/include/raft/core/resource/pinned_memory_resource.hpp index fda7af292e..75f74b91cb 100644 --- a/cpp/include/raft/core/resource/pinned_memory_resource.hpp +++ b/cpp/include/raft/core/resource/pinned_memory_resource.hpp @@ -54,9 +54,7 @@ class pinned_memory_resource_factory : public resource_factory { inline auto get_pinned_memory_resource_ref(resources const& res) -> raft::mr::host_device_resource_ref { - if (!res.has_resource_factory(resource_type::PINNED_MEMORY_RESOURCE)) { - res.add_resource_factory(std::make_shared()); - } + res.ensure_default_factory(std::make_shared()); auto& mr = *res.get_resource(resource_type::PINNED_MEMORY_RESOURCE); return raft::mr::host_device_resource_ref{mr}; @@ -68,7 +66,7 @@ inline auto get_pinned_memory_resource_ref(resources const& res) * @param res raft resources object for managing resources * @param mr host+device accessible memory resource */ -inline void set_pinned_memory_resource(resources const& res, raft::mr::host_device_resource mr) +inline void set_pinned_memory_resource(resources& res, raft::mr::host_device_resource mr) { res.add_resource_factory(std::make_shared(std::move(mr))); } diff --git a/cpp/include/raft/core/resource/resource_types.hpp b/cpp/include/raft/core/resource/resource_types.hpp index cda3c8ecae..65e199a3c7 100644 --- a/cpp/include/raft/core/resource/resource_types.hpp +++ b/cpp/include/raft/core/resource/resource_types.hpp @@ -5,6 +5,9 @@ #pragma once +#include +#include + namespace raft::resource { /** @@ -56,15 +59,6 @@ class resource { virtual ~resource() {} }; -class empty_resource : public resource { - public: - empty_resource() : resource() {} - - void* get_resource() override { return nullptr; } - - ~empty_resource() override {} -}; - /** * @brief A resource factory knows how to construct an instance of * a specific raft::resource::resource. @@ -87,26 +81,16 @@ class resource_factory { }; /** - * @brief A resource factory knows how to construct an instance of - * a specific raft::resource::resource. + * @brief Shared cell holding a factory and a lazily-created resource. + * + * Multiple raft::resources handles can share a cell via shared_ptr. + * Lazy initialization stores the concrete resource atomically, so all + * handles sharing the cell see the update. Explicit set replaces the + * shared_ptr in the local handle, isolating the change. */ -class empty_resource_factory : public resource_factory { - public: - empty_resource_factory() : resource_factory() {} - /** - * @brief Return the resource_type associated with the current factory - * @return resource_type corresponding to the current factory - */ - resource_type get_resource_type() override { return resource_type::LAST_KEY; } - - /** - * @brief Construct an instance of the factory's underlying resource. - * @return resource instance - */ - resource* make_resource() override { return &res; } - - private: - empty_resource res; +struct resource_cell { + std::atomic> factory{}; + std::atomic> res{}; }; /** diff --git a/cpp/include/raft/core/resource/stream_view.hpp b/cpp/include/raft/core/resource/stream_view.hpp index b2df678220..022aaacb08 100644 --- a/cpp/include/raft/core/resource/stream_view.hpp +++ b/cpp/include/raft/core/resource/stream_view.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -50,9 +50,7 @@ struct stream_view_resource_factory : public resource_factory { */ inline raft::stream_view get_stream_view(resources const& res) { - if (!res.has_resource_factory(resource_type::STREAM_VIEW)) { - res.add_resource_factory(std::make_shared()); - } + res.ensure_default_factory(std::make_shared()); return *res.get_resource(resource_type::STREAM_VIEW); }; @@ -62,7 +60,7 @@ inline raft::stream_view get_stream_view(resources const& res) * @param[in] res raft resources object for managing resources * @param[in] view raft stream view */ -inline void set_stream_view(resources const& res, raft::stream_view view) +inline void set_stream_view(resources& res, raft::stream_view view) { res.add_resource_factory(std::make_shared(view)); }; diff --git a/cpp/include/raft/core/resource/sub_comms.hpp b/cpp/include/raft/core/resource/sub_comms.hpp index 9feec30105..e1bd2d5698 100644 --- a/cpp/include/raft/core/resource/sub_comms.hpp +++ b/cpp/include/raft/core/resource/sub_comms.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -38,9 +38,7 @@ class sub_comms_resource_factory : public resource_factory { inline const comms::comms_t& get_subcomm(const resources& res, std::string key) { - if (!res.has_resource_factory(resource_type::SUB_COMMUNICATOR)) { - res.add_resource_factory(std::make_shared()); - } + res.ensure_default_factory(std::make_shared()); auto sub_comms = res.get_resource>>( @@ -51,13 +49,16 @@ inline const comms::comms_t& get_subcomm(const resources& res, std::string key) return *sub_comm; } +// In-place mutation: sub-communicators are typically set once during initialization +// and should be visible to all copies of the resources handle (Goal 1 / lazy-init +// propagation semantics). +// Note: we don't replace the _resource_ here (sub_comms_resource), so `res` is passed +// by const reference. inline void set_subcomm(resources const& res, std::string key, std::shared_ptr subcomm) { - if (!res.has_resource_factory(resource_type::SUB_COMMUNICATOR)) { - res.add_resource_factory(std::make_shared()); - } + res.ensure_default_factory(std::make_shared()); auto sub_comms = res.get_resource>>( resource_type::SUB_COMMUNICATOR); diff --git a/cpp/include/raft/core/resource/thrust_policy.hpp b/cpp/include/raft/core/resource/thrust_policy.hpp index 93a68fcec3..3e1dd90e86 100644 --- a/cpp/include/raft/core/resource/thrust_policy.hpp +++ b/cpp/include/raft/core/resource/thrust_policy.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -51,10 +51,8 @@ class thrust_policy_resource_factory : public resource_factory { */ inline rmm::exec_policy_nosync& get_thrust_policy(resources const& res) { - if (!res.has_resource_factory(resource_type::THRUST_POLICY)) { - rmm::cuda_stream_view stream = get_cuda_stream(res); - res.add_resource_factory(std::make_shared(stream)); - } + res.ensure_default_factory( + std::make_shared(get_cuda_stream(res))); return *res.get_resource(resource_type::THRUST_POLICY); }; diff --git a/cpp/include/raft/core/resources.hpp b/cpp/include/raft/core/resources.hpp index dbf1c701ee..10aaa34a0a 100644 --- a/cpp/include/raft/core/resources.hpp +++ b/cpp/include/raft/core/resources.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -9,9 +9,7 @@ #include // RAFT_EXPECTS #include -#include -#include -#include +#include #include namespace raft { @@ -25,6 +23,16 @@ namespace raft { * accessor functions can then register and load resources as needed in order * to keep its usage somewhat opaque to end-users. * + * Copies of a resources handle share the underlying resource_cell objects. + * Lazy initialization (via get_resource / ensure_default_factory) stores into the + * shared cell's atomics, so all copies see the update. Explicit modification + * (via add_resource_factory) replaces the shared_ptr in the local + * vector, isolating the change from other copies. + * + * Thread safety: concurrent const operations on the same handle are safe + * (inner-cell atomics). Concurrent const + non-const on the same handle + * requires external synchronization (standard C++ rules). + * * @code{.cpp} * #include * #include @@ -37,30 +45,17 @@ namespace raft { */ class resources { public: - template - using pair_res = std::pair>; - - using pair_res_factory = pair_res; - using pair_resource = pair_res; - - resources() - : factories_(resource::resource_type::LAST_KEY), resources_(resource::resource_type::LAST_KEY) + resources() : cells_(resource::resource_type::LAST_KEY) { - for (int i = 0; i < resource::resource_type::LAST_KEY; ++i) { - factories_.at(i) = std::make_pair(resource::resource_type::LAST_KEY, - std::make_shared()); - resources_.at(i) = std::make_pair(resource::resource_type::LAST_KEY, - std::make_shared()); + for (auto& c : cells_) { + c = std::make_shared(); } } - /** - * @brief Shallow copy of underlying resources instance. - * Note that this does not create any new resources. - */ - resources(const resources& res) : factories_(res.factories_), resources_(res.resources_) {} - resources(resources&&) = delete; - resources& operator=(resources&&) = delete; + resources(const resources&) = default; + resources(resources&&) = default; + resources& operator=(const resources&) = default; + resources& operator=(resources&&) = default; virtual ~resources() {} /** @@ -71,34 +66,50 @@ class resources { */ virtual bool has_resource_factory(resource::resource_type resource_type) const { - std::lock_guard _(mutex_); - return factories_.at(resource_type).first != resource::resource_type::LAST_KEY; + return cells_[resource_type]->factory.load() != nullptr; } /** - * @brief Register a resource_factory with the current instance. - * This will overwrite any existing resource factories. + * @brief Register a resource_factory with the current instance (explicit set). + * + * Creates a new resource_cell with the given factory. Other copies of this + * handle continue to point at the old cell, so the change does not propagate. + * * @param factory resource factory to register on the current instance */ - void add_resource_factory(std::shared_ptr factory) const + void add_resource_factory(std::shared_ptr factory) { - std::lock_guard _(mutex_); - resource::resource_type rtype = factory.get()->get_resource_type(); + resource::resource_type rtype = factory->get_resource_type(); RAFT_EXPECTS(rtype != resource::resource_type::LAST_KEY, "LAST_KEY is a placeholder and not a valid resource factory type."); - factories_.at(rtype) = std::make_pair(rtype, factory); - // Clear the corresponding resource, so that on next `get_resource` the new factory is used - if (resources_.at(rtype).first != resource::resource_type::LAST_KEY) { - resources_.at(rtype) = std::make_pair(resource::resource_type::LAST_KEY, - std::make_shared()); - } + auto new_cell = std::make_shared(); + new_cell->factory.store(std::move(factory)); + cells_[rtype] = std::move(new_cell); + } + + /** + * @brief Register a default factory if none has been set yet (lazy default). + * + * CAS's the factory into the existing shared cell. If another thread or copy + * already set a factory, this is a no-op. Because the cell is shared, all + * copies see the registered default. + * + * @param factory default resource factory + */ + void ensure_default_factory(std::shared_ptr factory) const + { + resource::resource_type rtype = factory->get_resource_type(); + std::shared_ptr expected{}; + cells_[rtype]->factory.compare_exchange_strong(expected, std::move(factory)); } /** * @brief Retrieve a resource for the given resource_type and cast to given pointer type. - * Note that the resources are loaded lazily on-demand and resources which don't yet - * exist on the current instance will be created using the corresponding factory, if - * it exists. + * + * Resources are created lazily on first access using the registered factory. + * The created resource is stored atomically in the shared cell, so all copies + * of this handle that share the same cell see the resource. + * * @tparam res_t pointer type for which retrieved resource will be casted * @param resource_type resource type to retrieve * @return the given resource, if it exists. @@ -106,24 +117,25 @@ class resources { template res_t* get_resource(resource::resource_type resource_type) const { - std::lock_guard _(mutex_); - - if (resources_.at(resource_type).first == resource::resource_type::LAST_KEY) { - RAFT_EXPECTS(factories_.at(resource_type).first != resource::resource_type::LAST_KEY, + auto& cell = cells_[resource_type]; + auto res = cell->res.load(); + if (!res) { + auto factory = cell->factory.load(); + RAFT_EXPECTS(factory != nullptr, "No resource factory has been registered for the given resource %d.", resource_type); - resource::resource_factory* factory = factories_.at(resource_type).second.get(); - resources_.at(resource_type) = std::make_pair( - resource_type, std::shared_ptr(factory->make_resource())); + auto new_res = std::shared_ptr(factory->make_resource()); + std::shared_ptr expected{}; + if (cell->res.compare_exchange_strong(expected, new_res)) { + res = new_res; + } else { + res = expected; + } } - - resource::resource* res = resources_.at(resource_type).second.get(); return reinterpret_cast(res->get_resource()); } protected: - mutable std::mutex mutex_; - mutable std::vector factories_; - mutable std::vector resources_; + std::vector> cells_; }; } // namespace raft diff --git a/cpp/include/raft/util/memory_tracking_resources.hpp b/cpp/include/raft/util/memory_tracking_resources.hpp index 306c10ce1e..f71665c6bb 100644 --- a/cpp/include/raft/util/memory_tracking_resources.hpp +++ b/cpp/include/raft/util/memory_tracking_resources.hpp @@ -136,7 +136,7 @@ class memory_tracking_resources : public resources { // snapshot_ is destroyed last (keeps original resource shared_ptrs alive). // owned_stream_ outlives report_ (report_ writes to it). // report_ is destroyed first of the three (stops background thread). - std::vector snapshot_; + std::vector> snapshot_; std::unique_ptr owned_stream_; raft::mr::resource_monitor report_; @@ -163,7 +163,7 @@ class memory_tracking_resources : public resources { auto managed_ref = raft::resource::get_managed_memory_resource_ref(*this); // Keeps original resource objects alive while tracking refs point into them. - snapshot_ = resources_; + snapshot_ = cells_; // --- Host (global) --- { diff --git a/cpp/tests/core/handle.cpp b/cpp/tests/core/handle.cpp index 928490c8ab..9480b55c4e 100644 --- a/cpp/tests/core/handle.cpp +++ b/cpp/tests/core/handle.cpp @@ -342,4 +342,74 @@ TEST(Raft, HandleAssign) assert_handles_equal(handle, copied_handle); } +TEST(Raft, LazyInitPropagates) +{ + raft::resources a; + raft::resources b(a); + + // Neither a nor b has a CUDA stream factory yet + ASSERT_FALSE(a.has_resource_factory(resource::resource_type::CUDA_STREAM_VIEW)); + ASSERT_FALSE(b.has_resource_factory(resource::resource_type::CUDA_STREAM_VIEW)); + + // Trigger lazy init on a + auto stream_a = resource::get_cuda_stream(a); + + // b should see the same resource (Goal 1: lazy init propagates) + ASSERT_TRUE(b.has_resource_factory(resource::resource_type::CUDA_STREAM_VIEW)); + auto stream_b = resource::get_cuda_stream(b); + ASSERT_EQ(stream_a, stream_b); +} + +TEST(Raft, ExplicitSetRequiresNonConst) +{ + // Goal 2: add_resource_factory is non-const. + // This is a compile-time check -- the following must NOT compile: + // const raft::resources r; + // r.add_resource_factory(std::make_shared()); + // We verify the non-const overload works: + raft::resources r; + r.add_resource_factory(std::make_shared()); + ASSERT_TRUE(r.has_resource_factory(resource::resource_type::CUDA_STREAM_VIEW)); +} + +TEST(Raft, ExplicitSetDoesNotPropagate) +{ + raft::resources a; + raft::resources b(a); + + // Trigger lazy init on both + auto stream_orig = resource::get_cuda_stream(a); + ASSERT_EQ(stream_orig, resource::get_cuda_stream(b)); + + // Explicit set on a -- creates a new cell + cudaStream_t raw_stream; + RAFT_CUDA_TRY(cudaStreamCreate(&raw_stream)); + rmm::cuda_stream_view new_stream(raw_stream); + resource::set_cuda_stream(a, new_stream); + + // a sees the new stream + ASSERT_EQ(new_stream, resource::get_cuda_stream(a)); + + // b still sees the old stream (Goal 3: explicit set does not propagate) + ASSERT_EQ(stream_orig, resource::get_cuda_stream(b)); + ASSERT_NE(resource::get_cuda_stream(a), resource::get_cuda_stream(b)); + + RAFT_CUDA_TRY(cudaStreamDestroy(raw_stream)); +} + +TEST(Raft, ResourcesDefaultMovable) +{ + raft::resources a; + resource::get_cuda_stream(a); + + // Move construct + raft::resources b(std::move(a)); + ASSERT_TRUE(b.has_resource_factory(resource::resource_type::CUDA_STREAM_VIEW)); + + // Move assign + raft::resources c; + c = std::move(b); + ASSERT_TRUE(c.has_resource_factory(resource::resource_type::CUDA_STREAM_VIEW)); +} + } // namespace raft From 21b773366843145fb1e8d7f99df84294b1c6235e Mon Sep 17 00:00:00 2001 From: achirkin Date: Wed, 29 Apr 2026 14:12:05 +0200 Subject: [PATCH 2/5] Require C++20 --- cpp/CMakeLists.txt | 2 +- cpp/internal/CMakeLists.txt | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 2b8c6934f9..a548b8e88b 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -191,7 +191,7 @@ target_include_directories( # Keep RAFT as lightweight as possible. Only CUDA libs and rmm should be used in global target. target_link_libraries(raft INTERFACE rapids_logger::rapids_logger rmm::rmm CCCL::CCCL) -target_compile_features(raft INTERFACE cxx_std_17 $) +target_compile_features(raft INTERFACE cxx_std_20 $) target_compile_options( raft INTERFACE $<$:--expt-extended-lambda --expt-relaxed-constexpr> diff --git a/cpp/internal/CMakeLists.txt b/cpp/internal/CMakeLists.txt index cda3ce416e..459d59160b 100644 --- a/cpp/internal/CMakeLists.txt +++ b/cpp/internal/CMakeLists.txt @@ -1,6 +1,6 @@ # ============================================================================= # cmake-format: off -# SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # cmake-format: on # ============================================================================= @@ -10,5 +10,5 @@ if(BUILD_TESTS OR BUILD_PRIMS_BENCH) target_include_directories( raft_internal INTERFACE "$" ) - target_compile_features(raft_internal INTERFACE cxx_std_17 $) + target_compile_features(raft_internal INTERFACE cxx_std_20 $) endif() From 5054e97bda1d2f429cee0c77fa95ee89960aff1d Mon Sep 17 00:00:00 2001 From: achirkin Date: Sat, 9 May 2026 10:29:50 +0200 Subject: [PATCH 3/5] Fix style --- cpp/include/raft/core/resource/resource_types.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/include/raft/core/resource/resource_types.hpp b/cpp/include/raft/core/resource/resource_types.hpp index 356114dd60..31fb0d5906 100644 --- a/cpp/include/raft/core/resource/resource_types.hpp +++ b/cpp/include/raft/core/resource/resource_types.hpp @@ -10,7 +10,6 @@ #include #include - namespace RAFT_EXPORT raft { namespace resource { From f8865fd9b28f75c338d4d7f2f5312e653844422f Mon Sep 17 00:00:00 2001 From: Artem Chirkin <9253178+achirkin@users.noreply.github.com> Date: Wed, 17 Jun 2026 06:00:04 -0700 Subject: [PATCH 4/5] Follow up on merge commit --- cpp/include/raft/core/memory_stats_resources.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/core/memory_stats_resources.hpp b/cpp/include/raft/core/memory_stats_resources.hpp index f0ec64903d..92534eb418 100644 --- a/cpp/include/raft/core/memory_stats_resources.hpp +++ b/cpp/include/raft/core/memory_stats_resources.hpp @@ -141,7 +141,7 @@ class memory_stats_resources : public resources { }; } - std::vector snapshot_; + std::vector> snapshot_; raft::mr::host_resource old_host_; raft::mr::device_resource old_device_; @@ -182,7 +182,7 @@ class memory_stats_resources : public resources { auto pinned_ref = resource::get_pinned_memory_resource_ref(*this); auto managed_ref = resource::get_managed_memory_resource_ref(*this); - snapshot_ = resources_; + snapshot_ = cells_; // --- Host (global) --- { From 7c80d208e74fac2faf0d15d966311d0c1bcb9585 Mon Sep 17 00:00:00 2001 From: Artem Chirkin <9253178+achirkin@users.noreply.github.com> Date: Wed, 17 Jun 2026 06:30:24 -0700 Subject: [PATCH 5/5] Follow up on merge commit --- cpp/include/raft/core/memory_stats_resources.hpp | 9 ++++----- cpp/include/raft/core/memory_tracking_resources.hpp | 9 ++++----- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/cpp/include/raft/core/memory_stats_resources.hpp b/cpp/include/raft/core/memory_stats_resources.hpp index 92534eb418..d449065fab 100644 --- a/cpp/include/raft/core/memory_stats_resources.hpp +++ b/cpp/include/raft/core/memory_stats_resources.hpp @@ -207,11 +207,10 @@ class memory_stats_resources : public resources { // --- Device (global) --- // Invalidate the cached thrust policy (the resource_ref it captured - // will be stale once we replace the global device resource). - factories_.at(resource::resource_type::THRUST_POLICY) = std::make_pair( - resource::resource_type::LAST_KEY, std::make_shared()); - resources_.at(resource::resource_type::THRUST_POLICY) = std::make_pair( - resource::resource_type::LAST_KEY, std::make_shared()); + // will be stale once we replace the global device resource). Swapping in a + // fresh cell drops the cached factory/resource locally while snapshot_ keeps + // the originals alive, so it gets lazily rebuilt against the new device MR. + cells_[resource::resource_type::THRUST_POLICY] = std::make_shared(); { device_stats_adaptor_t sa{rmm::device_async_resource_ref{old_device_}}; device_stats_ = sa.get_stats(); diff --git a/cpp/include/raft/core/memory_tracking_resources.hpp b/cpp/include/raft/core/memory_tracking_resources.hpp index 28de8e823d..0e98b436ec 100644 --- a/cpp/include/raft/core/memory_tracking_resources.hpp +++ b/cpp/include/raft/core/memory_tracking_resources.hpp @@ -209,11 +209,10 @@ class memory_tracking_resources : public resources { // --- Device (global) --- // Invalidate the cached thrust policy (the resource_ref it captured - // will be stale once we replace the global device resource). - factories_.at(resource::resource_type::THRUST_POLICY) = std::make_pair( - resource::resource_type::LAST_KEY, std::make_shared()); - resources_.at(resource::resource_type::THRUST_POLICY) = std::make_pair( - resource::resource_type::LAST_KEY, std::make_shared()); + // will be stale once we replace the global device resource). Swapping in a + // fresh cell drops the cached factory/resource locally while snapshot_ keeps + // the originals alive, so it gets lazily rebuilt against the new device MR. + cells_[resource::resource_type::THRUST_POLICY] = std::make_shared(); { device_stats_t sa{rmm::device_async_resource_ref{old_device_}}; report_.register_source("device", sa.get_stats());