hipvs/ivf_flat/
index_params.rs

1/*
2 * Copyright (c) 2024, NVIDIA CORPORATION.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::error::{check_cuvs, Result};
18use crate::distance_type::DistanceType;
19use std::fmt;
20use std::io::{stderr, Write};
21
22pub struct IndexParams(pub ffi::cuvsIvfFlatIndexParams_t);
23
24impl IndexParams {
25    /// Returns a new IndexParams
26    pub fn new() -> Result<IndexParams> {
27        unsafe {
28            let mut params = std::mem::MaybeUninit::<ffi::cuvsIvfFlatIndexParams_t>::uninit();
29            check_cuvs(ffi::cuvsIvfFlatIndexParamsCreate(params.as_mut_ptr()))?;
30            Ok(IndexParams(params.assume_init()))
31        }
32    }
33
34    /// The number of clusters used in the coarse quantizer.
35    pub fn set_n_lists(self, n_lists: u32) -> IndexParams {
36        unsafe {
37            (*self.0).n_lists = n_lists;
38        }
39        self
40    }
41
42    /// DistanceType to use for building the index
43    pub fn set_metric(self, metric: DistanceType) -> IndexParams {
44        unsafe {
45            (*self.0).metric = metric;
46        }
47        self
48    }
49
50    /// The number of iterations searching for kmeans centers during index building.
51    pub fn set_metric_arg(self, metric_arg: f32) -> IndexParams {
52        unsafe {
53            (*self.0).metric_arg = metric_arg;
54        }
55        self
56    }
57    /// The number of iterations searching for kmeans centers during index building.
58    pub fn set_kmeans_n_iters(self, kmeans_n_iters: u32) -> IndexParams {
59        unsafe {
60            (*self.0).kmeans_n_iters = kmeans_n_iters;
61        }
62        self
63    }
64
65    /// If kmeans_trainset_fraction is less than 1, then the dataset is
66    /// subsampled, and only n_samples * kmeans_trainset_fraction rows
67    /// are used for training.
68    pub fn set_kmeans_trainset_fraction(self, kmeans_trainset_fraction: f64) -> IndexParams {
69        unsafe {
70            (*self.0).kmeans_trainset_fraction = kmeans_trainset_fraction;
71        }
72        self
73    }
74
75    /// After training the coarse and fine quantizers, we will populate
76    /// the index with the dataset if add_data_on_build == true, otherwise
77    /// the index is left empty, and the extend method can be used
78    /// to add new vectors to the index.
79    pub fn set_add_data_on_build(self, add_data_on_build: bool) -> IndexParams {
80        unsafe {
81            (*self.0).add_data_on_build = add_data_on_build;
82        }
83        self
84    }
85}
86
87impl fmt::Debug for IndexParams {
88    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
89        // custom debug trait here, default value will show the pointer address
90        // for the inner params object which isn't that useful.
91        write!(f, "IndexParams({:?})", unsafe { *self.0 })
92    }
93}
94
95impl Drop for IndexParams {
96    fn drop(&mut self) {
97        if let Err(e) = check_cuvs(unsafe { ffi::cuvsIvfFlatIndexParamsDestroy(self.0) }) {
98            write!(
99                stderr(),
100                "failed to call cuvsIvfFlatIndexParamsDestroy {:?}",
101                e
102            )
103            .expect("failed to write to stderr");
104        }
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    #[test]
113    fn test_index_params() {
114        let params = IndexParams::new()
115            .unwrap()
116            .set_n_lists(128)
117            .set_add_data_on_build(false);
118
119        unsafe {
120            assert_eq!((*params.0).n_lists, 128);
121            assert_eq!((*params.0).add_data_on_build, false);
122        }
123    }
124}