18 template <
typename GridwisePermute,
23 typename ElementwiseOperation,
24 typename Block2TileMap>
26 const OutGridDesc out_grid_desc,
27 const InDataType* p_in_global,
28 OutDataType* p_out_global,
29 const ElementwiseOperation elementwise_op,
30 const Block2TileMap block_2_tile_map)
43 template <
typename InGridDesc,
47 typename ElementwiseOperation,
53 typename InBlockTransferThreadClusterLengths,
54 typename InBlockTransferThreadClusterArrangeOrder,
61 static_assert(InGridDesc::GetNumOfDimension() == OutGridDesc::GetNumOfDimension());
62 static_assert(3 <= InGridDesc::GetNumOfDimension());
63 static_assert((InGridDesc::GetNumOfDimension() - 2) <= SrcVectorDim &&
64 SrcVectorDim < InGridDesc::GetNumOfDimension());
65 static_assert((OutGridDesc::GetNumOfDimension() - 2) <= DstVectorDim &&
66 DstVectorDim < OutGridDesc::GetNumOfDimension());
67 static_assert(SrcVectorDim != DstVectorDim);
78 static_assert(3 <=
NumDim);
102 const index_t grid_size = N0 * H0 * W0;
107 template <
typename TopIdx>
110 static_assert(TopIdx::Size() == 1);
112 auto block_1d_id = idx_top[
I0];
121 block_1d_id = block_1d_id % (N0 * H0 * W0);
123 index_t idx_N0 = block_1d_id / (H0 * W0);
124 index_t idx_H0 = (block_1d_id % (H0 * W0)) / W0;
125 index_t idx_W0 = block_1d_id % W0;
131 const InGridDesc desc_;
149 template <
typename Gr
idDesc>
150 __host__ __device__
static constexpr
auto GetMergedDesc(
const GridDesc& desc)
152 constexpr
index_t NumDim = GridDesc::GetNumOfDimension();
153 static_assert(3 <= NumDim);
158 [&](
auto I) {
return desc.GetLength(I); },
Number<NumDim - 2>{})),
170 constexpr
auto in_block_desc_nperblock_hperblock_wperblock =
173 return in_block_desc_nperblock_hperblock_wperblock.GetElementSpaceSize() *
182 __host__ __device__
static constexpr
bool CheckValidity(
const InGridDesc& in_grid_desc,
183 const OutGridDesc& out_grid_desc)
185 constexpr
index_t NumDim = InGridDesc::GetNumOfDimension();
190 if(valid && in_grid_desc.GetLength(I) != out_grid_desc.GetLength(I))
200 out_grid_desc.GetLength(Number<NumDim - 1>{}));
203 template <
typename Block2TileMap>
204 __device__
static void Run(
const InGridDesc in_grid_desc,
205 const OutGridDesc out_grid_desc,
206 const InDataType* p_in_global,
207 OutDataType* p_out_global,
208 void* __restrict__ p_shared,
209 const ElementwiseOperation elementwise_op,
212 auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
213 p_in_global, in_grid_desc.GetElementSpaceSize());
215 auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
216 p_out_global, out_grid_desc.GetElementSpaceSize());
219 const auto block_work_idx =
222 const index_t n_block_data_idx_on_grid =
223 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * NPerBlock);
225 const index_t h_block_data_idx_on_grid =
226 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * HPerBlock);
228 const index_t w_block_data_idx_on_grid =
229 __builtin_amdgcn_readfirstlane(block_work_idx[I2] * WPerBlock);
232 constexpr
auto in_block_desc_nperblock_hperblock_wperblock =
233 GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock();
235 auto in_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
236 static_cast<InDataType*
>(p_shared),
237 in_block_desc_nperblock_hperblock_wperblock.GetElementSpaceSize());
242 constexpr
index_t SrcVectorDimAfterMerge =
243 SrcVectorDim - (InGridDesc::GetNumOfDimension() - 3);
244 constexpr
index_t DstVectorDimAfterMerge = SrcVectorDimAfterMerge;
250 const auto in_grid_desc_n_h_w = GetMergedDesc(in_grid_desc);
255 ElementwiseOperation,
257 InMemoryDataOperationEnum::Set,
259 InBlockTransferThreadClusterLengths,
260 InBlockTransferThreadClusterArrangeOrder,
263 decltype(in_grid_desc_n_h_w),
264 decltype(in_block_desc_nperblock_hperblock_wperblock),
265 InBlockTransferAccessOrder,
266 InBlockTransferAccessOrder,
267 SrcVectorDimAfterMerge,
274 true>(in_grid_desc_n_h_w,
276 n_block_data_idx_on_grid, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
278 in_block_desc_nperblock_hperblock_wperblock,
284 const auto out_grid_desc_n_w_h = GetMergedDesc(out_grid_desc);
298 ElementwiseOperation,
300 InMemoryDataOperationEnum::Set,
302 InBlockTransferThreadClusterLengths,
303 InBlockTransferThreadClusterArrangeOrder,
306 decltype(in_block_desc_nperblock_hperblock_wperblock),
307 decltype(out_grid_desc_n_h_w),
308 InBlockTransferAccessOrder,
309 InBlockTransferAccessOrder,
311 DstVectorDimAfterMerge,
317 true>(in_block_desc_nperblock_hperblock_wperblock,
322 n_block_data_idx_on_grid, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
325 in_global_load.
Run(in_grid_desc_n_h_w,
327 in_block_desc_nperblock_hperblock_wperblock,
331 out_global_store.Run(in_block_desc_nperblock_hperblock_wperblock,
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__global__ void kernel_nd_permute(const InGridDesc in_grid_desc, const OutGridDesc out_grid_desc, const InDataType *p_in_global, OutDataType *p_out_global, const ElementwiseOperation elementwise_op, const Block2TileMap block_2_tile_map)
Definition: gridwise_permute.hpp:25
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__host__ constexpr __device__ auto generate_sequence_v2(F &&f, Number< N >)
Definition: sequence_helper.hpp:25
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
Definition: gridwise_permute.hpp:76
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: gridwise_permute.hpp:108
Block2TileMap(const InGridDesc &desc)
Definition: gridwise_permute.hpp:91
constexpr __host__ index_t CalculateGridSize(const InGridDesc &desc) const
Definition: gridwise_permute.hpp:93
static constexpr index_t NumDim
Definition: gridwise_permute.hpp:77
Block2TileMap & operator=(const Block2TileMap &)=delete
static constexpr auto I0
Definition: gridwise_permute.hpp:80
Block2TileMap(Block2TileMap &&)=delete
Block2TileMap & operator=(Block2TileMap &&)=delete
Block2TileMap(const Block2TileMap &)=default
Definition: gridwise_permute.hpp:60
__host__ static constexpr __device__ auto MakeDefaultBlock2TileMap(const InGridDesc &desc)
Definition: gridwise_permute.hpp:177
static constexpr auto I2
Definition: gridwise_permute.hpp:71
static constexpr auto I0
Definition: gridwise_permute.hpp:69
__host__ static constexpr __device__ auto GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock()
Definition: gridwise_permute.hpp:137
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_permute.hpp:73
__host__ static constexpr __device__ auto GetMergedDesc(const GridDesc &desc)
Definition: gridwise_permute.hpp:150
static constexpr auto I1
Definition: gridwise_permute.hpp:70
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_permute.hpp:168
__host__ static constexpr __device__ bool CheckValidity(const InGridDesc &in_grid_desc, const OutGridDesc &out_grid_desc)
Definition: gridwise_permute.hpp:182
static __device__ void Run(const InGridDesc in_grid_desc, const OutGridDesc out_grid_desc, const InDataType *p_in_global, OutDataType *p_out_global, void *__restrict__ p_shared, const ElementwiseOperation elementwise_op, const Block2TileMap &block_2_tile_map)
Definition: gridwise_permute.hpp:204
Definition: multi_index_transform.hpp:13
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id)
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:137
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: unary_element_wise_operation.hpp:241