/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/host/check_err.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/host/check_err.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/host/check_err.hpp Source File
check_err.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <algorithm>
7 #include <cmath>
8 #include <cstdlib>
9 #include <iostream>
10 #include <iomanip>
11 #include <iterator>
12 #include <limits>
13 #include <type_traits>
14 #include <vector>
15 
16 #include "ck_tile/core.hpp"
17 #include "ck_tile/host/ranges.hpp"
18 
19 namespace ck_tile {
20 
30 using F32 = float;
32 using I8 = int8_t;
34 using I32 = int32_t;
35 
48 template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
49 CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1)
50 {
51 
52  static_assert(
54  "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
55 
56  double compute_error = 0;
58  {
59  return 0;
60  }
61  else
62  {
63  compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
64  }
65 
67  "Warning: Unhandled OutDataType for setting up the relative threshold!");
68 
69  double output_error = 0;
71  {
72  return 0;
73  }
74  else
75  {
76  output_error = std::pow(2, -numeric_traits<OutDataType>::mant) * 0.5;
77  }
78  double midway_error = std::max(compute_error, output_error);
79 
81  "Warning: Unhandled AccDataType for setting up the relative threshold!");
82 
83  double acc_error = 0;
85  {
86  return 0;
87  }
88  else
89  {
90  acc_error = std::pow(2, -numeric_traits<AccDataType>::mant) * 0.5 * number_of_accumulations;
91  }
92  return std::max(acc_error, midway_error);
93 }
94 
108 template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
109 CK_TILE_HOST double get_absolute_threshold(const double max_possible_num,
110  const int number_of_accumulations = 1)
111 {
112 
113  static_assert(
115  "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
116 
117  auto expo = std::log2(std::abs(max_possible_num));
118  double compute_error = 0;
120  {
121  return 0;
122  }
123  else
124  {
125  compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5;
126  }
127 
129  "Warning: Unhandled OutDataType for setting up the absolute threshold!");
130 
131  double output_error = 0;
133  {
134  return 0;
135  }
136  else
137  {
138  output_error = std::pow(2, expo - numeric_traits<OutDataType>::mant) * 0.5;
139  }
140  double midway_error = std::max(compute_error, output_error);
141 
143  "Warning: Unhandled AccDataType for setting up the absolute threshold!");
144 
145  double acc_error = 0;
147  {
148  return 0;
149  }
150  else
151  {
152  acc_error =
153  std::pow(2, expo - numeric_traits<AccDataType>::mant) * 0.5 * number_of_accumulations;
154  }
155  return std::max(acc_error, midway_error);
156 }
157 
168 template <typename T>
169 std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
170 {
171  using size_type = typename std::vector<T>::size_type;
172 
173  os << "[";
174  for(size_type idx = 0; idx < v.size(); ++idx)
175  {
176  if(0 < idx)
177  {
178  os << ", ";
179  }
180  os << v[idx];
181  }
182  return os << "]";
183 }
184 
197 template <typename Range, typename RefRange>
198 CK_TILE_HOST bool check_size_mismatch(const Range& out,
199  const RefRange& ref,
200  const std::string& msg = "Error: Incorrect results!")
201 {
202  if(out.size() != ref.size())
203  {
204  std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
205  << std::endl;
206  return true;
207  }
208  return false;
209 }
210 
220 CK_TILE_HOST void report_error_stats(int err_count, double max_err, std::size_t total_size)
221 {
222  const float error_percent =
223  static_cast<float>(err_count) / static_cast<float>(total_size) * 100.f;
224  std::cerr << "max err: " << max_err;
225  std::cerr << ", number of errors: " << err_count;
226  std::cerr << ", " << error_percent << "% wrong values" << std::endl;
227 }
228 
245 template <typename Range, typename RefRange>
246 typename std::enable_if<
247  std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
248  std::is_floating_point_v<ranges::range_value_t<Range>> &&
249  !std::is_same_v<ranges::range_value_t<Range>, half_t>,
250  bool>::type CK_TILE_HOST
251 check_err(const Range& out,
252  const RefRange& ref,
253  const std::string& msg = "Error: Incorrect results!",
254  double rtol = 1e-5,
255  double atol = 3e-6,
256  bool allow_infinity_ref = false)
257 {
258 
259  if(check_size_mismatch(out, ref, msg))
260  return false;
261 
262  const auto is_infinity_error = [=](auto o, auto r) {
263  const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
264  const bool both_infinite_and_same =
265  std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
266 
267  return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
268  };
269 
270  bool res{true};
271  int err_count = 0;
272  double err = 0;
273  double max_err = std::numeric_limits<double>::min();
274  for(std::size_t i = 0; i < ref.size(); ++i)
275  {
276  const double o = *std::next(std::begin(out), i);
277  const double r = *std::next(std::begin(ref), i);
278  err = std::abs(o - r);
279  if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
280  {
281  max_err = err > max_err ? err : max_err;
282  err_count++;
283  if(err_count < 5)
284  {
285  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
286  << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
287  }
288  res = false;
289  }
290  }
291  if(!res)
292  {
293  report_error_stats(err_count, max_err, ref.size());
294  }
295  return res;
296 }
297 
314 template <typename Range, typename RefRange>
315 typename std::enable_if<
316  std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
317  std::is_same_v<ranges::range_value_t<Range>, bf16_t>,
318  bool>::type CK_TILE_HOST
319 check_err(const Range& out,
320  const RefRange& ref,
321  const std::string& msg = "Error: Incorrect results!",
322  double rtol = 1e-3,
323  double atol = 1e-3,
324  bool allow_infinity_ref = false)
325 {
326  if(check_size_mismatch(out, ref, msg))
327  return false;
328 
329  const auto is_infinity_error = [=](auto o, auto r) {
330  const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
331  const bool both_infinite_and_same =
332  std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
333 
334  return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
335  };
336 
337  bool res{true};
338  int err_count = 0;
339  double err = 0;
340  // TODO: This is a hack. We should have proper specialization for bf16_t data type.
341  double max_err = std::numeric_limits<float>::min();
342  for(std::size_t i = 0; i < ref.size(); ++i)
343  {
344  const double o = type_convert<float>(*std::next(std::begin(out), i));
345  const double r = type_convert<float>(*std::next(std::begin(ref), i));
346  err = std::abs(o - r);
347  if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
348  {
349  max_err = err > max_err ? err : max_err;
350  err_count++;
351  if(err_count < 5)
352  {
353  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
354  << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
355  }
356  res = false;
357  }
358  }
359  if(!res)
360  {
361  report_error_stats(err_count, max_err, ref.size());
362  }
363  return res;
364 }
365 
383 template <typename Range, typename RefRange>
384 typename std::enable_if<
385  std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
386  std::is_same_v<ranges::range_value_t<Range>, half_t>,
387  bool>::type CK_TILE_HOST
388 check_err(const Range& out,
389  const RefRange& ref,
390  const std::string& msg = "Error: Incorrect results!",
391  double rtol = 1e-3,
392  double atol = 1e-3,
393  bool allow_infinity_ref = false)
394 {
395  if(check_size_mismatch(out, ref, msg))
396  return false;
397 
398  const auto is_infinity_error = [=](auto o, auto r) {
399  const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
400  const bool both_infinite_and_same =
401  std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
402 
403  return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
404  };
405 
406  bool res{true};
407  int err_count = 0;
408  double err = 0;
409  double max_err = static_cast<double>(std::numeric_limits<ranges::range_value_t<Range>>::min());
410  for(std::size_t i = 0; i < ref.size(); ++i)
411  {
412  const double o = type_convert<float>(*std::next(std::begin(out), i));
413  const double r = type_convert<float>(*std::next(std::begin(ref), i));
414  err = std::abs(o - r);
415  if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
416  {
417  max_err = err > max_err ? err : max_err;
418  err_count++;
419  if(err_count < 5)
420  {
421  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
422  << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
423  }
424  res = false;
425  }
426  }
427  if(!res)
428  {
429  report_error_stats(err_count, max_err, ref.size());
430  }
431  return res;
432 }
433 
449 template <typename Range, typename RefRange>
450 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
451  std::is_integral_v<ranges::range_value_t<Range>> &&
452  !std::is_same_v<ranges::range_value_t<Range>, bf16_t>)
453 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
454  || std::is_same_v<ranges::range_value_t<Range>, int4_t>
455 #endif
456  ,
457  bool>
458  CK_TILE_HOST check_err(const Range& out,
459  const RefRange& ref,
460  const std::string& msg = "Error: Incorrect results!",
461  double = 0,
462  double atol = 0)
463 {
464  if(check_size_mismatch(out, ref, msg))
465  return false;
466 
467  bool res{true};
468  int err_count = 0;
469  int64_t err = 0;
471  for(std::size_t i = 0; i < ref.size(); ++i)
472  {
473  const int64_t o = *std::next(std::begin(out), i);
474  const int64_t r = *std::next(std::begin(ref), i);
475  err = std::abs(o - r);
476 
477  if(err > atol)
478  {
479  max_err = err > max_err ? err : max_err;
480  err_count++;
481  if(err_count < 5)
482  {
483  std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
484  << std::endl;
485  }
486  res = false;
487  }
488  }
489  if(!res)
490  {
491  report_error_stats(err_count, static_cast<double>(max_err), ref.size());
492  }
493  return res;
494 }
495 
513 template <typename Range, typename RefRange>
514 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
515  std::is_same_v<ranges::range_value_t<Range>, fp8_t>),
516  bool>
517  CK_TILE_HOST check_err(const Range& out,
518  const RefRange& ref,
519  const std::string& msg = "Error: Incorrect results!",
520  unsigned max_rounding_point_distance = 1,
521  double atol = 1e-1,
522  bool allow_infinity_ref = false)
523 {
524  if(check_size_mismatch(out, ref, msg))
525  return false;
526 
527  const auto is_infinity_error = [=](auto o, auto r) {
528  const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
529  const bool both_infinite_and_same =
530  std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
531 
532  return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
533  };
534 
535  static const auto get_rounding_point_distance = [](fp8_t o, fp8_t r) -> unsigned {
536  static const auto get_sign_bit = [](fp8_t v) -> bool {
537  return 0x80 & bit_cast<uint8_t>(v);
538  };
539 
540  if(get_sign_bit(o) ^ get_sign_bit(r))
541  {
543  }
544  else
545  {
546  return std::abs(bit_cast<int8_t>(o) - bit_cast<int8_t>(r));
547  }
548  };
549 
550  bool res{true};
551  int err_count = 0;
552  double err = 0;
553  double max_err = std::numeric_limits<float>::min();
554  for(std::size_t i = 0; i < ref.size(); ++i)
555  {
556  const fp8_t o_fp8 = *std::next(std::begin(out), i);
557  const fp8_t r_fp8 = *std::next(std::begin(ref), i);
558  const double o_fp64 = type_convert<float>(o_fp8);
559  const double r_fp64 = type_convert<float>(r_fp8);
560  err = std::abs(o_fp64 - r_fp64);
561  if(!(less_equal<double>{}(err, atol) ||
562  get_rounding_point_distance(o_fp8, r_fp8) <= max_rounding_point_distance) ||
563  is_infinity_error(o_fp64, r_fp64))
564  {
565  max_err = err > max_err ? err : max_err;
566  err_count++;
567  if(err_count < 5)
568  {
569  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
570  << "] != ref[" << i << "]: " << o_fp64 << " != " << r_fp64 << std::endl;
571  }
572  res = false;
573  }
574  }
575  if(!res)
576  {
577  report_error_stats(err_count, max_err, ref.size());
578  }
579  return res;
580 }
581 
598 template <typename Range, typename RefRange>
599 std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
600  std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
601  bool>
602  CK_TILE_HOST check_err(const Range& out,
603  const RefRange& ref,
604  const std::string& msg = "Error: Incorrect results!",
605  double rtol = 1e-3,
606  double atol = 1e-3,
607  bool allow_infinity_ref = false)
608 {
609  if(check_size_mismatch(out, ref, msg))
610  return false;
611 
612  const auto is_infinity_error = [=](auto o, auto r) {
613  const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
614  const bool both_infinite_and_same =
615  std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
616 
617  return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
618  };
619 
620  bool res{true};
621  int err_count = 0;
622  double err = 0;
623  double max_err = std::numeric_limits<float>::min();
624  for(std::size_t i = 0; i < ref.size(); ++i)
625  {
626  const double o = type_convert<float>(*std::next(std::begin(out), i));
627  const double r = type_convert<float>(*std::next(std::begin(ref), i));
628  err = std::abs(o - r);
629  if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
630  {
631  max_err = err > max_err ? err : max_err;
632  err_count++;
633  if(err_count < 5)
634  {
635  std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
636  << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
637  }
638  res = false;
639  }
640  }
641  if(!res)
642  {
643  report_error_stats(err_count, max_err, ref.size());
644  }
645  return res;
646 }
647 
648 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:39
__host__ T pow(T x, T gamma)
Definition: math_v2.hpp:427
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
float F32
32-bit floating point (single precision) type
Definition: check_err.hpp:30
ck_tile::bf16_t BF16
16-bit brain floating point type
Definition: check_err.hpp:28
_BitInt(8) fp8_t
Definition: float8.hpp:204
CK_TILE_HOST bool check_size_mismatch(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!")
Check for size mismatch between output and reference ranges.
Definition: check_err.hpp:198
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations=1)
Calculate relative error threshold for numerical comparisons.
Definition: check_err.hpp:49
CK_TILE_HOST void report_error_stats(int err_count, double max_err, std::size_t total_size)
Report error statistics for numerical comparisons.
Definition: check_err.hpp:220
CK_TILE_HOST double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations=1)
Calculate absolute error threshold for numerical comparisons.
Definition: check_err.hpp:109
std::ostream & operator<<(std::ostream &os, const std::vector< T > &v)
Stream operator overload for vector output.
Definition: check_err.hpp:169
std::enable_if< std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange > > &&std::is_floating_point_v< ranges::range_value_t< Range > > &&!std::is_same_v< ranges::range_value_t< Range >, half_t >, bool >::type CK_TILE_HOST check_err(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-5, double atol=3e-6, bool allow_infinity_ref=false)
Check errors between floating point ranges using the specified tolerances.
Definition: check_err.hpp:251
ck_tile::half_t F16
16-bit floating point (half precision) type
Definition: check_err.hpp:26
int32_t int32_t
Definition: integer.hpp:10
ck_tile::bf8_t BF8
8-bit brain floating point type
Definition: check_err.hpp:24
int32_t I32
32-bit signed integer type
Definition: check_err.hpp:34
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
ck_tile::fp8_t F8
8-bit floating point type
Definition: check_err.hpp:22
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
_Float16 half_t
Definition: half.hpp:111
int8_t I8
8-bit signed integer type
Definition: check_err.hpp:32
_BitInt(4) int4_t
Definition: data_type.hpp:31
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
constexpr bool is_same_v
Definition: type.hpp:283
long int64_t
Definition: data_type.hpp:461
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
Definition: type_traits.hpp:115
Definition: math.hpp:395
Definition: numeric.hpp:81