13 #include <type_traits> 
   26 template <
typename ComputeDataType, 
typename OutDataType, 
typename AccDataType = ComputeDataType>
 
   37     static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
 
   38                       is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
 
   39                       is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, I8> ||
 
   40                       is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, int>,
 
   41                   "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
 
   42     double compute_error = 0;
 
   43     if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
 
   44                  is_same_v<ComputeDataType, int>)
 
   53     static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
 
   54                       is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
 
   55                       is_same_v<OutDataType, F32> || is_same_v<OutDataType, I8> ||
 
   56                       is_same_v<OutDataType, I32> || is_same_v<OutDataType, int>,
 
   57                   "Warning: Unhandled OutDataType for setting up the relative threshold!");
 
   58     double output_error = 0;
 
   59     if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
 
   60                  is_same_v<OutDataType, int>)
 
   68     double midway_error = 
std::max(compute_error, output_error);
 
   70     static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
 
   71                       is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
 
   72                       is_same_v<AccDataType, F32> || is_same_v<AccDataType, I8> ||
 
   73                       is_same_v<AccDataType, I32> || is_same_v<AccDataType, int>,
 
   74                   "Warning: Unhandled AccDataType for setting up the relative threshold!");
 
   76     if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
 
   77                  is_same_v<AccDataType, int>)
 
   85     return std::max(acc_error, midway_error);
 
   88 template <
typename ComputeDataType, 
typename OutDataType, 
typename AccDataType = ComputeDataType>
 
   99     static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
 
  100                       is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
 
  101                       is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, I8> ||
 
  102                       is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, int>,
 
  103                   "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
 
  104     auto expo            = std::log2(std::abs(max_possible_num));
 
  105     double compute_error = 0;
 
  106     if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
 
  107                  is_same_v<ComputeDataType, int>)
 
  116     static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
 
  117                       is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
 
  118                       is_same_v<OutDataType, F32> || is_same_v<OutDataType, I8> ||
 
  119                       is_same_v<OutDataType, I32> || is_same_v<OutDataType, int>,
 
  120                   "Warning: Unhandled OutDataType for setting up the absolute threshold!");
 
  121     double output_error = 0;
 
  122     if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
 
  123                  is_same_v<OutDataType, int>)
 
  131     double midway_error = 
std::max(compute_error, output_error);
 
  133     static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
 
  134                       is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
 
  135                       is_same_v<AccDataType, F32> || is_same_v<AccDataType, I8> ||
 
  136                       is_same_v<AccDataType, I32> || is_same_v<AccDataType, int>,
 
  137                   "Warning: Unhandled AccDataType for setting up the absolute threshold!");
 
  138     double acc_error = 0;
 
  139     if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
 
  140                  is_same_v<AccDataType, int>)
 
  149     return std::max(acc_error, midway_error);
 
  152 template <
typename Range, 
typename RefRange>
 
  155         std::is_floating_point_v<ranges::range_value_t<Range>> &&
 
  156         !std::is_same_v<ranges::range_value_t<Range>, 
half_t>,
 
  160           const std::string& msg = 
"Error: Incorrect results!",
 
  164     if(out.size() != ref.size())
 
  166         std::cerr << msg << 
" out.size() != ref.size(), :" << out.size() << 
" != " << ref.size()
 
  175     for(std::size_t i = 0; i < ref.size(); ++i)
 
  177         const double o = *std::next(std::begin(out), i);
 
  178         const double r = *std::next(std::begin(ref), i);
 
  179         err            = std::abs(o - r);
 
  180         if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
 
  182             max_err = err > max_err ? err : max_err;
 
  186                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  187                           << 
"] != ref[" << i << 
"]: " << o << 
" != " << r << std::endl;
 
  194         const float error_percent =
 
  195             static_cast<float>(err_count) / 
static_cast<float>(out.size()) * 100.f;
 
  196         std::cerr << 
"max err: " << max_err;
 
  197         std::cerr << 
", number of errors: " << err_count;
 
  198         std::cerr << 
", " << error_percent << 
"% wrong values" << std::endl;
 
  203 template <
typename Range, 
typename RefRange>
 
  206         std::is_same_v<ranges::range_value_t<Range>, 
bhalf_t>,
 
  210           const std::string& msg = 
"Error: Incorrect results!",
 
  214     if(out.size() != ref.size())
 
  216         std::cerr << msg << 
