/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File
gridwise_gemm_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/utility/env.hpp"
18 
19 namespace ck {
20 
27 // operations that could be applied on each tensor respectively. The CDE_op is an
28 // elementwise operation applied to the C and all D tensors.
128 template <typename ALayout,
129  typename BLayout,
130  typename DsLayout,
131  typename ELayout,
132  typename ADataType,
133  typename BDataType,
134  typename AccDataType,
135  typename CShuffleDataType,
136  typename DsDataType,
137  typename EDataType,
138  typename AElementwiseOperation,
139  typename BElementwiseOperation,
140  typename CDEElementwiseOperation,
142  index_t BlockSize,
143  index_t MPerBlock,
144  index_t NPerBlock,
145  index_t KPerBlock,
146  index_t AK1Value,
147  index_t BK1Value,
148  index_t MPerWmma,
149  index_t NPerWmma,
150  index_t MRepeat,
151  index_t NRepeat,
152  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
153  typename ABlockTransferThreadClusterArrangeOrder,
154  typename ABlockTransferSrcAccessOrder,
155  index_t ABlockTransferSrcVectorDim,
156  index_t ABlockTransferSrcScalarPerVector,
157  index_t ABlockTransferDstScalarPerVector_AK1,
158  bool AThreadTransferSrcResetCoordinateAfterRun,
159  index_t ABlockLdsExtraM,
160  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
161  typename BBlockTransferThreadClusterArrangeOrder,
162  typename BBlockTransferSrcAccessOrder,
163  index_t BBlockTransferSrcVectorDim,
164  index_t BBlockTransferSrcScalarPerVector,
165  index_t BBlockTransferDstScalarPerVector_BK1,
166  bool BThreadTransferSrcResetCoordinateAfterRun,
167  index_t BBlockLdsExtraN,
168  index_t CShuffleMRepeatPerShuffle,
169  index_t CShuffleNRepeatPerShuffle,
170  typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
171  typename CDEShuffleBlockTransferScalarPerVectors,
172  BlockGemmPipelineScheduler BlkGemmPipeSched,
173  BlockGemmPipelineVersion BlkGemmPipelineVer,
174  typename ComputeTypeA,
175  typename ComputeTypeB,
176  bool PermuteA,
177  bool PermuteB>
180  ALayout,
181  BLayout,
182  DsLayout,
183  ELayout,
184  ADataType,
185  BDataType,
186  AccDataType,
187  CShuffleDataType,
188  DsDataType,
189  EDataType,
190  AElementwiseOperation,
191  BElementwiseOperation,
192  CDEElementwiseOperation,
193  GemmSpec,
194  BlockSize,
195  MPerBlock,
196  NPerBlock,
197  KPerBlock,
198  AK1Value,
199  BK1Value,
200  MPerWmma,
201  NPerWmma,
202  MRepeat,
203  NRepeat,
204  ABlockTransferThreadClusterLengths_AK0_M_AK1,
205  ABlockTransferThreadClusterArrangeOrder,
206  ABlockTransferSrcAccessOrder,
207  ABlockTransferSrcVectorDim,
208  ABlockTransferSrcScalarPerVector,
209  ABlockTransferDstScalarPerVector_AK1,
210  AThreadTransferSrcResetCoordinateAfterRun,
211  ABlockLdsExtraM,
212  BBlockTransferThreadClusterLengths_BK0_N_BK1,
213  BBlockTransferThreadClusterArrangeOrder,
214  BBlockTransferSrcAccessOrder,
215  BBlockTransferSrcVectorDim,
216  BBlockTransferSrcScalarPerVector,
217  BBlockTransferDstScalarPerVector_BK1,
218  BThreadTransferSrcResetCoordinateAfterRun,
219  BBlockLdsExtraN,
220  CShuffleMRepeatPerShuffle,
221  CShuffleNRepeatPerShuffle,
222  CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
223  CDEShuffleBlockTransferScalarPerVectors,
224  BlkGemmPipeSched,
225  BlkGemmPipelineVer,
226  ComputeTypeA,
227  ComputeTypeB,
228  PermuteA,
229  PermuteB>
230 {
232  ALayout,
233  BLayout,
234  DsLayout,
235  ELayout,
236  ADataType,
237  BDataType,
238  AccDataType,
239  CShuffleDataType,
240  DsDataType,
241  EDataType,
242  AElementwiseOperation,
243  BElementwiseOperation,
244  CDEElementwiseOperation,
245  GemmSpec,
246  BlockSize,
247  MPerBlock,
248  NPerBlock,
249  KPerBlock,
250  AK1Value,
251  BK1Value,
252  MPerWmma,
253  NPerWmma,
254  MRepeat,
255  NRepeat,
256  ABlockTransferThreadClusterLengths_AK0_M_AK1,
257  ABlockTransferThreadClusterArrangeOrder,
258  ABlockTransferSrcAccessOrder,
259  ABlockTransferSrcVectorDim,
260  ABlockTransferSrcScalarPerVector,
261  ABlockTransferDstScalarPerVector_AK1,
262  AThreadTransferSrcResetCoordinateAfterRun,
263  ABlockLdsExtraM,
264  BBlockTransferThreadClusterLengths_BK0_N_BK1,
265  BBlockTransferThreadClusterArrangeOrder,
266  BBlockTransferSrcAccessOrder,
267  BBlockTransferSrcVectorDim,
268  BBlockTransferSrcScalarPerVector,
269  BBlockTransferDstScalarPerVector_BK1,
270  BThreadTransferSrcResetCoordinateAfterRun,
271  BBlockLdsExtraN,
272  CShuffleMRepeatPerShuffle,
273  CShuffleNRepeatPerShuffle,
274  CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
275  CDEShuffleBlockTransferScalarPerVectors,
276  BlkGemmPipeSched,
277  BlkGemmPipelineVer,
278  ComputeTypeA,
279  ComputeTypeB,
280  PermuteA,
281  PermuteB>;
282 
283  using Base::I0;
284  using Base::I1;
285  using Base::I2;
286  using Base::I3;
287  using Base::I4;
288  using Base::I5;
289  using Base::I6;
290  using Base::I7;
291 
292  using Base::AK0Number;
293  using Base::AK1Number;
294  using Base::BK0Number;
295  using Base::BK1Number;
296 
297  using Base::APackedSize;
298  using Base::BPackedSize;
299 
303  using Base::CalculateKRead;
304  using Base::CalculateMBlock;
306  using Base::CalculateNBlock;
313 
315 
317 
319 
322 
323  using Base::NumDTensor;
324  using typename Base::DsGridPointer;
325 
326  struct Problem
327  {
328  __host__ Problem(index_t M_,
329  index_t N_,
330  index_t K_,
331  index_t StrideA_,
332  index_t StrideB_,
333  std::array<index_t, NumDTensor> StrideDs_,
334  index_t StrideE_,
335  index_t KBatch_)
336  : M{M_},
337  N{N_},
338  K{K_},
339  StrideA{StrideA_},
340  StrideB{StrideB_},
341  StrideDs{StrideDs_},
342  StrideE{StrideE_},
343  KBatch{KBatch_},
346  KRead{CalculateKRead(K_, KBatch_)},
347  KPadded{CalculateKPadded(K_, KBatch_)},
348  AK0{CalculateAK0Padded(K_, KBatch_)},
349  BK0{CalculateBK0Padded(K_, KBatch_)},
350  MBlock{CalculateMBlock(M_)},
352  {
353  }
354 
355  __host__ void Print() const
356  {
357  std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
358  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", ";
359  if constexpr(NumDTensor > 0)
360  {
361  std::cout << "SDs: { ";
362  static_for<0, NumDTensor, 1>{}([&](auto i) {
363  std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : "");
364  });
365  std::cout << " }, ";
366  }
367  std::cout << "SE:" << StrideE << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
368  << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
369  << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
370  << ", " << "NBlock: " << NBlock << "}" << std::endl;
371  }
372 
378  std::array<index_t, NumDTensor> StrideDs;
389  };
390 
391  // Argument
393  {
394  __host__ Argument(const ADataType* p_a_grid_,
395  const BDataType* p_b_grid_,
396  std::array<const void*, NumDTensor> p_ds_grid_,
397  EDataType* p_e_grid_,
398  index_t M_,
399  index_t N_,
400  index_t K_,
401  index_t StrideA_,
402  index_t StrideB_,
403  std::array<index_t, NumDTensor> StrideDs_,
404  index_t StrideE_,
405  index_t k_batch_,
406  AElementwiseOperation a_element_op_,
407  BElementwiseOperation b_element_op_,
408  CDEElementwiseOperation cde_element_op_,
409  bool is_reduce_ = false)
410  : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, k_batch_},
411  p_a_grid{p_a_grid_},
412  p_b_grid{p_b_grid_},
413  p_ds_grid{},
414  p_e_grid{p_e_grid_},
415  a_element_op{a_element_op_},
416  b_element_op{b_element_op_},
417  cde_element_op{cde_element_op_},
418  is_reduce(is_reduce_)
419  {
420  static_for<0, NumDTensor, 1>{}([&](auto i) {
421  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
422 
423  p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
424  });
425  }
426 
427  __host__ __device__ inline bool IsReduceAdd() const
428  {
429  return (Problem::KBatch > 1) && is_reduce;
430  }
431 
432  __host__ __device__ inline bool IsAtomicAdd() const
433  {
434  return (Problem::KBatch > 1) && (!is_reduce);
435  }
436 
437  const ADataType* p_a_grid;
438  const BDataType* p_b_grid;
440  EDataType* p_e_grid;
441 
442  const AElementwiseOperation a_element_op;
443  const BElementwiseOperation b_element_op;
444  const CDEElementwiseOperation cde_element_op;
445 
446  // TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd
447  bool is_reduce;
448  };
449 
451  {
452 
453  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
454  {
455  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
456  {
457  a_k_split_offset = k_id * karg.KRead / APackedSize;
458  }
459  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
460  {
461  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
462  }
463 
464  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
465  {
466  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
467  }
468  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
469  {
470  if constexpr(!PermuteB)
471  {
472  b_k_split_offset = k_id * karg.KRead / BPackedSize;
473  }
474  else
475  {
476  const int k0_offset = karg.KRead * karg.N;
477  b_k_split_offset = k_id * k0_offset / BPackedSize;
478  }
479  }
480 
481  if(k_id < karg.KBatch - 1)
482  {
483  karg.K = karg.KRead;
484  }
485  else
486  {
487  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
488  }
489 
490  if(karg.IsReduceAdd())
491  {
492  c_reduce_offset = k_id * karg.M * karg.N;
493  }
494  else
495  {
496  c_reduce_offset = 0;
497  }
498  }
499 
503  };
504 
506 
507  // return block_id to C matrix tile idx (m0, n0) mapping
508  // if arch = gfx942
510  // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
511 
512  __device__ static index_t GetKBlockPerScale() { return 1; }
513 
514  template <bool HasMainKBlockLoop,
515  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
516  TailNumber TailNum>
517  __device__ static void Run(const ADataType* p_a_grid,
518  const BDataType* p_b_grid,
519  DsGridPointer& p_ds_grid,
520  EDataType* p_e_grid,
521  void* p_shared,
522  const Problem& problem,
523  AElementwiseOperation a_element_op,
524  BElementwiseOperation b_element_op,
525  CDEElementwiseOperation cde_element_op)
526  {
527  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
528  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
529  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
530  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
531  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
532  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
533  const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
534  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
535  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
537  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
538  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
540  e_grid_desc_m_n, problem.MBlock, problem.NBlock);
541 
542  // divide block work by [M, N]
543  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
544 
545  const auto block_work_idx =
546  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
547 
548  if(!block_2_ctile_map.ValidCTileIndex(
549  block_work_idx,
550  make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
551  e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
552  {
553  return;
554  }
555 
556  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
557  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
558 
559  // BScale struct (Empty)
560  using BScale = typename BlockwiseGemmPipe::Empty;
561  auto b_scale_struct = BScale{};
562 
563  const index_t num_k_block_per_scale = GetKBlockPerScale();
564 
565  Base::template Run<decltype(a_grid_desc_ak0_m_ak1),
566  decltype(b_grid_desc_bk0_n_bk1),
567  decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
568  decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
569  decltype(b_scale_struct),
570  HasMainKBlockLoop,
571  EGlobalMemoryDataOperation,
572  TailNum>(p_a_grid,
573  p_b_grid,
574  p_ds_grid,
575  p_e_grid,
576  p_shared,
577  a_grid_desc_ak0_m_ak1,
578  b_grid_desc_bk0_n_bk1,
579  ds_grid_desc_mblock_mperblock_nblock_nperblock,
580  e_grid_desc_mblock_mperblock_nblock_nperblock,
581  a_element_op,
582  b_element_op,
583  cde_element_op,
584  block_m_id,
585  block_n_id,
586  num_k_block_per_scale,
587  b_scale_struct);
588  }
589 
590  // Wrapper function to have __global__ function in common
591  // between gemm_universal, b_scale, ab_scale, etc.
592  template <bool HasMainKBlockLoop,
593  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
594  TailNumber TailNum>
595  __device__ static void
596  Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
597  {
598  Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
599  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
600  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
601  karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset,
602  karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
603  p_shared,
604  karg,
605  karg.a_element_op,
606  karg.b_element_op,
607  karg.cde_element_op);
608  }
609 };
610 
611 } // namespace ck
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
InMemoryDataOperationEnum
Definition: ck.hpp:276
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:298
Definition: block_to_ctile_map.hpp:270
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:393
const AElementwiseOperation a_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:442
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:432
const CDEElementwiseOperation cde_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:444
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, EDataType *p_e_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CDEElementwiseOperation cde_element_op_, bool is_reduce_=false)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:394
EDataType * p_e_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:440
bool is_reduce
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:447
const BDataType * p_b_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:438
const BElementwiseOperation b_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:443
const ADataType * p_a_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:437
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:427
DsGridPointer p_ds_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:439
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:327
index_t KPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:384
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:378
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t KBatch_)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:328
index_t BK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:386
index_t KRead
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:383
index_t K
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:375
index_t M
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:373
index_t N
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:374
index_t StrideB
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:377
index_t StrideE
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:379
index_t MBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:387
index_t AK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:385
index_t MPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:381
index_t NPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:382
index_t KBatch
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:380
index_t StrideA
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:376
__host__ void Print() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:355
index_t NBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:388
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:451
index_t c_reduce_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:502
index_t b_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:501
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:453
index_t a_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:500
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:106
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:108
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:133
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:678
static constexpr __device__ auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:853
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:172
static constexpr auto AK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:121
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:178
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:228
static constexpr auto I4
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:112
__host__ static __device__ auto MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:440
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:202
__host__ static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:814
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:162
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:109
static constexpr auto BK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:124
static constexpr auto I7
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:115
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:540
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:157
static constexpr auto AK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:123
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:515
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, decltype(MakeAWmmaTileDescriptor(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBWmmaTileDescriptor(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:850
static constexpr auto I3
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:111
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:197
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:140
static constexpr auto BK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:122
static constexpr auto I5
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:113
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:312
static constexpr auto I6
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:114
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:110
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:529
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:517
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:190
static constexpr index_t NumDTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:502
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:167
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:230
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:108
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:133
static __device__ void Run(void *p_shared, const SplitKBatchOffset &splitk_batch_offset, Argument &karg)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:596
static constexpr __device__ auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:853
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:172
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:178
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:228
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:202
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:162
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:109
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:157
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:197
typename Base::BlockwiseGemmPipe BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:505
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:140
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, EDataType *p_e_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:517
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:312
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:110
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:529
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:517
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:318
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:190
static __device__ index_t GetKBlockPerScale()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:512
static constexpr index_t NumDTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:502
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:167
Definition: functional2.hpp:33
Definition: device_base.hpp:51