/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp Source File
gridwise_gemm_xdlops_v2r4r2.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
18 
19 namespace ck {
20 
21 template <typename GridwiseGemm,
22  bool HasMainKBlockLoop,
23  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
24  typename Block2CTileMap,
25  typename AElementwiseOperation,
26  typename BElementwiseOperation,
27  typename CElementwiseOperation>
28 __global__ void
29 #if CK_USE_LAUNCH_BOUNDS
31 #endif
32  kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg,
33  const Block2CTileMap& b2c_map,
34  const AElementwiseOperation a_element_op,
35  const BElementwiseOperation b_element_op,
36  const CElementwiseOperation c_element_op)
37 {
38 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
39  defined(__gfx94__))
40  constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
41 
42  __shared__ uint8_t p_shared[shared_size];
43 
44  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
45  karg, static_cast<void*>(p_shared), b2c_map, a_element_op, b_element_op, c_element_op);
46 #else
47  ignore = karg;
48  ignore = b2c_map;
49  ignore = a_element_op;
50  ignore = b_element_op;
51  ignore = c_element_op;
52 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
53 }
54 
55 template <index_t BlockSize,
56  typename FloatA,
57  typename FloatB,
58  typename FloatAcc,
59  typename FloatC,
60  typename ALayout,
61  typename BLayout,
62  typename CLayout,
63  typename AElementwiseOperation,
64  typename BElementwiseOperation,
65  typename CElementwiseOperation,
67  index_t NumGemmKPrefetchStage,
68  index_t MPerBlock,
69  index_t NPerBlock,
70  index_t K0PerBlock,
71  index_t MPerXDL,
72  index_t NPerXDL,
73  index_t K1Value,
74  index_t MRepeat,
75  index_t NRepeat,
76  typename ABlockTransferThreadClusterLengths_K0_M_K1,
77  typename ABlockTransferThreadClusterArrangeOrder,
78  typename ABlockTransferSrcAccessOrder,
79  index_t ABlockTransferSrcVectorDim,
80  index_t ABlockTransferSrcScalarPerVector,
81  index_t ABlockTransferDstScalarPerVector_K1,
82  bool AThreadTransferSrcResetCoordinateAfterRun,
83  bool ABlockLdsExtraM,
84  typename BBlockTransferThreadClusterLengths_K0_N_K1,
85  typename BBlockTransferThreadClusterArrangeOrder,
86  typename BBlockTransferSrcAccessOrder,
87  index_t BBlockTransferSrcVectorDim,
88  index_t BBlockTransferSrcScalarPerVector,
89  index_t BBlockTransferDstScalarPerVector_K1,
90  bool BThreadTransferSrcResetCoordinateAfterRun,
91  bool BBlockLdsExtraN,
92  index_t CShuffleMRepeatPerShuffle,
93  index_t CShuffleNRepeatPerShuffle,
94  index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
95  typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
98  typename ComputeTypeA = FloatC,
99  typename ComputeTypeB = ComputeTypeA,
100  typename LDSTypeA = ComputeTypeA,
101  typename LDSTypeB = ComputeTypeB>
103 {
104  static constexpr auto I0 = Number<0>{};
105  static constexpr auto I1 = Number<1>{};
106  static constexpr auto I2 = Number<2>{};
107  static constexpr auto I3 = Number<3>{};
108  static constexpr auto I4 = Number<4>{};
109  static constexpr auto I5 = Number<5>{};
110  static constexpr auto I6 = Number<6>{};
111  static constexpr auto I7 = Number<7>{};
112 
113  // K1 should be Number<...>
114  static constexpr auto K1 = Number<K1Value>{};
115  static constexpr auto M01 = 1;
116  static constexpr auto N01 = 1;
117 
118  static constexpr auto gemm_padder =
120  MPerBlock, NPerBlock, K1* K0PerBlock};
121 
123 
125  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
126 
128  {
129  const FloatA* p_a_grid;
130  const FloatB* p_b_grid;
131  FloatC* p_c_grid;
143 
144  Argument(const FloatA* p_a_grid_,
145  const FloatB* p_b_grid_,
146  FloatC* p_c_grid_,
147  index_t M_,
148  index_t N_,
149  index_t K_,
150  index_t StrideA_,
151  index_t StrideB_,
152  index_t StrideC_,
153  index_t MPadded_,
154  index_t NPadded_,
155  index_t KPadded_,
156  index_t K0Padded_,
157  index_t k_batch_)
158  : p_a_grid(p_a_grid_),
159  p_b_grid(p_b_grid_),
160  p_c_grid(p_c_grid_),
161  M(M_),
162  N(N_),
163  K(K_),
164  StrideA(StrideA_),
165  StrideB(StrideB_),
166  StrideC(StrideC_),
167  MPadded(MPadded_),
168  NPadded(NPadded_),
169  KPadded(KPadded_),
170  K0Padded(K0Padded_),
171  k_batch(k_batch_)
172  {
173  }
174 
175  void Print() const
176  {
177  std::cout << "arg {"
178  << "M:" << M << ", "
179  << "N:" << N << ", "
180  << "K:" << K << ", "
181  << "SA:" << StrideA << ", "
182  << "SB:" << StrideB << ", "
183  << "SC:" << StrideC << ", "
184  << "MP:" << MPadded << ", "
185  << "NP:" << NPadded << ", "
186  << "KP:" << KPadded << ", "
187  << "K0Padded:" << K0Padded << ", "
188  << "KB:" << k_batch << "}" << std::endl;
189  }
190  };
191 
192  __host__ __device__ static auto CalculateGridSize(const Argument& karg)
193  {
194  return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock),
195  math::integer_divide_ceil(karg.M, MPerBlock),
196  karg.k_batch);
197  }
198 
199  // prefer this to be called on host
200  __host__ __device__ static auto CalculateMPadded(index_t M)
201  {
202  return math::integer_least_multiple(M, MPerBlock);
203  }
204 
205  __host__ __device__ static auto CalculateNPadded(index_t N)
206  {
207  return math::integer_least_multiple(N, NPerBlock);
208  }
209 
210  __host__ __device__ static auto CalculateK0Padded(index_t K, index_t K_Batch = 1)
211  {
212  // k_batch * k0 * k0_per_block * k1
213  auto K_t = K_Batch * K0PerBlock * K1;
214  return (K + K_t - 1) / K_t * K0PerBlock;
215  }
216 
217  __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
218  {
219  auto K0Padded = CalculateK0Padded(K, K_Batch);
220  return K_Batch * K0Padded * K1;
221  }
222 
223  __host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M,
224  index_t MPad,
225  index_t K,
226  index_t StrideA,
227  index_t KBatch,
228  index_t K0Padded,
229  index_t KPad)
230  {
231  const auto a_grid_desc_m_k = [&]() {
233  {
234  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
235  }
237  {
238  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
239  }
240  }();
241 
246  {
247 
248  const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
249  a_grid_desc_m_k,
253 
254  // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
256  a_grid_desc_m_kpad,
257  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
258  make_right_pad_transform(M, MPad - M)),
261  }
262  else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
264  {
265  // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
267  a_grid_desc_m_k,
268  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
269  make_right_pad_transform(M, MPad - M)),
272  }
273  else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding)
274  {
275  const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
276  a_grid_desc_m_k,
280 
282  a_grid_desc_m_kpad,
283  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
287  }
288  else
289  {
291  a_grid_desc_m_k,
292  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
296  }
297  }
298 
299  __host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K,
300  index_t NPad,
301  index_t N,
302  index_t StrideB,
303  index_t KBatch,
304  index_t K0Padded,
305  index_t KPad)
306  {
307  const auto b_grid_desc_k_n = [&]() {
309  {
310  return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
311  }
313  {
314  return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
315  }
316  }();
317 
322  {
323 
324  const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
325  b_grid_desc_k_n,
329 
330  // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
332  b_grid_desc_kpad_n,
333  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
334  make_right_pad_transform(N, NPad - N)),
337  }
338  else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
340  {
341  // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
343  b_grid_desc_k_n,
344  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
345  make_right_pad_transform(N, NPad - N)),
348  }
349  else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding)
350  {
351  const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
352  b_grid_desc_k_n,
356 
358  b_grid_desc_kpad_n,
359  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
363  }
364  else
365  {
367  b_grid_desc_k_n,
368  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
372  }
373  }
374 
375  __host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
376  {
377  const auto c_grid_desc_m_n = [&]() {
379  {
380  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
381  }
383  {
384  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
385  }
386  }();
387 
388  return gemm_padder.PadCDescriptor_M_N(c_grid_desc_m_n);
389  }
390 
391  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
392  {
393  constexpr auto max_lds_align = K1;
394 
395  // A matrix in LDS memory, dst of blockwise copy
396  constexpr auto a_k0_m_k1_block_desc = [&]() {
397  if constexpr(ABlockLdsExtraM)
398  {
402  }
403  else
404  {
406  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
407  }
408  }();
409 
410  // B matrix in LDS memory, dst of blockwise copy
411  constexpr auto b_k0_n_k1_block_desc = [&]() {
412  if constexpr(BBlockLdsExtraN)
413  {
417  }
418  else
419  {
421  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
422  }
423  }();
424 
425  // LDS allocation for A and B: be careful of alignment
426  constexpr auto a_block_space_size =
427  math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
428 
429  constexpr auto b_block_space_size =
430  math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
431 
432  constexpr auto c_block_size =
434 
435  return math::max(a_block_space_size * sizeof(LDSTypeA) +
436  b_block_space_size * sizeof(LDSTypeB),
437  c_block_size * sizeof(FloatC));
438  }
439 
440  __host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
441  {
442  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
446  {
447  if(!(karg.M % MPerBlock == 0))
448  {
449  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
450  {
451  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
452  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
453  << std::endl;
454  }
455  return false;
456  }
457  }
458 
459  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
463  {
464  if(!(karg.N % NPerBlock == 0))
465  {
466  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
467  {
468  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
469  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
470  << std::endl;
471  }
472  return false;
473  }
474  }
475 
476  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
480  {
481 
482  auto K_t = karg.k_batch * K0PerBlock * K1;
483  if(!(karg.K % K_t == 0))
484  {
485  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
486  {
487  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
488  << karg.K << " " << __FILE__ << ":" << __LINE__
489  << ", in function: " << __func__ << std::endl;
490  }
491  return false;
492  }
493  }
494 
496  {
497  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
498  {
499  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
500  {
501  std::cout << "Arg K (" << karg.K
502  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
503  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
504  << __LINE__ << ", in function: " << __func__ << std::endl;
505  }
506  return false;
507  }
508  }
509  else
510  {
511  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
512  {
513  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
514  {
515  std::cout << "Arg M (" << karg.M
516  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
517  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
518  << __LINE__ << ", in function: " << __func__ << std::endl;
519  }
520  return false;
521  }
522  }
523 
525  {
526  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
527  {
528  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
529  {
530  std::cout << "Arg N (" << karg.N
531  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
532  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
533  << __LINE__ << ", in function: " << __func__ << std::endl;
534  }
535  return false;
536  }
537  }
538  else
539  {
540  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
541  {
542  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
543  {
544  std::cout << "Arg K (" << karg.K
545  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
546  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
547  << __LINE__ << ", in function: " << __func__ << std::endl;
548  }
549  return false;
550  }
551  }
552 
554  {
555  if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
556  {
557  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
558  {
559  std::cout << "Arg N (" << karg.N
560  << ") value is not a multiple of "
561  "CBlockTransferScalarPerVector_NWaveNPerXDL ("
562  << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__
563  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
564  }
565  return false;
566  }
567  }
568  else
569  {
570  if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
571  {
572  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
573  {
574  std::cout << "Arg M (" << karg.M
575  << ") value is not a multiple of "
576  "CBlockTransferScalarPerVector_NWaveNPerXDL ("
577  << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__
578  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
579  }
580  return false;
581  }
582  }
583 
584  const auto num_k_loop = karg.K0Padded / K0PerBlock;
585  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
586  {
587  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
588  {
589  std::cout << "The number of k loops (" << num_k_loop
590  << ") value is not supported by GridwiseGemm Pipeline."
591  << " K0Padded: " << karg.K0Padded << ", K0PerBlock: " << K0PerBlock << " "
592  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
593  << std::endl;
594  }
595  return false;
596  }
597 
598  return true;
599  }
600 
601  __host__ __device__ static auto GetKPad(index_t K, index_t KBatch)
602  {
603  const index_t K0Padded =
604  math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
605  const index_t KPad = KBatch * K0Padded * K1;
606  return KPad;
607  }
608 
609  __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0Padded)
610  {
611  const index_t num_loop = K0Padded / K0PerBlock;
612  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
613  }
614 
615  template <typename CGridDesc>
616  __host__ __device__ static constexpr auto
617  MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc)
618  {
619  const auto M = c_m_n_grid_desc.GetLength(I0);
620  const auto N = c_m_n_grid_desc.GetLength(I1);
621 
622  const auto MBlock = M / MPerBlock;
623  const auto NBlock = N / NPerBlock;
624 
626  c_m_n_grid_desc,
631  }
632 
633  // return block_id to C matrix tile idx (m0, n0) mapping
634  template <typename CGridDesc>
635  __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
636  const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
637  {
639  c_m_n_grid_desc, 8, KBatch);
640  }
641 
642  __host__ __device__ static constexpr auto
644  {
645  constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
646  constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
647 
649  make_tuple(I1,
651  I1,
653  }
654 
655  // return block_id to C matrix tile idx (m0, n0, k_split) mapping
656  __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap()
657  {
659  }
660 
663 
664  template <bool HasMainKBlockLoop,
665  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
666  typename Block2CTileMap>
667  __device__ static void Run(const Argument& karg,
668  void* __restrict__ p_shared_block,
669  const Block2CTileMap& block_2_ctile_map,
670  const AElementwiseOperation a_element_op = AElementwiseOperation{},
671  const BElementwiseOperation b_element_op = BElementwiseOperation{},
672  const CElementwiseOperation c_element_op = CElementwiseOperation{})
673  {
674  const FloatA* p_a_grid = karg.p_a_grid;
675  const FloatB* p_b_grid = karg.p_b_grid;
676  FloatC* p_c_grid = karg.p_c_grid;
677  const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1(
678  karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0Padded, karg.KPadded);
679  const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1(
680  karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0Padded, karg.KPadded);
681  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
682 
683  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
685 
686  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
687  p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
688  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
689  p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
690  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
691  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
692 
693  // divide block work by [KBatch, M, N]
694  const auto block_work_idx =
695  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
696 
697  if(!block_2_ctile_map.ValidCTileIndex(
698  block_work_idx,
699  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
700  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
701  {
702  return;
703  }
704 
705  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
706  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I2]);
707  const index_t k_batch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
708 
709  // HACK: this force m/n_block_data_idx_on_grid into SGPR
710  const index_t m_block_data_idx_on_grid =
711  __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
712 
713  const index_t n_block_data_idx_on_grid =
714  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
715 
716  // lds max alignment
717  constexpr auto max_lds_align = K1;
718 
719  // A matrix in LDS memory, dst of blockwise copy
720  constexpr auto a_k0_m_k1_block_desc = [&]() {
721  if constexpr(ABlockLdsExtraM)
722  {
724  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
725  make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
726  }
727  else
728  {
730  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
731  }
732  }();
733 
734  constexpr auto a_b_k0_m_k1_block_desc = [&]() {
735  if constexpr(ABlockLdsExtraM)
736  {
738  make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
739  make_tuple(Number<K0PerBlock>{} * Number<MPerBlock + 1>{} * K1,
740  Number<MPerBlock + 1>{} * K1,
741  K1,
742  I1));
743  }
744  else
745  {
747  make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
748  max_lds_align);
749  }
750  }();
751  // B matrix in LDS memory, dst of blockwise copy
752  constexpr auto b_k0_n_k1_block_desc = [&]() {
753  if constexpr(BBlockLdsExtraN)
754  {
756  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
757  make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
758  }
759  else
760  {
762  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
763  }
764  }();
765 
766  constexpr auto b_b_k0_n_k1_block_desc = [&]() {
767  if constexpr(BBlockLdsExtraN)
768  {
770  make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
771  make_tuple(Number<K0PerBlock>{} * Number<NPerBlock + 1>{} * K1,
772  Number<NPerBlock + 1>{} * K1,
773  K1,
774  I1));
775  }
776  else
777  {
779  make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
780  max_lds_align);
781  }
782  }();
783  // A matrix blockwise copy
784  auto a_blockwise_copy =
785  ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
786  AElementwiseOperation,
789  Sequence<1, K0PerBlock, MPerBlock, K1>,
790  ABlockTransferThreadClusterLengths_K0_M_K1,
791  ABlockTransferThreadClusterArrangeOrder,
792  FloatA,
793  LDSTypeA,
794  decltype(a_b_k0_m_k1_grid_desc),
795  decltype(a_b_k0_m_k1_block_desc),
796  ABlockTransferSrcAccessOrder,
797  Sequence<0, 2, 1, 3>,
798  ABlockTransferSrcVectorDim,
799  3,
800  ABlockTransferSrcScalarPerVector,
801  ABlockTransferDstScalarPerVector_K1,
802  1,
803  1,
804  AThreadTransferSrcResetCoordinateAfterRun,
805  true>(
806  a_b_k0_m_k1_grid_desc,
807  make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
808  a_element_op,
809  a_b_k0_m_k1_block_desc,
810  make_multi_index(0, 0, 0, 0),
812 
813  // B matrix blockwise copy
814  auto b_blockwise_copy =
815  ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
816  BElementwiseOperation,
819  Sequence<1, K0PerBlock, NPerBlock, K1>,
820  BBlockTransferThreadClusterLengths_K0_N_K1,
821  BBlockTransferThreadClusterArrangeOrder,
822  FloatB,
823  LDSTypeB,
824  decltype(b_b_k0_n_k1_grid_desc),
825  decltype(b_b_k0_n_k1_block_desc),
826  BBlockTransferSrcAccessOrder,
827  Sequence<0, 2, 1, 3>,
828  BBlockTransferSrcVectorDim,
829  3,
830  BBlockTransferSrcScalarPerVector,
831  BBlockTransferDstScalarPerVector_K1,
832  1,
833  1,
834  BThreadTransferSrcResetCoordinateAfterRun,
835  true>(
836  b_b_k0_n_k1_grid_desc,
837  make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
838  b_element_op,
839  b_b_k0_n_k1_block_desc,
840  make_multi_index(0, 0, 0, 0),
842 
843  // GEMM definition
844  // c_mtx += transpose(a_mtx) * b_mtx
845  // a_mtx[K0PerBlock, MPerBlock] is in LDS
846  // b_mtx[K0PerBlock, NPerBlock] is in LDS
847  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
848  // register
849  // sanity check
850 
852  BlockSize,
853  LDSTypeA,
854  LDSTypeB,
855  FloatAcc,
856  decltype(a_k0_m_k1_block_desc),
857  decltype(b_k0_n_k1_block_desc),
858  MPerXDL,
859  NPerXDL,
860  MRepeat,
861  NRepeat,
862  K1,
863  LoopSched,
864  ComputeTypeA,
865  ComputeTypeB>();
866 
867  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
868 
869  // LDS allocation for A and B: be careful of alignment
870  constexpr auto a_block_space_size =
871  math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
872 
873  auto p_a_block = reinterpret_cast<LDSTypeA*>(p_shared_block);
874  auto p_b_block = reinterpret_cast<LDSTypeB*>(p_a_block + a_block_space_size);
875 
876  constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
877  constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
878 
879  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
880  p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
881  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
882  p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
883 
884  // gridwise GEMM pipeline
885  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
886  (a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) /
887  (K0PerBlock * K1));
888 
889  const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
890 
891  gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
892  a_b_k0_m_k1_block_desc,
893  a_blockwise_copy,
894  a_grid_buf,
895  a_block_buf,
896  a_block_slice_copy_step,
897  b_b_k0_n_k1_grid_desc,
898  b_b_k0_n_k1_block_desc,
899  b_blockwise_copy,
900  b_grid_buf,
901  b_block_buf,
902  b_block_slice_copy_step,
903  blockwise_gemm,
904  c_thread_buf,
905  num_k_block_main_loop);
906 
907  // output: register to global memory
908  {
909  constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
910  constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
911 
912  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
913  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
914 
915  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
916  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
917 
918  constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
919  constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
920  constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
921  constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
922  constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
923  constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
924  constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
925  constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
926 
927  constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
929 
930  auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
931  static_cast<FloatC*>(p_shared_block),
932  c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
933 
934  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
935  c_block_desc_mblock_mperblock_nblock_nperblock,
936  make_tuple(
937  make_freeze_transform(I0), // freeze mblock
938  make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle,
939  M1,
940  M2,
941  M3,
942  M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL
943  make_freeze_transform(I0), // freeze nblock
944  make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle,
945  N1,
946  N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL
947  make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
948  make_tuple(
949  Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
950 
951  // calculate origin of thread output tensor on global memory
952  // blockwise GEMM c matrix starting index
953  const auto c_thread_mtx_on_block =
954  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
955 
956  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
957  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
958 
959  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
961  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
962  make_tuple(Sequence<0, 1, 2, 3, 4>{}),
963  make_tuple(Sequence<0>{}));
964 
965  const auto m_thread_data_on_block_idx =
966  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
967  make_multi_index(m_thread_data_on_block));
968 
969  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
972  make_tuple(Sequence<0, 1, 2>{}),
973  make_tuple(Sequence<0>{}));
974 
975  const auto n_thread_data_on_block_idx =
976  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
977  make_multi_index(n_thread_data_on_block));
978 
979  // VGPR to LDS
980  auto c_thread_copy_vgpr_to_lds =
981  ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
982  FloatC,
983  decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
984  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
986  Sequence<CShuffleMRepeatPerShuffle,
987  CShuffleNRepeatPerShuffle,
988  I1,
989  I1,
990  M2,
991  I1,
992  M4,
993  I1>,
994  Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
995  7,
996  1,
998  1,
999  true>{
1000  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1001  make_multi_index(0,
1002  0,
1003  m_thread_data_on_block_idx[I1],
1004  n_thread_data_on_block_idx[I1],
1005  m_thread_data_on_block_idx[I2],
1006  m_thread_data_on_block_idx[I3],
1007  m_thread_data_on_block_idx[I4],
1008  n_thread_data_on_block_idx[I2]),
1010 
1011  // LDS to global
1012  auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1013  ThisThreadBlock, // index_t BlockSize,
1014  CElementwiseOperation, // ElementwiseOperation,
1015  CGlobalMemoryDataOperation, // DstInMemOp,
1016  Sequence<1,
1017  CShuffleMRepeatPerShuffle * MWave * MPerXDL,
1018  1,
1019  CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths,
1020  CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1021  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1022  FloatC, // typename SrcData,
1023  FloatC, // typename DstData,
1024  decltype(c_block_desc_mblock_mperblock_nblock_nperblock),
1025  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1026  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1027  3, // index_t VectorDim,
1028  CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
1029  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1030  false> // bool ThreadTransferDstResetCoordinateAfterRun
1031  {c_block_desc_mblock_mperblock_nblock_nperblock,
1032  make_multi_index(0, 0, 0, 0),
1033  c_grid_desc_mblock_mperblock_nblock_nperblock,
1034  make_multi_index(block_m_id, 0, block_n_id, 0),
1035  c_element_op};
1036 
1037  constexpr auto mxdlperwave_forward_step =
1038  make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0);
1039  constexpr auto nxdlperwave_forward_step =
1040  make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL);
1041  constexpr auto nxdlperwave_backward_step =
1042  make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL);
1043 
1044  static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
1045  constexpr auto mxdlperwave = mxdlperwave_iter;
1046 
1047  static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
1048  constexpr bool nxdlperwave_forward_sweep =
1049  (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
1050 
1051  constexpr index_t nxdlperwave_value =
1052  nxdlperwave_forward_sweep
1053  ? nxdlperwave_iter
1054  : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
1055 
1056  constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
1057 
1058  // make sure it's safe to do ds_write
1059  block_sync_lds();
1060 
1061  // VGPR to LDS
1062  c_thread_copy_vgpr_to_lds.Run(
1063  c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
1064  make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
1065  c_thread_buf,
1066  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1067  c_block_buf);
1068 
1069  // make sure it's safe to do ds_read
1070  block_sync_lds();
1071 
1072  // LDS to global
1073  c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock,
1074  c_block_buf,
1075  c_grid_desc_mblock_mperblock_nblock_nperblock,
1076  c_grid_buf);
1077 
1078  // move on nxdlperwave dimension
1079  if constexpr(nxdlperwave_forward_sweep &&
1080  (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
1081  {
1082  c_block_copy_lds_to_global.MoveDstSliceWindow(
1083  c_grid_desc_mblock_mperblock_nblock_nperblock,
1084  nxdlperwave_forward_step);
1085  }
1086  else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
1087  {
1088  c_block_copy_lds_to_global.MoveDstSliceWindow(
1089  c_grid_desc_mblock_mperblock_nblock_nperblock,
1090  nxdlperwave_backward_step);
1091  }
1092  });
1093 
1094  // move on mxdlperwave dimension
1095  if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
1096  {
1097  c_block_copy_lds_to_global.MoveDstSliceWindow(
1098  c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step);
1099  }
1100  });
1101  }
1102  }
1103 
1104  static std::string GetTypeString()
1105  {
1106  auto str = std::stringstream();
1107 
1108  // clang-format off
1109  str << "GemmXdlSplitKCShuffle_"
1110  << getGemmSpecializationString(GemmSpec) << "_"
1111  << std::string(ALayout::name)[0]
1112  << std::string(BLayout::name)[0]
1113  << std::string(CLayout::name)[0]
1114  << "_"
1115  << "B" << BlockSize << "_"
1116  << "Vec" << ABlockTransferSrcScalarPerVector << "x"
1117  << BBlockTransferSrcScalarPerVector << "x"
1118  << CBlockTransferScalarPerVector_NWaveNPerXDL << "_"
1119  << MPerBlock << "x"
1120  << NPerBlock << "x"
1121  << K0PerBlock << "x"
1122  << K1 ;
1123  // clang-format on
1124 
1125  return str.str();
1126  }
1127 };
1128 
1129 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
#define CK_ENV(name)
Definition: env.hpp:128
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:33
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition: blockwise_gemm_xdlops.hpp:606
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg, const Block2CTileMap &b2c_map, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:32
InMemoryDataOperationEnum
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:17
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Simple tile mapping which creates 3D grid of block of threads.
Definition: block_to_ctile_map.hpp:974
Definition: block_to_ctile_map.hpp:539
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:128
index_t StrideC
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:137
index_t K0Padded
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:141
index_t MPadded
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:138
index_t k_batch
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:142
index_t N
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:133
const FloatA * p_a_grid
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:129
index_t K
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:134
index_t StrideA
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:135
index_t M
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:132
const FloatB * p_b_grid
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:130
index_t StrideB
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:136
FloatC * p_c_grid
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:131
index_t NPadded
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:139
index_t KPadded
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:140
Argument(const FloatA *p_a_grid_, const FloatB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t MPadded_, index_t NPadded_, index_t KPadded_, index_t K0Padded_, index_t k_batch_)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:144
void Print() const
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:175
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:103
static constexpr auto I2
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:106
static constexpr auto gemm_padder
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:118
__host__ static __device__ auto CalculateGridSize(const Argument &karg)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:192
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:200
static constexpr auto I5
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:109
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage, LoopSched >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:125
__host__ static __device__ auto GetKPad(index_t K, index_t KBatch)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:601
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap())> DefaultBlock2CTileMap
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:662
static constexpr auto K1
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:114
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:105
static __device__ void Run(const Argument &karg, void *__restrict__ p_shared_block, const Block2CTileMap &block_2_ctile_map, const AElementwiseOperation a_element_op=AElementwiseOperation{}, const BElementwiseOperation b_element_op=BElementwiseOperation{}, const CElementwiseOperation c_element_op=CElementwiseOperation{})
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:667
__host__ static constexpr __device__ auto MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_m_n_grid_desc)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:617
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:391
__host__ static constexpr __device__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:440
static constexpr auto I3
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:107
static constexpr auto I7
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:111
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:643
static constexpr auto I6
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:110
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:205
static constexpr auto N01
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:116
__host__ static __device__ auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, index_t MPad, index_t K, index_t StrideA, index_t KBatch, index_t K0Padded, index_t KPad)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:223
__host__ static __device__ auto CalculateK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:210
__host__ static constexpr __device__ bool CalculateHasMainK0BlockLoop(index_t K0Padded)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:609
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap()
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:656
__host__ static __device__ auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, index_t NPad, index_t N, index_t StrideB, index_t KBatch, index_t K0Padded, index_t KPad)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:299
__host__ static constexpr __device__ auto MakeCBlockClusterAdaptor(const CGridDesc &c_m_n_grid_desc, index_t, index_t, index_t KBatch)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:635
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:375
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:122
static constexpr auto M01
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:115
remove_cvref_t< decltype(MakeCGridDescriptor_M_N(1, 1, 1))> CGridDesc_M_N
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:661
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:217
static std::string GetTypeString()
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:1104
static constexpr auto I4
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:108
static constexpr auto I0
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:104
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: device_base.hpp:50
Definition: matrix_padder.hpp:134
Definition: unary_element_wise_operation.hpp:241