/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp Source File
device_grouped_gemm_xdl_splitk_cshuffle.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
9 #include "ck/ck.hpp"
14 #include "ck/utility/tuple.hpp"
21 
22 namespace ck {
23 namespace tensor_operation {
24 namespace device {
25 
26 template <typename GridwiseGemm,
27  typename GemmDesc,
28  bool HasMainKBlockLoop,
29  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
30  typename AElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
31  typename BElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
32  typename CElementwiseOperation = ck::tensor_operation::element_wise::PassThrough>
33 __global__ void
34 #if CK_USE_LAUNCH_BOUNDS
36 #endif
38  const index_t group_count,
39  const AElementwiseOperation a_element_op,
40  const BElementwiseOperation b_element_op,
41  const CElementwiseOperation c_element_op)
42 {
43 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
44  constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
45  __shared__ uint8_t p_shared[shared_size];
46 
47  const index_t block_id = get_block_1d_id();
48  const auto gemm_desc_ptr =
49  reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
50 
51  index_t left = 0;
52  index_t right = group_count;
53  index_t group_id = index_t((left + right) / 2);
54  while((!(block_id >= gemm_desc_ptr[group_id].block_start_ &&
55  block_id < gemm_desc_ptr[group_id].block_end_)) &&
56  left <= right)
57  {
58  if(block_id < gemm_desc_ptr[group_id].block_start_)
59  {
60  right = group_id;
61  }
62  else
63  {
64  left = group_id;
65  }
66  group_id = index_t((left + right) / 2);
67  }
68 
69  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
70  gemm_desc_ptr[group_id].karg_,
71  static_cast<void*>(p_shared),
72  gemm_desc_ptr[group_id].block_2_ctile_map_,
73  a_element_op,
74  b_element_op,
75  c_element_op);
76 #else
77  ignore = gemm_descs_const;
78  ignore = group_count;
79  ignore = a_element_op;
80  ignore = b_element_op;
81  ignore = c_element_op;
82 #endif // end of if (defined(__gfx9__))
83 }
84 
85 template <typename ALayout,
86  typename BLayout,
87  typename DsLayout,
88  typename ELayout,
89  typename ADataType,
90  typename BDataType,
91  typename AccDataType,
92  typename CShuffleDataType,
93  typename DsDataType,
94  typename EDataType,
95  typename AElementwiseOperation,
96  typename BElementwiseOperation,
97  typename CDEElementwiseOperation,
98  GemmSpecialization GemmSpec,
99  ck::index_t NumGemmKPrefetchStage,
100  ck::index_t BlockSize,
101  ck::index_t MPerBlock,
102  ck::index_t NPerBlock,
103  ck::index_t KPerBlock,
104  ck::index_t AK1,
105  ck::index_t BK1,
106  ck::index_t MPerXDL,
107  ck::index_t NPerXDL,
108  ck::index_t MXdlPerWave,
109  ck::index_t NXdlPerWave,
110  typename ABlockTransferThreadClusterLengths_K0_M_K1,
111  typename ABlockTransferThreadClusterArrangeOrder,
112  typename ABlockTransferSrcAccessOrder,
113  ck::index_t ABlockTransferSrcVectorDim,
114  ck::index_t ABlockTransferSrcScalarPerVector,
115  ck::index_t ABlockTransferDstScalarPerVector_K1,
116  bool ABlockLdsExtraM,
117  typename BBlockTransferThreadClusterLengths_K0_N_K1,
118  typename BBlockTransferThreadClusterArrangeOrder,
119  typename BBlockTransferSrcAccessOrder,
120  ck::index_t BBlockTransferSrcVectorDim,
121  ck::index_t BBlockTransferSrcScalarPerVector,
122  ck::index_t BBlockTransferDstScalarPerVector_K1,
123  bool BBlockLdsExtraN,
124  index_t CShuffleMXdlPerWavePerShuffle,
125  index_t CShuffleNXdlPerWavePerShuffle,
126  typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
127  index_t CDEBlockTransferScalarPerVector_NPerBlock,
128  PipelineVersion PipelineVer = PipelineVersion::v1,
130  // Current implementation does not support multiple D fusions.
132  is_same_v<DsDataType, ck::Tuple<>>,
133  bool> = false>
135  BLayout,
136  DsLayout,
137  ELayout,
138  ADataType,
139  BDataType,
140  DsDataType,
141  EDataType,
142  AElementwiseOperation,
143  BElementwiseOperation,
144  CDEElementwiseOperation>
145 {
146  static constexpr index_t NumDTensor = DsDataType::Size();
147 
148  static constexpr auto I0 = Number<0>{};
149  static constexpr auto I1 = Number<1>{};
150  static constexpr auto I2 = Number<2>{};
151  static constexpr auto I3 = Number<3>{};
152  static_assert(KPerBlock % AK1 == 0);
153  static constexpr index_t K0PerBlock = KPerBlock / AK1;
154 
156  BlockSize,
157  ADataType,
158  BDataType,
159  AccDataType,
160  EDataType,
161  ALayout,
162  BLayout,
163  ELayout,
164  AElementwiseOperation,
165  BElementwiseOperation,
166  CDEElementwiseOperation,
167  GemmSpec,
168  NumGemmKPrefetchStage,
169  MPerBlock,
170  NPerBlock,
171  K0PerBlock,
172  MPerXDL,
173  NPerXDL,
174  AK1,
175  MXdlPerWave,
176  NXdlPerWave,
177  ABlockTransferThreadClusterLengths_K0_M_K1,
178  ABlockTransferThreadClusterArrangeOrder,
179  ABlockTransferSrcAccessOrder,
180  ABlockTransferSrcVectorDim,
181  ABlockTransferSrcScalarPerVector,
182  ABlockTransferDstScalarPerVector_K1,
183  false, // AThreadTransferSrcResetCoordinateAfterRun,
184  ABlockLdsExtraM,
185  BBlockTransferThreadClusterLengths_K0_N_K1,
186  BBlockTransferThreadClusterArrangeOrder,
187  BBlockTransferSrcAccessOrder,
188  BBlockTransferSrcVectorDim,
189  BBlockTransferSrcScalarPerVector,
190  BBlockTransferDstScalarPerVector_K1,
191  false, // BThreadTransferSrcResetCoordinateAfterRun,
192  BBlockLdsExtraN,
193  CShuffleMXdlPerWavePerShuffle,
194  CShuffleNXdlPerWavePerShuffle,
195  CDEBlockTransferScalarPerVector_NPerBlock,
196  CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
197  LoopSched,
198  PipelineVer>;
199 
203  // Block2CTileMap configuration parameter.
204  static constexpr index_t B2E_M01 = 8;
209  {
213 
214  GemmTransKernelArg() = default;
216  GroupedGemmBlock2ETileMap&& b2c_map,
217  index_t block_start,
218  index_t block_end)
219  : karg_{karg},
220  block_2_ctile_map_{b2c_map},
221  block_start_{block_start},
222  block_end_{block_end}
223  {
224  }
225  };
226 
227  static constexpr index_t DefaultKBatch = 1;
228 
229  // Argument
230  struct Argument : public BaseArgument
231  {
232 
233  Argument(std::vector<const void*>& p_As,
234  std::vector<const void*>& p_Bs,
235  std::vector<void*>& p_Es,
236  std::vector<GemmDesc>& gemm_descs)
237  : Argument(p_As, p_Bs, p_Es, gemm_descs, DefaultKBatch)
238  {
239  // TODO: use occupancy api to calculate appropriate batch size.
240  }
241 
242  Argument(std::vector<const void*>& p_As,
243  std::vector<const void*>& p_Bs,
244  std::vector<void*>& p_Es,
245  std::vector<GemmDesc>& gemm_descs,
246  index_t kbatch)
247  : K_BATCH{kbatch}
248  {
249  grid_size_ = 0;
250  group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
251 
252  if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
253  group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
254  group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
255  {
256  throw std::runtime_error("wrong! group_count_ != p_As/b/c.size");
257  }
258 
260 
262 
263  for(std::size_t i = 0; i < gemm_descs.size(); ++i)
264  {
265  const index_t M = gemm_descs[i].M_;
266  const index_t N = gemm_descs[i].N_;
267  const index_t K = gemm_descs[i].K_;
268 
269  if(M == 0)
270  {
272  continue;
273  }
274 
275  const index_t stride_a = gemm_descs[i].stride_A_;
276  const index_t stride_b = gemm_descs[i].stride_B_;
277  const index_t stride_c = gemm_descs[i].stride_C_;
278 
279  const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
280  const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
281  const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
282  const index_t k0_padded = GridwiseGemm::CalculateK0Padded(K, K_BATCH);
283 
284  const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c);
285 
286  const auto local_b2c_tile_map =
287  Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
288  const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
289 
290  const index_t block_start = grid_size_;
291  const index_t block_end = grid_size_ + grid_size_grp;
292 
293  grid_size_ += grid_size_grp;
294 
295  // block-to-e-tile map
296  auto grouped_block_2_ctile_map =
297  GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
298 
299  auto karg = KernelArgument{type_convert<const ADataType*>(p_As[i]),
300  type_convert<const BDataType*>(p_Bs[i]),
301  type_convert<EDataType*>(p_Es[i]),
302  M,
303  N,
304  K,
305  stride_a,
306  stride_b,
307  stride_c,
308  m_padded,
309  n_padded,
310  k_padded,
311  k0_padded,
312  K_BATCH};
313 
314  gemm_kernel_args_.emplace_back(
315  std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
316  }
317  }
318 
324  void UpdateKBatch(index_t kbatch)
325  {
326  K_BATCH = kbatch;
327  grid_size_ = 0;
328 
329  for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i)
330  {
331 
332  auto& karg = gemm_kernel_args_[i].karg_;
333 
334  const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
335  const index_t k0_padded = GridwiseGemm::CalculateK0Padded(karg.K, K_BATCH);
336 
337  const auto c_grid_desc_m_n =
338  GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
339 
340  const auto local_b2c_tile_map =
341  Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
342  const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
343 
344  const index_t block_start = grid_size_;
345  const index_t block_end = grid_size_ + grid_size_grp;
346 
347  grid_size_ += grid_size_grp;
348 
349  // block-to-e-tile map
350  auto grouped_block_2_ctile_map =
351  GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
352 
353  karg.KPadded = k_padded;
354  karg.K0Padded = k0_padded;
355  karg.k_batch = K_BATCH;
356  gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
357  gemm_kernel_args_[i].block_start_ = block_start;
358  gemm_kernel_args_[i].block_end_ = block_end;
359  }
360  }
361 
362  // private:
366 
367  std::vector<GemmTransKernelArg> gemm_kernel_args_;
369  };
370 
371  // Invoker
372  struct Invoker : public BaseInvoker
373  {
374  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
375  {
376  index_t K0 = arg.gemm_kernel_args_[0].karg_.K0Padded;
377  bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1;
378  bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
379 
380  for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
381  {
382  const auto& karg = arg.gemm_kernel_args_[i].karg_;
383  if(stream_config.log_level_ > 0)
384  {
385  karg.Print();
386  }
387 
388  auto kbatch = karg.k_batch;
389 
390  if(!GridwiseGemm::CheckValidity(karg))
391  {
392  std::ostringstream err;
393  err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__
394  << ":" << __LINE__ << ", in function: " << __func__;
395  throw std::runtime_error(err.str());
396  }
397 
398  K0 = karg.K0Padded;
399  bool not_all_have_main_k0_block_loop_same =
400  all_have_main_k0_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
401  bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1);
402 
403  if(not_all_have_main_k0_block_loop_same)
404  {
405  std::ostringstream err;
406  err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__
407  << ":" << __LINE__ << ", in function: " << __func__;
408  throw std::runtime_error(err.str());
409  }
410 
411  if(not_all_have_kbatch_value_same)
412  {
413  std::ostringstream err;
414  err << "Not all gemms have same kbatch value (=1 or >1)! "
415  << "group [" << i << "], kbatch: " << kbatch
416  << ", group [0], kbatch: " << arg.gemm_kernel_args_[0].karg_.k_batch
417  << " in " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
418  throw std::runtime_error(err.str());
419  }
420  }
421 
423  hipMemcpyAsync(arg.p_workspace_,
424  arg.gemm_kernel_args_.data(),
425  arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
426  hipMemcpyHostToDevice,
427  stream_config.stream_id_));
428 
429  float ave_time = 0;
430 
431  const auto Run = [&](const auto& kernel) {
432  if(all_have_kbatch_gt_one)
433  {
434  for(const auto& trans_arg : arg.gemm_kernel_args_)
435  {
436  const auto& karg = trans_arg.karg_;
437  hip_check_error(hipMemsetAsync(karg.p_c_grid,
438  0,
439  karg.M * karg.N * sizeof(EDataType),
440  stream_config.stream_id_));
441  }
442  }
443 
444  ave_time =
445  launch_and_time_kernel(stream_config,
446  kernel,
447  dim3(arg.grid_size_),
448  dim3(BlockSize),
449  0,
451  arg.gemm_kernel_args_.size(),
452  PassThrough{},
453  PassThrough{},
454  PassThrough{});
455  };
456 
457  if(all_have_main_k0_block_loop)
458  {
459  if(all_have_kbatch_gt_one)
460  {
461  const auto kernel =
463  GemmTransKernelArg,
464  true,
466 
467  Run(kernel);
468  }
469  else
470  {
471  const auto kernel =
473  GemmTransKernelArg,
474  true,
476 
477  Run(kernel);
478  }
479  }
480  else
481  {
482  if(all_have_kbatch_gt_one)
483  {
484  const auto kernel =
486  GemmTransKernelArg,
487  false,
489 
490  Run(kernel);
491  }
492  else
493  {
494  const auto kernel =
496  GemmTransKernelArg,
497  false,
499 
500  Run(kernel);
501  }
502  }
503 
504  return ave_time;
505  }
506 
507  // polymorphic
508  float Run(const BaseArgument* p_arg,
509  const StreamConfig& stream_config = StreamConfig{}) override
510  {
511  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
512  }
513  };
514 
515  static constexpr bool IsValidCompilationParameter()
516  {
517  // TODO: properly implement this check
518  return true;
519  }
520 
521  static bool IsSupportedArgument(const Argument& arg)
522  {
523  if(!ck::is_xdl_supported())
524  {
525  return false;
526  }
527 
528  if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
530  {
531  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
532  {
533  std::cout << "The group count is not equal to sum of skipped groups "
534  "and kernel args size!"
535  << std::endl;
536  }
537  return false;
538  }
539 
540  if(std::is_same_v<EDataType, ck::bhalf_t> && arg.K_BATCH > 1 && !is_bf16_atomic_supported())
541  {
542  return false;
543  }
544 
545  bool supported = true;
546  for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
547  {
548  const auto& a = arg.gemm_kernel_args_[i].karg_;
549 
550  bool group_arg_valid = GridwiseGemm::CheckValidity(a);
551  if(not group_arg_valid)
552  {
553  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
554  {
555  std::cout << "[" << __func__ << "] group id: " << i
556  << " has invalid GridwiseGemm settings!" << std::endl;
557  a.Print();
558  }
559  }
560  supported = supported && group_arg_valid;
561  }
562  return supported;
563  }
564 
565  // polymorphic
566  bool IsSupportedArgument(const BaseArgument* p_arg) override
567  {
568  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
569  }
570 
571  static auto MakeArgument(std::vector<const void*>& p_As,
572  std::vector<const void*>& p_Bs,
573  std::vector<std::array<const void*, NumDTensor>>&,
574  std::vector<void*>& p_Es,
575  std::vector<GemmDesc> gemm_descs,
576  AElementwiseOperation,
577  BElementwiseOperation,
578  CDEElementwiseOperation)
579  {
580  return Argument{p_As, p_Bs, p_Es, gemm_descs};
581  }
582 
583  static auto MakeInvoker() { return Invoker{}; }
584 
585  // polymorphic
586  std::unique_ptr<BaseArgument>
587  MakeArgumentPointer(std::vector<const void*>& p_As,
588  std::vector<const void*>& p_Bs,
589  std::vector<std::array<const void*, NumDTensor>>&,
590  std::vector<void*>& p_Es,
591  std::vector<GemmDesc>& gemm_descs,
592  AElementwiseOperation,
593  BElementwiseOperation,
594  CDEElementwiseOperation) override
595  {
596  return std::make_unique<Argument>(p_As, p_Bs, p_Es, gemm_descs);
597  }
598 
599  // polymorphic
600  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
601  {
602  return std::make_unique<Invoker>(Invoker{});
603  }
604 
605  // polymorphic
606  std::string GetTypeString() const override
607  {
608  auto str = std::stringstream();
609 
610  // clang-format off
611  str << "DeviceGroupedGemm_XdlSplitK"
612  << "<"
613  << std::string(ALayout::name)[0] << ","
614  << std::string(BLayout::name)[0] << ","
615  << std::string(ELayout::name)[0] << ","
616  << BlockSize << ", "
617  << MPerBlock << ", "
618  << NPerBlock << ", "
619  << KPerBlock << ", "
620  << AK1 << ", "
621  << BK1 << ", "
622  << MPerXDL << ", "
623  << NPerXDL << ", "
624  << MXdlPerWave << ", "
625  << NXdlPerWave << ", "
626  << ABlockTransferSrcScalarPerVector << ", "
627  << BBlockTransferSrcScalarPerVector << ", "
628  << CShuffleMXdlPerWavePerShuffle << ", "
629  << CShuffleNXdlPerWavePerShuffle << ", "
630  << getGemmSpecializationString(GemmSpec)
631  << ">";
632  // clang-format on
633 
634  return str.str();
635  }
636 
637  size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
638  {
639  auto p_arg_ = dynamic_cast<const Argument*>(p_arg);
640  if(p_arg_)
641  {
642  return p_arg_->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg);
643  }
644  else
645  throw std::runtime_error(
646  "The argument pointer is not an object of "
647  "DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!");
648  }
649 
650  size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
651  {
652  return GetWorkSpaceSize(p_arg);
653  }
654 
655  // TODO: deperecation notice.
656  static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); }
657 
658  // polymorphic
659  void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
660  {
661  auto p_arg_ = dynamic_cast<Argument*>(p_arg);
662  if(p_arg_)
663  {
664  p_arg_->UpdateKBatch(kbatch);
665  }
666  else
667  throw std::runtime_error(
668  "The argument pointer is not an object of "
669  "DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!");
670  }
671 
672  void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
673  {
674  return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
675  }
676 };
677 
678 } // namespace device
679 } // namespace tensor_operation
680 } // namespace ck
#define CK_CONSTANT_ADDRESS_SPACE
Definition: ck.hpp:26
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
#define CK_ENV(name)
Definition: env.hpp:128
void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:33
__global__ void kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op)
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:37
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:264
bool is_xdl_supported()
Definition: device_prop.hpp:54
InMemoryDataOperationEnum
Definition: ck.hpp:267
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition: amd_address_space.hpp:35
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition: amd_address_space.hpp:24
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:289
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:13
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:17
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:69
Definition: stream_config.hpp:10
Definition: block_to_ctile_map.hpp:539
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:128
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:103
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:200
__host__ static constexpr __device__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:440
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:205
__host__ static __device__ auto CalculateK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:210
__host__ static constexpr __device__ bool CalculateHasMainK0BlockLoop(index_t K0Padded)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:609
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:375
remove_cvref_t< decltype(MakeCGridDescriptor_M_N(1, 1, 1))> CGridDesc_M_N
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:661
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:217
Definition: integral_constant.hpp:10
Definition: device_base.hpp:50
void * p_workspace_
Definition: device_base.hpp:57
Definition: device_base.hpp:61
virtual void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const
Definition: device_base.hpp:101
Definition: device_grouped_gemm_splitk.hpp:33
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:231
index_t skipped_group_count_
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:365
index_t K_BATCH
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:363
void UpdateKBatch(index_t kbatch)
Recalculate group grid size for all gemms and update B2C maps.
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:324
std::vector< GemmTransKernelArg > gemm_kernel_args_
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:367
Argument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, index_t kbatch)
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:242
index_t grid_size_
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:368
Argument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs)
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:233
index_t group_count_
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:364
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:209
GemmTransKernelArg(KernelArgument &&karg, GroupedGemmBlock2ETileMap &&b2c_map, index_t block_start, index_t block_end)
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:215
KernelArgument karg_
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:210
index_t block_start_
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:212
GroupedGemmBlock2ETileMap block_2_ctile_map_
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:211
index_t block_end_
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:212
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:373
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:508
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:374
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:145
static bool IsSupportedArgument(const Argument &arg)
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:521
std::string GetTypeString() const override
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:606
void SetKBatchSize(BaseArgument *p_arg, index_t kbatch) const override
Sets the k batch size.
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:659
static constexpr auto I1
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:149
static auto MakeInvoker()
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:583
static auto MakeArgument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor >> &, std::vector< void * > &p_Es, std::vector< GemmDesc > gemm_descs, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation)
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:571
static constexpr auto I2
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:150
static constexpr index_t DefaultKBatch
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:227
static constexpr index_t B2E_M01
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:204
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:637
void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args) const override
Sets the device kernel arguments pointer and may copy data to device.
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:672
size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const override
Gets the device kernel argument size.
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:650
static constexpr auto I3
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:151
OffsettedBlockToCTileMap< Block2ETileMapKSplit > GroupedGemmBlock2ETileMap
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:205
static constexpr bool IsValidCompilationParameter()
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:515
static constexpr index_t NumDTensor
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:146
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, EDataType, ALayout, BLayout, ELayout, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, AK1, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer > GridwiseGemm
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:198
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor >> &, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation) override
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:587
typename GridwiseGemm::CGridDesc_M_N CGridDesc_M_N
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:200
ck::tensor_operation::element_wise::PassThrough PassThrough
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:207
typename GridwiseGemm::Argument KernelArgument
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:206
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:566
static constexpr auto I0
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:148
static void SetKBatchSize(Argument &arg, index_t kbatch)
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:656
static constexpr index_t K0PerBlock
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:153
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:600
Definition: device_grouped_gemm.hpp:86
Definition: unary_element_wise_operation.hpp:241