31           bool AEnableLds = 
true,
 
   32           bool BEnableLds = 
true,
 
   33           bool TransposeC = 
false>
 
   52 struct BlockwiseGemmWMMA
 
   54     static constexpr 
auto I0    = Number<0>{};
 
   55     static constexpr 
auto I1    = Number<1>{};
 
   56     static constexpr 
auto I2    = Number<2>{};
 
   57     static constexpr 
auto I3    = Number<3>{};
 
   58     static constexpr 
auto I4    = Number<4>{};
 
   59     static constexpr 
auto I5    = Number<5>{};
 
   60     static constexpr 
auto WmmaK = Number<16>{};
 
   77         WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
 
   79     static constexpr 
index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
 
   80     static constexpr 
index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
 
  100         return threadid_to_wave_idx_adaptor.CalculateBottomIndex(
make_multi_index(thread_id));
 
  106         if constexpr(AEnableLds)
 
  109             const auto waveId_m   = wave_idx[
I0];
 
  110             const auto WMMA_a_idx = 
wmma_gemm.CalculateAThreadOriginDataIndex();
 
  123         if constexpr(BEnableLds)
 
  126             const auto waveId_n   = wave_idx[
I1];
 
  127             const auto WMMA_b_idx = 
wmma_gemm.CalculateBThreadOriginDataIndex();
 
  138     template <index_t m0, index_t n0>
 
  143         const auto waveId_m = wave_idx[
I0];
 
  144         const auto waveId_n = wave_idx[
I1];
 
  146         const auto blk_idx = 
wmma_gemm.GetBeginOfThreadBlk();
 
  158         const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
 
  160         const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
 
  166     template <index_t m0, index_t n0>
 
  171         const auto waveId_m = wave_idx[
I0];
 
  172         const auto waveId_n = wave_idx[
I1];
 
  174         const auto blk_idx = 
wmma_gemm.GetBeginOfThreadBlk3D();
 
  177             Number<m0>{}, waveId_m, blk_idx[
I0], Number<n0>{}, waveId_n, blk_idx[
I1], blk_idx[
I2]);
 
  185         static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
 
  186                       "wrong! Desc should be known at compile-time");
 
  189                       "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
 
  191         static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
 
  192                           NPerBlock % (NPerWMMA * NRepeat) == 0,
 
  197     __host__ __device__ 
static constexpr 
auto 
  200         constexpr 
auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
 
  201             wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
 
  203         constexpr 
auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[
I2];
 
  212     __host__ __device__ 
static constexpr 
auto 
  215         constexpr 
auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
 
  216             wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
 
  218         constexpr 
auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[
I2];
 
  219         constexpr 
auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[
I3];
 
  224             make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
 
  225                        Number<NRepeat>{} * MAccVgprs * AccStride,
 
  226                        Number<NRepeat>{} * MAccVgprs * AccStride,
 
  227                        MAccVgprs * AccStride,
 
  228                        MAccVgprs * AccStride,
 
  229                        MAccVgprs * AccStride,
 
  233     template <
typename CGr
idDesc_M_N>
 
  234     __host__ __device__ 
static constexpr 
auto 
  236         const CGridDesc_M_N& c_grid_desc_m_n)
 
  238         const auto M = c_grid_desc_m_n.GetLength(
I0);
 
  239         const auto N = c_grid_desc_m_n.GetLength(
I1);
 
  241         const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
 
  248                 make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
 
  251             .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
 
  252                 c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
 
  256     __host__ __device__ 
static constexpr 
auto 
  259         constexpr 
auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
 
  265                                                            Number<NPerWMMA>{}));
 
  268             .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
 
  269                 c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
 
  273     __host__ __device__ 
static constexpr 
auto 
  276         constexpr 
auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
 
  282                                                            Number<NPerWMMA>{}));
 
  285             .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
 
  286                 c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
 
  294     template <
typename ABlockBuffer, 
typename BBlockBuffer, 
typename CThreadBuffer>
 
  295     __device__ 
