hipvs/ivf_flat/
search_params.rs

1/*
2 * Copyright (c) 2024, NVIDIA CORPORATION.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::error::{check_cuvs, Result};
18use std::fmt;
19use std::io::{stderr, Write};
20
21/// Supplemental parameters to search IvfFlat index
22pub struct SearchParams(pub ffi::cuvsIvfFlatSearchParams_t);
23
24impl SearchParams {
25    /// Returns a new SearchParams object
26    pub fn new() -> Result<SearchParams> {
27        unsafe {
28            let mut params = std::mem::MaybeUninit::<ffi::cuvsIvfFlatSearchParams_t>::uninit();
29            check_cuvs(ffi::cuvsIvfFlatSearchParamsCreate(params.as_mut_ptr()))?;
30            Ok(SearchParams(params.assume_init()))
31        }
32    }
33
34    /// Supplemental parameters to search IVF-Flat index
35    pub fn set_n_probes(self, n_probes: u32) -> SearchParams {
36        unsafe {
37            (*self.0).n_probes = n_probes;
38        }
39        self
40    }
41}
42
43impl fmt::Debug for SearchParams {
44    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
45        // custom debug trait here, default value will show the pointer address
46        // for the inner params object which isn't that useful.
47        write!(f, "SearchParams {{ params: {:?} }}", unsafe { *self.0 })
48    }
49}
50
51impl Drop for SearchParams {
52    fn drop(&mut self) {
53        if let Err(e) = check_cuvs(unsafe { ffi::cuvsIvfFlatSearchParamsDestroy(self.0) }) {
54            write!(
55                stderr(),
56                "failed to call cuvsIvfFlatSearchParamsDestroy {:?}",
57                e
58            )
59            .expect("failed to write to stderr");
60        }
61    }
62}
63
64#[cfg(test)]
65mod tests {
66    use super::*;
67
68    #[test]
69    fn test_search_params() {
70        let params = SearchParams::new().unwrap().set_n_probes(128);
71
72        unsafe {
73            assert_eq!((*params.0).n_probes, 128);
74        }
75    }
76}