hipvs/
error.rs

1/*
2 * Copyright (c) 2024, NVIDIA CORPORATION.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use std::fmt;
18
19#[derive(Debug, Clone)]
20pub struct CuvsError {
21    code: ffi::cuvsError_t,
22    text: String,
23}
24
25#[derive(Debug, Clone)]
26pub enum Error {
27    CudaError(ffi::cudaError_t),
28    CuvsError(CuvsError),
29}
30
31impl std::error::Error for Error {}
32impl std::error::Error for CuvsError {}
33
34pub type Result<T> = std::result::Result<T, Error>;
35
36impl fmt::Display for Error {
37    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
38        match self {
39            Error::CudaError(cuda_error) => write!(f, "cudaError={:?}", cuda_error),
40            Error::CuvsError(cuvs_error) => write!(f, "cuvsError={:?}", cuvs_error),
41        }
42    }
43}
44
45impl fmt::Display for CuvsError {
46    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
47        write!(f, "{:?}:{}", self.code, self.text)
48    }
49}
50
51/// Simple wrapper to convert a cuvsError_t into a Result
52pub fn check_cuvs(err: ffi::cuvsError_t) -> Result<()> {
53    match err {
54        ffi::cuvsError_t::CUVS_SUCCESS => Ok(()),
55        _ => {
56            // get a description of the error from cuvs
57            let cstr = unsafe {
58                let text_ptr = ffi::cuvsGetLastErrorText();
59                std::ffi::CStr::from_ptr(text_ptr)
60            };
61            let text = std::string::String::from_utf8_lossy(cstr.to_bytes()).to_string();
62
63            Err(Error::CuvsError(CuvsError { code: err, text }))
64        }
65    }
66}
67
68pub fn check_cuda(err: ffi::cudaError_t) -> Result<()> {
69    match err {
70        ffi::cudaError::cudaSuccess => Ok(()),
71        _ => Err(Error::CudaError(err)),
72    }
73}