/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp Source File
device_gemm_xdl_cshuffle_streamk_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
19 
20 namespace ck {
21 namespace tensor_operation {
22 namespace device {
23 
24 template <typename ALayout,
25  typename BLayout,
26  typename CLayout,
27  typename ADataType,
28  typename BDataType,
29  typename CDataType,
30  typename GemmAccDataType,
31  typename CShuffleDataType,
32  typename AElementwiseOperation,
33  typename BElementwiseOperation,
34  typename CElementwiseOperation,
35  GemmSpecialization GemmSpec,
36  index_t BlockSize,
37  index_t MPerBlock,
38  index_t NPerBlock,
39  index_t KPerBlock,
40  index_t AK1,
41  index_t BK1,
42  index_t MPerXDL,
43  index_t NPerXDL,
44  index_t MXdlPerWave,
45  index_t NXdlPerWave,
46  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
47  typename ABlockTransferThreadClusterArrangeOrder,
48  typename ABlockTransferSrcAccessOrder,
49  index_t ABlockTransferSrcVectorDim,
50  index_t ABlockTransferSrcScalarPerVector,
51  index_t ABlockTransferDstScalarPerVector_AK1,
52  bool ABlockLdsExtraM,
53  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
54  typename BBlockTransferThreadClusterArrangeOrder,
55  typename BBlockTransferSrcAccessOrder,
56  index_t BBlockTransferSrcVectorDim,
57  index_t BBlockTransferSrcScalarPerVector,
58  index_t BBlockTransferDstScalarPerVector_BK1,
59  bool BBlockLdsExtraN,
60  index_t CShuffleMXdlPerWavePerShuffle,
61  index_t CShuffleNXdlPerWavePerShuffle,
62  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
63  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
66  typename ComputeTypeA = CDataType,
67  typename ComputeTypeB = ComputeTypeA>
69  BLayout,
70  CLayout,
71  ADataType,
72  BDataType,
73  CDataType,
74  AElementwiseOperation,
75  BElementwiseOperation,
76  CElementwiseOperation>
77 {
78  // GridwiseGemm
80  ALayout,
81  BLayout,
82  CLayout,
83  ADataType,
84  BDataType,
85  GemmAccDataType,
86  CShuffleDataType,
87  CDataType,
88  AElementwiseOperation,
89  BElementwiseOperation,
90  CElementwiseOperation,
91  GemmSpec,
92  BlockSize,
93  MPerBlock,
94  NPerBlock,
95  KPerBlock,
96  AK1,
97  BK1,
98  MPerXDL,
99  NPerXDL,
100  MXdlPerWave,
101  NXdlPerWave,
102  ABlockTransferThreadClusterLengths_AK0_M_AK1,
103  ABlockTransferThreadClusterArrangeOrder,
104  ABlockTransferSrcAccessOrder,
105  ABlockTransferSrcVectorDim,
106  ABlockTransferSrcScalarPerVector,
107  ABlockTransferDstScalarPerVector_AK1,
108  false,
109  ABlockLdsExtraM,
110  BBlockTransferThreadClusterLengths_BK0_N_BK1,
111  BBlockTransferThreadClusterArrangeOrder,
112  BBlockTransferSrcAccessOrder,
113  BBlockTransferSrcVectorDim,
114  BBlockTransferSrcScalarPerVector,
115  BBlockTransferDstScalarPerVector_BK1,
116  false,
117  BBlockLdsExtraN,
118  CShuffleMXdlPerWavePerShuffle,
119  CShuffleNXdlPerWavePerShuffle,
120  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
121  CShuffleBlockTransferScalarPerVector_NPerBlock,
122  BlkGemmPipeSched,
123  BlkGemmPipelineVer,
124  ComputeTypeA,
125  ComputeTypeB>;
126 
128 
129  // Invoker
130  struct Invoker : public BaseInvoker
131  {
132  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
133  {
134 
135  if(stream_config.log_level_ > 0)
136  {
137  arg.Print();
138  }
139 
141  {
142  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
143  }
144 
145  float ave_time = 0;
146 
147  index_t k_grain = KPerBlock;
148  index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
149 
150  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
151 
154  {
155 
156  hip_check_error(hipMemsetAsync(
157  arg.p_c_grid, 0, arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_));
158  }
159 
160  const auto Run = [&](const auto& kernel) {
161  dim3 grid_dim;
162  if(arg.Grid_size < 0)
163  {
164  int occupancy, num_cu;
165  hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
166  &occupancy, kernel, BlockSize, 0));
167  hipDeviceProp_t dev_prop;
168  hipDevice_t dev;
169  hip_check_error(hipGetDevice(&dev));
170  hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
171  num_cu = dev_prop.multiProcessorCount;
172  arg.Grid_size = num_cu * occupancy;
173  grid_dim = arg.Grid_size;
174  }
175  else
176  grid_dim = arg.Grid_size;
177 
178  if(stream_config.flush_cache)
179  {
180  Argument arg_ = arg;
182  arg_,
183  stream_config.rotating_count,
184  arg_.M * arg_.K * sizeof(ADataType),
185  arg_.K * arg_.N * sizeof(BDataType));
186  rotating_mem.Print();
187 
188  auto run_flush_cache = [&]() {
189  // flush icache
191  // rotating mem
192  rotating_mem.Next();
193  };
194 
195  ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
196  stream_config, run_flush_cache, kernel, grid_dim, dim3(BlockSize), 0, arg_);
197  }
198  else
199  {
200 
203  {
204  ave_time = launch_and_time_kernel(
205  stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg);
206  }
209  {
210  char* workspace_semaphore =
211  reinterpret_cast<char*>(arg.p_workspace_) +
212  arg.block_2_ctile_map_streamk.get_workspace_size_for_acc(
213  sizeof(GemmAccDataType));
214  auto preprocess = [&]() {
215  hipMemsetAsync(
216  workspace_semaphore,
217  0,
218  // sizeof(uint32_t),
219  arg.block_2_ctile_map_streamk.get_workspace_size_for_semaphore(),
220  stream_config.stream_id_);
221  };
222 
224  stream_config, preprocess, kernel, grid_dim, dim3(BlockSize), 0, arg);
225  }
226  }
227  };
228 
229  constexpr index_t minimum_occupancy =
230  BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
231 
232  if(has_main_k_block_loop)
233  {
234  // Tail number always full
235  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
236  BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
237  {
238 
239  const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
240  true,
242  minimum_occupancy>;
243 
244  Run(kernel);
245  }
246  // Tail number could be One to Seven
247  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
248  {
249 
250  {
252  {
253  const auto kernel =
255  true,
257  minimum_occupancy,
259  Run(kernel);
260  }
261  else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
263  {
264  const auto kernel =
266  true,
268  minimum_occupancy,
270  Run(kernel);
271  }
272 
273  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
274  {
276  {
277  const auto kernel =
279  true,
281  minimum_occupancy,
283  Run(kernel);
284  }
285  }
286 
287  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
288  {
291  {
292  const auto kernel =
294  true,
296  minimum_occupancy,
298  Run(kernel);
299  }
300  }
301 
302  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
303  {
306  {
307  const auto kernel =
309  true,
311  minimum_occupancy,
313  Run(kernel);
314  }
315  }
316 
317  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
318  {
321  {
322  const auto kernel =
324  true,
326  minimum_occupancy,
328  Run(kernel);
329  }
330  }
331 
332  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
333  {
335  {
336  const auto kernel =
338  true,
340  minimum_occupancy,
342  Run(kernel);
343  }
344  }
345 
346  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
347  {
350  {
351  const auto kernel =
353  true,
355  minimum_occupancy,
357  Run(kernel);
358  }
359  }
360  }
361  }
362  // Tail number could be Odd or Even
363  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
364  {
365 
367  {
368  const auto kernel =
370  true,
372  minimum_occupancy,
374  Run(kernel);
375  }
376  else
377  {
378  const auto kernel =
380  true,
382  minimum_occupancy,
384  Run(kernel);
385  }
386  }
387  else
388  {
389 
391  {
392  const auto kernel =
394  true,
396  minimum_occupancy,
398  Run(kernel);
399  }
400  else
401  {
402  const auto kernel =
404  true,
406  minimum_occupancy,
408  Run(kernel);
409  }
410  }
411  }
412  else
413  {
414  // Tail number always 1
415  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
416  {
417 
418  const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
419  false,
421  minimum_occupancy>;
422  Run(kernel);
423  }
424  }
425 
426  return ave_time;
427  }
428 
429  // polymorphic
430  float Run(const BaseArgument* p_arg,
431  const StreamConfig& stream_config = StreamConfig{}) override
432  {
433  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
434  }
435  };
436 
437  size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
438  {
439  const Argument* p_arg = dynamic_cast<const Argument*>(pArg);
442  {
443  return p_arg->block_2_ctile_map_streamk.get_workspace_size(sizeof(GemmAccDataType));
444  }
445  else
446  {
447  return 0;
448  }
449  }
450 
452  void* p_workspace,
453  const StreamConfig& = StreamConfig{}) const override
454  {
455  Argument* pArg_ = dynamic_cast<Argument*>(pArg);
456 
457  pArg_->p_workspace_ = p_workspace;
458  }
459 
460  static constexpr bool IsValidCompilationParameter()
461  {
462  // TODO: properly implement this check
463  return true;
464  }
465 
466  static bool IsSupportedArgument(const Argument& arg)
467  {
468  if(!ck::is_xdl_supported())
469  {
470  return false;
471  }
472  if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> &&
473  arg.Streamk_sel > 0)
474  {
475  return false;
476  }
477  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
478  GemmSpec == GemmSpecialization::NKPadding ||
479  GemmSpec == GemmSpecialization::MNKPadding ||
480  GemmSpec == GemmSpecialization::KPadding))
481  {
482  return false;
483  }
484 
485  return GridwiseGemm::CheckValidity(arg);
486  }
487 
488  // polymorphic
489  bool IsSupportedArgument(const BaseArgument* p_arg) override
490  {
491  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
492  }
493 
494  static auto MakeArgument(const ADataType* p_a,
495  const BDataType* p_b,
496  CDataType* p_c,
497  index_t M,
498  index_t N,
499  index_t K,
500  index_t StrideA,
501  index_t StrideB,
502  index_t StrideC,
503  index_t streamk_sel,
504  index_t Grid_size,
505  AElementwiseOperation,
506  BElementwiseOperation,
507  CElementwiseOperation)
508  {
509 
510  constexpr index_t minimum_occupancy =
511  BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
512  index_t K_split = (K + KPerBlock - 1) / KPerBlock * KPerBlock;
513  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
514  int occupancy, num_cu;
515  const auto calculate_grid_size = [&](const auto& kernel) {
517  hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
518  hipDeviceProp_t dev_prop;
519  hipDevice_t dev;
520  hip_check_error(hipGetDevice(&dev));
521  hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
522  num_cu = dev_prop.multiProcessorCount;
523  Grid_size = num_cu * occupancy;
524  };
525 
526  if(has_main_k_block_loop)
527  {
528  // Tail number always full
529  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
530  BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
531  {
532 
533  const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
534  true,
536  minimum_occupancy>;
537  calculate_grid_size(kernel);
538  }
539  // Tail number could be One to Seven
540  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
541  {
542 
544  {
545  const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
546  true,
548  minimum_occupancy,
550  calculate_grid_size(kernel);
551  }
553  {
554  const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
555  true,
557  minimum_occupancy,
559  calculate_grid_size(kernel);
560  }
561 
562  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
563  {
565  {
566  const auto kernel =
568  true,
570  minimum_occupancy,
572  calculate_grid_size(kernel);
573  }
574  }
575 
576  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
577  {
579  {
580  const auto kernel =
582  true,
584  minimum_occupancy,
586  calculate_grid_size(kernel);
587  }
588  }
589 
590  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
591  {
593  {
594  const auto kernel =
596  true,
598  minimum_occupancy,
600  calculate_grid_size(kernel);
601  }
602  }
603 
604  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
605  {
607  {
608  const auto kernel =
610  true,
612  minimum_occupancy,
614  calculate_grid_size(kernel);
615  }
616  }
617 
618  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
619  {
621  {
622  const auto kernel =
624  true,
626  minimum_occupancy,
628  calculate_grid_size(kernel);
629  }
630  }
631 
632  if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
633  {
635  {
636  const auto kernel =
638  true,
640  minimum_occupancy,
642  calculate_grid_size(kernel);
643  }
644  }
645  }
646  // Tail number could be Odd or Even
647  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
648  {
649 
651  {
652  const auto kernel =
654  true,
656  minimum_occupancy,
658  calculate_grid_size(kernel);
659  }
660  else
661  {
662  const auto kernel =
664  true,
666  minimum_occupancy,
668  calculate_grid_size(kernel);
669  }
670  }
671  else
672  {
673 
675  {
676  const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
677  true,
679  minimum_occupancy,
681  calculate_grid_size(kernel);
682  }
683  else
684  {
685  const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
686  true,
688  minimum_occupancy,
690  calculate_grid_size(kernel);
691  }
692  }
693  }
694  else
695  {
696  // Tail number always 1
697  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
698  {
699 
700  const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
701  false,
703  minimum_occupancy>;
704  calculate_grid_size(kernel);
705  }
706  }
707 
708  return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, streamk_sel, Grid_size};
709  }
710 
711  static auto MakeInvoker() { return Invoker{}; }
712 
713  // polymorphic
714  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
715  const void* p_b,
716  void* p_c,
717  index_t M,
718  index_t N,
719  index_t K,
720  index_t StrideA,
721  index_t StrideB,
722  index_t StrideC,
723  index_t streamk_sel,
724  index_t Grid_size,
725  AElementwiseOperation,
726  BElementwiseOperation,
727  CElementwiseOperation) override
728  {
729  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
730  static_cast<const BDataType*>(p_b),
731  static_cast<CDataType*>(p_c),
732  M,
733  N,
734  K,
735  StrideA,
736  StrideB,
737  StrideC,
738  streamk_sel,
739  Grid_size);
740  }
741 
742  // polymorphic
743  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
744  {
745  return std::make_unique<Invoker>(Invoker{});
746  }
747 
748  // polymorphic
749  std::string GetTypeString() const override
750  {
751  auto str = std::stringstream();
752 
753  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
756 
757  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
763 
764  // clang-format off
765  str << "DeviceGemmXdlUniversal"
766  << "<"
767  << getGemmSpecializationString(GemmSpec) << ", "
768  << std::string(ALayout::name)[0]
769  << std::string(BLayout::name)[0]
770  << std::string(CLayout::name)[0]
771  << ">"
772  << " BlkSize: "
773  << BlockSize << ", "
774  << "BlkTile: "
775  << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
776  << "WaveTile: "
777  << MPerXDL<<"x"<<NPerXDL << ", "
778  << "WaveMap: "
779  << MXdlPerWave<<"x" << NXdlPerWave<<", "
780  << "VmemReadVec: "
781  << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
782  << "BlkGemmPipelineScheduler: "
783  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
784  << "BlkGemmPipelineVersion: "
785  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
786  << "BlkGemmPipelinePrefetchStages: "
787  << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
788  // clang-format on
789 
790  return str.str();
791  }
792 };
793 
794 } // namespace device
795 } // namespace tensor_operation
796 } // namespace ck
void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:90
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:33
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:215
Definition: ck.hpp:264
@ Atomic
Definition: block_to_ctile_map.hpp:1009
@ Reduction
Definition: block_to_ctile_map.hpp:1010
bool is_xdl_supported()
Definition: device_prop.hpp:54
BlockGemmPipelineVersion
Definition: blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp:13
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:58
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:12
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:37
int32_t index_t
Definition: ck.hpp:289
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:69
Definition: stream_config.hpp:10
static constexpr StreamKReductionStrategy ReductionStrategy
Definition: block_to_ctile_map.hpp:1422
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:517
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:126
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:1130
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:1123
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:940
Definition: device_base.hpp:50
void * p_workspace_
Definition: device_base.hpp:57
Definition: device_base.hpp:61
Definition: device_gemm_streamk_v2.hpp:22
Definition: device_gemm_xdl_cshuffle_streamk_v3.hpp:131
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl_cshuffle_streamk_v3.hpp:430
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl_cshuffle_streamk_v3.hpp:132
Definition: device_gemm_xdl_cshuffle_streamk_v3.hpp:77
size_t GetWorkSpaceSize(const BaseArgument *pArg) const override
Definition: device_gemm_xdl_cshuffle_streamk_v3.hpp:437
typename GridwiseGemm::Argument Argument
Definition: device_gemm_xdl_cshuffle_streamk_v3.hpp:127
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t streamk_sel, index_t Grid_size, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_xdl_cshuffle_streamk_v3.hpp:494
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_xdl_cshuffle_streamk_v3.hpp:460
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl_cshuffle_streamk_v3.hpp:743
void SetWorkSpacePointer(BaseArgument *pArg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition: device_gemm_xdl_cshuffle_streamk_v3.hpp:451
GridwiseGemm_xdl_cshuffle_streamk_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > GridwiseGemm
Definition: device_gemm_xdl_cshuffle_streamk_v3.hpp:125
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_xdl_cshuffle_streamk_v3.hpp:466
static auto MakeInvoker()
Definition: device_gemm_xdl_cshuffle_streamk_v3.hpp:711
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl_cshuffle_streamk_v3.hpp:489
std::string GetTypeString() const override
Definition: device_gemm_xdl_cshuffle_streamk_v3.hpp:749
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t streamk_sel, index_t Grid_size, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_xdl_cshuffle_streamk_v3.hpp:714
Definition: flush_cache.hpp:137