23 template <
typename LowLengths>
 
   49         static_assert(LowerIndex::Size() == NDimLow, 
"wrong!");
 
   56     __host__ __device__ constexpr 
const auto& 
GetUpperLengths()
 const { 
return up_lengths_; }
 
   58     template <
typename LowIdx, 
typename UpIdx>
 
   60                                                            const UpIdx& idx_up)
 const 
   62         static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
 
   63                       "wrong! inconsistent # of dimension");
 
   69             idx_low(i) = tmp / this->low_lengths_scan_[i];
 
   70             tmp %= this->low_lengths_scan_[i];
 
   76     template <
typename LowIdxDiff,
 
   82                                               const UpIdxDiff& idx_up_diff,
 
   84                                               const UpIdx& idx_up_new,
 
   87         static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
 
   88                           LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
 
   89                       "wrong! inconsistent # of dimension");
 
   92         constexpr 
auto INm1 = 
Number<NDimLow - 1>{};
 
   97         idx_diff_low(INm1) = idx_up_diff[I0];
 
  114     template <
typename UpIdx>
 
  115     __host__ __device__ 
static constexpr 
bool 
  121     __host__ __device__ 
void Print()
 const 
  124         printf(
"Merge_v3_direct_division_mod_wrw, ");
 
  125         printf(
"low_lengths_ ");
 
  127         printf(
"low_lengths_scan_ ");
 
  129         printf(
"up_lengths_ ");
 
  135 template <
typename LowLengths>
 
  141 template <
typename GridwiseGemm,
 
  145           typename AGridDesc_B_K0_M_K1,
 
  146           typename BGridDesc_B_K0_N_K1,
 
  147           typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
 
  148           typename AElementwiseOperation,
 
  149           typename BElementwiseOperation,
 
  150           typename CElementwiseOperation,
 
  151           typename CBlockClusterAdaptor,
 
  152           bool HasMainKBlockLoop>
 
  154 #if CK_USE_LAUNCH_BOUNDS 
  158                                   const FloatB* __restrict__ p_b_grid,
 
  159                                   FloatC* __restrict__ p_c_grid,
 
  160                                   const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
 
  161                                   const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
 
  162                                   const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
 
  163                                       c_grid_desc_mblock_mperblock_nblock_nperblock,
 
  164                                   const AElementwiseOperation a_element_op,
 
  165                                   const BElementwiseOperation b_element_op,
 
  166                                   const CElementwiseOperation c_element_op,
 
  167                                   const CBlockClusterAdaptor c_block_cluster_adaptor)
 
  169 #if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) 
  170     __shared__ 
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
  172     GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
 
  176                                                   a_b_k0_m_k1_grid_desc,
 
  177                                                   b_b_k0_n_k1_grid_desc,
 
  178                                                   c_grid_desc_mblock_mperblock_nblock_nperblock,
 
  182                                                   c_block_cluster_adaptor);
 
  187     ignore = a_b_k0_m_k1_grid_desc;
 
  188     ignore = b_b_k0_n_k1_grid_desc;
 
  189     ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
 
  193     ignore = c_block_cluster_adaptor;
 
  203           typename AGridDesc_B_K0_M_K1,
 
  204           typename BGridDesc_B_K0_N_K1,
 
  205           typename CMNGridDesc,
 
  206           typename AElementwiseOperation,
 
  207           typename BElementwiseOperation,
 
  208           typename CElementwiseOperation,
 
  217           typename ABlockTransferThreadClusterLengths_K0_M_K1,
 
  218           typename ABlockTransferThreadClusterArrangeOrder,
 
  219           typename ABlockTransferSrcAccessOrder,
 
  220           index_t ABlockTransferSrcVectorDim,
 
  221           index_t ABlockTransferSrcScalarPerVector,
 
  222           index_t ABlockTransferDstScalarPerVector_K1,
 
  223           bool AThreadTransferSrcResetCoordinateAfterRun,
 
  224           bool ABlockLdsExtraM,
 
  228           typename BBlockTransferThreadClusterLengths_K0_N_K1,
 
  229           typename BBlockTransferThreadClusterArrangeOrder,
 
  230           typename BBlockTransferSrcAccessOrder,
 
  231           index_t BBlockTransferSrcVectorDim,
 
  232           index_t BBlockTransferSrcScalarPerVector,
 
  233           index_t BBlockTransferDstScalarPerVector_K1,
 
  234           bool BThreadTransferSrcResetCoordinateAfterRun,
 
  235           bool BBlockLdsExtraN,
 
  239           index_t CShuffleMRepeatPerShuffle,
 
  240           index_t CShuffleNRepeatPerShuffle,
 
  241           index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
 
  242           typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
  243           bool ABlockLdsExtraM1Wrw      = 
