/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-roccv/checkouts/latest/include/core/tensor_data.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-roccv/checkouts/latest/include/core/tensor_data.hpp Source File#

4 min read time

Applies to Linux

rocCV: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-roccv/checkouts/latest/include/core/tensor_data.hpp Source File
tensor_data.hpp
Go to the documentation of this file.
1 
23 #pragma once
24 
25 #include <stdint.h>
26 
27 #include <iostream>
28 #include <optional>
29 
30 #include "data_type.hpp"
31 #include "tensor_buffer.hpp"
32 #include "tensor_shape.hpp"
33 
34 namespace roccv {
35 
42 class TensorData {
43  public:
44  TensorData() = delete;
45  virtual ~TensorData() = default;
46 
52  virtual int rank() const;
53 
59  virtual const TensorShape &shape() const &;
60 
67  virtual const int64_t shape(int d) const &;
68 
74  virtual const DataType &dtype() const;
75 
81  virtual void *basePtr() const;
82 
89  virtual const eDeviceType device() const;
90 
91  template <typename Derived>
92  std::optional<Derived> cast() {
93  static_assert(std::is_base_of<TensorData, Derived>::value, "Cannot cast TensorData to an unrelated type.");
94  static_assert(sizeof(Derived) == sizeof(TensorData), "Derived type must not add any additional data members.");
95  return std::optional(Derived(m_shape, m_dtype, m_buffer, m_deviceType));
96  }
97 
98  protected:
99  TensorData(const TensorShape &tshape, const DataType &dtype, const TensorBufferStrided &buffer,
101 
106 };
107 
115  public:
126  TensorDataStrided(const TensorShape &tshape, const DataType &dtype, const TensorBufferStrided &buffer,
128 
135  const int64_t stride(int d) const;
136 };
137 } // namespace roccv
Supported data types for use with the Tensor utilities.
Definition: data_type.hpp:33
Holds the underlying tensor data alongside metadata (shape, layout, datatype). Non-strided tensor dat...
Definition: tensor_data.hpp:42
TensorShape m_shape
Definition: tensor_data.hpp:102
virtual int rank() const
Returns the rank (the number of dimensions) of the tensor data.
std::optional< Derived > cast()
Definition: tensor_data.hpp:92
virtual void * basePtr() const
Returns the base pointer of the tensor data in memory.
TensorData(const TensorShape &tshape, const DataType &dtype, const TensorBufferStrided &buffer, const eDeviceType device=eDeviceType::GPU)
virtual const eDeviceType device() const
Retrieves the location where the tensor data is allocated, either on the device or the host.
virtual const DataType & dtype() const
Retrieves the data type of the tensor's elements.
TensorBufferStrided m_buffer
Definition: tensor_data.hpp:105
virtual const TensorShape & shape() const &
Returns the shape of the tensor.
DataType m_dtype
Definition: tensor_data.hpp:103
virtual ~TensorData()=default
eDeviceType m_deviceType
Definition: tensor_data.hpp:104
Holds the underlying tensor data alongside tensor metadata. This particular tensor data type is used ...
Definition: tensor_data.hpp:114
const int64_t stride(int d) const
Returns the stride at a given dimension.
TensorDataStrided(const TensorShape &tshape, const DataType &dtype, const TensorBufferStrided &buffer, const eDeviceType device=eDeviceType::GPU)
Constructs a TensorDataStrided object.
TensorShape class.
Definition: tensor_shape.hpp:34
Definition: strided_data_wrap.hpp:33
A tensor buffer with strided data.
Definition: tensor_buffer.hpp:46
eDeviceType
Describes the device type. Used to determine where Tensor data should be allocated and whether operat...
Definition: util_enums.h:69