/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp Source File
device_gemm_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
19 
20 namespace ck {
21 namespace tensor_operation {
22 namespace device {
23 
123 template <typename ALayout,
124  typename BLayout,
125  typename CLayout,
126  typename ADataType,
127  typename BDataType,
128  typename CDataType,
129  typename AccDataType,
130  typename CShuffleDataType,
131  typename AElementwiseOperation,
132  typename BElementwiseOperation,
133  typename CElementwiseOperation,
134  GemmSpecialization GemmSpec,
135  index_t BlockSize,
136  index_t MPerBlock,
137  index_t NPerBlock,
138  index_t KPerBlock,
139  index_t AK1,
140  index_t BK1,
141  index_t MPerWmma,
142  index_t NPerWmma,
143  index_t MRepeat,
144  index_t NRepeat,
145  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
146  typename ABlockTransferThreadClusterArrangeOrder,
147  typename ABlockTransferSrcAccessOrder,
148  index_t ABlockTransferSrcVectorDim,
149  index_t ABlockTransferSrcScalarPerVector,
150  index_t ABlockTransferDstScalarPerVector_AK1,
151  bool ABlockLdsExtraM,
152  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
153  typename BBlockTransferThreadClusterArrangeOrder,
154  typename BBlockTransferSrcAccessOrder,
155  index_t BBlockTransferSrcVectorDim,
156  index_t BBlockTransferSrcScalarPerVector,
157  index_t BBlockTransferDstScalarPerVector_BK1,
158  bool BBlockLdsExtraN,
159  index_t CShuffleMRepeatPerShuffle,
160  index_t CShuffleNRepeatPerShuffle,
161  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
162  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
165  typename ComputeTypeA = CDataType,
166  typename ComputeTypeB = ComputeTypeA,
167  bool PermuteA = false,
168  bool PermuteB = false>
169 struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
170  BLayout,
171  CLayout,
172  ADataType,
173  BDataType,
174  CDataType,
175  AElementwiseOperation,
176  BElementwiseOperation,
177  CElementwiseOperation>
178 {
179  // GridwiseGemm
181  ALayout,
182  BLayout,
183  CLayout,
184  ADataType,
185  BDataType,
186  AccDataType,
187  CShuffleDataType,
188  CDataType,
189  AElementwiseOperation,
190  BElementwiseOperation,
191  CElementwiseOperation,
192  GemmSpec,
193  BlockSize,
194  MPerBlock,
195  NPerBlock,
196  KPerBlock,
197  AK1,
198  BK1,
199  MPerWmma,
200  NPerWmma,
201  MRepeat,
202  NRepeat,
203  ABlockTransferThreadClusterLengths_AK0_M_AK1,
204  ABlockTransferThreadClusterArrangeOrder,
205  ABlockTransferSrcAccessOrder,
206  ABlockTransferSrcVectorDim,
207  ABlockTransferSrcScalarPerVector,
208  ABlockTransferDstScalarPerVector_AK1,
209  false,
210  ABlockLdsExtraM,
211  BBlockTransferThreadClusterLengths_BK0_N_BK1,
212  BBlockTransferThreadClusterArrangeOrder,
213  BBlockTransferSrcAccessOrder,
214  BBlockTransferSrcVectorDim,
215  BBlockTransferSrcScalarPerVector,
216  BBlockTransferDstScalarPerVector_BK1,
217  false,
218  BBlockLdsExtraN,
219  CShuffleMRepeatPerShuffle,
220  CShuffleNRepeatPerShuffle,
221  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
222  CShuffleBlockTransferScalarPerVector_NPerBlock,
223  BlkGemmPipeSched,
224  BlkGemmPipelineVer,
225  ComputeTypeA,
226  ComputeTypeB,
227  PermuteA,
228  PermuteB>;
229 
231 
241  struct Invoker : public BaseInvoker
242  {
248  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
249  {
250  if(stream_config.log_level_ > 0)
251  {
252  arg.Print();
253  GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
254  }
255 
257  {
258  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
259  }
260 
261  index_t gdx, gdy, gdz;
262  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
263 
264  float ave_time = 0;
265 
266  index_t k_grain = arg.KBatch * KPerBlock;
267  index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
268 
269  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
270 
271  const auto Run = [&](const auto& kernel) {
272  if(stream_config.flush_cache)
273  {
274  Argument arg_ = arg;
275 
276  const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
277  arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
278  const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
279  arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
280 
281  auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
282  sizeof(ADataType) / GridwiseGemm::APackedSize;
283  auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
284  sizeof(BDataType) / GridwiseGemm::BPackedSize;
285 
287  arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
288  rotating_mem.Print();
289 
290  auto run_flush_cache = [&]() {
291  // flush icache
293  // rotating mem
294  rotating_mem.Next();
295  // clear c mem
296  if(arg_.KBatch > 1)
297  HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_c_grid,
298  0,
299  arg_.M * arg_.N * sizeof(CDataType),
300  stream_config.stream_id_));
301  };
302 
303  ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
304  stream_config,
305  run_flush_cache,
306  kernel,
307  dim3(gdx, gdy, gdz),
308  dim3(BlockSize),
309  0,
310  arg_);
311  }
312  else
313  {
314  if(arg.KBatch > 1)
315  HIP_CHECK_ERROR(hipMemsetAsync(arg.p_c_grid,
316  0,
317  arg.M * arg.N * sizeof(CDataType),
318  stream_config.stream_id_));
319 
320  ave_time = launch_and_time_kernel(
321  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
322  }
323  };
324 
325  constexpr index_t minimum_occupancy = []() {
326  if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
327  {
328  return 2;
329  }
330  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
331  {
332  return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
333  }
334  else
335  {
336  return 1;
337  }
338  }();
339 
340  if(has_main_k_block_loop)
341  {
342  // Tail number always full
343  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
344  BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
345  {
346  if(arg.KBatch > 1)
347  {
348  const auto kernel =
350  true,
352  minimum_occupancy>;
353  Run(kernel);
354  }
355  else
356  {
357  const auto kernel =
359  true,
361  minimum_occupancy>;
362  Run(kernel);
363  }
364  }
365  else
366  {
367  // TODO: Implement
368  }
369  }
370  else
371  {
372  // Tail number always 1
373  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
374  {
375  if(arg.KBatch > 1)
376  {
377  const auto kernel =
379  false,
381  minimum_occupancy>;
382  Run(kernel);
383  }
384  else
385  {
386  const auto kernel =
388  false,
390  minimum_occupancy>;
391  Run(kernel);
392  }
393  }
394  }
395 
396  return ave_time;
397  }
398 
399  // polymorphic
400  float Run(const BaseArgument* p_arg,
401  const StreamConfig& stream_config = StreamConfig{}) override
402  {
403  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
404  }
405  };
406 
407  static constexpr bool IsValidCompilationParameter()
408  {
409  // TODO: properly implement this check
410  return true;
411  }
412 
413  static bool IsSupportedArgument(const Argument& arg)
414  {
416  {
417  return false;
418  }
419 
420  if constexpr(std::is_same_v<CDataType, ck::half_t> ||
421  std::is_same_v<CDataType, ck::bhalf_t>)
422  {
423  if(arg.KBatch > 1 && ck::is_gfx11_supported())
424  {
425  // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
426  return false;
427  }
428  }
429 
430  if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
431  std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
432  {
434  {
435  return false;
436  }
437  }
438 
439  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
440  GemmSpec == GemmSpecialization::NKPadding ||
441  GemmSpec == GemmSpecialization::MNKPadding ||
442  GemmSpec == GemmSpecialization::KPadding))
443  {
444  return false;
445  }
446 
447  return GridwiseGemm::CheckValidity(arg);
448  }
449 
450  // polymorphic
451  bool IsSupportedArgument(const BaseArgument* p_arg) override
452  {
453  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
454  }
455 
456  index_t GetKPerBlock() override { return KPerBlock; }
457 
458  bool GetPermuteA() override { return PermuteA; }
459  bool GetPermuteB() override { return PermuteB; }
460 
461  static auto MakeArgument(const ADataType* p_a,
462  const BDataType* p_b,
463  CDataType* p_c,
464  index_t M,
465  index_t N,
466  index_t K,
467  index_t StrideA,
468  index_t StrideB,
469  index_t StrideC,
470  index_t KBatch,
471  AElementwiseOperation,
472  BElementwiseOperation,
473  CElementwiseOperation)
474  {
475  return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
476  }
477 
478  static auto MakeInvoker() { return Invoker{}; }
479 
480  // polymorphic
481  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
482  const void* p_b,
483  void* p_c,
484  index_t M,
485  index_t N,
486  index_t K,
487  index_t StrideA,
488  index_t StrideB,
489  index_t StrideC,
490  index_t KBatch,
491  AElementwiseOperation,
492  BElementwiseOperation,
493  CElementwiseOperation) override
494  {
495  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
496  static_cast<const BDataType*>(p_b),
497  static_cast<CDataType*>(p_c),
498  M,
499  N,
500  K,
501  StrideA,
502  StrideB,
503  StrideC,
504  KBatch);
505  }
506 
507  // polymorphic
508  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
509  {
510  return std::make_unique<Invoker>(Invoker{});
511  }
512 
513  // polymorphic
514  std::string GetTypeString() const override
515  {
516  auto str = std::stringstream();
517 
518  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
521 
522  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
528 
529  // clang-format off
530  str << "DeviceGemm_Wmma_CShuffleV3"
531  << "<"
532  << getGemmSpecializationString(GemmSpec) << ", "
533  << std::string(ALayout::name)[0]
534  << std::string(BLayout::name)[0]
535  << std::string(CLayout::name)[0]
536  << ">"
537  << " BlkSize: "
538  << BlockSize << ", "
539  << "BlkTile: "
540  << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", "
541  << "WaveTile: "
542  << MPerWmma << "x"<<NPerWmma << ", "
543  << "WaveMap: "
544  << MRepeat << "x" << NRepeat << ", "
545  << "VmemReadVec: "
546  << ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
547  << "BlkGemmPipelineScheduler: "
548  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
549  << "BlkGemmPipelineVersion: "
550  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
551  << "BlkGemmPipelinePrefetchStages: "
552  << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
553  << "KPack: "
555  // clang-format on
556 
557  return str.str();
558  }
560 };
561 
562 } // namespace device
563 } // namespace tensor_operation
564 } // namespace ck
#define REGISTER_EXTRA_PRINTING_METHODS
Definition: device_base.hpp:46
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition: hip_check_error.hpp:22
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:216
Definition: ck.hpp:269
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
bool is_gfx12_supported()
Definition: device_prop.hpp:94
__global__ void kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:29
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
int32_t index_t
Definition: ck.hpp:300
bool is_gfx11_supported()
Definition: device_prop.hpp:86
Definition: stream_config.hpp:10
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:661
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:210
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:323
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:240
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:233
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1093
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:247
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1290
static constexpr index_t KPack
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:226
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:407
Definition: device_base.hpp:51
Definition: device_base.hpp:62
Helper structure responsible for kernel invocation.
Definition: device_gemm_wmma_cshuffle_v3.hpp:242
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
This function issues GPU kernel execution.
Definition: device_gemm_wmma_cshuffle_v3.hpp:248
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_wmma_cshuffle_v3.hpp:400
"Universal" GEMM operation with SplitK support.
Definition: device_gemm_wmma_cshuffle_v3.hpp:178
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_wmma_cshuffle_v3.hpp:407
std::string GetTypeString() const override
Definition: device_gemm_wmma_cshuffle_v3.hpp:514
GridwiseGemm_wmma_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemm
Definition: device_gemm_wmma_cshuffle_v3.hpp:228
static auto MakeInvoker()
Definition: device_gemm_wmma_cshuffle_v3.hpp:478
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_wmma_cshuffle_v3.hpp:451
bool GetPermuteA() override
Definition: device_gemm_wmma_cshuffle_v3.hpp:458
typename GridwiseGemm::Argument Argument
Definition: device_gemm_wmma_cshuffle_v3.hpp:230
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_wmma_cshuffle_v3.hpp:508
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_wmma_cshuffle_v3.hpp:481
bool GetPermuteB() override
Definition: device_gemm_wmma_cshuffle_v3.hpp:459
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_wmma_cshuffle_v3.hpp:461
index_t GetKPerBlock() override
Definition: device_gemm_wmma_cshuffle_v3.hpp:456
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_wmma_cshuffle_v3.hpp:413
Definition: device_gemm_v2.hpp:22
Definition: flush_cache.hpp:138