1use std::io::{stderr, Write};
18
19use crate::dlpack::ManagedTensor;
20use crate::error::{check_cuvs, Result};
21use crate::ivf_pq::{IndexParams, SearchParams};
22use crate::resources::Resources;
23
24#[derive(Debug)]
26pub struct Index(ffi::cuvsIvfPqIndex_t);
27
28impl Index {
29 pub fn build<T: Into<ManagedTensor>>(
37 res: &Resources,
38 params: &IndexParams,
39 dataset: T,
40 ) -> Result<Index> {
41 let dataset: ManagedTensor = dataset.into();
42 let index = Index::new()?;
43 unsafe {
44 check_cuvs(ffi::cuvsIvfPqBuild(
45 res.0,
46 params.0,
47 dataset.as_ptr(),
48 index.0,
49 ))?;
50 }
51 Ok(index)
52 }
53
54 pub fn new() -> Result<Index> {
56 unsafe {
57 let mut index = std::mem::MaybeUninit::<ffi::cuvsIvfPqIndex_t>::uninit();
58 check_cuvs(ffi::cuvsIvfPqIndexCreate(index.as_mut_ptr()))?;
59 Ok(Index(index.assume_init()))
60 }
61 }
62
63 pub fn search(
73 self,
74 res: &Resources,
75 params: &SearchParams,
76 queries: &ManagedTensor,
77 neighbors: &ManagedTensor,
78 distances: &ManagedTensor,
79 ) -> Result<()> {
80 unsafe {
81 check_cuvs(ffi::cuvsIvfPqSearch(
82 res.0,
83 params.0,
84 self.0,
85 queries.as_ptr(),
86 neighbors.as_ptr(),
87 distances.as_ptr(),
88 ))
89 }
90 }
91}
92
93impl Drop for Index {
94 fn drop(&mut self) {
95 if let Err(e) = check_cuvs(unsafe { ffi::cuvsIvfPqIndexDestroy(self.0) }) {
96 write!(stderr(), "failed to call cuvsIvfPqIndexDestroy {:?}", e)
97 .expect("failed to write to stderr");
98 }
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use ndarray::s;
106 use ndarray_rand::rand_distr::Uniform;
107 use ndarray_rand::RandomExt;
108
109 #[test]
110 fn test_ivf_pq() {
111 let build_params = IndexParams::new().unwrap().set_n_lists(64);
112
113 let res = Resources::new().unwrap();
114
115 let n_datapoints = 1024;
117 let n_features = 16;
118 let dataset =
119 ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));
120
121 let dataset_device = ManagedTensor::from(&dataset).to_device(&res).unwrap();
122
123 let index = Index::build(&res, &build_params, dataset_device)
125 .expect("failed to create ivf-pq index");
126
127 let n_queries = 4;
130 let queries = dataset.slice(s![0..n_queries, ..]);
131
132 let k = 10;
133
134 let queries = ManagedTensor::from(&queries).to_device(&res).unwrap();
138 let mut neighbors_host = ndarray::Array::<i64, _>::zeros((n_queries, k));
139 let neighbors = ManagedTensor::from(&neighbors_host)
140 .to_device(&res)
141 .unwrap();
142
143 let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
144 let distances = ManagedTensor::from(&distances_host)
145 .to_device(&res)
146 .unwrap();
147
148 let search_params = SearchParams::new().unwrap();
149
150 index
151 .search(&res, &search_params, &queries, &neighbors, &distances)
152 .unwrap();
153
154 distances.to_host(&res, &mut distances_host).unwrap();
156 neighbors.to_host(&res, &mut neighbors_host).unwrap();
157
158 assert_eq!(neighbors_host[[0, 0]], 0);
161 assert_eq!(neighbors_host[[1, 0]], 1);
162 assert_eq!(neighbors_host[[2, 0]], 2);
163 assert_eq!(neighbors_host[[3, 0]], 3);
164 }
165}