25 #include <hip/hip_runtime.h>
33 template <
typename SrcWrapper,
typename DstWrapper>
37 using src_type =
typename SrcWrapper::ValueType;
38 using dst_type =
typename DstWrapper::ValueType;
41 for (
int z_idx = 0; z_idx < maxBatchSize; z_idx++) {
42 double th = thresh.
at(z_idx);
43 double mv = maxVal.
at(z_idx);
44 for (
int y_idx = 0; y_idx < output.height(); y_idx++) {
45 for (
int x_idx = 0; x_idx < output.width(); x_idx++) {
46 src_type inputVal = input.at(z_idx, y_idx, x_idx, 0);
48 for (
int i = 0; i < output.channels(); i++) {
49 double ip = StaticCast<double>(
GetElement(inputVal, i));
50 double outVal = ip > th ? mv : 0;
51 GetElement(outputVal, i) = StaticCast<base_type>(outVal);
53 output.at(z_idx, y_idx, x_idx, 0) = outputVal;
59 template <
typename SrcWrapper,
typename DstWrapper>
63 using src_type =
typename SrcWrapper::ValueType;
64 using dst_type =
typename DstWrapper::ValueType;
67 for (
int z_idx = 0; z_idx < maxBatchSize; z_idx++) {
68 double th = thresh.
at(z_idx);
69 double mv = maxVal.
at(z_idx);
70 for (
int y_idx = 0; y_idx < output.height(); y_idx++) {
71 for (
int x_idx = 0; x_idx < output.width(); x_idx++) {
72 src_type inputVal = input.at(z_idx, y_idx, x_idx, 0);
74 for (
int i = 0; i < output.channels(); i++) {
75 double ip = StaticCast<double>(
GetElement(inputVal, i));
76 double outVal = ip > th ? 0 : mv;
77 GetElement(outputVal, i) = StaticCast<base_type>(outVal);
79 output.at(z_idx, y_idx, x_idx, 0) = outputVal;
85 template <
typename SrcWrapper,
typename DstWrapper>
89 using src_type =
typename SrcWrapper::ValueType;
90 using dst_type =
typename DstWrapper::ValueType;
93 for (
int z_idx = 0; z_idx < maxBatchSize; z_idx++) {
94 double th = thresh.
at(z_idx);
95 for (
int y_idx = 0; y_idx < output.height(); y_idx++) {
96 for (
int x_idx = 0; x_idx < output.width(); x_idx++) {
97 src_type inputVal = input.at(z_idx, y_idx, x_idx, 0);
99 for (
int i = 0; i < output.channels(); i++) {
100 double ip = StaticCast<double>(
GetElement(inputVal, i));
101 double outVal = ip > th ? th : ip;
102 GetElement(outputVal, i) = StaticCast<base_type>(outVal);
104 output.at(z_idx, y_idx, x_idx, 0) = outputVal;
110 template <
typename SrcWrapper,
typename DstWrapper>
114 using src_type =
typename SrcWrapper::ValueType;
115 using dst_type =
typename DstWrapper::ValueType;
118 for (
int z_idx = 0; z_idx < maxBatchSize; z_idx++) {
119 double th = thresh.
at(z_idx);
120 for (
int y_idx = 0; y_idx < output.height(); y_idx++) {
121 for (
int x_idx = 0; x_idx < output.width(); x_idx++) {
122 src_type inputVal = input.at(z_idx, y_idx, x_idx, 0);
124 for (
int i = 0; i < output.channels(); i++) {
125 double ip = StaticCast<double>(
GetElement(inputVal, i));
126 double outVal = ip > th ? ip : 0;
127 GetElement(outputVal, i) = StaticCast<base_type>(outVal);
129 output.at(z_idx, y_idx, x_idx, 0) = outputVal;
135 template <
typename SrcWrapper,
typename DstWrapper>
139 using src_type =
typename SrcWrapper::ValueType;
140 using dst_type =
typename DstWrapper::ValueType;
143 for (
int z_idx = 0; z_idx < maxBatchSize; z_idx++) {
144 double th = thresh.
at(z_idx);
145 for (
int y_idx = 0; y_idx < output.height(); y_idx++) {
146 for (
int x_idx = 0; x_idx < output.width(); x_idx++) {
147 src_type inputVal = input.at(z_idx, y_idx, x_idx, 0);
149 for (
int i = 0; i < output.channels(); i++) {
150 double ip = StaticCast<double>(
GetElement(inputVal, i));
151 double outVal = ip > th ? 0 : ip;
152 GetElement(outputVal, i) = StaticCast<base_type>(outVal);
154 output.at(z_idx, y_idx, x_idx, 0) = outputVal;
Definition: generic_tensor_wrapper.hpp:28
__device__ __host__ T & at(ARGS... idx)
Definition: generic_tensor_wrapper.hpp:48
void tozero_generic(SrcWrapper input, DstWrapper output, roccv::GenericTensorWrapper< double > thresh, const int32_t maxBatchSize)
Definition: thresholding_host.hpp:111
void binary_generic(SrcWrapper input, DstWrapper output, roccv::GenericTensorWrapper< double > thresh, roccv::GenericTensorWrapper< double > maxVal, const int32_t maxBatchSize)
Definition: thresholding_host.hpp:34
void trunc_generic(SrcWrapper input, DstWrapper output, roccv::GenericTensorWrapper< double > thresh, const int32_t maxBatchSize)
Definition: thresholding_host.hpp:86
void binary_inv_generic(SrcWrapper input, DstWrapper output, roccv::GenericTensorWrapper< double > thresh, roccv::GenericTensorWrapper< double > maxVal, const int32_t maxBatchSize)
Definition: thresholding_host.hpp:60
void tozeroinv_generic(SrcWrapper input, DstWrapper output, roccv::GenericTensorWrapper< double > thresh, const int32_t maxBatchSize)
Definition: thresholding_host.hpp:136
Definition: non_max_suppression_helpers.hpp:26
Definition: strided_data_wrap.hpp:33
typename TypeTraits< T >::base_type BaseType
Returns the base type of a given HIP vectorized type.
Definition: type_traits.hpp:117
__host__ __device__ RT & GetElement(T &v, int idx)
Definition: type_traits.hpp:128