30 namespace tensor_operation {
33 template <
typename ALayout,
41 typename GemmAccDataType,
42 typename CShuffleDataType,
43 typename AElementwiseOperation,
44 typename BElementwiseOperation,
45 typename CElementwiseOperation,
57 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
58 typename ABlockTransferThreadClusterArrangeOrder,
59 typename ABlockTransferSrcAccessOrder,
60 index_t ABlockTransferSrcVectorDim,
61 index_t ABlockTransferSrcScalarPerVector,
62 index_t ABlockTransferDstScalarPerVector_AK1,
64 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
65 typename BBlockTransferThreadClusterArrangeOrder,
66 typename BBlockTransferSrcAccessOrder,
67 index_t BBlockTransferSrcVectorDim,
68 index_t BBlockTransferSrcScalarPerVector,
69 index_t BBlockTransferDstScalarPerVector_BK1,
71 index_t CShuffleMRepeatPerShuffle,
72 index_t CShuffleNRepeatPerShuffle,
73 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
74 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
77 typename ReduceDataType = CDataType,
78 typename ComputeTypeA = CDataType,
79 typename ComputeTypeB = ComputeTypeA>
88 AElementwiseOperation,
89 BElementwiseOperation,
90 CElementwiseOperation>
107 AElementwiseOperation,
108 BElementwiseOperation,
121 ABlockTransferThreadClusterLengths_AK0_M_AK1,
122 ABlockTransferThreadClusterArrangeOrder,
123 ABlockTransferSrcAccessOrder,
124 ABlockTransferSrcVectorDim,
125 ABlockTransferSrcScalarPerVector,
126 ABlockTransferDstScalarPerVector_AK1,
129 BBlockTransferThreadClusterLengths_BK0_N_BK1,
130 BBlockTransferThreadClusterArrangeOrder,
131 BBlockTransferSrcAccessOrder,
132 BBlockTransferSrcVectorDim,
133 BBlockTransferSrcScalarPerVector,
134 BBlockTransferDstScalarPerVector_BK1,
137 CShuffleMRepeatPerShuffle,
138 CShuffleNRepeatPerShuffle,
139 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
151 std::array<const void*, 1> p_b_grid_,
152 const ::std::array<const void*, NumDTensor> p_ds_,
153 CDataType* p_c_grid_,
157 std::array<index_t, 1> StrideA_,
158 std::array<index_t, 1> StrideB_,
159 const ::std::array<index_t, NumDTensor> stride_ds_,
162 AElementwiseOperation a_element_op_,
163 BElementwiseOperation b_element_op_,
164 CElementwiseOperation c_element_op_)
167 ::std::array<const void*, 0>{},
168 reinterpret_cast<ReduceDataType*
>(p_c_grid_),
174 std::array<index_t, 0>{},
190 const ::std::array<const void*, NumDTensor>
p_ds;
218 CShuffleBlockTransferScalarPerVector_NPerBlock,
221 CShuffleBlockTransferScalarPerVector_NPerBlock,
222 CShuffleBlockTransferScalarPerVector_NPerBlock,
229 static constexpr
index_t NumInDim = 3;
230 static constexpr
index_t NumOutDim = 2;
232 ::std::array<index_t, NumInDim> in_lengths = {arg.
KBatch, arg.
M, arg.
N};
233 ::std::array<index_t, NumOutDim> out_lengths = {arg.
M, arg.
N};
235 ::std::array<index_t, NumInDim> in_strides;
236 ::std::array<index_t, NumOutDim> out_strides;
239 in_strides = {arg.
M * arg.
N, arg.
N, 1};
240 out_strides = {arg.
N, 1};
244 in_strides = {arg.
M * arg.
N, 1, arg.
M};
245 out_strides = {1, arg.
M};
248 ::std::array<int, 1> reduce_dims{0};
250 ::std::array<::std::array<index_t, NumOutDim>,
NumDTensor> DsLengths;
251 ::std::array<::std::array<index_t, NumOutDim>,
NumDTensor> DsStrides;
253 static_for<0, NumDTensor, 1>{}([&](
auto i) {
254 DsLengths[i] = out_lengths;
259 DsStrides[i] = {arg.
StrideDs[i], 1};
263 DsStrides[i] = {1, arg.
StrideDs[i]};
269 auto argument_ptr = reduce.MakeArgumentPointer(in_lengths,
282 auto invoker_ptr = reduce.MakeInvokerPointer();
286 if(reduce.IsSupportedArgument(argument_ptr.get()))
288 ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config);
292 throw ::std::runtime_error(
293 "The runtime parameters are not supported by the device instance.");
310 throw ::std::runtime_error(
"using reduce, but empty workspace!");
315 if(stream_config.log_level_ > 0)
322 throw ::std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
331 index_t K_split = (arg.
K + k_grain - 1) / k_grain * KPerBlock;
335 constexpr
index_t minimum_occupancy =
338 if(has_main_k_block_loop)
346 stream_config, kernel, ::dim3(gdx, gdy, gdz), ::dim3(BlockSize), 0, arg);
356 stream_config, kernel, ::dim3(gdx, gdy, gdz), ::dim3(BlockSize), 0, arg);
361 ave_time += RunReduce(arg_, stream_config);
371 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
396 return GridwiseGemm::CheckValidity(
402 return IsSupportedArgument(*
dynamic_cast<const Argument*
>(p_arg));
407 return GridwiseGemm::CalculateGridSize(M, N, KBatch);
414 return GridwiseGemm::GetSharedMemoryNumberOfByte();
418 const BDataType* p_b,
419 const ::std::array<const void*, NumDTensor> p_ds,
426 const ::std::array<index_t, NumDTensor> stride_ds,
429 AElementwiseOperation a_element_op,
430 BElementwiseOperation b_element_op,
431 CElementwiseOperation c_element_op)
433 return Argument{std::array<const void*, 1>{p_a},
434 std::array<const void*, 1>{p_b},
440 std::array<index_t, 1>{StrideA},
441 std::array<index_t, 1>{StrideB},
455 return ::std::make_unique<Invoker>(
Invoker{});
461 ::std::array<const void*, NumDTensor> p_ds,
468 ::std::array<index_t, NumDTensor> DsStrides,
471 AElementwiseOperation a_element_op,
472 BElementwiseOperation b_element_op,
473 CElementwiseOperation c_element_op)
override
475 return ::std::make_unique<Argument>(std::array<const void*, 1>{p_a},
476 std::array<const void*, 1>{p_b},
478 static_cast<CDataType*
>(p_c),
482 std::array<index_t, 1>{StrideA},
483 std::array<index_t, 1>{StrideB},
494 auto str = ::std::stringstream();
502 return ::std::string(
"?");
514 return ::std::string(
"v?");
518 str <<
"DeviceGemmWmmaUniversalReduce"
521 << ::std::string(ALayout::name)[0]
522 << ::std::string(BLayout::name)[0]
523 << ::std::string(CLayout::name)[0]
528 << MPerBlock<<
"x"<<NPerBlock<<
"x"<<KPerBlock <<
", "
530 << MPerWmma<<
"x"<<NPerWmma <<
", "
532 << MRepeat<<
"x" << NRepeat<<
", "
534 << ABlockTransferSrcScalarPerVector<<
"x"<<BBlockTransferSrcScalarPerVector<<
", "
535 <<
"BlkGemmPipelineScheduler: "
536 << BlkGemmPipelineSchedulerToString(BlkGemmPipeSched) <<
", "
537 <<
"BlkGemmPipelineVersion: "
538 << BlkGemmPipelineVersionToString(BlkGemmPipelineVer) <<
", "
539 <<
"BlkGemmPipelinePrefetchStages: "
540 << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
548 auto arg = *
dynamic_cast<const Argument*
>(p_arg);
553 return arg.
M * arg.
N * arg.
KBatch *
sizeof(ReduceDataType);
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:47
bool is_wmma_supported()
Definition: device_prop.hpp:127
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto generate_sequence_v2(F &&f, Number< N >)
Definition: sequence_helper.hpp:25
__global__ void kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:40
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:299
Definition: stream_config.hpp:10
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:406
EDataType * p_e_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:471
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:458
index_t N
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:387
index_t K
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:388
__host__ void Print() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:360
index_t KBatch
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:393
index_t M
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:386
ck::GridwiseGemm_wmma_cshuffle_v3_base< ALayout, BLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1Value, BK1Value, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, AThreadTransferSrcResetCoordinateAfterRun, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BThreadTransferSrcResetCoordinateAfterRun, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB, false >::CalculateHasMainKBlockLoop static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:837
ck::GridwiseGemm_wmma_cshuffle_v3_base< ALayout, BLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1Value, BK1Value, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, AThreadTransferSrcResetCoordinateAfterRun, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BThreadTransferSrcResetCoordinateAfterRun, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB, false >::CalculateGridSize static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:273
ck::GridwiseGemm_wmma_cshuffle_v3_base< ALayout, BLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1Value, BK1Value, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, AThreadTransferSrcResetCoordinateAfterRun, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BThreadTransferSrcResetCoordinateAfterRun, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB, false >::CheckValidity static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:624
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:233
Definition: multi_index_transform.hpp:13
Definition: sequence.hpp:43
Definition: tuple.hpp:186
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
static constexpr bool value
Definition: integral_constant.hpp:21
Definition: reduction_operator.hpp:37
Definition: device_base.hpp:197
void * p_workspace_
Definition: device_base.hpp:204
Definition: device_base.hpp:208
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:149
Argument(std::array< const void *, 1 > p_a_grid_, std::array< const void *, 1 > p_b_grid_, const ::std::array< const void *, NumDTensor > p_ds_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, std::array< index_t, 1 > StrideA_, std::array< index_t, 1 > StrideB_, const ::std::array< index_t, NumDTensor > stride_ds_, index_t StrideC_, index_t KBatch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:150
CDataType * p_c_grid
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:188
const ::std::array< const void *, NumDTensor > p_ds
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:190
CElementwiseOperation c_element_op
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:189
::std::array< index_t, NumDTensor > StrideDs
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:191
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:226
float RunReduce(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:227
float Run(const Argument &arg_, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:299
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:368
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:91
ck::tensor_operation::element_wise::PassThrough PassThrough
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:94
DeviceReduceThreadWiseMultiD< ReduceDataType, DsDataType, GemmAccDataType, CDataType, 3, 1, ReduceAdd, PassThrough, OutElementwiseOperation, 256, CShuffleBlockTransferScalarPerVector_NPerBlock, 1, 0, CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, decltype(DsVectorLengthSequence)> DeviceReduceInstance
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:223
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:375
static size_t GetSharedMemoryNumberOfByte()
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:412
static constexpr index_t NumDTensor
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:92
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, const ::std::array< const void *, NumDTensor > p_ds, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, const ::std::array< index_t, NumDTensor > stride_ds, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:417
static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:405
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:381
static constexpr auto DsVectorLengthSequence
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:197
::std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:453
static auto MakeInvoker()
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:450
GridwiseGemm_wmma_cshuffle_v3< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, GemmAccDataType, ReduceDataType, Tuple<>, ReduceDataType, AElementwiseOperation, BElementwiseOperation, PassThrough, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, false, false > GridwiseGemm
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:146
ck::reduce::Add ReduceAdd
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:194
CElementwiseOperation OutElementwiseOperation
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:195
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:400
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:546
static constexpr index_t GetBlockSize()
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:410
::std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, ::std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, ::std::array< index_t, NumDTensor > DsStrides, index_t StrideC, index_t KSplit, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:459
::std::string GetTypeString() const override
Definition: device_gemm_wmma_cshuffle_v3r1.hpp:492
Definition: device_gemm_v2.hpp:57
Definition: device_reduce_threadwise_multi_d.hpp:47
Definition: unary_element_wise_operation.hpp:340