18 template <
typename GridwisePermute,
 
   23           typename ElementwiseOperation,
 
   24           typename Block2TileMap>
 
   26                                   const OutGridDesc out_grid_desc,
 
   27                                   const InDataType* p_in_global,
 
   28                                   OutDataType* p_out_global,
 
   29                                   const ElementwiseOperation elementwise_op,
 
   30                                   const Block2TileMap block_2_tile_map)
 
   43 template <
typename InGridDesc,
 
   47           typename ElementwiseOperation,
 
   53           typename InBlockTransferThreadClusterLengths,
 
   54           typename InBlockTransferThreadClusterArrangeOrder,
 
   61     static_assert(InGridDesc::GetNumOfDimension() == OutGridDesc::GetNumOfDimension());
 
   62     static_assert(3 <= InGridDesc::GetNumOfDimension());
 
   63     static_assert((InGridDesc::GetNumOfDimension() - 2) <= SrcVectorDim &&
 
   64                   SrcVectorDim < InGridDesc::GetNumOfDimension());
 
   65     static_assert((OutGridDesc::GetNumOfDimension() - 2) <= DstVectorDim &&
 
   66                   DstVectorDim < OutGridDesc::GetNumOfDimension());
 
   67     static_assert(SrcVectorDim != DstVectorDim);
 
   78         static_assert(3 <= 
NumDim);
 
  102             const index_t grid_size = N0 * H0 * W0;
 
  107         template <
typename TopIdx>
 
  110             static_assert(TopIdx::Size() == 1);
 
  112             auto block_1d_id = idx_top[
I0];
 
  121             block_1d_id = block_1d_id % (N0 * H0 * W0);
 
  123             index_t idx_N0 = block_1d_id / (H0 * W0);
 
  124             index_t idx_H0 = (block_1d_id % (H0 * W0)) / W0;
 
  125             index_t idx_W0 = block_1d_id % W0;
 
  131         const InGridDesc desc_;
 
  149     template <
typename Gr
idDesc>
 
  150     __host__ __device__ 
static constexpr 
auto GetMergedDesc(
const GridDesc& desc)
 
  152         constexpr 
index_t NumDim = GridDesc::GetNumOfDimension();
 
  153         static_assert(3 <= NumDim);
 
  158                            [&](
auto I) { 
return desc.GetLength(I); }, 
Number<NumDim - 2>{})),
 
  170         constexpr 
auto in_block_desc_nperblock_hperblock_wperblock =
 
  173         return in_block_desc_nperblock_hperblock_wperblock.GetElementSpaceSize() *
 
  182     __host__ __device__ 
static constexpr 
bool CheckValidity(
const InGridDesc& in_grid_desc,
 
  183                                                             const OutGridDesc& out_grid_desc)
 
  185         constexpr 
index_t NumDim = InGridDesc::GetNumOfDimension();
 
  190             if(valid && in_grid_desc.GetLength(I) != out_grid_desc.GetLength(I))
 
  200                 out_grid_desc.GetLength(Number<NumDim - 1>{}));
 
  203     template <
typename Block2TileMap>
 
  204     __device__ 
static void Run(
const InGridDesc in_grid_desc,
 
  205                                const OutGridDesc out_grid_desc,
 
  206                                const InDataType* p_in_global,
 
  207                                OutDataType* p_out_global,
 
  208                                void* __restrict__ p_shared,
 
  209                                const ElementwiseOperation elementwise_op,
 
  212         auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  213             p_in_global, in_grid_desc.GetElementSpaceSize());
 
  215         auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  216             p_out_global, out_grid_desc.GetElementSpaceSize());
 
  219         const auto block_work_idx =
 
  222         const index_t n_block_data_idx_on_grid =
 
  223             __builtin_amdgcn_readfirstlane(block_work_idx[I0] * NPerBlock);
 
  225         const index_t h_block_data_idx_on_grid =
 
  226             __builtin_amdgcn_readfirstlane(block_work_idx[I1] * HPerBlock);
 
  228         const index_t w_block_data_idx_on_grid =
 
  229             __builtin_amdgcn_readfirstlane(block_work_idx[I2] * WPerBlock);
 
  232         constexpr 
auto in_block_desc_nperblock_hperblock_wperblock =
 
  233             GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock();
 
  235         auto in_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  236             static_cast<InDataType*
>(p_shared),
 
  237             in_block_desc_nperblock_hperblock_wperblock.GetElementSpaceSize());
 
  242         constexpr 
index_t SrcVectorDimAfterMerge =
 
  243             SrcVectorDim - (InGridDesc::GetNumOfDimension() - 3);
 
  244         constexpr 
