25 #include <hip/hip_runtime.h>
33 template <
typename SrcWrapper,
typename DstWrapper>
37 const int32_t maxBatchSize) {
39 const auto x_idx = threadIdx.x + blockIdx.x * blockDim.x;
40 const auto y_idx = threadIdx.y + blockIdx.y * blockDim.y;
41 const auto z_idx = threadIdx.z + blockIdx.z * blockDim.z;
43 using src_type =
typename SrcWrapper::ValueType;
44 using dst_type =
typename DstWrapper::ValueType;
47 if (x_idx < output.width() && y_idx < output.height() && z_idx < maxBatchSize) {
48 double th = thresh.
at(z_idx);
49 double mv = maxVal.
at(z_idx);
50 src_type inputVal = input.at(z_idx, y_idx, x_idx, 0);
53 for (
int i = 0; i < output.channels(); i++) {
54 double ip = StaticCast<double>(
GetElement(inputVal, i));
55 double outVal = ip > th ? mv : 0;
56 GetElement(outputVal, i) = StaticCast<base_type>(outVal);
58 output.at(z_idx, y_idx, x_idx, 0) = outputVal;
62 template <
typename SrcWrapper,
typename DstWrapper>
66 const int32_t maxBatchSize) {
68 const auto x_idx = threadIdx.x + blockIdx.x * blockDim.x;
69 const auto y_idx = threadIdx.y + blockIdx.y * blockDim.y;
70 const auto z_idx = threadIdx.z + blockIdx.z * blockDim.z;
72 using src_type =
typename SrcWrapper::ValueType;
73 using dst_type =
typename DstWrapper::ValueType;
76 if (x_idx < output.width() && y_idx < output.height() && z_idx < maxBatchSize) {
77 double th = thresh.
at(z_idx);
78 double mv = maxVal.
at(z_idx);
79 src_type inputVal = input.at(z_idx, y_idx, x_idx, 0);
82 for (
int i = 0; i < output.channels(); i++) {
83 double ip = StaticCast<double>(
GetElement(inputVal, i));
84 double outVal = ip > th ? 0 : mv;
85 GetElement(outputVal, i) = StaticCast<base_type>(outVal);
87 output.at(z_idx, y_idx, x_idx, 0) = outputVal;
91 template <
typename SrcWrapper,
typename DstWrapper>
94 const int32_t maxBatchSize) {
96 const auto x_idx = threadIdx.x + blockIdx.x * blockDim.x;
97 const auto y_idx = threadIdx.y + blockIdx.y * blockDim.y;
98 const auto z_idx = threadIdx.z + blockIdx.z * blockDim.z;
100 using src_type =
typename SrcWrapper::ValueType;
101 using dst_type =
typename DstWrapper::ValueType;
104 if (x_idx < output.width() && y_idx < output.height() && z_idx < maxBatchSize) {
105 double th = thresh.
at(z_idx);
106 src_type inputVal = input.at(z_idx, y_idx, x_idx, 0);
109 for (
int i = 0; i < output.channels(); i++) {
110 double ip = StaticCast<double>(
GetElement(inputVal, i));
111 double outVal = ip > th ? th : ip;
112 GetElement(outputVal, i) = StaticCast<base_type>(outVal);
114 output.at(z_idx, y_idx, x_idx, 0) = outputVal;
118 template <
typename SrcWrapper,
typename DstWrapper>
121 const int32_t maxBatchSize) {
123 const auto x_idx = threadIdx.x + blockIdx.x * blockDim.x;
124 const auto y_idx = threadIdx.y + blockIdx.y * blockDim.y;
125 const auto z_idx = threadIdx.z + blockIdx.z * blockDim.z;
127 using src_type =
typename SrcWrapper::ValueType;
128 using dst_type =
typename DstWrapper::ValueType;
131 if (x_idx < output.width() && y_idx < output.height() && z_idx < maxBatchSize) {
132 double th = thresh.
at(z_idx);
133 src_type inputVal = input.at(z_idx, y_idx, x_idx, 0);
136 for (
int i = 0; i < output.channels(); i++) {
137 double ip = StaticCast<double>(
GetElement(inputVal, i));
138 double outVal = ip > th ? ip : 0;
139 GetElement(outputVal, i) = StaticCast<base_type>(outVal);
141 output.at(z_idx, y_idx, x_idx, 0) = outputVal;
145 template <
typename SrcWrapper,
typename DstWrapper>
148 const int32_t maxBatchSize) {
150 const auto x_idx = threadIdx.x + blockIdx.x * blockDim.x;
151 const auto y_idx = threadIdx.y + blockIdx.y * blockDim.y;
152 const auto z_idx = threadIdx.z + blockIdx.z * blockDim.z;
154 using src_type =
typename SrcWrapper::ValueType;
155 using dst_type =
typename DstWrapper::ValueType;
158 if (x_idx >= output.width() || y_idx >= output.height())
return;
160 double th = thresh.
at(z_idx);
161 src_type inputVal = input.at(z_idx, y_idx, x_idx, 0);
164 for (
int i = 0; i < output.channels(); i++) {
165 double ip = StaticCast<double>(
GetElement(inputVal, i));
166 double outVal = ip > th ? 0 : ip;
167 GetElement(outputVal, i) = StaticCast<base_type>(outVal);
169 output.at(z_idx, y_idx, x_idx, 0) = outputVal;
Definition: generic_tensor_wrapper.hpp:28
__device__ __host__ T & at(ARGS... idx)
Definition: generic_tensor_wrapper.hpp:48
__global__ void tozero_generic(SrcWrapper input, DstWrapper output, roccv::GenericTensorWrapper< double > thresh, const int32_t maxBatchSize)
Definition: thresholding_device.hpp:119
__global__ void binary_generic(SrcWrapper input, DstWrapper output, roccv::GenericTensorWrapper< double > thresh, roccv::GenericTensorWrapper< double > maxVal, const int32_t maxBatchSize)
Definition: thresholding_device.hpp:34
__global__ void tozeroinv_generic(SrcWrapper input, DstWrapper output, roccv::GenericTensorWrapper< double > thresh, const int32_t maxBatchSize)
Definition: thresholding_device.hpp:146
__global__ void binary_inv_generic(SrcWrapper input, DstWrapper output, roccv::GenericTensorWrapper< double > thresh, roccv::GenericTensorWrapper< double > maxVal, const int32_t maxBatchSize)
Definition: thresholding_device.hpp:63
__global__ void trunc_generic(SrcWrapper input, DstWrapper output, roccv::GenericTensorWrapper< double > thresh, const int32_t maxBatchSize)
Definition: thresholding_device.hpp:92
Definition: non_max_suppression_helpers.hpp:26
Definition: strided_data_wrap.hpp:33
typename TypeTraits< T >::base_type BaseType
Returns the base type of a given HIP vectorized type.
Definition: type_traits.hpp:117
__host__ __device__ RT & GetElement(T &v, int idx)
Definition: type_traits.hpp:128