false,
 
  244           bool BBlockLdsExtraN1Wrw      = 
false,
 
  245           index_t NumGemmKPrefetchStage = 1,
 
  247           typename ComputeTypeA         = FloatA,
 
  248           typename ComputeTypeB         = ComputeTypeA>
 
  266         decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
 
  273 #if CK_GFX90A_DENORM_WORKAROUND 
  295         constexpr 
auto max_lds_align = K1;
 
  298         constexpr 
auto a_block_desc_k0_m_k1 = [&]() {
 
  299             if constexpr(ABlockLdsExtraM)
 
  301                 if constexpr(ABlockLdsExtraM1Wrw)
 
  312                         a_block_desc_k0_m0_m1_k1,
 
  320                     return a_block_desc_k0_m_k1_tmp;
 
  336         return a_block_desc_k0_m_k1;
 
  341         constexpr 
auto max_lds_align = K1;
 
  344         constexpr 
auto a_block_desc_b_k0_m_k1 = [&]() {
 
  345             if constexpr(ABlockLdsExtraM)
 
  347                 if constexpr(ABlockLdsExtraM1Wrw)
 
  363                         a_block_desc_b_k0_m0_m1_k1,
 
  372                     return a_block_desc_b_k0_m_k1_tmp;
 
  392         return a_block_desc_b_k0_m_k1;
 
  397         constexpr 
auto max_lds_align = K1;
 
  400         constexpr 
auto b_block_desc_k0_n_k1 = [&]() {
 
  401             if constexpr(BBlockLdsExtraN)
 
  403                 if constexpr(BBlockLdsExtraN1Wrw)
 
  414                         b_block_desc_k0_n0_n1_k1,
 
  422                     return b_block_desc_k0_n_k1_tmp;
 
  438         return b_block_desc_k0_n_k1;
 
  443         constexpr 
auto max_lds_align = K1;
 
  446         constexpr 
auto b_block_desc_b_k0_n_k1 = [&]() {
 
  447             if constexpr(BBlockLdsExtraN)
 
  449                 if constexpr(BBlockLdsExtraN1Wrw)
 
  465                         b_block_desc_b_k0_n0_n1_k1,
 
  474                     return b_block_desc_b_k0_n_k1_tmp;
 
  494         return b_block_desc_b_k0_n_k1;
 
  499         constexpr 
auto max_lds_align = K1;
 
  502         constexpr 
auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1();
 
  505         constexpr 
auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1();
 
  509             a_b_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
 
  512             b_b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
 
  514         constexpr 
auto c_block_size =
 
  515             GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
 
  519                          c_block_size * 
sizeof(FloatC));
 
  523     template <
typename Block2CTileMap>
 
  524     __host__ __device__ 
static constexpr 
bool 
  526                   const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
 
  527                   const CMNGridDesc& c_m_n_grid_desc,
 
  528                   const Block2CTileMap& block_2_ctile_map)
 
  531                       "wrong! K1 need to be known at compile-time");
 
  533         static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
 
  534                           (NPerBlock % (NRepeat * NPerXDL)) == 0,
 
  535                       "Invalid tuning param!");
 
  537         const auto M      = a_b_k0_m_k1_grid_desc.GetLength(I2);
 
  538         const auto N      = b_b_k0_n_k1_grid_desc.GetLength(I2);
 
  539         const auto K0     = a_b_k0_m_k1_grid_desc.GetLength(I1);
 
  540         const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
 
  543         const auto num_k_loop = K0 / K0PerBlock;
 
  545         if(!GridwiseGemmPipe::IsSupported(num_k_loop))
 
  550         if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
 
  551              K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) &&
 
  552              K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) &&
 
  553              K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) &&
 
  554              KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0)))
 
  557         if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
 
  560         if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc))
 
  572         const index_t num_loop = K0 / K0PerBlock;
 
  574         return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
 
  579     __host__ __device__ 
