/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp Source File
epilogue_cshuffle_v3_wmma_base.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
9 
10 namespace ck {
11 
12 template <typename DsDataType,
13  typename EDataType,
14  typename AccDataType,
15  typename CShuffleDataType,
16  index_t MPerBlock,
17  index_t NPerBlock,
18  index_t MPerWmma,
19  index_t NPerWmma,
20  index_t MRepeat,
21  index_t NRepeat,
22  index_t CShuffleMRepeatPerShuffle,
23  index_t CShuffleNRepeatPerShuffle,
24  typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
25  typename CDEShuffleBlockTransferScalarPerVectors,
26  typename CDEElementwiseOperation,
27  typename ThisThreadBlock,
28  typename BlockwiseGemmPipe>
30 {
31  static constexpr auto I0 = Number<0>{};
32  static constexpr auto I1 = Number<1>{};
33  static constexpr auto I2 = Number<2>{};
34  static constexpr auto I3 = Number<3>{};
35  static constexpr auto I4 = Number<4>{};
36  static constexpr auto I5 = Number<5>{};
37  static constexpr auto I6 = Number<6>{};
38 
39  static constexpr index_t NumDTensor = DsDataType::Size();
40  static constexpr auto EShuffleBlockTransferScalarPerVector =
41  CDEShuffleBlockTransferScalarPerVectors{}[I0];
42 
46  Sequence<CShuffleMRepeatPerShuffle,
47  1,
48  1,
49  CShuffleNRepeatPerShuffle,
50  1,
51  1,
52  BlockwiseGemmPipe::MAccVgprs>>;
53 
57  Sequence<1,
58  CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma,
59  1,
60  CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma>>;
61 
62  // *Caution Here repeat is shuffle repeat
63  __device__ static constexpr auto
65  {
66  constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
67  constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
68 
69  constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
71  make_tuple(I1,
73  I1,
75 
76  return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
77  }
78 
79  __device__ static constexpr auto GetCShuffleLDSDescriptor()
80  {
81  // C mapping in single block
82  constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
83  BlockwiseGemmPipe::
84  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
85 
86  constexpr auto MWave =
87  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
88  .GetLength(I1);
89  constexpr auto MSubGroup =
90  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
91  .GetLength(I2);
92  constexpr auto NWave =
93  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
94  .GetLength(I4);
95  constexpr auto NThreadPerSubGroup =
96  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
97  .GetLength(I5);
98  constexpr auto MAccVgprs =
99  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
100  .GetLength(I6);
101 
106  Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
107  MWave, // MWave
108  MSubGroup, // MSubGroup * MAccVgprs = MPerWmma
109  MAccVgprs)),
112  Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
113  NWave, // NWave
114  NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma
117  }
118 
119  __device__ static auto GetVgprToLDSEpilogueDescriptor()
120  {
121  // C mapping in single block
122  constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
123  BlockwiseGemmPipe::
124  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
125 
126  constexpr auto MWave =
127  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
128  .GetLength(I1);
129  constexpr auto MSubGroup =
130  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
131  .GetLength(I2);
132  constexpr auto NWave =
133  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
134  .GetLength(I4);
135  constexpr auto NThreadPerSubGroup =
136  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
137  .GetLength(I5);
138  constexpr auto MAccVgprs =
139  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
140  .GetLength(I6);
141 
142  // calculate origin of thread output tensor on global memory
143  // blockwise GEMM c matrix starting index
144  const auto c_thread_mtx_on_block =
145  BlockwiseGemmPipe::CalculateCThreadOriginDataIndex(I0, I0);
146 
147  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
148  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
149 
150  const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
152  make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
155 
156  const auto m_thread_data_on_block_idx =
157  m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor
158  .CalculateBottomIndex(make_multi_index(m_thread_data_on_block));
159 
160  const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
162  make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
165 
166  const auto n_thread_data_on_block_idx =
167  n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
168  make_multi_index(n_thread_data_on_block));
169 
171  AccDataType,
172  CShuffleDataType,
173  decltype(BlockwiseGemmPipe::
174  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()),
175  decltype(GetCShuffleLDSDescriptor()),
177  Sequence<CShuffleMRepeatPerShuffle,
178  I1,
179  I1,
180  CShuffleNRepeatPerShuffle,
181  I1,
182  I1,
183  MAccVgprs>,
185  6,
186  1,
188  1,
189  true>{GetCShuffleLDSDescriptor(),
191  m_thread_data_on_block_idx[I1],
192  m_thread_data_on_block_idx[I2],
193  0,
194  n_thread_data_on_block_idx[I1],
195  n_thread_data_on_block_idx[I2],
196  m_thread_data_on_block_idx[I3]),
198  }
199 
200  template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
201  typename InterDataType,
202  typename CDsDescRefs,
203  typename EGridDesc>
204  __device__ static auto
205  GetLDSToVmemEpilogueDescriptor(CDsDescRefs& c_ds_desc_refs,
206  EGridDesc& e_grid_desc_mblock_mperblock_nblock_nperblock,
207  CDEElementwiseOperation& cde_element_op,
208  const index_t& block_m_id,
209  const index_t& block_n_id)
210  {
211  // tuple of starting index of C/Ds blockwise copy
212  const auto idx_c_ds_block_begin = container_concat(
213  make_tuple(make_multi_index(0, 0, 0, 0)),
214  generate_tuple([&](auto) { return make_multi_index(block_m_id, 0, block_n_id, 0); },
215  Number<NumDTensor>{}));
216 
217  // blockwise copy which loads C from LDS, D from global, applies elementwise
218  // operation and stores result E to global
220  ThisThreadBlock, // ThreadGroup
221  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
223  CDsDescRefs,
224  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
225  CDEElementwiseOperation, // ElementwiseOperation,
226  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOps,
227  Sequence<1,
228  CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma,
229  1,
230  CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves *
231  NPerWmma>, // BlockSliceLengths,
232  CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
233  Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder,
234  Sequence<0, 1, 2, 3>, // SrcDimAccessOrder,
235  Sequence<0, 1, 2, 3>, // DstDimAccessOrder,
236  3, // SrcVectorDim,
237  3, // DstVectorDim,
238  CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors
239  EShuffleBlockTransferScalarPerVector, // DstScalarPerVector
243  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
244  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
245  1,
246  Tuple<InterDataType>>{c_ds_desc_refs,
247  idx_c_ds_block_begin,
248  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
249  make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
250  cde_element_op};
251  }
252 };
253 
254 } // namespace ck
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:47
Definition: ck.hpp:270
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:279
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:301
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
Definition: epilogue_cshuffle_v3_wmma_base.hpp:30
static constexpr auto I6
Definition: epilogue_cshuffle_v3_wmma_base.hpp:37
static constexpr auto I2
Definition: epilogue_cshuffle_v3_wmma_base.hpp:33
static constexpr auto I4
Definition: epilogue_cshuffle_v3_wmma_base.hpp:35
static constexpr auto I5
Definition: epilogue_cshuffle_v3_wmma_base.hpp:36
static constexpr index_t NumDTensor
Definition: epilogue_cshuffle_v3_wmma_base.hpp:39
static constexpr auto I0
Definition: epilogue_cshuffle_v3_wmma_base.hpp:31
static constexpr auto I3
Definition: epilogue_cshuffle_v3_wmma_base.hpp:34
static __device__ auto GetLDSToVmemEpilogueDescriptor(CDsDescRefs &c_ds_desc_refs, EGridDesc &e_grid_desc_mblock_mperblock_nblock_nperblock, CDEElementwiseOperation &cde_element_op, const index_t &block_m_id, const index_t &block_n_id)
Definition: epilogue_cshuffle_v3_wmma_base.hpp:205
static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:64
static constexpr auto I1
Definition: epilogue_cshuffle_v3_wmma_base.hpp:32
static __device__ auto GetVgprToLDSEpilogueDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:119
static constexpr __device__ auto GetCShuffleLDSDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:79
static constexpr auto EShuffleBlockTransferScalarPerVector
Definition: epilogue_cshuffle_v3_wmma_base.hpp:40
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: thread_group.hpp:12
Definition: thread_group_tensor_slice_transfer_v7r3.hpp:48
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: unary_element_wise_operation.hpp:340