13 #include <type_traits> 
   48 template <
typename ComputeDataType, 
typename OutDataType, 
typename AccDataType = ComputeDataType>
 
   53         is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
 
   54         "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
 
   56     double compute_error = 0;
 
   66     static_assert(
is_any_of<OutDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
 
   67                   "Warning: Unhandled OutDataType for setting up the relative threshold!");
 
   69     double output_error = 0;
 
   78     double midway_error = 
std::max(compute_error, output_error);
 
   80     static_assert(
is_any_of<AccDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
 
   81                   "Warning: Unhandled AccDataType for setting up the relative threshold!");
 
   92     return std::max(acc_error, midway_error);
 
  108 template <
typename ComputeDataType, 
typename OutDataType, 
typename AccDataType = ComputeDataType>
 
  110                                            const int number_of_accumulations = 1)
 
  114         is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
 
  115         "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
 
  117     auto expo            = std::log2(std::abs(max_possible_num));
 
  118     double compute_error = 0;
 
  128     static_assert(
is_any_of<OutDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
 
  129                   "Warning: Unhandled OutDataType for setting up the absolute threshold!");
 
  131     double output_error = 0;
 
  140     double midway_error = 
std::max(compute_error, output_error);
 
  142     static_assert(
is_any_of<AccDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
 
  143                   "Warning: Unhandled AccDataType for setting up the absolute threshold!");
 
  145     double acc_error = 0;
 
  155     return std::max(acc_error, midway_error);
 
  168 template <
typename T>
 
  169 std::ostream& 
operator<<(std::ostream& os, 
const std::vector<T>& v)
 
  171     using size_type = 
typename std::vector<T>::size_type;
 
  174     for(size_type idx = 0; idx < v.size(); ++idx)
 
  197 template <
typename Range, 
typename RefRange>
 
  200                                       const std::string& msg = 
"Error: Incorrect results!")
 
  202     if(out.size() != ref.size())
 
  204         std::cerr << msg << 
" out.size() != ref.size(), :" << out.size() << 
" != " << ref.size()
 
  222     const float error_percent =
 
  223         static_cast<float>(err_count) / 
static_cast<float>(total_size) * 100.f;
 
  224     std::cerr << 
"max err: " << max_err;
 
  225     std::cerr << 
", number of errors: " << err_count;
 
  226     std::cerr << 
", " << error_percent << 
"% wrong values" << std::endl;
 
  245 template <
typename Range, 
typename RefRange>
 
  247     std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
 
  248         std::is_floating_point_v<ranges::range_value_t<Range>> &&
 
  249         !std::is_same_v<ranges::range_value_t<Range>, 
half_t>,
 
  253           const std::string& msg  = 
"Error: Incorrect results!",
 
  256           bool allow_infinity_ref = 
