hipvs/ivf_flat/
index_params.rs1use crate::error::{check_cuvs, Result};
18use crate::distance_type::DistanceType;
19use std::fmt;
20use std::io::{stderr, Write};
21
22pub struct IndexParams(pub ffi::cuvsIvfFlatIndexParams_t);
23
24impl IndexParams {
25 pub fn new() -> Result<IndexParams> {
27 unsafe {
28 let mut params = std::mem::MaybeUninit::<ffi::cuvsIvfFlatIndexParams_t>::uninit();
29 check_cuvs(ffi::cuvsIvfFlatIndexParamsCreate(params.as_mut_ptr()))?;
30 Ok(IndexParams(params.assume_init()))
31 }
32 }
33
34 pub fn set_n_lists(self, n_lists: u32) -> IndexParams {
36 unsafe {
37 (*self.0).n_lists = n_lists;
38 }
39 self
40 }
41
42 pub fn set_metric(self, metric: DistanceType) -> IndexParams {
44 unsafe {
45 (*self.0).metric = metric;
46 }
47 self
48 }
49
50 pub fn set_metric_arg(self, metric_arg: f32) -> IndexParams {
52 unsafe {
53 (*self.0).metric_arg = metric_arg;
54 }
55 self
56 }
57 pub fn set_kmeans_n_iters(self, kmeans_n_iters: u32) -> IndexParams {
59 unsafe {
60 (*self.0).kmeans_n_iters = kmeans_n_iters;
61 }
62 self
63 }
64
65 pub fn set_kmeans_trainset_fraction(self, kmeans_trainset_fraction: f64) -> IndexParams {
69 unsafe {
70 (*self.0).kmeans_trainset_fraction = kmeans_trainset_fraction;
71 }
72 self
73 }
74
75 pub fn set_add_data_on_build(self, add_data_on_build: bool) -> IndexParams {
80 unsafe {
81 (*self.0).add_data_on_build = add_data_on_build;
82 }
83 self
84 }
85}
86
87impl fmt::Debug for IndexParams {
88 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
89 write!(f, "IndexParams({:?})", unsafe { *self.0 })
92 }
93}
94
95impl Drop for IndexParams {
96 fn drop(&mut self) {
97 if let Err(e) = check_cuvs(unsafe { ffi::cuvsIvfFlatIndexParamsDestroy(self.0) }) {
98 write!(
99 stderr(),
100 "failed to call cuvsIvfFlatIndexParamsDestroy {:?}",
101 e
102 )
103 .expect("failed to write to stderr");
104 }
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111
112 #[test]
113 fn test_index_params() {
114 let params = IndexParams::new()
115 .unwrap()
116 .set_n_lists(128)
117 .set_add_data_on_build(false);
118
119 unsafe {
120 assert_eq!((*params.0).n_lists, 128);
121 assert_eq!((*params.0).add_data_on_build, false);
122 }
123 }
124}