" out.size() != ref.size(), :" << out.size() << 
" != " << ref.size()
 
  226     for(std::size_t i = 0; i < ref.size(); ++i)
 
  228         const double o = type_convert<float>(*std::next(std::begin(out), i));
 
  229         const double r = type_convert<float>(*std::next(std::begin(ref), i));
 
  230         err            = std::abs(o - r);
 
  231         if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
 
  233             max_err = err > max_err ? err : max_err;
 
  237                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  238                           << 
"] != ref[" << i << 
"]: " << o << 
" != " << r << std::endl;
 
  245         const float error_percent =
 
  246             static_cast<float>(err_count) / 
static_cast<float>(out.size()) * 100.f;
 
  247         std::cerr << 
"max err: " << max_err;
 
  248         std::cerr << 
", number of errors: " << err_count;
 
  249         std::cerr << 
", " << error_percent << 
"% wrong values" << std::endl;
 
  254 template <
typename Range, 
typename RefRange>
 
  257         std::is_same_v<ranges::range_value_t<Range>, 
half_t>,
 
  261           const std::string& msg = 
"Error: Incorrect results!",
 
  265     if(out.size() != ref.size())
 
  267         std::cerr << msg << 
" out.size() != ref.size(), :" << out.size() << 
" != " << ref.size()
 
  276     for(std::size_t i = 0; i < ref.size(); ++i)
 
  278         const double o = type_convert<float>(*std::next(std::begin(out), i));
 
  279         const double r = type_convert<float>(*std::next(std::begin(ref), i));
 
  280         err            = std::abs(o - r);
 
  281         if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
 
  283             max_err = err > max_err ? err : max_err;
 
  287                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  288                           << 
"] != ref[" << i << 
"]: " << o << 
" != " << r << std::endl;
 
  295         const float error_percent =
 
  296             static_cast<float>(err_count) / 
static_cast<float>(out.size()) * 100.f;
 
  297         std::cerr << 
"max err: " << max_err;
 
  298         std::cerr << 
", number of errors: " << err_count;
 
  299         std::cerr << 
", " << error_percent << 
"% wrong values" << std::endl;
 
  304 template <
typename Range, 
typename RefRange>
 
  306                   std::is_integral_v<ranges::range_value_t<Range>> &&
 
  307                   !std::is_same_v<ranges::range_value_t<Range>, 
bhalf_t> &&
 
  308                   !std::is_same_v<ranges::range_value_t<Range>, 
f8_t> &&
 
  309                   !std::is_same_v<ranges::range_value_t<Range>, 
bf8_t>)
 
  310 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
 
  317           const std::string& msg = 
"Error: Incorrect results!",
 
  321     if(out.size() != ref.size())
 
  323         std::cerr << msg << 
" out.size() != ref.size(), :" << out.size() << 
" != " << ref.size()
 
  332     for(std::size_t i = 0; i < ref.size(); ++i)
 
  334         const int64_t o = *std::next(std::begin(out), i);
 
  335         const int64_t r = *std::next(std::begin(ref), i);
 
  336         err             = std::abs(o - r);
 
  340             max_err = err > max_err ? err : max_err;
 
  344                 std::cerr << msg << 
" out[" << i << 
"] != ref[" << i << 
"]: " << o << 
" != " << r
 
  352         const float error_percent =
 
  353             static_cast<float>(err_count) / 
static_cast<float>(out.size()) * 100.f;
 
  354         std::cerr << 
"max err: " << max_err;
 
  355         std::cerr << 
", number of errors: " << err_count;
 
  356         std::cerr << 
", " << error_percent << 
"% wrong values" << std::endl;
 
  361 template <
typename Range, 
typename RefRange>
 
  363                   std::is_same_v<ranges::range_value_t<Range>, 
f8_t>),
 
  367           const std::string& msg = 
"Error: Incorrect results!",
 
  371     if(out.size() != ref.size())
 
  373         std::cerr << msg << 
" out.size() != ref.size(), :" << out.size() << 
" != " << ref.size()
 
  383     for(std::size_t i = 0; i < ref.size(); ++i)
 
  385         const double o = type_convert<float>(*std::next(std::begin(out), i));
 
  386         const double r = type_convert<float>(*std::next(std::begin(ref), i));
 
  387         err            = std::abs(o - r);
 
  389         if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
 
  391             max_err = err > max_err ? err : max_err;
 
  395                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  396                           << 
