hipvs/cagra/
search_params.rs

1/*
2 * Copyright (c) 2024, NVIDIA CORPORATION.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::error::{check_cuvs, Result};
18use std::fmt;
19use std::io::{stderr, Write};
20
21pub type SearchAlgo = ffi::cuvsCagraSearchAlgo;
22pub type HashMode = ffi::cuvsCagraHashMode;
23
24/// Supplemental parameters to search CAGRA index
25pub struct SearchParams(pub ffi::cuvsCagraSearchParams_t);
26
27impl SearchParams {
28    /// Returns a new SearchParams object
29    pub fn new() -> Result<SearchParams> {
30        unsafe {
31            let mut params = std::mem::MaybeUninit::<ffi::cuvsCagraSearchParams_t>::uninit();
32            check_cuvs(ffi::cuvsCagraSearchParamsCreate(params.as_mut_ptr()))?;
33            Ok(SearchParams(params.assume_init()))
34        }
35    }
36
37    /// Maximum number of queries to search at the same time (batch size). Auto select when 0
38    pub fn set_max_queries(self, max_queries: usize) -> SearchParams {
39        unsafe {
40            (*self.0).max_queries = max_queries;
41        }
42        self
43    }
44
45    /// Number of intermediate search results retained during the search.
46    /// This is the main knob to adjust trade off between accuracy and search speed.
47    /// Higher values improve the search accuracy
48    pub fn set_itopk_size(self, itopk_size: usize) -> SearchParams {
49        unsafe {
50            (*self.0).itopk_size = itopk_size;
51        }
52        self
53    }
54
55    /// Upper limit of search iterations. Auto select when 0.
56    pub fn set_max_iterations(self, max_iterations: usize) -> SearchParams {
57        unsafe {
58            (*self.0).max_iterations = max_iterations;
59        }
60        self
61    }
62
63    /// Which search implementation to use.
64    pub fn set_algo(self, algo: SearchAlgo) -> SearchParams {
65        unsafe {
66            (*self.0).algo = algo;
67        }
68        self
69    }
70
71    /// Number of threads used to calculate a single distance. 4, 8, 16, or 32.
72    pub fn set_team_size(self, team_size: usize) -> SearchParams {
73        unsafe {
74            (*self.0).team_size = team_size;
75        }
76        self
77    }
78
79    /// Lower limit of search iterations.
80    pub fn set_min_iterations(self, min_iterations: usize) -> SearchParams {
81        unsafe {
82            (*self.0).min_iterations = min_iterations;
83        }
84        self
85    }
86
87    /// Thread block size. 0, 64, 128, 256, 512, 1024. Auto selection when 0.
88    pub fn set_thread_block_size(self, thread_block_size: usize) -> SearchParams {
89        unsafe {
90            (*self.0).thread_block_size = thread_block_size;
91        }
92        self
93    }
94
95    /// Hashmap type. Auto selection when AUTO.
96    pub fn set_hashmap_mode(self, hashmap_mode: HashMode) -> SearchParams {
97        unsafe {
98            (*self.0).hashmap_mode = hashmap_mode;
99        }
100        self
101    }
102
103    /// Lower limit of hashmap bit length. More than 8.
104    pub fn set_hashmap_min_bitlen(self, hashmap_min_bitlen: usize) -> SearchParams {
105        unsafe {
106            (*self.0).hashmap_min_bitlen = hashmap_min_bitlen;
107        }
108        self
109    }
110
111    /// Upper limit of hashmap fill rate. More than 0.1, less than 0.9.
112    pub fn set_hashmap_max_fill_rate(self, hashmap_max_fill_rate: f32) -> SearchParams {
113        unsafe {
114            (*self.0).hashmap_max_fill_rate = hashmap_max_fill_rate;
115        }
116        self
117    }
118
119    /// Number of iterations of initial random seed node selection. 1 or more.
120    pub fn set_num_random_samplings(self, num_random_samplings: u32) -> SearchParams {
121        unsafe {
122            (*self.0).num_random_samplings = num_random_samplings;
123        }
124        self
125    }
126
127    /// Bit mask used for initial random seed node selection.
128    pub fn set_rand_xor_mask(self, rand_xor_mask: u64) -> SearchParams {
129        unsafe {
130            (*self.0).rand_xor_mask = rand_xor_mask;
131        }
132        self
133    }
134}
135
136impl fmt::Debug for SearchParams {
137    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
138        // custom debug trait here, default value will show the pointer address
139        // for the inner params object which isn't that useful.
140        write!(f, "SearchParams {{ params: {:?} }}", unsafe { *self.0 })
141    }
142}
143
144impl Drop for SearchParams {
145    fn drop(&mut self) {
146        if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraSearchParamsDestroy(self.0) }) {
147            write!(
148                stderr(),
149                "failed to call cuvsCagraSearchParamsDestroy {:?}",
150                e
151            )
152            .expect("failed to write to stderr");
153        }
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    #[test]
162    fn test_search_params() {
163        let params = SearchParams::new().unwrap().set_itopk_size(128);
164
165        unsafe {
166            assert_eq!((*params.0).itopk_size, 128);
167        }
168    }
169}