/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp Source File
threadwise_welford.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/utility/math_v2.hpp"
7 
8 namespace ck {
9 
10 // Assume
11 // 1) XDesc is known at compile-time
12 // 2) MeanVarDesc is known at compile-time
13 // 3) XBuffer is static buffer
14 // 4) MeanBuffer is static buffer
15 // 5) VarBuffer is static buffer
16 template <typename T, typename XThreadDesc_M_K, typename MeanVarThreadDesc_M>
18 {
19  static constexpr auto x_thread_desc_m_k = XThreadDesc_M_K{};
20  static constexpr auto mean_var_thread_desc_m = MeanVarThreadDesc_M{};
21 
22  static constexpr auto thread_x_length_m = x_thread_desc_m_k.GetLength(Number<0>{});
23  static constexpr auto thread_x_length_k = x_thread_desc_m_k.GetLength(Number<1>{});
24  static constexpr auto thread_mean_var_length_m = mean_var_thread_desc_m.GetLength(Number<0>{});
25 
27  "lengths of source and mean/var buffer must match!");
28 
29  __device__ constexpr ThreadwiseWelford() : cur_count_(0), max_count_(0) {}
30 
31  __device__ inline void Update(T& mean, T& var, T x)
32  {
33  using ck::math::isnan;
34 
35  if(isnan(x))
36  {
37  mean = x;
38  var = x;
39  }
40  else
41  {
42  T delta = x - mean;
43  mean += delta / cur_count_;
44  T delta2 = x - mean;
45  var += delta * delta2;
46  }
47  }
48 
49  template <typename XBufferType, typename MeanBufferType, typename VarBufferType>
50  __device__ void
51  Run(const XBufferType& x_buf_m_k, MeanBufferType& mean_buf_m, VarBufferType& var_buf_m)
52  {
53  // FIXME - Better naming for var_buf_m
54 
57  {
58  ++cur_count_;
59 
61  constexpr index_t out_offset =
62  mean_var_thread_desc_m.CalculateOffset(make_tuple(iM));
63 
64  constexpr auto in_offset =
65  x_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK));
66  Update(mean_buf_m(Number<out_offset>{}),
67  var_buf_m(Number<out_offset>{}),
68  x_buf_m_k[Number<in_offset>{}]);
69  });
70  }
71  });
72  };
73 
76 };
77 
78 template <typename T,
79  typename SrcMeanVarCountThreadDesc_M_K,
80  typename DstMeanVarThreadDesc_M,
81  bool GetActualVariance = false>
83 {
84  static constexpr auto src_thread_desc_m_k = SrcMeanVarCountThreadDesc_M_K{};
85  static constexpr auto dst_thread_desc_m = DstMeanVarThreadDesc_M{};
86 
87  static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{});
88  static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{});
89  static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{});
90 
91  static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
92 
93  __device__ static void
94  Merge(T& mean_a, T& var_a, int32_t& count_a, T mean_b, T var_b, int32_t count_b)
95  {
96  int count = count_a + count_b;
97  T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count;
98  T delta = mean_b - mean_a;
99  mean_a += delta * count_b_over_count;
100  var_a += var_b + delta * delta * count_a * count_b_over_count;
101  count_a = count;
102  }
103 
104  template <typename SrcMeanBufferType,
105  typename SrcVarBufferType,
106  typename SrcCountBufferType,
107  typename DstMeanBufferType,
108  typename DstVarBufferType,
109  typename DstCountBufferType>
110  __device__ static void Run(const SrcMeanBufferType& src_mean_buf,
111  const SrcVarBufferType& src_var_buf,
112  const SrcCountBufferType& src_count_buf,
113  DstMeanBufferType& dst_mean_buf,
114  DstVarBufferType& dst_var_buf,
115  DstCountBufferType& dst_count_buf)
116  {
117  static_for<0, src_length_m, 1>{}([&](auto iM) {
118  static_for<0, src_length_k, 1>{}([&](auto iK) {
119  constexpr auto src_offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK));
120 
121  Merge(dst_mean_buf(iM),
122  dst_var_buf(iM),
123  dst_count_buf(iM),
124  src_mean_buf[Number<src_offset>{}],
125  src_var_buf[Number<src_offset>{}],
126  src_count_buf[Number<src_offset>{}]);
127  });
128 
129  if constexpr(GetActualVariance)
130  {
131  dst_var_buf(iM) = dst_var_buf[iM] / dst_count_buf[iM];
132  };
133  });
134  };
135 };
136 
137 } // namespace ck
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
Definition: threadwise_welford.hpp:18
__device__ void Update(T &mean, T &var, T x)
Definition: threadwise_welford.hpp:31
static constexpr auto thread_mean_var_length_m
Definition: threadwise_welford.hpp:24
int cur_count_
Definition: threadwise_welford.hpp:72
constexpr __device__ ThreadwiseWelford()
Definition: threadwise_welford.hpp:29
static constexpr auto mean_var_thread_desc_m
Definition: threadwise_welford.hpp:20
static constexpr auto x_thread_desc_m_k
Definition: threadwise_welford.hpp:19
__device__ void Run(const XBufferType &x_buf_m_k, MeanBufferType &mean_buf_m, VarBufferType &var_buf_m)
Definition: threadwise_welford.hpp:51
int max_count_
Definition: threadwise_welford.hpp:75
static constexpr auto thread_x_length_m
Definition: threadwise_welford.hpp:22
static constexpr auto thread_x_length_k
Definition: threadwise_welford.hpp:23
Definition: threadwise_welford.hpp:83
static __device__ void Run(const SrcMeanBufferType &src_mean_buf, const SrcVarBufferType &src_var_buf, const SrcCountBufferType &src_count_buf, DstMeanBufferType &dst_mean_buf, DstVarBufferType &dst_var_buf, DstCountBufferType &dst_count_buf)
Definition: threadwise_welford.hpp:110
static constexpr auto src_length_k
Definition: threadwise_welford.hpp:88
static __device__ void Merge(T &mean_a, T &var_a, int32_t &count_a, T mean_b, T var_b, int32_t count_b)
Definition: threadwise_welford.hpp:94
static constexpr auto dst_thread_desc_m
Definition: threadwise_welford.hpp:85
static constexpr auto dst_length_m
Definition: threadwise_welford.hpp:89
static constexpr auto src_thread_desc_m_k
Definition: threadwise_welford.hpp:84
static constexpr auto src_length_m
Definition: threadwise_welford.hpp:87
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31