20 template <
typename GridwiseGemm,
 
   23           typename AGridDesc_K0_M0_M1_K1,
 
   24           typename BGridDesc_K0_N0_N1_K1,
 
   25           typename CGridDesc_M0_M10_M11_N0_N10_N11,
 
   26           typename Block2CTileMap,
 
   27           bool HasMainKBlockLoop,
 
   28           bool HasDoubleTailKBlockLoop>
 
   30 #if CK_USE_LAUNCH_BOUNDS 
   34                             const FloatAB* __restrict__ p_b_grid,
 
   35                             FloatC* __restrict__ p_c_grid,
 
   36                             const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
 
   37                             const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
 
   38                             const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
 
   39                             const Block2CTileMap block_2_ctile_map)
 
   41     constexpr 
index_t shared_block_size =
 
   42         GridwiseGemm::GetSharedMemoryNumberOfByte() / 
sizeof(FloatAB);
 
   44     __shared__ FloatAB p_shared_block[shared_block_size];
 
   46     GridwiseGemm::Run(p_a_grid,
 
   50                       a_grid_desc_k0_m0_m1_k1,
 
   51                       b_grid_desc_k0_n0_n1_k1,
 
   52                       c_grid_desc_m0_m10_m11_n0_n10_n11,
 
   63           typename AGridDesc_K0_M_K1,
 
   64           typename BGridDesc_K0_N_K1,
 
   65           typename CGridDesc_M_N,
 
   73           typename M11N11ThreadClusterM110Xs,
 
   74           typename M11N11ThreadClusterN110Xs,
 
   75           typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
 
   76           typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
 
   77           typename ABlockTransferThreadClusterArrangeOrder,
 
   78           typename ABlockTransferSrcAccessOrder,
 
   79           typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
 
   80           typename ABlockTransferSrcVectorTensorContiguousDimOrder,
 
   81           typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
 
   82           typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
 
   83           typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
 
   84           typename BBlockTransferThreadClusterArrangeOrder,
 
   85           typename BBlockTransferSrcAccessOrder,
 
   86           typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
 
   87           typename BBlockTransferSrcVectorTensorContiguousDimOrder,
 
   88           typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
 
   89           typename CThreadTransferSrcDstAccessOrder,
 
   90           index_t CThreadTransferSrcDstVectorDim,
 
   91           index_t CThreadTransferDstScalarPerVector>
 
  105         constexpr 
auto max_lds_align = 
K1;
 
  119         constexpr 
auto a_block_aligned_space_size =
 
  122         constexpr 
auto b_block_aligned_space_size =
 
  125         return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * 
sizeof(FloatAB);
 
  128     __host__ __device__ 
static constexpr 
bool 
  130                   const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
 
  131                   const CGridDesc_M_N& c_grid_desc_m_n)
 
  133         const auto M  = a_grid_desc_k0_m_k1.GetLength(
I1);
 
  134         const auto N  = b_grid_desc_k0_n_k1.GetLength(
I1);
 
  135         const auto K0 = a_grid_desc_k0_m_k1.GetLength(
I0);
 
  139         return (M == c_grid_desc_m_n.GetLength(
I0) && N == c_grid_desc_m_n.GetLength(
I1) &&
 
  140                 K0 == b_grid_desc_k0_n_k1.GetLength(
I0) &&
 
  141                 K1 == a_grid_desc_k0_m_k1.GetLength(
I2) &&
 
  142                 K1 == b_grid_desc_k0_n_k1.GetLength(
I2)) &&
 
  143                (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
 
  148         const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
 
  155         const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
 
  157         return has_main_k_block_loop;
 
  162         const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
 
  164         return has_double_tail_k_block_loop;
 
  167     __host__ __device__ 
static constexpr 
auto 
  170         const auto K0 = a_grid_desc_k0_m_k1.GetLength(
I0);
 
  171         const auto M  = a_grid_desc_k0_m_k1.GetLength(
I1);
 
  174         const auto M0 = M / M1;
 
  176         const auto a_grid_desc_k0_m0_m1_k1 =
 
  184         return a_grid_desc_k0_m0_m1_k1;
 
  187     __host__ __device__ 
static constexpr 
auto 
  190         const auto K0 = b_grid_desc_k0_n_k1.GetLength(
I0);
 
  191         const auto N  = b_grid_desc_k0_n_k1.GetLength(
I1);
 
  194         const auto N0 = N / N1;
 
  196         const auto b_grid_desc_k0_n0_n1_k1 =
 
  204         return b_grid_desc_k0_n0_n1_k1;
 
  207     __host__ __device__ 
static constexpr 
auto 
  210         const auto M = c_grid_desc_m_n.GetLength(
I0);
 
  211         const auto N = c_grid_desc_m_n.GetLength(
I1);
 
  216         const auto M0 = M / M1;
 
  217         const auto N0 = N / N1;
 
  226         constexpr 
auto M10 = M1 / M11;
 
  227         constexpr 
auto N10 = N1 / N11;
 
  236         return c_grid_desc_m0_m10_m11_n0_n10_n11;
 
  240     __host__ __device__ 
static constexpr 
auto 
  253     template <
bool HasMainKBlockLoop, 
bool HasDoubleTailKBlockLoop>
 
  254     __device__ 
static void 
  255     Run(
const FloatAB* __restrict__ p_a_grid,
 
  256         const FloatAB* __restrict__ p_b_grid,
 
  257         FloatC* __restrict__ p_c_grid,
 
  258         FloatAB* __restrict__ p_shared_block,
 
  266         const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  267             p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize());
 
  268         const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  269             p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize());
 
  270         auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  271             p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
 
  274         const auto c_m0_n0_block_cluster_idx =
 
  278         const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[
I0]);
 
  279         const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[
I1]);
 
  281         if(!block_2_ctile_map.ValidCTileIndex(
 
  283                make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(
I0),
 
  284                           c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(
I3))))
 
  290         constexpr 
