/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp Source File
device_grouped_gemm_xdl.hpp
Go to the documentation of this file.
1 #pragma once
2 // SPDX-License-Identifier: MIT
3 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
4 
5 #pragma once
6 
7 #include <iostream>
8 #include <sstream>
9 
20 
21 namespace ck {
22 namespace tensor_operation {
23 namespace device {
24 
25 template <typename GridwiseGemm,
26  typename GemmDesc,
27  typename AElementwiseOperation,
28  typename BElementwiseOperation,
29  typename CDEElementwiseOperation,
30  bool HasMainKBlockLoop>
31 __global__ void
32 #if CK_USE_LAUNCH_BOUNDS
34 #endif
35  kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
36  const index_t group_count,
37  const AElementwiseOperation a_element_op,
38  const BElementwiseOperation b_element_op,
39  const CDEElementwiseOperation c_element_op)
40 {
41 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
42  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
43 
44  const index_t block_id = get_block_1d_id();
45 
46  const auto gemm_desc_ptr =
47  reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
48 
49  index_t left = 0;
50  index_t right = group_count;
51  index_t group_id = index_t((left + right) / 2);
52  while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ &&
53  block_id < gemm_desc_ptr[group_id].BlockEnd_)) &&
54  left <= right)
55  {
56  if(block_id < gemm_desc_ptr[group_id].BlockStart_)
57  {
58  right = group_id;
59  }
60  else
61  {
62  left = group_id;
63  }
64  group_id = index_t((left + right) / 2);
65  }
66 
67  GridwiseGemm::template Run<HasMainKBlockLoop>(
68  gemm_desc_ptr[group_id].a_ptr_,
69  gemm_desc_ptr[group_id].b_ptr_,
70  gemm_desc_ptr[group_id].ds_ptr_,
71  gemm_desc_ptr[group_id].e_ptr_,
72  p_shared,
73  a_element_op,
74  b_element_op,
75  c_element_op,
76  gemm_desc_ptr[group_id].a_grid_desc_ak0_m_ak1_,
77  gemm_desc_ptr[group_id].b_grid_desc_bk0_n_bk1_,
78  gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
79  gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
80  gemm_desc_ptr[group_id].block_2_etile_map_);
81 #else
82  ignore = gemm_descs_const;
83  ignore = group_count;
84  ignore = a_element_op;
85  ignore = b_element_op;
86  ignore = c_element_op;
87 #endif
88 }
89 
90 template <typename ALayout,
91  typename BLayout,
92  typename DsLayout,
93  typename ELayout,
94  typename ADataType,
95  typename BDataType,
96  typename AccDataType,
97  typename CShuffleDataType,
98  typename DsDataType,
99  typename EDataType,
100  typename AElementwiseOperation,
101  typename BElementwiseOperation,
102  typename CDEElementwiseOperation,
103  GemmSpecialization GemmSpec,
104  ck::index_t NumPrefetch,
105  ck::index_t BlockSize,
106  ck::index_t MPerBlock,
107  ck::index_t NPerBlock,
108  ck::index_t KPerBlock,
109  ck::index_t AK1,
110  ck::index_t BK1,
111  ck::index_t MPerXDL,
112  ck::index_t NPerXDL,
113  ck::index_t MXdlPerWave,
114  ck::index_t NXdlPerWave,
115  typename ABlockTransferThreadClusterLengths_K0_M_K1,
116  typename ABlockTransferThreadClusterArrangeOrder,
117  typename ABlockTransferSrcAccessOrder,
118  ck::index_t ABlockTransferSrcVectorDim,
119  ck::index_t ABlockTransferSrcScalarPerVector,
120  ck::index_t ABlockTransferDstScalarPerVector_K1,
121  bool ABlockLdsExtraM,
122  typename BBlockTransferThreadClusterLengths_K0_N_K1,
123  typename BBlockTransferThreadClusterArrangeOrder,
124  typename BBlockTransferSrcAccessOrder,
125  ck::index_t BBlockTransferSrcVectorDim,
126  ck::index_t BBlockTransferSrcScalarPerVector,
127  ck::index_t BBlockTransferDstScalarPerVector_K1,
128  bool BBlockLdsExtraN,
129  index_t CShuffleMXdlPerWavePerShuffle,
130  index_t CShuffleNXdlPerWavePerShuffle,
131  typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
132  index_t CDEBlockTransferScalarPerVector_NPerBlock,
134 struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
135  BLayout,
136  DsLayout,
137  ELayout,
138  ADataType,
139  BDataType,
140  DsDataType,
141  EDataType,
142  AElementwiseOperation,
143  BElementwiseOperation,
144  CDEElementwiseOperation>
145 {
147 
148  static constexpr index_t NumDTensor = DsDataType::Size();
149 
150  static constexpr auto I0 = Number<0>{};
151  static constexpr auto I1 = Number<1>{};
152  static constexpr auto I2 = Number<2>{};
153 
154  static constexpr auto matrix_padder =
155  MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
156 
157  static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
158  {
159  const auto a_grid_desc_mraw_kraw = [&]() {
160  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
161  {
162  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
163  make_tuple(StrideA, I1));
164  }
165  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
166  {
167  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
168  make_tuple(I1, StrideA));
169  }
170  }();
171 
172  return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
173  }
174 
175  static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
176  {
177  const auto b_grid_desc_nraw_kraw = [&]() {
179  {
180  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
181  make_tuple(I1, StrideB));
182  }
184  {
185  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
186  make_tuple(StrideB, I1));
187  }
188  }();
189 
190  return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
191  }
192 
193  template <typename ELay>
194  static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
195  {
196  const auto e_grid_desc_mraw_nraw = [&]() {
198  {
199  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
200  make_tuple(StrideE, I1));
201  }
203  {
204  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
205  make_tuple(I1, StrideE));
206  }
207  }();
208 
209  return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
210  }
211 
212  static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
213  const std::array<index_t, NumDTensor>& NRaws,
214  const std::array<index_t, NumDTensor>& DsStride)
215  {
216  return generate_tuple(
217  [&](auto i) {
218  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
219 
220  return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
221  },
223  }
224 
225  using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
226  using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
228  using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
229 
230  using ComputeDataType = ADataType;
231 
232  // GridwiseGemm
234  ADataType, // TODO: distinguish A/B datatype
235  BDataType,
237  AccDataType,
238  CShuffleDataType,
239  DsDataType,
240  EDataType,
241  AElementwiseOperation,
242  BElementwiseOperation,
243  CDEElementwiseOperation,
245  NumPrefetch, // NumGemmKPrefetchStage
246  BlockSize,
247  MPerBlock,
248  NPerBlock,
249  KPerBlock,
250  AK1,
251  BK1,
252  MPerXDL,
253  NPerXDL,
254  MXdlPerWave,
255  NXdlPerWave,
256  ABlockTransferThreadClusterLengths_K0_M_K1,
257  ABlockTransferThreadClusterArrangeOrder,
258  ABlockTransferSrcAccessOrder,
259  ABlockTransferSrcVectorDim,
260  ABlockTransferSrcScalarPerVector,
261  ABlockTransferDstScalarPerVector_K1,
262  false, // AThreadTransferSrcResetCoordinateAfterRun,
263  ABlockLdsExtraM,
264  BBlockTransferThreadClusterLengths_K0_N_K1,
265  BBlockTransferThreadClusterArrangeOrder,
266  BBlockTransferSrcAccessOrder,
267  BBlockTransferSrcVectorDim,
268  BBlockTransferSrcScalarPerVector,
269  BBlockTransferDstScalarPerVector_K1,
270  false, // BThreadTransferSrcResetCoordinateAfterRun,
271  BBlockLdsExtraN,
272  CShuffleMXdlPerWavePerShuffle,
273  CShuffleNXdlPerWavePerShuffle,
274  CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
275  CDEBlockTransferScalarPerVector_NPerBlock,
276  LoopSched>;
277 
280  AGridDesc_M_K{}))>;
283  BGridDesc_N_K{}))>;
286  DsGridDesc_M_N{}))>;
289  EGridDesc_M_N{}))>;
290 
292  {
295 
297  {
299  BlockStart_ = -1;
300  }
301 
302  GroupedGemmBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, ck::index_t BlockStart)
303  {
305  BlockStart_ = BlockStart;
306  }
307 
308  template <typename TopIdx>
309  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
310  {
311  return block_2_etile_map_.CalculateBottomIndex(
312  make_multi_index(idx_top[I0] - BlockStart_));
313  }
314 
315  // it's actually E-Tile
316  template <typename CTileIdx, typename CTileDim>
317  __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
318  const CTileDim& c_tile_dim) const
319  {
320  return block_2_etile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
321  }
322 
323  __host__ bool CheckValidity(const EGridDesc_M_N& e_grid_desc_m_n) const
324  {
325  return block_2_etile_map_.CheckValidity(e_grid_desc_m_n);
326  }
327 
330  };
331 
333  {
334  // pointers
335  const ADataType* a_ptr_;
336  const BDataType* b_ptr_;
338  EDataType* e_ptr_;
339 
340  // tensor descriptors for problem definiton
345 
346  // tensor descriptors for block/thread-wise copy
352 
353  // block-to-e-tile map
356  };
357 
358  // Argument
359  struct Argument : public BaseArgument
360  {
361  Argument(std::vector<const void*>& p_As,
362  std::vector<const void*>& p_Bs,
363  std::vector<std::array<const void*, NumDTensor>>& p_Ds,
364  std::vector<void*>& p_Es,
365  std::vector<GemmDesc>& gemm_descs,
366  AElementwiseOperation a_element_op,
367  BElementwiseOperation b_element_op,
368  CDEElementwiseOperation c_element_op)
369  : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}
370  {
371  grid_size_ = 0;
372 
373  group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
374 
375  if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
376  group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
377  group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
378  {
379  throw std::runtime_error("wrong! group_count_ != p_As/b/c.size");
380  }
381 
383 
385 
386  for(std::size_t i = 0; i < gemm_descs.size(); i++)
387  {
388  const index_t M = gemm_descs[i].M_;
389  const index_t N = gemm_descs[i].N_;
390  const index_t K = gemm_descs[i].K_;
391 
392  a_mtx_mraw_kraw_.emplace_back(M, K);
393  b_mtx_nraw_kraw_.emplace_back(N, K);
394 
395  if(M == 0)
396  {
398  continue;
399  }
400 
401  const index_t StrideA = gemm_descs[i].stride_A_;
402  const index_t StrideB = gemm_descs[i].stride_B_;
403  const index_t StrideC = gemm_descs[i].stride_C_;
404 
405  // pointer
406  typename GridwiseGemm::DsGridPointer p_ds_grid{};
407 
408  static_for<0, NumDTensor, 1>{}([&](auto j) {
409  using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
410 
411  p_ds_grid(j) = static_cast<const DDataType*>(p_Ds[i][j]);
412  });
413 
414  // tensor descriptors for problem definiton
415  const auto a_grid_desc_m_k = DeviceOp::MakeAGridDescriptor_M_K(M, K, StrideA);
416  const auto b_grid_desc_n_k = DeviceOp::MakeBGridDescriptor_N_K(K, N, StrideB);
417 
418  DsGridDesc_M_N ds_grid_desc_m_n;
419 
420  static_for<0, NumDTensor, 1>{}([&](auto j) {
421  using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
422 
423  ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
424  M, N, gemm_descs[i].stride_Ds_[j]);
425  });
426 
427  const auto e_grid_desc_m_n =
428  DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideC);
429 
430  // tensor descriptors for block/thread-wise copy
431  const auto a_grid_desc_ak0_m_ak1 =
433 
434  const auto b_grid_desc_bk0_n_bk1 =
436 
437  const index_t grid_size_grp =
438  GroupedGemmBlock2ETileMap(e_grid_desc_m_n, 0)
439  .block_2_etile_map_.CalculateGridSize(e_grid_desc_m_n);
440 
441  const index_t BlockStart = grid_size_;
442  const index_t BlockEnd = grid_size_ + grid_size_grp;
443 
444  grid_size_ += grid_size_grp;
445 
446  // block-to-e-tile map
447  const auto block_2_etile_map =
448  GroupedGemmBlock2ETileMap(e_grid_desc_m_n, BlockStart);
449 
450  if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
451  b_grid_desc_n_k,
452  ds_grid_desc_m_n,
453  e_grid_desc_m_n,
454  block_2_etile_map))
455  {
456  // tensor descriptors for block/thread-wise copy
458  ds_grid_desc_mblock_mperblock_nblock_nperblock;
459 
460  static_for<0, NumDTensor, 1>{}([&](auto j) {
461  ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
463  ds_grid_desc_m_n[j]);
464  });
465 
466  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
468  e_grid_desc_m_n);
469 
470  gemm_desc_kernel_arg_.push_back(
471  GemmBiasTransKernelArg{static_cast<const ADataType*>(p_As[i]),
472  static_cast<const BDataType*>(p_Bs[i]),
473  p_ds_grid,
474  static_cast<EDataType*>(p_Es[i]),
475  a_grid_desc_m_k,
476  b_grid_desc_n_k,
477  ds_grid_desc_m_n,
478  e_grid_desc_m_n,
479  a_grid_desc_ak0_m_ak1,
480  b_grid_desc_bk0_n_bk1,
481  ds_grid_desc_mblock_mperblock_nblock_nperblock,
482  e_grid_desc_mblock_mperblock_nblock_nperblock,
483  block_2_etile_map,
484  BlockStart,
485  BlockEnd});
486  }
487  }
488  }
489 
490  // private:
493 
494  AElementwiseOperation a_element_op_;
495  BElementwiseOperation b_element_op_;
496  CDEElementwiseOperation c_element_op_;
497 
498  std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
499  std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
500  std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
501 
503  };
504 
505  // Invoker
506  struct Invoker : public BaseInvoker
507  {
509 
510  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
511  {
512  bool has_main_k_block_loop = true;
513 
514  for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
515  {
516  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
517  {
518  std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{"
519  << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0)
520  << ", "
521  << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I1)
522  << ", "
523  << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2)
524  << "}";
525 
526  std::cout << ", arg.b_grid_desc_bk0_n_bk1_{"
527  << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I0)
528  << ", "
529  << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I1)
530  << ", "
531  << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I2)
532  << "}";
533 
534  std::cout << ", arg.e_grid_desc_m_n_{ "
535  << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
536  << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
537  << std::endl;
538  }
539 
540  if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_m_k_,
541  arg.gemm_desc_kernel_arg_[i].b_grid_desc_n_k_,
542  arg.gemm_desc_kernel_arg_[i].ds_grid_desc_m_n_,
543  arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_,
544  arg.gemm_desc_kernel_arg_[i].block_2_etile_map_))
545  {
546  throw std::runtime_error(
547  "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
548  }
549 
550  const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
551  arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
552 
553  if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
554  {
555  throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
556  }
557  }
558 
559  hipGetErrorString(
560  hipMemcpyAsync(arg.p_workspace_,
561  arg.gemm_desc_kernel_arg_.data(),
562  arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg),
563  hipMemcpyHostToDevice,
564  stream_config.stream_id_));
565 
566  float ave_time = 0;
567 
568  auto launch_kernel = [&](auto has_main_k_block_loop_) {
569  const auto kernel = kernel_grouped_gemm_xdl<GridwiseGemm,
570  GemmBiasTransKernelArg,
571  AElementwiseOperation,
572  BElementwiseOperation,
573  CDEElementwiseOperation,
574  has_main_k_block_loop_>;
575 
576  return launch_and_time_kernel(
577  stream_config,
578  kernel,
579  dim3(arg.grid_size_),
580  dim3(BlockSize),
581  0,
583  arg.gemm_desc_kernel_arg_.size(),
584  arg.a_element_op_,
585  arg.b_element_op_,
586  arg.c_element_op_);
587  };
588 
589  if(has_main_k_block_loop)
590  {
591  ave_time = launch_kernel(integral_constant<bool, true>{});
592  }
593  else
594  {
595  ave_time = launch_kernel(integral_constant<bool, false>{});
596  }
597 
598  return ave_time;
599  }
600 
601  // polymorphic
602  float Run(const BaseArgument* p_arg,
603  const StreamConfig& stream_config = StreamConfig{}) override
604  {
605  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
606  }
607  };
608 
609  static bool IsSupportedArgument(const Argument& arg)
610  {
611  if(!ck::is_xdl_supported())
612  {
613  return false;
614  }
615 
616  if((ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) +
618  {
619  return false;
620  }
621 
622  bool supported = true;
623 
624  // If we use padding we do not support vector loads for dimensions not divisible by vector
625  // load size.
626  if constexpr(GemmSpec != GemmSpecialization::Default)
627  {
628  // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout,
629  // thus we have to adapt it to the {M,K} or {N,K} layout.
630  const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
631  const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;
632 
633  for(index_t i = 0; i < arg.group_count_; ++i)
634  {
635  const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number<a_raw_vector_dim>{});
636  const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number<b_raw_vector_dim>{});
637 
638  supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0);
639  supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0);
640  }
641  }
642 
643  return supported;
644  }
645 
646  // polymorphic
647  bool IsSupportedArgument(const BaseArgument* p_arg) override
648  {
649  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
650  }
651 
652  static auto MakeArgument(std::vector<const void*>& p_As,
653  std::vector<const void*>& p_Bs,
654  std::vector<std::array<const void*, NumDTensor>>& p_Ds,
655  std::vector<void*>& p_Es,
656  std::vector<GemmDesc> gemm_descs,
657  AElementwiseOperation a_element_op,
658  BElementwiseOperation b_element_op,
659  CDEElementwiseOperation c_element_op)
660  {
661  return Argument{
662  p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
663  }
664 
665  static auto MakeInvoker() { return Invoker{}; }
666 
667  // polymorphic
668  std::unique_ptr<BaseArgument>
669  MakeArgumentPointer(std::vector<const void*>& p_As,
670  std::vector<const void*>& p_Bs,
671  std::vector<std::array<const void*, NumDTensor>>& p_Ds,
672  std::vector<void*>& p_Es,
673  std::vector<GemmDesc>& gemm_descs,
674  AElementwiseOperation a_element_op,
675  BElementwiseOperation b_element_op,
676  CDEElementwiseOperation c_element_op) override
677  {
678  return std::make_unique<Argument>(
679  p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op);
680  }
681 
682  // polymorphic
683  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
684  {
685  return std::make_unique<Invoker>(Invoker{});
686  }
687 
688  // polymorphic
689  std::string GetTypeString() const override
690  {
691  auto str = std::stringstream();
692 
693  // clang-format off
694  str << "DeviceGroupedGemm_Xdl"
695  << "<"
696  << BlockSize << ", "
697  << MPerBlock << ", "
698  << NPerBlock << ", "
699  << KPerBlock << ", "
700  << AK1 << ", "
701  << BK1 << ", "
702  << MPerXDL << ", "
703  << NPerXDL << ", "
704  << MXdlPerWave << ", "
705  << NXdlPerWave << ", "
706  << ABlockTransferSrcScalarPerVector << ", "
707  << BBlockTransferSrcScalarPerVector << ", "
708  << CShuffleMXdlPerWavePerShuffle << ", "
709  << CShuffleNXdlPerWavePerShuffle << ", "
710  << getGemmSpecializationString(GemmSpec)
711  << ">";
712  // clang-format on
713 
714  return str.str();
715  }
716 
717  size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
718  {
719  auto p_arg_ = dynamic_cast<const Argument*>(p_arg);
720  if(p_arg_)
721  {
722  return p_arg_->group_count_ * sizeof(GemmBiasTransKernelArg);
723  }
724  else
725  throw std::runtime_error("The argument pointer is not an object of "
726  "DeviceGroupedGemmMultipleDXdlCShuffle::Argument structure!");
727  }
728 
729  size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
730  {
731  return GetWorkSpaceSize(p_arg);
732  }
733 
734  void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
735  {
736  return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
737  }
738 };
739 
740 } // namespace device
741 } // namespace tensor_operation
742 } // namespace ck
#define CK_CONSTANT_ADDRESS_SPACE
Definition: ck.hpp:26
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
#define CK_ENV(name)
Definition: env.hpp:128
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:33
GemmSpecialization
Definition: gemm_specialization.hpp:11
__global__ void kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation c_element_op)
Definition: device_grouped_gemm_xdl.hpp:35
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables... callables)
Definition: kernel_launch.hpp:72
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
bool is_xdl_supported()
Definition: device_prop.hpp:54
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition: amd_address_space.hpp:35
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition: amd_address_space.hpp:24
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:289
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
__host__ static constexpr __device__ auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:221
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:396
__host__ static constexpr __device__ auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:204
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:403
__host__ static constexpr __device__ auto MakeDefaultBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:254
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const DsGridDesc_M_N &ds_grid_desc_m_n, const EGridDesc_M_N &e_grid_desc_m_n, [[maybe_unused]] const Block2ETileMap &)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:329
__host__ static constexpr __device__ auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:187
__host__ static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N &ds_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:242
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: functional2.hpp:31
Definition: device_base.hpp:50
void * p_workspace_
Definition: device_base.hpp:57
Definition: device_base.hpp:61
virtual void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const
Definition: device_base.hpp:101
Definition: device_grouped_gemm_xdl.hpp:360
std::vector< GemmBiasTransKernelArg > gemm_desc_kernel_arg_
Definition: device_grouped_gemm_xdl.hpp:498
std::vector< Tuple< index_t, index_t > > a_mtx_mraw_kraw_
Definition: device_grouped_gemm_xdl.hpp:499
Argument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor >> &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op)
Definition: device_grouped_gemm_xdl.hpp:361
AElementwiseOperation a_element_op_
Definition: device_grouped_gemm_xdl.hpp:494
CDEElementwiseOperation c_element_op_
Definition: device_grouped_gemm_xdl.hpp:496
std::vector< Tuple< index_t, index_t > > b_mtx_nraw_kraw_
Definition: device_grouped_gemm_xdl.hpp:500
index_t skipped_group_count_
Definition: device_grouped_gemm_xdl.hpp:492
index_t grid_size_
Definition: device_grouped_gemm_xdl.hpp:502
index_t group_count_
Definition: device_grouped_gemm_xdl.hpp:491
BElementwiseOperation b_element_op_
Definition: device_grouped_gemm_xdl.hpp:495
EGridDesc_M_N e_grid_desc_m_n_
Definition: device_grouped_gemm_xdl.hpp:344
DsGridDesc_M_N ds_grid_desc_m_n_
Definition: device_grouped_gemm_xdl.hpp:343
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition: device_grouped_gemm_xdl.hpp:348
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_grouped_gemm_xdl.hpp:351
const ADataType * a_ptr_
Definition: device_grouped_gemm_xdl.hpp:335
BGridDesc_N_K b_grid_desc_n_k_
Definition: device_grouped_gemm_xdl.hpp:342
AGridDesc_M_K a_grid_desc_m_k_
Definition: device_grouped_gemm_xdl.hpp:341
GroupedGemmBlock2ETileMap block_2_etile_map_
Definition: device_grouped_gemm_xdl.hpp:354
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition: device_grouped_gemm_xdl.hpp:347
EDataType * e_ptr_
Definition: device_grouped_gemm_xdl.hpp:338
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_grouped_gemm_xdl.hpp:350
ck::index_t BlockStart_
Definition: device_grouped_gemm_xdl.hpp:355
GridwiseGemm::DsGridPointer ds_ptr_
Definition: device_grouped_gemm_xdl.hpp:337
ck::index_t BlockEnd_
Definition: device_grouped_gemm_xdl.hpp:355
const BDataType * b_ptr_
Definition: device_grouped_gemm_xdl.hpp:336
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition: device_grouped_gemm_xdl.hpp:317
GroupedGemmBlock2ETileMap()
Definition: device_grouped_gemm_xdl.hpp:296
Block2ETileMap block_2_etile_map_
Definition: device_grouped_gemm_xdl.hpp:328
GroupedGemmBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n, ck::index_t BlockStart)
Definition: device_grouped_gemm_xdl.hpp:302
ck::index_t BlockStart_
Definition: device_grouped_gemm_xdl.hpp:329
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: device_grouped_gemm_xdl.hpp:309
__host__ bool CheckValidity(const EGridDesc_M_N &e_grid_desc_m_n) const
Definition: device_grouped_gemm_xdl.hpp:323
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> Block2ETileMap
Definition: device_grouped_gemm_xdl.hpp:294
Definition: device_grouped_gemm_xdl.hpp:507
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_grouped_gemm_xdl.hpp:510
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_grouped_gemm_xdl.hpp:602
Definition: device_grouped_gemm_xdl.hpp:145
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition: device_grouped_gemm_xdl.hpp:283
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
Definition: device_grouped_gemm_xdl.hpp:194
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition: device_grouped_gemm_xdl.hpp:226
static auto MakeArgument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor >> &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op)
Definition: device_grouped_gemm_xdl.hpp:652
std::string GetTypeString() const override
Definition: device_grouped_gemm_xdl.hpp:689
static auto MakeDsGridDescriptor_M_N(const std::array< index_t, NumDTensor > &MRaws, const std::array< index_t, NumDTensor > &NRaws, const std::array< index_t, NumDTensor > &DsStride)
Definition: device_grouped_gemm_xdl.hpp:212
void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args) const override
Sets the device kernel arguments pointer and may copy data to device.
Definition: device_grouped_gemm_xdl.hpp:734
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition: device_grouped_gemm_xdl.hpp:225
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition: device_grouped_gemm_xdl.hpp:175
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor >> &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op) override
Definition: device_grouped_gemm_xdl.hpp:669
ADataType ComputeDataType
Definition: device_grouped_gemm_xdl.hpp:230
remove_cvref_t< decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: device_grouped_gemm_xdl.hpp:286
static constexpr auto I1
Definition: device_grouped_gemm_xdl.hpp:151
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition: device_grouped_gemm_xdl.hpp:280
remove_cvref_t< decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: device_grouped_gemm_xdl.hpp:289
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_grouped_gemm_xdl.hpp:647
size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const override
Gets the device kernel argument size.
Definition: device_grouped_gemm_xdl.hpp:729
static constexpr auto I2
Definition: device_grouped_gemm_xdl.hpp:152
static constexpr auto matrix_padder
Definition: device_grouped_gemm_xdl.hpp:154
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({}, {}, {}))> DsGridDesc_M_N
Definition: device_grouped_gemm_xdl.hpp:227
static bool IsSupportedArgument(const Argument &arg)
Definition: device_grouped_gemm_xdl.hpp:609
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_grouped_gemm_xdl.hpp:683
static auto MakeInvoker()
Definition: device_grouped_gemm_xdl.hpp:665
static constexpr index_t NumDTensor
Definition: device_grouped_gemm_xdl.hpp:148
decltype(MakeEGridDescriptor_M_N< ELayout >(1, 1, 1)) EGridDesc_M_N
Definition: device_grouped_gemm_xdl.hpp:228
GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, NumPrefetch, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched > GridwiseGemm
Definition: device_grouped_gemm_xdl.hpp:276
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition: device_grouped_gemm_xdl.hpp:157
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition: device_grouped_gemm_xdl.hpp:717
static constexpr auto I0
Definition: device_grouped_gemm_xdl.hpp:150
Definition: device_grouped_gemm.hpp:105
Definition: device_grouped_gemm.hpp:86
Definition: matrix_padder.hpp:180