26           typename AElementwiseOperation,
 
   27           typename BElementwiseOperation,
 
   28           typename CDEElementwiseOperation,
 
   30           typename AGridDesc_K0_M_K1,
 
   31           typename BGridDesc_K0_N_K1,
 
   32           typename CGridDesc_M_N,
 
   40           typename M11N11ThreadClusterM110Xs,
 
   41           typename M11N11ThreadClusterN110Xs,
 
   42           typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
 
   43           typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
 
   44           typename ABlockTransferThreadClusterArrangeOrder,
 
   45           typename ABlockTransferSrcAccessOrder,
 
   46           typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
 
   47           typename ABlockTransferSrcVectorTensorContiguousDimOrder,
 
   48           typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
 
   49           typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
 
   50           typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
 
   51           typename BBlockTransferThreadClusterArrangeOrder,
 
   52           typename BBlockTransferSrcAccessOrder,
 
   53           typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
 
   54           typename BBlockTransferSrcVectorTensorContiguousDimOrder,
 
   55           typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
 
   56           typename CThreadTransferSrcDstAccessOrder,
 
   57           index_t CThreadTransferSrcDstVectorDim,
 
   58           index_t CThreadTransferDstScalarPerVector>
 
   78                 return static_cast<const DDataType*
>(
nullptr);
 
   86         constexpr 
auto max_lds_align = 
K1;
 
  100         constexpr 
auto a_block_aligned_space_size =
 
  103         constexpr 
auto b_block_aligned_space_size =
 
  106         return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * 
sizeof(FloatAB);
 
  109     __host__ __device__ 
static constexpr 
bool 
  111                   const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
 
  112                   const CGridDesc_M_N& c_grid_desc_m_n)
 
  116         if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * 
sizeof(FloatAB) <= TwoGB &&
 
  117              b_grid_desc_k0_n_k1.GetElementSpaceSize() * 
sizeof(FloatAB) <= TwoGB &&
 
  118              c_grid_desc_m_n.GetElementSpaceSize() * 
sizeof(FloatC) <= TwoGB))
 
  123         const auto M  = a_grid_desc_k0_m_k1.GetLength(
I1);
 
  124         const auto N  = b_grid_desc_k0_n_k1.GetLength(
I1);
 
  125         const auto K0 = a_grid_desc_k0_m_k1.GetLength(
I0);
 
  129         return (M == c_grid_desc_m_n.GetLength(
I0) && N == c_grid_desc_m_n.GetLength(
I1) &&
 
  130                 K0 == b_grid_desc_k0_n_k1.GetLength(
I0) &&
 
  131                 K1 == a_grid_desc_k0_m_k1.GetLength(
I2) &&
 
  132                 K1 == b_grid_desc_k0_n_k1.GetLength(
I2)) &&
 
  133                (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
 
  138         const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
 
  145         const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
 
  147         return has_main_k_block_loop;
 
  152         const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
 
  154         return has_double_tail_k_block_loop;
 
  157     __host__ __device__ 
static constexpr 
auto 
  160         const auto K0 = a_grid_desc_k0_m_k1.GetLength(
I0);
 
  161         const auto M  = a_grid_desc_k0_m_k1.GetLength(
I1);
 
  164         const auto M0 = M / M1;
 
  166         const auto a_grid_desc_k0_m0_m1_k1 =
 
  174         return a_grid_desc_k0_m0_m1_k1;
 
  177     __host__ __device__ 
static constexpr 
auto 
  180         const auto K0 = b_grid_desc_k0_n_k1.GetLength(
I0);
 
  181         const auto N  = b_grid_desc_k0_n_k1.GetLength(
I1);
 
  184         const auto N0 = N / N1;
 
  186         const auto b_grid_desc_k0_n0_n1_k1 =
 
  194         return b_grid_desc_k0_n0_n1_k1;
 
  198     template <
typename CGr
idDesc_M_N_>
 
  199     __host__ __device__ 
static constexpr 
auto 
  202         const auto M = c_grid_desc_m_n.GetLength(
I0);
 
  203         const auto N = c_grid_desc_m_n.GetLength(
I1);
 
  208         const auto M0 = M / M1;
 
  209         const auto N0 = N / N1;
 
  218         constexpr 
auto M10 = M1 / M11;
 
  219         constexpr 
auto N10 = N1 / N11;
 
  228         return c_grid_desc_m0_m10_m11_n0_n10_n11;
 
  232     template <
typename DsGr
idDesc_M_N>
 
  233     __host__ __device__ 
static constexpr 
auto 
  241     __host__ __device__ 
static constexpr 
auto 
  255     template <
typename DsGridDesc_M0_M10_M11_N0_N10_N11,
 
  256               bool HasMainKBlockLoop,
 
  257               bool HasDoubleTailKBlockLoop,
 
  258               typename Block2CTileMap>
 
  259     __device__ 
static void 
  260     Run(
const FloatAB* __restrict__ p_a_grid,
 
  261         const FloatAB* __restrict__ p_b_grid,
 
  263         FloatC* __restrict__ p_c_grid,
 
  264         void* __restrict__ p_shared_block,
 
  265         const AElementwiseOperation&,
 
  266         const BElementwiseOperation&,
 
  267         const CDEElementwiseOperation& cde_element_op,
 
  270         const DsGridDesc_M0_M10_M11_N0_N10_N11& ds_grid_desc_m0_m10_m11_n0_n10_n11,
 
  272         const Block2CTileMap& block_2_ctile_map,
 
  276         const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  277             p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize());
 
  278         const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  279             p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize());
 
  280         auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  281             p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
 
  284         const auto c_m0_n0_block_cluster_idx =
 
  288         const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[
I0]);
 
  289         const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[
I1]);
 
  291         if(!block_2_ctile_map.ValidCTileIndex(
 
  293                make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(
I0),
 
  294                           c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(
I3))))
 
  300         constexpr 
