-
Notifications
You must be signed in to change notification settings - Fork 197
Multi partition cagra search #2035
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 11 commits
c6e880d
390adec
92dbcf9
82dcf71
fad76cb
22a2c8d
69c4771
49e5a14
29751c7
2fffcb8
d284bd0
1e012c0
235fb32
86a4f70
21bd700
e1954b3
fdb025c
12d7f55
7f479b8
6b4e34f
cf679b9
37a65d7
56e99ee
aa93362
4db0941
876fb4d
f071508
efc7283
837767d
4b423ea
54a55f5
134f899
0aa4b03
b1c3e12
b35575f
0c41987
22e337c
301f5ef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -712,6 +712,34 @@ cuvsError_t cuvsCagraSearch(cuvsResources_t res, | |
| DLManagedTensor* distances, | ||
| cuvsFilter filter); | ||
|
|
||
| /** | ||
| * @brief Search multiple CAGRA index segments concurrently using a single GPU kernel launch. | ||
| * | ||
| * Launches a single kernel with grid (1, num_queries, num_segments) so each CTA handles one | ||
| * (query, segment) pair concurrently. All results land in the caller-supplied device buffers | ||
| * on the same CUDA stream, so downstream operations (e.g. selectK) see them via stream ordering | ||
| * with no explicit synchronization needed. | ||
| * | ||
| * Only float32 datasets are currently supported. Distance values are comparable across segments | ||
| * (same scale) but are not postprocessed (no kScale correction) — they are suitable for | ||
| * relative comparison (selectK / recall). | ||
| * | ||
| * @param[in] res cuvsResources_t opaque C handle | ||
| * @param[in] params search parameters | ||
| * @param[in] num_segments number of index segments | ||
| * @param[in] indices array of num_segments cuvsCagraIndex_t pointers | ||
| * @param[in] queries array of num_segments DLManagedTensor* (device, float32, [nq, dim]) | ||
| * @param[out] neighbors array of num_segments DLManagedTensor* (device, uint32, [nq, topk]) | ||
| * @param[out] distances array of num_segments DLManagedTensor* (device, float32, [nq, topk]) | ||
| */ | ||
| cuvsError_t cuvsCagraSearchMultiSegment(cuvsResources_t res, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to align on nomenclature a bit, I wonder if we can think of a more general name. Maybe "Partition"? Segment is pretty closely coupled to databases, and more specifically to LSM-based databases, but cuVS the library is more general that that. cuVS is at the level of "hash partitioning" or "blind sharding" (those are the terms we tend to use in this context). I think "MultiPartition" would be a more fitting name.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aligning to "partition" for now. FYI, also considered: "MultiShard", "MultiIndex", "Federated", but these might come with unintended connotations. |
||
| cuvsCagraSearchParams_t params, | ||
| uint32_t num_segments, | ||
| cuvsCagraIndex_t* indices, | ||
| DLManagedTensor** queries, | ||
| DLManagedTensor** neighbors, | ||
| DLManagedTensor** distances); | ||
|
|
||
| /** | ||
| * @} | ||
| */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
| #pragma once | ||
|
|
||
| #include <cuvs/core/c_api.h> | ||
| #include <dlpack/dlpack.h> | ||
|
|
||
| #ifdef __cplusplus | ||
| extern "C" { | ||
| #endif | ||
|
|
||
| /** | ||
| * @brief Select the k smallest values from a flat device array of n candidates. | ||
| * | ||
| * Treats `in_val` as a matrix of shape [1, n] and selects the `k` smallest | ||
| * float values. `out_idx` receives the int64 column positions of the selected | ||
| * values in [0, n), so the caller can recover per-segment identity as: | ||
| * | ||
| * segment_index = out_idx[j] / segment_k | ||
| * position_in_segment = out_idx[j] % segment_k | ||
| * | ||
| * @param[in] res cuvsResources_t handle | ||
| * @param[in] in_val DLManagedTensor* shape [1, n], float32, device memory | ||
| * @param[out] out_val DLManagedTensor* shape [1, k], float32, device memory | ||
| * @param[out] out_idx DLManagedTensor* shape [1, k], int64, device memory | ||
| * @return cuvsError_t | ||
| */ | ||
| cuvsError_t cuvsSelectK(cuvsResources_t res, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh this is great. I was just working on code examples for the new docs and realized we only have a C++ API for select_k. It'll be great to get the C APis, and later on the Python and other language wrappers for select-k.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the refactoring prompted by your other comment, select-k is no longer needed for this work. Leaving the C API intact in case it might be useful to others. |
||
| DLManagedTensor* in_val, | ||
| DLManagedTensor* out_val, | ||
| DLManagedTensor* out_idx); | ||
|
|
||
| #ifdef __cplusplus | ||
| } | ||
| #endif | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -689,6 +689,54 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res, | |
| }); | ||
| } | ||
|
|
||
| extern "C" cuvsError_t cuvsCagraSearchMultiSegment(cuvsResources_t res, | ||
| cuvsCagraSearchParams_t params, | ||
| uint32_t num_segments, | ||
| cuvsCagraIndex_t* indices, | ||
| DLManagedTensor** queries, | ||
| DLManagedTensor** neighbors, | ||
| DLManagedTensor** distances) | ||
| { | ||
| return cuvs::core::translate_exceptions([=] { | ||
| RAFT_EXPECTS(num_segments > 0, "num_segments must be > 0"); | ||
| RAFT_EXPECTS(indices != nullptr && queries != nullptr && neighbors != nullptr && | ||
| distances != nullptr, | ||
| "All pointer arrays must be non-null"); | ||
|
|
||
| auto res_ptr = reinterpret_cast<raft::resources*>(res); | ||
| auto search_params = cuvs::neighbors::cagra::search_params(); | ||
| convert_c_search_params(*params, &search_params); | ||
|
|
||
| // Only float32 is supported for multi-segment search. | ||
| RAFT_EXPECTS( | ||
| indices[0]->dtype.code == kDLFloat && indices[0]->dtype.bits == 32, | ||
| "Multi-segment search only supports float32 indices"); | ||
|
|
||
| using T = float; | ||
| using IdxT = uint32_t; | ||
| using OutIdxT = uint32_t; | ||
| using DistanceT = float; | ||
| using IndexT = cuvs::neighbors::cagra::index<T, IdxT>; | ||
|
|
||
| std::vector<const IndexT*> idx_vec(num_segments); | ||
| std::vector<raft::device_matrix_view<const T, int64_t, raft::row_major>> q_vec(num_segments); | ||
| std::vector<raft::device_matrix_view<OutIdxT, int64_t, raft::row_major>> n_vec(num_segments); | ||
| std::vector<raft::device_matrix_view<DistanceT, int64_t, raft::row_major>> d_vec(num_segments); | ||
|
|
||
| for (uint32_t i = 0; i < num_segments; i++) { | ||
| RAFT_EXPECTS(indices[i] != nullptr && indices[i]->addr != 0, | ||
| "Index at position %u is null or not built", i); | ||
| idx_vec[i] = reinterpret_cast<const IndexT*>(indices[i]->addr); | ||
| q_vec[i] = cuvs::core::from_dlpack<std::remove_reference_t<decltype(q_vec[i])>>(queries[i]); | ||
| n_vec[i] = cuvs::core::from_dlpack<std::remove_reference_t<decltype(n_vec[i])>>(neighbors[i]); | ||
| d_vec[i] = cuvs::core::from_dlpack<std::remove_reference_t<decltype(d_vec[i])>>(distances[i]); | ||
| } | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Validate every segment before the Only 🤖 Prompt for AI Agents |
||
|
|
||
| cuvs::neighbors::cagra::search_multi_segment( | ||
| *res_ptr, search_params, idx_vec, q_vec, n_vec, d_vec); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reject mixed-distance metrics across segments. This API combines raw distances from all segments into one global ranking, so all indices must use the same metric. Right now nothing checks that 🤖 Prompt for AI Agents |
||
| }); | ||
| } | ||
|
|
||
| extern "C" cuvsError_t cuvsCagraMerge(cuvsResources_t res, | ||
| cuvsCagraIndexParams_t params, | ||
| cuvsCagraIndex_t* indices, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| #include <cuvs/core/c_api.h> | ||
| #include "../core/exceptions.hpp" | ||
| #include <cuvs/selection/select_k.hpp> | ||
| #include <dlpack/dlpack.h> | ||
|
|
||
| #include <raft/core/device_mdspan.hpp> | ||
| #include <raft/core/resources.hpp> | ||
|
|
||
| extern "C" cuvsError_t cuvsSelectK(cuvsResources_t res, | ||
| DLManagedTensor* in_val, | ||
| DLManagedTensor* out_val, | ||
| DLManagedTensor* out_idx) | ||
| { | ||
| return cuvs::core::translate_exceptions([=] { | ||
| auto* res_ptr = reinterpret_cast<raft::resources*>(res); | ||
|
|
||
| int64_t n = in_val->dl_tensor.shape[1]; | ||
| int64_t k = out_val->dl_tensor.shape[1]; | ||
|
|
||
| auto in_view = raft::make_device_matrix_view<const float, int64_t, raft::row_major>( | ||
| static_cast<const float*>(in_val->dl_tensor.data), 1, n); | ||
|
|
||
| auto out_val_view = raft::make_device_matrix_view<float, int64_t, raft::row_major>( | ||
| static_cast<float*>(out_val->dl_tensor.data), 1, k); | ||
|
|
||
| auto out_idx_view = raft::make_device_matrix_view<int64_t, int64_t, raft::row_major>( | ||
| static_cast<int64_t*>(out_idx->dl_tensor.data), 1, k); | ||
|
|
||
| cuvs::selection::select_k( | ||
| *res_ptr, | ||
| in_view, | ||
| std::nullopt, // implicit positions [0, n) as in_idx | ||
| out_val_view, | ||
| out_idx_view, | ||
| true); // select_min = true (smallest distance = nearest neighbor) | ||
|
Comment on lines
+14
to
+40
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Validate the DLPack contract before dereferencing
🤖 Prompt for AI Agents |
||
| }); | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 79
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 1665
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 269
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 580
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 2101
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 200
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 89
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 2518
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 39
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 784
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 331
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 779
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 76
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 275
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 665
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 39
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 1728
The async-memory resource owner cannot be
thread_localwhen this API changes the current resource globally.The implementation at
c/src/core/c_api.cpp:188storescuda_async_memory_resourceinthread_local async_mrand passes it tormm::mr::set_current_device_resource(), but the documentation explicitly states this function "will change the memory resource for the whole process" (line 235). This creates a critical lifetime mismatch:set_current_device_resource()is device-scoped (affecting all threads), then when the enabling thread exits, itsthread_local async_mris destroyed while still registered as the current resource, leaving RMM with a dangling pointer.The pool resource avoids this issue by passing temporary rvalues to
set_current_device_resource(), allowing RMM to manage the lifetime. Either makeasync_mrprocess/device-scoped (notthread_local), or narrow the documentation and implementation to clarify thread-local semantics.🤖 Prompt for AI Agents