/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File
gridwise_gemm_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck/utility/env.hpp"
19 
20 namespace ck {
21 
28 // operations that could be applied on each tensor respectively. The CDE_op is an
29 // elementwise operation applied to the C and all D tensors.
129 template <typename ALayout,
130  typename BLayout,
131  typename DsLayout,
132  typename ELayout,
133  typename AsDataType,
134  typename BsDataType,
135  typename AccDataType,
136  typename CShuffleDataType,
137  typename DsDataType,
138  typename EDataType,
139  typename AElementwiseOperation,
140  typename BElementwiseOperation,
141  typename CDEElementwiseOperation,
143  index_t BlockSize,
144  index_t MPerBlock,
145  index_t NPerBlock,
146  index_t KPerBlock,
147  index_t AK1Value,
148  index_t BK1Value,
149  index_t MPerWmma,
150  index_t NPerWmma,
151  index_t MRepeat,
152  index_t NRepeat,
153  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
154  typename ABlockTransferThreadClusterArrangeOrder,
155  typename ABlockTransferSrcAccessOrder,
156  index_t ABlockTransferSrcVectorDim,
157  index_t ABlockTransferSrcScalarPerVector,
158  index_t ABlockTransferDstScalarPerVector_AK1,
159  bool AThreadTransferSrcResetCoordinateAfterRun,
160  index_t ABlockLdsExtraM,
161  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
162  typename BBlockTransferThreadClusterArrangeOrder,
163  typename BBlockTransferSrcAccessOrder,
164  index_t BBlockTransferSrcVectorDim,
165  index_t BBlockTransferSrcScalarPerVector,
166  index_t BBlockTransferDstScalarPerVector_BK1,
167  bool BThreadTransferSrcResetCoordinateAfterRun,
168  index_t BBlockLdsExtraN,
169  index_t CShuffleMRepeatPerShuffle,
170  index_t CShuffleNRepeatPerShuffle,
171  typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
172  typename CDEShuffleBlockTransferScalarPerVectors,
173  BlockGemmPipelineScheduler BlkGemmPipeSched,
174  BlockGemmPipelineVersion BlkGemmPipelineVer,
175  typename ComputeTypeA,
176  typename ComputeTypeB,
177  bool PermuteA,
178  bool PermuteB,
179  bool ForceThreadTileTransfer = false>
182  ALayout,
183  BLayout,
184  DsLayout,
185  ELayout,
186  AsDataType,
187  BsDataType,
188  AccDataType,
189  CShuffleDataType,
190  DsDataType,
191  EDataType,
192  AElementwiseOperation,
193  BElementwiseOperation,
194  CDEElementwiseOperation,
195  GemmSpec,
196  BlockSize,
197  MPerBlock,
198  NPerBlock,
199  KPerBlock,
200  AK1Value,
201  BK1Value,
202  MPerWmma,
203  NPerWmma,
204  MRepeat,
205  NRepeat,
206  ABlockTransferThreadClusterLengths_AK0_M_AK1,
207  ABlockTransferThreadClusterArrangeOrder,
208  ABlockTransferSrcAccessOrder,
209  ABlockTransferSrcVectorDim,
210  ABlockTransferSrcScalarPerVector,
211  ABlockTransferDstScalarPerVector_AK1,
212  AThreadTransferSrcResetCoordinateAfterRun,
213  ABlockLdsExtraM,
214  BBlockTransferThreadClusterLengths_BK0_N_BK1,
215  BBlockTransferThreadClusterArrangeOrder,
216  BBlockTransferSrcAccessOrder,
217  BBlockTransferSrcVectorDim,
218  BBlockTransferSrcScalarPerVector,
219  BBlockTransferDstScalarPerVector_BK1,
220  BThreadTransferSrcResetCoordinateAfterRun,
221  BBlockLdsExtraN,
222  CShuffleMRepeatPerShuffle,
223  CShuffleNRepeatPerShuffle,
224  CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
225  CDEShuffleBlockTransferScalarPerVectors,
226  BlkGemmPipeSched,
227  BlkGemmPipelineVer,
228  ComputeTypeA,
229  ComputeTypeB,
230  PermuteA,
231  PermuteB,
232  ForceThreadTileTransfer>
233 {
235  ALayout,
236  BLayout,
237  DsLayout,
238  ELayout,
239  AsDataType,
240  BsDataType,
241  AccDataType,
242  CShuffleDataType,
243  DsDataType,
244  EDataType,
245  AElementwiseOperation,
246  BElementwiseOperation,
247  CDEElementwiseOperation,
248  GemmSpec,
249  BlockSize,
250  MPerBlock,
251  NPerBlock,
252  KPerBlock,
253  AK1Value,
254  BK1Value,
255  MPerWmma,
256  NPerWmma,
257  MRepeat,
258  NRepeat,
259  ABlockTransferThreadClusterLengths_AK0_M_AK1,
260  ABlockTransferThreadClusterArrangeOrder,
261  ABlockTransferSrcAccessOrder,
262  ABlockTransferSrcVectorDim,
263  ABlockTransferSrcScalarPerVector,
264  ABlockTransferDstScalarPerVector_AK1,
265  AThreadTransferSrcResetCoordinateAfterRun,
266  ABlockLdsExtraM,
267  BBlockTransferThreadClusterLengths_BK0_N_BK1,
268  BBlockTransferThreadClusterArrangeOrder,
269  BBlockTransferSrcAccessOrder,
270  BBlockTransferSrcVectorDim,
271  BBlockTransferSrcScalarPerVector,
272  BBlockTransferDstScalarPerVector_BK1,
273  BThreadTransferSrcResetCoordinateAfterRun,
274  BBlockLdsExtraN,
275  CShuffleMRepeatPerShuffle,
276  CShuffleNRepeatPerShuffle,
277  CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
278  CDEShuffleBlockTransferScalarPerVectors,
279  BlkGemmPipeSched,
280  BlkGemmPipelineVer,
281  ComputeTypeA,
282  ComputeTypeB,
283  PermuteA,
284  PermuteB,
285  ForceThreadTileTransfer>;
286 
287  using Base::I0;
288  using Base::I1;
289  using Base::I2;
290  using Base::I3;
291  using Base::I4;
292  using Base::I5;
293  using Base::I6;
294  using Base::I7;
295 
296  using Base::AK0Number;
297  using Base::AK1Number;
298  using Base::BK0Number;
299  using Base::BK1Number;
300 
301  using Base::APackedSize;
302  using Base::BPackedSize;
303 
307  using Base::CalculateKRead;
308  using Base::CalculateMBlock;
310  using Base::CalculateNBlock;
317 
319 
321 
322  using Base::NumATensor;
323  using Base::NumBTensor;
324  using Base::NumDTensor;
325  using typename Base::AsGridPointer;
326  using typename Base::BsGridPointer;
327  using typename Base::DsGridPointer;
328  using AsDataType_ = AsDataType;
329  using BsDataType_ = BsDataType;
330 
331  struct Problem
332  {
333  __host__ Problem(index_t M_,
334  index_t N_,
335  index_t K_,
336  std::array<index_t, NumATensor> StrideAs_,
337  std::array<index_t, NumBTensor> StrideBs_,
338  std::array<index_t, NumDTensor> StrideDs_,
339  index_t StrideE_,
340  index_t KBatch_)
341  : M{M_},
342  N{N_},
343  K{K_},
344  StrideAs{StrideAs_},
345  StrideBs{StrideBs_},
346  StrideDs{StrideDs_},
347  StrideE{StrideE_},
348  KBatch{KBatch_},
351  KRead{CalculateKRead(K_, KBatch_)},
352  KPadded{CalculateKPadded(K_, KBatch_)},
353  AK0{CalculateAK0Padded(K_, KBatch_)},
354  BK0{CalculateBK0Padded(K_, KBatch_)},
355  MBlock{CalculateMBlock(M_)},
357  {
358  }
359 
360  __host__ void Print() const
361  {
362  std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
363  << "SAs: {";
364  static_for<0, NumATensor, 1>{}([&](auto i) {
365  std::cout << StrideAs[i] << (i.value < NumATensor - 1 ? ", " : "");
366  });
367  std::cout << "}, " << "SBs: {";
368  static_for<0, NumBTensor, 1>{}([&](auto i) {
369  std::cout << StrideBs[i] << (i.value < NumBTensor - 1 ? ", " : "");
370  });
371  std::cout << "}, ";
372  if constexpr(NumDTensor > 0)
373  {
374  std::cout << "SDs: { ";
375  static_for<0, NumDTensor, 1>{}([&](auto i) {
376  std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : "");
377  });
378  std::cout << " }, ";
379  }
380  std::cout << "SE:" << StrideE << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
381  << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
382  << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
383  << ", " << "NBlock: " << NBlock << "}" << std::endl;
384  }
385 
389  std::array<index_t, NumATensor> StrideAs;
390  std::array<index_t, NumBTensor> StrideBs;
391  std::array<index_t, NumDTensor> StrideDs;
402  };
403 
404  // Argument
406  {
407  __host__ Argument(std::array<const void*, NumATensor> p_as_grid_,
408  std::array<const void*, NumBTensor> p_bs_grid_,
409  std::array<const void*, NumDTensor> p_ds_grid_,
410  EDataType* p_e_grid_,
411  index_t M_,
412  index_t N_,
413  index_t K_,
414  std::array<index_t, NumATensor> StrideAs_,
415  std::array<index_t, NumBTensor> StrideBs_,
416  std::array<index_t, NumDTensor> StrideDs_,
417  index_t StrideE_,
418  index_t k_batch_,
419  AElementwiseOperation a_element_op_,
420  BElementwiseOperation b_element_op_,
421  CDEElementwiseOperation cde_element_op_,
422  bool is_reduce_ = false)
423  : Problem{M_, N_, K_, StrideAs_, StrideBs_, StrideDs_, StrideE_, k_batch_},
424  p_as_grid{},
425  p_bs_grid{},
426  p_ds_grid{},
427  p_e_grid{p_e_grid_},
428  a_element_op{a_element_op_},
429  b_element_op{b_element_op_},
430  cde_element_op{cde_element_op_},
431  is_reduce(is_reduce_)
432  {
433  // populate pointer, desc for As
434  static_for<0, NumATensor, 1>{}([&](auto i) {
435  using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
436 
437  // A pointer
438  p_as_grid(i) = static_cast<const ADataType_*>(p_as_grid_[i]);
439  });
440 
441  // populate pointer, desc for Bs
442  static_for<0, NumBTensor, 1>{}([&](auto i) {
443  using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
444 
445  // B pointer
446  p_bs_grid(i) = static_cast<const BDataType_*>(p_bs_grid_[i]);
447  });
448 
449  // populate pointer, desc for Ds
450  static_for<0, NumDTensor, 1>{}([&](auto i) {
451  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
452 
453  // D pointer
454  p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
455  });
456  }
457 
458  __host__ __device__ inline bool IsReduceAdd() const
459  {
460  return (Problem::KBatch > 1) && is_reduce;
461  }
462 
463  __host__ __device__ inline bool IsAtomicAdd() const
464  {
465  return (Problem::KBatch > 1) && (!is_reduce);
466  }
467 
471  EDataType* p_e_grid;
472 
473  AElementwiseOperation a_element_op;
474  BElementwiseOperation b_element_op;
475  CDEElementwiseOperation cde_element_op;
476 
477  // TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd
478  bool is_reduce;
479  };
480 
482  {
483 
484  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
485  {
486  // Note: in xdl implementation multiple AB supports one layout
487  // but multiple strides, so we create an array of offsets with
488  // the same values.
489  // It should be fixed later on. Once we will have a thread transfer
490  // more flexible.
491  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
492  {
494  [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead / APackedSize; });
495  }
496  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
497  {
499  [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; });
500  }
501 
502  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
503  {
505  [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; });
506  }
507  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
508  {
509  if constexpr(!PermuteB)
510  {
512  [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
513  }
514  else
515  {
516  const int k0_offset = karg.KRead * karg.N;
518  [&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
519  }
520  }
521 
522  if(k_id < karg.KBatch - 1)
523  {
524  karg.K = karg.KRead;
525  }
526  else
527  {
528  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
529  }
530 
531  if(karg.IsReduceAdd())
532  {
533  c_reduce_offset = k_id * karg.M * karg.N;
534  }
535  else
536  {
537  c_reduce_offset = 0;
538  }
539  }
540 
541  std::array<index_t, NumATensor> a_k_split_offset;
542  std::array<index_t, NumBTensor> b_k_split_offset;
544  };
545 
547 
548  // return block_id to C matrix tile idx (m0, n0) mapping
549  // if arch = gfx942
551  // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
552 
553  __device__ static index_t GetKBlockPerScale() { return 1; }
554 
555  template <bool HasMainKBlockLoop,
556  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
557  TailNumber TailNum,
558  typename Block2CTileMap,
559  typename EpilogueArgument,
560  int BlockMapMBlockIndex = 0,
561  int BlockMapNBlockIndex = 1>
562  __device__ static void Run(AsGridPointer& p_as_grid,
563  BsGridPointer& p_bs_grid,
564  DsGridPointer& p_ds_grid,
565  EDataType* p_e_grid,
566  void* p_shared,
567  const Problem& problem,
568  const Block2CTileMap& block_2_ctile_map,
569  AElementwiseOperation a_element_op,
570  BElementwiseOperation b_element_op,
571  CDEElementwiseOperation cde_element_op,
572  EpilogueArgument& epilogue_args)
573  {
574  const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
575  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
576  const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
577  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
578  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
579  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
580  const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
581  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
582  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
584  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
585  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
587  e_grid_desc_m_n, problem.MBlock, problem.NBlock);
588 
589  const auto block_work_idx =
591 
592  if(!block_2_ctile_map.ValidCTileIndex(
593  block_work_idx,
594  make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
595  e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
596  {
597  return;
598  }
599 
600  const index_t block_m_id =
601  __builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapMBlockIndex>{}]);
602  const index_t block_n_id =
603  __builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapNBlockIndex>{}]);
604 
605  // BScale struct (Empty)
606  using BScale = typename BlockwiseGemmPipe::Empty;
607  auto b_scale_struct = BScale{};
608 
609  const index_t num_k_block_per_scale = GetKBlockPerScale();
610 
611  Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
612  decltype(bs_grid_desc_bk0_n_bk1),
613  decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
614  decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
615  decltype(b_scale_struct),
616  decltype(epilogue_args),
617  HasMainKBlockLoop,
618  EGlobalMemoryDataOperation,
619  TailNum>(p_as_grid,
620  p_bs_grid,
621  p_ds_grid,
622  p_e_grid,
623  p_shared,
624  as_grid_desc_ak0_m_ak1,
625  bs_grid_desc_bk0_n_bk1,
626  ds_grid_desc_mblock_mperblock_nblock_nperblock,
627  e_grid_desc_mblock_mperblock_nblock_nperblock,
628  a_element_op,
629  b_element_op,
630  cde_element_op,
631  block_m_id,
632  block_n_id,
633  num_k_block_per_scale,
634  b_scale_struct,
635  epilogue_args);
636  }
637 
638  template <bool HasMainKBlockLoop,
639  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
640  TailNumber TailNum,
641  typename EpilogueArgument>
642  __device__ static void Run(AsGridPointer& p_as_grid,
643  BsGridPointer& p_bs_grid,
644  DsGridPointer& p_ds_grid,
645  EDataType* p_e_grid,
646  void* p_shared,
647  const Problem& problem,
648  AElementwiseOperation a_element_op,
649  BElementwiseOperation b_element_op,
650  CDEElementwiseOperation cde_element_op,
651  EpilogueArgument& epilogue_args)
652  {
653  Run<HasMainKBlockLoop,
654  EGlobalMemoryDataOperation,
655  TailNum,
657  EpilogueArgument>(p_as_grid,
658  p_bs_grid,
659  p_ds_grid,
660  p_e_grid,
661  p_shared,
662  problem,
663  DefaultBlock2CTileMap(problem),
664  a_element_op,
665  b_element_op,
666  cde_element_op,
667  epilogue_args);
668  }
669 
670  // Wrapper function to have __global__ function in common
671  // between gemm_universal, b_scale, ab_scale, etc.
672  template <bool HasMainKBlockLoop,
673  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
674  TailNumber TailNum,
675  typename Block2CTileMap,
676  typename EpilogueArgument,
677  int BlockMapMBlockIndex = 0,
678  int BlockMapNBlockIndex = 1>
679  __device__ static void Run(void* p_shared,
680  const SplitKBatchOffset& splitk_batch_offset,
681  Argument& karg,
682  const Block2CTileMap& block_2_ctile_map,
683  EpilogueArgument& epilogue_args)
684  {
685  // shift A matrices pointer for splitk
686  AsGridPointer p_as_grid_splitk;
687  static_for<0, NumATensor, 1>{}([&](auto i) {
688  using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
689  p_as_grid_splitk(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
690  splitk_batch_offset.a_k_split_offset[i];
691  });
692 
693  // shift B matrices pointer for splitk
694  BsGridPointer p_bs_grid_splitk;
695  static_for<0, NumBTensor, 1>{}([&](auto i) {
696  using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
697  p_bs_grid_splitk(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
698  splitk_batch_offset.b_k_split_offset[i];
699  });
700 
701  Run<HasMainKBlockLoop,
702  EGlobalMemoryDataOperation,
703  TailNum,
705  EpilogueArgument,
706  BlockMapMBlockIndex,
707  BlockMapNBlockIndex>(p_as_grid_splitk,
708  p_bs_grid_splitk,
709  karg.p_ds_grid,
710  karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
711  p_shared,
712  karg,
713  block_2_ctile_map,
714  karg.a_element_op,
715  karg.b_element_op,
716  karg.cde_element_op,
717  epilogue_args);
718  }
719 
720  // Wrapper function to have __global__ function in common
721  // between gemm_universal, b_scale, ab_scale, etc.
722  template <bool HasMainKBlockLoop,
723  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
724  TailNumber TailNum,
725  typename EpilogueArgument>
726  __device__ static void Run(void* p_shared,
727  const SplitKBatchOffset& splitk_batch_offset,
728  Argument& karg,
729  EpilogueArgument& epilogue_args)
730  {
731  Run<HasMainKBlockLoop,
732  EGlobalMemoryDataOperation,
733  TailNum,
735  EpilogueArgument>(
736  p_shared, splitk_batch_offset, karg, DefaultBlock2CTileMap(karg), epilogue_args);
737  }
738 
739  __device__ static auto DefaultBlock2CTileMap(const Problem& problem)
740  {
741  return Block2CTileMap{problem.M, problem.N, 4};
742  }
743 };
744 
745 } // namespace ck
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:270
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
InMemoryDataOperationEnum
Definition: ck.hpp:279
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:301
Definition: block_to_ctile_map.hpp:271
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:298
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition: block_to_ctile_map.hpp:384
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:406
AElementwiseOperation a_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:473
CDEElementwiseOperation cde_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:475
__host__ Argument(std::array< const void *, NumATensor > p_as_grid_, std::array< const void *, NumBTensor > p_bs_grid_, std::array< const void *, NumDTensor > p_ds_grid_, EDataType *p_e_grid_, index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CDEElementwiseOperation cde_element_op_, bool is_reduce_=false)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:407
BElementwiseOperation b_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:474
bool is_reduce
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:478
EDataType * p_e_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:471
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:463
BsGridPointer p_bs_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:469
AsGridPointer p_as_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:468
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:458
DsGridPointer p_ds_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:470
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:332
std::array< index_t, NumBTensor > StrideBs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:390
index_t AK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:398
index_t N
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:387
index_t K
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:388
index_t BK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:399
index_t NPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:395
index_t KPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:397
index_t StrideE
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:392
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:391
index_t NBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:401
std::array< index_t, NumATensor > StrideAs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:389
__host__ void Print() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:360
index_t KBatch
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:393
index_t M
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:386
__host__ Problem(index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t KBatch_)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:333
index_t MBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:400
index_t MPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:394
index_t KRead
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:396
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:482
index_t c_reduce_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:543
std::array< index_t, NumATensor > a_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:541
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:484
std::array< index_t, NumBTensor > b_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:542
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:123
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:305
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:127
static constexpr __device__ auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:650
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, LDSTypeA, LDSTypeB, ComputeTypeA, ComputeTypeB, AccDataType, decltype(MakeAWmmaTileDescriptor()), decltype(MakeBWmmaTileDescriptor()), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner >())> BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:584
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:552
static constexpr auto I3
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:128
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:340
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:310
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:335
decltype(MakeAsGridPointer()) AsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:367
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:126
static constexpr index_t NumATensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:134
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:328
static constexpr auto AK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:152
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:538
__host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:422
__host__ static __device__ auto MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:463
__host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:395
static constexpr auto I6
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:131
static constexpr auto AK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:150
static constexpr index_t NumBTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:135
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:125
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:540
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:172
static constexpr auto I7
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:132
static constexpr auto I4
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:129
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:179
static constexpr auto BK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:153
static constexpr auto BK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:151
decltype(MakeBsGridPointer()) BsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:368
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:295
static constexpr index_t NumDTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:525
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:300
static constexpr auto I5
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:130
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:316
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:233
static __device__ void Run(AsGridPointer &p_as_grid, BsGridPointer &p_bs_grid, DsGridPointer &p_ds_grid, EDataType *p_e_grid, void *p_shared, const Problem &problem, const Block2CTileMap &block_2_ctile_map, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, EpilogueArgument &epilogue_args)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:562
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:305
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:127
static constexpr __device__ auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:650
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:552
AsDataType AsDataType_
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:328
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:340
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:310
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:335
decltype(MakeAsGridPointer()) AsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:367
static constexpr index_t NumATensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:134
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:328
BlockToCTileMap_Grouped_M00_N0_M01Adapt< 8, MPerBlock, NPerBlock > Block2CTileMap
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:550
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:538
__host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:422
static __device__ void Run(void *p_shared, const SplitKBatchOffset &splitk_batch_offset, Argument &karg, EpilogueArgument &epilogue_args)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:726
__host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:395
static constexpr index_t NumBTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:135
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:125
BsDataType BsDataType_
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:329
static __device__ void Run(void *p_shared, const SplitKBatchOffset &splitk_batch_offset, Argument &karg, const Block2CTileMap &block_2_ctile_map, EpilogueArgument &epilogue_args)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:679
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:540
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:320
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:172
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:179
static __device__ void Run(AsGridPointer &p_as_grid, BsGridPointer &p_bs_grid, DsGridPointer &p_ds_grid, EDataType *p_e_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, EpilogueArgument &epilogue_args)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:642
static __device__ index_t GetKBlockPerScale()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:553
decltype(MakeBsGridPointer()) BsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:368
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:295
static constexpr index_t NumDTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:525
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:300
typename Base::BlockwiseGemmPipe BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:546
static __device__ auto DefaultBlock2CTileMap(const Problem &problem)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:739
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:316
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: device_base.hpp:197