1use std::io::{stderr, Write};
19
20use crate::distance_type::DistanceType;
21use crate::dlpack::ManagedTensor;
22use crate::error::{check_cuvs, Result};
23use crate::resources::Resources;
24
25#[derive(Debug)]
27pub struct Index(ffi::cuvsBruteForceIndex_t);
28
29impl Index {
30 pub fn build<T: Into<ManagedTensor>>(
39 res: &Resources,
40 metric: DistanceType,
41 metric_arg: Option<f32>,
42 dataset: T,
43 ) -> Result<Index> {
44 let dataset: ManagedTensor = dataset.into();
45 let index = Index::new()?;
46 unsafe {
47 check_cuvs(ffi::cuvsBruteForceBuild(
48 res.0,
49 dataset.as_ptr(),
50 metric,
51 metric_arg.unwrap_or(2.0),
52 index.0,
53 ))?;
54 }
55 Ok(index)
56 }
57
58 pub fn new() -> Result<Index> {
60 unsafe {
61 let mut index = std::mem::MaybeUninit::<ffi::cuvsBruteForceIndex_t>::uninit();
62 check_cuvs(ffi::cuvsBruteForceIndexCreate(index.as_mut_ptr()))?;
63 Ok(Index(index.assume_init()))
64 }
65 }
66
67 pub fn search(
76 self,
77 res: &Resources,
78 queries: &ManagedTensor,
79 neighbors: &ManagedTensor,
80 distances: &ManagedTensor,
81 ) -> Result<()> {
82 unsafe {
83 let prefilter = ffi::cuvsFilter {
84 addr: 0,
85 type_: ffi::cuvsFilterType::NO_FILTER,
86 };
87
88 check_cuvs(ffi::cuvsBruteForceSearch(
89 res.0,
90 self.0,
91 queries.as_ptr(),
92 neighbors.as_ptr(),
93 distances.as_ptr(),
94 prefilter,
95 ))
96 }
97 }
98}
99
100impl Drop for Index {
101 fn drop(&mut self) {
102 if let Err(e) = check_cuvs(unsafe { ffi::cuvsBruteForceIndexDestroy(self.0) }) {
103 write!(stderr(), "failed to call cagraIndexDestroy {:?}", e)
104 .expect("failed to write to stderr");
105 }
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use ndarray::s;
113 use ndarray_rand::rand_distr::Uniform;
114 use ndarray_rand::RandomExt;
115 use mark_flaky_tests::flaky;
116
117 fn test_bfknn(metric: DistanceType) {
118 let res = Resources::new().unwrap();
119
120 let n_datapoints = 16;
122 let n_features = 8;
123 let dataset_host =
124 ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));
125
126 let dataset = ManagedTensor::from(&dataset_host).to_device(&res).unwrap();
127
128 println!("dataset {:#?}", dataset_host);
129
130 let index =
132 Index::build(&res, metric, None, dataset).expect("failed to create brute force index");
133
134 res.sync_stream().unwrap();
135
136 let n_queries = 4;
139 let queries = dataset_host.slice(s![0..n_queries, ..]);
140
141 let k = 4;
142
143 println!("queries! {:#?}", queries);
144 let queries = ManagedTensor::from(&queries).to_device(&res).unwrap();
145 let mut neighbors_host = ndarray::Array::<i64, _>::zeros((n_queries, k));
146 let neighbors = ManagedTensor::from(&neighbors_host)
147 .to_device(&res)
148 .unwrap();
149
150 let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
151 let distances = ManagedTensor::from(&distances_host)
152 .to_device(&res)
153 .unwrap();
154
155 index
156 .search(&res, &queries, &neighbors, &distances)
157 .unwrap();
158
159 distances.to_host(&res, &mut distances_host).unwrap();
161 neighbors.to_host(&res, &mut neighbors_host).unwrap();
162 res.sync_stream().unwrap();
163
164 println!("distances {:#?}", distances_host);
165 println!("neighbors {:#?}", neighbors_host);
166
167 assert_eq!(neighbors_host[[0, 0]], 0);
170 assert_eq!(neighbors_host[[1, 0]], 1);
171 assert_eq!(neighbors_host[[2, 0]], 2);
172 assert_eq!(neighbors_host[[3, 0]], 3);
173 }
174
175 #[flaky]
183 fn test_l2() {
184 test_bfknn(DistanceType::L2Expanded);
185 }
186}