22 namespace tensor_operation {
124 template <
typename ALayout,
130 typename AccDataType,
131 typename CShuffleDataType,
132 typename AElementwiseOperation,
133 typename BElementwiseOperation,
134 typename CElementwiseOperation,
146 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
147 typename ABlockTransferThreadClusterArrangeOrder,
148 typename ABlockTransferSrcAccessOrder,
149 index_t ABlockTransferSrcVectorDim,
150 index_t ABlockTransferSrcScalarPerVector,
151 index_t ABlockTransferDstScalarPerVector_AK1,
152 bool ABlockLdsExtraM,
153 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
154 typename BBlockTransferThreadClusterArrangeOrder,
155 typename BBlockTransferSrcAccessOrder,
156 index_t BBlockTransferSrcVectorDim,
157 index_t BBlockTransferSrcScalarPerVector,
158 index_t BBlockTransferDstScalarPerVector_BK1,
159 bool BBlockLdsExtraN,
160 index_t CShuffleMRepeatPerShuffle,
161 index_t CShuffleNRepeatPerShuffle,
162 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
163 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
166 typename ComputeTypeA = CDataType,
167 typename ComputeTypeB = ComputeTypeA,
168 bool PermuteA =
false,
169 bool PermuteB =
false>
176 AElementwiseOperation,
177 BElementwiseOperation,
178 CElementwiseOperation>
191 AElementwiseOperation,
192 BElementwiseOperation,
193 CElementwiseOperation,
205 ABlockTransferThreadClusterLengths_AK0_M_AK1,
206 ABlockTransferThreadClusterArrangeOrder,
207 ABlockTransferSrcAccessOrder,
208 ABlockTransferSrcVectorDim,
209 ABlockTransferSrcScalarPerVector,
210 ABlockTransferDstScalarPerVector_AK1,
213 BBlockTransferThreadClusterLengths_BK0_N_BK1,
214 BBlockTransferThreadClusterArrangeOrder,
215 BBlockTransferSrcAccessOrder,
216 BBlockTransferSrcVectorDim,
217 BBlockTransferSrcScalarPerVector,
218 BBlockTransferDstScalarPerVector_BK1,
221 CShuffleMRepeatPerShuffle,
222 CShuffleNRepeatPerShuffle,
223 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
273 const BDataType* p_b,
282 AElementwiseOperation a_element_op,
283 BElementwiseOperation b_element_op,
284 CElementwiseOperation cde_element_op)
286 return Argument{std::array<const void*, 1>{p_a},
287 std::array<const void*, 1>{p_b},
288 std::array<const void*, 0>{},
293 std::array<index_t, 1>{StrideA},
294 std::array<index_t, 1>{StrideB},
295 std::array<index_t, 0>{},
316 AElementwiseOperation a_element_op,
317 BElementwiseOperation b_element_op,
318 CElementwiseOperation c_element_op)
override
320 return std::make_unique<Argument>(std::array<const void*, 1>{p_a},
321 std::array<const void*, 1>{p_b},
322 std::array<const void*, 0>{},
323 static_cast<CDataType*
>(p_c),
327 std::array<index_t, 1>{StrideA},
328 std::array<index_t, 1>{StrideB},
329 std::array<index_t, 0>{},
340 return std::make_unique<Invoker>(
Invoker{});
346 auto str = std::stringstream();
348 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
352 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
360 str <<
"DeviceGemm_Wmma_CShuffleV3"
363 << std::string(ALayout::name)[0]
364 << std::string(BLayout::name)[0]
365 << std::string(CLayout::name)[0]
370 << MPerBlock <<
"x" << NPerBlock <<
"x" << KPerBlock <<
", "
372 << MPerWmma <<
"x"<<NPerWmma <<
", "
374 << MRepeat <<
"x" << NRepeat <<
", "
376 << ABlockTransferSrcScalarPerVector <<
"x" << BBlockTransferSrcScalarPerVector <<
", "
377 <<
"BlkGemmPipelineScheduler: "
378 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] <<
", "
379 <<
"BlkGemmPipelineVersion: "
380 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] <<
", "
381 <<
"BlkGemmPipelinePrefetchStages: "
382 << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages <<
", "
#define REGISTER_EXTRA_PRINTING_METHODS
Definition: device_base.hpp:47
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
int32_t index_t
Definition: ck.hpp:299
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:406
ck::GridwiseGemm_wmma_cshuffle_v3_base< ALayout, BLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1Value, BK1Value, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, AThreadTransferSrcResetCoordinateAfterRun, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BThreadTransferSrcResetCoordinateAfterRun, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB, false >::KPack static constexpr index_t KPack
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:154
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:233
Definition: sequence.hpp:43
Definition: tuple.hpp:186
Definition: tuple.hpp:117
Definition: device_base.hpp:197
Helper structure responsible for kernel invocation.
Definition: device_gemm_wmma_cshuffle_v3_common.hpp:57
Definition: device_gemm_wmma_cshuffle_v3_common.hpp:43
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_wmma_cshuffle_v3_common.hpp:268
"Universal" GEMM operation with SplitK support.
Definition: device_gemm_wmma_cshuffle_v3.hpp:179
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation cde_element_op)
Definition: device_gemm_wmma_cshuffle_v3.hpp:272
std::string GetTypeString() const override
Definition: device_gemm_wmma_cshuffle_v3.hpp:344
typename DeviceGemmCommon::Invoker Invoker
Definition: device_gemm_wmma_cshuffle_v3.hpp:254
static auto MakeInvoker()
Definition: device_gemm_wmma_cshuffle_v3.hpp:303
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_wmma_cshuffle_v3.hpp:262
bool GetPermuteA() override
Definition: device_gemm_wmma_cshuffle_v3.hpp:269
typename GridwiseGemm::Argument Argument
Definition: device_gemm_wmma_cshuffle_v3.hpp:232
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_wmma_cshuffle_v3.hpp:338
bool GetPermuteB() override
Definition: device_gemm_wmma_cshuffle_v3.hpp:270
GridwiseGemm_wmma_cshuffle_v3< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemm
Definition: device_gemm_wmma_cshuffle_v3.hpp:230
index_t GetKPerBlock() override
Definition: device_gemm_wmma_cshuffle_v3.hpp:267
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_gemm_wmma_cshuffle_v3.hpp:306
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_wmma_cshuffle_v3.hpp:256
Definition: device_gemm_v2.hpp:22