hipvs/
dlpack.rs

1/*
2 * Copyright (c) 2024, NVIDIA CORPORATION.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use std::convert::From;
18
19use crate::error::{check_cuda, check_cuvs, Result};
20use crate::resources::Resources;
21
22/// ManagedTensor is a wrapper around a dlpack DLManagedTensor object.
23/// This lets you pass matrices in device or host memory into cuvs.
24#[derive(Debug)]
25pub struct ManagedTensor(ffi::DLManagedTensor);
26
27pub trait IntoDtype {
28    fn ffi_dtype() -> ffi::DLDataType;
29}
30
31impl ManagedTensor {
32    pub fn as_ptr(&self) -> *mut ffi::DLManagedTensor {
33        &self.0 as *const _ as *mut _
34    }
35
36    /// Creates a new ManagedTensor on the current GPU device, and copies
37    /// the data into it.
38    pub fn to_device(&self, res: &Resources) -> Result<ManagedTensor> {
39        unsafe {
40            let bytes = dl_tensor_bytes(&self.0.dl_tensor);
41            let mut device_data: *mut std::ffi::c_void = std::ptr::null_mut();
42
43            // allocate storage, copy over
44            check_cuvs(ffi::cuvsRMMAlloc(res.0, &mut device_data as *mut _, bytes))?;
45
46            check_cuda(ffi::cudaMemcpyAsync(
47                device_data,
48                self.0.dl_tensor.data,
49                bytes,
50                ffi::cudaMemcpyKind_cudaMemcpyDefault,
51                res.get_cuda_stream()?,
52            ))?;
53
54            let mut ret = self.0.clone();
55            ret.dl_tensor.data = device_data;
56            ret.deleter = Some(rmm_free_tensor);
57            ret.dl_tensor.device.device_type = ffi::DLDeviceType::kDLCUDA;
58
59            Ok(ManagedTensor(ret))
60        }
61    }
62
63    /// Copies data from device memory into host memory
64    pub fn to_host<
65        T: IntoDtype,
66        S: ndarray::RawData<Elem = T> + ndarray::RawDataMut,
67        D: ndarray::Dimension,
68    >(
69        &self,
70        res: &Resources,
71        arr: &mut ndarray::ArrayBase<S, D>,
72    ) -> Result<()> {
73        unsafe {
74            let bytes = dl_tensor_bytes(&self.0.dl_tensor);
75            check_cuda(ffi::cudaMemcpyAsync(
76                arr.as_mut_ptr() as *mut std::ffi::c_void,
77                self.0.dl_tensor.data,
78                bytes,
79                ffi::cudaMemcpyKind_cudaMemcpyDefault,
80                res.get_cuda_stream()?,
81            ))?;
82            Ok(())
83        }
84    }
85}
86
87/// Figures out how many bytes are in a DLTensor
88fn dl_tensor_bytes(tensor: &ffi::DLTensor) -> usize {
89    let mut bytes: usize = 1;
90    for dim in 0..tensor.ndim {
91        bytes *= unsafe { (*tensor.shape.add(dim as usize)) as usize };
92    }
93    bytes *= (tensor.dtype.bits / 8) as usize;
94    bytes
95}
96
97unsafe extern "C" fn rmm_free_tensor(self_: *mut ffi::DLManagedTensor) {
98    let bytes = dl_tensor_bytes(&(*self_).dl_tensor);
99    let res = Resources::new().unwrap();
100    let _ = ffi::cuvsRMMFree(res.0, (*self_).dl_tensor.data as *mut _, bytes);
101}
102
103/// Create a non-owning view of a Tensor from a ndarray
104impl<T: IntoDtype, S: ndarray::RawData<Elem = T>, D: ndarray::Dimension>
105    From<&ndarray::ArrayBase<S, D>> for ManagedTensor
106{
107    fn from(arr: &ndarray::ArrayBase<S, D>) -> Self {
108        // There is a draft PR out right now for creating dlpack directly from ndarray
109        // right now, but until its merged we have to implement ourselves
110        //https://github.com/rust-ndarray/ndarray/pull/1306/files
111        unsafe {
112            let mut ret = std::mem::MaybeUninit::<ffi::DLTensor>::uninit();
113            let tensor = ret.as_mut_ptr();
114            (*tensor).data = arr.as_ptr() as *mut std::os::raw::c_void;
115            (*tensor).device = ffi::DLDevice {
116                device_type: ffi::DLDeviceType::kDLCPU,
117                device_id: 0,
118            };
119            (*tensor).byte_offset = 0;
120            (*tensor).strides = std::ptr::null_mut(); // TODO: error if not rowmajor
121            (*tensor).ndim = arr.ndim() as i32;
122            (*tensor).shape = arr.shape().as_ptr() as *mut _;
123            (*tensor).dtype = T::ffi_dtype();
124            ManagedTensor(ffi::DLManagedTensor {
125                dl_tensor: ret.assume_init(),
126                manager_ctx: std::ptr::null_mut(),
127                deleter: None,
128            })
129        }
130    }
131}
132
133impl Drop for ManagedTensor {
134    fn drop(&mut self) {
135        unsafe {
136            if let Some(deleter) = self.0.deleter {
137                deleter(&mut self.0 as *mut _);
138            }
139        }
140    }
141}
142
143impl IntoDtype for f32 {
144    fn ffi_dtype() -> ffi::DLDataType {
145        ffi::DLDataType {
146            code: ffi::DLDataTypeCode::kDLFloat as _,
147            bits: 32,
148            lanes: 1,
149        }
150    }
151}
152
153impl IntoDtype for f64 {
154    fn ffi_dtype() -> ffi::DLDataType {
155        ffi::DLDataType {
156            code: ffi::DLDataTypeCode::kDLFloat as _,
157            bits: 64,
158            lanes: 1,
159        }
160    }
161}
162
163impl IntoDtype for i32 {
164    fn ffi_dtype() -> ffi::DLDataType {
165        ffi::DLDataType {
166            code: ffi::DLDataTypeCode::kDLInt as _,
167            bits: 32,
168            lanes: 1,
169        }
170    }
171}
172
173impl IntoDtype for i64 {
174    fn ffi_dtype() -> ffi::DLDataType {
175        ffi::DLDataType {
176            code: ffi::DLDataTypeCode::kDLInt as _,
177            bits: 64,
178            lanes: 1,
179        }
180    }
181}
182
183impl IntoDtype for u32 {
184    fn ffi_dtype() -> ffi::DLDataType {
185        ffi::DLDataType {
186            code: ffi::DLDataTypeCode::kDLUInt as _,
187            bits: 32,
188            lanes: 1,
189        }
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    #[test]
198    fn test_from_ndarray() {
199        let arr = ndarray::Array::<f32, _>::zeros((8, 4));
200
201        let tensor = unsafe { (*(ManagedTensor::from(&arr).as_ptr())).dl_tensor };
202
203        assert_eq!(tensor.ndim, 2);
204
205        // make sure we can get the shape ok
206        assert_eq!(unsafe { *tensor.shape }, 8);
207        assert_eq!(unsafe { *tensor.shape.add(1) }, 4);
208    }
209}