19 template <
typename InputGridDesc,
20 typename InputDataType,
21 typename OutputGridDesc,
22 typename OutputDataType,
23 typename Block2ETileMap,
24 typename ComputePtrOffsetOfStridedBatch,
25 typename GridwiseTensorRearrangeKernel>
27 #if CK_USE_LAUNCH_BOUNDS
31 const InputDataType* __restrict__ p_in_global,
32 const OutputGridDesc out_grid_desc,
33 OutputDataType* __restrict__ p_out_global,
35 const Block2ETileMap block_2_tile_map,
36 const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
38 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
39 defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
41 GridwiseTensorRearrangeKernel::Run(in_grid_desc,
47 compute_ptr_offset_of_batch);
55 ignore = compute_ptr_offset_of_batch;
59 template <
typename InputGridDesc,
60 typename InputDataType,
61 typename OutputGridDesc,
62 typename OutputDataType,
66 typename ThreadClusterLengths,
69 typename Block2ETileMap,
70 typename ComputePtrOffsetOfStridedBatch>
79 __device__
static void Run(
const InputGridDesc& in_grid_desc,
80 const InputDataType* __restrict__ p_in_global,
81 const OutputGridDesc& out_grid_desc,
82 OutputDataType* __restrict__ p_out_global,
84 const Block2ETileMap& block_2_tile_map,
85 const ComputePtrOffsetOfStridedBatch& compute_ptr_offset_of_batch)
87 const auto block_work_idx =
90 const index_t m_block_data_idx_on_grid =
91 __builtin_amdgcn_readfirstlane(block_work_idx[
I0] * MPerBlock);
93 const index_t k_block_data_idx_on_grid =
94 __builtin_amdgcn_readfirstlane(block_work_idx[
I1] * KPerBlock);
96 auto copy_global_to_global =
100 decltype(
tie(in_grid_desc)),
101 decltype(
tie(out_grid_desc)),
105 ThreadClusterLengths,
118 const index_t num_blocks_per_batch =
119 __builtin_amdgcn_readfirstlane(
get_grid_size() / batch_count);
121 __builtin_amdgcn_readfirstlane(
get_block_1d_id() / num_blocks_per_batch);
124 const index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
125 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
126 const index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
127 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
129 const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
130 p_in_global + a_batch_offset, in_grid_desc.GetElementSpaceSize());
131 auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
132 p_out_global + c_batch_offset, out_grid_desc.GetElementSpaceSize());
134 copy_global_to_global.Run(
135 tie(in_grid_desc),
tie(in_global_buf),
tie(out_grid_desc),
tie(out_global_buf));
138 __host__
static constexpr
bool CheckValidity(
const InputGridDesc& in_grid_desc,
139 const OutputGridDesc& out_grid_desc)
141 if(in_grid_desc.GetLength(
I0) % MPerBlock != 0 ||
142 in_grid_desc.GetLength(
I1) % KPerBlock != 0)
144 if(out_grid_desc.GetLength(
I0) % MPerBlock != 0 ||
145 out_grid_desc.GetLength(
I1) % KPerBlock != 0)
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition: get_id.hpp:24
InMemoryDataOperationEnum
Definition: ck.hpp:267
int64_t long_index_t
Definition: ck.hpp:290
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__global__ void kernel_tensor_rearrange(const InputGridDesc in_grid_desc, const InputDataType *__restrict__ p_in_global, const OutputGridDesc out_grid_desc, OutputDataType *__restrict__ p_out_global, const index_t batch_count, const Block2ETileMap block_2_tile_map, const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
Definition: gridwise_tensor_rearrange.hpp:30
Definition: gridwise_tensor_rearrange.hpp:72
static constexpr __host__ bool CheckValidity(const InputGridDesc &in_grid_desc, const OutputGridDesc &out_grid_desc)
Definition: gridwise_tensor_rearrange.hpp:138
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_tensor_rearrange.hpp:77
static constexpr auto I0
Definition: gridwise_tensor_rearrange.hpp:74
static constexpr auto I1
Definition: gridwise_tensor_rearrange.hpp:75
static __device__ void Run(const InputGridDesc &in_grid_desc, const InputDataType *__restrict__ p_in_global, const OutputGridDesc &out_grid_desc, OutputDataType *__restrict__ p_out_global, const index_t batch_count, const Block2ETileMap &block_2_tile_map, const ComputePtrOffsetOfStridedBatch &compute_ptr_offset_of_batch)
Definition: gridwise_tensor_rearrange.hpp:79
Definition: sequence.hpp:43
Definition: thread_group_tensor_slice_transfer_v7.hpp:42
Definition: tuple.hpp:117
Definition: integral_constant.hpp:10
Definition: unary_element_wise_operation.hpp:241