static constexpr 
auto 
  582         const auto M = c_m_n_grid_desc.GetLength(I0);
 
  583         const auto N = c_m_n_grid_desc.GetLength(I1);
 
  585         const auto MBlock = M / MPerBlock;
 
  586         const auto NBlock = N / NPerBlock;
 
  601             c_m_n_grid_desc, M01, N01, KBatch);
 
  604     __host__ __device__ 
static constexpr 
auto 
  607         constexpr 
index_t MWave = MPerBlock / (MRepeat * MPerXDL);
 
  608         constexpr 
index_t NWave = NPerBlock / (NRepeat * NPerXDL);
 
  618         decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{}));
 
  621     template <
bool HasMainKBlockLoop>
 
  622     __device__ 
static void Run(
const FloatA* __restrict__ p_a_grid,
 
  623                                const FloatB* __restrict__ p_b_grid,
 
  624                                FloatC* __restrict__ p_c_grid,
 
  625                                void* __restrict__ p_shared,
 
  626                                const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
 
  627                                const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
 
  629                                    c_grid_desc_mblock_mperblock_nblock_nperblock,
 
  630                                const AElementwiseOperation& a_element_op,
 
  631                                const BElementwiseOperation& b_element_op,
 
  632                                const CElementwiseOperation& c_element_op,
 
  635         const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  636             p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
 
  637         const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  638             p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
 
  639         auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  640             p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
  642         const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
 
  645         const auto block_work_idx =
 
  648         const index_t k_batch_id = block_work_idx[I0];
 
  650         if(!c_block_cluster_adaptor.ValidCTileIndex(
 
  651                make_tuple(block_work_idx[I1], block_work_idx[I2]),
 
  652                make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
 
  653                           c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
 
  659         const index_t m_block_data_idx_on_grid =
 
  660             __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
 
  662         const index_t n_block_data_idx_on_grid =
 
  663             __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
 
  666         constexpr 
auto max_lds_align = K1;
 
  669         constexpr 
auto a_k0_m_k1_block_desc = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
 
  671         constexpr 
auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1();
 
  673         constexpr 
auto b_k0_n_k1_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
 
  675         constexpr 
auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1();
 
  677         auto a_blockwise_copy =
 
  679                                                 AElementwiseOperation,
 
  681                                                 InMemoryDataOperationEnum::Set,
 
  683                                                 ABlockTransferThreadClusterLengths_K0_M_K1,
 
  684                                                 ABlockTransferThreadClusterArrangeOrder,
 
  687                                                 decltype(a_b_k0_m_k1_grid_desc),
 
  688                                                 decltype(a_b_k0_m_k1_block_desc),
 
  689                                                 ABlockTransferSrcAccessOrder,
 
  691                                                 ABlockTransferSrcVectorDim,
 
  693                                                 ABlockTransferSrcScalarPerVector,
 
  694                                                 ABlockTransferDstScalarPerVector_K1,
 
  697                                                 AThreadTransferSrcResetCoordinateAfterRun,
 
  699                 a_b_k0_m_k1_grid_desc,
 
  702                 a_b_k0_m_k1_block_desc,
 
  707         auto b_blockwise_copy =
 
  709                                                 BElementwiseOperation,
 
  711                                                 InMemoryDataOperationEnum::Set,
 
  713                                                 BBlockTransferThreadClusterLengths_K0_N_K1,
 
  714                                                 BBlockTransferThreadClusterArrangeOrder,
 
  717                                                 decltype(b_b_k0_n_k1_grid_desc),
 
  718                                                 decltype(b_b_k0_n_k1_block_desc),
 
  719                                                 BBlockTransferSrcAccessOrder,
 
  721                                                 BBlockTransferSrcVectorDim,
 
  723                                                 BBlockTransferSrcScalarPerVector,
 
  724                                                 BBlockTransferDstScalarPerVector_K1,
 
  727                                                 BThreadTransferSrcResetCoordinateAfterRun,
 
  729                 b_b_k0_n_k1_grid_desc,
 
  732                 b_b_k0_n_k1_block_desc,
 
  743         constexpr 
bool is_single_rate_mfma =
 
  751         constexpr 
auto is_scale_mfma = 
false;
 
  758                                                               is_scale_mfma>::selected_mfma.k_per_blk);
 
  760         auto blockwise_gemm =
 
  765                                                                 decltype(a_k0_m_k1_block_desc),
 
  766                                                                 decltype(b_k0_n_k1_block_desc),
 
  776         constexpr 
auto a_block_space_size =
 
  779         constexpr 
auto a_block_slice_copy_step = 
make_multi_index(0, K0PerBlock, 0, 0);
 
  780         constexpr 
auto b_block_slice_copy_step = 
make_multi_index(0, K0PerBlock, 0, 0);
 
  782         auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  783             static_cast<FloatAAdjusted*
>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
 
  785         auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  787             b_k0_n_k1_block_desc.GetElementSpaceSize());
 
  790         const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
 
  792         GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
 
  793                                                           a_b_k0_m_k1_block_desc,
 
  797                                                           a_block_slice_copy_step,
 
  798                                                           b_b_k0_n_k1_grid_desc,
 
  799                                                           b_b_k0_n_k1_block_desc,
 
  803                                                           b_block_slice_copy_step,
 
  810             constexpr 
index_t MWave = MPerBlock / (MRepeat * MPerXDL);
 
  811             constexpr 
index_t NWave = NPerBlock / (NRepeat * NPerXDL);
 
  813             constexpr 
auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
 
  814                 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
  816             constexpr 
auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
 
  817                 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
  819             constexpr 
auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
 
  820             constexpr 
auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
 
  821             constexpr 
auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
 
  822             constexpr 
auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
 
  823             constexpr 
auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
 
  824             constexpr 
auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
 
  825             constexpr 
auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
 
  826             constexpr 
auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
 
  828             constexpr 
auto c_block_desc_mblock_mperblock_nblock_nperblock =
 
  829                 GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
 
  831             auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  832                 static_cast<FloatC*
>(p_shared),
 
  833                 c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
  835             static_assert(M1 == MWave, 
"");
 
  836             static_assert(N1 == NWave, 
"");
 
  837             static_assert(M2 * M3 * M4 == MPerXDL, 
"");
 
  838             static_assert(N2 == NPerXDL, 
"");
 
  841                 c_block_desc_mblock_mperblock_nblock_nperblock,
 
  859             const auto c_thread_mtx_on_block =
 
  860                 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
 
  862             const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
 
  863             const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
 
  865             const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
 
  871             const auto m_thread_data_on_block_idx =
 
  872                 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
 
  875             const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
 
  881             const auto n_thread_data_on_block_idx =
 
  882                 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
 
  886             auto c_thread_copy_vgpr_to_lds =
 
  889                                                    decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
 
  890                                                    decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
  893                                                             CShuffleNRepeatPerShuffle,
 
  903                                                    InMemoryDataOperationEnum::Set,
 
  906                     c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
  909                                      m_thread_data_on_block_idx[I1],
 
  910                                      n_thread_data_on_block_idx[I1],
 
  911                                      m_thread_data_on_block_idx[I2],
 
  912                                      m_thread_data_on_block_idx[I3],
 
  913                                      m_thread_data_on_block_idx[I4],
 
  914                                      n_thread_data_on_block_idx[I2]),
 
  920                 CElementwiseOperation,      
 
  921                 CGlobalMemoryDataOperation, 
 
  923                          CShuffleMRepeatPerShuffle * MWave * MPerXDL,
 
  925                          CShuffleNRepeatPerShuffle * NWave * NPerXDL>, 
 
  926                 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
  930                 decltype(c_block_desc_mblock_mperblock_nblock_nperblock),
 
  931                 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
 
  934                 CBlockTransferScalarPerVector_NWaveNPerXDL, 
 
  937                 {c_block_desc_mblock_mperblock_nblock_nperblock,
 
  939                  c_grid_desc_mblock_mperblock_nblock_nperblock,
 
  943             constexpr 
auto mxdlperwave_forward_step =
 
  945             constexpr 
auto nxdlperwave_forward_step =
 
  947             constexpr 
auto nxdlperwave_backward_step =
 
  951                 constexpr 
auto mxdlperwave = mxdlperwave_iter;
 
  954                     constexpr 
bool nxdlperwave_forward_sweep =
 
  955                         (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
 
  957                     constexpr 
index_t nxdlperwave_value =
 
  958                         nxdlperwave_forward_sweep
 
  960                             : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
 
  968                     c_thread_copy_vgpr_to_lds.Run(
 
  969                         c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
 
  970                         make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
 
  972                         c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
  979                     c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock,
 
  981                                                    c_grid_desc_mblock_mperblock_nblock_nperblock,
 
  985                     if constexpr(nxdlperwave_forward_sweep &&
 
  986                                  (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
 
  988                         c_block_copy_lds_to_global.MoveDstSliceWindow(
 
  989                             c_grid_desc_mblock_mperblock_nblock_nperblock,
 
  990                             nxdlperwave_forward_step);
 
  992                     else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
 
  994                         c_block_copy_lds_to_global.MoveDstSliceWindow(
 
  995                             c_grid_desc_mblock_mperblock_nblock_nperblock,
 
  996                             nxdlperwave_backward_step);
 
 1001                 if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
 
 1003                     c_block_copy_lds_to_global.MoveDstSliceWindow(
 
 1004                         c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step);
 
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
 
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
 
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
__host__ __device__ multiplies() -> multiplies< void, void >
FIXME: create macro to replace 'host device' and nothing more.
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
__global__ void kernel_gemm_xdlops_bwd_weight(const FloatA *__restrict__ p_a_grid, const FloatB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const CBlockClusterAdaptor c_block_cluster_adaptor)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:157
 
InMemoryDataOperationEnum
Definition: ck.hpp:275
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
 
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
 
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
 
ushort bhalf_t
Definition: data_type.hpp:29
 
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
 
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
 
__host__ constexpr __device__ auto container_reverse_exclusive_scan(const Array< TData, NSize > &x, Reduce f, TData init)
Definition: container_helper.hpp:213
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
 
int32_t index_t
Definition: ck.hpp:297
 
__host__ constexpr __device__ auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition: container_helper.hpp:111
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
 
__host__ __device__ void print_multi_index(const Tuple< Xs... > &x)
Definition: statically_indexed_array_multi_index.hpp:147
 
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
 
__host__ constexpr __device__ auto make_merge_transform_v4_no_carry(const LowLengths &low_lengths)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:136
 
Definition: block_to_ctile_map.hpp:719
 
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
 
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_smfmac_xdlops.hpp:79
 
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:250
 
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:266
 
__host__ static constexpr __device__ auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:395
 
ComputeTypeB FloatBAdjusted
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:280
 
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:263
 
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:497
 
decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{})) CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:618
 
__host__ static constexpr __device__ auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:293
 
ComputeTypeA FloatAAdjusted
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:279
 
__host__ static constexpr __device__ auto GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:339
 
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:605
 
__host__ static constexpr __device__ bool CalculateHasMainK0BlockLoop(index_t K0)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:569
 
__host__ static constexpr __device__ auto MakeCBlockClusterAdaptor(const CMNGridDesc &c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:597
 
static __device__ void Run(const FloatA *__restrict__ p_a_grid, const FloatB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, void *__restrict__ p_shared, const AGridDesc_B_K0_M_K1 &a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1 &b_b_k0_n_k1_grid_desc, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const CBlockClusterAdaptor &c_block_cluster_adaptor)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:622
 
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_B_K0_M_K1 &a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1 &b_b_k0_n_k1_grid_desc, const CMNGridDesc &c_m_n_grid_desc, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:525
 
__host__ static constexpr __device__ auto MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CMNGridDesc &c_m_n_grid_desc)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:580
 
decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)) CBlockClusterAdaptor
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:619
 
__host__ static constexpr __device__ auto GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:441
 
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:25
 
__host__ constexpr __device__ Merge_v4_no_carry(const LowLengths &low_lengths)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:43
 
LowLengthsScan low_lengths_scan_
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:38
 
__host__ constexpr __device__ Merge_v4_no_carry()=default
 
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number< 1 >{}))) UpLengths
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:35
 
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:116
 
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_up_diff, LowIdx &idx_low, const UpIdx &idx_up_new, Number< Hack >) const
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:81
 
static constexpr index_t NDimLow
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:26
 
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:52
 
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:56
 
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:59
 
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:107
 
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:100
 
UpLengths up_lengths_
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:39
 
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number< 1 >{})) LowLengthsScan
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:32
 
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:54
 
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:102
 
LowLengths low_lengths_
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:37
 
__host__ __device__ void Print() const
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:121
 
Definition: xdlops_gemm.hpp:942
 
Definition: sequence.hpp:43
 
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
 
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
Definition: integral_constant.hpp:20
 
Definition: is_known_at_compile_time.hpp:14
 
Definition: functional2.hpp:33
 
Definition: unary_element_wise_operation.hpp:308