/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Source File
device_gemm_xdl_cshuffle_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
19 
20 namespace ck {
21 namespace tensor_operation {
22 namespace device {
23 
123 template <typename ALayout,
124  typename BLayout,
125  typename CLayout,
126  typename ADataType,
127  typename BDataType,
128  typename CDataType,
129  typename GemmAccDataType,
130  typename CShuffleDataType,
131  typename AElementwiseOperation,
132  typename BElementwiseOperation,
133  typename CElementwiseOperation,
134  GemmSpecialization GemmSpec,
135  index_t BlockSize,
136  index_t MPerBlock,
137  index_t NPerBlock,
138  index_t KPerBlock,
139  index_t AK1,
140  index_t BK1,
141  index_t MPerXDL,
142  index_t NPerXDL,
143  index_t MXdlPerWave,
144  index_t NXdlPerWave,
145  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
146  typename ABlockTransferThreadClusterArrangeOrder,
147  typename ABlockTransferSrcAccessOrder,
148  index_t ABlockTransferSrcVectorDim,
149  index_t ABlockTransferSrcScalarPerVector,
150  index_t ABlockTransferDstScalarPerVector_AK1,
151  bool ABlockLdsExtraM,
152  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
153  typename BBlockTransferThreadClusterArrangeOrder,
154  typename BBlockTransferSrcAccessOrder,
155  index_t BBlockTransferSrcVectorDim,
156  index_t BBlockTransferSrcScalarPerVector,
157  index_t BBlockTransferDstScalarPerVector_BK1,
158  bool BBlockLdsExtraN,
159  index_t CShuffleMXdlPerWavePerShuffle,
160  index_t CShuffleNXdlPerWavePerShuffle,
161  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
162  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
165  typename ComputeTypeA = CDataType,
166  typename ComputeTypeB = ComputeTypeA,
167  bool PermuteA = false,
168  bool PermuteB = false>
169 struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
170  BLayout,
171  CLayout,
172  ADataType,
173  BDataType,
174  CDataType,
175  AElementwiseOperation,
176  BElementwiseOperation,
177  CElementwiseOperation>
178 {
179  // GridwiseGemm
181  ALayout,
182  BLayout,
183  CLayout,
184  ADataType,
185  BDataType,
186  GemmAccDataType,
187  CShuffleDataType,
188  CDataType,
189  AElementwiseOperation,
190  BElementwiseOperation,
191  CElementwiseOperation,
192  GemmSpec,
193  BlockSize,
194  MPerBlock,
195  NPerBlock,
196  KPerBlock,
197  AK1,
198  BK1,
199  MPerXDL,
200  NPerXDL,
201  MXdlPerWave,
202  NXdlPerWave,
203  ABlockTransferThreadClusterLengths_AK0_M_AK1,
204  ABlockTransferThreadClusterArrangeOrder,
205  ABlockTransferSrcAccessOrder,
206  ABlockTransferSrcVectorDim,
207  ABlockTransferSrcScalarPerVector,
208  ABlockTransferDstScalarPerVector_AK1,
209  false,
210  ABlockLdsExtraM,
211  BBlockTransferThreadClusterLengths_BK0_N_BK1,
212  BBlockTransferThreadClusterArrangeOrder,
213  BBlockTransferSrcAccessOrder,
214  BBlockTransferSrcVectorDim,
215  BBlockTransferSrcScalarPerVector,
216  BBlockTransferDstScalarPerVector_BK1,
217  false,
218  BBlockLdsExtraN,
219  CShuffleMXdlPerWavePerShuffle,
220  CShuffleNXdlPerWavePerShuffle,
221  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
222  CShuffleBlockTransferScalarPerVector_NPerBlock,
223  BlkGemmPipeSched,
224  BlkGemmPipelineVer,
225  ComputeTypeA,
226  ComputeTypeB,
227  PermuteA,
228  PermuteB>;
229 
231 
232  static constexpr index_t APackedSize = []() {
234  return 2;
235  else
236  return 1;
237  }();
238 
239  static constexpr index_t BPackedSize = []() {
241  return 2;
242  else
243  return 1;
244  }();
245 
255  struct Invoker : public BaseInvoker
256  {
262  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
263  {
264  if(stream_config.log_level_ > 0)
265  {
266  arg.Print();
267  GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
268  }
269 
271  {
272  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
273  }
274 
275  index_t gdx, gdy, gdz;
276  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
277 
278  float ave_time = 0;
279 
280  index_t k_grain = arg.KBatch * KPerBlock;
281  index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
282 
283  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
284 
285  const auto Run = [&](const auto& kernel) {
286  if(stream_config.flush_cache)
287  {
288  Argument arg_ = arg;
289 
290  const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
291  arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
292  const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
293  arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
294 
295  auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
296  sizeof(ADataType) / APackedSize;
297  auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
298  sizeof(BDataType) / BPackedSize;
299 
301  arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
302  rotating_mem.Print();
303 
304  auto run_flush_cache = [&]() {
305  // flush icache
307  // rotating mem
308  rotating_mem.Next();
309  // clear c mem
310  if(arg_.KBatch > 1)
311  hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
312  0,
313  arg_.M * arg_.N * sizeof(CDataType),
314  stream_config.stream_id_));
315  };
316 
317  ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
318  stream_config,
319  run_flush_cache,
320  kernel,
321  dim3(gdx, gdy, gdz),
322  dim3(BlockSize),
323  0,
324  arg_);
325  }
326  else
327  {
328  if(arg.KBatch > 1)
329  hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
330  0,
331  arg.M * arg.N * sizeof(CDataType),
332  stream_config.stream_id_));
333 
334  ave_time = launch_and_time_kernel(
335  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
336  }
337  };
338 
339  constexpr index_t minimum_occupancy = []() {
340  if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
341  {
342  return 2;
343  }
344  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
345  {
346  return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
347  }
348  else
349  {
350  return 1;
351  }
352  }();
353 
354  if(has_main_k_block_loop)
355  {
356  // Tail number always full
357  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
358  BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
359  {
360  if(arg.KBatch > 1)
361  {
362  const auto kernel =
364  true,
366  minimum_occupancy>;
367  Run(kernel);
368  }
369  else
370  {
371  const auto kernel =
373  true,
375  minimum_occupancy>;
376  Run(kernel);
377  }
378  }
379  // Tail number could be One to Seven
380  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
381  {
382  if(arg.KBatch > 1)
383  {
385  {
386  const auto kernel =
388  true,
390  minimum_occupancy,
392  Run(kernel);
393  }
394  else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
396  {
397  const auto kernel =
399  true,
401  minimum_occupancy,
403  Run(kernel);
404  }
405 
406  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
407  {
409  {
410  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
411  GridwiseGemm,
412  true,
414  minimum_occupancy,
416  Run(kernel);
417  }
418  }
419 
420  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
421  {
424  {
425  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
426  GridwiseGemm,
427  true,
429  minimum_occupancy,
431  Run(kernel);
432  }
433  }
434 
435  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
436  {
439  {
440  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
441  GridwiseGemm,
442  true,
444  minimum_occupancy,
446  Run(kernel);
447  }
448  }
449 
450  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
451  {
454  {
455  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
456  GridwiseGemm,
457  true,
459  minimum_occupancy,
461  Run(kernel);
462  }
463  }
464 
465  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
466  {
468  {
469  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
470  GridwiseGemm,
471  true,
473  minimum_occupancy,
475  Run(kernel);
476  }
477  }
478 
479  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
480  {
483  {
484  const auto kernel = kernel_gemm_xdl_cshuffle_v3<
485  GridwiseGemm,
486  true,
488  minimum_occupancy,
490  Run(kernel);
491  }
492  }
493  }
494  else
495  {
497  {
498  const auto kernel =
500  true,
502  minimum_occupancy,
504  Run(kernel);
505  }
506  else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
508  {
509  const auto kernel =
511  true,
513  minimum_occupancy,
515  Run(kernel);
516  }
517 
518  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
519  {
521  {
522  const auto kernel =
524  true,
526  minimum_occupancy,
528  Run(kernel);
529  }
530  }
531 
532  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
533  {
536  {
537  const auto kernel =
539  true,
541  minimum_occupancy,
543  Run(kernel);
544  }
545  }
546 
547  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
548  {
551  {
552  const auto kernel =
554  true,
556  minimum_occupancy,
558  Run(kernel);
559  }
560  }
561 
562  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
563  {
566  {
567  const auto kernel =
569  true,
571  minimum_occupancy,
573  Run(kernel);
574  }
575  }
576 
577  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
578  {
580  {
581  const auto kernel =
583  true,
585  minimum_occupancy,
587  Run(kernel);
588  }
589  }
590 
591  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
592  {
595  {
596  const auto kernel =
598  true,
600  minimum_occupancy,
602  Run(kernel);
603  }
604  }
605  }
606  }
607  // Tail number could be Odd or Even
608  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
609  {
610  if(arg.KBatch > 1)
611  {
613  {
614  const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
615  GridwiseGemm,
616  true,
618  minimum_occupancy,
620  Run(kernel);
621  }
622  else
623  {
624  const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
625  GridwiseGemm,
626  true,
628  minimum_occupancy,
630  Run(kernel);
631  }
632  }
633  else
634  {
636  {
637  const auto kernel =
639  true,
641  minimum_occupancy,
643  Run(kernel);
644  }
645  else
646  {
647  const auto kernel =
649  true,
651  minimum_occupancy,
653  Run(kernel);
654  }
655  }
656  }
657  else
658  {
659  if(arg.KBatch > 1)
660  {
662  {
663  const auto kernel =
665  true,
667  minimum_occupancy,
669  Run(kernel);
670  }
671  else
672  {
673  const auto kernel =
675  true,
677  minimum_occupancy,
679  Run(kernel);
680  }
681  }
682  else
683  {
685  {
686  const auto kernel =
688  true,
690  minimum_occupancy,
692  Run(kernel);
693  }
694  else
695  {
696  const auto kernel =
698  true,
700  minimum_occupancy,
702  Run(kernel);
703  }
704  }
705  }
706  }
707  else
708  {
709  // Tail number always 1
710  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
711  {
712  if(arg.KBatch > 1)
713  {
714  const auto kernel =
716  false,
718  minimum_occupancy>;
719  Run(kernel);
720  }
721  else
722  {
723  const auto kernel =
725  false,
727  minimum_occupancy>;
728  Run(kernel);
729  }
730  }
731  }
732 
733  return ave_time;
734  }
735 
736  // polymorphic
737  float Run(const BaseArgument* p_arg,
738  const StreamConfig& stream_config = StreamConfig{}) override
739  {
740  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
741  }
742  };
743 
744  static constexpr bool IsValidCompilationParameter()
745  {
746  // TODO: properly implement this check
747  return true;
748  }
749 
750  static bool IsSupportedArgument(const Argument& arg)
751  {
752  if(!ck::is_xdl_supported())
753  {
754  return false;
755  }
756 
757  if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
758  {
759  return false;
760  }
761 
762  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
763  GemmSpec == GemmSpecialization::NKPadding ||
764  GemmSpec == GemmSpecialization::MNKPadding ||
765  GemmSpec == GemmSpecialization::KPadding))
766  {
767  return false;
768  }
769 
770  return GridwiseGemm::CheckValidity(arg);
771  }
772 
773  // polymorphic
774  bool IsSupportedArgument(const BaseArgument* p_arg) override
775  {
776  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
777  }
778 
779  index_t GetKPerBlock() override { return KPerBlock; }
780 
781  bool GetPermuteA() override { return PermuteA; }
782  bool GetPermuteB() override { return PermuteB; }
783 
784  static auto MakeArgument(const ADataType* p_a,
785  const BDataType* p_b,
786  CDataType* p_c,
787  index_t M,
788  index_t N,
789  index_t K,
790  index_t StrideA,
791  index_t StrideB,
792  index_t StrideC,
793  index_t KBatch,
794  AElementwiseOperation,
795  BElementwiseOperation,
796  CElementwiseOperation)
797  {
798  return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
799  }
800 
801  static auto MakeInvoker() { return Invoker{}; }
802 
803  // polymorphic
804  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
805  const void* p_b,
806  void* p_c,
807  index_t M,
808  index_t N,
809  index_t K,
810  index_t StrideA,
811  index_t StrideB,
812  index_t StrideC,
813  index_t KBatch,
814  AElementwiseOperation,
815  BElementwiseOperation,
816  CElementwiseOperation) override
817  {
818  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
819  static_cast<const BDataType*>(p_b),
820  static_cast<CDataType*>(p_c),
821  M,
822  N,
823  K,
824  StrideA,
825  StrideB,
826  StrideC,
827  KBatch);
828  }
829 
830  // polymorphic
831  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
832  {
833  return std::make_unique<Invoker>(Invoker{});
834  }
835 
836  // polymorphic
837  std::string GetTypeString() const override
838  {
839  auto str = std::stringstream();
840 
841  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
844 
845  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
851 
852  // clang-format off
853  str << "DeviceGemmXdlUniversal"
854  << "<"
855  << getGemmSpecializationString(GemmSpec) << ", "
856  << std::string(ALayout::name)[0]
857  << std::string(BLayout::name)[0]
858  << std::string(CLayout::name)[0]
859  << ">"
860  << " BlkSize: "
861  << BlockSize << ", "
862  << "BlkTile: "
863  << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
864  << "WaveTile: "
865  << MPerXDL<<"x"<<NPerXDL << ", "
866  << "WaveMap: "
867  << MXdlPerWave<<"x" << NXdlPerWave<<", "
868  << "VmemReadVec: "
869  << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
870  << "BlkGemmPipelineScheduler: "
871  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
872  << "BlkGemmPipelineVersion: "
873  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
874  << "BlkGemmPipelinePrefetchStages: "
875  << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
876  << "Kpack: "
877  << GridwiseGemm::BlockwiseGemmPipe::AMmaKStride;
878  // clang-format on
879 
880  return str.str();
881  }
883 };
884 
885 } // namespace device
886 } // namespace tensor_operation
887 } // namespace ck
#define REGISTER_EXTRA_PRINTING_METHODS
Definition: device_base.hpp:46
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:216
Definition: ck.hpp:269
bool is_xdl_supported()
Definition: device_prop.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:59
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
int32_t index_t
Definition: ck.hpp:300
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:68
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:708
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:241
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:445
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1355
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1152
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:293
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:363
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1348
Definition: data_type.hpp:186
Definition: device_base.hpp:51
Definition: device_base.hpp:62
Helper structure responsible for kernel invocation.
Definition: device_gemm_xdl_cshuffle_v3.hpp:256
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl_cshuffle_v3.hpp:737
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
This function issues GPU kernel execution.
Definition: device_gemm_xdl_cshuffle_v3.hpp:262
"Universal" GEMM operation with SplitK support.
Definition: device_gemm_xdl_cshuffle_v3.hpp:178
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl_cshuffle_v3.hpp:774
bool GetPermuteA() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:781
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:831
index_t GetKPerBlock() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:779
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_xdl_cshuffle_v3.hpp:750
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_xdl_cshuffle_v3.hpp:744
std::string GetTypeString() const override
Definition: device_gemm_xdl_cshuffle_v3.hpp:837
GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemm
Definition: device_gemm_xdl_cshuffle_v3.hpp:228
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_xdl_cshuffle_v3.hpp:784
bool GetPermuteB() override
Definition: device_gemm_xdl_cshuffle_v3.hpp:782
static auto MakeInvoker()
Definition: device_gemm_xdl_cshuffle_v3.hpp:801
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_xdl_cshuffle_v3.hpp:804
typename GridwiseGemm::Argument Argument
Definition: device_gemm_xdl_cshuffle_v3.hpp:230
static constexpr index_t BPackedSize
Definition: device_gemm_xdl_cshuffle_v3.hpp:239
static constexpr index_t APackedSize
Definition: device_gemm_xdl_cshuffle_v3.hpp:232
Definition: device_gemm_v2.hpp:22
Definition: flush_cache.hpp:138