auto max_lds_align = 
K1;
 
  314         static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
 
  315                           a_k0_m_k1_block_desc.GetElementSpaceSize() &&
 
  316                       b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
 
  317                           b_k0_n_k1_block_desc.GetElementSpaceSize() &&
 
  325             ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
 
  326             ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
 
  327             ABlockTransferThreadClusterArrangeOrder,
 
  331             decltype(a_block_desc_k0_m0_m1_k1),
 
  332             ABlockTransferSrcAccessOrder,
 
  334             ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, 
 
  335             ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, 
 
  336             ABlockTransferSrcVectorTensorContiguousDimOrder,  
 
  339             true>(a_grid_desc_k0_m0_m1_k1,
 
  341                   a_block_desc_k0_m0_m1_k1,
 
  349             BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
 
  350             BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
 
  351             BBlockTransferThreadClusterArrangeOrder,
 
  355             decltype(b_block_desc_k0_n0_n1_k1),
 
  356             BBlockTransferSrcAccessOrder,
 
  358             BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, 
 
  359             BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, 
 
  360             BBlockTransferSrcVectorTensorContiguousDimOrder,  
 
  363             true>(b_grid_desc_k0_n0_n1_k1,
 
  365                   b_block_desc_k0_n0_n1_k1,
 
  374         const auto blockwise_gemm =
 
  380                 decltype(a_k0_m_k1_block_desc),
 
  381                 decltype(b_k0_n_k1_block_desc),
 
  385                 M11N11ThreadClusterM110Xs,
 
  386                 M11N11ThreadClusterN110Xs,
 
  390         constexpr 
auto c_m10_m11_n10_n11_thread_tensor_lengths =
 
  391             decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
 
  398             a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
 
  401             b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
 
  403         FloatAB* p_a_block_double = p_shared_block;
 
  404         FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
 
  407         auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
 
  408             c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
 
  411         c_thread_buf.Clear();
 
  413         constexpr 
auto a_block_slice_copy_step = 
make_multi_index(K0PerBlock, 0, 0, 0);
 
  414         constexpr 
auto b_block_slice_copy_step = 
make_multi_index(K0PerBlock, 0, 0, 0);
 
  416         auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  417             p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
 
  418         auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  419             p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
 
  421         auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  422             p_a_block_double + a_block_aligned_space_size,
 
  423             a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
 
  424         auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  425             p_b_block_double + b_block_aligned_space_size,
 
  426             b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
 
  430             a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
 
  431             b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
 
  433             a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
 
  434             b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
 
  437         if constexpr(HasMainKBlockLoop)
 
  439             const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(
I0);
 
  441             index_t k_block_data_begin = 0;
 
  448                 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
 
  449                                                     a_block_slice_copy_step);
 
  450                 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
 
  451                                                     b_block_slice_copy_step);
 
  454                 a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
 
  455                 b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
 
  460                 blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11,
 
  466                 a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
 
  467                 b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
 
  470                 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
 
  471                                                     a_block_slice_copy_step);
 
  472                 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
 
  473                                                     b_block_slice_copy_step);
 
  476                 a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
 
  477                 b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
 
  483                     c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
 
  486                 a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
 
  487                 b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
 
  489                 k_block_data_begin += 2 * K0PerBlock;
 
  490             } 
