hipvs/cagra/mod.rs
1/*
2 * Copyright (c) 2024, NVIDIA CORPORATION.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17//! CAGRA is a graph-based nearest neighbors implementation with state-of-the art
18//! query performance for both small- and large-batch sized search.
19//!
20//! Example:
21//! ```
22//!
23//! use hipvs::cagra::{Index, IndexParams, SearchParams};
24//! use hipvs::{ManagedTensor, Resources, Result};
25//!
26//! use ndarray::s;
27//! use ndarray_rand::rand_distr::Uniform;
28//! use ndarray_rand::RandomExt;
29//!
30//! fn cagra_example() -> Result<()> {
31//! let res = Resources::new()?;
32//!
33//! // Create a new random dataset to index
34//! let n_datapoints = 65536;
35//! let n_features = 512;
36//! let dataset =
37//! ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));
38//!
39//! // build the cagra index
40//! let build_params = IndexParams::new()?;
41//! let index = Index::build(&res, &build_params, &dataset)?;
42//! println!(
43//! "Indexed {}x{} datapoints into cagra index",
44//! n_datapoints, n_features
45//! );
46//!
47//! // use the first 4 points from the dataset as queries : will test that we get them back
48//! // as their own nearest neighbor
49//! let n_queries = 4;
50//! let queries = dataset.slice(s![0..n_queries, ..]);
51//!
52//! let k = 10;
53//!
54//! // CAGRA search API requires queries and outputs to be on device memory
55//! // copy query data over, and allocate new device memory for the distances/ neighbors
56//! // outputs
57//! let queries = ManagedTensor::from(&queries).to_device(&res)?;
58//! let mut neighbors_host = ndarray::Array::<u32, _>::zeros((n_queries, k));
59//! let neighbors = ManagedTensor::from(&neighbors_host).to_device(&res)?;
60//!
61//! let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
62//! let distances = ManagedTensor::from(&distances_host).to_device(&res)?;
63//!
64//! let search_params = SearchParams::new()?;
65//!
66//! index.search(&res, &search_params, &queries, &neighbors, &distances)?;
67//!
68//! // Copy back to host memory
69//! distances.to_host(&res, &mut distances_host)?;
70//! neighbors.to_host(&res, &mut neighbors_host)?;
71//!
72//! // nearest neighbors should be themselves, since queries are from the
73//! // dataset
74//! println!("Neighbors {:?}", neighbors_host);
75//! println!("Distances {:?}", distances_host);
76//! Ok(())
77//! }
78//! ```
79
80mod index;
81mod index_params;
82mod search_params;
83
84pub use index::Index;
85pub use index_params::{BuildAlgo, CompressionParams, IndexParams};
86pub use search_params::{HashMode, SearchAlgo, SearchParams};