hipvs/distance/mod.rs
1/*
2 * Copyright (c) 2024, NVIDIA CORPORATION.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17
18use crate::distance_type::DistanceType;
19use crate::dlpack::ManagedTensor;
20use crate::error::{check_cuvs, Result};
21use crate::resources::Resources;
22
23/// Compute pairwise distances between X and Y
24///
25/// # Arguments
26///
27/// * `res` - Resources to use
28/// * `x` - A matrix in device memory - shape (m, k)
29/// * `y` - A matrix in device memory - shape (n, k)
30/// * `distances` - A matrix in device memory that receives the output distances - shape (m, n)
31/// * `metric` - DistanceType to use for building the index
32/// * `metric_arg` - Optional value of `p` for Minkowski distances
33pub fn pairwise_distance(
34 res: &Resources,
35 x: &ManagedTensor,
36 y: &ManagedTensor,
37 distances: &ManagedTensor,
38 metric: DistanceType,
39 metric_arg: Option<f32>,
40) -> Result<()> {
41 unsafe {
42 check_cuvs(ffi::cuvsPairwiseDistance(
43 res.0,
44 x.as_ptr(),
45 y.as_ptr(),
46 distances.as_ptr(),
47 metric,
48 metric_arg.unwrap_or(2.0),
49 ))
50 }
51}
52
53#[cfg(test)]
54mod tests {
55 use super::*;
56 use ndarray_rand::rand_distr::Uniform;
57 use ndarray_rand::RandomExt;
58
59 #[test]
60 fn test_pairwise_distance() {
61 let res = Resources::new().unwrap();
62
63 // Create a new random dataset to index
64 let n_datapoints = 256;
65 let n_features = 16;
66 let dataset =
67 ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));
68 let dataset_device = ManagedTensor::from(&dataset).to_device(&res).unwrap();
69
70 let mut distances_host = ndarray::Array::<f32, _>::zeros((n_datapoints, n_datapoints));
71 let distances = ManagedTensor::from(&distances_host)
72 .to_device(&res)
73 .unwrap();
74
75 pairwise_distance(&res, &dataset_device, &dataset_device, &distances, DistanceType::L2Expanded,
76 None).unwrap();
77
78 // Copy back to host memory
79 distances.to_host(&res, &mut distances_host).unwrap();
80
81 // Self distance should be 0
82 assert_eq!(distances_host[[0, 0]], 0.0);
83 }
84}