/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp Source File
gridwise_gemm_dl_v1r3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck {
19 
20 template <typename GridwiseGemm,
21  typename FloatAB,
22  typename FloatC,
23  typename AGridDesc_K0_M0_M1_K1,
24  typename BGridDesc_K0_N0_N1_K1,
25  typename CGridDesc_M0_M10_M11_N0_N10_N11,
26  typename Block2CTileMap,
27  bool HasMainKBlockLoop,
28  bool HasDoubleTailKBlockLoop>
29 __global__ void
30 #if CK_USE_LAUNCH_BOUNDS
32 #endif
33  kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid,
34  const FloatAB* __restrict__ p_b_grid,
35  FloatC* __restrict__ p_c_grid,
36  const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
37  const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
38  const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
39  const Block2CTileMap block_2_ctile_map)
40 {
41  constexpr index_t shared_block_size =
42  GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
43 
44  __shared__ FloatAB p_shared_block[shared_block_size];
45 
46  GridwiseGemm::Run(p_a_grid,
47  p_b_grid,
48  p_c_grid,
49  p_shared_block,
50  a_grid_desc_k0_m0_m1_k1,
51  b_grid_desc_k0_n0_n1_k1,
52  c_grid_desc_m0_m10_m11_n0_n10_n11,
53  block_2_ctile_map,
56 }
57 
58 template <index_t BlockSize,
59  typename FloatAB,
60  typename FloatAcc,
61  typename FloatC,
62  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
63  typename AGridDesc_K0_M_K1,
64  typename BGridDesc_K0_N_K1,
65  typename CGridDesc_M_N,
66  index_t MPerBlock,
67  index_t NPerBlock,
68  index_t K0PerBlock,
69  index_t K1Value,
70  index_t M1PerThreadM111,
71  index_t N1PerThreadN111,
72  index_t KPerThread,
73  typename M11N11ThreadClusterM110Xs,
74  typename M11N11ThreadClusterN110Xs,
75  typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
76  typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
77  typename ABlockTransferThreadClusterArrangeOrder,
78  typename ABlockTransferSrcAccessOrder,
79  typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
80  typename ABlockTransferSrcVectorTensorContiguousDimOrder,
81  typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
82  typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
83  typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
84  typename BBlockTransferThreadClusterArrangeOrder,
85  typename BBlockTransferSrcAccessOrder,
86  typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
87  typename BBlockTransferSrcVectorTensorContiguousDimOrder,
88  typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
89  typename CThreadTransferSrcDstAccessOrder,
90  index_t CThreadTransferSrcDstVectorDim,
91  index_t CThreadTransferDstScalarPerVector>
93 {
94  static constexpr auto I0 = Number<0>{};
95  static constexpr auto I1 = Number<1>{};
96  static constexpr auto I2 = Number<2>{};
97  static constexpr auto I3 = Number<3>{};
98 
99  // K1 should be Number<...>
100  static constexpr auto K1 = Number<K1Value>{};
101 
102  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
103  {
104  // TODO: change this. I think it needs multi-dimensional alignment
105  constexpr auto max_lds_align = K1;
106 
107  // TODO: check alignment
108  // A matrix in LDS memory, dst of blockwise copy
109  constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
110  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
111 
112  // TODO: check alignment
113  // B matrix in LDS memory, dst of blockwise copy
114  constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned(
115  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
116 
117  // TODO: check alignment
118  // LDS allocation for A and B: be careful of alignment
119  constexpr auto a_block_aligned_space_size =
120  math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
121 
122  constexpr auto b_block_aligned_space_size =
123  math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align);
124 
125  return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
126  }
127 
128  __host__ __device__ static constexpr bool
129  CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
130  const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
131  const CGridDesc_M_N& c_grid_desc_m_n)
132  {
133  const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
134  const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
135  const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
136 
137  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
138 
139  return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
140  K0 == b_grid_desc_k0_n_k1.GetLength(I0) &&
141  K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
142  K1 == b_grid_desc_k0_n_k1.GetLength(I2)) &&
143  (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
144  }
145 
146  __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
147  {
148  const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
149 
150  return grid_size;
151  }
152 
153  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
154  {
155  const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
156 
157  return has_main_k_block_loop;
158  }
159 
160  __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
161  {
162  const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
163 
164  return has_double_tail_k_block_loop;
165  }
166 
167  __host__ __device__ static constexpr auto
168  MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1)
169  {
170  const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
171  const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
172 
173  const auto M1 = Number<MPerBlock>{};
174  const auto M0 = M / M1;
175 
176  const auto a_grid_desc_k0_m0_m1_k1 =
177  transform_tensor_descriptor(a_grid_desc_k0_m_k1,
183 
184  return a_grid_desc_k0_m0_m1_k1;
185  }
186 
187  __host__ __device__ static constexpr auto
188  MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1)
189  {
190  const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
191  const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
192 
193  const auto N1 = Number<NPerBlock>{};
194  const auto N0 = N / N1;
195 
196  const auto b_grid_desc_k0_n0_n1_k1 =
197  transform_tensor_descriptor(b_grid_desc_k0_n_k1,
203 
204  return b_grid_desc_k0_n0_n1_k1;
205  }
206 
207  __host__ __device__ static constexpr auto
208  MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n)
209  {
210  const auto M = c_grid_desc_m_n.GetLength(I0);
211  const auto N = c_grid_desc_m_n.GetLength(I1);
212 
213  constexpr auto M1 = Number<MPerBlock>{};
214  constexpr auto N1 = Number<NPerBlock>{};
215 
216  const auto M0 = M / M1;
217  const auto N0 = N / N1;
218 
219  constexpr auto M11 =
220  Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies{}, I1) *
221  M1PerThreadM111>{};
222  constexpr auto N11 =
223  Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies{}, I1) *
224  N1PerThreadN111>{};
225 
226  constexpr auto M10 = M1 / M11;
227  constexpr auto N10 = N1 / N11;
228 
229  const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
230  c_grid_desc_m_n,
232  make_unmerge_transform(make_tuple(N0, N10, N11))),
235 
236  return c_grid_desc_m0_m10_m11_n0_n10_n11;
237  }
238 
239  // return block_id to C matrix tile idx (m0, n0) mapping
240  __host__ __device__ static constexpr auto
241  MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
242  {
244  c_grid_desc_m_n);
245  }
246 
247  using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
248  using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
250  decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
251  using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
252 
253  template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
254  __device__ static void
255  Run(const FloatAB* __restrict__ p_a_grid,
256  const FloatAB* __restrict__ p_b_grid,
257  FloatC* __restrict__ p_c_grid,
258  FloatAB* __restrict__ p_shared_block,
259  const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1,
260  const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1,
261  const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11,
262  const Block2CTileMap& block_2_ctile_map,
265  {
266  const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
267  p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize());
268  const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
269  p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize());
270  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
271  p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
272 
273  // divide block work by [M, N]
274  const auto c_m0_n0_block_cluster_idx =
275  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
276 
277  // HACK: this forces index data into SGPR
278  const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
279  const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
280 
281  if(!block_2_ctile_map.ValidCTileIndex(
282  make_tuple(im0, in0),
283  make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
284  c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
285  {
286  return;
287  }
288 
289  // TODO: change this. I think it needs multi-dimensional alignment
290  constexpr auto max_lds_align = K1;
291 
292  // TODO: check alignment
293  // A matrix in LDS memory, dst of blockwise copy
294  // be careful of LDS alignment
295  constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
296  make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
297 
298  // TODO: check alignment
299  // B matrix in LDS memory, dst of blockwise copy
300  // be careful of LDS alignment
301  constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
302  make_tuple(Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
303 
304  // TODO: check alignment
305  // A matrix in LDS memory, for blockwise GEMM
306  constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
307  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
308 
309  // TODO: check alignment
310  // B matrix in LDS memory, for blockwise GEMM
311  constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
312  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
313 
314  static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
315  a_k0_m_k1_block_desc.GetElementSpaceSize() &&
316  b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
317  b_k0_n_k1_block_desc.GetElementSpaceSize() &&
318  "wrong!");
319 
320  // A matrix blockwise copy
321  auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
322  BlockSize,
324  Sequence<K0PerBlock, 1, MPerBlock, K1.value>,
325  ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
326  ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
327  ABlockTransferThreadClusterArrangeOrder,
328  FloatAB,
329  FloatAB,
330  remove_reference_t<decltype(a_grid_desc_k0_m0_m1_k1)>,
331  decltype(a_block_desc_k0_m0_m1_k1),
332  ABlockTransferSrcAccessOrder,
334  ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
335  ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths
336  ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
337  Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
338  false,
339  true>(a_grid_desc_k0_m0_m1_k1,
340  make_multi_index(0, im0, 0, 0),
341  a_block_desc_k0_m0_m1_k1,
342  make_multi_index(0, 0, 0, 0));
343 
344  // B matrix blockwise copy
345  auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
346  BlockSize,
348  Sequence<K0PerBlock, 1, NPerBlock, K1.value>,
349  BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
350  BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
351  BBlockTransferThreadClusterArrangeOrder,
352  FloatAB,
353  FloatAB,
354  remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>,
355  decltype(b_block_desc_k0_n0_n1_k1),
356  BBlockTransferSrcAccessOrder,
358  BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
359  BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths
360  BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
361  Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
362  false,
363  true>(b_grid_desc_k0_n0_n1_k1,
364  make_multi_index(0, in0, 0, 0),
365  b_block_desc_k0_n0_n1_k1,
366  make_multi_index(0, 0, 0, 0));
367 
368  // GEMM definition
369  // c_mtx += transpose(a_mtx) * b_mtx
370  // a_mtx[K0PerBlock, MPerBlock] is in LDS
371  // b_mtx[KPerBlocl, NPerBlock] is in LDS
372  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
373  // register
374  const auto blockwise_gemm =
376  BlockSize,
377  FloatAB,
378  FloatAB,
379  FloatAcc,
380  decltype(a_k0_m_k1_block_desc),
381  decltype(b_k0_n_k1_block_desc),
382  M1PerThreadM111,
383  N1PerThreadN111,
384  KPerThread,
385  M11N11ThreadClusterM110Xs,
386  M11N11ThreadClusterN110Xs,
387  M1PerThreadM111,
388  N1PerThreadN111>{};
389 
390  constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
391  decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
392 
393  constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed(
394  sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
395 
396  // LDS allocation for A and B: be careful of alignment
397  constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
398  a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
399 
400  constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
401  b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
402 
403  FloatAB* p_a_block_double = p_shared_block;
404  FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
405 
406  // register allocation for output
407  auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
408  c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
409 
410  // Initialize C
411  c_thread_buf.Clear();
412 
413  constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
414  constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
415 
416  auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
417  p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
418  auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
419  p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
420 
421  auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
422  p_a_block_double + a_block_aligned_space_size,
423  a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
424  auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
425  p_b_block_double + b_block_aligned_space_size,
426  b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
427 
428  // LDS double buffer: preload data into LDS
429  {
430  a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
431  b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
432 
433  a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
434  b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
435  }
436 
437  if constexpr(HasMainKBlockLoop)
438  {
439  const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0);
440 
441  index_t k_block_data_begin = 0;
442 
443  // LDS double buffer: main body
444  // use Do-While loop instead of For loop to simplify control flow
445  do
446  {
447  // even iteration
448  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
449  a_block_slice_copy_step);
450  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
451  b_block_slice_copy_step);
452 
453  // LDS doubel buffer: load next data from device mem
454  a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
455  b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
456 
457  block_sync_lds();
458 
459  // LDS double buffer: GEMM on current data
460  blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11,
461  a_block_even_buf,
462  b_block_even_buf,
463  c_thread_buf);
464 
465  // LDS double buffer: store next data to LDS
466  a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
467  b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
468 
469  // odd iteration
470  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
471  a_block_slice_copy_step);
472  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
473  b_block_slice_copy_step);
474 
475  // LDS double buffer: load next data from device mem
476  a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
477  b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
478 
479  block_sync_lds();
480 
481  // LDS double buffer: GEMM on current data
482  blockwise_gemm.Run(
483  c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
484 
485  // LDS double buffer: store next data to LDS
486  a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
487  b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
488 
489  k_block_data_begin += 2 * K0PerBlock;
490  } while(k_block_data_begin < K0 - 2 * K0PerBlock);
491  }
492 
493  // LDS double buffer: tail
494  if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
495  {
496  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step);
497  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step);
498 
499  block_sync_lds();
500 
501  // LDS double buffer: load last data from device mem
502  a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
503  b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
504 
505  // LDS double buffer: GEMM on 2nd-last data
506  blockwise_gemm.Run(
507  c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
508 
509  // LDS double buffer: store last data to LDS
510  a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
511  b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
512 
513  block_sync_lds();
514 
515  // LDS double buffer: GEMM on last data
516  blockwise_gemm.Run(
517  c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
518  }
519  else // if has 1 iteration left
520  {
521  __syncthreads();
522 
523  // LDS double buffer: GEMM on last data
524  blockwise_gemm.Run(
525  c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
526  }
527 
528  // output: register to global memory
529  {
530  constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
532  make_tuple(I1,
533  Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
535  I1,
538 
539  const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
540  blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
542 
544  FloatAcc,
545  FloatC,
546  decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
547  decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
549  Sequence<1,
550  c_m10_m11_n10_n11_thread_tensor_lengths[I0],
551  c_m10_m11_n10_n11_thread_tensor_lengths[I1],
552  1,
553  c_m10_m11_n10_n11_thread_tensor_lengths[I2],
554  c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
555  CThreadTransferSrcDstAccessOrder,
556  CThreadTransferSrcDstVectorDim,
557  CThreadTransferDstScalarPerVector,
558  CGlobalMemoryDataOperation,
559  1,
560  true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
561  make_multi_index(im0,
562  c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
563  c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
564  in0,
565  c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
566  c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
568  .Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
569  make_tuple(I0, I0, I0, I0, I0, I0),
570  c_thread_buf,
571  c_grid_desc_m0_m10_m11_n0_n10_n11,
572  c_grid_buf);
573  }
574  }
575 };
576 
577 template <index_t BlockSize,
578  typename FloatAB,
579  typename FloatAcc,
580  typename FloatC,
581  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
582  typename AGridDesc_B_K0_M_K1,
583  typename BGridDesc_B_K0_N_K1,
584  typename CGridDesc_M_N,
585  index_t MPerBlock,
586  index_t NPerBlock,
587  index_t K0PerBlock,
588  index_t K1Value,
589  index_t M1PerThreadM111,
590  index_t N1PerThreadN111,
591  index_t KPerThread,
592  typename M11N11ThreadClusterM110Xs,
593  typename M11N11ThreadClusterN110Xs,
594  typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
595  typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
596  typename ABlockTransferThreadClusterArrangeOrder,
597  typename ABlockTransferSrcAccessOrder,
598  typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
599  typename ABlockTransferSrcVectorTensorContiguousDimOrder,
600  typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
601  typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
602  typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
603  typename BBlockTransferThreadClusterArrangeOrder,
604  typename BBlockTransferSrcAccessOrder,
605  typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
606  typename BBlockTransferSrcVectorTensorContiguousDimOrder,
607  typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
608  typename CThreadTransferSrcDstAccessOrder,
609  index_t CThreadTransferSrcDstVectorDim,
610  index_t CThreadTransferDstScalarPerVector>
612 {
613  static constexpr auto I0 = Number<0>{};
614  static constexpr auto I1 = Number<1>{};
615  static constexpr auto I2 = Number<2>{};
616  static constexpr auto I3 = Number<3>{};
617 
618  // K1 should be Number<...>
619  static constexpr auto K1 = Number<K1Value>{};
620 
621  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
622  {
623  // TODO: change this. I think it needs multi-dimensional alignment
624  constexpr auto max_lds_align = K1;
625 
626  // TODO: check alignment
627  // A matrix in LDS memory, dst of blockwise copy
628  constexpr auto a_block_desc_b_k0_m_k1 = make_naive_tensor_descriptor_aligned(
629  make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
630 
631  // TODO: check alignment
632  // B matrix in LDS memory, dst of blockwise copy
633  constexpr auto b_block_desc_b_k0_n_k1 = make_naive_tensor_descriptor_aligned(
634  make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
635 
636  // TODO: check alignment
637  // LDS allocation for A and B: be careful of alignment
638  constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
639  a_block_desc_b_k0_m_k1.GetElementSpaceSize(), max_lds_align);
640 
641  constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
642  b_block_desc_b_k0_n_k1.GetElementSpaceSize(), max_lds_align);
643 
644  return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
645  }
646 
647  __host__ __device__ static constexpr bool
648  CheckValidity(const AGridDesc_B_K0_M_K1& a_grid_desc_b_k0_m_k1,
649  const BGridDesc_B_K0_N_K1& b_grid_desc_b_k0_n_k1,
650  const CGridDesc_M_N& c_grid_desc_m_n)
651  {
652  constexpr long_index_t TwoGB = (long_index_t{1} << 31);
653 
654  if(!(a_grid_desc_b_k0_m_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
655  b_grid_desc_b_k0_n_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
656  c_grid_desc_m_n.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB))
657  {
658  return false;
659  }
660 
661  const auto M = a_grid_desc_b_k0_m_k1.GetLength(I2);
662  const auto N = b_grid_desc_b_k0_n_k1.GetLength(I2);
663  const auto K0 = a_grid_desc_b_k0_m_k1.GetLength(I1);
664  const auto KBatch = a_grid_desc_b_k0_m_k1.GetLength(I0);
665 
666  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
667 
668  return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
669  K0 == b_grid_desc_b_k0_n_k1.GetLength(I1) &&
670  K1 == a_grid_desc_b_k0_m_k1.GetLength(I3) &&
671  K1 == b_grid_desc_b_k0_n_k1.GetLength(I3)) &&
672  KBatch == b_grid_desc_b_k0_n_k1.GetLength(I0) &&
673  (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
674  }
675 
676  __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
677  {
678  const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
679 
680  return grid_size;
681  }
682 
683  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
684  {
685  const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
686 
687  return has_main_k_block_loop;
688  }
689 
690  __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
691  {
692  const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
693 
694  return has_double_tail_k_block_loop;
695  }
696 
697  __host__ __device__ static constexpr auto
698  MakeAGridDescriptor_B_K0_M0_M1_K1(const AGridDesc_B_K0_M_K1& a_grid_desc_b_k0_m_k1)
699  {
700  const auto KBatch = a_grid_desc_b_k0_m_k1.GetLength(I0);
701  const auto K0 = a_grid_desc_b_k0_m_k1.GetLength(I1);
702  const auto M = a_grid_desc_b_k0_m_k1.GetLength(I2);
703 
704  const auto M1 = Number<MPerBlock>{};
705  const auto M0 = M / M1;
706 
707  const auto a_grid_desc_b_k0_m0_m1_k1 = transform_tensor_descriptor(
708  a_grid_desc_b_k0_m_k1,
715 
716  return a_grid_desc_b_k0_m0_m1_k1;
717  }
718 
719  __host__ __device__ static constexpr auto
720  MakeBGridDescriptor_B_K0_N0_N1_K1(const BGridDesc_B_K0_N_K1& b_grid_desc_b_k0_n_k1)
721  {
722  const auto KBatch = b_grid_desc_b_k0_n_k1.GetLength(I0);
723  const auto K0 = b_grid_desc_b_k0_n_k1.GetLength(I1);
724  const auto N = b_grid_desc_b_k0_n_k1.GetLength(I2);
725 
726  const auto N1 = Number<NPerBlock>{};
727  const auto N0 = N / N1;
728 
729  const auto b_grid_desc_b_k0_n0_n1_k1 = transform_tensor_descriptor(
730  b_grid_desc_b_k0_n_k1,
737 
738  return b_grid_desc_b_k0_n0_n1_k1;
739  }
740 
741  __host__ __device__ static constexpr auto
742  MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n)
743  {
744  const auto M = c_grid_desc_m_n.GetLength(I0);
745  const auto N = c_grid_desc_m_n.GetLength(I1);
746 
747  constexpr auto M1 = Number<MPerBlock>{};
748  constexpr auto N1 = Number<NPerBlock>{};
749 
750  const auto M0 = M / M1;
751  const auto N0 = N / N1;
752 
753  constexpr auto M11 =
754  Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies{}, I1) *
755  M1PerThreadM111>{};
756  constexpr auto N11 =
757  Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies{}, I1) *
758  N1PerThreadN111>{};
759 
760  constexpr auto M10 = M1 / M11;
761  constexpr auto N10 = N1 / N11;
762 
763  const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
764  c_grid_desc_m_n,
766  make_unmerge_transform(make_tuple(N0, N10, N11))),
769 
770  return c_grid_desc_m0_m10_m11_n0_n10_n11;
771  }
772 
773  // return block_id to C matrix tile idx (m0, n0) mapping
774  __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
775  const CGridDesc_M_N& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
776  {
778  c_m_n_grid_desc, M01, N01, KBatch);
779  }
780 
782  decltype(MakeAGridDescriptor_B_K0_M0_M1_K1(AGridDesc_B_K0_M_K1{}));
784  decltype(MakeBGridDescriptor_B_K0_N0_N1_K1(BGridDesc_B_K0_N_K1{}));
786  decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
787  using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
788 
789  template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
790  __device__ static void
791  Run(const FloatAB* __restrict__ p_a_grid,
792  const FloatAB* __restrict__ p_b_grid,
793  FloatC* __restrict__ p_c_grid,
794  FloatAB* __restrict__ p_shared_block,
795  const AGridDesc_B_K0_M0_M1_K1& a_grid_desc_b_k0_m0_m1_k1,
796  const BGridDesc_B_K0_N0_N1_K1& b_grid_desc_b_k0_n0_n1_k1,
797  const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11,
798  const CBlockClusterAdaptor& c_block_cluster_adaptor,
801  {
802  const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
803  p_a_grid, a_grid_desc_b_k0_m0_m1_k1.GetElementSpaceSize());
804  const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
805  p_b_grid, b_grid_desc_b_k0_n0_n1_k1.GetElementSpaceSize());
806  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
807  p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
808 
809  // divide block work by [M, N]
810  const auto block_work_idx =
811  c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
812 
813  const index_t k_batch_id = block_work_idx[I0];
814 
815  if(!c_block_cluster_adaptor.ValidCTileIndex(
816  make_tuple(block_work_idx[I1], block_work_idx[I2]),
817  make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
818  c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
819  {
820  return;
821  }
822 
823  // HACK: this force m/n_block_data_idx_on_grid into SGPR
824  const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
825 
826  const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I2]);
827 
828  // TODO: change this. I think it needs multi-dimensional alignment
829  constexpr auto max_lds_align = K1;
830 
831  // TODO: check alignment
832  // A matrix in LDS memory, dst of blockwise copy
833  // be careful of LDS alignment
834  constexpr auto a_block_desc_b_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
835  make_tuple(I1, Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
836 
837  // TODO: check alignment
838  // B matrix in LDS memory, dst of blockwise copy
839  // be careful of LDS alignment
840  constexpr auto b_block_desc_b_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
841  make_tuple(I1, Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
842 
843  // TODO: check alignment
844  // A matrix in LDS memory, dst of blockwise copy
845  // be careful of LDS alignment
846  constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
847  make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
848 
849  // TODO: check alignment
850  // B matrix in LDS memory, dst of blockwise copy
851  // be careful of LDS alignment
852  constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
853  make_tuple(Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
854 
855  // TODO: check alignment
856  // A matrix in LDS memory, for blockwise GEMM
857  constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
858  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
859 
860  // TODO: check alignment
861  // B matrix in LDS memory, for blockwise GEMM
862  constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
863  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
864 
865  static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
866  a_k0_m_k1_block_desc.GetElementSpaceSize() &&
867  b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
868  b_k0_n_k1_block_desc.GetElementSpaceSize() &&
869  "wrong!");
870 
871  // A matrix blockwise copy
872  auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
873  BlockSize,
875  Sequence<1, K0PerBlock, 1, MPerBlock, K1.value>,
876  ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
877  ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
878  ABlockTransferThreadClusterArrangeOrder,
879  FloatAB,
880  FloatAB,
881  remove_reference_t<decltype(a_grid_desc_b_k0_m0_m1_k1)>,
882  decltype(a_block_desc_b_k0_m0_m1_k1),
883  ABlockTransferSrcAccessOrder,
885  ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
886  ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths
887  ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
888  Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
889  false,
890  true>(a_grid_desc_b_k0_m0_m1_k1,
891  make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0, 0),
892  a_block_desc_b_k0_m0_m1_k1,
893  make_multi_index(0, 0, 0, 0, 0));
894 
895  // B matrix blockwise copy
896  auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
897  BlockSize,
899  Sequence<1, K0PerBlock, 1, NPerBlock, K1.value>,
900  BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
901  BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
902  BBlockTransferThreadClusterArrangeOrder,
903  FloatAB,
904  FloatAB,
905  remove_reference_t<decltype(b_grid_desc_b_k0_n0_n1_k1)>,
906  decltype(b_block_desc_b_k0_n0_n1_k1),
907  BBlockTransferSrcAccessOrder,
909  BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
910  BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths
911  BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
912  Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
913  false,
914  true>(b_grid_desc_b_k0_n0_n1_k1,
915  make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0, 0),
916  b_block_desc_b_k0_n0_n1_k1,
917  make_multi_index(0, 0, 0, 0, 0));
918 
919  // GEMM definition
920  // c_mtx += transpose(a_mtx) * b_mtx
921  // a_mtx[K0PerBlock, MPerBlock] is in LDS
922  // b_mtx[KPerBlocl, NPerBlock] is in LDS
923  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
924  // register
925  const auto blockwise_gemm =
927  BlockSize,
928  FloatAB,
929  FloatAB,
930  FloatAcc,
931  decltype(a_k0_m_k1_block_desc),
932  decltype(b_k0_n_k1_block_desc),
933  M1PerThreadM111,
934  N1PerThreadN111,
935  KPerThread,
936  M11N11ThreadClusterM110Xs,
937  M11N11ThreadClusterN110Xs,
938  M1PerThreadM111,
939  N1PerThreadN111>{};
940 
941  constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
942  decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
943 
944  constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed(
945  sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
946 
947  // LDS allocation for A and B: be careful of alignment
948  constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
949  a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
950 
951  constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
952  b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
953 
954  FloatAB* p_a_block_double = p_shared_block;
955  FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
956 
957  // register allocation for output
958  auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
959  c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
960 
961  // Initialize C
962  c_thread_buf.Clear();
963 
964  constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0, 0);
965  constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0, 0);
966 
967  auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
968  p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
969  auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
970  p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
971 
972  auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
973  p_a_block_double + a_block_aligned_space_size,
974  a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
975  auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
976  p_b_block_double + b_block_aligned_space_size,
977  b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
978 
979  // LDS double buffer: preload data into LDS
980  {
981  a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
982  b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
983 
984  a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_even_buf);
985  b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_even_buf);
986  }
987 
988  if constexpr(HasMainKBlockLoop)
989  {
990  const auto K0 = a_grid_desc_b_k0_m0_m1_k1.GetLength(I1);
991 
992  index_t k_block_data_begin = 0;
993 
994  // LDS double buffer: main body
995  // use Do-While loop instead of For loop to simplify control flow
996  do
997  {
998  // even iteration
999  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_b_k0_m0_m1_k1,
1000  a_block_slice_copy_step);
1001  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1,
1002  b_block_slice_copy_step);
1003 
1004  // LDS double buffer: load next data from device mem
1005  a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
1006  b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
1007 
1008  block_sync_lds();
1009 
1010  // LDS double buffer: GEMM on current data
1011  blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11,
1012  a_block_even_buf,
1013  b_block_even_buf,
1014  c_thread_buf);
1015 
1016  // LDS double buffer: store next data to LDS
1017  a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_odd_buf);
1018  b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_odd_buf);
1019 
1020  // odd iteration
1021  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_b_k0_m0_m1_k1,
1022  a_block_slice_copy_step);
1023  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1,
1024  b_block_slice_copy_step);
1025 
1026  // LDS doubel buffer: load next data from device mem
1027  a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
1028  b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
1029 
1030  block_sync_lds();
1031 
1032  // LDS double buffer: GEMM on current data
1033  blockwise_gemm.Run(
1034  c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
1035 
1036  // LDS double buffer: store next data to LDS
1037  a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_even_buf);
1038  b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_even_buf);
1039 
1040  k_block_data_begin += 2 * K0PerBlock;
1041  } while(k_block_data_begin < K0 - 2 * K0PerBlock);
1042  }
1043 
1044  // LDS double buffer: tail
1045  if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
1046  {
1047  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_b_k0_m0_m1_k1, a_block_slice_copy_step);
1048  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1, b_block_slice_copy_step);
1049 
1050  block_sync_lds();
1051 
1052  // LDS double buffer: load last data from device mem
1053  a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
1054  b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
1055 
1056  // LDS double buffer: GEMM on 2nd-last data
1057  blockwise_gemm.Run(
1058  c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
1059 
1060  // LDS double buffer: store last data to LDS
1061  a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_odd_buf);
1062  b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_odd_buf);
1063 
1064  block_sync_lds();
1065 
1066  // LDS double buffer: GEMM on last data
1067  blockwise_gemm.Run(
1068  c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
1069  }
1070  else // if has 1 iteration left
1071  {
1072  __syncthreads();
1073 
1074  // LDS double buffer: GEMM on last data
1075  blockwise_gemm.Run(
1076  c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
1077  }
1078 
1079  // output: register to global memory
1080  {
1081  constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
1083  make_tuple(I1,
1084  Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
1086  I1,
1089 
1090  const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
1091  blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
1093 
1095  FloatAcc,
1096  FloatC,
1097  decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
1098  decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
1100  Sequence<1,
1101  c_m10_m11_n10_n11_thread_tensor_lengths[I0],
1102  c_m10_m11_n10_n11_thread_tensor_lengths[I1],
1103  1,
1104  c_m10_m11_n10_n11_thread_tensor_lengths[I2],
1105  c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
1106  CThreadTransferSrcDstAccessOrder,
1107  CThreadTransferSrcDstVectorDim,
1108  CThreadTransferDstScalarPerVector,
1109  CGlobalMemoryDataOperation,
1110  1,
1111  true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
1112  make_multi_index(m_block_data_idx_on_grid,
1113  c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
1114  c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
1115  n_block_data_idx_on_grid,
1116  c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
1117  c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
1119  .Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
1120  make_tuple(I0, I0, I0, I0, I0, I0),
1121  c_thread_buf,
1122  c_grid_desc_m0_m10_m11_n0_n10_n11,
1123  c_grid_buf);
1124  }
1125  }
1126 };
1127 
1128 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
InMemoryDataOperationEnum
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
int64_t long_index_t
Definition: ck.hpp:290
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__global__ void kernel_gemm_dl_v1r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_dl_v1r3.hpp:33
__host__ constexpr __device__ auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition: container_helper.hpp:380
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:16
__host__ constexpr __device__ auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition: container_helper.hpp:111
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: block_to_ctile_map.hpp:718
Definition: block_to_ctile_map.hpp:615
Definition: blockwise_tensor_slice_transfer_v5r1.hpp:37
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step)
Definition: blockwise_tensor_slice_transfer_v5r1.hpp:100
Definition: gridwise_gemm_dl_v1r3.hpp:612
static constexpr auto I2
Definition: gridwise_gemm_dl_v1r3.hpp:615
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_dl_v1r3.hpp:621
decltype(MakeAGridDescriptor_B_K0_M0_M1_K1(AGridDesc_B_K0_M_K1{})) AGridDesc_B_K0_M0_M1_K1
Definition: gridwise_gemm_dl_v1r3.hpp:782
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:742
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_dl_v1r3.hpp:676
__host__ static constexpr __device__ auto MakeAGridDescriptor_B_K0_M0_M1_K1(const AGridDesc_B_K0_M_K1 &a_grid_desc_b_k0_m_k1)
Definition: gridwise_gemm_dl_v1r3.hpp:698
__host__ static constexpr __device__ auto MakeCBlockClusterAdaptor(const CGridDesc_M_N &c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
Definition: gridwise_gemm_dl_v1r3.hpp:774
decltype(MakeBGridDescriptor_B_K0_N0_N1_K1(BGridDesc_B_K0_N_K1{})) BGridDesc_B_K0_N0_N1_K1
Definition: gridwise_gemm_dl_v1r3.hpp:784
__host__ static constexpr __device__ bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_v1r3.hpp:690
static constexpr auto K1
Definition: gridwise_gemm_dl_v1r3.hpp:619
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_v1r3.hpp:683
static constexpr auto I1
Definition: gridwise_gemm_dl_v1r3.hpp:614
__host__ static constexpr __device__ auto MakeBGridDescriptor_B_K0_N0_N1_K1(const BGridDesc_B_K0_N_K1 &b_grid_desc_b_k0_n_k1)
Definition: gridwise_gemm_dl_v1r3.hpp:720
static constexpr auto I0
Definition: gridwise_gemm_dl_v1r3.hpp:613
static constexpr auto I3
Definition: gridwise_gemm_dl_v1r3.hpp:616
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_B_K0_M_K1 &a_grid_desc_b_k0_m_k1, const BGridDesc_B_K0_N_K1 &b_grid_desc_b_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:648
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition: gridwise_gemm_dl_v1r3.hpp:786
decltype(MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)) CBlockClusterAdaptor
Definition: gridwise_gemm_dl_v1r3.hpp:787
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, FloatAB *__restrict__ p_shared_block, const AGridDesc_B_K0_M0_M1_K1 &a_grid_desc_b_k0_m0_m1_k1, const BGridDesc_B_K0_N0_N1_K1 &b_grid_desc_b_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 &c_grid_desc_m0_m10_m11_n0_n10_n11, const CBlockClusterAdaptor &c_block_cluster_adaptor, integral_constant< bool, HasMainKBlockLoop >, integral_constant< bool, HasDoubleTailKBlockLoop >)
Definition: gridwise_gemm_dl_v1r3.hpp:791
Definition: gridwise_gemm_dl_v1r3.hpp:93
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, FloatAB *__restrict__ p_shared_block, const AGridDesc_K0_M0_M1_K1 &a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 &b_grid_desc_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 &c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap &block_2_ctile_map, integral_constant< bool, HasMainKBlockLoop >, integral_constant< bool, HasDoubleTailKBlockLoop >)
Definition: gridwise_gemm_dl_v1r3.hpp:255
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:129
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:208
static constexpr auto I2
Definition: gridwise_gemm_dl_v1r3.hpp:96
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_dl_v1r3.hpp:102
static constexpr auto K1
Definition: gridwise_gemm_dl_v1r3.hpp:100
decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})) AGridDesc_K0_M0_M1_K1
Definition: gridwise_gemm_dl_v1r3.hpp:247
__host__ static constexpr __device__ auto MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1)
Definition: gridwise_gemm_dl_v1r3.hpp:188
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_dl_v1r3.hpp:146
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_v1r3.hpp:153
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:241
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition: gridwise_gemm_dl_v1r3.hpp:250
__host__ static constexpr __device__ bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_v1r3.hpp:160
decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{})) Block2CTileMap
Definition: gridwise_gemm_dl_v1r3.hpp:251
static constexpr auto I1
Definition: gridwise_gemm_dl_v1r3.hpp:95
decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})) BGridDesc_K0_N0_N1_K1
Definition: gridwise_gemm_dl_v1r3.hpp:248
static constexpr auto I0
Definition: gridwise_gemm_dl_v1r3.hpp:94
__host__ static constexpr __device__ auto MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1)
Definition: gridwise_gemm_dl_v1r3.hpp:168
static constexpr auto I3
Definition: gridwise_gemm_dl_v1r3.hpp:97
Definition: sequence.hpp:43
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:10
Definition: math.hpp:34
Definition: unary_element_wise_operation.hpp:241