"] != ref[" << i << 
"]: " << o << 
" != " << r << std::endl;
 
  404         std::cerr << std::setw(12) << std::setprecision(7) << 
"max err: " << max_err
 
  405                   << 
" number of errors: " << err_count << std::endl;
 
  410 template <
typename Range, 
typename RefRange>
 
  412                   std::is_same_v<ranges::range_value_t<Range>, 
bf8_t>),
 
  416           const std::string& msg = 
"Error: Incorrect results!",
 
  420     if(out.size() != ref.size())
 
  422         std::cerr << msg << 
" out.size() != ref.size(), :" << out.size() << 
" != " << ref.size()
 
  431     for(std::size_t i = 0; i < ref.size(); ++i)
 
  433         const double o = type_convert<float>(*std::next(std::begin(out), i));
 
  434         const double r = type_convert<float>(*std::next(std::begin(ref), i));
 
  435         err            = std::abs(o - r);
 
  436         if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
 
  438             max_err = err > max_err ? err : max_err;
 
  442                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  443                           << 
"] != ref[" << i << 
"]: " << o << 
" != " << r << std::endl;
 
  450         std::cerr << std::setw(12) << std::setprecision(7) << 
"max err: " << max_err << std::endl;
 
  455 template <
typename Range, 
typename RefRange>
 
  457                   std::is_same_v<ranges::range_value_t<Range>, 
f4_t>),
 
  461           const std::string& msg = 
"Error: Incorrect results!",
 
  465     if(out.size() != ref.size())
 
  467         std::cerr << msg << 
" out.size() != ref.size(), :" << out.size() << 
" != " << ref.size()
 
  477     for(std::size_t i = 0; i < ref.size(); ++i)
 
  479         const double o = type_convert<float>(*std::next(std::begin(out), i));
 
  480         const double r = type_convert<float>(*std::next(std::begin(ref), i));
 
  481         err            = std::abs(o - r);
 
  483         if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
 
  485             max_err = err > max_err ? err : max_err;
 
  489                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  490                           << 
"] != ref[" << i << 
"]: " << o << 
" != " << r << std::endl;
 
  498         std::cerr << std::setw(12) << std::setprecision(7) << 
"max err: " << max_err
 
  499                   << 
" number of errors: " << err_count << std::endl;
 
__host__ T pow(T x, T gamma)
Definition: math_v2.hpp:427
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
 
iter_value_t< ranges::iterator_t< R > > range_value_t
Definition: ranges.hpp:28
 
std::enable_if< std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange > > &&std::is_floating_point_v< ranges::range_value_t< Range > > &&!std::is_same_v< ranges::range_value_t< Range >, half_t >, bool >::type check_err(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-5, double atol=3e-6)
Definition: check_err.hpp:158
 
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations=1)
Definition: check_err.hpp:89
 
double get_relative_threshold(const int number_of_accumulations=1)
Definition: check_err.hpp:27
 
float F32
32-bit floating point (single precision) type
Definition: check_err.hpp:30
 
ck_tile::bf16_t BF16
16-bit brain floating point type
Definition: check_err.hpp:28
 
int8_t int8_t
Definition: int8.hpp:20
 
ck_tile::half_t F16
16-bit floating point (half precision) type
Definition: check_err.hpp:26
 
int32_t int32_t
Definition: integer.hpp:10
 
int32_t I32
32-bit signed integer type
Definition: check_err.hpp:34
 
ck_tile::fp8_t F8
8-bit floating point type
Definition: check_err.hpp:22
 
int8_t I8
8-bit signed integer type
Definition: check_err.hpp:32
 
bf8_fnuz_t bf8_t
Definition: amd_ck_fp8.hpp:1738
 
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
 
unsigned _BitInt(4) f4_t
Definition: data_type.hpp:32
 
_Float16 half_t
Definition: data_type.hpp:30
 
ushort bhalf_t
Definition: data_type.hpp:29
 
_BitInt(4) int4_t
Definition: data_type.hpp:31
 
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
 
constexpr bool is_same_v
Definition: type.hpp:283
 
long int64_t
Definition: data_type.hpp:461
 
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
 
Definition: numeric_limits.hpp:309
 
Definition: numeric_utils.hpp:10