hipvs/cagra/
index.rs

1/*
2 * Copyright (c) 2024, NVIDIA CORPORATION.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use std::io::{stderr, Write};
18
19use crate::cagra::{IndexParams, SearchParams};
20use crate::dlpack::ManagedTensor;
21use crate::error::{check_cuvs, Result};
22use crate::resources::Resources;
23
24/// CAGRA ANN Index
25#[derive(Debug)]
26pub struct Index(ffi::cuvsCagraIndex_t);
27
28impl Index {
29    /// Builds a new Index from the dataset for efficient search.
30    ///
31    /// # Arguments
32    ///
33    /// * `res` - Resources to use
34    /// * `params` - Parameters for building the index
35    /// * `dataset` - A row-major matrix on either the host or device to index
36    pub fn build<T: Into<ManagedTensor>>(
37        res: &Resources,
38        params: &IndexParams,
39        dataset: T,
40    ) -> Result<Index> {
41        let dataset: ManagedTensor = dataset.into();
42        let index = Index::new()?;
43        unsafe {
44            check_cuvs(ffi::cuvsCagraBuild(
45                res.0,
46                params.0,
47                dataset.as_ptr(),
48                index.0,
49            ))?;
50        }
51        Ok(index)
52    }
53
54    /// Creates a new empty index
55    pub fn new() -> Result<Index> {
56        unsafe {
57            let mut index = std::mem::MaybeUninit::<ffi::cuvsCagraIndex_t>::uninit();
58            check_cuvs(ffi::cuvsCagraIndexCreate(index.as_mut_ptr()))?;
59            Ok(Index(index.assume_init()))
60        }
61    }
62
63    /// Perform a Approximate Nearest Neighbors search on the Index
64    ///
65    /// # Arguments
66    ///
67    /// * `res` - Resources to use
68    /// * `params` - Parameters to use in searching the index
69    /// * `queries` - A matrix in device memory to query for
70    /// * `neighbors` - Matrix in device memory that receives the indices of the nearest neighbors
71    /// * `distances` - Matrix in device memory that receives the distances of the nearest neighbors
72    pub fn search(
73        self,
74        res: &Resources,
75        params: &SearchParams,
76        queries: &ManagedTensor,
77        neighbors: &ManagedTensor,
78        distances: &ManagedTensor,
79    ) -> Result<()> {
80        unsafe {
81            let prefilter = ffi::cuvsFilter {
82                addr: 0,
83                type_: ffi::cuvsFilterType::NO_FILTER,
84            };
85
86            check_cuvs(ffi::cuvsCagraSearch(
87                res.0,
88                params.0,
89                self.0,
90                queries.as_ptr(),
91                neighbors.as_ptr(),
92                distances.as_ptr(),
93                prefilter,
94            ))
95        }
96    }
97}
98
99impl Drop for Index {
100    fn drop(&mut self) {
101        if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraIndexDestroy(self.0) }) {
102            write!(stderr(), "failed to call cagraIndexDestroy {:?}", e)
103                .expect("failed to write to stderr");
104        }
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use ndarray::s;
112    use ndarray_rand::rand_distr::Uniform;
113    use ndarray_rand::RandomExt;
114
115    fn test_cagra(build_params: IndexParams) {
116        let res = Resources::new().unwrap();
117
118        // Create a new random dataset to index
119        let n_datapoints = 256;
120        let n_features = 16;
121        let dataset =
122            ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));
123
124        // build the cagra index
125        let index =
126            Index::build(&res, &build_params, &dataset).expect("failed to create cagra index");
127
128        // use the first 4 points from the dataset as queries : will test that we get them back
129        // as their own nearest neighbor
130        let n_queries = 4;
131        let queries = dataset.slice(s![0..n_queries, ..]);
132
133        let k = 10;
134
135        // CAGRA search API requires queries and outputs to be on device memory
136        // copy query data over, and allocate new device memory for the distances/ neighbors
137        // outputs
138        let queries = ManagedTensor::from(&queries).to_device(&res).unwrap();
139        let mut neighbors_host = ndarray::Array::<u32, _>::zeros((n_queries, k));
140        let neighbors = ManagedTensor::from(&neighbors_host)
141            .to_device(&res)
142            .unwrap();
143
144        let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
145        let distances = ManagedTensor::from(&distances_host)
146            .to_device(&res)
147            .unwrap();
148
149        let search_params = SearchParams::new().unwrap();
150
151        index
152            .search(&res, &search_params, &queries, &neighbors, &distances)
153            .unwrap();
154
155        // Copy back to host memory
156        distances.to_host(&res, &mut distances_host).unwrap();
157        neighbors.to_host(&res, &mut neighbors_host).unwrap();
158
159        // nearest neighbors should be themselves, since queries are from the
160        // dataset
161        assert_eq!(neighbors_host[[0, 0]], 0);
162        assert_eq!(neighbors_host[[1, 0]], 1);
163        assert_eq!(neighbors_host[[2, 0]], 2);
164        assert_eq!(neighbors_host[[3, 0]], 3);
165    }
166
167    #[test]
168    fn test_cagra_index() {
169        let build_params = IndexParams::new().unwrap();
170        test_cagra(build_params);
171    }
172
173    #[test]
174    fn test_cagra_compression() {
175        use crate::cagra::CompressionParams;
176        let build_params = IndexParams::new()
177            .unwrap()
178            .set_compression(CompressionParams::new().unwrap());
179        test_cagra(build_params);
180    }
181}