/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/mma.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/mma.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/mma.hpp Source File
mma.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 #pragma once
6 
7 #include "amdgcn_mma.hpp"
8 #include "mma_selector.hpp"
9 #include "mma_traits.hpp"
10 #include "mma_transforms.hpp"
11 
12 #include "mfma/mfma.hpp"
13 #include "wmma/wmma.hpp"
14 
15 namespace ck_tile::core::arch::mma {
16 
20 enum struct MmaAccumPolicy
21 {
22  // Decomposition and accumulation in row-major block order
23  ROW_MAJOR,
24  // Decomposition and accumulation in col-major block order
25  COL_MAJOR
26 };
27 
54 template <typename ADataType,
55  typename BDataType,
56  typename CDataType,
57  uint32_t FragM,
58  uint32_t FragN,
59  uint32_t FragK,
61  typename CompilerTarget =
62  decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
63  // get_compiler_target(),
64  typename MmaOp =
65  typename MmaDefaultSelector<ADataType, // TODO: c++20 MmaOpI MmaOp = typename
66  // MmaDefaultSelector<ADataType,
67  BDataType,
68  CDataType,
69  FragM,
70  FragN,
71  FragK,
72  CompilerTarget>::SelectedOp,
73  typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
74  typename MmaTransformsDefaultSelector<MmaOp, CompilerTarget>::SelectedTransforms>
76 {
77 
78  using BlockWiseMmaOp = MmaOp;
80 
81  // Block dimensions
82  constexpr static uint32_t BlockM = BlockWiseMmaOpTraits::BlockM;
83  constexpr static uint32_t BlockN = BlockWiseMmaOpTraits::BlockN;
84  constexpr static uint32_t BlockK = BlockWiseMmaOpTraits::BlockK;
85 
86  // Block counts for decomposition
87  constexpr static uint32_t BlocksM = FragM / BlockM;
88  constexpr static uint32_t BlocksN = FragN / BlockN;
89  constexpr static uint32_t BlocksK = FragK / BlockK;
90  constexpr static uint32_t BlocksC = BlocksM * BlocksN;
91 
92  // Vector types for packed registers in each block
96 
97  // Buffer types for fragments
101 
102  // Transforms
103  using ATransform = typename MmaTransforms::ATransform;
104  using BTransform = typename MmaTransforms::BTransform;
105  using CTransform = typename MmaTransforms::CTransform;
106  using DTransform = typename MmaTransforms::DTransform;
107 
108  // Sanity checks
109  static_assert(FragM >= BlockM, "FragM must be larger than BlockM");
110  static_assert(FragN >= BlockN, "FragN must be larger than BlockN");
111  static_assert(FragK >= BlockK, "FragK must be larger than BlockK");
112  static_assert(FragM % BlockM == 0u, "FragM must be a multiple of BlockM");
113  static_assert(FragN % BlockN == 0u, "FragN must be a multiple of BlockN");
114  static_assert(FragK % BlockK == 0u, "FragK must be a multiple of BlockK");
115 
116  private:
117  template <typename DstT, typename SrcT>
118  CK_TILE_DEVICE static auto formatBuffer(SrcT const& inputBuffer)
119  {
120  // TODO: Implement formatting logic as needed.
121  // This is intended to convert input fragments to the native vector types
122  // required by the BlockWiseMma operation for iteration
123  static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer");
124  return reinterpret_cast<DstT const&>(inputBuffer);
125  }
126 
127  template <typename DstT, typename SrcT>
128  CK_TILE_DEVICE static auto formatBuffer(SrcT& inputBuffer)
129  {
130  // TODO: Implement formatting logic as needed.
131  // This is intended to convert input fragments to the native vector types
132  // required by the BlockWiseMma operation for iteration
133  static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer");
134  return reinterpret_cast<DstT&>(inputBuffer);
135  }
136 
142  template <typename VecTA, typename VecTB, typename VecTC>
143  CK_TILE_DEVICE static decltype(auto) exec_col_major(VecTA&& a, VecTB&& b, VecTC&& accum)
144  {
145  // We implement an example wave-tile pipeline here.
146  // First, we apply the necessary transforms to the input fragments,
147  // then we convert the result into buffers of native vector formats
148  // that we can easily index. Native vector formats are necessary inputs
149  // to the given MmaOp exec function.
150  auto a_frag = formatBuffer<ABufferType>(ATransform::exec(a));
151  auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
152  auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
153 
154  // "Col-major" accumulation over the M-dimension blocks first.
155  // Pseudo code here, but we would basically iterate over the blocks in col-major order
156  for(uint32_t bn = 0u; bn < BlocksN; ++bn)
157  {
158  for(uint32_t bm = 0u; bm < BlocksM; ++bm)
159  {
160  for(uint32_t bk = 0u; bk < BlocksK; ++bk)
161  {
162  c_frag[bm][bn] =
163  BlockWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
164  }
165  }
166  }
167 
168  // Convert native vector results back to the output fragment format
169  // and then return after we apply the final output transform.
170  return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
171  }
172 
178  template <typename VecTA, typename VecTB, typename VecTC>
179  CK_TILE_DEVICE static decltype(auto) exec_row_major(VecTA&& a, VecTB&& b, VecTC&& accum)
180  {
181  // We implement an example wave-tile pipeline here.
182  // First, we apply the necessary transforms to the input fragments,
183  // then we convert the result into buffers of native vector formats
184  // that we can easily index. Native vector formats are necessary inputs
185  // to the given MmaOp exec function.
186  auto a_frag = formatBuffer<ABufferType>(ATransform::exec(a));
187  auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
188  auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
189 
190  // "Row-major" accumulation over the N-dimension blocks first.
191  // Pseudo code here, but we would basically iterate over the blocks in row-major order.
192  // We also have to ensure that the incoming vector fragments are converted to native vector
193  // types before passing to the BlockWiseMma exec function.
194  for(uint32_t bm = 0u; bm < BlocksM; ++bm)
195  {
196  for(uint32_t bn = 0u; bn < BlocksN; ++bn)
197  {
198  for(uint32_t bk = 0u; bk < BlocksK; ++bk)
199  {
200  c_frag[bm][bn] =
201  BlockWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
202  }
203  }
204  }
205 
206  // Convert native vector results back to the output fragment format
207  // and then return after we apply the final output transform.
208  return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
209  }
210 
211  public:
217  template <typename VecTA, typename VecTB, typename VecTC>
218  CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum)
219  {
220  if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR)
221  {
222  return exec_row_major(
223  std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
224  }
225  else // if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR)
226  {
227  return exec_col_major(
228  std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
229  }
230  }
231 };
232 
233 } // namespace ck_tile::core::arch::mma
#define CK_TILE_DEVICE
Definition: config.hpp:45
Definition: amdgcn_mma.hpp:10
MmaAccumPolicy
Accumulation order for Mma decomposition.
Definition: mma.hpp:21
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1517
unsigned int uint32_t
Definition: stdint.h:126
Reflects the template parameters and static members of a given MmaOp.
Definition: mma_traits.hpp:125
typename MmaOp::CVecType CVecType
Definition: mma_traits.hpp:130
typename MmaOp::BVecType BVecType
Definition: mma_traits.hpp:129
typename MmaOp::AVecType AVecType
Definition: mma_traits.hpp:128
typename MmaTransforms::CTransform CTransform
Definition: mma.hpp:105
constexpr static uint32_t BlocksM
Definition: mma.hpp:87
typename MmaTransforms::ATransform ATransform
Definition: mma.hpp:103
constexpr static uint32_t BlockK
Definition: mma.hpp:84
constexpr static uint32_t BlocksC
Definition: mma.hpp:90
constexpr static uint32_t BlockM
Definition: mma.hpp:82
typename BlockWiseMmaOpTraits::BVecType BVecType
Definition: mma.hpp:94
constexpr static uint32_t BlockN
Definition: mma.hpp:83
static decltype(auto) CK_TILE_DEVICE exec(VecTA &&a, VecTB &&b, VecTC &&accum)
Forward to Mma operation with specified accumulation order.
Definition: mma.hpp:218
constexpr static uint32_t BlocksK
Definition: mma.hpp:89
typename MmaTransforms::BTransform BTransform
Definition: mma.hpp:104
AVecType[BlocksM][BlocksK] ABufferType
Definition: mma.hpp:98
MmaOp BlockWiseMmaOp
Definition: mma.hpp:78
typename BlockWiseMmaOpTraits::AVecType AVecType
Definition: mma.hpp:93
typename MmaTransforms::DTransform DTransform
Definition: mma.hpp:106
BVecType[BlocksN][BlocksK] BBufferType
Definition: mma.hpp:99
CVecType[BlocksM][BlocksN] CBufferType
Definition: mma.hpp:100
constexpr static uint32_t BlocksN
Definition: mma.hpp:88
typename BlockWiseMmaOpTraits::CVecType CVecType
Definition: mma.hpp:95