while(k_block_data_begin < K0 - 2 * K0PerBlock);
 
  494         if constexpr(HasDoubleTailKBlockLoop) 
 
  496             a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step);
 
  497             b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step);
 
  502             a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
 
  503             b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
 
  507                 c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
 
  510             a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
 
  511             b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
 
  517                 c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
 
  525                 c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
 
  530             constexpr 
auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
 
  533                                Number<c_m10_m11_n10_n11_thread_tensor_lengths[
I0]>{},
 
  539             const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
 
  540                 blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
 
  546                 decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
 
  547                 decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
 
  550                          c_m10_m11_n10_n11_thread_tensor_lengths[
I0],
 
  551                          c_m10_m11_n10_n11_thread_tensor_lengths[
I1],
 
  553                          c_m10_m11_n10_n11_thread_tensor_lengths[
I2],
 
  554                          c_m10_m11_n10_n11_thread_tensor_lengths[
I3]>,
 
  555                 CThreadTransferSrcDstAccessOrder,
 
  556                 CThreadTransferSrcDstVectorDim,
 
  557                 CThreadTransferDstScalarPerVector,
 
  558                 CGlobalMemoryDataOperation,
 
  560                 true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
 
  562                                        c_m10_m11_n10_n11_thread_origin_idx_on_block[
I0],
 
  563                                        c_m10_m11_n10_n11_thread_origin_idx_on_block[
I1],
 
  565                                        c_m10_m11_n10_n11_thread_origin_idx_on_block[
I2],
 
  566                                        c_m10_m11_n10_n11_thread_origin_idx_on_block[
I3]),
 
  568                 .Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
 
  571                      c_grid_desc_m0_m10_m11_n0_n10_n11,
 
  582           typename AGridDesc_B_K0_M_K1,
 
  583           typename BGridDesc_B_K0_N_K1,
 
  584           typename CGridDesc_M_N,
 
  592           typename M11N11ThreadClusterM110Xs,
 
  593           typename M11N11ThreadClusterN110Xs,
 
  594           typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
 
  595           typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
 
  596           typename ABlockTransferThreadClusterArrangeOrder,
 
  597           typename ABlockTransferSrcAccessOrder,
 
  598           typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
 
  599           typename ABlockTransferSrcVectorTensorContiguousDimOrder,
 
  600           typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
 
  601           typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
 
  602           typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
 
  603           typename BBlockTransferThreadClusterArrangeOrder,
 
  604           typename BBlockTransferSrcAccessOrder,
 
  605           typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
 
  606           typename BBlockTransferSrcVectorTensorContiguousDimOrder,
 
  607           typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
 
  608           typename CThreadTransferSrcDstAccessOrder,
 
  609           index_t CThreadTransferSrcDstVectorDim,
 
  610           index_t CThreadTransferDstScalarPerVector>
 
  624         constexpr 
auto max_lds_align = 
K1;
 
  639             a_block_desc_b_k0_m_k1.GetElementSpaceSize(), max_lds_align);
 
  642             b_block_desc_b_k0_n_k1.GetElementSpaceSize(), max_lds_align);
 
  644         return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * 
