/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp Source File
gridwise_gemm_dpp.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
18 
19 namespace ck {
20 
21 template <typename GridwiseGemm, bool HasMainKBlockLoop>
22 __global__ void
23 #if CK_USE_LAUNCH_BOUNDS
25 #endif
26 #if CK_USE_WAVES_PER_EU
27  __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
28 #endif
29  kernel_gemm_dpp(const typename GridwiseGemm::Argument karg)
30 {
31 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx103__) || defined(__gfx11__))
32  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
33 
34  const auto a_grid_desc_ak0_m_ak1 = amd_wave_read_first_lane(
35  GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(karg.M, karg.K, karg.AK0, karg.StrideA));
36  const auto b_grid_desc_bk0_n_bk1 = amd_wave_read_first_lane(
37  GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(karg.K, karg.N, karg.BK0, karg.StrideB));
38  const auto c_grid_desc_m_n = amd_wave_read_first_lane(
39  GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC));
40 
41  GridwiseGemm::template Run<HasMainKBlockLoop>(karg.p_a_grid,
42  karg.p_b_grid,
43  karg.p_c_grid,
44  p_shared,
45  a_grid_desc_ak0_m_ak1,
46  b_grid_desc_bk0_n_bk1,
47  c_grid_desc_m_n);
48 #else
49  ignore = karg;
50 #endif
51 }
52 
53 template <index_t BlockSize,
54  typename ABDataType,
55  typename AccDataType,
56  typename CDataType,
57  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
58  typename ALayout,
59  typename BLayout,
60  typename CLayout,
61  typename AElementwiseOperation,
62  typename BElementwiseOperation,
63  typename CElementwiseOperation,
65  index_t MPerBlock,
66  index_t NPerBlock,
67  index_t KPerBlock,
68  index_t MPerDpp,
69  index_t NPerDpp,
70  index_t AK1Value,
71  index_t BK1Value,
72  index_t MDppPerWave,
73  index_t NDppPerWave,
74  typename ABlockTransferThreadClusterLengths_K0_M_K1,
75  typename ABlockTransferThreadClusterArrangeOrder,
76  typename ABlockTransferSrcAccessOrder,
77  index_t ABlockTransferSrcVectorDim,
78  index_t ABlockTransferSrcScalarPerVector,
79  index_t ABlockTransferDstScalarPerVector_K1,
80  bool AThreadTransferSrcResetCoordinateAfterRun,
81  bool ABlockLdsExtraM,
82  typename BBlockTransferThreadClusterLengths_K0_N_K1,
83  typename BBlockTransferThreadClusterArrangeOrder,
84  typename BBlockTransferSrcAccessOrder,
85  index_t BBlockTransferSrcVectorDim,
86  index_t BBlockTransferSrcScalarPerVector,
87  index_t BBlockTransferDstScalarPerVector_K1,
88  bool BThreadTransferSrcResetCoordinateAfterRun,
89  bool BBlockLdsExtraN,
90  typename CThreadTransferSrcDstAccessOrder,
91  index_t CThreadTransferSrcDstVectorDim,
92  index_t CThreadTransferDstScalarPerVector,
93  index_t NumGemmKPrefetchStage = 1,
96 {
97  static constexpr auto I0 = Number<0>{};
98  static constexpr auto I1 = Number<1>{};
99  static constexpr auto I2 = Number<2>{};
100  static constexpr auto I3 = Number<3>{};
101  static constexpr auto I4 = Number<4>{};
102  static constexpr auto I5 = Number<5>{};
103 
104  static constexpr auto AK1 = Number<AK1Value>{};
105  static constexpr auto BK1 = Number<BK1Value>{};
106  static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
107  static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
108 
109  static constexpr auto max_lds_align = math::lcm(AK1, BK1);
110 
112  // return block_id to C matrix tile idx (m0, n0) mapping
114 
115  __host__ static auto CalculateGridSize(index_t M, index_t N)
116  {
117  return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
118  }
119 
120  __host__ static auto CalculateMPadded(index_t M)
121  {
122  return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
123  }
124 
125  __host__ static auto CalculateNPadded(index_t N)
126  {
127  return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
128  }
129 
130  __host__ static auto CalculateAK0(index_t K) { return math::integer_divide_floor(K, AK1Value); }
131  __host__ static auto CalculateBK0(index_t K) { return math::integer_divide_floor(K, BK1Value); }
132 
133  // Argument
134  struct Problem
135  {
136  __host__ Problem(index_t M_,
137  index_t N_,
138  index_t K_,
139  index_t StrideA_,
140  index_t StrideB_,
141  index_t StrideC_)
142  : M{M_},
143  N{N_},
144  K{K_},
145  StrideA{StrideA_},
146  StrideB{StrideB_},
147  StrideC{StrideC_},
150  AK0{CalculateAK0(K)},
151  BK0{CalculateBK0(K)}
152  {
153  }
154 
155  __host__ void Print() const
156  {
157  std::cout << "problem {"
158  << "M:" << M << ", "
159  << "N:" << N << ", "
160  << "K:" << K << ", "
161  << "SA:" << StrideA << ", "
162  << "SB:" << StrideB << ", "
163  << "SC:" << StrideC << ", "
164  << "MP:" << MPadded << ", "
165  << "NP:" << NPadded << ", "
166  << "AK0:" << AK0 << ", "
167  << "BK0:" << BK0 << "}" << std::endl;
168  }
169 
180  };
181 
182  // Argument
184  {
185  __host__ Argument(const ABDataType* p_a_grid_,
186  const ABDataType* p_b_grid_,
187  CDataType* p_c_grid_,
188  index_t M_,
189  index_t N_,
190  index_t K_,
191  index_t StrideA_,
192  index_t StrideB_,
193  index_t StrideC_)
194  : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
195  p_a_grid{p_a_grid_},
196  p_b_grid{p_b_grid_},
197  p_c_grid{p_c_grid_}
198  {
199  }
200 
201  const ABDataType* p_a_grid;
202  const ABDataType* p_b_grid;
203  CDataType* p_c_grid;
204  };
205 
207  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
208 
209  __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
210  {
211  // A matrix in LDS memory, dst of blockwise copy
212  constexpr auto a_block_desc_ak0_m_ak1 = [&]() {
213  if constexpr(ABlockLdsExtraM)
214  {
218  }
219  else
220  {
223  }
224  }();
225 
226  return a_block_desc_ak0_m_ak1;
227  }
228 
229  __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
230  {
231  // B matrix in LDS memory, dst of blockwise copy
232  constexpr auto b_block_desc_bk0_n_bk1 = [&]() {
233  if constexpr(BBlockLdsExtraN)
234  {
238  }
239  else
240  {
243  }
244  }();
245 
246  return b_block_desc_bk0_n_bk1;
247  }
248 
249  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
250  {
251  // LDS allocation for A and B: be careful of alignment
252  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
253  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
254 
255  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
256  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
257  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
258  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
259 
260  return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(ABDataType);
261  }
262 
263  __host__ static constexpr bool CheckValidity(const Problem& problem)
264  {
265  static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value,
266  "Wrong! AK1 must be known at the time of compilation.");
267  static_assert(is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
268  "Wrong! BK1 must be known at the time of compilation.");
269 
270  static_assert(
271  MPerBlock % (MPerDpp * MDppPerWave) == 0,
272  "Invalid tuning parameters! MPerBlock must be divisible by MPerDpp * MDppPerWave.");
273  static_assert(
274  NPerBlock % (NPerDpp * NDppPerWave) == 0,
275  "Invalid tuning parameters! NPerBlock must be divisible by NPerDpp * NDppPerWave.");
276 
277  static_assert(
278  KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
279  "Invalid tuning parameters! KPerBlock must be divisible by both AK1 and BK1.");
280 
281  static_assert(AK1Value % ABlockTransferDstScalarPerVector_K1 == 0,
282  "Invalid tuning parameters! AK1Value must be divisible by "
283  "ABlockTransferDstScalarPerVector_K1");
284 
285  static_assert(BK1Value % BBlockTransferDstScalarPerVector_K1 == 0,
286  "Invalid tuning parameters! BK1Value must be divisible by "
287  "BBlockTransferDstScalarPerVector_K1");
288 
289  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
293  {
294  if(!(problem.M % MPerBlock == 0))
295  {
296  return false;
297  }
298  }
299 
300  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
304  {
305  if(!(problem.N % NPerBlock == 0))
306  {
307  return false;
308  }
309  }
310 
312  {
313  if(problem.K % ABlockTransferSrcScalarPerVector != 0)
314  {
315  return false;
316  }
317  }
318  else
319  {
320  if(problem.M % ABlockTransferSrcScalarPerVector != 0)
321  {
322  return false;
323  }
324  }
325 
327  {
328  if(problem.N % BBlockTransferSrcScalarPerVector != 0)
329  {
330  return false;
331  }
332  }
333  else
334  {
335  if(problem.K % BBlockTransferSrcScalarPerVector != 0)
336  {
337  return false;
338  }
339  }
340 
341  if(problem.K % KPerBlock != 0)
342  {
343  return false;
344  }
345 
346  // check gridwise gemm pipeline
347  const auto num_k_loop = problem.K / KPerBlock;
348  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
349  {
350  return false;
351  }
352 
353  return true;
354  }
355 
356  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
357  {
358  const auto num_loop = K / KPerBlock;
359 
360  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
361  }
362 
363  template <typename CGridDesc>
364  __host__ __device__ static constexpr auto
365  MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc& c_grid_desc_m_n)
366  {
367  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
368  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
369 
370  constexpr index_t KPack = math::max(
372 
373  using BlockwiseGemm =
375  ABDataType,
376  AccDataType,
377  decltype(a_block_desc_ak0_m_ak1),
378  decltype(b_block_desc_bk0_n_bk1),
379  MPerDpp,
380  NPerDpp,
381  MDppPerWave,
382  NDppPerWave,
383  KPack>;
384 
385  return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n);
386  }
387 
388  static constexpr auto matrix_padder =
390  MPerBlock, NPerBlock, KPerBlock};
391 
392  __device__ static auto
394  {
395  const auto a_grid_desc_mraw_kraw = [&]() {
397  {
398  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
399  }
401  {
402  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
403  }
404  }();
405 
406  const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
408  a_grid_desc_m_k,
413  }
414 
415  __device__ static auto
417  {
418  const auto b_grid_desc_nraw_kraw = [&]() {
420  {
421  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
422  }
424  {
425  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
426  }
427  }();
428 
429  const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
431  b_grid_desc_n_k,
433  make_unmerge_transform(make_tuple(BK0, BK1Value))),
436  }
437 
438  __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
439  {
440  const auto c_grid_desc_mraw_nraw = [&]() {
442  {
443  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
444  }
446  {
447  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
448  }
449  }();
450 
451  return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
452  }
453 
454  template <bool HasMainKBlockLoop,
455  typename AGridDesc_AK0_M_AK1,
456  typename BGridDesc_BK0_N_BK1,
457  typename CGridDesc_M_N>
458  __device__ static void Run(const ABDataType* __restrict__ p_a_grid,
459  const ABDataType* __restrict__ p_b_grid,
460  CDataType* __restrict__ p_c_grid,
461  void* __restrict__ p_shared,
462  const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
463  const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
464  const CGridDesc_M_N& c_grid_desc_m_n)
465  {
466  const auto c_grid_desc_m0_n0_m1_n1_m2_n2 =
467  MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n);
468 
469  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
470  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
471  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
472  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
473  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
474  p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_n2.GetElementSpaceSize());
475 
476  const AElementwiseOperation a_element_op{};
477  const BElementwiseOperation b_element_op{};
478  const CElementwiseOperation c_element_op{};
479 
480  const auto block_2_ctile_map =
481  Block2CTileMap{c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)};
482 
483  // divide block work by [M, N]
484  const auto block_work_idx =
485  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
486 
487  if(!block_2_ctile_map.ValidCTileIndex(
488  block_work_idx,
489  make_tuple(c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I0),
490  c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I1))))
491  {
492  return;
493  }
494 
495  // HACK: this force m/n_block_data_idx_on_grid into SGPR
496  const index_t m_block_data_idx_on_grid =
497  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
498  const index_t n_block_data_idx_on_grid =
499  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
500 
501  // A matrix in LDS memory, dst of blockwise copy
502  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
503  // B matrix in LDS memory, dst of blockwise copy
504  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
505 
506  auto a_blockwise_copy =
508  AElementwiseOperation,
512  ABlockTransferThreadClusterLengths_K0_M_K1,
513  ABlockTransferThreadClusterArrangeOrder,
514  ABDataType,
515  ABDataType,
516  decltype(a_grid_desc_ak0_m_ak1),
517  decltype(a_block_desc_ak0_m_ak1),
518  ABlockTransferSrcAccessOrder,
520  ABlockTransferSrcVectorDim,
521  2,
522  ABlockTransferSrcScalarPerVector,
523  ABlockTransferDstScalarPerVector_K1,
524  1,
525  1,
526  AThreadTransferSrcResetCoordinateAfterRun,
527  true,
528  NumGemmKPrefetchStage>(
529  a_grid_desc_ak0_m_ak1,
530  make_multi_index(0, m_block_data_idx_on_grid, 0),
531  a_element_op,
532  a_block_desc_ak0_m_ak1,
533  make_multi_index(0, 0, 0),
535 
536  auto b_blockwise_copy =
538  BElementwiseOperation,
542  BBlockTransferThreadClusterLengths_K0_N_K1,
543  BBlockTransferThreadClusterArrangeOrder,
544  ABDataType,
545  ABDataType,
546  decltype(b_grid_desc_bk0_n_bk1),
547  decltype(b_block_desc_bk0_n_bk1),
548  BBlockTransferSrcAccessOrder,
550  BBlockTransferSrcVectorDim,
551  2,
552  BBlockTransferSrcScalarPerVector,
553  BBlockTransferDstScalarPerVector_K1,
554  1,
555  1,
556  BThreadTransferSrcResetCoordinateAfterRun,
557  true,
558  NumGemmKPrefetchStage>(
559  b_grid_desc_bk0_n_bk1,
560  make_multi_index(0, n_block_data_idx_on_grid, 0),
561  b_element_op,
562  b_block_desc_bk0_n_bk1,
563  make_multi_index(0, 0, 0),
565 
566  // GEMM definition
567  // c_mtx += transpose(a_mtx) * b_mtx
568  // a_mtx[AK0PerBlock, MPerBlock] is in LDS
569  // b_mtx[BK0PerBlock, NPerBlock] is in LDS
570  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
571  // register
572  constexpr index_t KPack = math::max(
574  auto blockwise_gemm =
576  ABDataType,
577  AccDataType,
578  decltype(a_block_desc_ak0_m_ak1),
579  decltype(b_block_desc_bk0_n_bk1),
580  MPerDpp,
581  NPerDpp,
582  MDppPerWave,
583  NDppPerWave,
584  KPack>();
585 
586  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
587 
588  // LDS allocation for A and B: be careful of alignment
589  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
590  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
591 
592  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
593  static_cast<ABDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
594 
595  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
596  static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned,
597  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
598 
599  constexpr auto a_block_slice_copy_step = make_multi_index(AK0PerBlock, 0, 0);
600  constexpr auto b_block_slice_copy_step = make_multi_index(BK0PerBlock, 0, 0);
601 
602  // gridwise GEMM pipeline
603  const auto AK0 = a_grid_desc_ak0_m_ak1.GetLength(I0);
604  // (AK0 / AK0PerBlock) is always equal to (BK0 / BK0PerBlock)
605  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(AK0 / AK0PerBlock);
606 
607  GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
608  a_block_desc_ak0_m_ak1,
609  a_blockwise_copy,
610  a_grid_buf,
611  a_block_buf,
612  a_block_slice_copy_step,
613  b_grid_desc_bk0_n_bk1,
614  b_block_desc_bk0_n_bk1,
615  b_blockwise_copy,
616  b_grid_buf,
617  b_block_buf,
618  b_block_slice_copy_step,
619  blockwise_gemm,
620  c_thread_buf,
621  num_k_block_main_loop);
622 
623  // output: register to global memory
624  {
625  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2 =
626  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2();
627 
628  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
629  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2();
630 
631  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
632  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
633  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
634  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
635  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
636  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
637 
638  constexpr auto MPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
639  constexpr auto NPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
640 
641  // calculate origin of thread output tensor on global memory
642  // blockwise GEMM c matrix starting index
643  const auto c_thread_mtx_on_block =
644  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0);
645 
646  const index_t m_thread_data_on_grid =
647  m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
648 
649  const index_t n_thread_data_on_grid =
650  n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
651 
652  const auto m_thread_data_on_grid_to_m0_m1_m2_adaptor = make_single_stage_tensor_adaptor(
656 
657  const auto m_thread_data_on_grid_idx =
658  m_thread_data_on_grid_to_m0_m1_m2_adaptor.CalculateBottomIndex(
659  make_multi_index(m_thread_data_on_grid));
660 
661  const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
665 
666  const auto n_thread_data_on_grid_idx =
667  n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
668  make_multi_index(n_thread_data_on_grid));
669 
670  auto c_thread_copy =
672  CDataType,
673  decltype(c_thread_desc_m0_n0_m1_n1_m2_n2),
674  decltype(c_grid_desc_m0_n0_m1_n1_m2_n2),
675  CElementwiseOperation,
677  CThreadTransferSrcDstAccessOrder,
678  CThreadTransferSrcDstVectorDim,
679  CThreadTransferDstScalarPerVector,
680  CGlobalMemoryDataOperation,
681  1,
682  true>{
683  c_grid_desc_m0_n0_m1_n1_m2_n2,
684  make_multi_index(m_thread_data_on_grid_idx[I0],
685  n_thread_data_on_grid_idx[I0],
686  m_thread_data_on_grid_idx[I1],
687  n_thread_data_on_grid_idx[I1],
688  m_thread_data_on_grid_idx[I2],
689  n_thread_data_on_grid_idx[I2]),
690  c_element_op};
691 
692  c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_n2,
693  make_tuple(I0, I0, I0, I0, I0, I0),
694  c_thread_buf,
695  c_grid_desc_m0_n0_m1_n1_m2_n2,
696  c_grid_buf);
697  }
698  }
699 };
700 
701 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:267
__global__ void kernel_gemm_dpp(const typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_dpp.hpp:29
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:17
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:298
Definition: blockwise_gemm_dpp.hpp:33
Definition: dpp_gemm.hpp:322
Definition: gridwise_gemm_dpp.hpp:184
const ABDataType * p_a_grid
Definition: gridwise_gemm_dpp.hpp:201
const ABDataType * p_b_grid
Definition: gridwise_gemm_dpp.hpp:202
CDataType * p_c_grid
Definition: gridwise_gemm_dpp.hpp:203
__host__ Argument(const ABDataType *p_a_grid_, const ABDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_dpp.hpp:185
Definition: gridwise_gemm_dpp.hpp:135
index_t NPadded
Definition: gridwise_gemm_dpp.hpp:177
index_t BK0
Definition: gridwise_gemm_dpp.hpp:179
index_t StrideB
Definition: gridwise_gemm_dpp.hpp:174
index_t N
Definition: gridwise_gemm_dpp.hpp:171
index_t K
Definition: gridwise_gemm_dpp.hpp:172
index_t StrideC
Definition: gridwise_gemm_dpp.hpp:175
index_t M
Definition: gridwise_gemm_dpp.hpp:170
index_t AK0
Definition: gridwise_gemm_dpp.hpp:178
index_t MPadded
Definition: gridwise_gemm_dpp.hpp:176
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_dpp.hpp:136
__host__ void Print() const
Definition: gridwise_gemm_dpp.hpp:155
index_t StrideA
Definition: gridwise_gemm_dpp.hpp:173
Definition: gridwise_gemm_dpp.hpp:96
static __host__ auto CalculateAK0(index_t K)
Definition: gridwise_gemm_dpp.hpp:130
static __device__ void Run(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dpp.hpp:458
static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
Definition: gridwise_gemm_dpp.hpp:438
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc &c_grid_desc_m_n)
Definition: gridwise_gemm_dpp.hpp:365
static constexpr auto BK0PerBlock
Definition: gridwise_gemm_dpp.hpp:107
static __host__ auto CalculateBK0(index_t K)
Definition: gridwise_gemm_dpp.hpp:131
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_dpp.hpp:356
static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t N, index_t BK0, index_t StrideB)
Definition: gridwise_gemm_dpp.hpp:416
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_dpp.hpp:111
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_dpp.hpp:115
static constexpr auto I4
Definition: gridwise_gemm_dpp.hpp:101
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_dpp.hpp:263
static constexpr auto matrix_padder
Definition: gridwise_gemm_dpp.hpp:388
static constexpr auto I5
Definition: gridwise_gemm_dpp.hpp:102
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_dpp.hpp:120
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage >())> GridwiseGemmPipe
Definition: gridwise_gemm_dpp.hpp:207
static constexpr auto AK0PerBlock
Definition: gridwise_gemm_dpp.hpp:106
static constexpr auto I3
Definition: gridwise_gemm_dpp.hpp:100
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_dpp.hpp:229
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_dpp.hpp:125
static constexpr auto BK1
Definition: gridwise_gemm_dpp.hpp:105
static constexpr auto I2
Definition: gridwise_gemm_dpp.hpp:99
static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t K, index_t AK0, index_t StrideA)
Definition: gridwise_gemm_dpp.hpp:393
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_dpp.hpp:249
static constexpr auto I1
Definition: gridwise_gemm_dpp.hpp:98
static constexpr auto I0
Definition: gridwise_gemm_dpp.hpp:97
static constexpr auto AK1
Definition: gridwise_gemm_dpp.hpp:104
static constexpr auto max_lds_align
Definition: gridwise_gemm_dpp.hpp:109
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_dpp.hpp:209
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:10
Definition: is_known_at_compile_time.hpp:14
Definition: type.hpp:177
Definition: device_base.hpp:50
Definition: matrix_padder.hpp:180
Definition: unary_element_wise_operation.hpp:241