1use std::convert::From;
18
19use crate::error::{check_cuda, check_cuvs, Result};
20use crate::resources::Resources;
21
22#[derive(Debug)]
25pub struct ManagedTensor(ffi::DLManagedTensor);
26
27pub trait IntoDtype {
28 fn ffi_dtype() -> ffi::DLDataType;
29}
30
31impl ManagedTensor {
32 pub fn as_ptr(&self) -> *mut ffi::DLManagedTensor {
33 &self.0 as *const _ as *mut _
34 }
35
36 pub fn to_device(&self, res: &Resources) -> Result<ManagedTensor> {
39 unsafe {
40 let bytes = dl_tensor_bytes(&self.0.dl_tensor);
41 let mut device_data: *mut std::ffi::c_void = std::ptr::null_mut();
42
43 check_cuvs(ffi::cuvsRMMAlloc(res.0, &mut device_data as *mut _, bytes))?;
45
46 check_cuda(ffi::cudaMemcpyAsync(
47 device_data,
48 self.0.dl_tensor.data,
49 bytes,
50 ffi::cudaMemcpyKind_cudaMemcpyDefault,
51 res.get_cuda_stream()?,
52 ))?;
53
54 let mut ret = self.0.clone();
55 ret.dl_tensor.data = device_data;
56 ret.deleter = Some(rmm_free_tensor);
57 ret.dl_tensor.device.device_type = ffi::DLDeviceType::kDLCUDA;
58
59 Ok(ManagedTensor(ret))
60 }
61 }
62
63 pub fn to_host<
65 T: IntoDtype,
66 S: ndarray::RawData<Elem = T> + ndarray::RawDataMut,
67 D: ndarray::Dimension,
68 >(
69 &self,
70 res: &Resources,
71 arr: &mut ndarray::ArrayBase<S, D>,
72 ) -> Result<()> {
73 unsafe {
74 let bytes = dl_tensor_bytes(&self.0.dl_tensor);
75 check_cuda(ffi::cudaMemcpyAsync(
76 arr.as_mut_ptr() as *mut std::ffi::c_void,
77 self.0.dl_tensor.data,
78 bytes,
79 ffi::cudaMemcpyKind_cudaMemcpyDefault,
80 res.get_cuda_stream()?,
81 ))?;
82 Ok(())
83 }
84 }
85}
86
87fn dl_tensor_bytes(tensor: &ffi::DLTensor) -> usize {
89 let mut bytes: usize = 1;
90 for dim in 0..tensor.ndim {
91 bytes *= unsafe { (*tensor.shape.add(dim as usize)) as usize };
92 }
93 bytes *= (tensor.dtype.bits / 8) as usize;
94 bytes
95}
96
97unsafe extern "C" fn rmm_free_tensor(self_: *mut ffi::DLManagedTensor) {
98 let bytes = dl_tensor_bytes(&(*self_).dl_tensor);
99 let res = Resources::new().unwrap();
100 let _ = ffi::cuvsRMMFree(res.0, (*self_).dl_tensor.data as *mut _, bytes);
101}
102
103impl<T: IntoDtype, S: ndarray::RawData<Elem = T>, D: ndarray::Dimension>
105 From<&ndarray::ArrayBase<S, D>> for ManagedTensor
106{
107 fn from(arr: &ndarray::ArrayBase<S, D>) -> Self {
108 unsafe {
112 let mut ret = std::mem::MaybeUninit::<ffi::DLTensor>::uninit();
113 let tensor = ret.as_mut_ptr();
114 (*tensor).data = arr.as_ptr() as *mut std::os::raw::c_void;
115 (*tensor).device = ffi::DLDevice {
116 device_type: ffi::DLDeviceType::kDLCPU,
117 device_id: 0,
118 };
119 (*tensor).byte_offset = 0;
120 (*tensor).strides = std::ptr::null_mut(); (*tensor).ndim = arr.ndim() as i32;
122 (*tensor).shape = arr.shape().as_ptr() as *mut _;
123 (*tensor).dtype = T::ffi_dtype();
124 ManagedTensor(ffi::DLManagedTensor {
125 dl_tensor: ret.assume_init(),
126 manager_ctx: std::ptr::null_mut(),
127 deleter: None,
128 })
129 }
130 }
131}
132
133impl Drop for ManagedTensor {
134 fn drop(&mut self) {
135 unsafe {
136 if let Some(deleter) = self.0.deleter {
137 deleter(&mut self.0 as *mut _);
138 }
139 }
140 }
141}
142
143impl IntoDtype for f32 {
144 fn ffi_dtype() -> ffi::DLDataType {
145 ffi::DLDataType {
146 code: ffi::DLDataTypeCode::kDLFloat as _,
147 bits: 32,
148 lanes: 1,
149 }
150 }
151}
152
153impl IntoDtype for f64 {
154 fn ffi_dtype() -> ffi::DLDataType {
155 ffi::DLDataType {
156 code: ffi::DLDataTypeCode::kDLFloat as _,
157 bits: 64,
158 lanes: 1,
159 }
160 }
161}
162
163impl IntoDtype for i32 {
164 fn ffi_dtype() -> ffi::DLDataType {
165 ffi::DLDataType {
166 code: ffi::DLDataTypeCode::kDLInt as _,
167 bits: 32,
168 lanes: 1,
169 }
170 }
171}
172
173impl IntoDtype for i64 {
174 fn ffi_dtype() -> ffi::DLDataType {
175 ffi::DLDataType {
176 code: ffi::DLDataTypeCode::kDLInt as _,
177 bits: 64,
178 lanes: 1,
179 }
180 }
181}
182
183impl IntoDtype for u32 {
184 fn ffi_dtype() -> ffi::DLDataType {
185 ffi::DLDataType {
186 code: ffi::DLDataTypeCode::kDLUInt as _,
187 bits: 32,
188 lanes: 1,
189 }
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 #[test]
198 fn test_from_ndarray() {
199 let arr = ndarray::Array::<f32, _>::zeros((8, 4));
200
201 let tensor = unsafe { (*(ManagedTensor::from(&arr).as_ptr())).dl_tensor };
202
203 assert_eq!(tensor.ndim, 2);
204
205 assert_eq!(unsafe { *tensor.shape }, 8);
207 assert_eq!(unsafe { *tensor.shape.add(1) }, 4);
208 }
209}