false)
 
  262     const auto is_infinity_error = [=](
auto o, 
auto r) {
 
  263         const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
 
  264         const bool both_infinite_and_same =
 
  265             std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
 
  267         return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
 
  274     for(std::size_t i = 0; i < ref.size(); ++i)
 
  276         const double o = *std::next(std::begin(out), i);
 
  277         const double r = *std::next(std::begin(ref), i);
 
  278         err            = std::abs(o - r);
 
  279         if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
 
  281             max_err = err > max_err ? err : max_err;
 
  285                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  286                           << 
"] != ref[" << i << 
"]: " << o << 
" != " << r << std::endl;
 
  314 template <
typename Range, 
typename RefRange>
 
  316     std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
 
  317         std::is_same_v<ranges::range_value_t<Range>, 
bf16_t>,
 
  321           const std::string& msg  = 
"Error: Incorrect results!",
 
  324           bool allow_infinity_ref = 
false)
 
  329     const auto is_infinity_error = [=](
auto o, 
auto r) {
 
  330         const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
 
  331         const bool both_infinite_and_same =
 
  332             std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
 
  334         return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
 
  342     for(std::size_t i = 0; i < ref.size(); ++i)
 
  344         const double o = type_convert<float>(*std::next(std::begin(out), i));
 
  345         const double r = type_convert<float>(*std::next(std::begin(ref), i));
 
  346         err            = std::abs(o - r);
 
  347         if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
 
  349             max_err = err > max_err ? err : max_err;
 
  353                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  354                           << 
"] != ref[" << i << 
"]: " << o << 
" != " << r << std::endl;
 
  383 template <
typename Range, 
typename RefRange>
 
  385     std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
 
  386         std::is_same_v<ranges::range_value_t<Range>, 
half_t>,
 
  390           const std::string& msg  = 
"Error: Incorrect results!",
 
  393           bool allow_infinity_ref = 
false)
 
  398     const auto is_infinity_error = [=](
auto o, 
auto r) {
 
  399         const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
 
  400         const bool both_infinite_and_same =
 
  401             std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
 
  403         return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
 
  409     double max_err = 
static_cast<double>(std::numeric_limits<ranges::range_value_t<Range>>
::min());
 
  410     for(std::size_t i = 0; i < ref.size(); ++i)
 
  412         const double o = type_convert<float>(*std::next(std::begin(out), i));
 
  413         const double r = type_convert<float>(*std::next(std::begin(ref), i));
 
  414         err            = std::abs(o - r);
 
  415         if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
 
  417             max_err = err > max_err ? err : max_err;
 
  421                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  422                           << 
"] != ref[" << i << 
"]: " << o << 
" != " << r << std::endl;
 
  449 template <
typename Range, 
typename RefRange>
 
  450 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
 
  451                   std::is_integral_v<ranges::range_value_t<Range>> &&
 
  452                   !std::is_same_v<ranges::range_value_t<Range>, 
bf16_t>)
 
  453 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
 
  460                            const std::string& msg = 
"Error: Incorrect results!",
 
  471     for(std::size_t i = 0; i < ref.size(); ++i)
 
  473         const int64_t o = *std::next(std::begin(out), i);
 
  474         const int64_t r = *std::next(std::begin(ref), i);
 
  475         err             = std::abs(o - r);
 
  479             max_err = err > max_err ? err : max_err;
 
  483                 std::cerr << msg << 
" out[" << i << 
"] != ref[" << i << 
"]: " << o << 
" != " << r
 
  513 template <
typename Range, 
typename RefRange>
 
  514 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
 
  515                   std::is_same_v<ranges::range_value_t<Range>, 
fp8_t>),
 
  519                            const std::string& msg               = 
"Error: Incorrect results!",
 
  520                            unsigned max_rounding_point_distance = 1,
 
  522                            bool allow_infinity_ref              = 
false)
 
  527     const auto is_infinity_error = [=](
auto o, 
auto r) {
 
  528         const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
 
  529         const bool both_infinite_and_same =
 
  530             std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
 
  532         return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
 
  535     static const auto get_rounding_point_distance = [](
fp8_t o, 
fp8_t r) -> 
unsigned {
 
  536         static const auto get_sign_bit = [](
fp8_t v) -> 
bool {
 
  537             return 0x80 & bit_cast<uint8_t>(v);
 
  540         if(get_sign_bit(o) ^ get_sign_bit(r))
 
  546             return std::abs(bit_cast<int8_t>(o) - bit_cast<int8_t>(r));
 
  554     for(std::size_t i = 0; i < ref.size(); ++i)
 
  556         const fp8_t o_fp8   = *std::next(std::begin(out), i);
 
  557         const fp8_t r_fp8   = *std::next(std::begin(ref), i);
 
  558         const double o_fp64 = type_convert<float>(o_fp8);
 
  559         const double r_fp64 = type_convert<float>(r_fp8);
 
  560         err                 = std::abs(o_fp64 - r_fp64);
 
  562              get_rounding_point_distance(o_fp8, r_fp8) <= max_rounding_point_distance) ||
 
  563            is_infinity_error(o_fp64, r_fp64))
 
  565             max_err = err > max_err ? err : max_err;
 
  569                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  570                           << 
"] != ref[" << i << 
"]: " << o_fp64 << 
" != " << r_fp64 << std::endl;
 
  598 template <
typename Range, 
typename RefRange>
 
  599 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
 
  600                   std::is_same_v<ranges::range_value_t<Range>, 
bf8_t>),
 
  604                            const std::string& msg  = 