sizeof(FloatAB);
 
  647     __host__ __device__ 
static constexpr 
bool 
  649                   const BGridDesc_B_K0_N_K1& b_grid_desc_b_k0_n_k1,
 
  650                   const CGridDesc_M_N& c_grid_desc_m_n)
 
  654         if(!(a_grid_desc_b_k0_m_k1.GetElementSpaceSize() * 
sizeof(FloatAB) <= TwoGB &&
 
  655              b_grid_desc_b_k0_n_k1.GetElementSpaceSize() * 
sizeof(FloatAB) <= TwoGB &&
 
  656              c_grid_desc_m_n.GetElementSpaceSize() * 
sizeof(FloatC) <= TwoGB))
 
  661         const auto M      = a_grid_desc_b_k0_m_k1.GetLength(
I2);
 
  662         const auto N      = b_grid_desc_b_k0_n_k1.GetLength(
I2);
 
  663         const auto K0     = a_grid_desc_b_k0_m_k1.GetLength(
I1);
 
  664         const auto KBatch = a_grid_desc_b_k0_m_k1.GetLength(
I0);
 
  668         return (M == c_grid_desc_m_n.GetLength(
I0) && N == c_grid_desc_m_n.GetLength(
I1) &&
 
  669                 K0 == b_grid_desc_b_k0_n_k1.GetLength(
I1) &&
 
  670                 K1 == a_grid_desc_b_k0_m_k1.GetLength(
I3) &&
 
  671                 K1 == b_grid_desc_b_k0_n_k1.GetLength(
I3)) &&
 
  672                KBatch == b_grid_desc_b_k0_n_k1.GetLength(
I0) &&
 
  673                (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
 
  678         const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
 
  685         const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
 
  687         return has_main_k_block_loop;
 
  692         const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
 
  694         return has_double_tail_k_block_loop;
 
  697     __host__ __device__ 
static constexpr 
auto 
  700         const auto KBatch = a_grid_desc_b_k0_m_k1.GetLength(
I0);
 
  701         const auto K0     = a_grid_desc_b_k0_m_k1.GetLength(
I1);
 
  702         const auto M      = a_grid_desc_b_k0_m_k1.GetLength(
I2);
 
  705         const auto M0 = M / M1;
 
  708             a_grid_desc_b_k0_m_k1,
 
  716         return a_grid_desc_b_k0_m0_m1_k1;
 
  719     __host__ __device__ 
static constexpr 
auto 
  722         const auto KBatch = b_grid_desc_b_k0_n_k1.GetLength(
I0);
 
  723         const auto K0     = b_grid_desc_b_k0_n_k1.GetLength(
I1);
 
  724         const auto N      = b_grid_desc_b_k0_n_k1.GetLength(
I2);
 
  727         const auto N0 = N / N1;
 
  730             b_grid_desc_b_k0_n_k1,
 
  738         return b_grid_desc_b_k0_n0_n1_k1;
 
  741     __host__ __device__ 
static constexpr 
auto 
  744         const auto M = c_grid_desc_m_n.GetLength(
I0);
 
  745         const auto N = c_grid_desc_m_n.GetLength(
I1);
 
  750         const auto M0 = M / M1;
 
  751         const auto N0 = N / N1;
 
  760         constexpr 
auto M10 = M1 / M11;
 
  761         constexpr 
auto N10 = N1 / N11;
 
  770         return c_grid_desc_m0_m10_m11_n0_n10_n11;
 
  778             c_m_n_grid_desc, M01, N01, KBatch);
 
  789     template <
bool HasMainKBlockLoop, 
bool HasDoubleTailKBlockLoop>
 
  790     __device__ 
