hipvs/cagra/
index_params.rs

1/*
2 * Copyright (c) 2024, NVIDIA CORPORATION.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::error::{check_cuvs, Result};
18use std::fmt;
19use std::io::{stderr, Write};
20
21pub type BuildAlgo = ffi::cuvsCagraGraphBuildAlgo;
22
23/// Supplemental parameters to build CAGRA Index
24pub struct CompressionParams(pub ffi::cuvsCagraCompressionParams_t);
25
26impl CompressionParams {
27    /// Returns a new CompressionParams
28    pub fn new() -> Result<CompressionParams> {
29        unsafe {
30            let mut params = std::mem::MaybeUninit::<ffi::cuvsCagraCompressionParams_t>::uninit();
31            check_cuvs(ffi::cuvsCagraCompressionParamsCreate(params.as_mut_ptr()))?;
32            Ok(CompressionParams(params.assume_init()))
33        }
34    }
35
36    /// The bit length of the vector element after compression by PQ.
37    pub fn set_pq_bits(self, pq_bits: u32) -> CompressionParams {
38        unsafe {
39            (*self.0).pq_bits = pq_bits;
40        }
41        self
42    }
43
44    /// The dimensionality of the vector after compression by PQ. When zero,
45    /// an optimal value is selected using a heuristic.
46    pub fn set_pq_dim(self, pq_dim: u32) -> CompressionParams {
47        unsafe {
48            (*self.0).pq_dim = pq_dim;
49        }
50        self
51    }
52
53    /// Vector Quantization (VQ) codebook size - number of "coarse cluster
54    /// centers". When zero, an optimal value is selected using a heuristic.
55    pub fn set_vq_n_centers(self, vq_n_centers: u32) -> CompressionParams {
56        unsafe {
57            (*self.0).vq_n_centers = vq_n_centers;
58        }
59        self
60    }
61
62    /// The number of iterations searching for kmeans centers (both VQ & PQ
63    /// phases).
64    pub fn set_kmeans_n_iters(self, kmeans_n_iters: u32) -> CompressionParams {
65        unsafe {
66            (*self.0).kmeans_n_iters = kmeans_n_iters;
67        }
68        self
69    }
70
71    /// The fraction of data to use during iterative kmeans building (VQ
72    /// phase). When zero, an optimal value is selected using a heuristic.
73    pub fn set_vq_kmeans_trainset_fraction(
74        self,
75        vq_kmeans_trainset_fraction: f64,
76    ) -> CompressionParams {
77        unsafe {
78            (*self.0).vq_kmeans_trainset_fraction = vq_kmeans_trainset_fraction;
79        }
80        self
81    }
82
83    /// The fraction of data to use during iterative kmeans building (PQ
84    /// phase). When zero, an optimal value is selected using a heuristic.
85    pub fn set_pq_kmeans_trainset_fraction(
86        self,
87        pq_kmeans_trainset_fraction: f64,
88    ) -> CompressionParams {
89        unsafe {
90            (*self.0).pq_kmeans_trainset_fraction = pq_kmeans_trainset_fraction;
91        }
92        self
93    }
94}
95
96pub struct IndexParams(pub ffi::cuvsCagraIndexParams_t, Option<CompressionParams>);
97
98impl IndexParams {
99    /// Returns a new IndexParams
100    pub fn new() -> Result<IndexParams> {
101        unsafe {
102            let mut params = std::mem::MaybeUninit::<ffi::cuvsCagraIndexParams_t>::uninit();
103            check_cuvs(ffi::cuvsCagraIndexParamsCreate(params.as_mut_ptr()))?;
104            Ok(IndexParams(params.assume_init(), None))
105        }
106    }
107
108    /// Degree of input graph for pruning
109    pub fn set_intermediate_graph_degree(self, intermediate_graph_degree: usize) -> IndexParams {
110        unsafe {
111            (*self.0).intermediate_graph_degree = intermediate_graph_degree;
112        }
113        self
114    }
115
116    /// Degree of output graph
117    pub fn set_graph_degree(self, graph_degree: usize) -> IndexParams {
118        unsafe {
119            (*self.0).graph_degree = graph_degree;
120        }
121        self
122    }
123
124    /// ANN algorithm to build knn graph
125    pub fn set_build_algo(self, build_algo: BuildAlgo) -> IndexParams {
126        unsafe {
127            (*self.0).build_algo = build_algo;
128        }
129        self
130    }
131
132    /// Number of iterations to run if building with NN_DESCENT
133    pub fn set_nn_descent_niter(self, nn_descent_niter: usize) -> IndexParams {
134        unsafe {
135            (*self.0).nn_descent_niter = nn_descent_niter;
136        }
137        self
138    }
139
140    pub fn set_compression(mut self, compression: CompressionParams) -> IndexParams {
141        unsafe {
142            (*self.0).compression = compression.0;
143        }
144        // Note: we're moving the ownership of compression here to avoid having it cleaned up
145        // and leaving a dangling pointer
146        self.1 = Some(compression);
147        self
148    }
149}
150
151impl fmt::Debug for IndexParams {
152    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
153        // custom debug trait here, default value will show the pointer address
154        // for the inner params object which isn't that useful.
155        write!(f, "IndexParams({:?})", unsafe { *self.0 })
156    }
157}
158
159impl fmt::Debug for CompressionParams {
160    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
161        write!(f, "CompressionParams({:?})", unsafe { *self.0 })
162    }
163}
164
165impl Drop for IndexParams {
166    fn drop(&mut self) {
167        if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraIndexParamsDestroy(self.0) }) {
168            write!(
169                stderr(),
170                "failed to call cuvsCagraIndexParamsDestroy {:?}",
171                e
172            )
173            .expect("failed to write to stderr");
174        }
175    }
176}
177
178impl Drop for CompressionParams {
179    fn drop(&mut self) {
180        if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraCompressionParamsDestroy(self.0) }) {
181            write!(
182                stderr(),
183                "failed to call cuvsCagraCompressionParamsDestroy {:?}",
184                e
185            )
186            .expect("failed to write to stderr");
187        }
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    #[test]
196    fn test_index_params() {
197        let params = IndexParams::new()
198            .unwrap()
199            .set_intermediate_graph_degree(128)
200            .set_graph_degree(16)
201            .set_build_algo(BuildAlgo::NN_DESCENT)
202            .set_nn_descent_niter(10)
203            .set_compression(
204                CompressionParams::new()
205                    .unwrap()
206                    .set_pq_bits(4)
207                    .set_pq_dim(8),
208            );
209
210        // make sure the setters actually updated internal representation on the c-struct
211        unsafe {
212            assert_eq!((*params.0).graph_degree, 16);
213            assert_eq!((*params.0).intermediate_graph_degree, 128);
214            assert_eq!((*params.0).build_algo, BuildAlgo::NN_DESCENT);
215            assert_eq!((*params.0).nn_descent_niter, 10);
216            assert_eq!((*(*params.0).compression).pq_dim, 8);
217            assert_eq!((*(*params.0).compression).pq_bits, 4);
218        }
219    }
220}