/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp Source File
device_batched_gemm_e_permute_xdl.hpp
Go to the documentation of this file.
1 #pragma once
2 
3 #include <iostream>
4 #include <sstream>
5 
16 
17 namespace ck {
18 namespace tensor_operation {
19 namespace device {
20 
21 /*
22  * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
23  *
24  * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
25  * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
26  * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
27  * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
28 #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
29  * limitations.
30  *
31  * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
32  * returns the 2D index of the tile that it computes. \see
33  * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
34  * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
35  * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
36  * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
37  * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
38 \link
39  * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
40  * pointer offset into \p ComputePtrOffsetOfStridedBatch.
41  *
42  * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
43  * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
44  * realize BatchedGemmCPermute and GroupedGemm (and the corresponding GEMM fusion).
45  *
46  */
47 template <typename GridwiseGemm,
48  typename ABDataType,
49  typename EDataType,
50  typename AGridDesc_AK0_M_AK1,
51  typename BGridDesc_BK0_N_BK1,
52  typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
53  typename AElementwiseOperation,
54  typename BElementwiseOperation,
55  typename CDEElementwiseOperation,
56  typename ComputePtrOffsetOfBatch,
57  typename Block2ETileMap,
58  bool HasMainKBlockLoop>
59 __global__ void
60 #if CK_USE_LAUNCH_BOUNDS
62 #endif
63  kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid,
64  const ABDataType* __restrict__ p_b_grid,
65  EDataType* __restrict__ p_e_grid,
66  const index_t batch_count,
67  const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
68  const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
69  const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
70  e_grid_desc_mblock_mperblock_nblock_nperblock,
71  const AElementwiseOperation a_element_op,
72  const BElementwiseOperation b_element_op,
73  const CDEElementwiseOperation cde_element_op,
74  const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
75  const Block2ETileMap block_2_etile_map)
76 {
77 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
78  const index_t num_blocks_per_batch =
79  __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
80  const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
81 
82  const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
83  static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
84  const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
85  static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
86  const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
87  static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
88 
89  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
90 
91  GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
92  p_b_grid + b_batch_offset,
93  ck::Tuple<>{},
94  p_e_grid + e_batch_offset,
95  p_shared,
96  a_element_op,
97  b_element_op,
98  cde_element_op,
99  a_grid_desc_ak0_m_ak1,
100  b_grid_desc_bk0_n_bk1,
101  ck::Tuple<>{},
102  e_grid_desc_mblock_mperblock_nblock_nperblock,
103  block_2_etile_map);
104 #else
105  ignore = p_a_grid;
106  ignore = p_b_grid;
107  ignore = p_e_grid;
108  ignore = batch_count;
109  ignore = a_grid_desc_ak0_m_ak1;
110  ignore = b_grid_desc_bk0_n_bk1;
111  ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
112  ignore = a_element_op;
113  ignore = b_element_op;
114  ignore = cde_element_op;
115  ignore = compute_ptr_offset_of_batch;
116  ignore = block_2_etile_map;
117 #endif
118 }
119 
120 template <typename ALayout,
121  typename BLayout,
122  typename ELayout,
123  typename ADataType,
124  typename BDataType,
125  typename AccDataType,
126  typename CShuffleDataType,
127  typename EDataType,
128  typename AElementwiseOperation,
129  typename BElementwiseOperation,
130  typename CDEElementwiseOperation,
131  GemmSpecialization GemmSpec,
132  index_t NumPrefetch,
133  index_t BlockSize,
134  index_t MPerBlock,
135  index_t NPerBlock,
136  index_t KPerBlock,
137  index_t AK1,
138  index_t BK1,
139  index_t MPerXDL,
140  index_t NPerXDL,
141  index_t MXdlPerWave,
142  index_t NXdlPerWave,
143  typename ABlockTransferThreadClusterLengths_K0_M_K1,
144  typename ABlockTransferThreadClusterArrangeOrder,
145  typename ABlockTransferSrcAccessOrder,
146  index_t ABlockTransferSrcVectorDim,
147  index_t ABlockTransferSrcScalarPerVector,
148  index_t ABlockTransferDstScalarPerVector_K1,
149  index_t ABlockLdsExtraM,
150  typename BBlockTransferThreadClusterLengths_K0_N_K1,
151  typename BBlockTransferThreadClusterArrangeOrder,
152  typename BBlockTransferSrcAccessOrder,
153  index_t BBlockTransferSrcVectorDim,
154  index_t BBlockTransferSrcScalarPerVector,
155  index_t BBlockTransferDstScalarPerVector_K1,
156  index_t BBlockLdsExtraN,
157  index_t CShuffleMXdlPerWavePerShuffle,
158  index_t CShuffleNXdlPerWavePerShuffle,
159  typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
160  index_t CDEBlockTransferScalarPerVector_NPerBlock,
163  BLayout,
164  ELayout,
165  ADataType,
166  BDataType,
167  EDataType,
168  AElementwiseOperation,
169  BElementwiseOperation,
170  CDEElementwiseOperation>
171 {
173 
174  static constexpr auto I0 = Number<0>{};
175  static constexpr auto I1 = Number<1>{};
176  static constexpr auto I2 = Number<2>{};
177 
178  static constexpr auto matrix_padder =
179  MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
180 
181  static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
182  {
183  const auto a_grid_desc_mraw_kraw = [&]() {
184  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
185  {
186  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
187  make_tuple(StrideA, I1));
188  }
189  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
190  {
191  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
192  make_tuple(I1, StrideA));
193  }
194  }();
195 
196  return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
197  }
198 
199  static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
200  {
201  const auto b_grid_desc_nraw_kraw = [&]() {
203  {
204  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
205  make_tuple(I1, StrideB));
206  }
208  {
209  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
210  make_tuple(StrideB, I1));
211  }
212  }();
213 
214  return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
215  }
216 
217  static auto
218  MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N)
219  {
220  const auto e_grid_desc_mraw_nraw =
221  make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), make_tuple(stride_M, stride_N));
222 
223  return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
224  }
225 
227  index_t G1,
228  index_t MRaw,
229  index_t NRaw,
230  index_t stride_G0,
231  index_t stride_G1,
232  index_t stride_M,
233  index_t stride_N)
234  {
235  const auto e_grid_desc_g0_g1_mraw_nraw = [&]() {
237  make_tuple(G0, G1, MRaw, NRaw),
238  make_tuple(stride_G0, stride_G1, stride_M, stride_N));
239  }();
240 
241  const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
242  const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
243 
244  const auto MPad = M - MRaw;
245  const auto NPad = N - NRaw;
246 
247  if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
248  GemmSpec == GemmSpecialization::MNKPadding)
249  {
250  // pad M and N
252  e_grid_desc_g0_g1_mraw_nraw,
255  make_right_pad_transform(MRaw, MPad),
256  make_right_pad_transform(NRaw, NPad)),
259  }
260  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
261  GemmSpec == GemmSpecialization::MKPadding)
262  {
263  // pad M, but not N
265  e_grid_desc_g0_g1_mraw_nraw,
268  make_right_pad_transform(MRaw, MPad),
272  }
273  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
274  GemmSpec == GemmSpecialization::NKPadding)
275  {
276  // pad N, but not M
278  e_grid_desc_g0_g1_mraw_nraw,
282  make_right_pad_transform(NRaw, NPad)),
285  }
286  else
287  {
288  // not pad M or N
289  return e_grid_desc_g0_g1_mraw_nraw;
290  }
291  }
292 
293  using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
294  using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
295  using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1, 1));
296  using EGridDesc_G0_G1_M_N = decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1));
297 
299  {
301  index_t Batchstride_B,
302  EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n)
303  : Batchstride_A_(Batchstride_A),
304  Batchstride_B_(Batchstride_B),
305  e_grid_desc_g0_g1_m_n_(e_grid_desc_g0_g1_m_n)
306  {
307  }
308 
309  __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
310  {
311  return g_idx * static_cast<long_index_t>(Batchstride_A_);
312  }
313 
314  __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
315  {
316  return g_idx * static_cast<long_index_t>(Batchstride_B_);
317  }
318 
319  __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
320  {
321  const index_t G1 = e_grid_desc_g0_g1_m_n_.GetLength(I1);
322  index_t b0 = g_idx / G1;
323  index_t b1 = g_idx - b0 * G1; // g_idx % G1
324  return e_grid_desc_g0_g1_m_n_.CalculateOffset(make_multi_index(b0, b1, 0, 0));
325  }
326 
327  private:
328  index_t Batchstride_A_;
329  index_t Batchstride_B_;
330  EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
331  };
332 
333  using ComputeDataType = ADataType;
334 
335  // GridwiseGemm
337  ADataType,
338  BDataType,
340  AccDataType,
341  CShuffleDataType,
342  ck::Tuple<>, // DsDataType,
343  EDataType, // EDataType,
344  AElementwiseOperation,
345  BElementwiseOperation,
346  CDEElementwiseOperation,
350  Tuple<>,
352  NumPrefetch,
353  BlockSize,
354  MPerBlock,
355  NPerBlock,
356  KPerBlock,
357  AK1,
358  BK1,
359  MPerXDL,
360  NPerXDL,
361  MXdlPerWave,
362  NXdlPerWave,
363  ABlockTransferThreadClusterLengths_K0_M_K1,
364  ABlockTransferThreadClusterArrangeOrder,
365  ABlockTransferSrcAccessOrder,
366  ABlockTransferSrcVectorDim,
367  ABlockTransferSrcScalarPerVector,
368  ABlockTransferDstScalarPerVector_K1,
369  false, // AThreadTransferSrcResetCoordinateAfterRun,
370  ABlockLdsExtraM,
371  BBlockTransferThreadClusterLengths_K0_N_K1,
372  BBlockTransferThreadClusterArrangeOrder,
373  BBlockTransferSrcAccessOrder,
374  BBlockTransferSrcVectorDim,
375  BBlockTransferSrcScalarPerVector,
376  BBlockTransferDstScalarPerVector_K1,
377  false, // BThreadTransferSrcResetCoordinateAfterRun,
378  BBlockLdsExtraN,
379  CShuffleMXdlPerWavePerShuffle,
380  CShuffleNXdlPerWavePerShuffle,
381  CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
382  CDEBlockTransferScalarPerVector_NPerBlock,
383  LoopSched>;
384 
387  AGridDesc_M_K{}))>;
390  BGridDesc_N_K{}))>;
391 
394  EGridDesc_M_N{}));
395  using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
396 
397  // Argument
398  struct Argument : public BaseArgument
399  {
400  Argument(const ADataType* p_a_grid,
401  const BDataType* p_b_grid,
402  EDataType* p_e_grid,
403  index_t M,
404  index_t N,
405  index_t K,
406  index_t stride_A,
407  index_t stride_B,
408  index_t batch_stride_A,
409  index_t batch_stride_B,
410  BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
411  index_t BatchCount,
412  AElementwiseOperation a_element_op,
413  BElementwiseOperation b_element_op,
414  CDEElementwiseOperation cde_element_op)
415  : p_a_grid_{p_a_grid},
416  p_b_grid_{p_b_grid},
417  p_e_grid_{p_e_grid},
418  BatchCount_(BatchCount),
422  DeviceOp::MakeEGridDescriptor_M_N(batched_gemm_e_permute_desc.M_,
423  batched_gemm_e_permute_desc.N_,
424  batched_gemm_e_permute_desc.stride_M_,
425  batched_gemm_e_permute_desc.stride_N_)},
427  GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
429  GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
432  DeviceOp::MakeEGridDescriptor_G0_G1_M_N(batched_gemm_e_permute_desc.G0_,
433  batched_gemm_e_permute_desc.G1_,
434  batched_gemm_e_permute_desc.M_,
435  batched_gemm_e_permute_desc.N_,
436  batched_gemm_e_permute_desc.stride_G0_,
437  batched_gemm_e_permute_desc.stride_G1_,
438  batched_gemm_e_permute_desc.stride_M_,
439  batched_gemm_e_permute_desc.stride_N_)},
440  compute_ptr_offset_of_batch_{batch_stride_A, batch_stride_B, e_grid_desc_g0_g1_m_n_},
441  block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
442  a_element_op_{a_element_op},
443  b_element_op_{b_element_op},
444  cde_element_op_{cde_element_op}
445  {
448  ck::Tuple<>{},
451  {
455  }
456  }
457 
458  void Print() const
459  {
460  std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
461  std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
462  std::cout << "C[M, N]: " << e_grid_desc_m_n_ << std::endl;
463  }
464 
465  // private:
466  // pointers
467  const ADataType* p_a_grid_;
468  const BDataType* p_b_grid_;
469  EDataType* p_e_grid_;
470 
471  // batch count
473 
474  // tensor descriptors for problem definiton
478 
479  // tensor descriptors for block/thread-wise copy
484 
485  // for calculating Batch offset
487 
488  // block-to-e-tile map
490 
491  // element-wise op
492  AElementwiseOperation a_element_op_;
493  BElementwiseOperation b_element_op_;
494  CDEElementwiseOperation cde_element_op_;
495  };
496 
497  // Invoker
498  struct Invoker : public BaseInvoker
499  {
501 
502  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
503  {
505  arg.b_grid_desc_n_k_,
506  ck::Tuple<>{},
507  arg.e_grid_desc_m_n_,
508  arg.block_2_etile_map_))
509  {
510  throw std::runtime_error(
511  "wrong! GridwiseBatchedGemmCPermute_km_kn_m0m1n0n1_xdlops_v2r3 has invalid "
512  "setting");
513  }
514 
515  const index_t grid_size =
516  arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.BatchCount_;
517 
518  const auto K =
519  arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
520 
521  auto launch_kernel = [&](auto has_main_k_block_loop_) {
522  const auto kernel = kernel_batched_gemm_e_permute_xdl<
523  GridwiseGemm,
524  ADataType, // TODO: distiguish A/B datatype
525  EDataType,
528  typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
529  AElementwiseOperation,
530  BElementwiseOperation,
531  CDEElementwiseOperation,
532  ComputePtrOffsetOfStridedBatch,
534  has_main_k_block_loop_>;
535 
536  return launch_and_time_kernel(stream_config,
537  kernel,
538  dim3(grid_size),
539  dim3(BlockSize),
540  0,
541  arg.p_a_grid_,
542  arg.p_b_grid_,
543  arg.p_e_grid_,
544  arg.BatchCount_,
548  arg.a_element_op_,
549  arg.b_element_op_,
550  arg.cde_element_op_,
552  arg.block_2_etile_map_);
553  };
554 
556  {
558  }
559  else
560  {
561  return launch_kernel(integral_constant<bool, false>{});
562  }
563  }
564 
565  // polymorphic
566  float Run(const BaseArgument* p_arg,
567  const StreamConfig& stream_config = StreamConfig{}) override
568  {
569  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
570  }
571  };
572 
573  static constexpr bool IsValidCompilationParameter()
574  {
575  // TODO: properly implement this check
576  return true;
577  }
578 
579  static bool IsSupportedArgument(const Argument& arg)
580  {
581  if(!ck::is_xdl_supported())
582  {
583  return false;
584  }
585 
587  arg.b_grid_desc_n_k_,
588  ck::Tuple<>{},
589  arg.e_grid_desc_m_n_,
590  arg.block_2_etile_map_);
591  }
592 
593  // polymorphic
594  bool IsSupportedArgument(const BaseArgument* p_arg) override
595  {
596  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
597  }
598 
599  static auto MakeArgument(const ADataType* p_a,
600  const BDataType* p_b,
601  EDataType* p_e,
602  index_t M,
603  index_t N,
604  index_t K,
605  index_t stride_A,
606  index_t stride_B,
607  index_t batch_stride_A,
608  index_t batch_stride_B,
609  BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
610  index_t BatchCount,
611  AElementwiseOperation a_element_op,
612  BElementwiseOperation b_element_op,
613  CDEElementwiseOperation cde_element_op)
614  {
615  return Argument{p_a,
616  p_b,
617  p_e,
618  M,
619  N,
620  K,
621  stride_A,
622  stride_B,
623  batch_stride_A,
624  batch_stride_B,
625  batched_gemm_e_permute_desc,
626  BatchCount,
627  a_element_op,
628  b_element_op,
629  cde_element_op};
630  }
631 
632  static auto MakeInvoker() { return Invoker{}; }
633 
634  // polymorphic
635  std::unique_ptr<BaseArgument>
636  MakeArgumentPointer(const void* p_a,
637  const void* p_b,
638  void* p_e,
639  index_t M,
640  index_t N,
641  index_t K,
642  index_t stride_A,
643  index_t stride_B,
644  index_t batch_stride_A,
645  index_t batch_stride_B,
646  BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
647  index_t BatchCount,
648  AElementwiseOperation a_element_op,
649  BElementwiseOperation b_element_op,
650  CDEElementwiseOperation cde_element_op) override
651  {
652  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
653  static_cast<const BDataType*>(p_b),
654  static_cast<EDataType*>(p_e),
655  M,
656  N,
657  K,
658  stride_A,
659  stride_B,
660  batch_stride_A,
661  batch_stride_B,
662  batched_gemm_e_permute_desc,
663  BatchCount,
664  a_element_op,
665  b_element_op,
666  cde_element_op);
667  }
668 
669  // polymorphic
670  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
671  {
672  return std::make_unique<Invoker>(Invoker{});
673  }
674 
675  // polymorphic
676  std::string GetTypeString() const override
677  {
678  auto str = std::stringstream();
679 
680  // clang-format off
681  str << "DeviceBatchedGemmEPermuteXdl"
682  << "<"
683  << BlockSize << ", "
684  << MPerBlock << ", "
685  << NPerBlock << ", "
686  << KPerBlock
687  << ">";
688  // clang-format on
689 
690  return str.str();
691  }
692 };
693 
694 } // namespace device
695 } // namespace tensor_operation
696 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__global__ void kernel_batched_gemm_e_permute_xdl(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, EDataType *__restrict__ p_e_grid, const index_t batch_count, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map)
Definition: device_batched_gemm_e_permute_xdl.hpp:63
GemmSpecialization
Definition: gemm_specialization.hpp:11
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables... callables)
Definition: kernel_launch.hpp:72
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
bool is_xdl_supported()
Definition: device_prop.hpp:54
__device__ index_t get_grid_size()
Definition: get_id.hpp:24
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
int64_t long_index_t
Definition: ck.hpp:290
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
__host__ static constexpr __device__ auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:221
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:396
__host__ static constexpr __device__ auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:204
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const DsGridDesc_M_N &ds_grid_desc_m_n, const EGridDesc_M_N &e_grid_desc_m_n, [[maybe_unused]] const Block2ETileMap &)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:329
__host__ static constexpr __device__ auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:187
Definition: sequence.hpp:43
Definition: tuple.hpp:186
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_batched_gemm_e_permute.hpp:12
Definition: device_batched_gemm_e_permute.hpp:27
Definition: device_batched_gemm_e_permute_xdl.hpp:399
void Print() const
Definition: device_batched_gemm_e_permute_xdl.hpp:458
EDataType * p_e_grid_
Definition: device_batched_gemm_e_permute_xdl.hpp:469
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition: device_batched_gemm_e_permute_xdl.hpp:480
BGridDesc_N_K b_grid_desc_n_k_
Definition: device_batched_gemm_e_permute_xdl.hpp:476
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition: device_batched_gemm_e_permute_xdl.hpp:481
CDEElementwiseOperation cde_element_op_
Definition: device_batched_gemm_e_permute_xdl.hpp:494
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_
Definition: device_batched_gemm_e_permute_xdl.hpp:486
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_
Definition: device_batched_gemm_e_permute_xdl.hpp:483
EGridDesc_M_N e_grid_desc_m_n_
Definition: device_batched_gemm_e_permute_xdl.hpp:477
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, EDataType *p_e_grid, index_t M, index_t N, index_t K, index_t stride_A, index_t stride_B, index_t batch_stride_A, index_t batch_stride_B, BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, index_t BatchCount, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition: device_batched_gemm_e_permute_xdl.hpp:400
const ADataType * p_a_grid_
Definition: device_batched_gemm_e_permute_xdl.hpp:467
index_t BatchCount_
Definition: device_batched_gemm_e_permute_xdl.hpp:472
const BDataType * p_b_grid_
Definition: device_batched_gemm_e_permute_xdl.hpp:468
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock
Definition: device_batched_gemm_e_permute_xdl.hpp:482
AGridDesc_M_K a_grid_desc_m_k_
Definition: device_batched_gemm_e_permute_xdl.hpp:475
Block2ETileMap block_2_etile_map_
Definition: device_batched_gemm_e_permute_xdl.hpp:489
BElementwiseOperation b_element_op_
Definition: device_batched_gemm_e_permute_xdl.hpp:493
AElementwiseOperation a_element_op_
Definition: device_batched_gemm_e_permute_xdl.hpp:492
ComputePtrOffsetOfStridedBatch(index_t Batchstride_A, index_t Batchstride_B, EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n)
Definition: device_batched_gemm_e_permute_xdl.hpp:300
__host__ constexpr __device__ long_index_t GetCPtrOffset(index_t g_idx) const
Definition: device_batched_gemm_e_permute_xdl.hpp:319
__host__ constexpr __device__ long_index_t GetAPtrOffset(index_t g_idx) const
Definition: device_batched_gemm_e_permute_xdl.hpp:309
__host__ constexpr __device__ long_index_t GetBPtrOffset(index_t g_idx) const
Definition: device_batched_gemm_e_permute_xdl.hpp:314
Definition: device_batched_gemm_e_permute_xdl.hpp:499
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_batched_gemm_e_permute_xdl.hpp:502
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_batched_gemm_e_permute_xdl.hpp:566
Definition: device_batched_gemm_e_permute_xdl.hpp:171
decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1)) EGridDesc_G0_G1_M_N
Definition: device_batched_gemm_e_permute_xdl.hpp:296
static auto MakeInvoker()
Definition: device_batched_gemm_e_permute_xdl.hpp:632
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition: device_batched_gemm_e_permute_xdl.hpp:199
std::string GetTypeString() const override
Definition: device_batched_gemm_e_permute_xdl.hpp:676
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N)
Definition: device_batched_gemm_e_permute_xdl.hpp:218
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition: device_batched_gemm_e_permute_xdl.hpp:293
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition: device_batched_gemm_e_permute_xdl.hpp:181
static constexpr auto I1
Definition: device_batched_gemm_e_permute_xdl.hpp:175
static constexpr auto matrix_padder
Definition: device_batched_gemm_e_permute_xdl.hpp:178
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition: device_batched_gemm_e_permute_xdl.hpp:294
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, EDataType *p_e, index_t M, index_t N, index_t K, index_t stride_A, index_t stride_B, index_t batch_stride_A, index_t batch_stride_B, BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, index_t BatchCount, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition: device_batched_gemm_e_permute_xdl.hpp:599
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_e, index_t M, index_t N, index_t K, index_t stride_A, index_t stride_B, index_t batch_stride_A, index_t batch_stride_B, BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, index_t BatchCount, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition: device_batched_gemm_e_permute_xdl.hpp:636
static constexpr auto I0
Definition: device_batched_gemm_e_permute_xdl.hpp:174
static constexpr bool IsValidCompilationParameter()
Definition: device_batched_gemm_e_permute_xdl.hpp:573
GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, ck::Tuple<>, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_M_K, BGridDesc_N_K, Tuple<>, EGridDesc_M_N, NumPrefetch, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched > GridwiseGemm
Definition: device_batched_gemm_e_permute_xdl.hpp:383
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_batched_gemm_e_permute_xdl.hpp:594
decltype(MakeEGridDescriptor_M_N(1, 1, 1, 1)) EGridDesc_M_N
Definition: device_batched_gemm_e_permute_xdl.hpp:295
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition: device_batched_gemm_e_permute_xdl.hpp:387
static constexpr auto I2
Definition: device_batched_gemm_e_permute_xdl.hpp:176
typename GridwiseGemm::DefaultBlock2ETileMap Block2ETileMap
Definition: device_batched_gemm_e_permute_xdl.hpp:395
static bool IsSupportedArgument(const Argument &arg)
Definition: device_batched_gemm_e_permute_xdl.hpp:579
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_batched_gemm_e_permute_xdl.hpp:670
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition: device_batched_gemm_e_permute_xdl.hpp:390
static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0, index_t G1, index_t MRaw, index_t NRaw, index_t stride_G0, index_t stride_G1, index_t stride_M, index_t stride_N)
Definition: device_batched_gemm_e_permute_xdl.hpp:226
ADataType ComputeDataType
Definition: device_batched_gemm_e_permute_xdl.hpp:333
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})) EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: device_batched_gemm_e_permute_xdl.hpp:394
Definition: matrix_padder.hpp:180