hipvs/
brute_force.rs

1/*
2 * Copyright (c) 2024, NVIDIA CORPORATION.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16//! Brute Force KNN
17
18use std::io::{stderr, Write};
19
20use crate::distance_type::DistanceType;
21use crate::dlpack::ManagedTensor;
22use crate::error::{check_cuvs, Result};
23use crate::resources::Resources;
24
25/// Brute Force KNN Index
26#[derive(Debug)]
27pub struct Index(ffi::cuvsBruteForceIndex_t);
28
29impl Index {
30    /// Builds a new Brute Force KNN Index from the dataset for efficient search.
31    ///
32    /// # Arguments
33    ///
34    /// * `res` - Resources to use
35    /// * `metric` - DistanceType to use for building the index
36    /// * `metric_arg` - Optional value of `p` for Minkowski distances
37    /// * `dataset` - A row-major matrix on either the host or device to index
38    pub fn build<T: Into<ManagedTensor>>(
39        res: &Resources,
40        metric: DistanceType,
41        metric_arg: Option<f32>,
42        dataset: T,
43    ) -> Result<Index> {
44        let dataset: ManagedTensor = dataset.into();
45        let index = Index::new()?;
46        unsafe {
47            check_cuvs(ffi::cuvsBruteForceBuild(
48                res.0,
49                dataset.as_ptr(),
50                metric,
51                metric_arg.unwrap_or(2.0),
52                index.0,
53            ))?;
54        }
55        Ok(index)
56    }
57
58    /// Creates a new empty index
59    pub fn new() -> Result<Index> {
60        unsafe {
61            let mut index = std::mem::MaybeUninit::<ffi::cuvsBruteForceIndex_t>::uninit();
62            check_cuvs(ffi::cuvsBruteForceIndexCreate(index.as_mut_ptr()))?;
63            Ok(Index(index.assume_init()))
64        }
65    }
66
67    /// Perform a Nearest Neighbors search on the Index
68    ///
69    /// # Arguments
70    ///
71    /// * `res` - Resources to use
72    /// * `queries` - A matrix in device memory to query for
73    /// * `neighbors` - Matrix in device memory that receives the indices of the nearest neighbors
74    /// * `distances` - Matrix in device memory that receives the distances of the nearest neighbors
75    pub fn search(
76        self,
77        res: &Resources,
78        queries: &ManagedTensor,
79        neighbors: &ManagedTensor,
80        distances: &ManagedTensor,
81    ) -> Result<()> {
82        unsafe {
83            let prefilter = ffi::cuvsFilter {
84                addr: 0,
85                type_: ffi::cuvsFilterType::NO_FILTER,
86            };
87
88            check_cuvs(ffi::cuvsBruteForceSearch(
89                res.0,
90                self.0,
91                queries.as_ptr(),
92                neighbors.as_ptr(),
93                distances.as_ptr(),
94                prefilter,
95            ))
96        }
97    }
98}
99
100impl Drop for Index {
101    fn drop(&mut self) {
102        if let Err(e) = check_cuvs(unsafe { ffi::cuvsBruteForceIndexDestroy(self.0) }) {
103            write!(stderr(), "failed to call cagraIndexDestroy {:?}", e)
104                .expect("failed to write to stderr");
105        }
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use ndarray::s;
113    use ndarray_rand::rand_distr::Uniform;
114    use ndarray_rand::RandomExt;
115    use mark_flaky_tests::flaky;
116
117    fn test_bfknn(metric: DistanceType) {
118        let res = Resources::new().unwrap();
119
120        // Create a new random dataset to index
121        let n_datapoints = 16;
122        let n_features = 8;
123        let dataset_host =
124            ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));
125
126        let dataset = ManagedTensor::from(&dataset_host).to_device(&res).unwrap();
127
128        println!("dataset {:#?}", dataset_host);
129
130        // build the brute force index
131        let index =
132            Index::build(&res, metric, None, dataset).expect("failed to create brute force index");
133
134        res.sync_stream().unwrap();
135
136        // use the first 4 points from the dataset as queries : will test that we get them back
137        // as their own nearest neighbor
138        let n_queries = 4;
139        let queries = dataset_host.slice(s![0..n_queries, ..]);
140
141        let k = 4;
142
143        println!("queries! {:#?}", queries);
144        let queries = ManagedTensor::from(&queries).to_device(&res).unwrap();
145        let mut neighbors_host = ndarray::Array::<i64, _>::zeros((n_queries, k));
146        let neighbors = ManagedTensor::from(&neighbors_host)
147            .to_device(&res)
148            .unwrap();
149
150        let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
151        let distances = ManagedTensor::from(&distances_host)
152            .to_device(&res)
153            .unwrap();
154
155        index
156            .search(&res, &queries, &neighbors, &distances)
157            .unwrap();
158
159        // Copy back to host memory
160        distances.to_host(&res, &mut distances_host).unwrap();
161        neighbors.to_host(&res, &mut neighbors_host).unwrap();
162        res.sync_stream().unwrap();
163
164        println!("distances {:#?}", distances_host);
165        println!("neighbors {:#?}", neighbors_host);
166
167        // nearest neighbors should be themselves, since queries are from the
168        // dataset
169        assert_eq!(neighbors_host[[0, 0]], 0);
170        assert_eq!(neighbors_host[[1, 0]], 1);
171        assert_eq!(neighbors_host[[2, 0]], 2);
172        assert_eq!(neighbors_host[[3, 0]], 3);
173    }
174
175    /*
176        #[test]
177        fn test_cosine() {
178            test_bfknn(DistanceType::CosineExpanded);
179        }
180    */
181
182    #[flaky]
183    fn test_l2() {
184        test_bfknn(DistanceType::L2Expanded);
185    }
186}