hipvs/cagra/
search_params.rs1use crate::error::{check_cuvs, Result};
18use std::fmt;
19use std::io::{stderr, Write};
20
21pub type SearchAlgo = ffi::cuvsCagraSearchAlgo;
22pub type HashMode = ffi::cuvsCagraHashMode;
23
24pub struct SearchParams(pub ffi::cuvsCagraSearchParams_t);
26
27impl SearchParams {
28 pub fn new() -> Result<SearchParams> {
30 unsafe {
31 let mut params = std::mem::MaybeUninit::<ffi::cuvsCagraSearchParams_t>::uninit();
32 check_cuvs(ffi::cuvsCagraSearchParamsCreate(params.as_mut_ptr()))?;
33 Ok(SearchParams(params.assume_init()))
34 }
35 }
36
37 pub fn set_max_queries(self, max_queries: usize) -> SearchParams {
39 unsafe {
40 (*self.0).max_queries = max_queries;
41 }
42 self
43 }
44
45 pub fn set_itopk_size(self, itopk_size: usize) -> SearchParams {
49 unsafe {
50 (*self.0).itopk_size = itopk_size;
51 }
52 self
53 }
54
55 pub fn set_max_iterations(self, max_iterations: usize) -> SearchParams {
57 unsafe {
58 (*self.0).max_iterations = max_iterations;
59 }
60 self
61 }
62
63 pub fn set_algo(self, algo: SearchAlgo) -> SearchParams {
65 unsafe {
66 (*self.0).algo = algo;
67 }
68 self
69 }
70
71 pub fn set_team_size(self, team_size: usize) -> SearchParams {
73 unsafe {
74 (*self.0).team_size = team_size;
75 }
76 self
77 }
78
79 pub fn set_min_iterations(self, min_iterations: usize) -> SearchParams {
81 unsafe {
82 (*self.0).min_iterations = min_iterations;
83 }
84 self
85 }
86
87 pub fn set_thread_block_size(self, thread_block_size: usize) -> SearchParams {
89 unsafe {
90 (*self.0).thread_block_size = thread_block_size;
91 }
92 self
93 }
94
95 pub fn set_hashmap_mode(self, hashmap_mode: HashMode) -> SearchParams {
97 unsafe {
98 (*self.0).hashmap_mode = hashmap_mode;
99 }
100 self
101 }
102
103 pub fn set_hashmap_min_bitlen(self, hashmap_min_bitlen: usize) -> SearchParams {
105 unsafe {
106 (*self.0).hashmap_min_bitlen = hashmap_min_bitlen;
107 }
108 self
109 }
110
111 pub fn set_hashmap_max_fill_rate(self, hashmap_max_fill_rate: f32) -> SearchParams {
113 unsafe {
114 (*self.0).hashmap_max_fill_rate = hashmap_max_fill_rate;
115 }
116 self
117 }
118
119 pub fn set_num_random_samplings(self, num_random_samplings: u32) -> SearchParams {
121 unsafe {
122 (*self.0).num_random_samplings = num_random_samplings;
123 }
124 self
125 }
126
127 pub fn set_rand_xor_mask(self, rand_xor_mask: u64) -> SearchParams {
129 unsafe {
130 (*self.0).rand_xor_mask = rand_xor_mask;
131 }
132 self
133 }
134}
135
136impl fmt::Debug for SearchParams {
137 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
138 write!(f, "SearchParams {{ params: {:?} }}", unsafe { *self.0 })
141 }
142}
143
144impl Drop for SearchParams {
145 fn drop(&mut self) {
146 if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraSearchParamsDestroy(self.0) }) {
147 write!(
148 stderr(),
149 "failed to call cuvsCagraSearchParamsDestroy {:?}",
150 e
151 )
152 .expect("failed to write to stderr");
153 }
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 #[test]
162 fn test_search_params() {
163 let params = SearchParams::new().unwrap().set_itopk_size(128);
164
165 unsafe {
166 assert_eq!((*params.0).itopk_size, 128);
167 }
168 }
169}