static void 
  791     Run(
const FloatAB* __restrict__ p_a_grid,
 
  792         const FloatAB* __restrict__ p_b_grid,
 
  793         FloatC* __restrict__ p_c_grid,
 
  794         FloatAB* __restrict__ p_shared_block,
 
  802         const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  803             p_a_grid, a_grid_desc_b_k0_m0_m1_k1.GetElementSpaceSize());
 
  804         const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  805             p_b_grid, b_grid_desc_b_k0_n0_n1_k1.GetElementSpaceSize());
 
  806         auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  807             p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
 
  810         const auto block_work_idx =
 
  813         const index_t k_batch_id = block_work_idx[
I0];
 
  815         if(!c_block_cluster_adaptor.ValidCTileIndex(
 
  817                make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(
I0),
 
  818                           c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(
I3))))
 
  824         const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
 
  826         const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[
I2]);
 
  829         constexpr 
auto max_lds_align = 
K1;
 
  865         static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
 
  866                           a_k0_m_k1_block_desc.GetElementSpaceSize() &&
 
  867                       b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
 
  868                           b_k0_n_k1_block_desc.GetElementSpaceSize() &&
 
  875             Sequence<1, K0PerBlock, 1, MPerBlock, 
K1.value>,
 
  876             ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
 
  877             ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
 
  878             ABlockTransferThreadClusterArrangeOrder,
 
  882             decltype(a_block_desc_b_k0_m0_m1_k1),
 
  883             ABlockTransferSrcAccessOrder,
 
  885             ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, 
 
  886             ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, 
 
  887             ABlockTransferSrcVectorTensorContiguousDimOrder,  
 
  890             true>(a_grid_desc_b_k0_m0_m1_k1,
 
  892                   a_block_desc_b_k0_m0_m1_k1,
 
  899             Sequence<1, K0PerBlock, 1, NPerBlock, 
K1.value>,
 
  900             BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
 
  901             BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
 
  902             BBlockTransferThreadClusterArrangeOrder,
 
  906             decltype(b_block_desc_b_k0_n0_n1_k1),
 
  907             BBlockTransferSrcAccessOrder,
 
  909             BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, 
 
  910             BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, 
 
  911             BBlockTransferSrcVectorTensorContiguousDimOrder,  
 
  914             true>(b_grid_desc_b_k0_n0_n1_k1,
 
  916                   b_block_desc_b_k0_n0_n1_k1,
 
  925         const auto blockwise_gemm =
 
  931                 decltype(a_k0_m_k1_block_desc),
 
  932                 decltype(b_k0_n_k1_block_desc),
 
  936                 M11N11ThreadClusterM110Xs,
 
  937                 M11N11ThreadClusterN110Xs,
 
  941         constexpr 
auto c_m10_m11_n10_n11_thread_tensor_lengths =
 
  942             decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
 
  949             a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
 
  952             b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
 
  954         FloatAB* p_a_block_double = p_shared_block;
 
  955         FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
 
  958         auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
 
  959             c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
 
  962         c_thread_buf.Clear();
 
  964         constexpr 
auto a_block_slice_copy_step = 
make_multi_index(0, K0PerBlock, 0, 0, 0);
 
  965         constexpr 
auto b_block_slice_copy_step = 
make_multi_index(0, K0PerBlock, 0, 0, 0);
 
  967         auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  968             p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
 
  969         auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  970             p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
 
  972         auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  973             p_a_block_double + a_block_aligned_space_size,
 
  974             a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
 
  975         auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  976             p_b_block_double + b_block_aligned_space_size,
 
  977             b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
 
  981             a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
 
  982             b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
 
  984             a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_even_buf);
 
  985             b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_even_buf);
 
  988         if constexpr(HasMainKBlockLoop)
 
  990             const auto K0 = a_grid_desc_b_k0_m0_m1_k1.GetLength(
I1);
 
  992             index_t k_block_data_begin = 0;
 
  999                 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_b_k0_m0_m1_k1,
 
 1000                                                     a_block_slice_copy_step);
 
 1001                 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1,
 
 1002                                                     b_block_slice_copy_step);
 
 1005                 a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
 
 1006                 b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
 
 1011                 blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11,
 
 1017                 a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_odd_buf);
 
 1018                 b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_odd_buf);
 
 1021                 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_b_k0_m0_m1_k1,
 
 1022                                                     a_block_slice_copy_step);
 
 1023                 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1,
 
 1024                                                     b_block_slice_copy_step);
 
 1027                 a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
 
 1028                 b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
 
 1034                     c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
 
 1037                 a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_even_buf);
 
 1038                 b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_even_buf);
 
 1040                 k_block_data_begin += 2 * K0PerBlock;
 
 1041             } 
