hipvs/ivf_pq/
search_params.rs

1/*
2 * Copyright (c) 2024, NVIDIA CORPORATION.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::error::{check_cuvs, Result};
18use std::fmt;
19use std::io::{stderr, Write};
20
21pub use ffi::cudaDataType_t;
22
23/// Supplemental parameters to search IvfPq index
24pub struct SearchParams(pub ffi::cuvsIvfPqSearchParams_t);
25
26impl SearchParams {
27    /// Returns a new SearchParams object
28    pub fn new() -> Result<SearchParams> {
29        unsafe {
30            let mut params = std::mem::MaybeUninit::<ffi::cuvsIvfPqSearchParams_t>::uninit();
31            check_cuvs(ffi::cuvsIvfPqSearchParamsCreate(params.as_mut_ptr()))?;
32            Ok(SearchParams(params.assume_init()))
33        }
34    }
35
36    /// The number of clusters to search.
37    pub fn set_n_probes(self, n_probes: u32) -> SearchParams {
38        unsafe {
39            (*self.0).n_probes = n_probes;
40        }
41        self
42    }
43
44    /// Data type of look up table to be created dynamically at search
45    /// time. The use of low-precision types reduces the amount of shared
46    /// memory required at search time, so fast shared memory kernels can
47    /// be used even for datasets with large dimansionality. Note that
48    /// the recall is slightly degraded when low-precision type is
49    /// selected.
50    pub fn set_lut_dtype(self, lut_dtype: cudaDataType_t) -> SearchParams {
51        unsafe {
52            (*self.0).lut_dtype = lut_dtype;
53        }
54        self
55    }
56
57    /// Storage data type for distance/similarity computation.
58    pub fn set_internal_distance_dtype(
59        self,
60        internal_distance_dtype: cudaDataType_t,
61    ) -> SearchParams {
62        unsafe {
63            (*self.0).internal_distance_dtype = internal_distance_dtype;
64        }
65        self
66    }
67}
68
69impl fmt::Debug for SearchParams {
70    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
71        // custom debug trait here, default value will show the pointer address
72        // for the inner params object which isn't that useful.
73        write!(f, "SearchParams {{ params: {:?} }}", unsafe { *self.0 })
74    }
75}
76
77impl Drop for SearchParams {
78    fn drop(&mut self) {
79        if let Err(e) = check_cuvs(unsafe { ffi::cuvsIvfPqSearchParamsDestroy(self.0) }) {
80            write!(
81                stderr(),
82                "failed to call cuvsIvfPqSearchParamsDestroy {:?}",
83                e
84            )
85            .expect("failed to write to stderr");
86        }
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    #[test]
95    fn test_search_params() {
96        let params = SearchParams::new().unwrap().set_n_probes(128);
97
98        unsafe {
99            assert_eq!((*params.0).n_probes, 128);
100        }
101    }
102}