/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Source File
device_gemm_xdl_cshuffle_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
19 
20 namespace ck {
21 namespace tensor_operation {
22 namespace device {
23 
24 template <typename ALayout,
25  typename BLayout,
26  typename CLayout,
27  typename ADataType,
28  typename BDataType,
29  typename CDataType,
30  typename GemmAccDataType,
31  typename CShuffleDataType,
32  typename AElementwiseOperation,
33  typename BElementwiseOperation,
34  typename CElementwiseOperation,
35  GemmSpecialization GemmSpec,
36  index_t BlockSize,
37  index_t MPerBlock,
38  index_t NPerBlock,
39  index_t KPerBlock,
40  index_t AK1,
41  index_t BK1,
42  index_t MPerXDL,
43  index_t NPerXDL,
44  index_t MXdlPerWave,
45  index_t NXdlPerWave,
46  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
47  typename ABlockTransferThreadClusterArrangeOrder,
48  typename ABlockTransferSrcAccessOrder,
49  index_t ABlockTransferSrcVectorDim,
50  index_t ABlockTransferSrcScalarPerVector,
51  index_t ABlockTransferDstScalarPerVector_AK1,
52  bool ABlockLdsExtraM,
53  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
54  typename BBlockTransferThreadClusterArrangeOrder,
55  typename BBlockTransferSrcAccessOrder,
56  index_t BBlockTransferSrcVectorDim,
57  index_t BBlockTransferSrcScalarPerVector,
58  index_t BBlockTransferDstScalarPerVector_BK1,
59  bool BBlockLdsExtraN,
60  index_t CShuffleMXdlPerWavePerShuffle,
61  index_t CShuffleNXdlPerWavePerShuffle,
62  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
63  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
66  typename ComputeTypeA = CDataType,
67  typename ComputeTypeB = ComputeTypeA,
68  bool PermuteA = false,
69  bool PermuteB = false>
70 struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
71  BLayout,
72  CLayout,
73  ADataType,
74  BDataType,
75  CDataType,
76  AElementwiseOperation,
77  BElementwiseOperation,
78  CElementwiseOperation>
79 {
80  // GridwiseGemm
82  ALayout,
83  BLayout,
84  CLayout,
85  ADataType,
86  BDataType,
87  GemmAccDataType,
88  CShuffleDataType,
89  CDataType,
90  AElementwiseOperation,
91  BElementwiseOperation,
92  CElementwiseOperation,
93  GemmSpec,
94  BlockSize,
95  MPerBlock,
96  NPerBlock,
97  KPerBlock,
98  AK1,
99  BK1,
100  MPerXDL,
101  NPerXDL,
102  MXdlPerWave,
103  NXdlPerWave,
104  ABlockTransferThreadClusterLengths_AK0_M_AK1,
105  ABlockTransferThreadClusterArrangeOrder,
106  ABlockTransferSrcAccessOrder,
107  ABlockTransferSrcVectorDim,
108  ABlockTransferSrcScalarPerVector,
109  ABlockTransferDstScalarPerVector_AK1,
110  false,
111  ABlockLdsExtraM,
112  BBlockTransferThreadClusterLengths_BK0_N_BK1,
113  BBlockTransferThreadClusterArrangeOrder,
114  BBlockTransferSrcAccessOrder,
115  BBlockTransferSrcVectorDim,
116  BBlockTransferSrcScalarPerVector,
117  BBlockTransferDstScalarPerVector_BK1,
118  false,
119  BBlockLdsExtraN,
120  CShuffleMXdlPerWavePerShuffle,
121  CShuffleNXdlPerWavePerShuffle,
122  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
123  CShuffleBlockTransferScalarPerVector_NPerBlock,
124  BlkGemmPipeSched,
125  BlkGemmPipelineVer,
126  ComputeTypeA,
127  ComputeTypeB,
128  PermuteA,
129  PermuteB>;
130 
132 
133  // Invoker
134  struct Invoker : public BaseInvoker
135  {
136  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
137  {
138  if(stream_config.log_level_ > 0)
139  {
140  arg.Print();
141  GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
142  }
143 
145  {
146  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
147  }
148 
149  index_t gdx, gdy, gdz;
150  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
151 
152  float ave_time = 0;
153 
154  index_t k_grain = arg.KBatch * KPerBlock;
155  index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
156 
157  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
158 
159  const auto Run = [&](const auto& kernel) {
160  if(stream_config.flush_cache)
161  {
162  Argument arg_ = arg;
163 
164  const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
165  arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
166  const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
167  arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
168 
169  auto size_a_buffer =
170  a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
171  auto size_b_buffer =
172  b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
173 
175  arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
176  rotating_mem.Print();
177 
178  auto run_flush_cache = [&]() {
179  // flush icache
181  // rotating mem
182  rotating_mem.Next();
183  // clear c mem
184  if(arg_.KBatch > 1)
185  hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
186  0,
187  arg_.M * arg_.N * sizeof(CDataType),
188  stream_config.stream_id_));
189  };
190 
191  ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
192  stream_config,
193  run_flush_cache,
194  kernel,
195  dim3(gdx, gdy, gdz),
196  dim3(BlockSize),
197  0,
198  arg_);
199  }
200  else
201  {
202  if(arg.KBatch > 1)
203  hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
204  0,
205  arg.M * arg.N * sizeof(CDataType),
206  stream_config.stream_id_));
207 
208  ave_time = launch_and_time_kernel(
209  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
210  }
211  };
212 
213  constexpr index_t minimum_occupancy =
214  BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
215 
216  if(has_main_k_block_loop)
217  {
218  // Tail number always full
219  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
220  BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
221  {
222  if(arg.KBatch > 1)
223  {
224  const auto kernel =
226  true,
228  minimum_occupancy>;
229  Run(kernel);
230  }
231  else
232  {
233  const auto kernel =
235  true,
237  minimum_occupancy>;
238  Run(kernel);
239  }
240  }
241  // Tail number could be One to Seven
242  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
243  {
244  if(arg.KBatch > 1)
245  {
247  {
248  const auto kernel =
250  true,
252  minimum_occupancy,
254  Run(kernel);
255  }
256  else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
258  {
259  const auto kernel =
261  true,
263  minimum_occupancy,
265  Run(kernel);
266  }
267 
268  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
269  {
271  {
272  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
273  GridwiseGemm,
274  true,
276  minimum_occupancy,
278  Run(kernel);
279  }
280  }
281 
282  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
283  {
286  {
287  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
288  GridwiseGemm,
289  true,
291  minimum_occupancy,
293  Run(kernel);
294  }
295  }
296 
297  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
298  {
301  {
302  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
303  GridwiseGemm,
304  true,
306  minimum_occupancy,
308  Run(kernel);
309  }
310  }
311 
312  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
313  {
316  {
317  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
318  GridwiseGemm,
319  true,
321  minimum_occupancy,
323  Run(kernel);
324  }
325  }
326 
327  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
328  {
330  {
331  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
332  GridwiseGemm,
333  true,
335  minimum_occupancy,
337  Run(kernel);
338  }
339  }
340 
341  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
342  {
345  {
346  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
347  GridwiseGemm,
348  true,
350  minimum_occupancy,
352  Run(kernel);
353  }
354  }
355  }
356  else
357  {
359  {
360  const auto kernel =
362  true,
364  minimum_occupancy,
366  Run(kernel);
367  }
368  else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
370  {
371  const auto kernel =
373  true,
375  minimum_occupancy,
377  Run(kernel);
378  }
379 
380  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
381  {
383  {
384  const auto kernel =
386  true,
388  minimum_occupancy,
390  Run(kernel);
391  }
392  }
393 
394  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
395  {
398  {
399  const auto kernel =
401  true,
403  minimum_occupancy,
405  Run(kernel);
406  }
407  }
408 
409  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
410  {
413  {
414  const auto kernel =
416  true,
418  minimum_occupancy,
420  Run(kernel);
421  }
422  }
423 
424  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
425  {
428  {
429  const auto kernel =
431  true,
433  minimum_occupancy,
435  Run(kernel);
436  }
437  }
438 
439  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
440  {
442  {
443  const auto kernel =
445  true,
447  minimum_occupancy,
449  Run(kernel);
450  }
451  }
452 
453  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
454  {
457  {
458  const auto kernel =
460  true,
462  minimum_occupancy,
464  Run(kernel);
465  }
466  }
467  }
468  }
469  // Tail number could be Odd or Even
470  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
471  {
472  if(arg.KBatch > 1)
473  {
475  {
476  const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
477  GridwiseGemm,
478  true,
480  minimum_occupancy,
482  Run(kernel);
483  }
484  else
485  {
486  const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
487  GridwiseGemm,
488  true,
490  minimum_occupancy,
492  Run(kernel);
493  }
494  }
495  else
496  {
498  {
499  const auto kernel =
501  true,
503  minimum_occupancy,
505  Run(kernel);
506  }
507  else
508  {
509  const auto kernel =
511  true,
513  minimum_occupancy,
515  Run(kernel);
516  }
517  }
518  }
519  else
520  {
521  if(arg.KBatch > 1)
522  {
524  {
525  const auto kernel =
527  true,
529  minimum_occupancy,
531  Run(kernel);
532  }
533  else
534  {
535  const auto kernel =
537  true,
539  minimum_occupancy,
541  Run(kernel);
542  }
543  }
544  else
545  {
547  {
548  const auto kernel =
550  true,
552  minimum_occupancy,
554  Run(kernel);
555  }
556  else
557  {
558  const auto kernel =
560  true,
562  minimum_occupancy,
564  Run(kernel);
565  }
566  }
567  }
568  }
569  else
570  {
571  // Tail number always 1
572  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
573  {
574  if(arg.KBatch > 1)
575  {
576  const auto kernel =
578  false,
580  minimum_occupancy>;
581  Run(kernel);
582  }
583  else
584  {
585  const auto kernel =
587  false,
589  minimum_occupancy>;
590  Run(kernel);
591  }
592  }
593  }
594 
595  return ave_time;
596  }
597 
598  // polymorphic
599  float Run(const BaseArgument* p_arg,
600  const StreamConfig& stream_config = StreamConfig{}) override
601  {
602  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
603  }
604  };
605 
606  static constexpr bool IsValidCompilationParameter()
607  {
608  // TODO: properly implement this check
609  return true;
610  }
611 
612  static bool IsSupportedArgument(const Argument& arg)
613  {
614  if(!ck::is_xdl_supported())
615  {
616  return false;
617  }
618 
619  if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
620  {
621  return false;
622  }
623 
624  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
625  GemmSpec == GemmSpecialization::NKPadding ||
626  GemmSpec == GemmSpecialization::MNKPadding ||
627  GemmSpec == GemmSpecialization::KPadding))
628  {
629  return false;
630  }
631 
632  return GridwiseGemm::CheckValidity(arg);
633  }
634 
635  // polymorphic
636  bool IsSupportedArgument(const BaseArgument* p_arg) override
637  {
638  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
639  }
640 
641  index_t GetKPerBlock() override { return KPerBlock; }
642 
643  bool GetPermuteA() override { return PermuteA; }
644  bool GetPermuteB() override { return PermuteB; }
645 
646  static auto MakeArgument(const ADataType* p_a,
647  const BDataType* p_b,
648  CDataType* p_c,
649  index_t M,
650  index_t N,
651  index_t K,
652  index_t StrideA,
653  index_t StrideB,
654  index_t StrideC,
655  index_t KBatch,
656  AElementwiseOperation,
657  BElementwiseOperation,
658  CElementwiseOperation)
659  {
660  return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
661  }
662 
663  static auto MakeInvoker() { return Invoker{}; }
664 
665  // polymorphic
666  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
667  const void* p_b,
668  void* p_c,
669  index_t M,
670  index_t N,
671  index_t K,
672  index_t StrideA,
673  index_t StrideB,
674  index_t StrideC,
675  index_t KBatch,
676  AElementwiseOperation,
677  BElementwiseOperation,
678  CElementwiseOperation) override
679  {
680  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
681  static_cast<const BDataType*>(p_b),
682  static_cast<CDataType*>(p_c),
683  M,
684  N,
685  K,
686  StrideA,
687  StrideB,
688  StrideC,
689  KBatch);
690  }
691 
692  // polymorphic
693  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
694  {
695  return std::make_unique<Invoker>(Invoker{});
696  }
697 
698  // polymorphic
699  std::string GetTypeString() const override
700  {
701  auto str = std::stringstream();
702 
703  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
706 
707  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
713 
714  // clang-format off
715  str << "DeviceGemmXdlUniversal"
716  << "<"
717  << getGemmSpecializationString(GemmSpec) << ", "
718  << std::string(ALayout::name)[0]
719  << std::string(BLayout::name)[0]
720  << std::string(CLayout::name)[0]
721  << ">"
722  << " BlkSize: "
723  << BlockSize << ", "
724  << "BlkTile: "
725  << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
726  << "WaveTile: "
727  << MPerXDL<<"x"<<NPerXDL << ", "
728  << "WaveMap: "
729  << MXdlPerWave<<"x" << NXdlPerWave<<", "
730  << "VmemReadVec: "
731  << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
732  << "BlkGemmPipelineScheduler: "
733  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
734  << "BlkGemmPipelineVersion: "
735  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
736  << "BlkGemmPipelinePrefetchStages: "
737  << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
738  << "Kpack: "
739  << GridwiseGemm::BlockwiseGemmPipe::AMmaKStride;
740  // clang-format on
741 
742  return str.str();
743  }
745 };
746 
747 } // namespace device
748 } // namespace tensor_operation
749 } // namespace ck
#define REGISTER_EXTRA_PRINTING_METHODS
Definition: device_base.hpp:45
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:33
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:215
Definition: ck.hpp:264
bool is_xdl_supported()
Definition: device_prop.hpp:54
BlockGemmPipelineVersion
Definition: blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp:13
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:58
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:12
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:37
int32_t index_t
Definition: ck.hpp:289
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:69
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:241
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:66
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:610
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:322
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1004
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:603
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:88
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:240
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_gemm_xdl_cshuffle_v3.hpp:135
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl_cshuffle_v3.hpp:599
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl_cshuffle_v3.hpp:136
Definition: device_gemm_xdl_cshuffle_v3.hpp:79
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl_cshuffle_v3.hpp:636
bool GetPermuteA() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:643
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:693
index_t GetKPerBlock() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:641
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_xdl_cshuffle_v3.hpp:612
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_xdl_cshuffle_v3.hpp:606
std::string GetTypeString() const override
Definition: device_gemm_xdl_cshuffle_v3.hpp:699
GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemm
Definition: device_gemm_xdl_cshuffle_v3.hpp:129
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_xdl_cshuffle_v3.hpp:646
bool GetPermuteB() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:644
static auto MakeInvoker()
Definition: device_gemm_xdl_cshuffle_v3.hpp:663
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_xdl_cshuffle_v3.hpp:666
typename GridwiseGemm::Argument Argument
Definition: device_gemm_xdl_cshuffle_v3.hpp:131
Definition: device_gemm_v2.hpp:22
Definition: flush_cache.hpp:137