1use std::io::{stderr, Write};
18
19use crate::ivf_flat::{IndexParams, SearchParams};
20use crate::dlpack::ManagedTensor;
21use crate::error::{check_cuvs, Result};
22use crate::resources::Resources;
23
24#[derive(Debug)]
26pub struct Index(ffi::cuvsIvfFlatIndex_t);
27
28impl Index {
29 pub fn build<T: Into<ManagedTensor>>(
37 res: &Resources,
38 params: &IndexParams,
39 dataset: T,
40 ) -> Result<Index> {
41 let dataset: ManagedTensor = dataset.into();
42 let index = Index::new()?;
43 unsafe {
44 check_cuvs(ffi::cuvsIvfFlatBuild(
45 res.0,
46 params.0,
47 dataset.as_ptr(),
48 index.0,
49 ))?;
50 }
51 Ok(index)
52 }
53
54 pub fn new() -> Result<Index> {
56 unsafe {
57 let mut index = std::mem::MaybeUninit::<ffi::cuvsIvfFlatIndex_t>::uninit();
58 check_cuvs(ffi::cuvsIvfFlatIndexCreate(index.as_mut_ptr()))?;
59 Ok(Index(index.assume_init()))
60 }
61 }
62
63 pub fn search(
73 self,
74 res: &Resources,
75 params: &SearchParams,
76 queries: &ManagedTensor,
77 neighbors: &ManagedTensor,
78 distances: &ManagedTensor,
79 ) -> Result<()> {
80 unsafe {
81 let prefilter = ffi::cuvsFilter {
82 addr: 0,
83 type_: ffi::cuvsFilterType::NO_FILTER,
84 };
85
86 check_cuvs(ffi::cuvsIvfFlatSearch(
87 res.0,
88 params.0,
89 self.0,
90 queries.as_ptr(),
91 neighbors.as_ptr(),
92 distances.as_ptr(),
93 prefilter,
94 ))
95 }
96 }
97}
98
99impl Drop for Index {
100 fn drop(&mut self) {
101 if let Err(e) = check_cuvs(unsafe { ffi::cuvsIvfFlatIndexDestroy(self.0) }) {
102 write!(stderr(), "failed to call cuvsIvfFlatIndexDestroy {:?}", e)
103 .expect("failed to write to stderr");
104 }
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use ndarray::s;
112 use ndarray_rand::rand_distr::Uniform;
113 use ndarray_rand::RandomExt;
114
115 #[test]
116 fn test_ivf_flat() {
117 let build_params = IndexParams::new().unwrap().set_n_lists(64);
118
119 let res = Resources::new().unwrap();
120
121 let n_datapoints = 1024;
123 let n_features = 16;
124 let dataset =
125 ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));
126
127 let dataset_device = ManagedTensor::from(&dataset).to_device(&res).unwrap();
128
129 let index =
131 Index::build(&res, &build_params, dataset_device).expect("failed to create ivf-flat index");
132
133 let n_queries = 4;
136 let queries = dataset.slice(s![0..n_queries, ..]);
137
138 let k = 10;
139
140 let queries = ManagedTensor::from(&queries).to_device(&res).unwrap();
144 let mut neighbors_host = ndarray::Array::<i64, _>::zeros((n_queries, k));
145 let neighbors = ManagedTensor::from(&neighbors_host)
146 .to_device(&res)
147 .unwrap();
148
149 let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
150 let distances = ManagedTensor::from(&distances_host)
151 .to_device(&res)
152 .unwrap();
153
154 let search_params = SearchParams::new().unwrap();
155
156 index
157 .search(&res, &search_params, &queries, &neighbors, &distances)
158 .unwrap();
159
160 distances.to_host(&res, &mut distances_host).unwrap();
162 neighbors.to_host(&res, &mut neighbors_host).unwrap();
163
164 assert_eq!(neighbors_host[[0, 0]], 0);
167 assert_eq!(neighbors_host[[1, 0]], 1);
168 assert_eq!(neighbors_host[[2, 0]], 2);
169 assert_eq!(neighbors_host[[3, 0]], 3);
170 }
171}