hipvs/cagra/
index_params.rs1use crate::error::{check_cuvs, Result};
18use std::fmt;
19use std::io::{stderr, Write};
20
21pub type BuildAlgo = ffi::cuvsCagraGraphBuildAlgo;
22
23pub struct CompressionParams(pub ffi::cuvsCagraCompressionParams_t);
25
26impl CompressionParams {
27 pub fn new() -> Result<CompressionParams> {
29 unsafe {
30 let mut params = std::mem::MaybeUninit::<ffi::cuvsCagraCompressionParams_t>::uninit();
31 check_cuvs(ffi::cuvsCagraCompressionParamsCreate(params.as_mut_ptr()))?;
32 Ok(CompressionParams(params.assume_init()))
33 }
34 }
35
36 pub fn set_pq_bits(self, pq_bits: u32) -> CompressionParams {
38 unsafe {
39 (*self.0).pq_bits = pq_bits;
40 }
41 self
42 }
43
44 pub fn set_pq_dim(self, pq_dim: u32) -> CompressionParams {
47 unsafe {
48 (*self.0).pq_dim = pq_dim;
49 }
50 self
51 }
52
53 pub fn set_vq_n_centers(self, vq_n_centers: u32) -> CompressionParams {
56 unsafe {
57 (*self.0).vq_n_centers = vq_n_centers;
58 }
59 self
60 }
61
62 pub fn set_kmeans_n_iters(self, kmeans_n_iters: u32) -> CompressionParams {
65 unsafe {
66 (*self.0).kmeans_n_iters = kmeans_n_iters;
67 }
68 self
69 }
70
71 pub fn set_vq_kmeans_trainset_fraction(
74 self,
75 vq_kmeans_trainset_fraction: f64,
76 ) -> CompressionParams {
77 unsafe {
78 (*self.0).vq_kmeans_trainset_fraction = vq_kmeans_trainset_fraction;
79 }
80 self
81 }
82
83 pub fn set_pq_kmeans_trainset_fraction(
86 self,
87 pq_kmeans_trainset_fraction: f64,
88 ) -> CompressionParams {
89 unsafe {
90 (*self.0).pq_kmeans_trainset_fraction = pq_kmeans_trainset_fraction;
91 }
92 self
93 }
94}
95
96pub struct IndexParams(pub ffi::cuvsCagraIndexParams_t, Option<CompressionParams>);
97
98impl IndexParams {
99 pub fn new() -> Result<IndexParams> {
101 unsafe {
102 let mut params = std::mem::MaybeUninit::<ffi::cuvsCagraIndexParams_t>::uninit();
103 check_cuvs(ffi::cuvsCagraIndexParamsCreate(params.as_mut_ptr()))?;
104 Ok(IndexParams(params.assume_init(), None))
105 }
106 }
107
108 pub fn set_intermediate_graph_degree(self, intermediate_graph_degree: usize) -> IndexParams {
110 unsafe {
111 (*self.0).intermediate_graph_degree = intermediate_graph_degree;
112 }
113 self
114 }
115
116 pub fn set_graph_degree(self, graph_degree: usize) -> IndexParams {
118 unsafe {
119 (*self.0).graph_degree = graph_degree;
120 }
121 self
122 }
123
124 pub fn set_build_algo(self, build_algo: BuildAlgo) -> IndexParams {
126 unsafe {
127 (*self.0).build_algo = build_algo;
128 }
129 self
130 }
131
132 pub fn set_nn_descent_niter(self, nn_descent_niter: usize) -> IndexParams {
134 unsafe {
135 (*self.0).nn_descent_niter = nn_descent_niter;
136 }
137 self
138 }
139
140 pub fn set_compression(mut self, compression: CompressionParams) -> IndexParams {
141 unsafe {
142 (*self.0).compression = compression.0;
143 }
144 self.1 = Some(compression);
147 self
148 }
149}
150
151impl fmt::Debug for IndexParams {
152 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
153 write!(f, "IndexParams({:?})", unsafe { *self.0 })
156 }
157}
158
159impl fmt::Debug for CompressionParams {
160 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
161 write!(f, "CompressionParams({:?})", unsafe { *self.0 })
162 }
163}
164
165impl Drop for IndexParams {
166 fn drop(&mut self) {
167 if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraIndexParamsDestroy(self.0) }) {
168 write!(
169 stderr(),
170 "failed to call cuvsCagraIndexParamsDestroy {:?}",
171 e
172 )
173 .expect("failed to write to stderr");
174 }
175 }
176}
177
178impl Drop for CompressionParams {
179 fn drop(&mut self) {
180 if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraCompressionParamsDestroy(self.0) }) {
181 write!(
182 stderr(),
183 "failed to call cuvsCagraCompressionParamsDestroy {:?}",
184 e
185 )
186 .expect("failed to write to stderr");
187 }
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
196 fn test_index_params() {
197 let params = IndexParams::new()
198 .unwrap()
199 .set_intermediate_graph_degree(128)
200 .set_graph_degree(16)
201 .set_build_algo(BuildAlgo::NN_DESCENT)
202 .set_nn_descent_niter(10)
203 .set_compression(
204 CompressionParams::new()
205 .unwrap()
206 .set_pq_bits(4)
207 .set_pq_dim(8),
208 );
209
210 unsafe {
212 assert_eq!((*params.0).graph_degree, 16);
213 assert_eq!((*params.0).intermediate_graph_degree, 128);
214 assert_eq!((*params.0).build_algo, BuildAlgo::NN_DESCENT);
215 assert_eq!((*params.0).nn_descent_niter, 10);
216 assert_eq!((*(*params.0).compression).pq_dim, 8);
217 assert_eq!((*(*params.0).compression).pq_bits, 4);
218 }
219 }
220}