/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp Source File
epilogue_cshuffle_v3_reduce_wmma.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
8 
9 namespace ck {
10 
11 template <typename ReduceAccDataType,
12  typename ReducePtrsGlobal,
13  typename ReduceOperations,
14  typename ReduceInElementwiseOperations,
15  typename ReduceAccElementwiseOperations,
16  typename ReduceGlobalMemoryDataOperation,
17  typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
18  index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
19  index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>
21 {
22  using ReduceAccDataType_ = ReduceAccDataType;
23  using ReducePtrsGlobal_ = ReducePtrsGlobal;
24  using ReduceOperations_ = ReduceOperations;
25  using ReduceInElementwiseOperations_ = ReduceInElementwiseOperations;
26  using ReduceAccElementwiseOperations_ = ReduceAccElementwiseOperations;
27  using ReduceGlobalMemoryDataOperation_ = ReduceGlobalMemoryDataOperation;
29  CReduceThreadClusterLengths_MPerBlock_NPerBlock;
31  CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock;
33  CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock;
34 };
35 
36 template <typename DsDataType,
37  typename EDataType,
38  typename AccDataType,
39  typename CShuffleDataType,
40  index_t MPerBlock,
41  index_t NPerBlock,
42  index_t MPerWmma,
43  index_t NPerWmma,
44  index_t MRepeat,
45  index_t NRepeat,
46  index_t CShuffleMRepeatPerShuffle,
47  index_t CShuffleNRepeatPerShuffle,
48  typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
49  typename CDEShuffleBlockTransferScalarPerVectors,
50  typename CDEElementwiseOperation,
51  typename ThisThreadBlock,
52  typename BlockwiseGemmPipe,
54  index_t BlockSize,
55  typename ReduceTrait>
57  : EpilogueCShuffleBase<DsDataType,
58  EDataType,
59  AccDataType,
60  CShuffleDataType,
61  MPerBlock,
62  NPerBlock,
63  MPerWmma,
64  NPerWmma,
65  MRepeat,
66  NRepeat,
67  CShuffleMRepeatPerShuffle,
68  CShuffleNRepeatPerShuffle,
69  CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
70  CDEShuffleBlockTransferScalarPerVectors,
71  CDEElementwiseOperation,
72  ThisThreadBlock,
73  BlockwiseGemmPipe>
74 {
76  DsDataType,
77  EDataType,
78  AccDataType,
79  CShuffleDataType,
80  MPerBlock,
81  NPerBlock,
82  MPerWmma,
83  NPerWmma,
84  MRepeat,
85  NRepeat,
86  CShuffleMRepeatPerShuffle,
87  CShuffleNRepeatPerShuffle,
88  CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
89  CDEShuffleBlockTransferScalarPerVectors,
90  CDEElementwiseOperation,
92  BlockwiseGemmPipe>;
93 
97  using Base::I0;
98  using Base::I1;
99  using Base::I3;
100  using Base::NumDTensor;
101 
102  // assume Reduce is packed tensor
103  __device__ static auto MakeReduceGridDescriptor_M(index_t MRaw)
104  {
106 
107  const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
108 
109  const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
110  const auto MPad = M - MRaw;
111 
112  if constexpr(GemmSpec == GemmSpecialization::MPadding ||
113  GemmSpec == GemmSpecialization::MNPadding ||
114  GemmSpec == GemmSpecialization::MKPadding ||
115  GemmSpec == GemmSpecialization::MNKPadding)
116  {
117  // pad M
118  return transform_tensor_descriptor(d_grid_desc_mraw,
122  }
123  else
124  {
125  // not pad M
126  return d_grid_desc_mraw;
127  }
128  }
129 
131 
132  __device__ static constexpr auto
134  {
135  const auto M = d_grid_desc_m.GetLength(I0);
136  const auto MBlock = M / MPerBlock;
137 
138  const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor(
139  d_grid_desc_m,
143 
144  return reduce_grid_desc_mblock_mperblock;
145  }
146 
148  typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid_,
149  const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops_,
150  const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops_,
151  const index_t MRaw_)
152  : p_reduces_grid(p_reduces_grid_),
153  reduce_in_element_ops(reduce_in_element_ops_),
154  reduce_out_element_ops(reduce_out_element_ops_),
155  MRaw(MRaw_),
157  {
158  }
159 
160  template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
161  typename CThreadBuf,
162  typename DsGridPointer,
163  typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
164  typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
165  __device__ void Run(CThreadBuf& c_thread_buf,
166  DsGridPointer p_ds_grid,
167  EDataType* p_e_grid,
168  void* p_shared,
169  const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
170  ds_grid_desc_mblock_mperblock_nblock_nperblock,
171  const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
172  e_grid_desc_mblock_mperblock_nblock_nperblock,
173  CDEElementwiseOperation& cde_element_op,
174  const index_t& block_m_id,
175  const index_t& block_n_id)
176  {
177  auto reduce_grid_desc_mblock_mperblock =
179 
180  const auto ds_grid_buf = generate_tuple(
181  [&](auto i) {
182  return make_dynamic_buffer<AddressSpaceEnum::Global>(
183  p_ds_grid[i],
184  ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
185  },
187 
188  auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
189  p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
190 
191  // C mapping in single thread.
192  constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
193  BlockwiseGemmPipe::
194  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
195 
196  // LDS buffer
197  constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
199 
200  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
201  static_cast<CShuffleDataType*>(p_shared),
202  c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
203  .GetElementSpaceSize());
204 
205  // Thread transfer Vgpr to LDS
206  auto c_thread_copy_vgpr_to_lds = GetVgprToLDSEpilogueDescriptor();
207 
208  // Space Filling Curve Vgpr
209  constexpr auto sfc_c_vgpr = typename Base::SpaceFillingCurveVgpr{};
210 
211  // Space Filling Curve Vmem
212  constexpr auto sfc_cde_global = typename Base::SpaceFillingCurveVmem{};
213 
214  // Block descriptor
215  constexpr auto
216  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
218 
219  // tuple of reference to C/Ds tensor descriptors
220  const auto c_ds_desc_refs = concat_tuple_of_reference(
221  tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
222  generate_tie([&](auto i) -> const auto& // return type should be reference
223  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
224  Number<NumDTensor>{}));
225 
226  // Thread transfer LDS to Vmem
227  auto cde_shuffle_block_copy_lds_and_global =
228  Base::template GetLDSToVmemEpilogueDescriptor<EGlobalMemoryDataOperation, EDataType>(
229  c_ds_desc_refs,
230  e_grid_desc_mblock_mperblock_nblock_nperblock,
231  cde_element_op,
232  block_m_id,
233  block_n_id);
234 
235  // tuple of reference to C/Ds tensor buffers
236  const auto c_ds_buf_refs = concat_tuple_of_reference(
237  tie(c_shuffle_block_buf),
238  generate_tie([&](auto i) -> const auto& // return type should be reference
239  { return ds_grid_buf[i]; },
240  Number<NumDTensor>{}));
241 
242  // LDS c_reduce_block_desc_mperblock_nperblock
243  constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor(
244  c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
245  make_tuple(
248  c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength(
249  I1)),
252  c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength(
253  I3))),
256 
257  static_assert(
258  ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0) *
259  ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I1) ==
260  BlockSize,
261  "wrong!");
262 
263  static_assert(
264  (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) %
265  ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0) ==
266  0 &&
267  (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) %
268  ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I1) ==
269  0,
270  "wrong!");
271 
272  constexpr index_t mreduce_per_thread =
273  (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) /
274  ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0);
275 
276  constexpr index_t nreduce_per_thread =
277  (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) /
278  ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I1);
279 
280  static constexpr index_t NumReduce = ReduceTrait::ReducePtrsGlobal_::Size();
281 
282  constexpr auto c_reduce_thread_lengths_mperblock_nperblock =
284 
285  // VGPR c_reduce_thread_desc_mperblock_nperblock
286  constexpr auto c_reduce_thread_desc_mperblock_nperblock =
289 
290  // VGPR reduce_thread_desc_mperblock
291  constexpr auto reduce_thread_desc_mperblock =
293 
294  // VGPR reduce_thread_desc_mblock_mperblock
295  constexpr auto reduce_thread_desc_mblock_mperblock =
297 
298  auto c_reduce_thread_buf =
299  make_static_buffer<AddressSpaceEnum::Vgpr, typename ReduceTrait::ReduceAccDataType_>(
300  c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
301 
302  // reduce: threadwise copy from LDS to VGPR
303  constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor(
304  typename ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_{},
305  Sequence<1, 0>{});
306 
307  const auto c_reduce_thread_cluster_idx = c_reduce_thread_cluster_desc.CalculateBottomIndex(
309 
310  const auto c_reduce_thread_data_idx_begin =
311  c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock;
312 
313  auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
314  CShuffleDataType,
315  typename ReduceTrait::ReduceAccDataType_,
316  decltype(c_reduce_block_desc_mperblock_nperblock),
317  decltype(c_reduce_thread_desc_mperblock_nperblock),
318  decltype(c_reduce_thread_lengths_mperblock_nperblock),
320  1,
321  ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_,
322  1,
323  true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
324 
325  auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple(
326  [&](auto I) {
327  auto p_reduce_grid = p_reduces_grid[I];
328  auto reduce_acc_element_op = reduce_out_element_ops[I];
329 
331  typename ReduceTrait::ReduceAccDataType_,
332  remove_pointer_t<decltype(p_reduce_grid)>,
333  decltype(reduce_thread_desc_mblock_mperblock),
334  decltype(reduce_grid_desc_mblock_mperblock),
335  decltype(reduce_acc_element_op),
338  1,
339  ReduceTrait::CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock_,
340  ReduceTrait::ReduceGlobalMemoryDataOperation_::At(I),
341  1,
342  false>{reduce_grid_desc_mblock_mperblock,
343  make_multi_index(block_m_id, // mblock
344  c_reduce_thread_data_idx_begin[I0]), // mperblock
345  reduce_acc_element_op};
346  },
348 
349  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
350 
351  static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
352 
353  // CShuffle and Store
354  static_for<0, num_access, 1>{}([&](auto access_id) {
355  // make sure it's safe to write to LDS
356  block_sync_lds();
357 
358  // each thread write its data from VGPR to LDS
359  c_thread_copy_vgpr_to_lds.Run(
360  c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
361  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
362  c_thread_buf,
363  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
364  c_shuffle_block_buf);
365 
366  // make sure it's safe to read from LDS
367  block_sync_lds();
368 
369  // each block loads its C data from LDS, D from global, applies elementwise
370  // operation and stores result E to global
371  cde_shuffle_block_copy_lds_and_global.Run(
372  c_ds_desc_refs,
373  c_ds_buf_refs,
374  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
375  tie(e_grid_buf));
376 
377  {
378  c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
379  c_shuffle_block_buf,
380  c_reduce_thread_desc_mperblock_nperblock,
381  make_tuple(I0, I0),
382  c_reduce_thread_buf);
383 
384  static_for<0, NumReduce, 1>{}([&](auto In) {
385  auto& p_reduce_grid = p_reduces_grid[In];
386 
387  auto reduce_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
388  p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize());
389 
390  auto reduce_thread_buf =
392  typename ReduceTrait::ReduceAccDataType_>(
393  reduce_thread_desc_mperblock.GetElementSpaceSize());
394 
395  auto& reduce_in_element_op = reduce_in_element_ops[In];
396 
397  auto& reduce_thread_copy_vgpr_to_global =
398  reduce_tuple_thread_copy_vgpr_to_global(In);
399 
400  using ReduceOperation =
401  remove_cvref_t<decltype(typename ReduceTrait::ReduceOperations_{}[In])>;
402  using ThreadwiseReduce =
403  ThreadwiseReduction<typename ReduceTrait::ReduceAccDataType_,
404  decltype(c_reduce_thread_desc_mperblock_nperblock),
405  decltype(reduce_thread_desc_mperblock),
406  ReduceOperation,
407  false>;
408 
409  // Global write Gemm shuffle + reduction
410  const auto reduce_identityVal = ReduceOperation::template GetIdentityValue<
411  typename ReduceTrait::ReduceAccDataType_>();
412 
414  [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; });
415 
416  // reduce in VGPR
417  static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
418  static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
419  constexpr auto offset =
420  Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
421  make_tuple(im, in))>{};
422 
423  reduce_in_element_op(c_reduce_thread_buf(offset),
424  c_reduce_thread_buf(offset));
425  });
426  });
427 
428  ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf);
429 
430  // copy from VGPR to Global
431  reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock,
432  make_tuple(I0, I0),
433  reduce_thread_buf,
434  reduce_grid_desc_mblock_mperblock,
435  reduce_grid_buf);
436 
437  if constexpr(access_id < num_access - 1)
438  {
439  constexpr auto c_global_step = sfc_cde_global.GetForwardStep(access_id);
440  reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
441  reduce_grid_desc_mblock_mperblock,
442  make_tuple(c_global_step[I0], c_global_step[I1]));
443  }
444  });
445  }
446 
447  if constexpr(access_id < num_access - 1)
448  {
449  constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
450  // move on Ds
451  static_for<0, NumDTensor, 1>{}([&](auto i) {
452  cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
453  c_ds_desc_refs, i + I1, cde_global_step);
454  });
455 
456  // move on E
457  cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
458  tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
459  }
460  });
461  }
462 
463  typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid;
464  typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops;
465  typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops;
468 };
469 
470 } // namespace ck
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:270
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:279
typename remove_pointer< T >::type remove_pointer_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
__host__ constexpr __device__ auto make_static_buffer(Number< N >)
Definition: static_buffer.hpp:186
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:301
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: epilogue_cshuffle_v3_wmma_base.hpp:30
static constexpr index_t NumDTensor
Definition: epilogue_cshuffle_v3_wmma_base.hpp:39
static constexpr auto I0
Definition: epilogue_cshuffle_v3_wmma_base.hpp:31
static constexpr auto I3
Definition: epilogue_cshuffle_v3_wmma_base.hpp:34
static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:64
static constexpr auto I1
Definition: epilogue_cshuffle_v3_wmma_base.hpp:32
static __device__ auto GetVgprToLDSEpilogueDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:119
static constexpr __device__ auto GetCShuffleLDSDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:79
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:74
ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:465
index_t MRaw
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:466
static constexpr __device__ auto MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M &d_grid_desc_m)
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:133
ReduceTrait::ReducePtrsGlobal_ p_reduces_grid
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:463
ReduceGridDesc_M reduce_grid_desc_m
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:467
__device__ void Run(CThreadBuf &c_thread_buf, DsGridPointer p_ds_grid, EDataType *p_e_grid, void *p_shared, const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, CDEElementwiseOperation &cde_element_op, const index_t &block_m_id, const index_t &block_n_id)
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:165
static constexpr auto I0
Definition: epilogue_cshuffle_v3_wmma_base.hpp:31
decltype(MakeReduceGridDescriptor_M(1)) ReduceGridDesc_M
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:130
static constexpr auto I3
Definition: epilogue_cshuffle_v3_wmma_base.hpp:34
static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:64
__device__ EpilogueReduceCShuffle(typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid_, const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops_, const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops_, const index_t MRaw_)
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:147
ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:464
static __device__ auto MakeReduceGridDescriptor_M(index_t MRaw)
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:103
static constexpr auto I1
Definition: epilogue_cshuffle_v3_wmma_base.hpp:32
static __device__ auto GetVgprToLDSEpilogueDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:119
static constexpr __device__ auto GetCShuffleLDSDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:79
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:21
CReduceThreadClusterLengths_MPerBlock_NPerBlock CReduceThreadClusterLengths_MPerBlock_NPerBlock_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:29
ReduceGlobalMemoryDataOperation ReduceGlobalMemoryDataOperation_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:27
ReduceInElementwiseOperations ReduceInElementwiseOperations_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:25
static constexpr index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:32
ReducePtrsGlobal ReducePtrsGlobal_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:23
ReduceAccElementwiseOperations ReduceAccElementwiseOperations_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:26
static constexpr index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:30
ReduceOperations ReduceOperations_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:24
ReduceAccDataType ReduceAccDataType_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:22
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: thread_group.hpp:12
Definition: reduction_functions_threadwise.hpp:23
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33