11 template <
typename ReduceAccDataType,
12 typename ReducePtrsGlobal,
13 typename ReduceOperations,
14 typename ReduceInElementwiseOperations,
15 typename ReduceAccElementwiseOperations,
16 typename ReduceGlobalMemoryDataOperation,
17 typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
18 index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
19 index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>
29 CReduceThreadClusterLengths_MPerBlock_NPerBlock;
31 CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock;
33 CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock;
36 template <
typename DsDataType,
39 typename CShuffleDataType,
46 index_t CShuffleMRepeatPerShuffle,
47 index_t CShuffleNRepeatPerShuffle,
48 typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
49 typename CDEShuffleBlockTransferScalarPerVectors,
50 typename CDEElementwiseOperation,
52 typename BlockwiseGemmPipe,
67 CShuffleMRepeatPerShuffle,
68 CShuffleNRepeatPerShuffle,
69 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
70 CDEShuffleBlockTransferScalarPerVectors,
71 CDEElementwiseOperation,
86 CShuffleMRepeatPerShuffle,
87 CShuffleNRepeatPerShuffle,
88 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
89 CDEShuffleBlockTransferScalarPerVectors,
90 CDEElementwiseOperation,
110 const auto MPad = M -
MRaw;
112 if constexpr(GemmSpec == GemmSpecialization::MPadding ||
113 GemmSpec == GemmSpecialization::MNPadding ||
114 GemmSpec == GemmSpecialization::MKPadding ||
115 GemmSpec == GemmSpecialization::MNKPadding)
126 return d_grid_desc_mraw;
132 __device__
static constexpr
auto
135 const auto M = d_grid_desc_m.GetLength(
I0);
136 const auto MBlock = M / MPerBlock;
144 return reduce_grid_desc_mblock_mperblock;
148 typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid_,
149 const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops_,
150 const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops_,
162 typename DsGridPointer,
163 typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
164 typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
165 __device__
void Run(CThreadBuf& c_thread_buf,
166 DsGridPointer p_ds_grid,
169 const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
170 ds_grid_desc_mblock_mperblock_nblock_nperblock,
171 const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
172 e_grid_desc_mblock_mperblock_nblock_nperblock,
173 CDEElementwiseOperation& cde_element_op,
177 auto reduce_grid_desc_mblock_mperblock =
182 return make_dynamic_buffer<AddressSpaceEnum::Global>(
184 ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
188 auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
189 p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
192 constexpr
auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
194 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
197 constexpr
auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
200 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
201 static_cast<CShuffleDataType*
>(p_shared),
202 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
203 .GetElementSpaceSize());
216 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
221 tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
223 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
227 auto cde_shuffle_block_copy_lds_and_global =
228 Base::template GetLDSToVmemEpilogueDescriptor<EGlobalMemoryDataOperation, EDataType>(
230 e_grid_desc_mblock_mperblock_nblock_nperblock,
237 tie(c_shuffle_block_buf),
239 {
return ds_grid_buf[i]; },
244 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
248 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength(
252 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength(
258 ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(
I0) *
259 ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(
I1) ==
264 (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) %
265 ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(
I0) ==
267 (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) %
268 ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(
I1) ==
272 constexpr
index_t mreduce_per_thread =
273 (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) /
274 ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(
I0);
276 constexpr
index_t nreduce_per_thread =
277 (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) /
278 ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(
I1);
280 static constexpr
index_t NumReduce = ReduceTrait::ReducePtrsGlobal_::Size();
282 constexpr
auto c_reduce_thread_lengths_mperblock_nperblock =
286 constexpr
auto c_reduce_thread_desc_mperblock_nperblock =
291 constexpr
auto reduce_thread_desc_mperblock =
295 constexpr
auto reduce_thread_desc_mblock_mperblock =
298 auto c_reduce_thread_buf =
299 make_static_buffer<AddressSpaceEnum::Vgpr, typename ReduceTrait::ReduceAccDataType_>(
300 c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
304 typename ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_{},
307 const auto c_reduce_thread_cluster_idx = c_reduce_thread_cluster_desc.CalculateBottomIndex(
310 const auto c_reduce_thread_data_idx_begin =
311 c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock;
315 typename ReduceTrait::ReduceAccDataType_,
316 decltype(c_reduce_block_desc_mperblock_nperblock),
317 decltype(c_reduce_thread_desc_mperblock_nperblock),
318 decltype(c_reduce_thread_lengths_mperblock_nperblock),
321 ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_,
323 true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
331 typename ReduceTrait::ReduceAccDataType_,
333 decltype(reduce_thread_desc_mblock_mperblock),
334 decltype(reduce_grid_desc_mblock_mperblock),
335 decltype(reduce_acc_element_op),
339 ReduceTrait::CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock_,
340 ReduceTrait::ReduceGlobalMemoryDataOperation_::At(I),
342 false>{reduce_grid_desc_mblock_mperblock,
344 c_reduce_thread_data_idx_begin[
I0]),
345 reduce_acc_element_op};
349 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
351 static_assert(num_access == sfc_cde_global.GetNumOfAccess(),
"wrong!");
359 c_thread_copy_vgpr_to_lds.Run(
360 c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
361 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
363 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
364 c_shuffle_block_buf);
371 cde_shuffle_block_copy_lds_and_global.Run(
374 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
378 c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
380 c_reduce_thread_desc_mperblock_nperblock,
382 c_reduce_thread_buf);
387 auto reduce_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
388 p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize());
390 auto reduce_thread_buf =
392 typename ReduceTrait::ReduceAccDataType_>(
393 reduce_thread_desc_mperblock.GetElementSpaceSize());
397 auto& reduce_thread_copy_vgpr_to_global =
398 reduce_tuple_thread_copy_vgpr_to_global(In);
400 using ReduceOperation =
401 remove_cvref_t<decltype(
typename ReduceTrait::ReduceOperations_{}[In])>;
402 using ThreadwiseReduce =
404 decltype(c_reduce_thread_desc_mperblock_nperblock),
405 decltype(reduce_thread_desc_mperblock),
410 const auto reduce_identityVal = ReduceOperation::template GetIdentityValue<
411 typename ReduceTrait::ReduceAccDataType_>();
414 [&](
auto I) { reduce_thread_buf(I) = reduce_identityVal; });
419 constexpr
auto offset =
420 Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
423 reduce_in_element_op(c_reduce_thread_buf(offset),
424 c_reduce_thread_buf(offset));
428 ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf);
431 reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock,
434 reduce_grid_desc_mblock_mperblock,
437 if constexpr(access_id < num_access - 1)
439 constexpr
auto c_global_step = sfc_cde_global.GetForwardStep(access_id);
440 reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
441 reduce_grid_desc_mblock_mperblock,
447 if constexpr(access_id < num_access - 1)
449 constexpr
auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
452 cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
453 c_ds_desc_refs, i +
I1, cde_global_step);
457 cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
458 tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
GemmSpecialization
Definition: gemm_specialization.hpp:11
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:279
typename remove_pointer< T >::type remove_pointer_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
__host__ constexpr __device__ auto make_static_buffer(Number< N >)
Definition: static_buffer.hpp:186
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:301
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: epilogue_cshuffle_v3_wmma_base.hpp:30
static constexpr index_t NumDTensor
Definition: epilogue_cshuffle_v3_wmma_base.hpp:39
static constexpr auto I0
Definition: epilogue_cshuffle_v3_wmma_base.hpp:31
static constexpr auto I3
Definition: epilogue_cshuffle_v3_wmma_base.hpp:34
static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:64
static constexpr auto I1
Definition: epilogue_cshuffle_v3_wmma_base.hpp:32
static __device__ auto GetVgprToLDSEpilogueDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:119
static constexpr __device__ auto GetCShuffleLDSDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:79
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:74
ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:465
index_t MRaw
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:466
static constexpr __device__ auto MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M &d_grid_desc_m)
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:133
ReduceTrait::ReducePtrsGlobal_ p_reduces_grid
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:463
ReduceGridDesc_M reduce_grid_desc_m
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:467
__device__ void Run(CThreadBuf &c_thread_buf, DsGridPointer p_ds_grid, EDataType *p_e_grid, void *p_shared, const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, CDEElementwiseOperation &cde_element_op, const index_t &block_m_id, const index_t &block_n_id)
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:165
static constexpr auto I0
Definition: epilogue_cshuffle_v3_wmma_base.hpp:31
decltype(MakeReduceGridDescriptor_M(1)) ReduceGridDesc_M
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:130
static constexpr auto I3
Definition: epilogue_cshuffle_v3_wmma_base.hpp:34
static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:64
__device__ EpilogueReduceCShuffle(typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid_, const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops_, const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops_, const index_t MRaw_)
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:147
ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:464
static __device__ auto MakeReduceGridDescriptor_M(index_t MRaw)
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:103
static constexpr auto I1
Definition: epilogue_cshuffle_v3_wmma_base.hpp:32
static __device__ auto GetVgprToLDSEpilogueDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:119
static constexpr __device__ auto GetCShuffleLDSDescriptor()
Definition: epilogue_cshuffle_v3_wmma_base.hpp:79
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:21
CReduceThreadClusterLengths_MPerBlock_NPerBlock CReduceThreadClusterLengths_MPerBlock_NPerBlock_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:29
ReduceGlobalMemoryDataOperation ReduceGlobalMemoryDataOperation_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:27
ReduceInElementwiseOperations ReduceInElementwiseOperations_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:25
static constexpr index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:32
ReducePtrsGlobal ReducePtrsGlobal_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:23
ReduceAccElementwiseOperations ReduceAccElementwiseOperations_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:26
static constexpr index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:30
ReduceOperations ReduceOperations_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:24
ReduceAccDataType ReduceAccDataType_
Definition: epilogue_cshuffle_v3_reduce_wmma.hpp:22
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: thread_group.hpp:12
Definition: reduction_functions_threadwise.hpp:23
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33