/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp Source File
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
20 
21 namespace ck {
22 namespace tensor_operation {
23 namespace device {
24 
25 template <typename GridwiseGemm,
26  typename FloatAB,
27  typename FloatC,
28  typename D0sPointer,
29  typename AElementwiseOperation,
30  typename BElementwiseOperation,
31  typename C0DEElementwiseOperation,
32  typename B1ElementwiseOperation,
33  typename C1DEElementwiseOperation,
34  typename AGridDesc_AK0_M_AK1,
35  typename BGridDesc_BK0_N_BK1,
36  typename B1GridDesc_BK0_N_BK1,
37  typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
38  typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
39  typename Block2CTileMap,
40  typename ComputeBasePtrOfStridedBatch,
41  typename C0MatrixMask,
42  bool HasMainKBlockLoop>
43 __global__ void
44 #if CK_USE_LAUNCH_BOUNDS
46 #endif
48  const FloatAB* __restrict__ p_a_grid,
49  const FloatAB* __restrict__ p_b_grid,
50  const FloatAB* __restrict__ p_b1_grid,
51  FloatC* __restrict__ p_c_grid,
52  D0sPointer p_d0s_grid,
53  const AElementwiseOperation a_element_op,
54  const BElementwiseOperation b_element_op,
55  const C0DEElementwiseOperation c0de_element_op,
56  const B1ElementwiseOperation b1_element_op,
57  const C1DEElementwiseOperation c1de_element_op,
58  const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
59  const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
60  const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
61  const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
62  c1_grid_desc_mblock_mperblock_nblock_nperblock,
63  const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
64  d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
65  const Block2CTileMap block_2_ctile_map,
66  const index_t batch_count,
67  const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
68  const C0MatrixMask c0_matrix_mask)
69 {
70 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
71  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
72  const index_t num_blocks_per_batch =
73  __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
74  const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
75 
76  const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
77  static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
78  const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
79  static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx)));
80  const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
81  static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
82  const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
83  static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
84 
85  static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) {
86  const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
87  static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In)));
88  p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset;
89  });
90 
91  GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
92  p_b_grid + b_batch_offset,
93  p_b1_grid + b1_batch_offset,
94  p_c_grid + c_batch_offset,
95  p_d0s_grid,
96  p_shared,
97  a_element_op,
98  b_element_op,
99  c0de_element_op,
100  b1_element_op,
101  c1de_element_op,
102  a_grid_desc_ak0_m_ak1,
103  b_grid_desc_bk0_n_bk1,
104  b1_grid_desc_bk0_n_bk1,
105  c1_grid_desc_mblock_mperblock_nblock_nperblock,
106  d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
107  block_2_ctile_map,
108  c0_matrix_mask);
109 #else
110  ignore = p_a_grid;
111  ignore = p_b_grid;
112  ignore = p_b1_grid;
113  ignore = p_c_grid;
114  ignore = p_d0s_grid;
115  ignore = a_element_op;
116  ignore = b_element_op;
117  ignore = c0de_element_op;
118  ignore = b1_element_op;
119  ignore = c1de_element_op;
120  ignore = a_grid_desc_ak0_m_ak1;
121  ignore = b_grid_desc_bk0_n_bk1;
122  ignore = b1_grid_desc_bk0_n_bk1;
123  ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock;
124  ignore = d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
125  ignore = block_2_ctile_map;
126  ignore = batch_count;
127  ignore = compute_base_ptr_of_batch;
128  ignore = c0_matrix_mask;
129 #endif // end of if (defined(__gfx9__))
130 }
131 
132 // Computes C = A * B0 * B1
133 // ^^^^^^ (Acc0)
134 // ^^^^^^^^^^^ (Acc1)
135 template <index_t NumDimG,
136  index_t NumDimM,
137  index_t NumDimN,
138  index_t NumDimK,
139  index_t NumDimO, // NumDimGemm1N
140  typename ADataType,
141  typename BDataType,
142  typename B1DataType,
143  typename CDataType,
144  typename D0sDataType,
145  typename D1sDataType,
146  typename GemmAccDataType,
147  typename CShuffleDataType,
148  typename AElementwiseOperation,
149  typename BElementwiseOperation,
150  typename C0DEElementwiseOperation,
151  typename B1ElementwiseOperation,
152  typename C1DEElementwiseOperation,
153  GemmSpecialization GemmSpec,
154  TensorSpecialization ASpec,
155  TensorSpecialization BSpec,
156  TensorSpecialization B1Spec,
157  TensorSpecialization CSpec,
158  index_t NumGemmKPrefetchStage,
159  index_t BlockSize,
160  index_t MPerBlock,
161  index_t NPerBlock, // Gemm0NPerBlock
162  index_t KPerBlock, // Gemm0KPerBlock
163  index_t Gemm1NPerBlock,
164  index_t Gemm1KPerBlock,
165  index_t AK1,
166  index_t BK1,
167  index_t B1K1,
168  index_t MPerXDL,
169  index_t NPerXDL,
170  index_t MXdlPerWave,
171  index_t NXdlPerWave,
172  index_t Gemm1NXdlPerWave,
173  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
174  typename ABlockTransferThreadClusterArrangeOrder,
175  typename ABlockTransferSrcAccessOrder,
176  index_t ABlockTransferSrcVectorDim,
177  index_t ABlockTransferSrcScalarPerVector,
178  index_t ABlockTransferDstScalarPerVector_AK1,
179  bool ABlockLdsExtraM,
180  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
181  typename BBlockTransferThreadClusterArrangeOrder,
182  typename BBlockTransferSrcAccessOrder,
183  index_t BBlockTransferSrcVectorDim,
184  index_t BBlockTransferSrcScalarPerVector,
185  index_t BBlockTransferDstScalarPerVector_BK1,
186  bool BBlockLdsExtraN,
187  typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
188  typename B1BlockTransferThreadClusterArrangeOrder,
189  typename B1BlockTransferSrcAccessOrder,
190  index_t B1BlockTransferSrcVectorDim,
191  index_t B1BlockTransferSrcScalarPerVector,
192  index_t B1BlockTransferDstScalarPerVector_BK1,
193  bool B1BlockLdsExtraN,
194  index_t CShuffleMXdlPerWavePerShuffle,
195  index_t CShuffleNXdlPerWavePerShuffle,
196  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
197  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
198  MaskingSpecialization MaskingSpec,
199  int D0sTransferSrcScalarPerVector = 4,
202  : public DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
203  NumDimM,
204  NumDimN,
205  NumDimK,
206  NumDimO,
207  ADataType,
208  BDataType,
209  B1DataType,
210  CDataType,
211  D0sDataType,
212  D1sDataType,
213  AElementwiseOperation,
214  BElementwiseOperation,
215  C0DEElementwiseOperation,
216  B1ElementwiseOperation,
217  C1DEElementwiseOperation,
218  MaskingSpec>
219 {
220  static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
221  "Number of dimension must be greater than 0");
222 
223  static constexpr index_t NumD0Tensor = D0sDataType::Size();
224  static constexpr index_t NumD1Tensor = D1sDataType::Size();
225 
226  // TODO ANT: implement bias combination
227  static_assert(NumD1Tensor == 0, "Gemm1 Bias addition is unimplemented");
228 
229 #if 0
230  // TODO ANT: use alias
231  static constexpr index_t NumDimGemm0M = NumDimM;
232  static constexpr index_t NumDimGemm0N = NumDimN;
233  static constexpr index_t NumDimGemm0K = NumDimK;
234  static constexpr index_t NumDimGemm1M = NumDimM;
235  static constexpr index_t NumDimGemm1N = NumDimO;
236  static constexpr index_t NumDimGemm1K = NumDimN;
237 #endif
238 
240 
241  static constexpr auto I0 = Number<0>{};
242  static constexpr auto I1 = Number<1>{};
243  static constexpr auto I2 = Number<2>{};
244 
248  GemmSpec,
249  ASpec,
250  BSpec,
251  B1Spec,
252  CSpec>;
253 
254  static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
255  const std::vector<index_t>& a_gs_ms_ks_strides_vec)
256  {
258  Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
259  Number<AK1>{});
260  }
261 
262  static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
263  const std::vector<index_t>& b_gs_ns_ks_strides_vec)
264  {
266  Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec),
267  Number<BK1>{});
268  }
269 
270  static auto
271  MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
272  const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
273  {
275  Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec,
276  b1_gs_gemm1ns_gemm1ks_strides_vec),
277  Number<B1K1>{});
278  }
279 
281  const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
282  const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides)
283  {
284  return generate_tuple(
285  [&](auto i) {
286  return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths[i],
287  acc0_biases_gs_ms_ns_strides[i]);
288  },
290  }
291 
293  const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
294  const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides)
295  {
296  return generate_tuple(
297  [&](auto i) {
298  return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths[i],
299  acc0_biases_gs_ms_ns_strides[i]);
300  },
302  }
303 
312  using D0sGridDesc_M_N = decltype(MakeD0sGridDescriptor_M_N({}, {}));
314 
315  constexpr static auto make_MaskOutPredicate()
316  {
317  if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
318  {
319  return MaskDisabledPredicate{};
320  }
321  else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
322  {
324  }
325  }
327 
329  {
331  const BGridDesc_G_N_K& b_grid_desc_g_n_k,
332  const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
333  const C1GridDesc_G_M_N& c1_grid_desc_g_m_n,
334  const D0sGridDesc_G_M_N& d0s_grid_desc_g_m_n)
335  : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
336  b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
337  b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
338  c1_grid_desc_g_m_n_(c1_grid_desc_g_m_n),
339  d0s_grid_desc_g_m_n_(d0s_grid_desc_g_m_n)
340  {
341  }
342 
343  __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
344  {
345  return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
346  }
347 
348  __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
349  {
350  return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
351  }
352 
353  __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
354  {
355  return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
356  }
357 
358  __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
359  {
360  return c1_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
361  }
362 
363  template <index_t I>
364  __host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
365  Number<I> d0_idx) const
366  {
367  return d0s_grid_desc_g_m_n_[d0_idx].CalculateOffset(make_multi_index(g_idx, 0, 0));
368  }
369 
370  private:
371  AGridDesc_G_M_K a_grid_desc_g_m_k_;
372  BGridDesc_G_N_K b_grid_desc_g_n_k_;
373  B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
374  C1GridDesc_G_M_N c1_grid_desc_g_m_n_;
375  D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_;
376  };
377 
378  // GridwiseGemm
380  ADataType, // TODO: distinguish A/B datatype
381  GemmAccDataType,
382  CShuffleDataType,
383  CDataType,
384  D0sDataType,
385  AElementwiseOperation,
386  BElementwiseOperation,
387  C0DEElementwiseOperation,
388  B1ElementwiseOperation,
389  C1DEElementwiseOperation,
396  NumGemmKPrefetchStage,
397  BlockSize,
398  MPerBlock,
399  NPerBlock,
400  KPerBlock,
401  Gemm1NPerBlock,
402  Gemm1KPerBlock,
403  AK1,
404  BK1,
405  B1K1,
406  MPerXDL,
407  NPerXDL,
408  MXdlPerWave,
409  NXdlPerWave,
410  Gemm1NXdlPerWave,
411  ABlockTransferThreadClusterLengths_AK0_M_AK1,
412  ABlockTransferThreadClusterArrangeOrder,
413  ABlockTransferSrcAccessOrder,
414  ABlockTransferSrcVectorDim,
415  ABlockTransferSrcScalarPerVector,
416  ABlockTransferDstScalarPerVector_AK1,
417  true,
418  ABlockLdsExtraM,
419  BBlockTransferThreadClusterLengths_BK0_N_BK1,
420  BBlockTransferThreadClusterArrangeOrder,
421  BBlockTransferSrcAccessOrder,
422  BBlockTransferSrcVectorDim,
423  BBlockTransferSrcScalarPerVector,
424  BBlockTransferDstScalarPerVector_BK1,
425  true,
426  BBlockLdsExtraN,
427  B1BlockTransferThreadClusterLengths_BK0_N_BK1,
428  B1BlockTransferThreadClusterArrangeOrder,
429  B1BlockTransferSrcAccessOrder,
430  B1BlockTransferSrcVectorDim,
431  B1BlockTransferSrcScalarPerVector,
432  B1BlockTransferDstScalarPerVector_BK1,
433  false,
434  B1BlockLdsExtraN,
435  CShuffleMXdlPerWavePerShuffle,
436  CShuffleNXdlPerWavePerShuffle,
437  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
438  CShuffleBlockTransferScalarPerVector_NPerBlock,
439  LoopSched,
442  D0sTransferSrcScalarPerVector>;
443 
444  // Argument
445  // FIXME: constness
446  struct Argument : public BaseArgument
447  {
449  const ADataType* p_a_grid,
450  const BDataType* p_b_grid,
451  const B1DataType* p_b1_grid,
452  CDataType* p_c_grid,
453  const std::array<void*, NumD0Tensor> p_acc0_biases,
454  const std::array<void*, NumD1Tensor> p_acc1_biases,
455  const std::vector<index_t>& a_gs_ms_ks_lengths,
456  const std::vector<index_t>& a_gs_ms_ks_strides,
457  const std::vector<index_t>& b_gs_ns_ks_lengths,
458  const std::vector<index_t>& b_gs_ns_ks_strides,
459  const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
460  const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
461  const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
462  const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
463  const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
464  const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides,
465  const std::array<std::vector<ck::index_t>, NumD1Tensor>&
466  acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
467  const std::array<std::vector<ck::index_t>, NumD1Tensor>&
468  acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
469  AElementwiseOperation a_element_op,
470  BElementwiseOperation b_element_op,
471  C0DEElementwiseOperation c0de_element_op,
472  B1ElementwiseOperation b1_element_op,
473  C1DEElementwiseOperation c1de_element_op)
474  : p_a_grid_{p_a_grid},
475  p_b_grid_{p_b_grid},
476  p_b1_grid_{p_b1_grid},
477  p_c_grid_{p_c_grid},
478  p_d0s_grid_{},
480  DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
482  DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
484  b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
485  c1_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
486  c_gs_ms_gemm1ns_strides)},
488  Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
490  Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
491  b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
492  b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
493  c1_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
494  c_gs_ms_gemm1ns_strides)},
496  acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)},
499  block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c1_grid_desc_m_n_)},
500  a_element_op_{a_element_op},
501  b_element_op_{b_element_op},
502  c0de_element_op_{c0de_element_op},
503  b1_element_op_{b1_element_op},
504  c1de_element_op_{c1de_element_op},
506  raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
507  b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
508  b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
509  b1_gs_gemm1ns_gemm1ks_lengths[NumDimG + NumDimO - 1]},
510  a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
511  a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
512  b_nz_kz_strides_{b_gs_ns_ks_strides[NumDimG + NumDimN - 1],
513  b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]},
514  b1_nz_kz_strides_{b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO - 1],
515  b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]},
516  c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
517  c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
518  batch_count_{c1_grid_desc_g_m_n_.GetLength(I0)},
524  {
525  // TODO ANT: implement bias addition
526  ignore = p_acc1_biases;
527  ignore = acc1_biases_gs_ms_gemm1ns_lengths;
528  ignore = acc1_biases_gs_ms_gemm1ns_strides;
529 
530  static_for<0, NumD0Tensor, 1>{}([&](auto i) {
531  using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
532  // D0 pointer
533  p_d0s_grid_(i) = static_cast<const D0DataType*>(p_acc0_biases[i]);
534  // for check
535  d0s_nl_ns_lengths_strides_[i].push_back(
536  acc0_biases_gs_ms_ns_lengths[i][NumDimG + NumDimM]);
537  d0s_nl_ns_lengths_strides_[i].push_back(
538  acc0_biases_gs_ms_ns_strides[i][NumDimG + NumDimM]);
539  });
540 
546  {
550 
552  acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)};
555  d0s_grid_desc_m_n);
556  }
557  }
558 
559  void Print() const
560  {
561  std::cout << "a_grid_desc_g_m_k_: " << a_grid_desc_g_m_k_.GetLength(I0) << ", "
562  << a_grid_desc_g_m_k_.GetLength(I1) << ", "
563  << a_grid_desc_g_m_k_.GetLength(I2) << '\n';
564  std::cout << "b_grid_desc_g_n_k_: " << b_grid_desc_g_n_k_.GetLength(I0) << ", "
565  << b_grid_desc_g_n_k_.GetLength(I1) << ", "
566  << b_grid_desc_g_n_k_.GetLength(I2) << '\n';
567  std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
568  << b1_grid_desc_g_n_k_.GetLength(I1) << ", "
569  << b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
570  std::cout << "c1_grid_desc_g_m_n_: " << c1_grid_desc_g_m_n_.GetLength(I0) << ", "
571  << c1_grid_desc_g_m_n_.GetLength(I1) << ", "
572  << c1_grid_desc_g_m_n_.GetLength(I2) << '\n';
573  }
574 
575  // pointers
576  const ADataType* p_a_grid_;
577  const BDataType* p_b_grid_;
578  const B1DataType* p_b1_grid_;
579  CDataType* p_c_grid_;
581 
582  // tensor descriptor
592 
597 
598  // block-to-c-tile map
600 
601  // element-wise op
602  AElementwiseOperation a_element_op_;
603  BElementwiseOperation b_element_op_;
604  C0DEElementwiseOperation c0de_element_op_;
605  B1ElementwiseOperation b1_element_op_;
606  C1DEElementwiseOperation c1de_element_op_;
607 
608  // check C0 masking and padding
610 
611  // For robust IsSupportedArgument() check
612  std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
613  std::vector<index_t> a_mz_kz_strides_;
614  std::vector<index_t> b_nz_kz_strides_;
615  std::vector<index_t> b1_nz_kz_strides_;
616  std::vector<index_t> c_mz_gemm1nz_strides_;
617  std::array<std::vector<ck::index_t>, NumD0Tensor> d0s_nl_ns_lengths_strides_;
618 
621  };
622 
623  // Invoker
624  struct Invoker : public BaseInvoker
625  {
627 
628  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
629  {
631  {
632  throw std::runtime_error("wrong! unsupported argument");
633  }
634 
635  const index_t grid_size =
636  arg.block_2_ctile_map_.CalculateGridSize(arg.c1_grid_desc_m_n_) * arg.batch_count_;
637 
638  // Gemm0_K
639  const auto K =
640  arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
641 
642  float ave_time = 0;
643 
644  auto launch_kernel = [&](auto has_main_k_block_loop_) {
646  GridwiseGemm,
647  ADataType, // TODO: distiguish A/B datatype
648  CDataType,
650  AElementwiseOperation,
651  BElementwiseOperation,
652  C0DEElementwiseOperation,
653  B1ElementwiseOperation,
654  C1DEElementwiseOperation,
661  ComputeBasePtrOfStridedBatch,
662  C0MatrixMask,
663  has_main_k_block_loop_>;
664 
665  return launch_and_time_kernel(stream_config,
666  kernel,
667  dim3(grid_size),
668  dim3(BlockSize),
669  0,
670  arg.p_a_grid_,
671  arg.p_b_grid_,
672  arg.p_b1_grid_,
673  arg.p_c_grid_,
674  arg.p_d0s_grid_,
675  arg.a_element_op_,
676  arg.b_element_op_,
677  arg.c0de_element_op_,
678  arg.b1_element_op_,
679  arg.c1de_element_op_,
685  arg.block_2_ctile_map_,
686  arg.batch_count_,
688  arg.c0_matrix_mask_);
689  };
690 
691  // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
692  // to concern Gemm0's loop
694  {
696  }
697  else
698  {
699  ave_time = launch_kernel(integral_constant<bool, false>{});
700  }
701 
702  return ave_time;
703  }
704 
705  // polymorphic
706  float Run(const BaseArgument* p_arg,
707  const StreamConfig& stream_config = StreamConfig{}) override
708  {
709  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
710  }
711  };
712 
713  static constexpr bool IsValidCompilationParameter()
714  {
715  // TODO: properly implement this check
716  return true;
717  }
718 
719  static bool IsSupportedArgument(const Argument& arg)
720  {
721  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
722  {
723  arg.Print();
724  }
725 
726  if(!ck::is_xdl_supported())
727  {
728  return false;
729  }
730 
731  // TODO ANT: Check if tensor specialization & strides mismatch
732 
733  // Check if C permute dimension matches GEMM + GEMM shape
734  const index_t c_g = arg.c1_grid_desc_g_m_n_.GetLength(I0); // unpadded
735  const index_t c_m = arg.c1_grid_desc_m_n_.GetLength(I0);
736  const index_t c_gemm1n = arg.c1_grid_desc_m_n_.GetLength(I1);
737  const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
738  const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
739 
740  if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
741  {
742  return false;
743  }
744 
745  // Note: we need raw lengths since threadwise copy can not handle vector load when part of
746  // vector is out of bounds
747  // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
748  const auto MzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
749  const auto NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[1];
750  const auto KzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[2];
751  const auto Gemm1NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[3];
752 
753  // Check scalar per vector requirement
754  const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
755  const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw;
756  const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw;
757  const auto c_extent_lowest = Gemm1NzRaw;
758 
759  if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
760  b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
761  b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
762  c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
763  {
764  return false;
765  }
766 
767  // Check vector load/store requirement
768  const auto a_stride_lowest =
769  ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
770  const auto b_stride_lowest =
771  BBlockTransferSrcVectorDim == 2 ? arg.b_nz_kz_strides_[1] : arg.b_nz_kz_strides_[0];
772  const auto b1_stride_lowest =
773  B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_kz_strides_[1] : arg.b1_nz_kz_strides_[0];
774  const auto c_stride_lowest =
775  arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be contiguous
776 
777  if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 ||
778  c_stride_lowest == 1))
779  {
780  return false;
781  }
782  for(int i = 0; i < NumD0Tensor; i++)
783  {
784  if(arg.d0s_nl_ns_lengths_strides_[i][1] == 1 &&
785  arg.d0s_nl_ns_lengths_strides_[i][0] % D0sTransferSrcScalarPerVector != 0)
786  {
787  return false;
788  }
789  if(arg.d0s_nl_ns_lengths_strides_[i][1] != 1 && D0sTransferSrcScalarPerVector != 1)
790  {
791  return false;
792  }
793  }
794 
798  arg.c1_grid_desc_m_n_,
799  arg.block_2_ctile_map_);
800  }
801 
802  // polymorphic
803  bool IsSupportedArgument(const BaseArgument* p_arg) override
804  {
805  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
806  }
807 
808  static auto MakeArgument(
809  const ADataType* p_a,
810  const BDataType* p_b,
811  const B1DataType* p_b1,
812  CDataType* p_c,
813  const std::array<void*, NumD0Tensor> p_acc0_biases,
814  const std::array<void*, NumD1Tensor> p_acc1_biases,
815  const std::vector<index_t>& a_gs_ms_ks_lengths,
816  const std::vector<index_t>& a_gs_ms_ks_strides,
817  const std::vector<index_t>& b_gs_ns_ks_lengths,
818  const std::vector<index_t>& b_gs_ns_ks_strides,
819  const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
820  const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
821  const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
822  const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
823  const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_lengths,
824  const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_strides,
825  const std::array<std::vector<ck::index_t>, NumD1Tensor>
826  acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
827  const std::array<std::vector<ck::index_t>, NumD1Tensor>
828  acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
829  AElementwiseOperation a_element_op,
830  BElementwiseOperation b_element_op,
831  C0DEElementwiseOperation c0de_element_op,
832  B1ElementwiseOperation b1_element_op,
833  C1DEElementwiseOperation c1de_element_op)
834  {
835  return Argument{p_a,
836  p_b,
837  p_b1,
838  p_c,
839  p_acc0_biases,
840  p_acc1_biases,
841  a_gs_ms_ks_lengths,
842  a_gs_ms_ks_strides,
843  b_gs_ns_ks_lengths,
844  b_gs_ns_ks_strides,
845  b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
846  b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
847  c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
848  c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
849  acc0_biases_gs_ms_ns_lengths,
850  acc0_biases_gs_ms_ns_strides,
851  acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
852  acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
853  a_element_op,
854  b_element_op,
855  c0de_element_op,
856  b1_element_op,
857  c1de_element_op};
858  }
859 
860  static auto MakeInvoker() { return Invoker{}; }
861 
862  // polymorphic
863  // FIXME: constness
864  std::unique_ptr<BaseArgument> MakeArgumentPointer(
865  const void* p_a,
866  const void* p_b,
867  const void* p_b1,
868  void* p_c,
869  const std::array<void*, NumD0Tensor> p_acc0_biases,
870  const std::array<void*, NumD1Tensor> p_acc1_biases,
871  const std::vector<index_t>& a_gs_ms_ks_lengths,
872  const std::vector<index_t>& a_gs_ms_ks_strides,
873  const std::vector<index_t>& b_gs_ns_ks_lengths,
874  const std::vector<index_t>& b_gs_ns_ks_strides,
875  const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
876  const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
877  const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
878  const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
879  const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_lengths,
880  const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_strides,
881  const std::array<std::vector<ck::index_t>, NumD1Tensor>
882  acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
883  const std::array<std::vector<ck::index_t>, NumD1Tensor>
884  acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
885  AElementwiseOperation a_element_op,
886  BElementwiseOperation b_element_op,
887  C0DEElementwiseOperation c0de_element_op,
888  B1ElementwiseOperation b1_element_op,
889  C1DEElementwiseOperation c1de_element_op) override
890  {
891  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
892  static_cast<const BDataType*>(p_b),
893  static_cast<const B1DataType*>(p_b1),
894  static_cast<CDataType*>(p_c),
895  p_acc0_biases, // cast in struct Argument
896  p_acc1_biases, // cast in struct Argument
897  a_gs_ms_ks_lengths,
898  a_gs_ms_ks_strides,
899  b_gs_ns_ks_lengths,
900  b_gs_ns_ks_strides,
901  b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
902  b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
903  c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
904  c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
905  acc0_biases_gs_ms_ns_lengths,
906  acc0_biases_gs_ms_ns_strides,
907  acc1_biases_gs_ms_gemm1ns_lengths,
908  acc1_biases_gs_ms_gemm1ns_strides,
909  a_element_op,
910  b_element_op,
911  c0de_element_op,
912  b1_element_op,
913  c1de_element_op);
914  }
915 
916  // polymorphic
917  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
918  {
919  return std::make_unique<Invoker>(Invoker{});
920  }
921 
922  // polymorphic
923  std::string GetTypeString() const override
924  {
925  auto str = std::stringstream();
926 
927  // clang-format off
928  str << "DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle"
929  << "<"
930  << BlockSize << ", "
931  << MPerBlock << ", "
932  << NPerBlock << ", "
933  << KPerBlock << ", "
934  << AK1 << ", "
935  << BK1 << ", "
936  << MPerBlock << ", "
937  << Gemm1NPerBlock << ", "
938  << Gemm1KPerBlock << ", "
939  << B1K1 << ", "
940  << getGemmSpecializationString(GemmSpec) << ", "
941  << "ASpec" << getTensorSpecializationString(ASpec) << ", "
942  << "B0Spec" << getTensorSpecializationString(BSpec) << ", "
943  << "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
944  << "CSpec" << getTensorSpecializationString(CSpec) << ", "
945  << getMaskingSpecializationString(MaskingSpec) << ">";
946  // clang-format on
947 
948  return str.str();
949  }
950 };
951 
952 } // namespace device
953 } // namespace tensor_operation
954 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
#define CK_ENV(name)
Definition: env.hpp:128
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:33
std::string getMaskingSpecializationString(const MaskingSpecialization &s)
Definition: masking_specialization.hpp:16
MaskingSpecialization
Definition: masking_specialization.hpp:11
TensorSpecialization
Definition: tensor_specialization.hpp:11
GemmSpecialization
Definition: gemm_specialization.hpp:11
std::string getTensorSpecializationString(const TensorSpecialization &s)
Definition: tensor_specialization.hpp:16
__global__ void kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, const FloatAB *__restrict__ p_b1_grid, FloatC *__restrict__ p_c_grid, D0sPointer p_d0s_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const C0DEElementwiseOperation c0de_element_op, const B1ElementwiseOperation b1_element_op, const C1DEElementwiseOperation c1de_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c1_grid_desc_mblock_mperblock_nblock_nperblock, const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, const Block2CTileMap block_2_ctile_map, const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const C0MatrixMask c0_matrix_mask)
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:47
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables... callables)
Definition: kernel_launch.hpp:72
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
bool is_xdl_supported()
Definition: device_prop.hpp:54
__device__ index_t get_grid_size()
Definition: get_id.hpp:24
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
int64_t long_index_t
Definition: ck.hpp:290
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:289
Definition: stream_config.hpp:10
Definition: gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:86
remove_cvref_t< decltype(MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C1GridDesc_M_N{}))> C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition: gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:378
decltype(MakeD0sGridPointer()) D0sGridPointer
Definition: gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:371
__host__ static constexpr __device__ auto MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0sGridDesc_M_N &ds_grid_desc_m_n)
Definition: gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:362
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 &b1_grid_desc_bk0_n_bk1, const C1GridDesc_M_N &c1_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:216
__host__ static constexpr __device__ auto MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const C1GridDesc_M_N &c1_grid_desc_m_n)
Definition: gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:278
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(C1GridDesc_M_N{}))> DefaultBlock2CTileMap
Definition: gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:381
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:270
remove_cvref_t< decltype(MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(D0sGridDesc_M_N{}))> D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
Definition: gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:374
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
__host__ static constexpr __device__ auto MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k, const Number &BK1)
Definition: transform_contraction_to_gemm.hpp:208
__host__ static constexpr __device__ auto MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k, const Number &AK1)
Definition: transform_contraction_to_gemm.hpp:168
static auto MakeB0GridDescriptor_N_K(const std::vector< index_t > &b0_gs_ns_ks_lengths_vec, const std::vector< index_t > &b0_gs_ns_ks_strides_vec)
Definition: transform_contraction_to_gemm.hpp:198
static auto MakeAGridDescriptor_G_M_K(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition: transform_contraction_to_gemm.hpp:154
__host__ static constexpr __device__ auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K &b1_grid_desc_n_k, const Number &B1K1)
Definition: transform_contraction_to_gemm.hpp:248
static auto MakeB0GridDescriptor_G_N_K(const std::vector< index_t > &b0_gs_ns_ks_lengths_vec, const std::vector< index_t > &b0_gs_ns_ks_strides_vec)
Definition: transform_contraction_to_gemm.hpp:193
static auto MakeAGridDescriptor_M_K(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition: transform_contraction_to_gemm.hpp:159
static constexpr auto matrix_padder
Definition: transform_contraction_to_gemm.hpp:139
static auto MakeCGridDescriptor_G_M_N(const std::vector< index_t > &c_gs_ms_os_lengths_vec, const std::vector< index_t > &c_gs_ms_os_strides_vec)
Definition: transform_contraction_to_gemm.hpp:274
static auto MakeB1GridDescriptor_G_N_K(const std::vector< index_t > &b1_gs_os_ns_lengths_vec, const std::vector< index_t > &b1_gs_os_ns_strides_vec)
Definition: transform_contraction_to_gemm.hpp:233
static auto MakeB1GridDescriptor_N_K(const std::vector< index_t > &b1_gs_os_ns_lengths_vec, const std::vector< index_t > &b1_gs_os_ns_strides_vec)
Definition: transform_contraction_to_gemm.hpp:238
static auto MakeCGridDescriptor_M_N(const std::vector< index_t > &c_gs_ms_os_lengths_vec, const std::vector< index_t > &c_gs_ms_os_strides_vec)
Definition: transform_contraction_to_gemm.hpp:279
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:329
__host__ constexpr __device__ long_index_t GetBBasePtr(index_t g_idx) const
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:348
__host__ constexpr __device__ long_index_t GetD0BasePtr(index_t g_idx, Number< I > d0_idx) const
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:364
__host__ constexpr __device__ long_index_t GetB1BasePtr(index_t g_idx) const
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:353
__host__ constexpr __device__ long_index_t GetABasePtr(index_t g_idx) const
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:343
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K &a_grid_desc_g_m_k, const BGridDesc_G_N_K &b_grid_desc_g_n_k, const B1GridDesc_G_N_K &b1_grid_desc_g_n_k, const C1GridDesc_G_M_N &c1_grid_desc_g_m_n, const D0sGridDesc_G_M_N &d0s_grid_desc_g_m_n)
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:330
__host__ constexpr __device__ long_index_t GetCBasePtr(index_t g_idx) const
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:358
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:447
GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c1_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:594
const B1DataType * p_b1_grid_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:578
GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:599
AElementwiseOperation a_element_op_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:602
C1DEElementwiseOperation c1de_element_op_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:606
std::vector< index_t > c_mz_gemm1nz_strides_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:616
std::vector< index_t > b_nz_kz_strides_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:614
BElementwiseOperation b_element_op_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:603
B1GridDesc_G_N_K b1_grid_desc_g_n_k_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:589
index_t batch_count_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:619
C0DEElementwiseOperation c0de_element_op_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:604
C1GridDesc_M_N c1_grid_desc_m_n_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:586
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:583
std::vector< index_t > raw_lengths_mz_nz_kz_gemm1nz_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:612
D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:591
std::vector< index_t > b1_nz_kz_strides_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:615
GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:596
B1ElementwiseOperation b1_element_op_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:605
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:620
GridwiseGemm::D0sGridPointer p_d0s_grid_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:580
C1GridDesc_G_M_N c1_grid_desc_g_m_n_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:590
AGridDesc_G_M_K a_grid_desc_g_m_k_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:587
const BDataType * p_b_grid_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:577
std::vector< index_t > a_mz_kz_strides_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:613
BGridDesc_G_N_K b_grid_desc_g_n_k_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:588
void Print() const
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:559
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, const B1DataType *p_b1_grid, CDataType *p_c_grid, const std::array< void *, NumD0Tensor > p_acc0_biases, const std::array< void *, NumD1Tensor > p_acc1_biases, const std::vector< index_t > &a_gs_ms_ks_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_strides, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_lengths, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_strides, const std::vector< index_t > &c_gs_ms_gemm1ns_lengths, const std::vector< index_t > &c_gs_ms_gemm1ns_strides, const std::array< std::vector< ck::index_t >, NumD0Tensor > &acc0_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumD0Tensor > &acc0_biases_gs_ms_ns_strides, const std::array< std::vector< ck::index_t >, NumD1Tensor > &acc1_biases_gs_ms_gemm1ns_lengths, const std::array< std::vector< ck::index_t >, NumD1Tensor > &acc1_biases_gs_ms_gemm1ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, C0DEElementwiseOperation c0de_element_op, B1ElementwiseOperation b1_element_op, C1DEElementwiseOperation c1de_element_op)
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:448
const ADataType * p_a_grid_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:576
C0MatrixMask c0_matrix_mask_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:609
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:584
CDataType * p_c_grid_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:579
std::array< std::vector< ck::index_t >, NumD0Tensor > d0s_nl_ns_lengths_strides_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:617
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:585
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:625
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:706
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:628
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:219
decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})) BGridDesc_BK0_N_BK1
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:305
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:803
decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})) B1GridDesc_G_N_K
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:310
static auto MakeD0sGridDescriptor_G_M_N(const std::array< std::vector< ck::index_t >, NumD0Tensor > &acc0_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumD0Tensor > &acc0_biases_gs_ms_ns_strides)
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:292
static constexpr bool IsValidCompilationParameter()
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:713
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const void *p_b1, void *p_c, const std::array< void *, NumD0Tensor > p_acc0_biases, const std::array< void *, NumD1Tensor > p_acc1_biases, const std::vector< index_t > &a_gs_ms_ks_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_strides, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_lengths, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_strides, const std::vector< index_t > &c_gs_ms_gemm1ns_lengths, const std::vector< index_t > &c_gs_ms_gemm1ns_strides, const std::array< std::vector< ck::index_t >, NumD0Tensor > acc0_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumD0Tensor > acc0_biases_gs_ms_ns_strides, const std::array< std::vector< ck::index_t >, NumD1Tensor > acc1_biases_gs_ms_gemm1ns_lengths, const std::array< std::vector< ck::index_t >, NumD1Tensor > acc1_biases_gs_ms_gemm1ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, C0DEElementwiseOperation c0de_element_op, B1ElementwiseOperation b1_element_op, C1DEElementwiseOperation c1de_element_op) override
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:864
static constexpr index_t NumD1Tensor
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:224
static constexpr index_t NumD0Tensor
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:223
static auto MakeB1GridDescriptor_BK0_N_BK1(const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_lengths_vec, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_strides_vec)
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:271
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:917
constexpr static auto make_MaskOutPredicate()
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:315
std::string GetTypeString() const override
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:923
static constexpr auto I1
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:242
decltype(MakeD0sGridDescriptor_M_N({}, {})) D0sGridDesc_M_N
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:312
static auto MakeInvoker()
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:860
static constexpr auto I2
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:243
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, const B1DataType *p_b1, CDataType *p_c, const std::array< void *, NumD0Tensor > p_acc0_biases, const std::array< void *, NumD1Tensor > p_acc1_biases, const std::vector< index_t > &a_gs_ms_ks_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_strides, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_lengths, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_strides, const std::vector< index_t > &c_gs_ms_gemm1ns_lengths, const std::vector< index_t > &c_gs_ms_gemm1ns_strides, const std::array< std::vector< ck::index_t >, NumD0Tensor > acc0_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumD0Tensor > acc0_biases_gs_ms_ns_strides, const std::array< std::vector< ck::index_t >, NumD1Tensor > acc1_biases_gs_ms_gemm1ns_lengths, const std::array< std::vector< ck::index_t >, NumD1Tensor > acc1_biases_gs_ms_gemm1ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, C0DEElementwiseOperation c0de_element_op, B1ElementwiseOperation b1_element_op, C1DEElementwiseOperation c1de_element_op)
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:808
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:254
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector< index_t > &b_gs_ns_ks_lengths_vec, const std::vector< index_t > &b_gs_ns_ks_strides_vec)
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:262
static constexpr auto I0
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:241
decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})) AGridDesc_G_M_K
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:308
decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})) BGridDesc_G_N_K
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:309
static auto MakeD0sGridDescriptor_M_N(const std::array< std::vector< ck::index_t >, NumD0Tensor > &acc0_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumD0Tensor > &acc0_biases_gs_ms_ns_strides)
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:280
static bool IsSupportedArgument(const Argument &arg)
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:719
decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})) AGridDesc_AK0_M_AK1
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:304
decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})) B1GridDesc_BK0_N_BK1
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:306
decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})) C1GridDesc_G_M_N
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:311
C0MatrixMask_impl< decltype(make_MaskOutPredicate())> C0MatrixMask
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:326
GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector > GridwiseGemm
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:442
decltype(Transform::MakeCGridDescriptor_M_N({}, {})) C1GridDesc_M_N
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:307
decltype(MakeD0sGridDescriptor_G_M_N({}, {})) D0sGridDesc_G_M_N
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:313
Definition: device_batched_gemm_softmax_gemm_permute.hpp:34
Definition: masking_specialization.hpp:27
Definition: masking_specialization.hpp:41