hipvs/distance/
mod.rs

1/*
2 * Copyright (c) 2024, NVIDIA CORPORATION.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17
18use crate::distance_type::DistanceType;
19use crate::dlpack::ManagedTensor;
20use crate::error::{check_cuvs, Result};
21use crate::resources::Resources;
22
23/// Compute pairwise distances between X and Y
24///
25/// # Arguments
26///
27/// * `res` - Resources to use
28/// * `x` - A matrix in device memory - shape (m, k)
29/// * `y` - A matrix in device memory - shape (n, k)
30/// * `distances` - A matrix in device memory that receives the output distances - shape (m, n)
31/// * `metric` - DistanceType to use for building the index
32/// * `metric_arg` - Optional value of `p` for Minkowski distances
33pub fn pairwise_distance(
34    res: &Resources,
35    x: &ManagedTensor,
36    y: &ManagedTensor,
37    distances: &ManagedTensor,
38    metric: DistanceType,
39    metric_arg: Option<f32>,
40) -> Result<()> {
41    unsafe {
42        check_cuvs(ffi::cuvsPairwiseDistance(
43            res.0,
44            x.as_ptr(),
45            y.as_ptr(),
46            distances.as_ptr(),
47            metric,
48            metric_arg.unwrap_or(2.0),
49        ))
50    }
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56    use ndarray_rand::rand_distr::Uniform;
57    use ndarray_rand::RandomExt;
58
59    #[test]
60    fn test_pairwise_distance() {
61        let res = Resources::new().unwrap();
62
63        // Create a new random dataset to index
64        let n_datapoints = 256;
65        let n_features = 16;
66        let dataset =
67            ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));
68        let dataset_device = ManagedTensor::from(&dataset).to_device(&res).unwrap();
69
70        let mut distances_host = ndarray::Array::<f32, _>::zeros((n_datapoints, n_datapoints));
71        let distances = ManagedTensor::from(&distances_host)
72            .to_device(&res)
73            .unwrap();
74
75        pairwise_distance(&res, &dataset_device, &dataset_device, &distances, DistanceType::L2Expanded,
76        None).unwrap();
77
78        // Copy back to host memory
79        distances.to_host(&res, &mut distances_host).unwrap();
80
81        // Self distance should be 0
82        assert_eq!(distances_host[[0, 0]], 0.0);
83    }
84}