auto max_lds_align = 
K1;
 
  324         static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
 
  325                           a_k0_m_k1_block_desc.GetElementSpaceSize() &&
 
  326                       b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
 
  327                           b_k0_n_k1_block_desc.GetElementSpaceSize() &&
 
  335             ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
 
  336             ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
 
  337             ABlockTransferThreadClusterArrangeOrder,
 
  341             decltype(a_block_desc_k0_m0_m1_k1),
 
  342             ABlockTransferSrcAccessOrder,
 
  344             ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, 
 
  345             ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, 
 
  346             ABlockTransferSrcVectorTensorContiguousDimOrder,  
 
  349             true>(a_grid_desc_k0_m0_m1_k1,
 
  351                   a_block_desc_k0_m0_m1_k1,
 
  359             BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
 
  360             BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
 
  361             BBlockTransferThreadClusterArrangeOrder,
 
  365             decltype(b_block_desc_k0_n0_n1_k1),
 
  366             BBlockTransferSrcAccessOrder,
 
  368             BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, 
 
  369             BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, 
 
  370             BBlockTransferSrcVectorTensorContiguousDimOrder,  
 
  373             true>(b_grid_desc_k0_n0_n1_k1,
 
  375                   b_block_desc_k0_n0_n1_k1,
 
  384         const auto blockwise_gemm =
 
  390                 decltype(a_k0_m_k1_block_desc),
 
  391                 decltype(b_k0_n_k1_block_desc),
 
  395                 M11N11ThreadClusterM110Xs,
 
  396                 M11N11ThreadClusterN110Xs,
 
  400         constexpr 
auto c_m10_m11_n10_n11_thread_tensor_lengths =
 
  401             decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
 
  408             a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
 
  411             b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
 
  413         FloatAB* p_a_block_double = 
static_cast<FloatAB*
>(p_shared_block);
 
  414         FloatAB* p_b_block_double =
 
  415             static_cast<FloatAB*
>(p_shared_block) + 2 * a_block_aligned_space_size;
 
  418         auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
 
  419             c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
 
  422         c_thread_buf.Clear();
 
  424         constexpr 
auto a_block_slice_copy_step = 
make_multi_index(K0PerBlock, 0, 0, 0);
 
  425         constexpr 
auto b_block_slice_copy_step = 
make_multi_index(K0PerBlock, 0, 0, 0);
 
  427         auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  428             p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
 
  429         auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  430             p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
 
  432         auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  433             p_a_block_double + a_block_aligned_space_size,
 
  434             a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
 
  435         auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  436             p_b_block_double + b_block_aligned_space_size,
 
  437             b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
 
  441             a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
 
  442             b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
 
  444             a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
 
  445             b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
 
  448         if constexpr(HasMainKBlockLoop)
 
  450             const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(
I0);
 
  452             index_t k_block_data_begin = 0;
 
  459                 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
 
  460                                                     a_block_slice_copy_step);
 
  461                 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
 
  462                                                     b_block_slice_copy_step);
 
  465                 a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
 
  466                 b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
 
  471                 blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11,
 
  477                 a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
 
  478                 b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
 
  481                 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
 
  482                                                     a_block_slice_copy_step);
 
  483                 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
 
  484                                                     b_block_slice_copy_step);
 
  487                 a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
 
  488                 b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
 
  494                     c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
 
  497                 a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
 
  498                 b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
 
  500                 k_block_data_begin += 2 * K0PerBlock;
 
  501             } 
