/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp Source File#
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
Go to the documentation of this file.
16 #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp"
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:33
std::string getMaskingSpecializationString(const MaskingSpecialization &s)
Definition: masking_specialization.hpp:16
MaskingSpecialization
Definition: masking_specialization.hpp:11
@ MaskOutUpperTriangle
TensorSpecialization
Definition: tensor_specialization.hpp:11
GemmSpecialization
Definition: gemm_specialization.hpp:11
std::string getTensorSpecializationString(const TensorSpecialization &s)
Definition: tensor_specialization.hpp:16
__global__ void kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, const FloatAB *__restrict__ p_b1_grid, FloatC *__restrict__ p_c_grid, D0sPointer p_d0s_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const C0DEElementwiseOperation c0de_element_op, const B1ElementwiseOperation b1_element_op, const C1DEElementwiseOperation c1de_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c1_grid_desc_mblock_mperblock_nblock_nperblock, const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, const Block2CTileMap block_2_ctile_map, const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const C0MatrixMask c0_matrix_mask)
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:47
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables... callables)
Definition: kernel_launch.hpp:72
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
@ Default
Definition: stream_config.hpp:10
Definition: gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:86
remove_cvref_t< decltype(MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C1GridDesc_M_N{}))> C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition: gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:378
decltype(MakeD0sGridPointer()) D0sGridPointer
Definition: gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:371
__host__ static constexpr __device__ auto MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0sGridDesc_M_N &ds_grid_desc_m_n)
Definition: gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:362
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 &b1_grid_desc_bk0_n_bk1, const C1GridDesc_M_N &c1_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:216
__host__ static constexpr __device__ auto MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const C1GridDesc_M_N &c1_grid_desc_m_n)
Definition: gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:278
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(C1GridDesc_M_N{}))> DefaultBlock2CTileMap
Definition: gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:381
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:270
remove_cvref_t< decltype(MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(D0sGridDesc_M_N{}))> D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
Definition: gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:374
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: transform_contraction_to_gemm.hpp:121
__host__ static constexpr __device__ auto MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k, const Number &BK1)
Definition: transform_contraction_to_gemm.hpp:208
__host__ static constexpr __device__ auto MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k, const Number &AK1)
Definition: transform_contraction_to_gemm.hpp:168
static auto MakeB0GridDescriptor_N_K(const std::vector< index_t > &b0_gs_ns_ks_lengths_vec, const std::vector< index_t > &b0_gs_ns_ks_strides_vec)
Definition: transform_contraction_to_gemm.hpp:198
static auto MakeAGridDescriptor_G_M_K(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition: transform_contraction_to_gemm.hpp:154
__host__ static constexpr __device__ auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K &b1_grid_desc_n_k, const Number &B1K1)
Definition: transform_contraction_to_gemm.hpp:248
static auto MakeB0GridDescriptor_G_N_K(const std::vector< index_t > &b0_gs_ns_ks_lengths_vec, const std::vector< index_t > &b0_gs_ns_ks_strides_vec)
Definition: transform_contraction_to_gemm.hpp:193
static auto MakeAGridDescriptor_M_K(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition: transform_contraction_to_gemm.hpp:159
static constexpr auto matrix_padder
Definition: transform_contraction_to_gemm.hpp:139
static auto MakeCGridDescriptor_G_M_N(const std::vector< index_t > &c_gs_ms_os_lengths_vec, const std::vector< index_t > &c_gs_ms_os_strides_vec)
Definition: transform_contraction_to_gemm.hpp:274
static auto MakeB1GridDescriptor_G_N_K(const std::vector< index_t > &b1_gs_os_ns_lengths_vec, const std::vector< index_t > &b1_gs_os_ns_strides_vec)
Definition: transform_contraction_to_gemm.hpp:233
static auto MakeB1GridDescriptor_N_K(const std::vector< index_t > &b1_gs_os_ns_lengths_vec, const std::vector< index_t > &b1_gs_os_ns_strides_vec)
Definition: transform_contraction_to_gemm.hpp:238
static auto MakeCGridDescriptor_M_N(const std::vector< index_t > &c_gs_ms_os_lengths_vec, const std::vector< index_t > &c_gs_ms_os_strides_vec)
Definition: transform_contraction_to_gemm.hpp:279
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:329
__host__ constexpr __device__ long_index_t GetBBasePtr(index_t g_idx) const
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:348
__host__ constexpr __device__ long_index_t GetD0BasePtr(index_t g_idx, Number< I > d0_idx) const
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:364
__host__ constexpr __device__ long_index_t GetB1BasePtr(index_t g_idx) const
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:353
__host__ constexpr __device__ long_index_t GetABasePtr(index_t g_idx) const
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:343
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K &a_grid_desc_g_m_k, const BGridDesc_G_N_K &b_grid_desc_g_n_k, const B1GridDesc_G_N_K &b1_grid_desc_g_n_k, const C1GridDesc_G_M_N &c1_grid_desc_g_m_n, const D0sGridDesc_G_M_N &d0s_grid_desc_g_m_n)
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:330
__host__ constexpr __device__ long_index_t GetCBasePtr(index_t g_idx) const
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:358
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:447
GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c1_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:594
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle::Argument::p_b1_grid_
const B1DataType * p_b1_grid_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:578
GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:599
AElementwiseOperation a_element_op_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:602
C1DEElementwiseOperation c1de_element_op_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:606
std::vector< index_t > c_mz_gemm1nz_strides_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:616
std::vector< index_t > b_nz_kz_strides_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:614
BElementwiseOperation b_element_op_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:603
B1GridDesc_G_N_K b1_grid_desc_g_n_k_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:589
index_t batch_count_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:619
C0DEElementwiseOperation c0de_element_op_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:604
C1GridDesc_M_N c1_grid_desc_m_n_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:586
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:583
std::vector< index_t > raw_lengths_mz_nz_kz_gemm1nz_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:612
D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:591
std::vector< index_t > b1_nz_kz_strides_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:615
GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:596
B1ElementwiseOperation b1_element_op_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:605
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:620
GridwiseGemm::D0sGridPointer p_d0s_grid_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:580
C1GridDesc_G_M_N c1_grid_desc_g_m_n_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:590
AGridDesc_G_M_K a_grid_desc_g_m_k_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:587
const BDataType * p_b_grid_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:577
std::vector< index_t > a_mz_kz_strides_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:613
BGridDesc_G_N_K b_grid_desc_g_n_k_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:588
void Print() const
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:559
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, const B1DataType *p_b1_grid, CDataType *p_c_grid, const std::array< void *, NumD0Tensor > p_acc0_biases, const std::array< void *, NumD1Tensor > p_acc1_biases, const std::vector< index_t > &a_gs_ms_ks_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_strides, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_lengths, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_strides, const std::vector< index_t > &c_gs_ms_gemm1ns_lengths, const std::vector< index_t > &c_gs_ms_gemm1ns_strides, const std::array< std::vector< ck::index_t >, NumD0Tensor > &acc0_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumD0Tensor > &acc0_biases_gs_ms_ns_strides, const std::array< std::vector< ck::index_t >, NumD1Tensor > &acc1_biases_gs_ms_gemm1ns_lengths, const std::array< std::vector< ck::index_t >, NumD1Tensor > &acc1_biases_gs_ms_gemm1ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, C0DEElementwiseOperation c0de_element_op, B1ElementwiseOperation b1_element_op, C1DEElementwiseOperation c1de_element_op)
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:448
const ADataType * p_a_grid_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:576
C0MatrixMask c0_matrix_mask_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:609
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:584
CDataType * p_c_grid_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:579
std::array< std::vector< ck::index_t >, NumD0Tensor > d0s_nl_ns_lengths_strides_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:617
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:585
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:625
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:706
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:628
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:219
decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})) BGridDesc_BK0_N_BK1
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:305
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:803
decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})) B1GridDesc_G_N_K
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:310
static auto MakeD0sGridDescriptor_G_M_N(const std::array< std::vector< ck::index_t >, NumD0Tensor > &acc0_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumD0Tensor > &acc0_biases_gs_ms_ns_strides)
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:292
static constexpr bool IsValidCompilationParameter()
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:713
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const void *p_b1, void *p_c, const std::array< void *, NumD0Tensor > p_acc0_biases, const std::array< void *, NumD1Tensor > p_acc1_biases, const std::vector< index_t > &a_gs_ms_ks_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_strides, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_lengths, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_strides, const std::vector< index_t > &c_gs_ms_gemm1ns_lengths, const std::vector< index_t > &c_gs_ms_gemm1ns_strides, const std::array< std::vector< ck::index_t >, NumD0Tensor > acc0_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumD0Tensor > acc0_biases_gs_ms_ns_strides, const std::array< std::vector< ck::index_t >, NumD1Tensor > acc1_biases_gs_ms_gemm1ns_lengths, const std::array< std::vector< ck::index_t >, NumD1Tensor > acc1_biases_gs_ms_gemm1ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, C0DEElementwiseOperation c0de_element_op, B1ElementwiseOperation b1_element_op, C1DEElementwiseOperation c1de_element_op) override
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:864
static constexpr index_t NumD1Tensor
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:224
static constexpr index_t NumD0Tensor
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:223
static auto MakeB1GridDescriptor_BK0_N_BK1(const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_lengths_vec, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_strides_vec)
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:271
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:917
constexpr static auto make_MaskOutPredicate()
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:315
std::string GetTypeString() const override
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:923
static constexpr auto I1
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:242
decltype(MakeD0sGridDescriptor_M_N({}, {})) D0sGridDesc_M_N
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:312
static auto MakeInvoker()
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:860
static constexpr auto I2
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:243
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, const B1DataType *p_b1, CDataType *p_c, const std::array< void *, NumD0Tensor > p_acc0_biases, const std::array< void *, NumD1Tensor > p_acc1_biases, const std::vector< index_t > &a_gs_ms_ks_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_strides, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_lengths, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_strides, const std::vector< index_t > &c_gs_ms_gemm1ns_lengths, const std::vector< index_t > &c_gs_ms_gemm1ns_strides, const std::array< std::vector< ck::index_t >, NumD0Tensor > acc0_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumD0Tensor > acc0_biases_gs_ms_ns_strides, const std::array< std::vector< ck::index_t >, NumD1Tensor > acc1_biases_gs_ms_gemm1ns_lengths, const std::array< std::vector< ck::index_t >, NumD1Tensor > acc1_biases_gs_ms_gemm1ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, C0DEElementwiseOperation c0de_element_op, B1ElementwiseOperation b1_element_op, C1DEElementwiseOperation c1de_element_op)
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:808
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:254
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector< index_t > &b_gs_ns_ks_lengths_vec, const std::vector< index_t > &b_gs_ns_ks_strides_vec)
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:262
static constexpr auto I0
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:241
decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})) AGridDesc_G_M_K
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:308
decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})) BGridDesc_G_N_K
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:309
static auto MakeD0sGridDescriptor_M_N(const std::array< std::vector< ck::index_t >, NumD0Tensor > &acc0_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumD0Tensor > &acc0_biases_gs_ms_ns_strides)
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:280
static bool IsSupportedArgument(const Argument &arg)
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:719
decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})) AGridDesc_AK0_M_AK1
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:304
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle::B1GridDesc_BK0_N_BK1
decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})) B1GridDesc_BK0_N_BK1
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:306
decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})) C1GridDesc_G_M_N
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:311
C0MatrixMask_impl< decltype(make_MaskOutPredicate())> C0MatrixMask
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:326
GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector > GridwiseGemm
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:442
decltype(Transform::MakeCGridDescriptor_M_N({}, {})) C1GridDesc_M_N
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:307
decltype(MakeD0sGridDescriptor_G_M_N({}, {})) D0sGridDesc_G_M_N
Definition: device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:313
Definition: device_batched_gemm_softmax_gemm_permute.hpp:34
Definition: masking_specialization.hpp:27
Definition: masking_specialization.hpp:41