while(k_block_data_begin < K0 - 2 * K0PerBlock);
 
 1045         if constexpr(HasDoubleTailKBlockLoop) 
 
 1047             a_blockwise_copy.
MoveSrcSliceWindow(a_grid_desc_b_k0_m0_m1_k1, a_block_slice_copy_step);
 
 1048             b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1, b_block_slice_copy_step);
 
 1053             a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
 
 1054             b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
 
 1058                 c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
 
 1061             a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_odd_buf);
 
 1062             b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_odd_buf);
 
 1068                 c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
 
 1076                 c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
 
 1081             constexpr 
auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
 
 1084                                Number<c_m10_m11_n10_n11_thread_tensor_lengths[
I0]>{},
 
 1090             const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
 
 1091                 blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
 
 1097                 decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
 
 1098                 decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
 
 1101                          c_m10_m11_n10_n11_thread_tensor_lengths[
I0],
 
 1102                          c_m10_m11_n10_n11_thread_tensor_lengths[
I1],
 
 1104                          c_m10_m11_n10_n11_thread_tensor_lengths[
I2],
 
 1105                          c_m10_m11_n10_n11_thread_tensor_lengths[
I3]>,
 
 1106                 CThreadTransferSrcDstAccessOrder,
 
 1107                 CThreadTransferSrcDstVectorDim,
 
 1108                 CThreadTransferDstScalarPerVector,
 
 1109                 CGlobalMemoryDataOperation,
 
 1111                 true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
 
 1113                                        c_m10_m11_n10_n11_thread_origin_idx_on_block[
I0],
 
 1114                                        c_m10_m11_n10_n11_thread_origin_idx_on_block[
I1],
 
 1115                                        n_block_data_idx_on_grid,
 
 1116                                        c_m10_m11_n10_n11_thread_origin_idx_on_block[
I2],
 
 1117                                        c_m10_m11_n10_n11_thread_origin_idx_on_block[
I3]),
 
 1119                 .Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
 
 1122                      c_grid_desc_m0_m10_m11_n0_n10_n11,
 
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
 
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
 
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
 
InMemoryDataOperationEnum
Definition: ck.hpp:278
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
int64_t long_index_t
Definition: ck.hpp:301
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
 
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
 
__global__ void kernel_gemm_dl_v1r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_dl_v1r3.hpp:33
 
__host__ constexpr __device__ auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition: container_helper.hpp:380
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
 
int32_t index_t
Definition: ck.hpp:300
 
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
 
__host__ constexpr __device__ auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition: container_helper.hpp:111
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
Definition: block_to_ctile_map.hpp:719
 
Definition: block_to_ctile_map.hpp:616
 
Definition: blockwise_gemm_dl_v2r3.hpp:47
 
Definition: blockwise_tensor_slice_transfer_v5r1.hpp:37
 
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step)
Definition: blockwise_tensor_slice_transfer_v5r1.hpp:100
 
Definition: gridwise_gemm_dl_v1r3.hpp:612
 
static constexpr auto I2
Definition: gridwise_gemm_dl_v1r3.hpp:615
 
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_dl_v1r3.hpp:621
 
decltype(MakeAGridDescriptor_B_K0_M0_M1_K1(AGridDesc_B_K0_M_K1{})) AGridDesc_B_K0_M0_M1_K1
Definition: gridwise_gemm_dl_v1r3.hpp:782
 
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:742
 
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_dl_v1r3.hpp:676
 
__host__ static constexpr __device__ auto MakeAGridDescriptor_B_K0_M0_M1_K1(const AGridDesc_B_K0_M_K1 &a_grid_desc_b_k0_m_k1)
Definition: gridwise_gemm_dl_v1r3.hpp:698
 