while(k_block_data_begin < K0 - 2 * K0PerBlock);
 
  505         if constexpr(HasDoubleTailKBlockLoop) 
 
  507             a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step);
 
  508             b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step);
 
  513             a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
 
  514             b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
 
  518                 c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
 
  521             a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
 
  522             b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
 
  528                 c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
 
  536                 c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
 
  541             constexpr 
auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
 
  544                                Number<c_m10_m11_n10_n11_thread_tensor_lengths[
I0]>{},
 
  550             const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
 
  551                 blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
 
  556                     return make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  557                         p_ds_grid[i], ds_grid_desc_m0_m10_m11_n0_n10_n11[i].GetElementSpaceSize());
 
  567                                         c_m10_m11_n10_n11_thread_tensor_lengths[
I3],
 
  579                         decltype(ds_grid_desc_m0_m10_m11_n0_n10_n11[i]),
 
  580                         decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
 
  587                         CThreadTransferSrcDstAccessOrder,
 
  588                         CThreadTransferSrcDstVectorDim,
 
  589                         CThreadTransferDstScalarPerVector,
 
  591                         false>(ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
 
  593                                                 c_m10_m11_n10_n11_thread_origin_idx_on_block[
I0],
 
  594                                                 c_m10_m11_n10_n11_thread_origin_idx_on_block[
I1],
 
  596                                                 c_m10_m11_n10_n11_thread_origin_idx_on_block[
I2],
 
  597                                                 c_m10_m11_n10_n11_thread_origin_idx_on_block[
I3]));
 
  606                             ds_threadwise_copy(i).Run(ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
 
  608                                                       c_thread_desc_m0_m10_m11_n0_n10_n11,
 
  618                                     [&](
auto iSrc) -> 
const auto& {
 
  619                                         return ds_thread_buf[iSrc][i];
 
  625                                     c_thread_desc_m0_m10_m11_n0_n10_n11.CalculateOffset(
 
  632                                 unpack2(cde_element_op, dst_data_refs, src_data_refs);
 
  636                             ds_threadwise_copy(i).MoveSrcSliceWindow(
 
  637                                 ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
 
  642                         ds_threadwise_copy(i).MoveSrcSliceWindow(
 
  643                             ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
 
  645                                 0, 0, 1, 0, -c_m10_m11_n10_n11_thread_tensor_lengths[
I2], 0));
 
  649                     ds_threadwise_copy(i).MoveSrcSliceWindow(
 
  650                         ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
 
  652                             0, 1, -c_m10_m11_n10_n11_thread_tensor_lengths[
I1], 0, 0, 0));
 
  659                 decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
 
  660                 decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
 
  663                          c_m10_m11_n10_n11_thread_tensor_lengths[
I0],
 
  664                          c_m10_m11_n10_n11_thread_tensor_lengths[
I1],
 
  666                          c_m10_m11_n10_n11_thread_tensor_lengths[
I2],
 
  667                          c_m10_m11_n10_n11_thread_tensor_lengths[
I3]>,
 
  668                 CThreadTransferSrcDstAccessOrder,
 
  669                 CThreadTransferSrcDstVectorDim,
 
  670                 CThreadTransferDstScalarPerVector,
 
  671                 CGlobalMemoryDataOperation,
 
  673                 true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
 
  675                                        c_m10_m11_n10_n11_thread_origin_idx_on_block[
I0],
 
  676                                        c_m10_m11_n10_n11_thread_origin_idx_on_block[
I1],
 
  678                                        c_m10_m11_n10_n11_thread_origin_idx_on_block[
I2],
 
  679                                        c_m10_m11_n10_n11_thread_origin_idx_on_block[
I3]),
 
  681                 .Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
 
  684                      c_grid_desc_m0_m10_m11_n0_n10_n11,
 
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
 
__host__ constexpr __device__ auto unpack2(F &&f, X &&x, Y &&y)
Definition: functional4.hpp:55
 
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
 
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
 
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
 
InMemoryDataOperationEnum
Definition: ck.hpp:278
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
int64_t long_index_t
Definition: ck.hpp:301
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
 
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
 
__host__ constexpr __device__ auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition: container_helper.hpp:380
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
 
int32_t index_t
Definition: ck.hpp:300
 
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
 
__host__ constexpr __device__ auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition: container_helper.hpp:111
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
Definition: block_to_ctile_map.hpp:616
 
Definition: blockwise_gemm_dl_v2r3.hpp:47
 
Definition: blockwise_tensor_slice_transfer_v5r1.hpp:37
 
Definition: gridwise_gemm_dl_multiple_d.hpp:60
 
__host__ static constexpr __device__ auto MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1)
Definition: gridwise_gemm_dl_multiple_d.hpp:178
 
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_multiple_d.hpp:143
 
