/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp Source File
device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
11 #include "ck/utility/env.hpp"
26 #include "ck/host_utility/io.hpp"
27 
29 
30 namespace ck {
31 namespace tensor_operation {
32 namespace device {
33 
34 namespace {
35 
36 /*
37  * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
38  *
39  * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
40  * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
41  * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
42  * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
43  * limitations.
44  *
45  * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
46  * returns the 2D index of the tile that it computes. \see
47  * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
48  *
49  * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
50  * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
51  * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
52  * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
53  * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
54  * pointer offset into \p ComputePtrOffsetOfStridedBatch.
55  *
56  * MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this
57  * implementation we can avoid copy data to workspace before kernel launch since number of groups is
58  * runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then we run this
59  * kernel in the loop.
60  *
61  * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
62  * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
63  * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
64  *
65  */
66 template <typename GridwiseGemm,
67  typename ABDataType,
68  typename DsPointer,
69  typename EDataType,
70  index_t MaxGroupedGemmGroupsNum,
71  typename GemmArgs,
72  typename AElementwiseOp,
73  typename BElementwiseOp,
74  typename CDEElementwiseOp,
75  typename ComputePtrOffsetOfBatch,
76  typename ComputePtrOffsetOfN,
77  InMemoryDataOperationEnum OutElementOp,
78  bool HasMainKBlockLoopInAllGemm,
79  bool NoMainKBlockLoopInAllGemm,
80  bool CTranspose>
81 __global__ void
82 #if CK_USE_LAUNCH_BOUNDS
84 #endif
85  kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle(
86  const ABDataType* __restrict__ p_a_grid,
87  const ABDataType* __restrict__ p_b_grid,
88  DsPointer p_ds_grid,
89  EDataType* __restrict__ p_e_grid,
90  const std::array<GemmArgs, MaxGroupedGemmGroupsNum> gemm_kernel_args,
91  const index_t gemms_count,
92  const AElementwiseOp a_element_op,
93  const BElementwiseOp b_element_op,
94  const CDEElementwiseOp cde_element_op,
95  const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
96  const ComputePtrOffsetOfN compute_ptr_offset_of_n,
97  const index_t KBatch)
98 {
99 #if defined(__gfx9__)
100  // offset base pointer for each work-group
101  const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
102  const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
103  const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch);
104  const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch);
105 
106  const long_index_t a_batch_offset =
107  CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))
108  : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
109  const long_index_t b_batch_offset =
110  CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))
111  : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
112  const long_index_t e_batch_offset =
113  amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
114 
115  const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
116 
117  const long_index_t a_n_offset =
118  CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
119  const long_index_t b_n_offset =
120  CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0;
121 
122  const long_index_t e_n_offset =
123  amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
124 
125  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
126 
127  DsPointer p_ds_grid_grp;
128 
129  static constexpr index_t NumDTensor = DsPointer::Size();
130 
131  static_for<0, NumDTensor, 1>{}(
132  [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
133 
134  index_t left = 0;
135  index_t right = gemms_count;
136  index_t group_id = index_t((left + right) / 2);
137  while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ &&
138  block_args_id < gemm_kernel_args[group_id].BlockEnd_)) &&
139  left <= right)
140  {
141  if(block_args_id < gemm_kernel_args[group_id].BlockStart_)
142  {
143  right = group_id;
144  }
145  else
146  {
147  left = group_id;
148  }
149  group_id = index_t((left + right) / 2);
150  }
151 
152  if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm)
153  {
154  GridwiseGemm::template Run<HasMainKBlockLoopInAllGemm, OutElementOp>(
155  p_a_grid + a_batch_offset + a_n_offset,
156  p_b_grid + b_batch_offset + b_n_offset,
157  p_ds_grid_grp,
158  p_e_grid + e_batch_offset + e_n_offset,
159  p_shared,
160  a_element_op,
161  b_element_op,
162  cde_element_op,
163  gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
164  gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
165  gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
166  gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
167  gemm_kernel_args[group_id].block_2_ctile_map_,
168  KBatch,
169  k_idx);
170  }
171  else
172  {
173  if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
174  {
175  GridwiseGemm::template Run<true, OutElementOp>(
176  p_a_grid + a_batch_offset + a_n_offset,
177  p_b_grid + b_batch_offset + b_n_offset,
178  p_ds_grid_grp,
179  p_e_grid + e_batch_offset + e_n_offset,
180  p_shared,
181  a_element_op,
182  b_element_op,
183  cde_element_op,
184  gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
185  gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
186  gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
187  gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
188  gemm_kernel_args[group_id].block_2_ctile_map_,
189  KBatch,
190  k_idx);
191  }
192  else
193  {
194  GridwiseGemm::template Run<false, OutElementOp>(
195  p_a_grid + a_batch_offset + a_n_offset,
196  p_b_grid + b_batch_offset + b_n_offset,
197  p_ds_grid_grp,
198  p_e_grid + e_batch_offset + e_n_offset,
199  p_shared,
200  a_element_op,
201  b_element_op,
202  cde_element_op,
203  gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
204  gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
205  gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
206  gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
207  gemm_kernel_args[group_id].block_2_ctile_map_,
208  KBatch,
209  k_idx);
210  }
211  }
212 #else
213  ignore = p_a_grid;
214  ignore = p_b_grid;
215  ignore = p_ds_grid;
216  ignore = p_e_grid;
217  ignore = gemm_kernel_args;
218  ignore = gemms_count;
219  ignore = a_element_op;
220  ignore = b_element_op;
221  ignore = cde_element_op;
222  ignore = compute_ptr_offset_of_batch;
223  ignore = compute_ptr_offset_of_n;
224  ignore = KBatch;
225 #endif
226 }
227 
228 } // namespace
229 
230 // Conv backward data multiple D:
231 // input : output image A: [G, N, K, Ho, Wo]
232 // input : weight B: [G, K, C, Y, X],
233 // input : D0, D1, ... : [G, N, K, Ho, Wo]
234 // output : input image E: [G, N, C, Hi, Wi]
235 // C = a_op(A) * b_op(B)
236 // E = cde_op(C, D0, D1, ...)
237 template <index_t NDimSpatial,
238  typename ALayout, // output image
239  typename BLayout, // weight
240  typename DsLayout, // bias
241  typename ELayout, // input image
242  typename ADataType, // output image
243  typename BDataType, // weight
244  typename AccDataType,
245  typename CShuffleDataType,
246  typename DsDataType, // bias
247  typename EDataType, // input image
248  typename AElementwiseOp, // output image
249  typename BElementwiseOp, // weight
250  typename CDEElementwiseOp, // C, bias, and input image
251  ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization,
252  bool DoPadGemmM,
253  bool DoPadGemmN,
254  index_t NumGemmKPrefetchStage,
255  index_t BlockSize,
256  index_t MPerBlock,
257  index_t NPerBlock,
258  index_t KPerBlock,
259  index_t AK1,
260  index_t BK1,
261  index_t MPerXDL,
262  index_t NPerXDL,
263  index_t MXdlPerWave,
264  index_t NXdlPerWave,
265  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
266  typename ABlockTransferThreadClusterArrangeOrder,
267  typename ABlockTransferSrcAccessOrder,
268  index_t ABlockTransferSrcVectorDim,
269  index_t ABlockTransferSrcScalarPerVector,
270  index_t ABlockTransferDstScalarPerVector_AK1,
271  index_t ABlockLdsExtraM,
272  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
273  typename BBlockTransferThreadClusterArrangeOrder,
274  typename BBlockTransferSrcAccessOrder,
275  index_t BBlockTransferSrcVectorDim,
276  index_t BBlockTransferSrcScalarPerVector,
277  index_t BBlockTransferDstScalarPerVector_BK1,
278  index_t BBlockLdsExtraN,
279  index_t CShuffleMXdlPerWavePerShuffle,
280  index_t CShuffleNXdlPerWavePerShuffle,
281  typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
282  index_t CDEBlockTransferScalarPerVector_NPerBlock,
284  typename AComputeType = ADataType,
285  typename BComputeType = AComputeType,
286  index_t MaxTransposeTransferInScalarPerVector = 1,
287  index_t MaxTransposeTransferOutScalarPerVector = 1>
289  : public DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
290  ALayout, // output image
291  BLayout, // weight
292  DsLayout, // bias
293  ELayout, // input image
294  ADataType, // output image
295  BDataType, // weight
296  DsDataType, // bias
297  EDataType, // input image
298  AElementwiseOp,
299  BElementwiseOp,
300  CDEElementwiseOp,
301  AComputeType,
302  BComputeType>
303 {
304  // TODO: Extend support for more spatial dimensions.
305  static_assert(NDimSpatial == 2 || NDimSpatial == 3,
306  "wrong! only implemented for 2D and 3D now");
307 
308  // MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this
309  // implementation we can avoid copy data to workspace before kernel launch since number of
310  // groups is runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then
311  // we run this kernel in the loop.
312  static constexpr index_t MaxGroupedGemmGroupsNum =
313  ConvBackwardDataSpecialization ==
315  ? 1
316  : 32;
317 
319 
320  static constexpr index_t NumDTensor = DsDataType::Size();
322  static constexpr bool IsSplitKSupported =
323  (CDEBlockTransferScalarPerVector_NPerBlock % 2 == 0 || sizeof(EDataType) % 4 == 0) &&
324  std::is_same_v<remove_cvref_t<CDEElementwiseOp>, element_wise::PassThrough>;
325 
326  // TODO: Add support for different A and B data types.
327  using ABDataType = ADataType;
328 
329  static constexpr auto I0 = Number<0>{};
330  static constexpr auto I1 = Number<1>{};
331  static constexpr auto I2 = Number<2>{};
332  static constexpr auto I3 = Number<3>{};
333 
334  static constexpr bool isATensorColMajor =
335  (ConvBackwardDataSpecialization ==
337  (ABlockTransferSrcVectorDim == 1) &&
338  (is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
339  is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>());
340 
341  static constexpr bool NeedTransposeKernel =
342  (isATensorColMajor == false) && (is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
343  is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>());
344 
345  static constexpr bool CTranspose =
346  (NeedTransposeKernel == false) && (is_same_v<ELayout, tensor_layout::convolution::NGCHW> ||
347  is_same_v<ELayout, tensor_layout::convolution::NGCDHW>);
348 
350  is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
352  std::conditional_t<is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
354  ALayout>>;
356  is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
358  std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>() &&
361  BLayout>>;
363  is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
365  std::conditional_t<is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
367  ELayout>>;
368 
370  ConvBackwardDataSpecialization,
371  AK1,
372  BK1,
373  MPerBlock,
374  NPerBlock,
375  KPerBlock,
376  DoPadGemmM,
377  DoPadGemmN,
381  true, /*SplitConvN*/
382  ABDataType,
383  EDataType,
384  1,
385  index_t,
386  CTranspose>;
387 
388  static auto
390  {
391  const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1();
392 
393  const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1();
394 
395  const auto ds_grid_desc_m_n = generate_tuple(
396  [&](auto i) {
397  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
398  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
399  using ConvToGemmBwdDataTransformD =
400  TransformConvBwdDataToGemm_v1<NDimSpatial,
401  ConvBackwardDataSpecialization,
402  AK1,
403  BK1,
404  MPerBlock,
405  NPerBlock,
406  KPerBlock,
407  DoPadGemmM,
408  DoPadGemmN,
410  BLayout,
411  DLayout,
412  true, /*SplitConvN*/
413  ABDataType,
414  DDataType,
415  1, /*index_t NumGroupsToMerge = 1,*/
416  index_t, /* typename IndexType = */
417  CTranspose>;
418  return ConvToGemmBwdDataTransformD{}.MakeCDescriptor_M_N();
419  },
421 
422  const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N();
423  if constexpr(CTranspose)
424  {
425  return make_tuple(
426  b_grid_desc_bk0_n_bk1, a_grid_desc_ak0_m_ak1, ds_grid_desc_m_n, e_grid_desc_m_n);
427  }
428  else
429  {
430  return make_tuple(
431  a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
432  }
433  }
434 
435 // GridwiseGemm
436 #define GridwiseGemmMultiDTemplateParams \
437  ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
438  AElementwiseOp, BElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \
439  MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
440  ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
441  ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
442  ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
443  ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
444  BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
445  BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
446  BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
447  CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
448  CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
449  CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType
450 
451 #define GridwiseGemmCTransposeTemplateParameters \
452  ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
453  BElementwiseOp, AElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \
454  NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, MPerXDL, NXdlPerWave, MXdlPerWave, \
455  BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \
456  BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \
457  BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \
458  BBlockLdsExtraN, ABlockTransferThreadClusterLengths_AK0_M_AK1, \
459  ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \
460  ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \
461  ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \
462  CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
463  CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
464  CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType
465 
468  CTranspose,
470  GridwiseGemm>;
471 
472  template <typename EGridDesc_M_N>
473  static auto
475  {
476  return GridwiseGemmCTranspose::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
477  e_grid_desc_m_n);
478  }
479 
480  template <typename Desc_K0_M_K1>
481  static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1)
482  {
483  const auto grid_desc_m_k = transform_tensor_descriptor(
484  desc_k0_m_k1,
485  make_tuple(make_pass_through_transform(desc_k0_m_k1.GetLength(I1)),
487  make_tuple(desc_k0_m_k1.GetLength(I0), desc_k0_m_k1.GetLength(I2)))),
490 
491  return grid_desc_m_k;
492  }
493 
494  // desc
497 
502 
505 
507  decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
508  DsGridDesc_M_N{}));
511 
512  // block-to-e-tile map
514  decltype(GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}));
515 
517 
518  struct GemmArgs
519  {
520  GemmArgs() = default;
521  GemmArgs(AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
522  BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
524  ds_grid_desc_mblock_mperblock_nblock_nperblock,
526  e_grid_desc_mblock_mperblock_nblock_nperblock,
527  GroupedGemmBlock2ETileMap block_2_ctile_map,
528  index_t BlockStart,
529  index_t BlockEnd,
530  bool HasMainKBlockLoop)
531  : a_grid_desc_ak0_m_ak1_(a_grid_desc_ak0_m_ak1),
532  b_grid_desc_bk0_n_bk1_(b_grid_desc_bk0_n_bk1),
533 
535  ds_grid_desc_mblock_mperblock_nblock_nperblock),
536 
538  e_grid_desc_mblock_mperblock_nblock_nperblock),
539 
540  // block-to-e-tile map
541  block_2_ctile_map_(block_2_ctile_map),
542  BlockStart_(BlockStart),
543  BlockEnd_(BlockEnd),
544  HasMainKBlockLoop_(HasMainKBlockLoop)
545 
546  {
547  }
548  // tensor descriptors for block/thread-wise copy
554 
555  // block-to-e-tile map
559  };
562 
563  static constexpr index_t ClusterLengthMPerBlock =
564  CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
565  static constexpr index_t ClusterLengthNPerBlock =
566  CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
567 
568  static constexpr auto conv_ngchw_to_nhwgc_transformer =
570  BLayout,
571  ALayout,
572  NDimSpatial,
573  NPerBlock / ClusterLengthNPerBlock,
574  MPerBlock / ClusterLengthMPerBlock>{};
575 
577  std::min(MPerBlock / ClusterLengthMPerBlock, MaxTransposeTransferInScalarPerVector);
579  std::min(MPerBlock / ClusterLengthMPerBlock, MaxTransposeTransferOutScalarPerVector);
580 
583  .template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
586  .template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
589  .template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
592  .template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
593 
595 
604  NPerBlock,
605  MPerBlock,
606  NPerBlock / ClusterLengthNPerBlock,
607  MPerBlock / ClusterLengthMPerBlock,
611  I1,
612  I0>;
613 
622  MPerBlock,
623  NPerBlock,
624  MPerBlock / ClusterLengthMPerBlock,
625  NPerBlock / ClusterLengthNPerBlock,
627  Sequence<1>,
629  I0,
630  I1>;
631 
640  NPerBlock,
641  MPerBlock,
642  NPerBlock / ClusterLengthNPerBlock,
643  MPerBlock / ClusterLengthMPerBlock,
647  I0,
648  I1>;
649  // Argument
650  struct Argument : public BaseArgument
651  {
652  Argument(const void* p_a, // output image
653  const void* p_b, // weight
654  const std::array<const void*, NumDTensor>& p_ds, // bias
655  void* p_e, // input image
656  const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths,
657  const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
658  const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
659  const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
660  const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
661  ds_g_n_c_wis_lengths,
662  const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
663  ds_g_n_c_wis_strides,
664  const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths,
665  const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides,
666  const std::array<index_t, NDimSpatial>& conv_filter_strides,
667  const std::array<index_t, NDimSpatial>& conv_filter_dilations,
668  const std::array<index_t, NDimSpatial>& input_left_pads,
669  const std::array<index_t, NDimSpatial>& input_right_pads,
670  const AElementwiseOp& a_element_op,
671  const BElementwiseOp& b_element_op,
672  const CDEElementwiseOp& cde_element_op,
673  ck::index_t split_k = 1)
674  : p_a_grid_{static_cast<const ADataType*>(p_a)},
675  p_b_grid_{static_cast<const BDataType*>(p_b)},
676  p_ds_grid_{},
677  p_e_grid_{static_cast<EDataType*>(p_e)},
678  num_group_{a_g_n_k_wos_lengths[0]},
679  a_element_op_{a_element_op},
680  b_element_op_{b_element_op},
681  cde_element_op_{cde_element_op},
682  a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths},
683  b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
684  e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths},
685  conv_filter_strides_{conv_filter_strides},
686  input_left_pads_{input_left_pads},
687  input_right_pads_{input_right_pads},
688  k_batch_{split_k}
689  {
690  bool image_covered_dilation = true;
691  bool image_covered_strides = true;
692  for(index_t d = 0; d < NDimSpatial; d++)
693  {
694  // If dilation and stride is not equal to the we will have some empty places
695  image_covered_dilation &=
696  conv_filter_dilations[d] == 1 || conv_filter_strides[d] == 1;
697  // If stride is larger than windows size then we will have some empty places
698  image_covered_strides &= conv_filter_strides[d] <= b_g_k_c_xs_lengths[d + I3];
699  }
700  bool if_d_is_output_mem = false;
701  const void* out_mem_void = static_cast<const void*>(p_e);
702  static_for<0, NumDTensor, 1>{}([&](auto i) {
703  if(p_ds[i] == out_mem_void)
704  {
705  if_d_is_output_mem = true;
706  }
707  });
708 
709  bwd_needs_zero_out = k_batch_ > 1 || !image_covered_dilation || !image_covered_strides;
710 
711  // Temporary workaround untill prove/fix above conditions.
712  bwd_needs_zero_out = !if_d_is_output_mem;
714  ck::accumulate_n<long_index_t>(
715  e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
716  sizeof(EDataType);
717 
718  std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
720  a_g_n_k_wos_lengths, a_g_n_k_wos_strides)
721  : a_g_n_k_wos_strides;
722  std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_transposed =
724  b_g_k_c_xs_lengths, b_g_k_c_xs_strides)
725  : b_g_k_c_xs_strides;
726  std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_transposed =
728  e_g_n_c_wis_lengths, e_g_n_c_wis_strides)
729  : e_g_n_c_wis_strides;
730 
731  // populate Ds pointer
732  static_for<0, NumDTensor, 1>{}([&](auto i) {
733  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
734 
735  p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
736  });
737 
738  static_for<0, NumDTensor, 1>{}([&](auto i) {
739  compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0];
740  });
741 
742  static constexpr auto NonSpatialDimsNum = Number<3>{};
743 
744  static constexpr auto DIdx = Number<NonSpatialDimsNum>{};
745  static constexpr auto HIdx =
747  static constexpr auto WIdx = NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{}
749 
750  static constexpr auto ZIdx = Number<NonSpatialDimsNum>{};
751  static constexpr auto YIdx =
753  static constexpr auto XIdx = NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{}
755 
756  // problem definition
757  const index_t Z = b_g_k_c_xs_lengths[ZIdx];
758  const index_t Y = b_g_k_c_xs_lengths[YIdx];
759  const index_t X = b_g_k_c_xs_lengths[XIdx];
760 
761  const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
762  const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
763  const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
764 
765  const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
766  const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
767  const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
768 
769  const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
770  const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
771  const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
772 
773  const auto ZTilde = NDimSpatial == 3 ? ConvStrideD / GcdStrideDilationD : 1;
774  const auto YTilde = ConvStrideH / GcdStrideDilationH;
775  const auto XTilde = ConvStrideW / GcdStrideDilationW;
776 
777  index_t grid_size = 0;
778  // Allocate place for sets of gemms
779  gemm_kernel_args_.resize(
780  math::integer_divide_ceil(ZTilde * YTilde * XTilde, MaxGroupedGemmGroupsNum));
781 
782  for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
783  {
784  for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
785  {
786  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
787  {
788  // check slice is valid
789  const auto ZDotSlice =
790  NDimSpatial == 3 ? math::integer_divide_ceil(Z - i_ztilde, ZTilde) : 1;
791  const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
792  const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
793 
794  if(YDotSlice * XDotSlice * ZDotSlice <= 0)
795  {
796  continue;
797  }
798 
799  std::array<index_t, NDimSpatial> tildes;
800  if constexpr(NDimSpatial == 2)
801  {
802  tildes = {i_ytilde, i_xtilde};
803  }
804  else if constexpr(NDimSpatial == 3)
805  {
806  tildes = {i_ztilde, i_ytilde, i_xtilde};
807  }
808  else
809  {
810  throw std::runtime_error("wrong! only implemented for 2D and 3D now");
811  }
812 
813  ConvToGemmBwdDataTransform conv_to_gemm_transform_{
814  a_g_n_k_wos_lengths,
815  a_g_n_k_wos_strides_transposed,
816  b_g_k_c_xs_lengths,
817  b_g_k_c_xs_strides_transposed,
818  e_g_n_c_wis_lengths,
819  e_g_n_c_wis_strides_transposed,
820  conv_filter_strides,
821  conv_filter_dilations,
822  input_left_pads,
823  input_right_pads,
824  tildes,
825  k_batch_};
826 
827  conv_N_per_block_ = conv_to_gemm_transform_.N_;
828 
829  const auto a_grid_desc_ak0_m_ak1 = [&]() {
830  if constexpr(CTranspose)
831  {
832  return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
833  }
834  else
835  {
836  return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
837  }
838  }();
839 
840  const auto b_grid_desc_bk0_n_bk1 = [&]() {
841  if constexpr(CTranspose)
842  {
843  return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
844  }
845  else
846  {
847  return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
848  }
849  }();
850  DsGridDesc_M_N ds_grid_desc_m_n;
851 
852  // populate Ds desc
853  static_for<0, NumDTensor, 1>{}([&](auto i) {
854  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
855  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
856  using ConvToGemmBwdDataTransformD =
857  TransformConvBwdDataToGemm_v1<NDimSpatial,
858  ConvBackwardDataSpecialization,
859  AK1,
860  BK1,
861  MPerBlock,
862  NPerBlock,
863  KPerBlock,
864  DoPadGemmM,
865  DoPadGemmN,
868  DLayout,
869  true, /*SplitConvN*/
870  ABDataType,
871  DDataType,
872  1,
873  index_t,
874  CTranspose>;
875  ConvToGemmBwdDataTransformD conv_to_gemm_transform_d{
876  a_g_n_k_wos_lengths,
877  a_g_n_k_wos_strides_transposed,
878  b_g_k_c_xs_lengths,
879  b_g_k_c_xs_strides_transposed,
880  ds_g_n_c_wis_lengths[i],
881  ds_g_n_c_wis_strides[i],
882  conv_filter_strides,
883  conv_filter_dilations,
884  input_left_pads,
885  input_right_pads,
886  tildes};
887 
888  ds_grid_desc_m_n(i) = conv_to_gemm_transform_d.MakeCDescriptor_M_N();
889  });
890 
891  const auto e_grid_desc_m_n = conv_to_gemm_transform_.MakeCDescriptor_M_N();
892 
893  // desc for problem definition
894  const auto a_grid_desc_m_k =
895  transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1);
896  const auto b_grid_desc_n_k =
897  transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1);
898 
899  a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k);
900  b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k);
901  ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
902  e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n);
903 
904  const index_t grid_size_grp = Block2ETileMap::CalculateGridSize(
905  e_grid_desc_m_n.GetLength(I0), e_grid_desc_m_n.GetLength(I1));
906 
907  const index_t BlockStart = grid_size;
908  const index_t BlockEnd = grid_size + grid_size_grp;
909 
910  grid_size += grid_size_grp;
911 
912  // block-to-e-tile map
913  const auto block_2_etile_map =
914  GroupedGemmBlock2ETileMap(Block2ETileMap(e_grid_desc_m_n.GetLength(I0),
915  e_grid_desc_m_n.GetLength(I1)),
916  BlockStart);
917 
918  const auto GemmK = a_grid_desc_m_k.GetLength(I1);
919  const bool HasMainKBlockLoop =
920  GridwiseGemmCTranspose::CalculateHasMainKBlockLoop(GemmK, k_batch_);
921 
925  GemmArgs{a_grid_desc_ak0_m_ak1,
926  b_grid_desc_bk0_n_bk1,
927  GridwiseGemmCTranspose::
928  MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
929  ds_grid_desc_m_n),
931  e_grid_desc_m_n),
932  block_2_etile_map,
933  BlockStart,
934  BlockEnd,
935  HasMainKBlockLoop};
936  gemms_count_++;
938  {
939  gemms_grid_size_.push_back(grid_size);
940  grid_size = 0;
941  }
942  }
943  }
944  }
945  gemm_kernel_args_.resize(
947  gemms_grid_size_.push_back(grid_size);
948 
949  // A/B/Ds/E Batch Stride
950  compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
951  compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides_transposed[0];
952  compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides_transposed[0];
953 
954  compute_ptr_offset_of_n_.BatchStrideA_ =
955  a_g_n_k_wos_strides_transposed[1] * conv_N_per_block_;
956  compute_ptr_offset_of_n_.BatchStrideE_ =
957  e_g_n_c_wis_strides_transposed[1] * conv_N_per_block_;
958 
960 
961  if constexpr(NeedTransposeKernel)
962  {
963  // Use not modified base strides
965  conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
966  a_g_n_k_wos_lengths, a_g_n_k_wos_strides, num_workgroups_per_Conv_N_);
968  conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
969  a_g_n_k_wos_lengths, a_g_n_k_wos_strides, num_workgroups_per_Conv_N_);
970 
972  conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc<NDimSpatial>(
973  b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
975  conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc<NDimSpatial>(
976  b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
977 
979  conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
980  e_g_n_c_wis_lengths, e_g_n_c_wis_strides, num_workgroups_per_Conv_N_);
982  conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
983  e_g_n_c_wis_lengths, e_g_n_c_wis_strides, num_workgroups_per_Conv_N_);
984 
986  a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
988  b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
990  e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
991 
992  compute_ptr_offset_of_workspace_n_.BatchStrideA_ =
993  a_g_n_k_wos_strides[1] * conv_N_per_block_;
994  compute_ptr_offset_of_workspace_n_.BatchStrideE_ =
995  e_g_n_c_wis_strides[1] * conv_N_per_block_;
996  }
997  }
998 
999  std::size_t GetWorkspaceATensorSizeBytes() const
1000  {
1001  if constexpr(NeedTransposeKernel)
1002  {
1003  const long_index_t a_acum = ck::accumulate_n<long_index_t>(
1004  a_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
1005  // Align to 128B
1006  return math::integer_divide_ceil(sizeof(ADataType) * a_acum, 128) * 128;
1007  }
1008  else
1009  {
1010  return 0;
1011  }
1012  }
1013 
1014  std::size_t GetWorkspaceBTensorSizeBytes() const
1015  {
1016  if constexpr(NeedTransposeKernel)
1017  {
1018  const long_index_t b_acum = ck::accumulate_n<long_index_t>(
1019  b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
1020  // Align to 128B
1021  return math::integer_divide_ceil(sizeof(BDataType) * b_acum, 128) * 128;
1022  }
1023  else
1024  {
1025  return 0;
1026  }
1027  }
1028 
1029  std::size_t GetWorkspaceETensorSizeBytes() const
1030  {
1031  if constexpr(NeedTransposeKernel)
1032  {
1033  const long_index_t e_accum = ck::accumulate_n<long_index_t>(
1034  e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
1035  return sizeof(EDataType) * e_accum;
1036  }
1037  else
1038  {
1039  return 0;
1040  }
1041  }
1042 
1043  std::size_t GetWorkspaceSizeBytes() const
1044  {
1047  }
1048 
1049  void Print() const
1050  {
1051  for(std::size_t i = 0; i < a_grid_desc_m_k_container_.size(); i++)
1052  {
1053  std::cout << "a_grid_desc_m_ak_container_" << a_grid_desc_m_k_container_[i]
1054  << std::endl;
1055 
1056  std::cout << "b_grid_desc_n_bk_container_" << b_grid_desc_n_k_container_[i]
1057  << std::endl;
1058 
1059  static_for<0, NumDTensor, 1>{}([&](auto j) {
1060  std::cout << "ds_grid_desc_mblock_mperblock_nblock_nperblock_container_"
1061  << ds_grid_desc_m_n_container_[i][j] << std::endl;
1062  });
1063 
1064  std::cout << "e_grid_desc_mblock_mperblock_nblock_nperblock_container_"
1065  << e_grid_desc_m_n_container_[i] << std::endl;
1066  }
1067  }
1068 
1069  // pointers
1070  const ADataType* p_a_grid_;
1071  const BDataType* p_b_grid_;
1073  EDataType* p_e_grid_;
1074 
1075  // tensor descriptor for problem definition
1078  std::vector<AGridDesc_M_K> a_grid_desc_m_k_container_;
1079  std::vector<BGridDesc_N_K> b_grid_desc_n_k_container_;
1080  std::vector<DsGridDesc_M_N> ds_grid_desc_m_n_container_;
1081  std::vector<EGridDesc_M_N> e_grid_desc_m_n_container_;
1082 
1083  // block-to-e-tile map
1087 
1092 
1093  // for computing batch offset
1094  ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
1095  ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_n_;
1096  ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_workspace_n_;
1097 
1098  // element-wise op
1099  AElementwiseOp a_element_op_;
1100  BElementwiseOp b_element_op_;
1101  CDEElementwiseOp cde_element_op_;
1102 
1103  std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_;
1104  std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
1105  std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_lengths_;
1106  std::array<index_t, NDimSpatial> conv_filter_strides_;
1107  std::array<index_t, NDimSpatial> input_left_pads_;
1108  std::array<index_t, NDimSpatial> input_right_pads_;
1109 
1112  std::vector<index_t> gemms_grid_size_;
1114  std::vector<std::array<GemmArgs, MaxGroupedGemmGroupsNum>> gemm_kernel_args_;
1115 
1118  };
1119 
1120  // Invoker
1121  struct Invoker : public BaseInvoker
1122  {
1124 
1125  template <InMemoryDataOperationEnum ElementOp>
1126  float RunMultiDGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
1127  {
1128  float ave_time = 0;
1129 
1130  const index_t gdy = arg.num_group_;
1131  const index_t gdz = arg.num_workgroups_per_Conv_N_ * arg.k_batch_;
1132 
1133  const ADataType* p_a_grid = arg.p_a_grid_;
1134  const BDataType* p_b_grid = arg.p_b_grid_;
1135  EDataType* p_e_grid = arg.p_e_grid_;
1136  if constexpr(NeedTransposeKernel)
1137  {
1138  if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
1139  is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
1140  {
1141  p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
1142  p_e_grid =
1143  type_convert<EDataType*>(arg.p_workspace_) +
1145  sizeof(EDataType);
1146  }
1147 
1148  if constexpr(is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>() ||
1149  is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>())
1150  {
1151  p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
1152  arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
1153  }
1154  }
1155  for(std::size_t gemm_set_id = 0; gemm_set_id < arg.gemm_kernel_args_.size();
1156  gemm_set_id++)
1157  {
1158  const index_t gdx = arg.gemms_grid_size_[gemm_set_id];
1159  const index_t gemms_count_for_set =
1160  gemm_set_id == arg.gemm_kernel_args_.size() - 1
1161  ? arg.gemms_count_ - MaxGroupedGemmGroupsNum * gemm_set_id
1163  const std::array<GemmArgs, MaxGroupedGemmGroupsNum>& gemm_kernel_args =
1164  arg.gemm_kernel_args_[gemm_set_id];
1165 
1166  const auto clear_workspace = [&]() {
1167  if(arg.bwd_needs_zero_out && gemm_set_id == 0)
1168  {
1169  hip_check_error(hipMemsetAsync(
1170  p_e_grid, 0, arg.e_space_size_bytes, stream_config.stream_id_));
1171  }
1172  };
1173 
1174  bool has_loop_in_all_gemm = true;
1175  bool no_loop_in_all_gemm = true;
1176  for(auto i = 0; i < gemms_count_for_set; i++)
1177  {
1178  has_loop_in_all_gemm &= gemm_kernel_args[i].HasMainKBlockLoop_;
1179  no_loop_in_all_gemm &= !gemm_kernel_args[i].HasMainKBlockLoop_;
1180  }
1181 
1182  auto launch_kernel = [&](auto has_main_k_block_loop, auto no_main_k_block_loop) {
1183  constexpr bool has_main_loop = has_main_k_block_loop.value;
1184  constexpr bool no_main_loop = no_main_k_block_loop.value;
1185  if constexpr(CTranspose)
1186  {
1187  const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
1189  ADataType, // TODO: distiguish A/B datatype
1190  typename GridwiseGemm::DsGridPointer,
1191  EDataType,
1193  GemmArgs,
1194  BElementwiseOp,
1195  AElementwiseOp,
1196  CDEElementwiseOp,
1197  ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
1198  ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1199  ElementOp,
1200  has_main_loop,
1201  no_main_loop,
1202  CTranspose>;
1203 
1205  stream_config,
1206  clear_workspace,
1207  kernel,
1208  dim3(gdx, gdy, gdz),
1209  dim3(BlockSize),
1210  0,
1211  p_b_grid,
1212  p_a_grid,
1213  arg.p_ds_grid_,
1214  p_e_grid,
1215  gemm_kernel_args,
1216  gemms_count_for_set,
1217  arg.b_element_op_,
1218  arg.a_element_op_,
1219  arg.cde_element_op_,
1222  arg.k_batch_);
1223  }
1224  else
1225  {
1226  const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
1227  GridwiseGemm,
1228  ADataType, // TODO: distiguish A/B datatype
1229  typename GridwiseGemm::DsGridPointer,
1230  EDataType,
1232  GemmArgs,
1233  AElementwiseOp,
1234  BElementwiseOp,
1235  CDEElementwiseOp,
1236  ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
1237  ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1238  ElementOp,
1239  has_main_loop,
1240  no_main_loop,
1241  CTranspose>;
1242 
1244  stream_config,
1245  clear_workspace,
1246  kernel,
1247  dim3(gdx, gdy, gdz),
1248  dim3(BlockSize),
1249  0,
1250  p_a_grid,
1251  p_b_grid,
1252  arg.p_ds_grid_,
1253  p_e_grid,
1254  gemm_kernel_args,
1255  gemms_count_for_set,
1256  arg.a_element_op_,
1257  arg.b_element_op_,
1258  arg.cde_element_op_,
1261  arg.k_batch_);
1262  }
1263  };
1264  if(has_loop_in_all_gemm)
1265  {
1266  ave_time += launch_kernel(integral_constant<bool, true>{},
1267  integral_constant<bool, false>{});
1268  }
1269  else if(no_loop_in_all_gemm)
1270  {
1271  ave_time += launch_kernel(integral_constant<bool, false>{},
1272  integral_constant<bool, true>{});
1273  }
1274  else
1275  {
1276  ave_time += launch_kernel(integral_constant<bool, false>{},
1277  integral_constant<bool, false>{});
1278  }
1279  }
1280 
1281  return ave_time;
1282  }
1283 
1284  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
1285  {
1286  float ave_time = 0;
1287 
1288  if(stream_config.log_level_ > 0)
1289  {
1290  arg.Print();
1291  }
1292 
1293  // Transpose from NGKHW to NHWGK
1294  if constexpr(NeedTransposeKernel)
1295  {
1296  EDataType* p_e_in_grid =
1297  type_convert<EDataType*>(arg.p_workspace_) +
1299  sizeof(EDataType);
1300 
1301  const auto clear_workspace = [&]() {
1302  hip_check_error(hipMemsetAsync(p_e_in_grid,
1303  0,
1305  stream_config.stream_id_));
1306  };
1307 
1308  const index_t a_grid_size =
1309  arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
1310  arg.a_in_transpose_desc_) *
1312  const index_t b_grid_size =
1313  (is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>() ||
1314  is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>())
1315  ? arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize(
1317  : 0; // Dont run transpose B if not needed
1318 
1319  ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
1320  BDataType* p_b_out_grid = type_convert<BDataType*>(arg.p_workspace_) +
1321  arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
1322 
1323  auto kernel_transpose =
1337  I1,
1338  I1,
1339  I1,
1340  I1>;
1341 
1343  stream_config,
1344  clear_workspace,
1345  kernel_transpose,
1346  dim3(a_grid_size + b_grid_size),
1347  dim3(ElementwiseBlocksize),
1348  0,
1353  make_tuple(arg.p_a_grid_),
1354  make_tuple(arg.p_b_grid_),
1355  make_tuple(p_a_out_grid),
1356  make_tuple(p_b_out_grid),
1360  a_grid_size,
1362  I1, // B is not splited per N
1363  std::array<index_t, I1>{
1364  static_cast<index_t>(arg.compute_ptr_offset_of_workspace_n_.BatchStrideA_)},
1365  std::array<index_t, I1>{0},
1366  std::array<index_t, I1>{
1367  static_cast<index_t>(arg.compute_ptr_offset_of_n_.BatchStrideA_)},
1368  std::array<index_t, I1>{0});
1369  }
1370  if(arg.k_batch_ > 1)
1371  {
1372  if constexpr(IsSplitKSupported)
1373  {
1374  ave_time +=
1375  RunMultiDGemm<InMemoryDataOperationEnum::AtomicAdd>(arg, stream_config);
1376  }
1377  }
1378  else
1379  {
1380  ave_time += RunMultiDGemm<InMemoryDataOperationEnum::Set>(arg, stream_config);
1381  }
1382 
1383  // Transpose from NHWGC to NGCHW
1384  if constexpr(NeedTransposeKernel)
1385  {
1386  const index_t grid_size =
1387  arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
1388  arg.e_in_transpose_desc_) *
1390 
1391  const EDataType* p_e_in_grid =
1392  type_convert<EDataType*>(arg.p_workspace_) +
1394  sizeof(EDataType);
1395 
1396  EDataType* p_e_out_grid = arg.p_e_grid_;
1397 
1398  auto kernel_transpose =
1405  element_wise::PassThrough,
1406  I1,
1407  I1>;
1408 
1409  ave_time += launch_and_time_kernel(
1410  stream_config,
1411  kernel_transpose,
1412  dim3(grid_size),
1413  dim3(ElementwiseBlocksize),
1414  0,
1417  make_tuple(p_e_in_grid),
1418  make_tuple(p_e_out_grid),
1420  element_wise::PassThrough{},
1422  std::array<index_t, I1>{
1423  static_cast<index_t>(arg.compute_ptr_offset_of_n_.BatchStrideE_)},
1424  std::array<index_t, I1>{static_cast<index_t>(
1425  arg.compute_ptr_offset_of_workspace_n_.BatchStrideE_)});
1426  }
1427 
1428  return ave_time;
1429  }
1430 
1431  float Run(const BaseArgument* p_arg,
1432  const StreamConfig& stream_config = StreamConfig{}) override
1433  {
1434  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
1435  }
1436  };
1437 
1438  static bool IsSupportedArgument(const Argument& arg)
1439  {
1440  if(!ck::is_xdl_supported())
1441  {
1442  return false;
1443  }
1444 
1445  if(!is_bf16_atomic_supported() && std::is_same_v<EDataType, ck::bhalf_t> &&
1446  arg.k_batch_ > 1)
1447  {
1448  return false;
1449  }
1450 
1451  if constexpr(!IsSplitKSupported)
1452  {
1453  if(arg.k_batch_ != 1)
1454  {
1455  return false;
1456  }
1457  }
1458 
1459  const index_t ConvG = arg.b_g_k_c_xs_lengths_[0];
1460  const index_t ConvK = arg.b_g_k_c_xs_lengths_[1];
1461  const index_t ConvC = arg.b_g_k_c_xs_lengths_[2];
1462  const index_t output_spatial_acum = ck::accumulate_n<index_t>(
1463  arg.e_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
1464  const index_t input_spatial_acum = ck::accumulate_n<index_t>(
1465  arg.a_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
1466  // Specifialization
1467  if constexpr(ConvBackwardDataSpecialization ==
1469  {
1470  // check if it's 1x1, stride=1 pad = 0 conv
1471  for(int i = 0; i < NDimSpatial; i++)
1472  {
1473  if(!(arg.b_g_k_c_xs_lengths_[3 + i] == 1 && arg.conv_filter_strides_[i] == 1 &&
1474  arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
1475  {
1476  return false;
1477  }
1478  }
1479  }
1480 
1481  // vector load for A matrix from global memory to LDS
1482  if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
1483  is_same_v<ALayout, tensor_layout::convolution::GNDHWK> ||
1484  is_same_v<ALayout, tensor_layout::convolution::NHWGK> ||
1485  is_same_v<ALayout, tensor_layout::convolution::NDHWGK> || NeedTransposeKernel)
1486  {
1487  if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0))
1488  {
1489  return false;
1490  }
1491  }
1492  else if(is_same_v<ALayout, tensor_layout::convolution::NGKHW> ||
1493  is_same_v<ALayout, tensor_layout::convolution::NGKDHW>)
1494  {
1495  static_assert(NeedTransposeKernel == false);
1496 
1497  if constexpr(ABlockTransferSrcScalarPerVector != 1)
1498  {
1499  if(ABlockTransferSrcVectorDim != 1)
1500  {
1501  return false;
1502  }
1503  if(output_spatial_acum % ABlockTransferSrcScalarPerVector != 0)
1504  {
1505  return false;
1506  }
1507  }
1508  }
1509  else
1510  {
1511  return false;
1512  }
1513 
1514  // vector load for B matrix from global memory to LDS
1515  if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
1516  is_same_v<BLayout, tensor_layout::convolution::GKZYXC> ||
1517  is_same_v<BLayout, tensor_layout::convolution::GKCYX> ||
1518  is_same_v<BLayout, tensor_layout::convolution::GKCZYX>)
1519  {
1520  if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0))
1521  {
1522  return false;
1523  }
1524  }
1525  else
1526  {
1527  return false;
1528  }
1529 
1530  // vector store for Ds
1531  bool ds_valid = true;
1532 
1533  static_for<0, NumDTensor, 1>{}([&](auto i) {
1534  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
1535 
1536  if constexpr(is_same_v<DLayout, tensor_layout::convolution::GNHWC> ||
1537  is_same_v<DLayout, tensor_layout::convolution::GNDHWC> ||
1538  is_same_v<DLayout, tensor_layout::convolution::NHWGC> ||
1539  is_same_v<DLayout, tensor_layout::convolution::NDHWGC> ||
1540  is_same_v<DLayout, tensor_layout::convolution::G_NHW_C> ||
1541  is_same_v<DLayout, tensor_layout::convolution::GC> ||
1542  is_same_v<DLayout, tensor_layout::convolution::G_C>)
1543  {
1544  if(CTranspose == false)
1545  {
1546  // vector load D matrix from global memory
1547  if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
1548  {
1549  ds_valid = false;
1550  }
1551  }
1552  else
1553  {
1554  if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1555  {
1556  ds_valid = false;
1557  }
1558  }
1559  }
1560  else
1561  {
1562  ds_valid = false;
1563  }
1564  });
1565 
1566  if(!ds_valid)
1567  {
1568  return false;
1569  }
1570 
1571  // vector store for E
1572  if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC> ||
1573  is_same_v<ELayout, tensor_layout::convolution::GNDHWC> ||
1574  is_same_v<ELayout, tensor_layout::convolution::NHWGC> ||
1575  is_same_v<ELayout, tensor_layout::convolution::NDHWGC> ||
1576  is_same_v<ELayout, tensor_layout::convolution::NGCHW> ||
1577  is_same_v<ELayout, tensor_layout::convolution::NGCDHW>)
1578  {
1579  if(CTranspose == false)
1580  {
1581  // vector store C matrix into global memory
1582  if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
1583  {
1584  return false;
1585  }
1586  }
1587  else
1588  {
1589  if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1590  {
1591  return false;
1592  }
1593  }
1594  }
1595  else
1596  {
1597  return false;
1598  }
1599 
1600  // Gridwise GEMM size
1601  for(std::size_t i = 0; i < arg.a_grid_desc_m_k_container_.size(); i++)
1602  {
1603  if(!GridwiseGemmCTranspose::CheckValidity(
1609  .block_2_ctile_map_,
1610  arg.k_batch_))
1611  {
1612  return false;
1613  }
1614  }
1615 
1616  if constexpr(NeedTransposeKernel)
1617  {
1618  if((ConvG * ConvC) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1619  {
1620  return false;
1621  }
1622 
1623  if((ConvG * ConvK) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1624  {
1625  return false;
1626  }
1627 
1628  const index_t a_spatial_acum = ck::accumulate_n<index_t>(
1629  arg.a_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
1630  const index_t e_spatial_acum = ck::accumulate_n<index_t>(
1631  arg.e_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
1632 
1633  if(a_spatial_acum % TransposeTransferInScalarPerVectorAligned != 0)
1634  {
1635  return false;
1636  }
1637 
1638  if(e_spatial_acum % TransposeTransferOutScalarPerVectorAligned != 0)
1639  {
1640  return false;
1641  }
1642 
1643  if(!arg.p_workspace_)
1644  {
1645  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1646  {
1647  std::cout
1648  << "Warning: Workspace for "
1649  "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::Argument is not "
1650  "allocated, use SetWorkSpacePointer."
1651  << std::endl;
1652  }
1653  return false;
1654  }
1655  }
1656 
1657  return true;
1658  }
1659 
1660  bool IsSupportedArgument(const BaseArgument* p_arg) override
1661  {
1662  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
1663  }
1664 
1665  static auto
1666  MakeArgument(const void* p_a, // output image
1667  const void* p_b, // weight
1668  const std::array<const void*, NumDTensor>& p_ds, // bias
1669  void* p_e, // input image
1670  const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output image
1671  const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides, // output image
1672  const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
1673  const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight
1674  const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
1675  ds_g_n_c_wis_lengths, // bias
1676  const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
1677  ds_g_n_c_wis_strides, // bias
1678  const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image
1679  const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides, // input image
1680  const std::array<index_t, NDimSpatial>& conv_filter_strides,
1681  const std::array<index_t, NDimSpatial>& conv_filter_dilations,
1682  const std::array<index_t, NDimSpatial>& input_left_pads,
1683  const std::array<index_t, NDimSpatial>& input_right_pads,
1684  const AElementwiseOp& a_element_op,
1685  const BElementwiseOp& b_element_op,
1686  const CDEElementwiseOp& cde_element_op,
1687  const ck::index_t split_k = 1)
1688  {
1689  return Argument{p_a,
1690  p_b,
1691  p_ds,
1692  p_e,
1693  a_g_n_k_wos_lengths,
1694  a_g_n_k_wos_strides,
1695  b_g_k_c_xs_lengths,
1696  b_g_k_c_xs_strides,
1697  ds_g_n_c_wis_lengths,
1698  ds_g_n_c_wis_strides,
1699  e_g_n_c_wis_lengths,
1700  e_g_n_c_wis_strides,
1701  conv_filter_strides,
1702  conv_filter_dilations,
1703  input_left_pads,
1704  input_right_pads,
1705  a_element_op,
1706  b_element_op,
1707  cde_element_op,
1708  split_k};
1709  }
1710 
1711  static auto MakeInvoker() { return Invoker{}; }
1712 
1713  std::unique_ptr<BaseArgument> MakeArgumentPointer(
1714  const void* p_a, // output image
1715  const void* p_b, // weight
1716  const std::array<const void*, NumDTensor>& p_ds, // bias
1717  void* p_e, // input image
1718  const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output image
1719  const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides, // output image
1720  const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
1721  const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight
1722  const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
1723  ds_g_n_c_wis_lengths, // bias
1724  const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
1725  ds_g_n_c_wis_strides, // bias
1726  const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image
1727  const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides, // input image
1728  const std::array<index_t, NDimSpatial>& conv_filter_strides,
1729  const std::array<index_t, NDimSpatial>& conv_filter_dilations,
1730  const std::array<index_t, NDimSpatial>& input_left_pads,
1731  const std::array<index_t, NDimSpatial>& input_right_pads,
1732  const AElementwiseOp& a_element_op,
1733  const BElementwiseOp& b_element_op,
1734  const CDEElementwiseOp& cde_element_op,
1735  const ck::index_t split_k = 1) override
1736  {
1737  return std::make_unique<Argument>(p_a,
1738  p_b,
1739  p_ds,
1740  p_e,
1741  a_g_n_k_wos_lengths,
1742  a_g_n_k_wos_strides,
1743  b_g_k_c_xs_lengths,
1744  b_g_k_c_xs_strides,
1745  ds_g_n_c_wis_lengths,
1746  ds_g_n_c_wis_strides,
1747  e_g_n_c_wis_lengths,
1748  e_g_n_c_wis_strides,
1749  conv_filter_strides,
1750  conv_filter_dilations,
1751  input_left_pads,
1752  input_right_pads,
1753  a_element_op,
1754  b_element_op,
1755  cde_element_op,
1756  split_k);
1757  }
1758 
1759  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
1760  {
1761  return std::make_unique<Invoker>(Invoker{});
1762  }
1763 
1764  std::string GetTypeString() const override
1765  {
1766  auto str = std::stringstream();
1767 
1768  // clang-format off
1769  str << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1"
1770  << "<"
1771  << BlockSize << ", "
1772  << MPerBlock << ", "
1773  << NPerBlock << ", "
1774  << KPerBlock << ", "
1775  << AK1 << ", "
1776  << BK1 << ", "
1777  << getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << ", "
1778  << MPerXDL << ", "
1779  << NPerXDL << ", "
1780  << MXdlPerWave << ", "
1781  << NXdlPerWave << ", "
1782  << ABlockTransferSrcScalarPerVector << ", "
1783  << BBlockTransferSrcScalarPerVector << ", "
1784  << CShuffleMXdlPerWavePerShuffle << ", "
1785  << CShuffleNXdlPerWavePerShuffle;
1786 
1787  if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
1788  is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>()) {
1789  str << ", TransposeTransferInScalarPerVectorAligned: "
1791  << "TransposeTransferOutScalarPerVectorAligned: " << TransposeTransferOutScalarPerVectorAligned;
1792  }
1793 
1794 
1795  str << ">";
1796 
1797  return str.str();
1798  }
1799 
1800  size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
1801  {
1802  auto arg = dynamic_cast<const Argument*>(p_arg);
1803  if(arg)
1804  {
1805  return arg->GetWorkspaceSizeBytes();
1806  }
1807  else
1808  throw std::runtime_error(
1809  "The argument pointer is not an object of "
1810  "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::Argument structure!");
1811  }
1812 
1814  void* p_workspace,
1815  const StreamConfig& = StreamConfig{}) const override
1816  {
1817  auto p_arg_ = dynamic_cast<Argument*>(p_arg);
1818  if(p_arg_)
1819  {
1820  p_arg_->p_workspace_ = p_workspace;
1821  }
1822  else
1823  throw std::runtime_error(
1824  "The argument pointer is not an object of "
1825  "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::Argument structure!");
1826  }
1827 };
1828 
1829 } // namespace device
1830 } // namespace tensor_operation
1831 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:91
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
__host__ constexpr __device__ index_t gcd(index_t x, index_t y)
Definition: math.hpp:154
GemmSpecialization
Definition: gemm_specialization.hpp:11
std::string getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecialization &s)
Definition: convolution_backward_data_specialization.hpp:17
ConvolutionBackwardDataSpecialization
Definition: convolution_backward_data_specialization.hpp:11
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition: kernel_launch.hpp:144
Definition: ck.hpp:266
__global__ void kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op, const index_t batch_count, const std::array< index_t, NumInputs > input_batch_strides, const std::array< index_t, NumOutputs > output_batch_strides)
Definition: gridwise_elementwise_2d.hpp:221
bool is_xdl_supported()
Definition: device_prop.hpp:55
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:275
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
int64_t long_index_t
Definition: ck.hpp:298
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:297
__global__ void kernel_elementwise_batched_dual(const InAGridDescTuple in_grid_desc_tuple_a, const InBGridDescTuple in_grid_desc_tuple_b, const OutAGridDescTuple out_grid_desc_tuple_a, const OutBGridDescTuple out_grid_desc_tuple_b, const InADataTypePointerTuple p_in_global_tuple_a, const InBDataTypePointerTuple p_in_global_tuple_b, const OutADataTypePointerTuple p_out_global_tuple_a, const OutBDataTypePointerTuple p_out_global_tuple_b, const Block2TileMapA block_2_tile_map_a, const Block2TileMapB block_2_tile_map_b, const ElementwiseOperation elementwise_op, const index_t a_grid_size, const index_t batch_count_a, const index_t batch_count_b, const std::array< index_t, NumInputsA > input_batch_strides_a, const std::array< index_t, NumInputsB > input_batch_strides_b, const std::array< index_t, NumOutputsA > output_batch_strides_a, const std::array< index_t, NumOutputsB > output_batch_strides_b)
Definition: gridwise_elementwise_2d.hpp:117
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:68
Definition: stream_config.hpp:10
Definition: gridwise_elementwise_2d.hpp:278
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:406
Definition: sequence.hpp:43
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: tensor_layout.hpp:223
Definition: tensor_layout.hpp:228
Definition: tensor_layout.hpp:120
Definition: tensor_layout.hpp:347
Definition: tensor_layout.hpp:115
Definition: tensor_layout.hpp:342
Definition: transform_conv_bwd_data_to_gemm_v1.hpp:36
__host__ __device__ auto MakeADescriptor_AK0_M_AK1() const
Definition: transform_conv_bwd_data_to_gemm_v1.hpp:651
__host__ __device__ auto MakeBDescriptor_BK0_N_BK1() const
Definition: transform_conv_bwd_data_to_gemm_v1.hpp:897
__host__ __device__ auto MakeCDescriptor_M_N() const
Definition: transform_conv_bwd_data_to_gemm_v1.hpp:1104
Definition: transform_conv_ngchw_to_nhwgc.hpp:31
Definition: device_base.hpp:51
void * p_workspace_
Definition: device_base.hpp:58
Definition: device_base.hpp:62
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:651
std::size_t GetWorkspaceETensorSizeBytes() const
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1029
index_t conv_N_per_block_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1077
index_t num_workgroups_per_Conv_N_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1111
std::vector< DsGridDesc_M_N > ds_grid_desc_m_n_container_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1080
NGCHWTransposeDescType e_out_transpose_desc_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1088
const index_t k_batch_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1110
index_t gemms_count_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1113
std::array< index_t, NDimSpatial+3 > a_g_n_k_wos_lengths_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1103
EDataType * p_e_grid_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1073
Argument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_c_wis_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOp &a_element_op, const BElementwiseOp &b_element_op, const CDEElementwiseOp &cde_element_op, ck::index_t split_k=1)
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:652
GKCYXTransposeDescType b_in_transpose_desc_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1090
long_index_t e_space_size_bytes
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1117
const ADataType * p_a_grid_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1070
std::array< index_t, NDimSpatial > conv_filter_strides_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1106
std::array< index_t, NDimSpatial+3 > e_g_n_c_wis_lengths_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1105
std::vector< std::array< GemmArgs, MaxGroupedGemmGroupsNum > > gemm_kernel_args_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1114
std::array< index_t, NDimSpatial > input_left_pads_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1107
Block2TileMapInOutElementwise elementwise_block_2_ctile_map_transpose_a_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1084
std::vector< AGridDesc_M_K > a_grid_desc_m_k_container_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1078
const BDataType * p_b_grid_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1071
Block2TileMapInOutElementwise elementwise_block_2_ctile_map_transpose_e_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1085
CDEElementwiseOp cde_element_op_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1101
ComputePtrOffsetOfStridedBatch< I1, I1, NumDTensor > compute_ptr_offset_of_batch_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1094
bool bwd_needs_zero_out
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1116
NHWGCTransposeDescType e_in_transpose_desc_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1089
AElementwiseOp a_element_op_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1099
std::vector< BGridDesc_N_K > b_grid_desc_n_k_container_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1079
void Print() const
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1049
std::vector< index_t > gemms_grid_size_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1112
NHWGCTransposeDescType a_out_transpose_desc_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1089
index_t num_group_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1076
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_lengths_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1104
std::array< index_t, NDimSpatial > input_right_pads_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1108
ComputePtrOffsetOfStridedBatch< I1, I1, I0 > compute_ptr_offset_of_workspace_n_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1096
std::vector< EGridDesc_M_N > e_grid_desc_m_n_container_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1081
std::size_t GetWorkspaceSizeBytes() const
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1043
std::size_t GetWorkspaceBTensorSizeBytes() const
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1014
NGCHWTransposeDescType a_in_transpose_desc_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1088
GKYXCTransposeDescType b_out_transpose_desc_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1091
std::size_t GetWorkspaceATensorSizeBytes() const
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:999
BElementwiseOp b_element_op_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1100
GridwiseGemm::DsGridPointer p_ds_grid_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1072
ComputePtrOffsetOfStridedBatch< I1, I1, I0 > compute_ptr_offset_of_n_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1095
Block2TileMapWeiElementwise elementwise_block_2_ctile_map_transpose_b_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1086
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:519
index_t BlockStart_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:557
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:552
index_t BlockEnd_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:557
bool HasMainKBlockLoop_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:558
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:553
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:549
GroupedGemmBlock2ETileMap block_2_ctile_map_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:556
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:550
GemmArgs(AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, GroupedGemmBlock2ETileMap block_2_ctile_map, index_t BlockStart, index_t BlockEnd, bool HasMainKBlockLoop)
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:521
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1122
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1431
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1284
float RunMultiDGemm(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1126
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:303
std::conditional_t< is_NGCHW_GKCYX_NGKHW< ELayout, BLayout, ALayout >() &&NeedTransposeKernel, tensor_layout::convolution::GKYXC, std::conditional_t< is_NGCDHW_GKCZYX_NGKDHW< ELayout, BLayout, ALayout >() &&NeedTransposeKernel, tensor_layout::convolution::GKZYXC, BLayout > > BLayoutAfterTranspose
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:361
static constexpr index_t ElementwiseBlocksize
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:594
static constexpr index_t TransposeTransferOutScalarPerVectorAligned
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:578
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_c_wis_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOp &a_element_op, const BElementwiseOp &b_element_op, const CDEElementwiseOp &cde_element_op, const ck::index_t split_k=1) override
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1713
static constexpr auto conv_ngchw_to_nhwgc_transformer
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:568
std::conditional_t< is_NGCHW_NGKHW< ELayout, BLayout, ALayout >() &&NeedTransposeKernel, tensor_layout::convolution::NHWGC, std::conditional_t< is_NGCDHW_NGKDHW< ELayout, BLayout, ALayout >() &&NeedTransposeKernel, tensor_layout::convolution::NDHWGC, ELayout > > ELayoutAfterTranspose
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:367
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeGKCYXTransposeDesc< NDimSpatial >({}, {}))> GKCYXTransposeDescType
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:589
ADataType ABDataType
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:327
static auto MakeArgument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_c_wis_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOp &a_element_op, const BElementwiseOp &b_element_op, const CDEElementwiseOp &cde_element_op, const ck::index_t split_k=1)
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1666
static constexpr index_t ClusterLengthNPerBlock
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:565
GridwiseElementwise< Tuple< GKCYXTransposeDescType >, Tuple< GKYXCTransposeDescType >, Tuple< const BDataType * >, Tuple< BDataType * >, Block2TileMapWeiElementwise, element_wise::PassThrough, ElementwiseBlocksize, MPerBlock, NPerBlock, MPerBlock/ClusterLengthMPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 1, 0 >, Sequence< 1 >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, I0, I1 > GridwiseElementwiseWeightTranspose
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:630
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1800
decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{})) AGridDesc_M_K
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:503
decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})) EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:510
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeNHWGCTransposeDesc< NDimSpatial >({}, {}))> NHWGCTransposeDescType
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:586
GridwiseElementwise< Tuple< NHWGCTransposeDescType >, Tuple< NGCHWTransposeDescType >, Tuple< const EDataType * >, Tuple< EDataType * >, Block2TileMapInOutElementwise, element_wise::PassThrough, ElementwiseBlocksize, NPerBlock, MPerBlock, NPerBlock/ClusterLengthNPerBlock, MPerBlock/ClusterLengthMPerBlock, Sequence< 1, 0 >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, Sequence< TransposeTransferOutScalarPerVectorAligned >, I0, I1 > GridwiseElementwiseOutputTranspose
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:648
remove_cvref_t< tuple_element_t< 0, ABDsEGridDesc > > AGridDesc_AK0_M_AK1
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:498
constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:495
static constexpr auto I3
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:332
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1660
static constexpr bool NeedTransposeKernel
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:341
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1759
static constexpr auto I1
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:330
static constexpr auto I0
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:329
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeGKYXCTransposeDesc< NDimSpatial >({}, {}))> GKYXCTransposeDescType
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:592
static bool IsSupportedArgument(const Argument &arg)
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1438
GridwiseGemmMultipleD_xdl_cshuffle< GridwiseGemmMultiDTemplateParams > GridwiseGemm
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:466
static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1 &desc_k0_m_k1)
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:481
BlockToCTileMap_M00_N0_M01Adapt< MPerBlock, NPerBlock > Block2TileMapWeiElementwise
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:561
remove_cvref_t< tuple_element_t< 1, ABDsEGridDesc > > BGridDesc_BK0_N_BK1
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:499
std::conditional_t< is_NGCHW_NGKHW< ELayout, BLayout, ALayout >() &&NeedTransposeKernel, tensor_layout::convolution::NHWGK, std::conditional_t< is_NGCDHW_NGKDHW< ELayout, BLayout, ALayout >() &&NeedTransposeKernel, tensor_layout::convolution::NDHWGK, ALayout > > ALayoutAfterTranspose
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:354
static auto GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform &conv_to_gemm_transform)
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:389
static constexpr index_t ClusterLengthMPerBlock
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:563
decltype(GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(EGridDesc_M_N{})) Block2ETileMap
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:514
remove_cvref_t< tuple_element_t< 2, ABDsEGridDesc > > DsGridDesc_M_N
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:500
static auto MakeInvoker()
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1711
static constexpr bool isATensorColMajor
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:334
std::string GetTypeString() const override
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1764
void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1813
OffsettedBlockToCTileMap< Block2ETileMap > GroupedGemmBlock2ETileMap
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:516
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeNGCHWTransposeDesc< NDimSpatial >({}, {}))> NGCHWTransposeDescType
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:583
static auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N e_grid_desc_m_n)
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:474
static constexpr index_t MaxGroupedGemmGroupsNum
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:312
static constexpr bool CTranspose
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:345
std::conditional_t< CTranspose, GridwiseGemmMultipleD_xdl_cshuffle< GridwiseGemmCTransposeTemplateParameters >, GridwiseGemm > GridwiseGemmCTranspose
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:470
GridwiseElementwise< Tuple< NGCHWTransposeDescType >, Tuple< NHWGCTransposeDescType >, Tuple< const ADataType * >, Tuple< ADataType * >, Block2TileMapInOutElementwise, element_wise::PassThrough, ElementwiseBlocksize, NPerBlock, MPerBlock, NPerBlock/ClusterLengthNPerBlock, MPerBlock/ClusterLengthMPerBlock, Sequence< 1, 0 >, Sequence< TransposeTransferInScalarPerVectorAligned >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, I1, I0 > GridwiseElementwiseInputTranspose
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:612
BlockToCTileMap_M00_N0_M01Adapt< NPerBlock, MPerBlock > Block2TileMapInOutElementwise
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:560
remove_cvref_t< tuple_element_t< 3, ABDsEGridDesc > > EGridDesc_M_N
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:501
static constexpr index_t NumDTensor
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:320
decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{})) DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:508
static constexpr auto I2
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:331
static constexpr index_t TransposeTransferInScalarPerVectorAligned
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:576
static constexpr bool IsSplitKSupported
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:322
decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform)) ABDsEGridDesc
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:496
decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})) BGridDesc_N_K
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:504
static constexpr GemmSpecialization GemmSpec
Definition: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:321
Definition: device_grouped_conv_bwd_data_multiple_d.hpp:36
Definition: unary_element_wise_operation.hpp:308
#define CK_ENV(name)
Definition: env.hpp:129