"Error: Incorrect results!",
 
  607                            bool allow_infinity_ref = 
false)
 
  612     const auto is_infinity_error = [=](
auto o, 
auto r) {
 
  613         const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
 
  614         const bool both_infinite_and_same =
 
  615             std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
 
  617         return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
 
  624     for(std::size_t i = 0; i < ref.size(); ++i)
 
  626         const double o = type_convert<float>(*std::next(std::begin(out), i));
 
  627         const double r = type_convert<float>(*std::next(std::begin(ref), i));
 
  628         err            = std::abs(o - r);
 
  629         if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
 
  631             max_err = err > max_err ? err : max_err;
 
  635                 std::cerr << msg << std::setw(12) << std::setprecision(7) << 
" out[" << i
 
  636                           << 
"] != ref[" << i << 
"]: " << o << 
" != " << r << std::endl;
 
#define CK_TILE_HOST
Definition: config.hpp:39
 
__host__ T pow(T x, T gamma)
Definition: math_v2.hpp:427
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
 
Definition: cluster_descriptor.hpp:13
 
float F32
32-bit floating point (single precision) type
Definition: check_err.hpp:30
 
ck_tile::bf16_t BF16
16-bit brain floating point type
Definition: check_err.hpp:28
 
_BitInt(8) fp8_t
Definition: float8.hpp:204
 
CK_TILE_HOST bool check_size_mismatch(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!")
Check for size mismatch between output and reference ranges.
Definition: check_err.hpp:198
 
int8_t int8_t
Definition: int8.hpp:20
 
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
 
CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations=1)
Calculate relative error threshold for numerical comparisons.
Definition: check_err.hpp:49
 
CK_TILE_HOST void report_error_stats(int err_count, double max_err, std::size_t total_size)
Report error statistics for numerical comparisons.
Definition: check_err.hpp:220
 
CK_TILE_HOST double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations=1)
Calculate absolute error threshold for numerical comparisons.
Definition: check_err.hpp:109
 
std::ostream & operator<<(std::ostream &os, const std::vector< T > &v)
Stream operator overload for vector output.
Definition: check_err.hpp:169
 
std::enable_if< std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange > > &&std::is_floating_point_v< ranges::range_value_t< Range > > &&!std::is_same_v< ranges::range_value_t< Range >, half_t >, bool >::type CK_TILE_HOST check_err(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-5, double atol=3e-6, bool allow_infinity_ref=false)
Check errors between floating point ranges using the specified tolerances.
Definition: check_err.hpp:251
 
ck_tile::half_t F16
16-bit floating point (half precision) type
Definition: check_err.hpp:26
 
int32_t int32_t
Definition: integer.hpp:10
 
ck_tile::bf8_t BF8
8-bit brain floating point type
Definition: check_err.hpp:24
 
int32_t I32
32-bit signed integer type
Definition: check_err.hpp:34
 
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
 
ck_tile::fp8_t F8
8-bit floating point type
Definition: check_err.hpp:22
 
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
 
_Float16 half_t
Definition: half.hpp:111
 
int8_t I8
8-bit signed integer type
Definition: check_err.hpp:32
 
_BitInt(4) int4_t
Definition: data_type.hpp:31
 
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
 
constexpr bool is_same_v
Definition: type.hpp:283
 
long int64_t
Definition: data_type.hpp:461
 
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
 
Definition: type_traits.hpp:115
 
Definition: numeric.hpp:81