1use std::io::{stderr, Write};
18
19use crate::cagra::{IndexParams, SearchParams};
20use crate::dlpack::ManagedTensor;
21use crate::error::{check_cuvs, Result};
22use crate::resources::Resources;
23
24#[derive(Debug)]
26pub struct Index(ffi::cuvsCagraIndex_t);
27
28impl Index {
29 pub fn build<T: Into<ManagedTensor>>(
37 res: &Resources,
38 params: &IndexParams,
39 dataset: T,
40 ) -> Result<Index> {
41 let dataset: ManagedTensor = dataset.into();
42 let index = Index::new()?;
43 unsafe {
44 check_cuvs(ffi::cuvsCagraBuild(
45 res.0,
46 params.0,
47 dataset.as_ptr(),
48 index.0,
49 ))?;
50 }
51 Ok(index)
52 }
53
54 pub fn new() -> Result<Index> {
56 unsafe {
57 let mut index = std::mem::MaybeUninit::<ffi::cuvsCagraIndex_t>::uninit();
58 check_cuvs(ffi::cuvsCagraIndexCreate(index.as_mut_ptr()))?;
59 Ok(Index(index.assume_init()))
60 }
61 }
62
63 pub fn search(
73 self,
74 res: &Resources,
75 params: &SearchParams,
76 queries: &ManagedTensor,
77 neighbors: &ManagedTensor,
78 distances: &ManagedTensor,
79 ) -> Result<()> {
80 unsafe {
81 let prefilter = ffi::cuvsFilter {
82 addr: 0,
83 type_: ffi::cuvsFilterType::NO_FILTER,
84 };
85
86 check_cuvs(ffi::cuvsCagraSearch(
87 res.0,
88 params.0,
89 self.0,
90 queries.as_ptr(),
91 neighbors.as_ptr(),
92 distances.as_ptr(),
93 prefilter,
94 ))
95 }
96 }
97}
98
99impl Drop for Index {
100 fn drop(&mut self) {
101 if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraIndexDestroy(self.0) }) {
102 write!(stderr(), "failed to call cagraIndexDestroy {:?}", e)
103 .expect("failed to write to stderr");
104 }
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use ndarray::s;
112 use ndarray_rand::rand_distr::Uniform;
113 use ndarray_rand::RandomExt;
114
115 fn test_cagra(build_params: IndexParams) {
116 let res = Resources::new().unwrap();
117
118 let n_datapoints = 256;
120 let n_features = 16;
121 let dataset =
122 ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));
123
124 let index =
126 Index::build(&res, &build_params, &dataset).expect("failed to create cagra index");
127
128 let n_queries = 4;
131 let queries = dataset.slice(s![0..n_queries, ..]);
132
133 let k = 10;
134
135 let queries = ManagedTensor::from(&queries).to_device(&res).unwrap();
139 let mut neighbors_host = ndarray::Array::<u32, _>::zeros((n_queries, k));
140 let neighbors = ManagedTensor::from(&neighbors_host)
141 .to_device(&res)
142 .unwrap();
143
144 let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
145 let distances = ManagedTensor::from(&distances_host)
146 .to_device(&res)
147 .unwrap();
148
149 let search_params = SearchParams::new().unwrap();
150
151 index
152 .search(&res, &search_params, &queries, &neighbors, &distances)
153 .unwrap();
154
155 distances.to_host(&res, &mut distances_host).unwrap();
157 neighbors.to_host(&res, &mut neighbors_host).unwrap();
158
159 assert_eq!(neighbors_host[[0, 0]], 0);
162 assert_eq!(neighbors_host[[1, 0]], 1);
163 assert_eq!(neighbors_host[[2, 0]], 2);
164 assert_eq!(neighbors_host[[3, 0]], 3);
165 }
166
167 #[test]
168 fn test_cagra_index() {
169 let build_params = IndexParams::new().unwrap();
170 test_cagra(build_params);
171 }
172
173 #[test]
174 fn test_cagra_compression() {
175 use crate::cagra::CompressionParams;
176 let build_params = IndexParams::new()
177 .unwrap()
178 .set_compression(CompressionParams::new().unwrap());
179 test_cagra(build_params);
180 }
181}