hipvs/
resources.rs

1/*
2 * Copyright (c) 2024, NVIDIA CORPORATION.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::error::{check_cuvs, Result};
18use std::io::{stderr, Write};
19
20/// Resources are objects that are shared between function calls,
21/// and includes things like CUDA streams, cuBLAS handles and other
22/// resources that are expensive to create.
23#[derive(Debug)]
24pub struct Resources(pub ffi::cuvsResources_t);
25
26impl Resources {
27    /// Returns a new Resources object
28    pub fn new() -> Result<Resources> {
29        let mut res: ffi::cuvsResources_t = 0;
30        unsafe {
31            check_cuvs(ffi::cuvsResourcesCreate(&mut res))?;
32        }
33        Ok(Resources(res))
34    }
35
36    /// Sets the current cuda stream
37    pub fn set_cuda_stream(&self, stream: ffi::cudaStream_t) -> Result<()> {
38        unsafe { check_cuvs(ffi::cuvsStreamSet(self.0, stream)) }
39    }
40
41    /// Gets the current cuda stream
42    pub fn get_cuda_stream(&self) -> Result<ffi::cudaStream_t> {
43        unsafe {
44            let mut stream = std::mem::MaybeUninit::<ffi::cudaStream_t>::uninit();
45            check_cuvs(ffi::cuvsStreamGet(self.0, stream.as_mut_ptr()))?;
46            Ok(stream.assume_init())
47        }
48    }
49
50    /// Syncs the current cuda stream
51    pub fn sync_stream(&self) -> Result<()> {
52        unsafe { check_cuvs(ffi::cuvsStreamSync(self.0)) }
53    }
54}
55
56impl Drop for Resources {
57    fn drop(&mut self) {
58        unsafe {
59            if let Err(e) = check_cuvs(ffi::cuvsResourcesDestroy(self.0)) {
60                write!(stderr(), "failed to call cuvsResourcesDestroy {:?}", e)
61                    .expect("failed to write to stderr");
62            }
63        }
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70
71    #[test]
72    fn test_resources_create() {
73        let _ = Resources::new();
74    }
75}