__host__ static constexpr __device__ auto MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(const DsGridDesc_M_N &ds_grid_desc_m_n)
Definition: gridwise_gemm_dl_multiple_d.hpp:234
 
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_multiple_d.hpp:242
 
static constexpr auto I2
Definition: gridwise_gemm_dl_multiple_d.hpp:65
 
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_dl_multiple_d.hpp:253
 
static constexpr auto I1
Definition: gridwise_gemm_dl_multiple_d.hpp:64
 
__host__ static constexpr __device__ auto MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1)
Definition: gridwise_gemm_dl_multiple_d.hpp:158
 
static constexpr auto I3
Definition: gridwise_gemm_dl_multiple_d.hpp:66
 
static constexpr index_t NumDTensor
Definition: gridwise_gemm_dl_multiple_d.hpp:61
 
static constexpr auto I0
Definition: gridwise_gemm_dl_multiple_d.hpp:63
 
decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})) AGridDesc_K0_M0_M1_K1
Definition: gridwise_gemm_dl_multiple_d.hpp:248
 
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_dl_multiple_d.hpp:136
 
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N_ &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_multiple_d.hpp:200
 
__host__ static constexpr __device__ bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_multiple_d.hpp:150
 
static constexpr auto K1
Definition: gridwise_gemm_dl_multiple_d.hpp:69
 
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_multiple_d.hpp:110
 
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition: gridwise_gemm_dl_multiple_d.hpp:251
 
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, DsGridPointer p_ds_grid, FloatC *__restrict__ p_c_grid, void *__restrict__ p_shared_block, const AElementwiseOperation &, const BElementwiseOperation &, const CDEElementwiseOperation &cde_element_op, const AGridDesc_K0_M0_M1_K1 &a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 &b_grid_desc_k0_n0_n1_k1, const DsGridDesc_M0_M10_M11_N0_N10_N11 &ds_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 &c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap &block_2_ctile_map, integral_constant< bool, HasMainKBlockLoop >, integral_constant< bool, HasDoubleTailKBlockLoop >)
Definition: gridwise_gemm_dl_multiple_d.hpp:260
 
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_dl_multiple_d.hpp:83
 
static constexpr auto MakeDsGridPointer()
Definition: gridwise_gemm_dl_multiple_d.hpp:72
 
decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})) BGridDesc_K0_N0_N1_K1
Definition: gridwise_gemm_dl_multiple_d.hpp:249
 
Definition: sequence.hpp:43
 
Definition: static_buffer.hpp:16
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
 
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: unary_element_wise_operation.hpp:308