/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp Source File
gridwise_gemm_xdlops_v2r4.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
16 
17 namespace ck {
18 
19 template <typename GridwiseGemm,
20  typename FloatAB,
21  typename FloatC,
22  typename ABK0MK1GridDesc,
23  typename BBK0NK1GridDesc,
24  typename CM0N0M1N1M2M3M4N2GridDesc,
25  typename AElementwiseOperation,
26  typename BElementwiseOperation,
27  typename CElementwiseOperation,
28  typename CBlockClusterAdaptor,
29  bool HasMainKBlockLoop>
30 __global__ void
31 #if CK_USE_LAUNCH_BOUNDS
33 #endif
34  kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid,
35  const FloatAB* __restrict__ p_b_grid,
36  FloatC* __restrict__ p_c_grid,
37  const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc,
38  const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc,
39  const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
40  const AElementwiseOperation a_element_op,
41  const BElementwiseOperation b_element_op,
42  const CElementwiseOperation c_element_op,
43  const CBlockClusterAdaptor c_block_cluster_adaptor)
44 {
45 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
46  defined(__gfx94__))
47  constexpr index_t shared_block_size =
48  GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
49 
50  __shared__ FloatAB p_shared_block[shared_block_size];
51 
52  GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
53  p_b_grid,
54  p_c_grid,
55  p_shared_block,
56  a_b_k0_m_k1_grid_desc,
57  b_b_k0_n_k1_grid_desc,
58  c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
59  a_element_op,
60  b_element_op,
61  c_element_op,
62  c_block_cluster_adaptor);
63 #else
64  ignore = p_a_grid;
65  ignore = p_b_grid;
66  ignore = p_c_grid;
67  ignore = a_b_k0_m_k1_grid_desc;
68  ignore = b_b_k0_n_k1_grid_desc;
69  ignore = c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc;
70  ignore = a_element_op;
71  ignore = b_element_op;
72  ignore = c_element_op;
73  ignore = c_block_cluster_adaptor;
74 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
75 }
76 
77 template <index_t BlockSize,
78  typename FloatAB,
79  typename FloatAcc,
80  typename FloatC,
81  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
82  typename ABK0MK1GridDesc,
83  typename BBK0NK1GridDesc,
84  typename CMNGridDesc,
85  typename AElementwiseOperation,
86  typename BElementwiseOperation,
87  typename CElementwiseOperation,
88  index_t MPerBlock,
89  index_t NPerBlock,
90  index_t K0PerBlock,
91  index_t MPerXDL,
92  index_t NPerXDL,
93  index_t K1Value,
94  index_t MRepeat,
95  index_t NRepeat,
96  typename ABlockTransferThreadClusterLengths_K0_M_K1,
97  typename ABlockTransferThreadClusterArrangeOrder,
98  typename ABlockTransferSrcAccessOrder,
99  index_t ABlockTransferSrcVectorDim,
100  index_t ABlockTransferSrcScalarPerVector,
101  index_t ABlockTransferDstScalarPerVector_K1,
102  bool AThreadTransferSrcResetCoordinateAfterRun,
103  bool ABlockLdsExtraM,
104  typename BBlockTransferThreadClusterLengths_K0_N_K1,
105  typename BBlockTransferThreadClusterArrangeOrder,
106  typename BBlockTransferSrcAccessOrder,
107  index_t BBlockTransferSrcVectorDim,
108  index_t BBlockTransferSrcScalarPerVector,
109  index_t BBlockTransferDstScalarPerVector_K1,
110  bool BThreadTransferSrcResetCoordinateAfterRun,
111  bool BBlockLdsExtraN,
112  typename CThreadTransferSrcDstAccessOrder,
113  index_t CThreadTransferSrcDstVectorDim,
114  index_t CThreadTransferDstScalarPerVector>
116 {
117  static constexpr auto I0 = Number<0>{};
118  static constexpr auto I1 = Number<1>{};
119  static constexpr auto I2 = Number<2>{};
120  static constexpr auto I3 = Number<3>{};
121  static constexpr auto I4 = Number<4>{};
122  static constexpr auto I5 = Number<5>{};
123  static constexpr auto I6 = Number<6>{};
124  static constexpr auto I7 = Number<7>{};
125 
126  // K1 should be Number<...>
127  static constexpr auto K1 = Number<K1Value>{};
128 
130 
131  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
132  {
133  constexpr auto max_lds_align = K1;
134 
135  // A matrix in LDS memory, dst of blockwise copy
136  constexpr auto a_k0_m_k1_block_desc = [&]() {
137  if constexpr(ABlockLdsExtraM)
138  {
142  }
143  else
144  {
146  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
147  }
148  }();
149 
150  // B matrix in LDS memory, dst of blockwise copy
151  constexpr auto b_k0_n_k1_block_desc = [&]() {
152  if constexpr(BBlockLdsExtraN)
153  {
157  }
158  else
159  {
161  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
162  }
163  }();
164 
165  // LDS allocation for A and B: be careful of alignment
166  constexpr auto a_block_space_size =
167  math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
168 
169  constexpr auto b_block_space_size =
170  math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
171 
172  return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
173  }
174 
175  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
176  template <typename Block2CTileMap>
177  __host__ __device__ static constexpr bool
178  CheckValidity(const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
179  const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
180  const CMNGridDesc& c_m_n_grid_desc,
181  const Block2CTileMap& block_2_ctile_map)
182  {
183  static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
184  "wrong! K1 need to be known at compile-time");
185 
186  static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
187  (NPerBlock % (NRepeat * NPerXDL)) == 0,
188  "Invalid tuning param!");
189 
190  const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
191  const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
192  const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
193  const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
194 
195  if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
196  K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) &&
197  K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) &&
198  K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) &&
199  KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0)))
200  return false;
201 
202  if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
203  return false;
204 
205  if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc))
206  {
207  return false;
208  }
209 
210  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
211  return true;
212  }
213 
214  __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
215  {
216  const bool has_main_k0_block_loop = K0 > K0PerBlock;
217 
218  return has_main_k0_block_loop;
219  }
220 
221  __host__ __device__ static constexpr auto
222  MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
223  {
224  constexpr auto max_lds_align = K1;
225 
226  // A matrix in LDS memory, dst of blockwise copy
227  constexpr auto a_k0_m_k1_block_desc = [&]() {
228  if constexpr(ABlockLdsExtraM)
229  {
233  }
234  else
235  {
237  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
238  }
239  }();
240 
241  // B matrix in LDS memory, dst of blockwise copy
242  constexpr auto b_k0_n_k1_block_desc = [&]() {
243  if constexpr(BBlockLdsExtraN)
244  {
248  }
249  else
250  {
252  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
253  }
254  }();
255 
256  using BlockwiseGemm =
258  FloatAB,
259  FloatAcc,
260  decltype(a_k0_m_k1_block_desc),
261  decltype(b_k0_n_k1_block_desc),
262  MPerXDL,
263  NPerXDL,
264  MRepeat,
265  NRepeat,
266  K1>;
267 
268  return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_m_n_grid_desc);
269  }
270 
271  // return block_id to C matrix tile idx (m0, n0) mapping
272  __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
273  const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
274  {
276  c_m_n_grid_desc, 8, KBatch);
277  }
278 
280  using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
281 
282  template <bool HasMainKBlockLoop>
283  __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
284  const FloatAB* __restrict__ p_b_grid,
285  FloatC* __restrict__ p_c_grid,
286  FloatAB* __restrict__ p_shared_block,
287  const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
288  const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
289  const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
290  const AElementwiseOperation& a_element_op,
291  const BElementwiseOperation& b_element_op,
292  const CElementwiseOperation& c_element_op,
293  const CBlockClusterAdaptor& c_block_cluster_adaptor)
294  {
295  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
296  p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
297  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
298  p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
299  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
300  p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
301 
302  const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
303 
304  // divide block work by [M, N]
305  const auto block_work_idx =
306  c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
307 
308  if(!c_block_cluster_adaptor.ValidCTileIndex(
309  make_tuple(block_work_idx[I1], block_work_idx[I2]),
310  make_tuple(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetLength(I0),
311  c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetLength(I1))))
312  {
313  return;
314  }
315 
316  const index_t k_batch_id = block_work_idx[I0];
317 
318  // HACK: this force m/n_block_data_idx_on_grid into SGPR
319  const index_t m_block_data_idx_on_grid =
320  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
321 
322  const index_t n_block_data_idx_on_grid =
323  __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
324 
325  // lds max alignment
326  constexpr auto max_lds_align = K1;
327 
328  // A matrix in LDS memory, dst of blockwise copy
329  constexpr auto a_k0_m_k1_block_desc = [&]() {
330  if constexpr(ABlockLdsExtraM)
331  {
335  }
336  else
337  {
339  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
340  }
341  }();
342 
343  constexpr auto a_b_k0_m_k1_block_desc = [&]() {
344  if constexpr(ABlockLdsExtraM)
345  {
350  K1,
351  I1));
352  }
353  else
354  {
357  max_lds_align);
358  }
359  }();
360  // B matrix in LDS memory, dst of blockwise copy
361  constexpr auto b_k0_n_k1_block_desc = [&]() {
362  if constexpr(BBlockLdsExtraN)
363  {
367  }
368  else
369  {
371  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
372  }
373  }();
374 
375  constexpr auto b_b_k0_n_k1_block_desc = [&]() {
376  if constexpr(BBlockLdsExtraN)
377  {
382  K1,
383  I1));
384  }
385  else
386  {
389  max_lds_align);
390  }
391  }();
392  // A matrix blockwise copy
393  auto a_blockwise_copy =
395  AElementwiseOperation,
399  ABlockTransferThreadClusterLengths_K0_M_K1,
400  ABlockTransferThreadClusterArrangeOrder,
401  FloatAB,
402  FloatAB,
403  decltype(a_b_k0_m_k1_grid_desc),
404  decltype(a_b_k0_m_k1_block_desc),
405  ABlockTransferSrcAccessOrder,
407  ABlockTransferSrcVectorDim,
408  3,
409  ABlockTransferSrcScalarPerVector,
410  ABlockTransferDstScalarPerVector_K1,
411  1,
412  1,
413  AThreadTransferSrcResetCoordinateAfterRun,
414  true>(
415  a_b_k0_m_k1_grid_desc,
416  make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
417  a_element_op,
418  a_b_k0_m_k1_block_desc,
419  make_multi_index(0, 0, 0, 0),
421 
422  // B matrix blockwise copy
423  auto b_blockwise_copy =
425  BElementwiseOperation,
429  BBlockTransferThreadClusterLengths_K0_N_K1,
430  BBlockTransferThreadClusterArrangeOrder,
431  FloatAB,
432  FloatAB,
433  decltype(b_b_k0_n_k1_grid_desc),
434  decltype(b_b_k0_n_k1_block_desc),
435  BBlockTransferSrcAccessOrder,
437  BBlockTransferSrcVectorDim,
438  3,
439  BBlockTransferSrcScalarPerVector,
440  BBlockTransferDstScalarPerVector_K1,
441  1,
442  1,
443  BThreadTransferSrcResetCoordinateAfterRun,
444  true>(
445  b_b_k0_n_k1_grid_desc,
446  make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
447  b_element_op,
448  b_b_k0_n_k1_block_desc,
449  make_multi_index(0, 0, 0, 0),
451 
452  // GEMM definition
453  // c_mtx += transpose(a_mtx) * b_mtx
454  // a_mtx[K0PerBlock, MPerBlock] is in LDS
455  // b_mtx[K0PerBlock, NPerBlock] is in LDS
456  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
457  // register
458  // sanity check
459 
460  auto blockwise_gemm =
462  FloatAB,
463  FloatAcc,
464  decltype(a_k0_m_k1_block_desc),
465  decltype(b_k0_n_k1_block_desc),
466  MPerXDL,
467  NPerXDL,
468  MRepeat,
469  NRepeat,
470  K1>{};
471 
472  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
473 
474  // LDS allocation for A and B: be careful of alignment
475  constexpr auto a_block_space_size =
476  math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
477 
478  FloatAB* p_a_block = p_shared_block;
479  FloatAB* p_b_block = p_shared_block + a_block_space_size;
480 
481  constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
482  constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
483 
484  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
485  p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
486  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
487  p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
488 
489  // preload data into LDS
490  {
491  a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
492  b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
493 
494  a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
495  b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
496  }
497 
498  // Initialize C
499  c_thread_buf.Clear();
500 
501  // main body
502  if constexpr(HasMainKBlockLoop)
503  {
504  index_t k0_block_data_begin = 0;
505 
506  do
507  {
508  a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step);
509  b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step);
510 
511  a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
512 
513  block_sync_lds();
514 
515  b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
516 
517  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
518 
519  block_sync_lds();
520 
521  a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
522  b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
523 
524  k0_block_data_begin += K0PerBlock;
525  } while(k0_block_data_begin < (K0 - K0PerBlock));
526  }
527 
528  // tail
529  {
530  block_sync_lds();
531 
532  blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
533  }
534 
535  // output: register to global memory
536  {
537  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
538  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
539 
540  constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
541  constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
542  constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
543  constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
544  constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
545  constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
546  constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
547  constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
548 
549  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
551  Number<M0>{}, Number<N0>{}, I1, I1, Number<M2>{}, I1, Number<M4>{}, I1));
552 
553  // calculate origin of thread output tensor on global memory
554  // blockwise GEMM c matrix starting index
555  const auto c_thread_mtx_on_block =
556  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
557 
558  const index_t m_thread_data_on_grid =
559  m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
560 
561  const index_t n_thread_data_on_grid =
562  n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
563 
564  const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
566  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
569 
570  const auto m_thread_data_on_grid_idx =
571  m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
572  make_multi_index(m_thread_data_on_grid));
573 
574  const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
578 
579  const auto n_thread_data_on_grid_idx =
580  n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
581  make_multi_index(n_thread_data_on_grid));
582 
583  auto c_thread_copy =
585  FloatC,
586  decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
587  decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
588  CElementwiseOperation,
590  CThreadTransferSrcDstAccessOrder,
591  CThreadTransferSrcDstVectorDim,
592  CThreadTransferDstScalarPerVector,
593  CGlobalMemoryDataOperation,
594  1,
595  true>{
596 
597  c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
598  make_multi_index(m_thread_data_on_grid_idx[I0],
599  n_thread_data_on_grid_idx[I0],
600  m_thread_data_on_grid_idx[I1],
601  n_thread_data_on_grid_idx[I1],
602  m_thread_data_on_grid_idx[I2],
603  m_thread_data_on_grid_idx[I3],
604  m_thread_data_on_grid_idx[I4],
605  n_thread_data_on_grid_idx[I2]),
606  c_element_op};
607 
608  c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
609  make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
610  c_thread_buf,
611  c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
612  c_grid_buf);
613  }
614  }
615 };
616 
617 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:267
__global__ void kernel_gemm_xdlops_v2r4(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc, const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const CBlockClusterAdaptor c_block_cluster_adaptor)
Definition: gridwise_gemm_xdlops_v2r4.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:298
Definition: block_to_ctile_map.hpp:539
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
Definition: gridwise_gemm_xdlops_v2r4.hpp:116
static constexpr auto I3
Definition: gridwise_gemm_xdlops_v2r4.hpp:120
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_v2r4.hpp:129
static constexpr auto I6
Definition: gridwise_gemm_xdlops_v2r4.hpp:123
__host__ static constexpr __device__ auto MakeCBlockClusterAdaptor(const CMNGridDesc &c_m_n_grid_desc, index_t, index_t, index_t KBatch)
Definition: gridwise_gemm_xdlops_v2r4.hpp:272
static constexpr auto K1
Definition: gridwise_gemm_xdlops_v2r4.hpp:127
static constexpr auto I5
Definition: gridwise_gemm_xdlops_v2r4.hpp:122
static constexpr auto I7
Definition: gridwise_gemm_xdlops_v2r4.hpp:124
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, FloatAB *__restrict__ p_shared_block, const ABK0MK1GridDesc &a_b_k0_m_k1_grid_desc, const BBK0NK1GridDesc &b_b_k0_n_k1_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc &c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const CBlockClusterAdaptor &c_block_cluster_adaptor)
Definition: gridwise_gemm_xdlops_v2r4.hpp:283
__host__ static constexpr __device__ auto MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc &c_m_n_grid_desc)
Definition: gridwise_gemm_xdlops_v2r4.hpp:222
static constexpr auto I4
Definition: gridwise_gemm_xdlops_v2r4.hpp:121
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_v2r4.hpp:131
decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)) CBlockClusterAdaptor
Definition: gridwise_gemm_xdlops_v2r4.hpp:280
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v2r4.hpp:118
static constexpr auto I0
Definition: gridwise_gemm_xdlops_v2r4.hpp:117
__host__ static constexpr __device__ bool CalculateHasMainK0BlockLoop(index_t K0)
Definition: gridwise_gemm_xdlops_v2r4.hpp:214
static constexpr auto I2
Definition: gridwise_gemm_xdlops_v2r4.hpp:119
__host__ static constexpr __device__ bool CheckValidity(const ABK0MK1GridDesc &a_b_k0_m_k1_grid_desc, const BBK0NK1GridDesc &b_b_k0_n_k1_grid_desc, const CMNGridDesc &c_m_n_grid_desc, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v2r4.hpp:178
decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})) CM0N0M1N1M2M3M4N2GridDesc
Definition: gridwise_gemm_xdlops_v2r4.hpp:279
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id)
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:137
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:113
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step)
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:147
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:125
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:10
Definition: is_known_at_compile_time.hpp:14
Definition: unary_element_wise_operation.hpp:241