/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp Source File
gridwise_gemm_xdlops_v2r4r2.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/env.hpp"
19 
20 namespace ck {
21 
22 template <typename GridwiseGemm,
23  bool HasMainKBlockLoop,
24  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
25  typename Block2CTileMap,
26  typename AElementwiseOperation,
27  typename BElementwiseOperation,
28  typename CElementwiseOperation>
29 __global__ void
30 #if CK_USE_LAUNCH_BOUNDS
32 #endif
33  kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg,
34  const Block2CTileMap& b2c_map,
35  const AElementwiseOperation a_element_op,
36  const BElementwiseOperation b_element_op,
37  const CElementwiseOperation c_element_op)
38 {
39 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
40  defined(__gfx94__))
41  constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
42 
43  __shared__ uint8_t p_shared[shared_size];
44 
45  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
46  karg, static_cast<void*>(p_shared), b2c_map, a_element_op, b_element_op, c_element_op);
47 #else
48  ignore = karg;
49  ignore = b2c_map;
50  ignore = a_element_op;
51  ignore = b_element_op;
52  ignore = c_element_op;
53 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
54 }
55 
56 template <index_t BlockSize,
57  typename FloatA,
58  typename FloatB,
59  typename FloatAcc,
60  typename FloatC,
61  typename ALayout,
62  typename BLayout,
63  typename CLayout,
64  typename AElementwiseOperation,
65  typename BElementwiseOperation,
66  typename CElementwiseOperation,
68  index_t NumGemmKPrefetchStage,
69  index_t MPerBlock,
70  index_t NPerBlock,
71  index_t K0PerBlock,
72  index_t MPerXDL,
73  index_t NPerXDL,
74  index_t K1Value,
75  index_t MRepeat,
76  index_t NRepeat,
77  typename ABlockTransferThreadClusterLengths_K0_M_K1,
78  typename ABlockTransferThreadClusterArrangeOrder,
79  typename ABlockTransferSrcAccessOrder,
80  index_t ABlockTransferSrcVectorDim,
81  index_t ABlockTransferSrcScalarPerVector,
82  index_t ABlockTransferDstScalarPerVector_K1,
83  bool AThreadTransferSrcResetCoordinateAfterRun,
84  bool ABlockLdsExtraM,
85  typename BBlockTransferThreadClusterLengths_K0_N_K1,
86  typename BBlockTransferThreadClusterArrangeOrder,
87  typename BBlockTransferSrcAccessOrder,
88  index_t BBlockTransferSrcVectorDim,
89  index_t BBlockTransferSrcScalarPerVector,
90  index_t BBlockTransferDstScalarPerVector_K1,
91  bool BThreadTransferSrcResetCoordinateAfterRun,
92  bool BBlockLdsExtraN,
93  index_t CShuffleMRepeatPerShuffle,
94  index_t CShuffleNRepeatPerShuffle,
95  index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
96  typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
99  typename ComputeTypeA = FloatC,
100  typename ComputeTypeB = ComputeTypeA,
101  typename LDSTypeA = ComputeTypeA,
102  typename LDSTypeB = ComputeTypeB>
104 {
105  static constexpr auto I0 = Number<0>{};
106  static constexpr auto I1 = Number<1>{};
107  static constexpr auto I2 = Number<2>{};
108  static constexpr auto I3 = Number<3>{};
109  static constexpr auto I4 = Number<4>{};
110  static constexpr auto I5 = Number<5>{};
111  static constexpr auto I6 = Number<6>{};
112  static constexpr auto I7 = Number<7>{};
113 
114  // K1 should be Number<...>
115  static constexpr auto K1 = Number<K1Value>{};
116  static constexpr auto M01 = 1;
117  static constexpr auto N01 = 1;
118 
119  static constexpr auto gemm_padder =
121  MPerBlock, NPerBlock, K1* K0PerBlock};
122 
124 
126  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
127 
129  {
130  const FloatA* p_a_grid;
131  const FloatB* p_b_grid;
132  FloatC* p_c_grid;
144 
145  Argument(const FloatA* p_a_grid_,
146  const FloatB* p_b_grid_,
147  FloatC* p_c_grid_,
148  index_t M_,
149  index_t N_,
150  index_t K_,
151  index_t StrideA_,
152  index_t StrideB_,
153  index_t StrideC_,
154  index_t MPadded_,
155  index_t NPadded_,
156  index_t KPadded_,
157  index_t K0Padded_,
158  index_t k_batch_)
159  : p_a_grid(p_a_grid_),
160  p_b_grid(p_b_grid_),
161  p_c_grid(p_c_grid_),
162  M(M_),
163  N(N_),
164  K(K_),
165  StrideA(StrideA_),
166  StrideB(StrideB_),
167  StrideC(StrideC_),
168  MPadded(MPadded_),
169  NPadded(NPadded_),
170  KPadded(KPadded_),
171  K0Padded(K0Padded_),
172  k_batch(k_batch_)
173  {
174  }
175 
176  void Print() const
177  {
178  std::cout << "arg {"
179  << "M:" << M << ", "
180  << "N:" << N << ", "
181  << "K:" << K << ", "
182  << "SA:" << StrideA << ", "
183  << "SB:" << StrideB << ", "
184  << "SC:" << StrideC << ", "
185  << "MP:" << MPadded << ", "
186  << "NP:" << NPadded << ", "
187  << "KP:" << KPadded << ", "
188  << "K0Padded:" << K0Padded << ", "
189  << "KB:" << k_batch << "}" << std::endl;
190  }
191  };
192 
193  __host__ __device__ static auto CalculateGridSize(const Argument& karg)
194  {
195  return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock),
196  math::integer_divide_ceil(karg.M, MPerBlock),
197  karg.k_batch);
198  }
199 
200  // prefer this to be called on host
201  __host__ __device__ static auto CalculateMPadded(index_t M)
202  {
203  return math::integer_least_multiple(M, MPerBlock);
204  }
205 
206  __host__ __device__ static auto CalculateNPadded(index_t N)
207  {
208  return math::integer_least_multiple(N, NPerBlock);
209  }
210 
211  __host__ __device__ static auto CalculateK0Padded(index_t K, index_t K_Batch = 1)
212  {
213  // k_batch * k0 * k0_per_block * k1
214  auto K_t = K_Batch * K0PerBlock * K1;
215  return (K + K_t - 1) / K_t * K0PerBlock;
216  }
217 
218  __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
219  {
220  auto K0Padded = CalculateK0Padded(K, K_Batch);
221  return K_Batch * K0Padded * K1;
222  }
223 
224  __host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M,
225  index_t MPad,
226  index_t K,
227  index_t StrideA,
228  index_t KBatch,
229  index_t K0Padded,
230  index_t KPad)
231  {
232  const auto a_grid_desc_m_k = [&]() {
234  {
235  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
236  }
238  {
239  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
240  }
241  }();
242 
247  {
248 
249  const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
250  a_grid_desc_m_k,
254 
255  // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
257  a_grid_desc_m_kpad,
258  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
259  make_right_pad_transform(M, MPad - M)),
262  }
263  else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
265  {
266  // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
268  a_grid_desc_m_k,
269  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
270  make_right_pad_transform(M, MPad - M)),
273  }
274  else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding)
275  {
276  const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
277  a_grid_desc_m_k,
281 
283  a_grid_desc_m_kpad,
284  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
288  }
289  else
290  {
292  a_grid_desc_m_k,
293  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
297  }
298  }
299 
300  __host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K,
301  index_t NPad,
302  index_t N,
303  index_t StrideB,
304  index_t KBatch,
305  index_t K0Padded,
306  index_t KPad)
307  {
308  const auto b_grid_desc_k_n = [&]() {
310  {
311  return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
312  }
314  {
315  return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
316  }
317  }();
318 
323  {
324 
325  const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
326  b_grid_desc_k_n,
330 
331  // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
333  b_grid_desc_kpad_n,
334  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
335  make_right_pad_transform(N, NPad - N)),
338  }
339  else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
341  {
342  // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
344  b_grid_desc_k_n,
345  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
346  make_right_pad_transform(N, NPad - N)),
349  }
350  else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding)
351  {
352  const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
353  b_grid_desc_k_n,
357 
359  b_grid_desc_kpad_n,
360  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
364  }
365  else
366  {
368  b_grid_desc_k_n,
369  make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
373  }
374  }
375 
376  __host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
377  {
378  const auto c_grid_desc_m_n = [&]() {
380  {
381  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
382  }
384  {
385  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
386  }
387  }();
388 
389  return gemm_padder.PadCDescriptor_M_N(c_grid_desc_m_n);
390  }
391 
392  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
393  {
394  constexpr auto max_lds_align = K1;
395 
396  // A matrix in LDS memory, dst of blockwise copy
397  constexpr auto a_k0_m_k1_block_desc = [&]() {
398  if constexpr(ABlockLdsExtraM)
399  {
403  }
404  else
405  {
407  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
408  }
409  }();
410 
411  // B matrix in LDS memory, dst of blockwise copy
412  constexpr auto b_k0_n_k1_block_desc = [&]() {
413  if constexpr(BBlockLdsExtraN)
414  {
418  }
419  else
420  {
422  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
423  }
424  }();
425 
426  // LDS allocation for A and B: be careful of alignment
427  constexpr auto a_block_space_size =
428  math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
429 
430  constexpr auto b_block_space_size =
431  math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
432 
433  constexpr auto c_block_size =
435 
436  return math::max(a_block_space_size * sizeof(LDSTypeA) +
437  b_block_space_size * sizeof(LDSTypeB),
438  c_block_size * sizeof(FloatC));
439  }
440 
441  __host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
442  {
443  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
447  {
448  if(!(karg.M % MPerBlock == 0))
449  {
450  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
451  {
452  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
453  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
454  << std::endl;
455  }
456  return false;
457  }
458  }
459 
460  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
464  {
465  if(!(karg.N % NPerBlock == 0))
466  {
467  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
468  {
469  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
470  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
471  << std::endl;
472  }
473  return false;
474  }
475  }
476 
477  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
481  {
482 
483  auto K_t = karg.k_batch * K0PerBlock * K1;
484  if(!(karg.K % K_t == 0))
485  {
486  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
487  {
488  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
489  << karg.K << " " << __FILE__ << ":" << __LINE__
490  << ", in function: " << __func__ << std::endl;
491  }
492  return false;
493  }
494  }
495 
497  {
498  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
499  {
500  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
501  {
502  std::cout << "Arg K (" << karg.K
503  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
504  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
505  << __LINE__ << ", in function: " << __func__ << std::endl;
506  }
507  return false;
508  }
509  }
510  else
511  {
512  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
513  {
514  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
515  {
516  std::cout << "Arg M (" << karg.M
517  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
518  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
519  << __LINE__ << ", in function: " << __func__ << std::endl;
520  }
521  return false;
522  }
523  }
524 
526  {
527  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
528  {
529  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
530  {
531  std::cout << "Arg N (" << karg.N
532  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
533  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
534  << __LINE__ << ", in function: " << __func__ << std::endl;
535  }
536  return false;
537  }
538  }
539  else
540  {
541  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
542  {
543  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
544  {
545  std::cout << "Arg K (" << karg.K
546  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
547  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
548  << __LINE__ << ", in function: " << __func__ << std::endl;
549  }
550  return false;
551  }
552  }
553 
555  {
556  if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
557  {
558  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
559  {
560  std::cout << "Arg N (" << karg.N
561  << ") value is not a multiple of "
562  "CBlockTransferScalarPerVector_NWaveNPerXDL ("
563  << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__
564  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
565  }
566  return false;
567  }
568  }
569  else
570  {
571  if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
572  {
573  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
574  {
575  std::cout << "Arg M (" << karg.M
576  << ") value is not a multiple of "
577  "CBlockTransferScalarPerVector_NWaveNPerXDL ("
578  << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__
579  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
580  }
581  return false;
582  }
583  }
584 
585  const auto num_k_loop = karg.K0Padded / K0PerBlock;
586  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
587  {
588  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
589  {
590  std::cout << "The number of k loops (" << num_k_loop
591  << ") value is not supported by GridwiseGemm Pipeline."
592  << " K0Padded: " << karg.K0Padded << ", K0PerBlock: " << K0PerBlock << " "
593  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
594  << std::endl;
595  }
596  return false;
597  }
598 
599  return true;
600  }
601 
602  __host__ __device__ static auto GetKPad(index_t K, index_t KBatch)
603  {
604  const index_t K0Padded =
605  math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
606  const index_t KPad = KBatch * K0Padded * K1;
607  return KPad;
608  }
609 
610  __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0Padded)
611  {
612  const index_t num_loop = K0Padded / K0PerBlock;
613  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
614  }
615 
616  template <typename CGridDesc>
617  __host__ __device__ static constexpr auto
618  MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc)
619  {
620  const auto M = c_m_n_grid_desc.GetLength(I0);
621  const auto N = c_m_n_grid_desc.GetLength(I1);
622 
623  const auto MBlock = M / MPerBlock;
624  const auto NBlock = N / NPerBlock;
625 
627  c_m_n_grid_desc,
632  }
633 
634  // return block_id to C matrix tile idx (m0, n0) mapping
635  template <typename CGridDesc>
636  __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
637  const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
638  {
640  c_m_n_grid_desc, 8, KBatch);
641  }
642 
643  __host__ __device__ static constexpr auto
645  {
646  constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
647  constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
648 
650  make_tuple(I1,
652  I1,
654  }
655 
656  // return block_id to C matrix tile idx (m0, n0, k_split) mapping
657  __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap()
658  {
660  }
661 
664 
665  template <bool HasMainKBlockLoop,
666  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
667  typename Block2CTileMap>
668  __device__ static void Run(const Argument& karg,
669  void* __restrict__ p_shared_block,
670  const Block2CTileMap& block_2_ctile_map,
671  const AElementwiseOperation a_element_op = AElementwiseOperation{},
672  const BElementwiseOperation b_element_op = BElementwiseOperation{},
673  const CElementwiseOperation c_element_op = CElementwiseOperation{})
674  {
675  const FloatA* p_a_grid = karg.p_a_grid;
676  const FloatB* p_b_grid = karg.p_b_grid;
677  FloatC* p_c_grid = karg.p_c_grid;
678  const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1(
679  karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0Padded, karg.KPadded);
680  const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1(
681  karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0Padded, karg.KPadded);
682  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
683 
684  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
686 
687  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
688  p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
689  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
690  p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
691  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
692  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
693 
694  // divide block work by [KBatch, M, N]
695  const auto block_work_idx =
696  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
697 
698  if(!block_2_ctile_map.ValidCTileIndex(
699  block_work_idx,
700  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
701  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
702  {
703  return;
704  }
705 
706  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
707  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I2]);
708  const index_t k_batch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
709 
710  // HACK: this force m/n_block_data_idx_on_grid into SGPR
711  const index_t m_block_data_idx_on_grid =
712  __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
713 
714  const index_t n_block_data_idx_on_grid =
715  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
716 
717  // lds max alignment
718  constexpr auto max_lds_align = K1;
719 
720  // A matrix in LDS memory, dst of blockwise copy
721  constexpr auto a_k0_m_k1_block_desc = [&]() {
722  if constexpr(ABlockLdsExtraM)
723  {
725  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
726  make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
727  }
728  else
729  {
731  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
732  }
733  }();
734 
735  constexpr auto a_b_k0_m_k1_block_desc = [&]() {
736  if constexpr(ABlockLdsExtraM)
737  {
739  make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
740  make_tuple(Number<K0PerBlock>{} * Number<MPerBlock + 1>{} * K1,
741  Number<MPerBlock + 1>{} * K1,
742  K1,
743  I1));
744  }
745  else
746  {
748  make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
749  max_lds_align);
750  }
751  }();
752  // B matrix in LDS memory, dst of blockwise copy
753  constexpr auto b_k0_n_k1_block_desc = [&]() {
754  if constexpr(BBlockLdsExtraN)
755  {
757  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
758  make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
759  }
760  else
761  {
763  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
764  }
765  }();
766 
767  constexpr auto b_b_k0_n_k1_block_desc = [&]() {
768  if constexpr(BBlockLdsExtraN)
769  {
771  make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
772  make_tuple(Number<K0PerBlock>{} * Number<NPerBlock + 1>{} * K1,
773  Number<NPerBlock + 1>{} * K1,
774  K1,
775  I1));
776  }
777  else
778  {
780  make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
781  max_lds_align);
782  }
783  }();
784  // A matrix blockwise copy
785  auto a_blockwise_copy =
786  ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
787  AElementwiseOperation,
790  Sequence<1, K0PerBlock, MPerBlock, K1>,
791  ABlockTransferThreadClusterLengths_K0_M_K1,
792  ABlockTransferThreadClusterArrangeOrder,
793  FloatA,
794  LDSTypeA,
795  decltype(a_b_k0_m_k1_grid_desc),
796  decltype(a_b_k0_m_k1_block_desc),
797  ABlockTransferSrcAccessOrder,
798  Sequence<0, 2, 1, 3>,
799  ABlockTransferSrcVectorDim,
800  3,
801  ABlockTransferSrcScalarPerVector,
802  ABlockTransferDstScalarPerVector_K1,
803  1,
804  1,
805  AThreadTransferSrcResetCoordinateAfterRun,
806  true>(
807  a_b_k0_m_k1_grid_desc,
808  make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
809  a_element_op,
810  a_b_k0_m_k1_block_desc,
811  make_multi_index(0, 0, 0, 0),
813 
814  // B matrix blockwise copy
815  auto b_blockwise_copy =
816  ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
817  BElementwiseOperation,
820  Sequence<1, K0PerBlock, NPerBlock, K1>,
821  BBlockTransferThreadClusterLengths_K0_N_K1,
822  BBlockTransferThreadClusterArrangeOrder,
823  FloatB,
824  LDSTypeB,
825  decltype(b_b_k0_n_k1_grid_desc),
826  decltype(b_b_k0_n_k1_block_desc),
827  BBlockTransferSrcAccessOrder,
828  Sequence<0, 2, 1, 3>,
829  BBlockTransferSrcVectorDim,
830  3,
831  BBlockTransferSrcScalarPerVector,
832  BBlockTransferDstScalarPerVector_K1,
833  1,
834  1,
835  BThreadTransferSrcResetCoordinateAfterRun,
836  true>(
837  b_b_k0_n_k1_grid_desc,
838  make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
839  b_element_op,
840  b_b_k0_n_k1_block_desc,
841  make_multi_index(0, 0, 0, 0),
843 
844  // GEMM definition
845  // c_mtx += transpose(a_mtx) * b_mtx
846  // a_mtx[K0PerBlock, MPerBlock] is in LDS
847  // b_mtx[K0PerBlock, NPerBlock] is in LDS
848  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
849  // register
850  // sanity check
851 
853  BlockSize,
854  LDSTypeA,
855  LDSTypeB,
856  FloatAcc,
857  decltype(a_k0_m_k1_block_desc),
858  decltype(b_k0_n_k1_block_desc),
859  MPerXDL,
860  NPerXDL,
861  MRepeat,
862  NRepeat,
863  K1,
864  LoopSched,
865  ComputeTypeA,
866  ComputeTypeB>();
867 
868  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
869 
870  // LDS allocation for A and B: be careful of alignment
871  constexpr auto a_block_space_size =
872  math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
873 
874  auto p_a_block = reinterpret_cast<LDSTypeA*>(p_shared_block);
875  auto p_b_block = reinterpret_cast<LDSTypeB*>(p_a_block + a_block_space_size);
876 
877  constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
878  constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
879 
880  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
881  p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
882  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
883  p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
884 
885  // gridwise GEMM pipeline
886  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
887  (a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) /
888  (K0PerBlock * K1));
889 
890  const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
891 
892  gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
893  a_b_k0_m_k1_block_desc,
894  a_blockwise_copy,
895  a_grid_buf,
896  a_block_buf,
897  a_block_slice_copy_step,
898  b_b_k0_n_k1_grid_desc,
899  b_b_k0_n_k1_block_desc,
900  b_blockwise_copy,
901  b_grid_buf,
902  b_block_buf,
903  b_block_slice_copy_step,
904  blockwise_gemm,
905  c_thread_buf,
906  num_k_block_main_loop);
907 
908  // output: register to global memory
909  {
910  constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
911  constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
912 
913  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
914  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
915 
916  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
917  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
918 
919  constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
920  constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
921  constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
922  constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
923  constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
924  constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
925  constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
926  constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
927 
928  constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
930 
931  auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
932  static_cast<FloatC*>(p_shared_block),
933  c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
934 
935  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
936  c_block_desc_mblock_mperblock_nblock_nperblock,
937  make_tuple(
938  make_freeze_transform(I0), // freeze mblock
939  make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle,
940  M1,
941  M2,
942  M3,
943  M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL
944  make_freeze_transform(I0), // freeze nblock
945  make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle,
946  N1,
947  N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL
948  make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
949  make_tuple(
950  Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
951 
952  // calculate origin of thread output tensor on global memory
953  // blockwise GEMM c matrix starting index
954  const auto c_thread_mtx_on_block =
955  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
956 
957  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
958  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
959 
960  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
962  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
963  make_tuple(Sequence<0, 1, 2, 3, 4>{}),
964  make_tuple(Sequence<0>{}));
965 
966  const auto m_thread_data_on_block_idx =
967  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
968  make_multi_index(m_thread_data_on_block));
969 
970  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
973  make_tuple(Sequence<0, 1, 2>{}),
974  make_tuple(Sequence<0>{}));
975 
976  const auto n_thread_data_on_block_idx =
977  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
978  make_multi_index(n_thread_data_on_block));
979 
980  // VGPR to LDS
981  auto c_thread_copy_vgpr_to_lds =
982  ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
983  FloatC,
984  decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
985  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
987  Sequence<CShuffleMRepeatPerShuffle,
988  CShuffleNRepeatPerShuffle,
989  I1,
990  I1,
991  M2,
992  I1,
993  M4,
994  I1>,
995  Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
996  7,
997  1,
999  1,
1000  true>{
1001  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1002  make_multi_index(0,
1003  0,
1004  m_thread_data_on_block_idx[I1],
1005  n_thread_data_on_block_idx[I1],
1006  m_thread_data_on_block_idx[I2],
1007  m_thread_data_on_block_idx[I3],
1008  m_thread_data_on_block_idx[I4],
1009  n_thread_data_on_block_idx[I2]),
1011 
1012  // LDS to global
1013  auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1014  ThisThreadBlock, // index_t BlockSize,
1015  CElementwiseOperation, // ElementwiseOperation,
1016  CGlobalMemoryDataOperation, // DstInMemOp,
1017  Sequence<1,
1018  CShuffleMRepeatPerShuffle * MWave * MPerXDL,
1019  1,
1020  CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths,
1021  CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1022  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1023  FloatC, // typename SrcData,
1024  FloatC, // typename DstData,
1025  decltype(c_block_desc_mblock_mperblock_nblock_nperblock),
1026  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1027  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1028  3, // index_t VectorDim,
1029  CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
1030  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1031  false> // bool ThreadTransferDstResetCoordinateAfterRun
1032  {c_block_desc_mblock_mperblock_nblock_nperblock,
1033  make_multi_index(0, 0, 0, 0),
1034  c_grid_desc_mblock_mperblock_nblock_nperblock,
1035  make_multi_index(block_m_id, 0, block_n_id, 0),
1036  c_element_op};
1037 
1038  constexpr auto mxdlperwave_forward_step =
1039  make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0);
1040  constexpr auto nxdlperwave_forward_step =
1041  make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL);
1042  constexpr auto nxdlperwave_backward_step =
1043  make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL);
1044 
1045  static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
1046  constexpr auto mxdlperwave = mxdlperwave_iter;
1047 
1048  static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
1049  constexpr bool nxdlperwave_forward_sweep =
1050  (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
1051 
1052  constexpr index_t nxdlperwave_value =
1053  nxdlperwave_forward_sweep
1054  ? nxdlperwave_iter
1055  : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
1056 
1057  constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
1058 
1059  // make sure it's safe to do ds_write
1060  block_sync_lds();
1061 
1062  // VGPR to LDS
1063  c_thread_copy_vgpr_to_lds.Run(
1064  c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
1065  make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
1066  c_thread_buf,
1067  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1068  c_block_buf);
1069 
1070  // make sure it's safe to do ds_read
1071  block_sync_lds();
1072 
1073  // LDS to global
1074  c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock,
1075  c_block_buf,
1076  c_grid_desc_mblock_mperblock_nblock_nperblock,
1077  c_grid_buf);
1078 
1079  // move on nxdlperwave dimension
1080  if constexpr(nxdlperwave_forward_sweep &&
1081  (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
1082  {
1083  c_block_copy_lds_to_global.MoveDstSliceWindow(
1084  c_grid_desc_mblock_mperblock_nblock_nperblock,
1085  nxdlperwave_forward_step);
1086  }
1087  else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
1088  {
1089  c_block_copy_lds_to_global.MoveDstSliceWindow(
1090  c_grid_desc_mblock_mperblock_nblock_nperblock,
1091  nxdlperwave_backward_step);
1092  }
1093  });
1094 
1095  // move on mxdlperwave dimension
1096  if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
1097  {
1098  c_block_copy_lds_to_global.MoveDstSliceWindow(
1099  c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step);
1100  }
1101  });
1102  }
1103  }
1104 
1105  static std::string GetTypeString()
1106  {
1107  auto str = std::stringstream();
1108 
1109  // clang-format off
1110  str << "GemmXdlSplitKCShuffle_"
1111  << getGemmSpecializationString(GemmSpec) << "_"
1112  << std::string(ALayout::name)[0]
1113  << std::string(BLayout::name)[0]
1114  << std::string(CLayout::name)[0]
1115  << "_"
1116  << "B" << BlockSize << "_"
1117  << "Vec" << ABlockTransferSrcScalarPerVector << "x"
1118  << BBlockTransferSrcScalarPerVector << "x"
1119  << CBlockTransferScalarPerVector_NWaveNPerXDL << "_"
1120  << MPerBlock << "x"
1121  << NPerBlock << "x"
1122  << K0PerBlock << "x"
1123  << K1 ;
1124  // clang-format on
1125 
1126  return str.str();
1127  }
1128 };
1129 
1130 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:269
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition: blockwise_gemm_xdlops.hpp:606
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg, const Block2CTileMap &b2c_map, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:33
InMemoryDataOperationEnum
Definition: ck.hpp:278
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:300
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Simple tile mapping which creates 3D grid of block of threads.
Definition: block_to_ctile_map.hpp:976
Definition: block_to_ctile_map.hpp:540
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:129
index_t StrideC
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:138
index_t K0Padded
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:142
index_t MPadded
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:139
index_t k_batch
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:143
index_t N
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:134
const FloatA * p_a_grid
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:130
index_t K
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:135
index_t StrideA
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:136
index_t M
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:133
const FloatB * p_b_grid
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:131
index_t StrideB
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:137
FloatC * p_c_grid
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:132
index_t NPadded
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:140
index_t KPadded
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:141
Argument(const FloatA *p_a_grid_, const FloatB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t MPadded_, index_t NPadded_, index_t KPadded_, index_t K0Padded_, index_t k_batch_)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:145
void Print() const
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:176
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:104
static constexpr auto I2
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:107
static constexpr auto gemm_padder
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:119
__host__ static __device__ auto CalculateGridSize(const Argument &karg)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:193
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:201
static constexpr auto I5
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:110
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage, LoopSched >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:126
__host__ static __device__ auto GetKPad(index_t K, index_t KBatch)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:602
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap())> DefaultBlock2CTileMap
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:663
static constexpr auto K1
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:115
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:106
static __device__ void Run(const Argument &karg, void *__restrict__ p_shared_block, const Block2CTileMap &block_2_ctile_map, const AElementwiseOperation a_element_op=AElementwiseOperation{}, const BElementwiseOperation b_element_op=BElementwiseOperation{}, const CElementwiseOperation c_element_op=CElementwiseOperation{})
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:668
__host__ static constexpr __device__ auto MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_m_n_grid_desc)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:618
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:392
__host__ static constexpr __device__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:441
static constexpr auto I3
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:108
static constexpr auto I7
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:112
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:644
static constexpr auto I6
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:111
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:206
static constexpr auto N01
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:117
__host__ static __device__ auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, index_t MPad, index_t K, index_t StrideA, index_t KBatch, index_t K0Padded, index_t KPad)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:224
__host__ static __device__ auto CalculateK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:211
__host__ static constexpr __device__ bool CalculateHasMainK0BlockLoop(index_t K0Padded)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:610
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap()
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:657
__host__ static __device__ auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, index_t NPad, index_t N, index_t StrideB, index_t KBatch, index_t K0Padded, index_t KPad)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:300
__host__ static constexpr __device__ auto MakeCBlockClusterAdaptor(const CGridDesc &c_m_n_grid_desc, index_t, index_t, index_t KBatch)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:636
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:376
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:123
static constexpr auto M01
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:116
remove_cvref_t< decltype(MakeCGridDescriptor_M_N(1, 1, 1))> CGridDesc_M_N
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:662
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:218
static std::string GetTypeString()
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:1105
static constexpr auto I4
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:109
static constexpr auto I0
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:105
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: device_base.hpp:51
Definition: matrix_padder.hpp:134
Definition: unary_element_wise_operation.hpp:308
#define CK_ENV(name)
Definition: env.hpp:128