13 #include <type_traits>
48 template <
typename ComputeDataType,
typename OutDataType,
typename AccDataType = ComputeDataType>
53 is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
54 "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
56 double compute_error = 0;
66 static_assert(
is_any_of<OutDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
67 "Warning: Unhandled OutDataType for setting up the relative threshold!");
69 double output_error = 0;
78 double midway_error =
std::max(compute_error, output_error);
80 static_assert(
is_any_of<AccDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
81 "Warning: Unhandled AccDataType for setting up the relative threshold!");
92 return std::max(acc_error, midway_error);
108 template <
typename ComputeDataType,
typename OutDataType,
typename AccDataType = ComputeDataType>
110 const int number_of_accumulations = 1)
114 is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
115 "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
117 auto expo = std::log2(std::abs(max_possible_num));
118 double compute_error = 0;
128 static_assert(
is_any_of<OutDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
129 "Warning: Unhandled OutDataType for setting up the absolute threshold!");
131 double output_error = 0;
140 double midway_error =
std::max(compute_error, output_error);
142 static_assert(
is_any_of<AccDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
143 "Warning: Unhandled AccDataType for setting up the absolute threshold!");
145 double acc_error = 0;
155 return std::max(acc_error, midway_error);
168 template <
typename T>
169 std::ostream&
operator<<(std::ostream& os,
const std::vector<T>& v)
171 using size_type =
typename std::vector<T>::size_type;
174 for(size_type idx = 0; idx < v.size(); ++idx)
197 template <
typename Range,
typename RefRange>
200 const std::string& msg =
"Error: Incorrect results!")
202 if(out.size() != ref.size())
204 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
222 const float error_percent =
223 static_cast<float>(err_count) /
static_cast<float>(total_size) * 100.f;
224 std::cerr <<
"max err: " << max_err;
225 std::cerr <<
", number of errors: " << err_count;
226 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
245 template <
typename Range,
typename RefRange>
247 std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
248 std::is_floating_point_v<ranges::range_value_t<Range>> &&
249 !std::is_same_v<ranges::range_value_t<Range>,
half_t>,
253 const std::string& msg =
"Error: Incorrect results!",
256 bool allow_infinity_ref =
false)
262 const auto is_infinity_error = [=](
auto o,
auto r) {
263 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
264 const bool both_infinite_and_same =
265 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
267 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
274 for(std::size_t i = 0; i < ref.size(); ++i)
276 const double o = *std::next(std::begin(out), i);
277 const double r = *std::next(std::begin(ref), i);
278 err = std::abs(o - r);
279 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
281 max_err = err > max_err ? err : max_err;
285 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
286 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
314 template <
typename Range,
typename RefRange>
316 std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
317 std::is_same_v<ranges::range_value_t<Range>,
bf16_t>,
321 const std::string& msg =
"Error: Incorrect results!",
324 bool allow_infinity_ref =
false)
329 const auto is_infinity_error = [=](
auto o,
auto r) {
330 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
331 const bool both_infinite_and_same =
332 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
334 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
342 for(std::size_t i = 0; i < ref.size(); ++i)
344 const double o = type_convert<float>(*std::next(std::begin(out), i));
345 const double r = type_convert<float>(*std::next(std::begin(ref), i));
346 err = std::abs(o - r);
347 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
349 max_err = err > max_err ? err : max_err;
353 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
354 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
383 template <
typename Range,
typename RefRange>
385 std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
386 std::is_same_v<ranges::range_value_t<Range>,
half_t>,
390 const std::string& msg =
"Error: Incorrect results!",
393 bool allow_infinity_ref =
false)
398 const auto is_infinity_error = [=](
auto o,
auto r) {
399 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
400 const bool both_infinite_and_same =
401 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
403 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
409 double max_err =
static_cast<double>(std::numeric_limits<ranges::range_value_t<Range>>
::min());
410 for(std::size_t i = 0; i < ref.size(); ++i)
412 const double o = type_convert<float>(*std::next(std::begin(out), i));
413 const double r = type_convert<float>(*std::next(std::begin(ref), i));
414 err = std::abs(o - r);
415 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
417 max_err = err > max_err ? err : max_err;
421 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
422 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
449 template <
typename Range,
typename RefRange>
450 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
451 std::is_integral_v<ranges::range_value_t<Range>> &&
452 !std::is_same_v<ranges::range_value_t<Range>,
bf16_t>)
453 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
460 const std::string& msg =
"Error: Incorrect results!",
471 for(std::size_t i = 0; i < ref.size(); ++i)
473 const int64_t o = *std::next(std::begin(out), i);
474 const int64_t r = *std::next(std::begin(ref), i);
475 err = std::abs(o - r);
479 max_err = err > max_err ? err : max_err;
483 std::cerr << msg <<
" out[" << i <<
"] != ref[" << i <<
"]: " << o <<
" != " << r
513 template <
typename Range,
typename RefRange>
514 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
515 std::is_same_v<ranges::range_value_t<Range>,
fp8_t>),
519 const std::string& msg =
"Error: Incorrect results!",
520 unsigned max_rounding_point_distance = 1,
522 bool allow_infinity_ref =
false)
527 const auto is_infinity_error = [=](
auto o,
auto r) {
528 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
529 const bool both_infinite_and_same =
530 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
532 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
535 static const auto get_rounding_point_distance = [](
fp8_t o,
fp8_t r) ->
unsigned {
536 static const auto get_sign_bit = [](
fp8_t v) ->
bool {
537 return 0x80 & bit_cast<uint8_t>(v);
540 if(get_sign_bit(o) ^ get_sign_bit(r))
546 return std::abs(bit_cast<int8_t>(o) - bit_cast<int8_t>(r));
554 for(std::size_t i = 0; i < ref.size(); ++i)
556 const fp8_t o_fp8 = *std::next(std::begin(out), i);
557 const fp8_t r_fp8 = *std::next(std::begin(ref), i);
558 const double o_fp64 = type_convert<float>(o_fp8);
559 const double r_fp64 = type_convert<float>(r_fp8);
560 err = std::abs(o_fp64 - r_fp64);
562 get_rounding_point_distance(o_fp8, r_fp8) <= max_rounding_point_distance) ||
563 is_infinity_error(o_fp64, r_fp64))
565 max_err = err > max_err ? err : max_err;
569 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
570 <<
"] != ref[" << i <<
"]: " << o_fp64 <<
" != " << r_fp64 << std::endl;
598 template <
typename Range,
typename RefRange>
599 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
600 std::is_same_v<ranges::range_value_t<Range>,
bf8_t>),
604 const std::string& msg =
"Error: Incorrect results!",
607 bool allow_infinity_ref =
false)
612 const auto is_infinity_error = [=](
auto o,
auto r) {
613 const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
614 const bool both_infinite_and_same =
615 std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
617 return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
624 for(std::size_t i = 0; i < ref.size(); ++i)
626 const double o = type_convert<float>(*std::next(std::begin(out), i));
627 const double r = type_convert<float>(*std::next(std::begin(ref), i));
628 err = std::abs(o - r);
629 if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
631 max_err = err > max_err ? err : max_err;
635 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
636 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
#define CK_TILE_HOST
Definition: config.hpp:39
__host__ T pow(T x, T gamma)
Definition: math_v2.hpp:427
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
float F32
32-bit floating point (single precision) type
Definition: check_err.hpp:30
ck_tile::bf16_t BF16
16-bit brain floating point type
Definition: check_err.hpp:28
_BitInt(8) fp8_t
Definition: float8.hpp:204
CK_TILE_HOST bool check_size_mismatch(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!")
Check for size mismatch between output and reference ranges.
Definition: check_err.hpp:198
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations=1)
Calculate relative error threshold for numerical comparisons.
Definition: check_err.hpp:49
CK_TILE_HOST void report_error_stats(int err_count, double max_err, std::size_t total_size)
Report error statistics for numerical comparisons.
Definition: check_err.hpp:220
CK_TILE_HOST double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations=1)
Calculate absolute error threshold for numerical comparisons.
Definition: check_err.hpp:109
std::ostream & operator<<(std::ostream &os, const std::vector< T > &v)
Stream operator overload for vector output.
Definition: check_err.hpp:169
std::enable_if< std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange > > &&std::is_floating_point_v< ranges::range_value_t< Range > > &&!std::is_same_v< ranges::range_value_t< Range >, half_t >, bool >::type CK_TILE_HOST check_err(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-5, double atol=3e-6, bool allow_infinity_ref=false)
Check errors between floating point ranges using the specified tolerances.
Definition: check_err.hpp:251
ck_tile::half_t F16
16-bit floating point (half precision) type
Definition: check_err.hpp:26
int32_t int32_t
Definition: integer.hpp:10
ck_tile::bf8_t BF8
8-bit brain floating point type
Definition: check_err.hpp:24
int32_t I32
32-bit signed integer type
Definition: check_err.hpp:34
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
ck_tile::fp8_t F8
8-bit floating point type
Definition: check_err.hpp:22
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
_Float16 half_t
Definition: half.hpp:111
int8_t I8
8-bit signed integer type
Definition: check_err.hpp:32
_BitInt(4) int4_t
Definition: data_type.hpp:31
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
constexpr bool is_same_v
Definition: type.hpp:283
long int64_t
Definition: data_type.hpp:461
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
Definition: type_traits.hpp:115
Definition: numeric.hpp:81