/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp Source File
device_gemm_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
20 
21 namespace ck {
22 namespace tensor_operation {
23 namespace device {
24 
124 template <typename ALayout,
125  typename BLayout,
126  typename CLayout,
127  typename ADataType,
128  typename BDataType,
129  typename CDataType,
130  typename AccDataType,
131  typename CShuffleDataType,
132  typename AElementwiseOperation,
133  typename BElementwiseOperation,
134  typename CElementwiseOperation,
135  GemmSpecialization GemmSpec,
136  index_t BlockSize,
137  index_t MPerBlock,
138  index_t NPerBlock,
139  index_t KPerBlock,
140  index_t AK1,
141  index_t BK1,
142  index_t MPerWmma,
143  index_t NPerWmma,
144  index_t MRepeat,
145  index_t NRepeat,
146  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
147  typename ABlockTransferThreadClusterArrangeOrder,
148  typename ABlockTransferSrcAccessOrder,
149  index_t ABlockTransferSrcVectorDim,
150  index_t ABlockTransferSrcScalarPerVector,
151  index_t ABlockTransferDstScalarPerVector_AK1,
152  bool ABlockLdsExtraM,
153  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
154  typename BBlockTransferThreadClusterArrangeOrder,
155  typename BBlockTransferSrcAccessOrder,
156  index_t BBlockTransferSrcVectorDim,
157  index_t BBlockTransferSrcScalarPerVector,
158  index_t BBlockTransferDstScalarPerVector_BK1,
159  bool BBlockLdsExtraN,
160  index_t CShuffleMRepeatPerShuffle,
161  index_t CShuffleNRepeatPerShuffle,
162  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
163  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
166  typename ComputeTypeA = CDataType,
167  typename ComputeTypeB = ComputeTypeA,
168  bool PermuteA = false,
169  bool PermuteB = false>
170 struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
171  BLayout,
172  CLayout,
173  ADataType,
174  BDataType,
175  CDataType,
176  AElementwiseOperation,
177  BElementwiseOperation,
178  CElementwiseOperation>
179 {
181  ALayout,
182  BLayout,
183  Tuple<>, // DsLayout
184  CLayout,
185  ADataType,
186  BDataType,
187  AccDataType,
188  CShuffleDataType,
189  Tuple<>, // DsDataType
190  CDataType,
191  AElementwiseOperation,
192  BElementwiseOperation,
193  CElementwiseOperation,
194  GemmSpec,
195  BlockSize,
196  MPerBlock,
197  NPerBlock,
198  KPerBlock,
199  AK1,
200  BK1,
201  MPerWmma,
202  NPerWmma,
203  MRepeat,
204  NRepeat,
205  ABlockTransferThreadClusterLengths_AK0_M_AK1,
206  ABlockTransferThreadClusterArrangeOrder,
207  ABlockTransferSrcAccessOrder,
208  ABlockTransferSrcVectorDim,
209  ABlockTransferSrcScalarPerVector,
210  ABlockTransferDstScalarPerVector_AK1,
211  false,
212  ABlockLdsExtraM,
213  BBlockTransferThreadClusterLengths_BK0_N_BK1,
214  BBlockTransferThreadClusterArrangeOrder,
215  BBlockTransferSrcAccessOrder,
216  BBlockTransferSrcVectorDim,
217  BBlockTransferSrcScalarPerVector,
218  BBlockTransferDstScalarPerVector_BK1,
219  false,
220  BBlockLdsExtraN,
221  CShuffleMRepeatPerShuffle,
222  CShuffleNRepeatPerShuffle,
223  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
225  BlkGemmPipeSched,
226  BlkGemmPipelineVer,
227  ComputeTypeA,
228  ComputeTypeB,
229  PermuteA,
230  PermuteB>;
231 
233 
236  ADataType,
237  BDataType,
238  Tuple<>,
239  CDataType,
240  MPerBlock,
241  NPerBlock,
242  KPerBlock,
243  BlockSize,
244  AK1,
245  BK1,
246  GemmSpec,
248  BlkGemmPipeSched,
249  BlkGemmPipelineVer,
250  ComputeTypeA,
251  ComputeTypeB>;
252 
253  // Invoker
255 
256  static bool IsSupportedArgument(const Argument& arg)
257  {
259  }
260 
261  // polymorphic
262  bool IsSupportedArgument(const BaseArgument* p_arg) override
263  {
264  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
265  }
266 
267  index_t GetKPerBlock() override { return KPerBlock; }
268 
269  bool GetPermuteA() override { return PermuteA; }
270  bool GetPermuteB() override { return PermuteB; }
271 
272  static auto MakeArgument(const ADataType* p_a,
273  const BDataType* p_b,
274  CDataType* p_c,
275  index_t M,
276  index_t N,
277  index_t K,
278  index_t StrideA,
279  index_t StrideB,
280  index_t StrideC,
281  index_t KBatch,
282  AElementwiseOperation a_element_op,
283  BElementwiseOperation b_element_op,
284  CElementwiseOperation cde_element_op)
285  {
286  return Argument{p_a,
287  p_b,
288  std::array<const void*, 0>{}, // p_ds_grid_
289  p_c,
290  M,
291  N,
292  K,
293  StrideA,
294  StrideB,
295  std::array<index_t, 0>{}, // StrideDs_
296  StrideC,
297  KBatch,
298  a_element_op,
299  b_element_op,
300  cde_element_op};
301  }
302 
303  static auto MakeInvoker() { return Invoker{}; }
304 
305  // polymorphic
306  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
307  const void* p_b,
308  void* p_c,
309  index_t M,
310  index_t N,
311  index_t K,
312  index_t StrideA,
313  index_t StrideB,
314  index_t StrideC,
315  index_t KBatch,
316  AElementwiseOperation a_element_op,
317  BElementwiseOperation b_element_op,
318  CElementwiseOperation c_element_op) override
319  {
320  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
321  static_cast<const BDataType*>(p_b),
322  std::array<const void*, 0>{}, // p_ds_grid_
323  static_cast<CDataType*>(p_c),
324  M,
325  N,
326  K,
327  StrideA,
328  StrideB,
329  std::array<index_t, 0>{}, // StrideDs_
330  StrideC,
331  KBatch,
332  a_element_op,
333  b_element_op,
334  c_element_op);
335  }
336 
337  // polymorphic
338  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
339  {
340  return std::make_unique<Invoker>(Invoker{});
341  }
342 
343  // polymorphic
344  std::string GetTypeString() const override
345  {
346  auto str = std::stringstream();
347 
348  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
351 
352  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
358 
359  // clang-format off
360  str << "DeviceGemm_Wmma_CShuffleV3"
361  << "<"
362  << getGemmSpecializationString(GemmSpec) << ", "
363  << std::string(ALayout::name)[0]
364  << std::string(BLayout::name)[0]
365  << std::string(CLayout::name)[0]
366  << ">"
367  << " BlkSize: "
368  << BlockSize << ", "
369  << "BlkTile: "
370  << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", "
371  << "WaveTile: "
372  << MPerWmma << "x"<<NPerWmma << ", "
373  << "WaveMap: "
374  << MRepeat << "x" << NRepeat << ", "
375  << "VmemReadVec: "
376  << ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
377  << "BlkGemmPipelineScheduler: "
378  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
379  << "BlkGemmPipelineVersion: "
380  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
381  << "BlkGemmPipelinePrefetchStages: "
382  << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
383  << "KPack: "
385  // clang-format on
386 
387  return str.str();
388  }
390 };
391 
392 } // namespace device
393 } // namespace tensor_operation
394 } // namespace ck
#define REGISTER_EXTRA_PRINTING_METHODS
Definition: device_base.hpp:46
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:267
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
int32_t index_t
Definition: ck.hpp:298
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:393
static constexpr index_t KPack
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:126
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:230
Definition: sequence.hpp:43
Definition: tuple.hpp:186
Definition: device_base.hpp:51
Helper structure responsible for kernel invocation.
Definition: device_gemm_wmma_cshuffle_v3_common.hpp:56
Definition: device_gemm_wmma_cshuffle_v3_common.hpp:42
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_wmma_cshuffle_v3_common.hpp:257
"Universal" GEMM operation with SplitK support.
Definition: device_gemm_wmma_cshuffle_v3.hpp:179
GridwiseGemm_wmma_cshuffle_v3< ALayout, BLayout, Tuple<>, CLayout, ADataType, BDataType, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemm
Definition: device_gemm_wmma_cshuffle_v3.hpp:230
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation cde_element_op)
Definition: device_gemm_wmma_cshuffle_v3.hpp:272
std::string GetTypeString() const override
Definition: device_gemm_wmma_cshuffle_v3.hpp:344
typename DeviceGemmCommon::Invoker Invoker
Definition: device_gemm_wmma_cshuffle_v3.hpp:254
static auto MakeInvoker()
Definition: device_gemm_wmma_cshuffle_v3.hpp:303
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_wmma_cshuffle_v3.hpp:262
bool GetPermuteA() override
Definition: device_gemm_wmma_cshuffle_v3.hpp:269
typename GridwiseGemm::Argument Argument
Definition: device_gemm_wmma_cshuffle_v3.hpp:232
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_wmma_cshuffle_v3.hpp:338
bool GetPermuteB() override
Definition: device_gemm_wmma_cshuffle_v3.hpp:270
index_t GetKPerBlock() override
Definition: device_gemm_wmma_cshuffle_v3.hpp:267
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_gemm_wmma_cshuffle_v3.hpp:306
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_wmma_cshuffle_v3.hpp:256
Definition: device_gemm_v2.hpp:22