hipvs/ivf_pq/index_params.rs
1/*
2 * Copyright (c) 2024, NVIDIA CORPORATION.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::distance_type::DistanceType;
18use crate::error::{check_cuvs, Result};
19use std::fmt;
20use std::io::{stderr, Write};
21
22pub use ffi::codebook_gen;
23
24pub struct IndexParams(pub ffi::cuvsIvfPqIndexParams_t);
25
26impl IndexParams {
27 /// Returns a new IndexParams
28 pub fn new() -> Result<IndexParams> {
29 unsafe {
30 let mut params = std::mem::MaybeUninit::<ffi::cuvsIvfPqIndexParams_t>::uninit();
31 check_cuvs(ffi::cuvsIvfPqIndexParamsCreate(params.as_mut_ptr()))?;
32 Ok(IndexParams(params.assume_init()))
33 }
34 }
35
36 /// The number of clusters used in the coarse quantizer.
37 pub fn set_n_lists(self, n_lists: u32) -> IndexParams {
38 unsafe {
39 (*self.0).n_lists = n_lists;
40 }
41 self
42 }
43
44 /// DistanceType to use for building the index
45 pub fn set_metric(self, metric: DistanceType) -> IndexParams {
46 unsafe {
47 (*self.0).metric = metric;
48 }
49 self
50 }
51
52 /// The number of iterations searching for kmeans centers during index building.
53 pub fn set_metric_arg(self, metric_arg: f32) -> IndexParams {
54 unsafe {
55 (*self.0).metric_arg = metric_arg;
56 }
57 self
58 }
59
60 /// The number of iterations searching for kmeans centers during index building.
61 pub fn set_kmeans_n_iters(self, kmeans_n_iters: u32) -> IndexParams {
62 unsafe {
63 (*self.0).kmeans_n_iters = kmeans_n_iters;
64 }
65 self
66 }
67
68 /// If kmeans_trainset_fraction is less than 1, then the dataset is
69 /// subsampled, and only n_samples * kmeans_trainset_fraction rows
70 /// are used for training.
71 pub fn set_kmeans_trainset_fraction(self, kmeans_trainset_fraction: f64) -> IndexParams {
72 unsafe {
73 (*self.0).kmeans_trainset_fraction = kmeans_trainset_fraction;
74 }
75 self
76 }
77
78 /// The bit length of the vector element after quantization.
79 pub fn set_pq_bits(self, pq_bits: u32) -> IndexParams {
80 unsafe {
81 (*self.0).pq_bits = pq_bits;
82 }
83 self
84 }
85
86 /// The dimensionality of a the vector after product quantization.
87 /// When zero, an optimal value is selected using a heuristic. Note
88 /// pq_dim * pq_bits must be a multiple of 8. Hint: a smaller 'pq_dim'
89 /// results in a smaller index size and better search performance, but
90 /// lower recall. If 'pq_bits' is 8, 'pq_dim' can be set to any number,
91 /// but multiple of 8 are desirable for good performance. If 'pq_bits'
92 /// is not 8, 'pq_dim' should be a multiple of 8. For good performance,
93 /// it is desirable that 'pq_dim' is a multiple of 32. Ideally,
94 /// 'pq_dim' should be also a divisor of the dataset dim.
95 pub fn set_pq_dim(self, pq_dim: u32) -> IndexParams {
96 unsafe {
97 (*self.0).pq_dim = pq_dim;
98 }
99 self
100 }
101
102 pub fn set_codebook_kind(self, codebook_kind: codebook_gen) -> IndexParams {
103 unsafe {
104 (*self.0).codebook_kind = codebook_kind;
105 }
106 self
107 }
108
109 /// Apply a random rotation matrix on the input data and queries even
110 /// if `dim % pq_dim == 0`. Note: if `dim` is not multiple of `pq_dim`,
111 /// a random rotation is always applied to the input data and queries
112 /// to transform the working space from `dim` to `rot_dim`, which may
113 /// be slightly larger than the original space and and is a multiple
114 /// of `pq_dim` (`rot_dim % pq_dim == 0`). However, this transform is
115 /// not necessary when `dim` is multiple of `pq_dim` (`dim == rot_dim`,
116 /// hence no need in adding "extra" data columns / features). By
117 /// default, if `dim == rot_dim`, the rotation transform is
118 /// initialized with the identity matrix. When
119 /// `force_random_rotation == True`, a random orthogonal transform
120 pub fn set_force_random_rotation(self, force_random_rotation: bool) -> IndexParams {
121 unsafe {
122 (*self.0).force_random_rotation = force_random_rotation;
123 }
124 self
125 }
126
127 /// The max number of data points to use per PQ code during PQ codebook training. Using more data
128 /// points per PQ code may increase the quality of PQ codebook but may also increase the build
129 /// time. The parameter is applied to both PQ codebook generation methods, i.e., PER_SUBSPACE and
130 /// PER_CLUSTER. In both cases, we will use `pq_book_size * max_train_points_per_pq_code` training
131 /// points to train each codebook.
132 pub fn set_max_train_points_per_pq_code(self, max_pq_points: u32)-> IndexParams {
133 unsafe {
134 (*self.0).max_train_points_per_pq_code = max_pq_points;
135 }
136 self
137 }
138
139 /// After training the coarse and fine quantizers, we will populate
140 /// the index with the dataset if add_data_on_build == true, otherwise
141 /// the index is left empty, and the extend method can be used
142 /// to add new vectors to the index.
143 pub fn set_add_data_on_build(self, add_data_on_build: bool) -> IndexParams {
144 unsafe {
145 (*self.0).add_data_on_build = add_data_on_build;
146 }
147 self
148 }
149}
150
151impl fmt::Debug for IndexParams {
152 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
153 // custom debug trait here, default value will show the pointer address
154 // for the inner params object which isn't that useful.
155 write!(f, "IndexParams({:?})", unsafe { *self.0 })
156 }
157}
158
159impl Drop for IndexParams {
160 fn drop(&mut self) {
161 if let Err(e) = check_cuvs(unsafe { ffi::cuvsIvfPqIndexParamsDestroy(self.0) }) {
162 write!(
163 stderr(),
164 "failed to call cuvsIvfPqIndexParamsDestroy {:?}",
165 e
166 )
167 .expect("failed to write to stderr");
168 }
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn test_index_params() {
178 let params = IndexParams::new()
179 .unwrap()
180 .set_n_lists(128)
181 .set_add_data_on_build(false);
182
183 unsafe {
184 assert_eq!((*params.0).n_lists, 128);
185 assert_eq!((*params.0).add_data_on_build, false);
186 }
187 }
188}