/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp Source File
block_wp_asmem_breg_creg.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 // A is block window on shared memory
12 // B is block window on register
13 // C is block distributed tensor
14 template <typename Problem_, typename BlockPolicy_>
16 {
23 
24  static constexpr auto I0 = number<0>();
25  static constexpr auto I1 = number<1>();
26  static constexpr auto I2 = number<2>();
27  static constexpr auto idxM = I0;
28  static constexpr auto idxN = I1;
29  static constexpr auto idxK = I2;
33 
34  static constexpr index_t MPerBlock = BlockGemmShape::kM;
35  static constexpr index_t NPerBlock = BlockGemmShape::kN;
36  static constexpr index_t KPerBlock = BlockGemmShape::kK;
37 
38  static constexpr index_t kBlockSize = Problem::kBlockSize;
39 
40  static constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
41  using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
42 
43  static constexpr index_t MWarp = config.template at<1>();
44  static constexpr index_t NWarp = config.template at<2>();
45 
46  static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
47  static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
48  static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
49 
50  static constexpr index_t MPerBlockPerIter = MWarp * WarpGemm::kM;
51  static constexpr index_t KPerBlockPerIter = WarpGemm::kK;
52 
53  static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
54 
58 
59  using AWarpTensor = typename WarpGemm::AWarpTensor;
61 
63  {
64  constexpr auto a_block_outer_dstr_encoding =
70  sequence<0, 0>>{};
71  constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
72  a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
73 
74  return a_block_dstr_encode;
75  }
76 
77  template <typename SmemBlockWindow>
78  CK_TILE_DEVICE auto MakeALoadWindows(SmemBlockWindow& a_block_window) const
79  {
80  constexpr auto a_load_dstr = make_static_tile_distribution(MakeABlockDistributionEncode());
81 
82  // create MIterPerWarp × KIterPerWarp window
83  return generate_tuple(
84  [&](auto kIter) {
85  return generate_tuple(
86  [&](auto mIter) {
87  return make_tile_window(
89  a_block_window,
91  sequence<(mIter + 1) * MPerBlockPerIter,
92  (kIter + 1) * KPerBlockPerIter>{}),
93  a_load_dstr);
94  },
96  },
98  }
99 
100  template <typename ALoadWindows>
101  CK_TILE_DEVICE void LocalPrefetch(const ALoadWindows& a_load_windows)
102  {
103 
104  static_for<0, m_preload, 1>{}([&](auto loadIter) {
105  constexpr auto mIter = loadIter % MIterPerWarp;
106  constexpr auto kIter = loadIter / MIterPerWarp;
107 
109  a_load_windows[number<kIter>{}][number<mIter>{}]);
110  });
111  }
112 
113  CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
114  {
115  constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
116  sequence<>,
121  sequence<0, 0>>{};
122 
123  constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
124  c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
125 
126  constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
127 
128  auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
129  return c_block_tensor;
130  }
131 
132  // C += A * B
133  template <typename CBlockTensor,
134  typename ALoadWindows,
135  typename BFlatBlockTensor,
136  typename BFlatDistribution>
137  CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
138  const ALoadWindows& a_load_windows,
139  BFlatBlockTensor& b_block_tensor,
140  const BFlatDistribution&)
141  {
142  constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
143 
144  using CWarpDstr = typename WarpGemm::CWarpDstr;
145  using CWarpTensor = typename WarpGemm::CWarpTensor;
146 
147  using BWarpTensor = typename WarpGemm::BWarpTensor;
148 
149  constexpr auto b_block_y_lengths =
150  to_sequence(BFlatDistribution{}.get_ys_to_d_descriptor().get_lengths());
151 
152  constexpr auto c_warp_y_lengths =
153  to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
154 
155  constexpr auto b_block_y_index_zeros =
157  constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
158 
159  static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
160  static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
161  constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
162  static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
163  // read C warp tensor from C block tensor
164  BWarpTensor b_warp_tensor;
165  CWarpTensor c_warp_tensor;
166 
167  b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
169  typename sequence_split<decltype(b_block_y_index_zeros),
170  2>::right_type{}),
172  sequence<1, 1>{},
173  typename sequence_split<decltype(b_block_y_lengths), 2>::right_type{}));
174 
175  c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
176  merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
177  merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
178 
179  // warp GEMM
180  WarpGemm{}(
181  c_warp_tensor, preloaded_a_warp_tensor(number<AwarpIter>{}), b_warp_tensor);
182 
183  // write C warp tensor into C block tensor
184  c_block_tensor.set_y_sliced_thread_data(
185  merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
186  merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
187  c_warp_tensor.get_thread_buffer());
188 
189  __builtin_amdgcn_sched_barrier(0x7F6);
190  });
191  // preload next A from lds
192  if constexpr((kIter * MIterPerWarp + mIter) <
194  {
195  constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
196  constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
197 
199  a_load_windows[number<AkIter>{}][number<AmIter>{}]);
200  }
201 
202  // barrier
203  if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
204  {
205  block_sync_lds();
206  }
207  });
208  });
209  }
210 };
211 
212 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition: slice_tile.hpp:23
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1066
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:837
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:36
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:495
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1037
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
Definition: block_wp_asmem_breg_creg.hpp:16
static constexpr index_t kBlockSize
Definition: block_wp_asmem_breg_creg.hpp:38
remove_cvref_t< typename BlockGemmShape::WarpTile > WarpTile
Definition: block_wp_asmem_breg_creg.hpp:32
static constexpr index_t NWarp
Definition: block_wp_asmem_breg_creg.hpp:44
static constexpr index_t NIterPerWarp
Definition: block_wp_asmem_breg_creg.hpp:47
static constexpr index_t m_preload
Definition: block_wp_asmem_breg_creg.hpp:55
typename WarpGemm::AWarpTensor AWarpTensor
Definition: block_wp_asmem_breg_creg.hpp:59
static constexpr index_t MPerBlockPerIter
Definition: block_wp_asmem_breg_creg.hpp:50
static constexpr index_t NPerBlock
Definition: block_wp_asmem_breg_creg.hpp:35
static constexpr index_t MWarp
Definition: block_wp_asmem_breg_creg.hpp:43
static constexpr auto I0
Definition: block_wp_asmem_breg_creg.hpp:24
static constexpr index_t MPerBlock
Definition: block_wp_asmem_breg_creg.hpp:34
static constexpr CK_TILE_DEVICE auto MakeCBlockTile()
Definition: block_wp_asmem_breg_creg.hpp:113
remove_cvref_t< BlockPolicy_ > BlockPolicy
Definition: block_wp_asmem_breg_creg.hpp:18
remove_cvref_t< typename Problem::ADataType > ADataType
Definition: block_wp_asmem_breg_creg.hpp:19
CK_TILE_DEVICE void LocalPrefetch(const ALoadWindows &a_load_windows)
Definition: block_wp_asmem_breg_creg.hpp:101
CK_TILE_DEVICE void operator()(CBlockTensor &c_block_tensor, const ALoadWindows &a_load_windows, BFlatBlockTensor &b_block_tensor, const BFlatDistribution &)
Definition: block_wp_asmem_breg_creg.hpp:137
static constexpr auto idxM
Definition: block_wp_asmem_breg_creg.hpp:27
static constexpr index_t KPerBlockPerIter
Definition: block_wp_asmem_breg_creg.hpp:51
CK_TILE_DEVICE auto MakeALoadWindows(SmemBlockWindow &a_block_window) const
Definition: block_wp_asmem_breg_creg.hpp:78
static constexpr index_t MIterPerWarp
Definition: block_wp_asmem_breg_creg.hpp:46
static constexpr index_t KIterPerWarp
Definition: block_wp_asmem_breg_creg.hpp:48
static constexpr auto I2
Definition: block_wp_asmem_breg_creg.hpp:26
remove_cvref_t< typename Problem::BDataType > BDataType
Definition: block_wp_asmem_breg_creg.hpp:20
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition: block_wp_asmem_breg_creg.hpp:22
static constexpr auto idxN
Definition: block_wp_asmem_breg_creg.hpp:28
remove_cvref_t< typename BlockGemmShape::BlockTile > BlockTile
Definition: block_wp_asmem_breg_creg.hpp:30
static constexpr auto I1
Definition: block_wp_asmem_breg_creg.hpp:25
remove_cvref_t< typename BlockGemmShape::BlockWarps > BlockWarps
Definition: block_wp_asmem_breg_creg.hpp:31
static constexpr index_t KPerBlock
Definition: block_wp_asmem_breg_creg.hpp:36
statically_indexed_array< AWarpTensor, m_preload > preloaded_a_warp_tensor
Definition: block_wp_asmem_breg_creg.hpp:60
static constexpr index_t DsReadPreload
Definition: block_wp_asmem_breg_creg.hpp:53
static constexpr CK_TILE_DEVICE auto MakeABlockDistributionEncode()
Definition: block_wp_asmem_breg_creg.hpp:62
remove_cvref_t< typename Problem::CDataType > CDataType
Definition: block_wp_asmem_breg_creg.hpp:21
static constexpr auto idxK
Definition: block_wp_asmem_breg_creg.hpp:29
remove_cvref_t< Problem_ > Problem
Definition: block_wp_asmem_breg_creg.hpp:17
remove_cvref_t< decltype(config.template at< 0 >())> WarpGemm
Definition: block_wp_asmem_breg_creg.hpp:41
static constexpr auto config
Definition: block_wp_asmem_breg_creg.hpp:40
Definition: integral_constant.hpp:13
Definition: sequence.hpp:363
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192