13 #include <type_traits>
26 template <
typename ComputeDataType,
typename OutDataType,
typename AccDataType = ComputeDataType>
38 static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
39 is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
40 is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, TF32> ||
41 is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
42 is_same_v<ComputeDataType, int>,
43 "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
44 double compute_error = 0;
45 if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
46 is_same_v<ComputeDataType, int>)
55 static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
56 is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
57 is_same_v<OutDataType, F32> || is_same_v<ComputeDataType, TF32> ||
58 is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
59 is_same_v<OutDataType, int>,
60 "Warning: Unhandled OutDataType for setting up the relative threshold!");
61 double output_error = 0;
62 if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
63 is_same_v<OutDataType, int>)
71 double midway_error =
std::max(compute_error, output_error);
73 static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
74 is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
75 is_same_v<AccDataType, F32> || is_same_v<ComputeDataType, TF32> ||
76 is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
77 is_same_v<AccDataType, int>,
78 "Warning: Unhandled AccDataType for setting up the relative threshold!");
80 if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
81 is_same_v<AccDataType, int>)
89 return std::max(acc_error, midway_error);
92 template <
typename ComputeDataType,
typename OutDataType,
typename AccDataType = ComputeDataType>
104 static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
105 is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
106 is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, TF32> ||
107 is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
108 is_same_v<ComputeDataType, int>,
109 "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
110 auto expo = std::log2(std::abs(max_possible_num));
111 double compute_error = 0;
112 if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
113 is_same_v<ComputeDataType, int>)
122 static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
123 is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
124 is_same_v<OutDataType, F32> || is_same_v<ComputeDataType, TF32> ||
125 is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
126 is_same_v<OutDataType, int>,
127 "Warning: Unhandled OutDataType for setting up the absolute threshold!");
128 double output_error = 0;
129 if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
130 is_same_v<OutDataType, int>)
138 double midway_error =
std::max(compute_error, output_error);
140 static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
141 is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
142 is_same_v<AccDataType, F32> || is_same_v<ComputeDataType, TF32> ||
143 is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
144 is_same_v<AccDataType, int>,
145 "Warning: Unhandled AccDataType for setting up the absolute threshold!");
146 double acc_error = 0;
147 if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
148 is_same_v<AccDataType, int>)
157 return std::max(acc_error, midway_error);
160 template <
typename Range,
165 std::is_same_v<ranges::range_value_t<Range>,
float> &&
166 std::is_same_v<ComputeDataType, ck::tf32_t>,
170 const std::string& msg =
"Error: Incorrect results!",
174 if(out.size() != ref.size())
176 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
185 for(std::size_t i = 0; i < ref.size(); ++i)
187 const double o = *std::next(std::begin(out), i);
188 const double r = *std::next(std::begin(ref), i);
189 err = std::abs(o - r);
190 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
192 max_err = err > max_err ? err : max_err;
195 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
196 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
204 const float error_percent =
205 static_cast<float>(err_count) /
static_cast<float>(out.size()) * 100.f;
206 std::cerr <<
"max err: " << max_err;
207 std::cerr <<
", number of errors: " << err_count;
208 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
213 template <
typename Range,
218 std::is_floating_point_v<ranges::range_value_t<Range>> &&
219 !std::is_same_v<ranges::range_value_t<Range>,
half_t> &&
220 !std::is_same_v<ComputeDataType, ck::tf32_t>,
224 const std::string& msg =
"Error: Incorrect results!",
228 if(out.size() != ref.size())
230 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
239 for(std::size_t i = 0; i < ref.size(); ++i)
241 const double o = *std::next(std::begin(out), i);
242 const double r = *std::next(std::begin(ref), i);
243 err = std::abs(o - r);
244 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
246 max_err = err > max_err ? err : max_err;
249 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
250 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
258 const float error_percent =
259 static_cast<float>(err_count) /
static_cast<float>(out.size()) * 100.f;
260 std::cerr <<
"max err: " << max_err;
261 std::cerr <<
", number of errors: " << err_count;
262 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
267 template <
typename Range,
272 std::is_same_v<ranges::range_value_t<Range>,
bhalf_t>,
276 const std::string& msg =
"Error: Incorrect results!",
280 if(out.size() != ref.size())
282 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
292 for(std::size_t i = 0; i < ref.size(); ++i)
294 const double o = type_convert<float>(*std::next(std::begin(out), i));
295 const double r = type_convert<float>(*std::next(std::begin(ref), i));
296 err = std::abs(o - r);
297 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
299 max_err = err > max_err ? err : max_err;
303 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
304 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
311 const float error_percent =
312 static_cast<float>(err_count) /
static_cast<float>(out.size()) * 100.f;
313 std::cerr <<
"max err: " << max_err;
314 std::cerr <<
", number of errors: " << err_count;
315 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
320 template <
typename Range,
325 std::is_same_v<ranges::range_value_t<Range>,
half_t>,
329 const std::string& msg =
"Error: Incorrect results!",
333 if(out.size() != ref.size())
335 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
344 for(std::size_t i = 0; i < ref.size(); ++i)
346 const double o = type_convert<float>(*std::next(std::begin(out), i));
347 const double r = type_convert<float>(*std::next(std::begin(ref), i));
348 err = std::abs(o - r);
349 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
351 max_err = err > max_err ? err : max_err;
355 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
356 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
363 const float error_percent =
364 static_cast<float>(err_count) /
static_cast<float>(out.size()) * 100.f;
365 std::cerr <<
"max err: " << max_err;
366 std::cerr <<
", number of errors: " << err_count;
367 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
372 template <
typename Range,
376 std::is_integral_v<ranges::range_value_t<Range>> &&
377 !std::is_same_v<ranges::range_value_t<Range>,
bhalf_t> &&
378 !std::is_same_v<ranges::range_value_t<Range>,
f8_t> &&
379 !std::is_same_v<ranges::range_value_t<Range>,
bf8_t>)
380 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
387 const std::string& msg =
"Error: Incorrect results!",
391 if(out.size() != ref.size())
393 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
402 for(std::size_t i = 0; i < ref.size(); ++i)
404 const int64_t o = *std::next(std::begin(out), i);
405 const int64_t r = *std::next(std::begin(ref), i);
406 err = std::abs(o - r);
410 max_err = err > max_err ? err : max_err;
414 std::cerr << msg <<
" out[" << i <<
"] != ref[" << i <<
"]: " << o <<
" != " << r
422 const float error_percent =
423 static_cast<float>(err_count) /
static_cast<float>(out.size()) * 100.f;
424 std::cerr <<
"max err: " << max_err;
425 std::cerr <<
", number of errors: " << err_count;
426 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
431 template <
typename Range,
435 std::is_same_v<ranges::range_value_t<Range>,
f8_t>),
439 const std::string& msg =
"Error: Incorrect results!",
443 if(out.size() != ref.size())
445 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
455 for(std::size_t i = 0; i < ref.size(); ++i)
457 const double o = type_convert<float>(*std::next(std::begin(out), i));
458 const double r = type_convert<float>(*std::next(std::begin(ref), i));
459 err = std::abs(o - r);
461 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
463 max_err = err > max_err ? err : max_err;
467 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
468 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
476 std::cerr << std::setw(12) << std::setprecision(7) <<
"max err: " << max_err
477 <<
" number of errors: " << err_count << std::endl;
482 template <
typename Range,
486 std::is_same_v<ranges::range_value_t<Range>,
bf8_t>),
490 const std::string& msg =
"Error: Incorrect results!",
494 if(out.size() != ref.size())
496 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
505 for(std::size_t i = 0; i < ref.size(); ++i)
507 const double o = type_convert<float>(*std::next(std::begin(out), i));
508 const double r = type_convert<float>(*std::next(std::begin(ref), i));
509 err = std::abs(o - r);
510 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
512 max_err = err > max_err ? err : max_err;
516 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
517 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
524 std::cerr << std::setw(12) << std::setprecision(7) <<
"max err: " << max_err << std::endl;
529 template <
typename Range,
533 std::is_same_v<ranges::range_value_t<Range>,
f4_t>),
537 const std::string& msg =
"Error: Incorrect results!",
541 if(out.size() != ref.size())
543 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
553 for(std::size_t i = 0; i < ref.size(); ++i)
555 const double o = type_convert<float>(*std::next(std::begin(out), i));
556 const double r = type_convert<float>(*std::next(std::begin(ref), i));
557 err = std::abs(o - r);
559 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
561 max_err = err > max_err ? err : max_err;
565 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
566 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
574 std::cerr << std::setw(12) << std::setprecision(7) <<
"max err: " << max_err
575 <<
" number of errors: " << err_count << std::endl;
__host__ T pow(T x, T gamma)
Definition: math_v2.hpp:427
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
iter_value_t< ranges::iterator_t< R > > range_value_t
Definition: ranges.hpp:28
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations=1)
Definition: check_err.hpp:93
double get_relative_threshold(const int number_of_accumulations=1)
Definition: check_err.hpp:27
std::enable_if< std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange > > &&std::is_same_v< ranges::range_value_t< Range >, float > &&std::is_same_v< ComputeDataType, ck::tf32_t >, bool >::type check_err(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-5, double atol=3e-5)
Definition: check_err.hpp:168
float F32
32-bit floating point (single precision) type
Definition: check_err.hpp:33
ck_tile::bf16_t BF16
16-bit brain floating point type
Definition: check_err.hpp:31
ck_tile::half_t F16
16-bit floating point (half precision) type
Definition: check_err.hpp:29
int32_t I32
32-bit signed integer type
Definition: check_err.hpp:37
ck_tile::fp8_t F8
8-bit floating point type
Definition: check_err.hpp:25
int8_t I8
8-bit signed integer type
Definition: check_err.hpp:35
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1762
unsigned _BitInt(4) f4_t
Definition: data_type.hpp:33
_Float16 half_t
Definition: data_type.hpp:31
_BitInt(19) tf32_t
Definition: data_type.hpp:29
ushort bhalf_t
Definition: data_type.hpp:30
_BitInt(4) int4_t
Definition: data_type.hpp:32
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
constexpr bool is_same_v
Definition: type.hpp:283
long int64_t
Definition: data_type.hpp:464
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
signed int int32_t
Definition: stdint.h:123
signed char int8_t
Definition: stdint.h:121
Definition: numeric_limits.hpp:309
Definition: numeric_utils.hpp:10
Definition: amd_ck_fp8.hpp:49
Definition: amd_ck_fp8.hpp:36