/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp File Reference#
device_batched_gemm_gemm_xdl_cshuffle.hpp File Reference
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
Go to the source code of this file.
Namespaces | |
ck | |
ck::tensor_operation | |
ck::tensor_operation::device | |
Functions | |
template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename AElementwiseOperation , typename BElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename B1GridDesc_BK0_N_BK1 , typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename Block2CTileMap , typename ComputeBasePtrOfStridedBatch , bool HasMainKBlockLoop> | |
__global__ void | ck::tensor_operation::device::kernel_gemm_gemm_xdl_cshuffle_v1 (const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, const FloatAB *__restrict__ p_b1_grid, FloatC *__restrict__ p_c_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const AccElementwiseOperation acc_element_op, const B1ElementwiseOperation b1_element_op, const CElementwiseOperation c_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const Block2CTileMap block_2_ctile_map, const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) |