index_t DstVectorDimAfterMerge = SrcVectorDimAfterMerge;
 
  250         const auto in_grid_desc_n_h_w = GetMergedDesc(in_grid_desc);
 
  255             ElementwiseOperation,
 
  257             InMemoryDataOperationEnum::Set,
 
  259             InBlockTransferThreadClusterLengths,
 
  260             InBlockTransferThreadClusterArrangeOrder,
 
  263             decltype(in_grid_desc_n_h_w),
 
  264             decltype(in_block_desc_nperblock_hperblock_wperblock),
 
  265             InBlockTransferAccessOrder,
 
  266             InBlockTransferAccessOrder,
 
  267             SrcVectorDimAfterMerge,
 
  274             true>(in_grid_desc_n_h_w,
 
  276                       n_block_data_idx_on_grid, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
 
  278                   in_block_desc_nperblock_hperblock_wperblock,
 
  284         const auto out_grid_desc_n_w_h = GetMergedDesc(out_grid_desc);
 
  298             ElementwiseOperation,
 
  300             InMemoryDataOperationEnum::Set,
 
  302             InBlockTransferThreadClusterLengths,
 
  303             InBlockTransferThreadClusterArrangeOrder,
 
  306             decltype(in_block_desc_nperblock_hperblock_wperblock),
 
  307             decltype(out_grid_desc_n_h_w),
 
  308             InBlockTransferAccessOrder,
 
  309             InBlockTransferAccessOrder,
 
  311             DstVectorDimAfterMerge,
 
  317             true>(in_block_desc_nperblock_hperblock_wperblock,
 
  322                       n_block_data_idx_on_grid, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
 
  325         in_global_load.
Run(in_grid_desc_n_h_w,
 
  327                            in_block_desc_nperblock_hperblock_wperblock,
 
  331         out_global_store.Run(in_block_desc_nperblock_hperblock_wperblock,
 
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
 
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
 
__global__ void kernel_nd_permute(const InGridDesc in_grid_desc, const OutGridDesc out_grid_desc, const InDataType *p_in_global, OutDataType *p_out_global, const ElementwiseOperation elementwise_op, const Block2TileMap block_2_tile_map)
Definition: gridwise_permute.hpp:25
 
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
 
__host__ constexpr __device__ auto generate_sequence_v2(F &&f, Number< N >)
Definition: sequence_helper.hpp:25
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:297
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
Definition: gridwise_permute.hpp:76
 
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: gridwise_permute.hpp:108
 
Block2TileMap(const InGridDesc &desc)
Definition: gridwise_permute.hpp:91
 
constexpr __host__ index_t CalculateGridSize(const InGridDesc &desc) const
Definition: gridwise_permute.hpp:93
 
static constexpr index_t NumDim
Definition: gridwise_permute.hpp:77
 
Block2TileMap & operator=(const Block2TileMap &)=delete
 
static constexpr auto I0
Definition: gridwise_permute.hpp:80
 
Block2TileMap(Block2TileMap &&)=delete
 
Block2TileMap & operator=(Block2TileMap &&)=delete
 
Block2TileMap(const Block2TileMap &)=default
 
Definition: gridwise_permute.hpp:60
 
__host__ static constexpr __device__ auto MakeDefaultBlock2TileMap(const InGridDesc &desc)
Definition: gridwise_permute.hpp:177
 
static constexpr auto I2
Definition: gridwise_permute.hpp:71
 
static constexpr auto I0
Definition: gridwise_permute.hpp:69
 
__host__ static constexpr __device__ auto GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock()
Definition: gridwise_permute.hpp:137
 
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_permute.hpp:73
 
__host__ static constexpr __device__ auto GetMergedDesc(const GridDesc &desc)
Definition: gridwise_permute.hpp:150
 
static constexpr auto I1
Definition: gridwise_permute.hpp:70
 
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_permute.hpp:168
 
__host__ static constexpr __device__ bool CheckValidity(const InGridDesc &in_grid_desc, const OutGridDesc &out_grid_desc)
Definition: gridwise_permute.hpp:182
 
static __device__ void Run(const InGridDesc in_grid_desc, const OutGridDesc out_grid_desc, const InDataType *p_in_global, OutDataType *p_out_global, void *__restrict__ p_shared, const ElementwiseOperation elementwise_op, const Block2TileMap &block_2_tile_map)
Definition: gridwise_permute.hpp:204
 
Definition: multi_index_transform.hpp:13
 
Definition: sequence.hpp:43
 
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
 
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id)
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:143
 
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: unary_element_wise_operation.hpp:308