hipvs/ivf_pq/search_params.rs
1/*
2 * Copyright (c) 2024, NVIDIA CORPORATION.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::error::{check_cuvs, Result};
18use std::fmt;
19use std::io::{stderr, Write};
20
21pub use ffi::cudaDataType_t;
22
23/// Supplemental parameters to search IvfPq index
24pub struct SearchParams(pub ffi::cuvsIvfPqSearchParams_t);
25
26impl SearchParams {
27 /// Returns a new SearchParams object
28 pub fn new() -> Result<SearchParams> {
29 unsafe {
30 let mut params = std::mem::MaybeUninit::<ffi::cuvsIvfPqSearchParams_t>::uninit();
31 check_cuvs(ffi::cuvsIvfPqSearchParamsCreate(params.as_mut_ptr()))?;
32 Ok(SearchParams(params.assume_init()))
33 }
34 }
35
36 /// The number of clusters to search.
37 pub fn set_n_probes(self, n_probes: u32) -> SearchParams {
38 unsafe {
39 (*self.0).n_probes = n_probes;
40 }
41 self
42 }
43
44 /// Data type of look up table to be created dynamically at search
45 /// time. The use of low-precision types reduces the amount of shared
46 /// memory required at search time, so fast shared memory kernels can
47 /// be used even for datasets with large dimansionality. Note that
48 /// the recall is slightly degraded when low-precision type is
49 /// selected.
50 pub fn set_lut_dtype(self, lut_dtype: cudaDataType_t) -> SearchParams {
51 unsafe {
52 (*self.0).lut_dtype = lut_dtype;
53 }
54 self
55 }
56
57 /// Storage data type for distance/similarity computation.
58 pub fn set_internal_distance_dtype(
59 self,
60 internal_distance_dtype: cudaDataType_t,
61 ) -> SearchParams {
62 unsafe {
63 (*self.0).internal_distance_dtype = internal_distance_dtype;
64 }
65 self
66 }
67}
68
69impl fmt::Debug for SearchParams {
70 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
71 // custom debug trait here, default value will show the pointer address
72 // for the inner params object which isn't that useful.
73 write!(f, "SearchParams {{ params: {:?} }}", unsafe { *self.0 })
74 }
75}
76
77impl Drop for SearchParams {
78 fn drop(&mut self) {
79 if let Err(e) = check_cuvs(unsafe { ffi::cuvsIvfPqSearchParamsDestroy(self.0) }) {
80 write!(
81 stderr(),
82 "failed to call cuvsIvfPqSearchParamsDestroy {:?}",
83 e
84 )
85 .expect("failed to write to stderr");
86 }
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 #[test]
95 fn test_search_params() {
96 let params = SearchParams::new().unwrap().set_n_probes(128);
97
98 unsafe {
99 assert_eq!((*params.0).n_probes, 128);
100 }
101 }
102}