__host__ static constexpr __device__ auto MakeCBlockClusterAdaptor(const CGridDesc_M_N &c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
Definition: gridwise_gemm_dl_v1r3.hpp:774
 
decltype(MakeBGridDescriptor_B_K0_N0_N1_K1(BGridDesc_B_K0_N_K1{})) BGridDesc_B_K0_N0_N1_K1
Definition: gridwise_gemm_dl_v1r3.hpp:784
 
__host__ static constexpr __device__ bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_v1r3.hpp:690
 
static constexpr auto K1
Definition: gridwise_gemm_dl_v1r3.hpp:619
 
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_v1r3.hpp:683
 
static constexpr auto I1
Definition: gridwise_gemm_dl_v1r3.hpp:614
 
__host__ static constexpr __device__ auto MakeBGridDescriptor_B_K0_N0_N1_K1(const BGridDesc_B_K0_N_K1 &b_grid_desc_b_k0_n_k1)
Definition: gridwise_gemm_dl_v1r3.hpp:720
 
static constexpr auto I0
Definition: gridwise_gemm_dl_v1r3.hpp:613
 
static constexpr auto I3
Definition: gridwise_gemm_dl_v1r3.hpp:616
 
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_B_K0_M_K1 &a_grid_desc_b_k0_m_k1, const BGridDesc_B_K0_N_K1 &b_grid_desc_b_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:648
 
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition: gridwise_gemm_dl_v1r3.hpp:786
 
decltype(MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)) CBlockClusterAdaptor
Definition: gridwise_gemm_dl_v1r3.hpp:787
 
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, FloatAB *__restrict__ p_shared_block, const AGridDesc_B_K0_M0_M1_K1 &a_grid_desc_b_k0_m0_m1_k1, const BGridDesc_B_K0_N0_N1_K1 &b_grid_desc_b_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 &c_grid_desc_m0_m10_m11_n0_n10_n11, const CBlockClusterAdaptor &c_block_cluster_adaptor, integral_constant< bool, HasMainKBlockLoop >, integral_constant< bool, HasDoubleTailKBlockLoop >)
Definition: gridwise_gemm_dl_v1r3.hpp:791
 
Definition: gridwise_gemm_dl_v1r3.hpp:93
 
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, FloatAB *__restrict__ p_shared_block, const AGridDesc_K0_M0_M1_K1 &a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 &b_grid_desc_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 &c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap &block_2_ctile_map, integral_constant< bool, HasMainKBlockLoop >, integral_constant< bool, HasDoubleTailKBlockLoop >)
Definition: gridwise_gemm_dl_v1r3.hpp:255
 
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:129
 
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:208
 
static constexpr auto I2
Definition: gridwise_gemm_dl_v1r3.hpp:96
 
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_dl_v1r3.hpp:102
 
static constexpr auto K1
Definition: gridwise_gemm_dl_v1r3.hpp:100
 
decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})) AGridDesc_K0_M0_M1_K1
Definition: gridwise_gemm_dl_v1r3.hpp:247
 
__host__ static constexpr __device__ auto MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1)
Definition: gridwise_gemm_dl_v1r3.hpp:188
 
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_dl_v1r3.hpp:146
 
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_v1r3.hpp:153
 
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:241
 
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition: gridwise_gemm_dl_v1r3.hpp:250
 
__host__ static constexpr __device__ bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_v1r3.hpp:160
 
decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{})) Block2CTileMap
Definition: gridwise_gemm_dl_v1r3.hpp:251
 
static constexpr auto I1
Definition: gridwise_gemm_dl_v1r3.hpp:95
 
decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})) BGridDesc_K0_N0_N1_K1
Definition: gridwise_gemm_dl_v1r3.hpp:248
 
static constexpr auto I0
Definition: gridwise_gemm_dl_v1r3.hpp:94
 
__host__ static constexpr __device__ auto MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1)
Definition: gridwise_gemm_dl_v1r3.hpp:168
 
static constexpr auto I3
Definition: gridwise_gemm_dl_v1r3.hpp:97
 
Definition: sequence.hpp:43
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
Definition: integral_constant.hpp:20
 
Definition: unary_element_wise_operation.hpp:308