hipvs/ivf_pq/
index_params.rs

1/*
2 * Copyright (c) 2024, NVIDIA CORPORATION.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::distance_type::DistanceType;
18use crate::error::{check_cuvs, Result};
19use std::fmt;
20use std::io::{stderr, Write};
21
22pub use ffi::codebook_gen;
23
24pub struct IndexParams(pub ffi::cuvsIvfPqIndexParams_t);
25
26impl IndexParams {
27    /// Returns a new IndexParams
28    pub fn new() -> Result<IndexParams> {
29        unsafe {
30            let mut params = std::mem::MaybeUninit::<ffi::cuvsIvfPqIndexParams_t>::uninit();
31            check_cuvs(ffi::cuvsIvfPqIndexParamsCreate(params.as_mut_ptr()))?;
32            Ok(IndexParams(params.assume_init()))
33        }
34    }
35
36    /// The number of clusters used in the coarse quantizer.
37    pub fn set_n_lists(self, n_lists: u32) -> IndexParams {
38        unsafe {
39            (*self.0).n_lists = n_lists;
40        }
41        self
42    }
43
44    /// DistanceType to use for building the index
45    pub fn set_metric(self, metric: DistanceType) -> IndexParams {
46        unsafe {
47            (*self.0).metric = metric;
48        }
49        self
50    }
51
52    /// The number of iterations searching for kmeans centers during index building.
53    pub fn set_metric_arg(self, metric_arg: f32) -> IndexParams {
54        unsafe {
55            (*self.0).metric_arg = metric_arg;
56        }
57        self
58    }
59
60    /// The number of iterations searching for kmeans centers during index building.
61    pub fn set_kmeans_n_iters(self, kmeans_n_iters: u32) -> IndexParams {
62        unsafe {
63            (*self.0).kmeans_n_iters = kmeans_n_iters;
64        }
65        self
66    }
67
68    /// If kmeans_trainset_fraction is less than 1, then the dataset is
69    /// subsampled, and only n_samples * kmeans_trainset_fraction rows
70    /// are used for training.
71    pub fn set_kmeans_trainset_fraction(self, kmeans_trainset_fraction: f64) -> IndexParams {
72        unsafe {
73            (*self.0).kmeans_trainset_fraction = kmeans_trainset_fraction;
74        }
75        self
76    }
77
78    /// The bit length of the vector element after quantization.
79    pub fn set_pq_bits(self, pq_bits: u32) -> IndexParams {
80        unsafe {
81            (*self.0).pq_bits = pq_bits;
82        }
83        self
84    }
85
86    /// The dimensionality of a the vector after product quantization.
87    /// When zero, an optimal value is selected using a heuristic. Note
88    /// pq_dim * pq_bits must be a multiple of 8. Hint: a smaller 'pq_dim'
89    /// results in a smaller index size and better search performance, but
90    /// lower recall. If 'pq_bits' is 8, 'pq_dim' can be set to any number,
91    /// but multiple of 8 are desirable for good performance. If 'pq_bits'
92    /// is not 8, 'pq_dim' should be a multiple of 8. For good performance,
93    /// it is desirable that 'pq_dim' is a multiple of 32. Ideally,
94    /// 'pq_dim' should be also a divisor of the dataset dim.
95    pub fn set_pq_dim(self, pq_dim: u32) -> IndexParams {
96        unsafe {
97            (*self.0).pq_dim = pq_dim;
98        }
99        self
100    }
101
102    pub fn set_codebook_kind(self, codebook_kind: codebook_gen) -> IndexParams {
103        unsafe {
104            (*self.0).codebook_kind = codebook_kind;
105        }
106        self
107    }
108
109    /// Apply a random rotation matrix on the input data and queries even
110    /// if `dim % pq_dim == 0`. Note: if `dim` is not multiple of `pq_dim`,
111    /// a random rotation is always applied to the input data and queries
112    /// to transform the working space from `dim` to `rot_dim`, which may
113    /// be slightly larger than the original space and and is a multiple
114    /// of `pq_dim` (`rot_dim % pq_dim == 0`). However, this transform is
115    /// not necessary when `dim` is multiple of `pq_dim` (`dim == rot_dim`,
116    /// hence no need in adding "extra" data columns / features). By
117    /// default, if `dim == rot_dim`, the rotation transform is
118    /// initialized with the identity matrix. When
119    /// `force_random_rotation == True`, a random orthogonal transform
120    pub fn set_force_random_rotation(self, force_random_rotation: bool) -> IndexParams {
121        unsafe {
122            (*self.0).force_random_rotation = force_random_rotation;
123        }
124        self
125    }
126
127    /// The max number of data points to use per PQ code during PQ codebook training. Using more data
128    /// points per PQ code may increase the quality of PQ codebook but may also increase the build
129    /// time. The parameter is applied to both PQ codebook generation methods, i.e., PER_SUBSPACE and
130    /// PER_CLUSTER. In both cases, we will use `pq_book_size * max_train_points_per_pq_code` training
131    /// points to train each codebook.
132    pub fn set_max_train_points_per_pq_code(self, max_pq_points: u32)-> IndexParams {
133        unsafe {
134            (*self.0).max_train_points_per_pq_code = max_pq_points;
135        }
136        self
137    }
138
139    /// After training the coarse and fine quantizers, we will populate
140    /// the index with the dataset if add_data_on_build == true, otherwise
141    /// the index is left empty, and the extend method can be used
142    /// to add new vectors to the index.
143    pub fn set_add_data_on_build(self, add_data_on_build: bool) -> IndexParams {
144        unsafe {
145            (*self.0).add_data_on_build = add_data_on_build;
146        }
147        self
148    }
149}
150
151impl fmt::Debug for IndexParams {
152    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
153        // custom debug trait here, default value will show the pointer address
154        // for the inner params object which isn't that useful.
155        write!(f, "IndexParams({:?})", unsafe { *self.0 })
156    }
157}
158
159impl Drop for IndexParams {
160    fn drop(&mut self) {
161        if let Err(e) = check_cuvs(unsafe { ffi::cuvsIvfPqIndexParamsDestroy(self.0) }) {
162            write!(
163                stderr(),
164                "failed to call cuvsIvfPqIndexParamsDestroy {:?}",
165                e
166            )
167            .expect("failed to write to stderr");
168        }
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[test]
177    fn test_index_params() {
178        let params = IndexParams::new()
179            .unwrap()
180            .set_n_lists(128)
181            .set_add_data_on_build(false);
182
183        unsafe {
184            assert_eq!((*params.0).n_lists, 128);
185            assert_eq!((*params.0).add_data_on_build, false);
186        }
187    }
188}