/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp Source File
gridwise_gemm_wmma.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck {
19 
20 template <typename GridwiseGemm,
21  typename ADataType,
22  typename BDataType,
23  typename CDataType,
24  typename AGridDesc,
25  typename BGridDesc,
26  typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
27  typename AElementwiseOperation,
28  typename BElementwiseOperation,
29  typename CElementwiseOperation,
30  typename Block2CTileMap,
31  bool HasMainKBlockLoop>
32 __global__ void
33 #if CK_USE_LAUNCH_BOUNDS
35 #endif
36  kernel_gemm_wmma(const ADataType* __restrict__ p_a_grid,
37  const BDataType* __restrict__ p_b_grid,
38  CDataType* __restrict__ p_c_grid,
39  const AGridDesc a_grid_desc,
40  const BGridDesc b_grid_desc,
41  const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
42  c_grid_desc_mblock_mperblock_nblock_nperblock,
43  const AElementwiseOperation a_element_op,
44  const BElementwiseOperation b_element_op,
45  const CElementwiseOperation c_element_op,
46  const Block2CTileMap block_2_ctile_map)
47 {
48 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
49  __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
50 
51  GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
52  p_b_grid,
53  p_c_grid,
54  p_shared,
55  a_grid_desc,
56  b_grid_desc,
57  c_grid_desc_mblock_mperblock_nblock_nperblock,
58  a_element_op,
59  b_element_op,
60  c_element_op,
61  block_2_ctile_map);
62 #else
63  ignore = p_a_grid;
64  ignore = p_b_grid;
65  ignore = p_c_grid;
66  ignore = a_grid_desc;
67  ignore = b_grid_desc;
68  ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
69  ignore = a_element_op;
70  ignore = b_element_op;
71  ignore = c_element_op;
72  ignore = block_2_ctile_map;
73 #endif // end of if (defined(__gfx11__))
74 }
75 
76 template <index_t BlockSize,
77  typename ADataType,
78  typename BDataType,
79  typename AccDataType,
80  typename CShuffleDataType,
81  typename CDataType,
82  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
83  typename AGridDesc,
84  typename BGridDesc,
85  typename CGridDesc_M_N,
86  typename AElementwiseOperation,
87  typename BElementwiseOperation,
88  typename CElementwiseOperation,
89  index_t MPerBlock,
90  index_t NPerBlock,
91  index_t KPerBlock,
92  index_t MPerWmma,
93  index_t NPerWmma,
94  index_t K1Value,
95  index_t MRepeat,
96  index_t NRepeat,
97  typename ABlockTransferThreadClusterLengths_K0_M_K1,
98  typename ABlockTransferThreadClusterArrangeOrder,
99  typename ABlockTransferSrcAccessOrder,
100  index_t ABlockTransferSrcVectorDim,
101  index_t ABlockTransferSrcScalarPerVector,
102  index_t ABlockTransferDstScalarPerVector_K1,
103  bool AThreadTransferSrcResetCoordinateAfterRun,
104  bool AEnableLds,
105  bool ABlockLdsExtraM,
106  typename BBlockTransferThreadClusterLengths_K0_N_K1,
107  typename BBlockTransferThreadClusterArrangeOrder,
108  typename BBlockTransferSrcAccessOrder,
109  index_t BBlockTransferSrcVectorDim,
110  index_t BBlockTransferSrcScalarPerVector,
111  index_t BBlockTransferDstScalarPerVector_K1,
112  bool BThreadTransferSrcResetCoordinateAfterRun,
113  bool BEnableLds,
114  bool BBlockLdsExtraN,
115  index_t CShuffleMRepeatPerShuffle,
116  index_t CShuffleNRepeatPerShuffle,
117  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
118  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
119  index_t NumGemmKPrefetchStage = 1,
121  PipelineVersion PipelineVer = PipelineVersion::v1>
123 {
124  static constexpr auto I0 = Number<0>{};
125  static constexpr auto I1 = Number<1>{};
126  static constexpr auto I2 = Number<2>{};
127  static constexpr auto I3 = Number<3>{};
128  static constexpr auto I4 = Number<4>{};
129  static constexpr auto I5 = Number<5>{};
130  static constexpr auto I6 = Number<6>{};
131  static constexpr auto I7 = Number<7>{};
132 
133  // FIX ME: To be deprecated
134  static constexpr auto K1 = Number<K1Value>{};
135 
136  static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
137  static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
138  static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
139 
141 
143  remove_cvref_t<decltype(GridwiseGemmPipeline_Selector<PipelineVer,
144  NumGemmKPrefetchStage,
145  LoopSched,
146  AEnableLds,
147  BEnableLds>())>;
148 
149  // Describe how data store to (LDS/VGPR) buffer from Global memory
150  __host__ __device__ static constexpr auto MakeABlockDescriptor()
151  {
152  constexpr auto a_block_desc = [&]() {
153  if constexpr(AEnableLds)
154  {
155  // K0->M->K1 Per Block
156  constexpr auto K0PerBlock = KPerBlock / K1;
157  constexpr auto max_lds_align = K1;
158 
159  if constexpr(ABlockLdsExtraM)
160  {
164  }
165  else
166  {
168  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
169  }
170  }
171  else
172  {
173  constexpr auto A_KRow = I2;
174  constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
175  constexpr auto K0PerWmma = WmmaK / A_KRow / K1;
176  // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
179  Number<MRepeat>{},
180  I1,
182  I1,
183  I1,
184  K1),
186  Number<K0PerWmma>{} * K1,
187  Number<K0PerWmma>{} * K1,
188  K1,
189  K1,
190  K1,
191  I1));
192  }
193  }();
194 
195  return a_block_desc;
196  }
197 
198  __host__ __device__ static constexpr auto MakeBBlockDescriptor()
199  {
200  constexpr auto b_block_desc = [&]() {
201  if constexpr(BEnableLds)
202  {
203  // K0->N->K1 Per Block
204  constexpr auto K0PerBlock = KPerBlock / K1;
205  constexpr auto max_lds_align = K1;
206 
207  if constexpr(BBlockLdsExtraN)
208  {
212  }
213  else
214  {
216  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
217  }
218  }
219  else
220  {
221 
222  constexpr auto B_KRow = I2;
223  constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
224  constexpr auto K0PerWmma = WmmaK / B_KRow / K1;
225  // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
228  Number<NRepeat>{},
229  I1,
231  I1,
232  I1,
233  K1),
235  Number<K0PerWmma>{} * K1,
236  Number<K0PerWmma>{} * K1,
237  K1,
238  K1,
239  K1,
240  I1));
241  }
242  }();
243 
244  return b_block_desc;
245  }
246 
247  __host__ __device__ static constexpr auto MakeABlockSliceCopyStep()
248  {
249  constexpr auto a_block_copy_step = [&]() {
250  if constexpr(AEnableLds)
251  {
252  constexpr auto K0PerBlock = KPerBlock / K1;
253 
254  return make_multi_index(K0PerBlock, 0, 0);
255  }
256  else
257  {
258  constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
259 
260  return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
261  }
262  }();
263 
264  return a_block_copy_step;
265  }
266 
267  __host__ __device__ static constexpr auto MakeBBlockSliceCopyStep()
268  {
269  constexpr auto b_block_copy_step = [&]() {
270  if constexpr(BEnableLds)
271  {
272  constexpr auto K0PerBlock = KPerBlock / K1;
273 
274  return make_multi_index(K0PerBlock, 0, 0);
275  }
276  else
277  {
278  constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
279 
280  return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
281  }
282  }();
283 
284  return b_block_copy_step;
285  }
286 
287  // Describe how data read from (LDS/VGPR) buffer
288  template <typename ABlockDesc_>
289  __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&)
290  {
291 
292  constexpr auto a_wave_desc = [&]() {
293  if constexpr(AEnableLds)
294  {
295  // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
296  constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
297  constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
298 #ifdef __gfx12__
299  constexpr auto A_KRow = I2;
300 #else
301  constexpr auto A_KRow = I1;
302 #endif
303 
305  ABlockDesc_{},
312  }
313  else
314  {
315  // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
316  constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
317  constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3);
318  constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4);
319  constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6);
320 
321  // Err: merge transform cause non-constexpr issue
322 
323  // return transform_tensor_descriptor(
324  // ABlockDesc_{},
325  // make_tuple(make_merge_transform(make_tuple(Number<KWmma>{}, I1)),
326  // make_pass_through_transform(Number<MRepeat>{}),
327  // make_pass_through_transform(I1),
328  // make_pass_through_transform(I1),
329  // make_pass_through_transform(Number<A_K1>{})),
330  // make_tuple(Sequence<0, 3>{},
331  // Sequence<1>{},
332  // Sequence<2>{},
333  // Sequence<4>{},
334  // Sequence<5>{}),
335  // make_tuple(
336  // Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{},
337  // Sequence<4>{}));
338 
339  // Workaround, Freeze transform
341  Number<MRepeat>{},
342  I1,
343  Number<A_KRow>{},
344  I1,
345  Number<A_K1>{}));
346  }
347  }();
348 
349  return a_wave_desc;
350  }
351 
352  template <typename BBlockDesc_>
353  __host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&)
354  {
355  constexpr auto b_wave_desc = [&]() {
356  if constexpr(BEnableLds)
357  {
358  // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
359  constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
360  constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
361 #ifdef __gfx12__
362  constexpr auto B_KRow = I2;
363 #else
364  constexpr auto B_KRow = I1;
365 #endif
367  BBlockDesc_{},
374  }
375  else
376  {
377  // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
378  constexpr auto KWmma = BBlockDesc_{}.GetLength(I0);
379  constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3);
380  constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4);
381  constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6);
382 
383  // Workaround, Freeze transform
385  Number<NRepeat>{},
386  I1,
387  Number<B_KRow>{},
388  I1,
389  Number<B_K1>{}));
390  }
391  }();
392 
393  return b_wave_desc;
394  }
395 
396  __host__ __device__ static constexpr auto
397  // *Caution Here repeat is shuffle repeat
399  {
400  constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
402  make_tuple(I1,
404  I1,
406 
407  return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
408  }
409 
410  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
411  template <typename Block2CTileMap>
412  __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
413  const BGridDesc& b_grid_desc,
414  const CGridDesc_M_N& c_grid_desc_m_n,
415  const Block2CTileMap& block_2_ctile_map)
416  {
417  static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
418  "wrong! K1 need to be known at compile-time");
419 
420  static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
421  (NPerBlock % (NRepeat * NPerWmma)) == 0,
422  "Invalid tuning param!");
423 
424  const auto GetAProblemsizeMK = [&]() {
425  if constexpr(AEnableLds)
426  {
427  return make_tuple(a_grid_desc.GetLength(I1),
428  a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2));
429  }
430  else
431  {
432  return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) *
433  a_grid_desc.GetLength(I5),
434  a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
435  a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6));
436  }
437  };
438 
439  const auto GetBProblemsizeNK = [&]() {
440  if constexpr(BEnableLds)
441  {
442  return make_tuple(b_grid_desc.GetLength(I1),
443  b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2));
444  }
445  else
446  {
447  return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
448  b_grid_desc.GetLength(I5),
449  b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
450  b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6));
451  }
452  };
453 
454  const auto M = GetAProblemsizeMK()[I0];
455  const auto N = GetBProblemsizeNK()[I0];
456  const auto K = GetAProblemsizeMK()[I1];
457 
458  if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
459  K == GetBProblemsizeNK()[I1]))
460  {
461  printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n",
462  GetAProblemsizeMK()[I0],
463  GetAProblemsizeMK()[I1],
464  GetBProblemsizeNK()[I0],
465  GetBProblemsizeNK()[I1],
466  c_grid_desc_m_n.GetLength(I0),
467  c_grid_desc_m_n.GetLength(I1));
468  printf("GridwiseOp err: ProblemSize check");
469  return false;
470  }
471 
472  if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
473  {
474  printf("GridwiseOp err: ProblemSize division");
475  return false;
476  }
477 
478  // check gridwise gemm pipeline
479  const auto num_k_loop = K / KPerBlock;
480 
481  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
482  {
483  printf("GridwiseOp err: Pipeline not support this k_loop");
484  return false;
485  }
486 
487  if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
488  {
489  return false;
490  }
491 
492  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
493  constexpr long_index_t TwoGB = (long_index_t{1} << 31);
494 
495  if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
496  b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB))
497  {
498  return false;
499  }
500  return true;
501  }
502 
503  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
504  {
505  const index_t num_loop = K / KPerBlock;
506 
507  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
508  }
509 
510  __host__ __device__ static constexpr auto
511  MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
512  {
513  const auto M = c_grid_desc_m_n.GetLength(I0);
514  const auto N = c_grid_desc_m_n.GetLength(I1);
515 
516  const auto MBlock = M / MPerBlock;
517  const auto NBlock = N / NPerBlock;
518 
519  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
520  c_grid_desc_m_n,
525 
526  return c_grid_desc_mblock_mperblock_nblock_nperblock;
527  }
528 
529  // return block_id to C matrix tile idx (m0, n0) mapping
530  __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
531  const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
532  {
534  c_grid_desc_m_n);
535  }
536 
538  {
539  // LDS allocation for A and B: be careful of alignment
540 
541  static constexpr auto max_lds_align = K1;
542 
543  static constexpr auto a_block_space_size_aligned =
544  AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(),
546  : 0;
547  static constexpr auto b_block_space_size_aligned =
548  BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(),
550  : 0;
551 
552  static constexpr auto a_block_space_offset = 0;
554 
555  // LDS allocation for C shuffle in LDS
556  static constexpr auto c_shuffle_block_space_size =
558  .GetElementSpaceSize();
559 
560  static constexpr auto c_shuffle_block_space_offset = 0;
561 
562  static constexpr auto lds_size =
563  math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType),
564  a_block_space_size_aligned * sizeof(ADataType) +
565  b_block_space_size_aligned * sizeof(BDataType));
566  };
567 
570  CGridDesc_M_N{}))>;
572  remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
573 
574  template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
575  __device__ static void Run(const ADataType* __restrict__ p_a_grid,
576  const BDataType* __restrict__ p_b_grid,
577  CDataType* __restrict__ p_c_grid,
578  void* __restrict__ p_shared,
579  const AGridDesc& a_grid_desc,
580  const BGridDesc& b_grid_desc,
582  c_grid_desc_mblock_mperblock_nblock_nperblock,
583  const AElementwiseOperation& a_element_op,
584  const BElementwiseOperation& b_element_op,
585  const CElementwiseOperation& c_element_op,
586  const Block2CTileMap& block_2_ctile_map)
587  {
588  // clang-format off
589 /*******************************************************************************/
590 // Memory buffer zone.
591  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
592  p_a_grid, a_grid_desc.GetElementSpaceSize());
593  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
594  p_b_grid, b_grid_desc.GetElementSpaceSize());
595  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
596  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
597 
598 /*******************************************************************************/
599 // BlockIdx.x -> [BlockId.m, BlockId.n]
600  const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
601  if(!block_2_ctile_map.ValidCTileIndex(
602  block_work_idx,
603  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
604  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
605  { return; }
606 
607  // Store BlockId into SGPR
608  const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
609  const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
610 
611 /*******************************************************************************/
612 // BlockLevel, A/B Matrix ThreadMapping in WMMA Source buffer, As Destinaion of BlockWise_Copy
613  const auto K = [&](){
614  if constexpr(AEnableLds){
615  return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
616  }
617  else{
618  return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3)
619  * a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6);
620  }
621  }();
622 
623  constexpr auto a_block_desc = MakeABlockDescriptor();
624  constexpr auto b_block_desc = MakeBBlockDescriptor();
625 
626  auto a_block_trait = [&](){
627  // A matrix blockwise copy
628  if constexpr(AEnableLds)
629  {
630  constexpr auto K0PerBlock = KPerBlock/ K1;
631  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
632  static_cast<ADataType*>(p_shared),
634 
635  auto a_blockwise_copy =
637 /* typename SrcElementwiseOperation, */ AElementwiseOperation,
638 /* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
639 /* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
640 /* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>,
641 /* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
642 /* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
643 /* typename SrcData, */ ADataType,
644 /* typename DstData, */ ADataType,
645 /* typename SrcDesc, */ decltype(a_grid_desc),
646 /* typename DstDesc, */ decltype(a_block_desc),
647 /* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
648 /* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
649 /* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim,
650 /* index_t DstVectorDim, */ 2,
651 /* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector,
652 /* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1,
653 /* index_t SrcScalarStrideInVector, */ 1,
654 /* index_t DstScalarStrideInVector, */ 1,
655 /* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun,
656 /* bool ThreadTransferDstResetCoordinateAfterRun, */ true,
657  NumGemmKPrefetchStage>(
658  a_grid_desc,
659  make_multi_index(0, m_block_data_idx_on_grid, 0),
660  a_element_op,
661  a_block_desc,
662  make_multi_index(0, 0, 0),
664 
665  return make_tuple(a_block_buf, a_blockwise_copy);
666  }
667  else
668  {
669  // Thread-wise copy
670  // KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1
671  constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
672  constexpr auto K0PerWmma = WmmaK/2/K1Value;
673  auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
674  a_block_desc.GetElementSpaceSize());
675 
676  // Limitation: NumDim of Src and Dst descriptor should be identical
677  auto a_blockwise_copy =
679  ADataType,
680  decltype(a_grid_desc),
681  decltype(a_block_desc),
683  Number<MRepeat>{},
684  I1,
686  I1,
687  I1,
688  Number<K1Value>{}>,
690  6,
691  ABlockTransferSrcScalarPerVector,
692  AThreadTransferSrcResetCoordinateAfterRun,
693  true>(
694  a_grid_desc,
695  make_multi_index(0,
696  m_block_data_idx_on_grid/(MWaves * MPerWmma),
697  get_thread_local_1d_id() / 32,
698  0,
699  (get_thread_local_1d_id() % 32 )/ 16,
700  get_thread_local_1d_id() % 16,
701  0));
702 
703  return make_tuple(a_block_buf, a_blockwise_copy);
704  }
705  };
706 
707  auto b_block_trait = [&](){
708  if constexpr(BEnableLds)
709  {
710  constexpr auto K0PerBlock = KPerBlock/ K1;
711  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
712  static_cast<BDataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
714 
715  auto b_blockwise_copy =
717  BElementwiseOperation,
721  BBlockTransferThreadClusterLengths_K0_N_K1,
722  BBlockTransferThreadClusterArrangeOrder,
723  BDataType,
724  BDataType,
725  decltype(b_grid_desc),
726  decltype(b_block_desc),
727  BBlockTransferSrcAccessOrder,
729  BBlockTransferSrcVectorDim,
730  2,
731  BBlockTransferSrcScalarPerVector,
732  BBlockTransferDstScalarPerVector_K1,
733  1,
734  1,
735  BThreadTransferSrcResetCoordinateAfterRun,
736  true,
737  NumGemmKPrefetchStage>(
738  b_grid_desc,
739  make_multi_index(0, n_block_data_idx_on_grid, 0),
740  b_element_op,
741  b_block_desc,
742  make_multi_index(0, 0, 0),
744 
745  return make_tuple(b_block_buf, b_blockwise_copy);
746  }
747  else
748  {
749  // Thread-wise copy
750  // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
751  constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
752  constexpr auto K0PerWmma = WmmaK/2/K1Value;
753  auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
754  b_block_desc.GetElementSpaceSize());
755 
756  // Limitation: NumDim of Src and Dst descriptor should be identical
757  auto b_blockwise_copy =
759  BDataType,
760  decltype(b_grid_desc),
761  decltype(b_block_desc),
763  Number<NRepeat>{},
764  I1,
766  I1,
767  I1,
768  Number<K1Value>{}>,
770  6,
771  BBlockTransferSrcScalarPerVector,
772  BThreadTransferSrcResetCoordinateAfterRun,
773  true>(
774  b_grid_desc,
775  make_multi_index(0,
776  n_block_data_idx_on_grid/(NWaves * NPerWmma),
777  get_thread_local_1d_id() / 32,
778  0,
779  (get_thread_local_1d_id() % 32 )/ 16,
780  get_thread_local_1d_id() % 16,
781  0));
782 
783  return make_tuple(b_block_buf, b_blockwise_copy);
784  }
785  };
786 
787  auto a_block_buf = a_block_trait()[I0];
788  auto a_blockwise_copy = a_block_trait()[I1];
789 
790  auto b_block_buf = b_block_trait()[I0];
791  auto b_blockwise_copy = b_block_trait()[I1];
792 /*******************************************************************************/
793  // GEMM
794  constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
795 
796  auto blockwise_gemm =
797  BlockwiseGemmWMMA<BlockSize,
798  ADataType,
799  BDataType,
800  AccDataType,
801  decltype(MakeAWaveDescriptor(a_block_desc)),
802  decltype(MakeBWaveDescriptor(b_block_desc)),
803  MPerBlock,
804  NPerBlock,
805  KPerBlock,
806  MPerWmma,
807  NPerWmma,
808  MRepeat,
809  NRepeat,
810  KPack,
811  AEnableLds,
812  BEnableLds>{};
813 
814  // Prepare Register for C matrix
815  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
816 
817 /*******************************************************************************/
818  // Shift Per SUB_K
819  constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep();
820  constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep();
821 
822  // gridwise GEMM pipeline
823  const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock);
824  GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc,
825  a_block_desc,
826  a_blockwise_copy,
827  a_grid_buf,
828  a_block_buf,
829  a_block_slice_copy_step,
830  b_grid_desc,
831  b_block_desc,
832  b_blockwise_copy,
833  b_grid_buf,
834  b_block_buf,
835  b_block_slice_copy_step,
836  blockwise_gemm,
837  c_thread_buf,
838  KBlockMainLoop);
839 /*******************************************************************************/
840  // write out to C, implement shuffle
841  {
842  // C mapping in single thread.
843  constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
844  blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
845 
846  // C mapping in single block
847  constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
848  blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
849 
850  constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1);
851  constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2);
852  constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I4);
853  constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5);
854  constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6);
855 
856  // LDS descriptor, shuffle and write out in MRepeat x NRepeat times
857  constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
859 
860  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
861  static_cast<CShuffleDataType*>(p_shared) + SharedMemTrait::c_shuffle_block_space_offset,
863 
864  constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor(
865  c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
866  make_tuple(
869  Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
870  MWave, // MWave
871  MSubGroup, // MSubGroup * MAccVgprs = MPerWmma
872  MAccVgprs)),
875  Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
876  NWave, // NWave
877  NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma
880 
881  // calculate origin of thread output tensor on global memory
882  // blockwise GEMM c matrix starting index
883  const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0);
884 
885  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
886  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
887 
888  const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
890  make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
893 
894  const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
896  make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
899 
900  const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex(
901  make_multi_index(m_thread_data_on_block));
902 
903  const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
904  make_multi_index(n_thread_data_on_block));
905 
906  // shuffle: threadwise copy C from VGPR to LDS
907  auto c_thread_copy_vgpr_to_lds =
909  CShuffleDataType,
910  decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
911  decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
913  Sequence<CShuffleMRepeatPerShuffle,
914  I1,
915  I1,
916  CShuffleNRepeatPerShuffle,
917  I1,
918  I1,
919  MAccVgprs>,
921  6,
922  1, // vector write pixel
924  1,
925  true>{
926  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
928  m_thread_data_on_block_idx[I1],
929  m_thread_data_on_block_idx[I2],
930  0,
931  n_thread_data_on_block_idx[I1],
932  n_thread_data_on_block_idx[I2],
933  m_thread_data_on_block_idx[I3]),
935 
936  // shuffle: blockwise copy C from LDS to global
937  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
938  ThisThreadBlock, // ThreadGroup
939  CElementwiseOperation, // ElementwiseOperation,
940  CGlobalMemoryDataOperation, // DstInMemOp,
941  Sequence<1,
942  CShuffleMRepeatPerShuffle * MWave * MPerWmma,
943  1,
944  CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
945  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
946  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
947  CShuffleDataType, // typename SrcData,
948  CDataType, // typename DstData,
949  decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
950  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
951  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
952  3, // index_t VectorDim,
953  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
954  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
955  false> // bool ThreadTransferDstResetCoordinateAfterRun>
956  {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
957  make_multi_index(0, 0, 0, 0),
958  c_grid_desc_mblock_mperblock_nblock_nperblock,
959  make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
960  c_element_op};
961 
962  // space filling curve for local reg & global memory
963  // space filling curve for threadwise C in VGPR
964  constexpr auto sfc_c_vgpr =
967  Sequence<CShuffleMRepeatPerShuffle,
968  1,
969  1,
970  CShuffleNRepeatPerShuffle,
971  1,
972  1,
973  MAccVgprs>>{};
974 
975  // space filling curve for shuffled blockwise C in global mem
976  constexpr auto sfc_c_global =
979  Sequence<1,
980  CShuffleMRepeatPerShuffle * MWave * MPerWmma,
981  1,
982  CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
983 
984  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
985 
986  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
987 
988  static_for<0, num_access, 1>{}([&](auto access_id) {
989  // make sure it's safe to write to LDS
990  block_sync_lds();
991 
992  // each thread write its data from VGPR to LDS
993  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
994  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
995  c_thread_buf,
996  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
997  c_shuffle_block_buf);
998 
999  // make sure it's safe to read from LDS
1000  block_sync_lds();
1001 
1002  // each block copy its data from LDS to global
1003  c_shuffle_block_copy_lds_to_global.Run(
1004  c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
1005  c_shuffle_block_buf,
1006  c_grid_desc_mblock_mperblock_nblock_nperblock,
1007  c_grid_buf);
1008 
1009  if constexpr(access_id < num_access - 1)
1010  {
1011  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1012 
1013  // move on C
1014  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1015  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1016  }
1017  });
1018  }
1019  // clang-format on
1020  }
1021 };
1022 
1023 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
constexpr auto GridwiseGemmPipeline_Selector()
Definition: gridwise_gemm_pipeline_selector.hpp:30
InMemoryDataOperationEnum
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__global__ void kernel_gemm_wmma(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, CDataType *__restrict__ p_c_grid, const AGridDesc a_grid_desc, const BGridDesc b_grid_desc, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_wmma.hpp:36
int64_t long_index_t
Definition: ck.hpp:290
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:289
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:16
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:17
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:298
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: block_to_ctile_map.hpp:260
Definition: blockwise_gemm_wmma.hpp:550
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_wmma.hpp:585
Definition: gridwise_gemm_wmma.hpp:538
static constexpr auto c_shuffle_block_space_size
Definition: gridwise_gemm_wmma.hpp:556
static constexpr auto b_block_space_size_aligned
Definition: gridwise_gemm_wmma.hpp:547
static constexpr auto max_lds_align
Definition: gridwise_gemm_wmma.hpp:541
static constexpr auto c_shuffle_block_space_offset
Definition: gridwise_gemm_wmma.hpp:560
static constexpr auto lds_size
Definition: gridwise_gemm_wmma.hpp:562
static constexpr auto a_block_space_size_aligned
Definition: gridwise_gemm_wmma.hpp:543
static constexpr auto a_block_space_offset
Definition: gridwise_gemm_wmma.hpp:552
static constexpr auto b_block_space_offset
Definition: gridwise_gemm_wmma.hpp:553
Definition: gridwise_gemm_wmma.hpp:123
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition: gridwise_gemm_wmma.hpp:530
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition: gridwise_gemm_wmma.hpp:570
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage, LoopSched, AEnableLds, BEnableLds >())> GridwiseGemmPipe
Definition: gridwise_gemm_wmma.hpp:147
static constexpr auto I6
Definition: gridwise_gemm_wmma.hpp:130
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared, const AGridDesc &a_grid_desc, const BGridDesc &b_grid_desc, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_wmma.hpp:575
static constexpr auto I5
Definition: gridwise_gemm_wmma.hpp:129
__host__ static constexpr __device__ auto MakeBBlockDescriptor()
Definition: gridwise_gemm_wmma.hpp:198
__host__ static constexpr __device__ auto MakeBWaveDescriptor(const BBlockDesc_ &)
Definition: gridwise_gemm_wmma.hpp:353
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc &a_grid_desc, const BGridDesc &b_grid_desc, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_wmma.hpp:412
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_wmma.hpp:511
__host__ static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: gridwise_gemm_wmma.hpp:398
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_wmma.hpp:503
static constexpr auto I7
Definition: gridwise_gemm_wmma.hpp:131
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_wmma.hpp:140
__host__ static constexpr __device__ auto MakeBBlockSliceCopyStep()
Definition: gridwise_gemm_wmma.hpp:267
static constexpr auto K1
Definition: gridwise_gemm_wmma.hpp:134
static constexpr auto I4
Definition: gridwise_gemm_wmma.hpp:128
static constexpr auto I1
Definition: gridwise_gemm_wmma.hpp:125
static constexpr auto MWaves
Definition: gridwise_gemm_wmma.hpp:136
__host__ static constexpr __device__ auto MakeABlockSliceCopyStep()
Definition: gridwise_gemm_wmma.hpp:247
static constexpr auto I2
Definition: gridwise_gemm_wmma.hpp:126
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition: gridwise_gemm_wmma.hpp:572
__host__ static constexpr __device__ auto MakeAWaveDescriptor(const ABlockDesc_ &)
Definition: gridwise_gemm_wmma.hpp:289
static constexpr auto I3
Definition: gridwise_gemm_wmma.hpp:127
__host__ static constexpr __device__ auto MakeABlockDescriptor()
Definition: gridwise_gemm_wmma.hpp:150
static constexpr auto NWaves
Definition: gridwise_gemm_wmma.hpp:137
static constexpr auto I0
Definition: gridwise_gemm_wmma.hpp:124
static constexpr auto WmmaK
Definition: gridwise_gemm_wmma.hpp:138
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: threadwise_tensor_slice_transfer.hpp:214
Definition: integral_constant.hpp:10
Definition: is_known_at_compile_time.hpp:14
Definition: functional2.hpp:31
Definition: unary_element_wise_operation.hpp:241