void Run(
const ABlockBuffer& a_block_buf,
 
  296                         const BBlockBuffer& b_block_buf,
 
  297                         CThreadBuffer& c_thread_buf)
 const 
  299         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
 
  301         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
 
  304         static_assert(KPack % (
A_K1 * 
A_KRow) == 0, 
"");
 
  305         static_assert(KPack % (
B_K1 * 
B_KRow) == 0, 
"");
 
  308         if constexpr(MRepeat < NRepeat)
 
  310             static_for<0, KPerBlock / KPack, 1>{}(
 
  312                     static_for<0, MRepeat, 1>{}([&](
auto m0) {
 
  322                         static_for<0, NRepeat, 1>{}([&](
auto n0) {
 
  332                             vector_type<FloatA, KPack / 
A_KRow> a_thread_vec;
 
  333                             vector_type<FloatB, KPack / 
B_KRow> b_thread_vec;
 
  335                             static_for<0, KPack / 
A_KRow, 1>{}([&](
auto i) {
 
  336                                 a_thread_vec.template AsType<FloatA>()(i) =
 
  341                             static_for<0, KPack / 
B_KRow, 1>{}([&](
auto i) {
 
  342                                 b_thread_vec.template AsType<FloatB>()(i) =
 
  347                             using wmma_input_type_a =
 
  348                                 typename vector_type<FloatA, 
WmmaK / 
A_KRow>::type;
 
  349                             using wmma_input_type_b =
 
  350                                 typename vector_type<FloatB, 
WmmaK / 
B_KRow>::type;
 
  356                                 a_thread_vec.template AsType<wmma_input_type_a>(),
 
  357                                 b_thread_vec.template AsType<wmma_input_type_b>(),
 
  358                                 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
 
  365             static_for<0, NRepeat, 1>{}([&](
auto n0) {
 
  366                 static_for<0, MRepeat, 1>{}([&](
auto m0) {
 
  367                     static_for<0, KPerBlock / KPack, 1>{}([&](
auto k) { 
 
  386                         vector_type<FloatA, KPack / 
A_KRow> a_thread_vec;
 
  387                         vector_type<FloatB, KPack / 
B_KRow> b_thread_vec;
 
  389                         static_for<0, KPack / 
A_KRow, 1>{}([&](
auto i) {
 
  390                             a_thread_vec.template AsType<FloatA>()(i) =
 
  395                         static_for<0, KPack / 
B_KRow, 1>{}([&](
auto i) {
 
  396                             b_thread_vec.template AsType<FloatB>()(i) =
 
  401                         using wmma_input_type_a =
 
  402                             typename vector_type<FloatA, 
WmmaK / 
A_KRow>::type;
 
  403                         using wmma_input_type_b =
 
  404                             typename vector_type<FloatB, 
WmmaK / 
B_KRow>::type;
 
  410                             a_thread_vec.template AsType<wmma_input_type_a>(),
 
  411                             b_thread_vec.template AsType<wmma_input_type_b>(),
 
  412                             c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
 
  421         make_tuple(Number<KPack / A_K1 / A_KRow>{}, Number<MRepeat>{}, 
I1, 
I1, 
I1, Number<A_K1>{}),
 
  430         make_tuple(Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, 
I1, 
I1, 
I1, Number<B_K1>{}),
 
  442     template <
bool EnableLds>
 
  443     struct AThreadCopySelector;
 
  446     struct AThreadCopySelector<true>
 
  449             ThreadwiseTensorSliceTransfer_v4<FloatA,
 
  454                                              Sequence<0, 1, 2, 3, 4, 5>,
 
  461     struct AThreadCopySelector<false>
 
  463         using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
 
  468             tensor_operation::element_wise::PassThrough,
 
  470             Sequence<0, 1, 2, 3, 4, 5>,
 
  476     template <
bool EnableLds>
 
  477     struct BThreadCopySelector;
 
  480     struct BThreadCopySelector<true>
 
  483             ThreadwiseTensorSliceTransfer_v4<FloatB,
 
  488                                              Sequence<0, 1, 2, 3, 4, 5>,
 
  495     struct BThreadCopySelector<false>
 
  497         using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
 
  502             tensor_operation::element_wise::PassThrough,
 
  504             Sequence<0, 1, 2, 3, 4, 5>,
 
  528           bool AEnableLds = 
true,
 
  529           bool BEnableLds = 
true,
 
  530           bool TransposeC = 
false>
 
  596         return threadid_to_wave_idx_adaptor.CalculateBottomIndex(
make_multi_index(thread_id));
 
  602         if constexpr(AEnableLds)
 
  605             const auto waveId_m   = wave_idx[
I0];
 
  606             const auto WMMA_a_idx = 
wmma_gemm.CalculateAThreadOriginDataIndex();
 
  609             return make_tuple(0, 0, waveId_m, 0, WMMA_a_idx, 0);
 
  619         if constexpr(BEnableLds)
 
  622             const auto waveId_n   = wave_idx[
I1];
 
  623             const auto WMMA_b_idx = 
wmma_gemm.CalculateBThreadOriginDataIndex();
 
  626             return make_tuple(0, 0, waveId_n, 0, WMMA_b_idx, 0);
 
  634     template <index_t m0, index_t n0>
 
  639         const auto waveId_m = wave_idx[
I0];
 
  640         const auto waveId_n = wave_idx[
I1];
 
  642         const auto blk_idx = 
wmma_gemm.GetBeginOfThreadBlk();
 
  654         const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
 
  656         const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
 
  662     template <index_t m0, index_t n0>
 
  667         const auto waveId_m = wave_idx[
I0];
 
  668         const auto waveId_n = wave_idx[
I1];
 
  670         const auto blk_idx = 
wmma_gemm.GetBeginOfThreadBlk3D();
 
  681         static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
 
  682                       "wrong! Desc should be known at compile-time");
 
  685                       "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
 
  687         static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
 
  688                           NPerBlock % (NPerWMMA * NRepeat) == 0,
 
  693     __host__ __device__ 
static constexpr 
auto 
  696         constexpr 
auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
 
  697             wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
 
  699         constexpr 
auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[
I2];
 
  708     __host__ __device__ 
static constexpr 
auto 
  711         constexpr 
auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
 
  712             wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
 
  714         constexpr 
auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[
I2];
 
  715         constexpr 
auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[
I3];
 
  723                        MAccVgprs * AccStride,
 
  724                        MAccVgprs * AccStride,
 
  725                        MAccVgprs * AccStride,
 
  729     template <
typename CGr
idDesc_M_N>
 
  730     __host__ __device__ 
static constexpr 
auto 
  732         const CGridDesc_M_N& c_grid_desc_m_n)
 
  734         const auto M = c_grid_desc_m_n.GetLength(
I0);
 
  735         const auto N = c_grid_desc_m_n.GetLength(
I1);
 
  737         const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
 
  747             .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
 
  748                 c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
 
  752     __host__ __device__ 
static constexpr 
auto 
  755         constexpr 
auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
 
  764             .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
 
  765                 c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
 
  769     __host__ __device__ 
static constexpr 
auto 
  772         constexpr 
auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
 
  781             .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
 
  782                 c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
 
  790     template <
typename ABlockBuffer, 
typename BBlockBuffer, 
typename CThreadBuffer>
 
  791     __device__ 
void Run(
const ABlockBuffer& a_block_buf,
 
  792                         const BBlockBuffer& b_block_buf,
 
  793                         CThreadBuffer& c_thread_buf)
 const 
  795         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
 
  797         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
 
  801         if constexpr(MRepeat < NRepeat)
 
  829                                 a_thread_vec.template AsType<FloatA>()(i) =
 
  837                                 b_thread_vec.template AsType<FloatB>()(i) =
 
  853                             wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
 
  854                                           b_thread_vec.template AsType<wmma_input_type_b>(),
 
  864                     static_for<0, KPerBlock / KPack, 1>{}([&](
auto k) { 
 
  887                             b_thread_vec.template AsType<FloatB>()(i) =
 
  895                             a_thread_vec.template AsType<FloatA>()(i) =
 
  911                         wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
 
  912                                       b_thread_vec.template AsType<wmma_input_type_b>(),
 
  930                                                 Number<A_K1 * A_KRow>{},
 
  944                                                 Number<B_K1 * B_KRow>{},
 
  953     template <
bool EnableLds>
 
  986             TransposeC ? false : 
true>;
 
  989     template <
bool EnableLds>
 
 1022             TransposeC ? true : 
false>;
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
 
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
 
int32_t index_t
Definition: ck.hpp:300
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
integral_constant< index_t, N > Number
Definition: number.hpp:12
 
Definition: blockwise_gemm_wmma.hpp:954
 
Definition: blockwise_gemm_wmma.hpp:990
 
Definition: blockwise_gemm_wmma.hpp:550
 
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: blockwise_gemm_wmma.hpp:731
 
__host__ static constexpr __device__ auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
Definition: blockwise_gemm_wmma.hpp:709
 
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, wmma_gemm.GetRegSizePerWmma(), true > c_thread_buf_
Definition: blockwise_gemm_wmma.hpp:583
 
static constexpr index_t NWaves
Definition: blockwise_gemm_wmma.hpp:576
 
static constexpr index_t A_KRow
Definition: blockwise_gemm_wmma.hpp:567
 
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_wmma.hpp:935
 
static constexpr index_t B_K1
Definition: blockwise_gemm_wmma.hpp:570
 
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_wmma.hpp:559
 
__device__ void Run(const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition: blockwise_gemm_wmma.hpp:791
 
static constexpr index_t A_K1
Definition: blockwise_gemm_wmma.hpp:569
 
static constexpr auto I0
Definition: blockwise_gemm_wmma.hpp:551
 
static constexpr auto I5
Definition: blockwise_gemm_wmma.hpp:556
 
static constexpr index_t B_KRow
Definition: blockwise_gemm_wmma.hpp:568
 
__host__ static constexpr __device__ auto GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
Definition: blockwise_gemm_wmma.hpp:770
 
BThreadCopySelector< BEnableLds >::type b_thread_copy_
Definition: blockwise_gemm_wmma.hpp:1026
 
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_wmma.hpp:585
 
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >)
Definition: blockwise_gemm_wmma.hpp:635
 
static constexpr auto I1
Definition: blockwise_gemm_wmma.hpp:552
 
static constexpr auto I3
Definition: blockwise_gemm_wmma.hpp:554
 
static constexpr index_t MWaves
Definition: blockwise_gemm_wmma.hpp:575
 
decltype(CalculateAThreadOriginDataIndex()) Tuple6
Definition: blockwise_gemm_wmma.hpp:676
 
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_wmma.hpp:921
 
static constexpr index_t WaveSize
Definition: blockwise_gemm_wmma.hpp:562
 
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_wmma.hpp:600
 
__host__ static constexpr __device__ auto GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
Definition: blockwise_gemm_wmma.hpp:753
 
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_wmma.hpp:950
 
static constexpr auto WmmaK
Definition: blockwise_gemm_wmma.hpp:557
 
static constexpr auto I4
Definition: blockwise_gemm_wmma.hpp:555
 
static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1
Definition: blockwise_gemm_wmma.hpp:787
 
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1
Definition: blockwise_gemm_wmma.hpp:788
 
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_wmma.hpp:587
 
__host__ static constexpr __device__ auto GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
Definition: blockwise_gemm_wmma.hpp:694
 
static constexpr auto wmma_gemm
Definition: blockwise_gemm_wmma.hpp:572
 
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_wmma.hpp:617
 
__host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin=CalculateAThreadOriginDataIndex(), Tuple6 b_origin=CalculateBThreadOriginDataIndex())
Definition: blockwise_gemm_wmma.hpp:677
 
static constexpr auto I2
Definition: blockwise_gemm_wmma.hpp:553
 
AThreadCopySelector< AEnableLds >::type a_thread_copy_
Definition: blockwise_gemm_wmma.hpp:1025
 
static __device__ auto CalculateCThreadOriginDataIndex7D(Number< m0 >, Number< n0 >)
Definition: blockwise_gemm_wmma.hpp:663
 
Definition: sequence.hpp:43
 
Definition: static_buffer.hpp:75
 
static __device__ index_t GetThreadId()
Definition: thread_group.hpp:19
 
static constexpr __device__ index_t GetNumOfThread()
Definition: thread_group.hpp:15
 
Definition: threadwise_tensor_slice_transfer.hpp:1881
 
Definition: threadwise_tensor_slice_transfer.hpp:1264
 
Definition: wmma_gemm.hpp:663
 
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: unary_element_wise_operation.hpp:308
 
Definition: dtype_vector.hpp:10