/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp Source File
gridwise_tensor_rearrange.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
16 
17 namespace ck {
18 
19 template <typename InputGridDesc,
20  typename InputDataType,
21  typename OutputGridDesc,
22  typename OutputDataType,
23  typename Block2ETileMap,
24  typename ComputePtrOffsetOfStridedBatch,
25  typename GridwiseTensorRearrangeKernel>
26 __global__ void
27 #if CK_USE_LAUNCH_BOUNDS
29 #endif
30  kernel_tensor_rearrange(const InputGridDesc in_grid_desc,
31  const InputDataType* __restrict__ p_in_global,
32  const OutputGridDesc out_grid_desc,
33  OutputDataType* __restrict__ p_out_global,
34  const index_t batch_count,
35  const Block2ETileMap block_2_tile_map,
36  const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
37 {
38 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
39  defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
40  defined(__gfx12__))
41  GridwiseTensorRearrangeKernel::Run(in_grid_desc,
42  p_in_global,
43  out_grid_desc,
44  p_out_global,
45  batch_count,
46  block_2_tile_map,
47  compute_ptr_offset_of_batch);
48 #else
49  ignore = in_grid_desc;
50  ignore = p_in_global;
51  ignore = out_grid_desc;
52  ignore = p_out_global;
53  ignore = batch_count;
54  ignore = block_2_tile_map;
55  ignore = compute_ptr_offset_of_batch;
56 #endif
57 }
58 
59 template <typename InputGridDesc,
60  typename InputDataType,
61  typename OutputGridDesc,
62  typename OutputDataType,
63  index_t BlockSize,
64  index_t MPerBlock,
65  index_t KPerBlock,
66  typename ThreadClusterLengths,
67  index_t ScalarPerVector,
68  InMemoryDataOperationEnum DstInMemOp,
69  typename Block2ETileMap,
70  typename ComputePtrOffsetOfStridedBatch>
72 {
73 
74  static constexpr auto I0 = Number<0>{};
75  static constexpr auto I1 = Number<1>{};
76 
78 
79  __device__ static void Run(const InputGridDesc& in_grid_desc,
80  const InputDataType* __restrict__ p_in_global,
81  const OutputGridDesc& out_grid_desc,
82  OutputDataType* __restrict__ p_out_global,
83  const index_t batch_count,
84  const Block2ETileMap& block_2_tile_map,
85  const ComputePtrOffsetOfStridedBatch& compute_ptr_offset_of_batch)
86  {
87  const auto block_work_idx =
88  block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
89 
90  const index_t m_block_data_idx_on_grid =
91  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
92 
93  const index_t k_block_data_idx_on_grid =
94  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * KPerBlock);
95 
96  auto copy_global_to_global =
100  decltype(tie(in_grid_desc)),
101  decltype(tie(out_grid_desc)),
103  Sequence<static_cast<index_t>(DstInMemOp)>,
105  ThreadClusterLengths,
108  I1,
109  ScalarPerVector,
112  in_grid_desc,
113  make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)),
114  out_grid_desc,
115  make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)),
117 
118  const index_t num_blocks_per_batch =
119  __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
120  const index_t g_idx =
121  __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
122 
123  // Global Memory
124  const index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
125  static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
126  const index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
127  static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
128 
129  const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
130  p_in_global + a_batch_offset, in_grid_desc.GetElementSpaceSize());
131  auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
132  p_out_global + c_batch_offset, out_grid_desc.GetElementSpaceSize());
133 
134  copy_global_to_global.Run(
135  tie(in_grid_desc), tie(in_global_buf), tie(out_grid_desc), tie(out_global_buf));
136  }
137 
138  __host__ static constexpr bool CheckValidity(const InputGridDesc& in_grid_desc,
139  const OutputGridDesc& out_grid_desc)
140  {
141  if(in_grid_desc.GetLength(I0) % MPerBlock != 0 ||
142  in_grid_desc.GetLength(I1) % KPerBlock != 0)
143  return false;
144  if(out_grid_desc.GetLength(I0) % MPerBlock != 0 ||
145  out_grid_desc.GetLength(I1) % KPerBlock != 0)
146  return false;
147  return true;
148  }
149 };
150 
151 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition: get_id.hpp:24
InMemoryDataOperationEnum
Definition: ck.hpp:267
int64_t long_index_t
Definition: ck.hpp:290
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__global__ void kernel_tensor_rearrange(const InputGridDesc in_grid_desc, const InputDataType *__restrict__ p_in_global, const OutputGridDesc out_grid_desc, OutputDataType *__restrict__ p_out_global, const index_t batch_count, const Block2ETileMap block_2_tile_map, const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
Definition: gridwise_tensor_rearrange.hpp:30
Definition: gridwise_tensor_rearrange.hpp:72
static constexpr __host__ bool CheckValidity(const InputGridDesc &in_grid_desc, const OutputGridDesc &out_grid_desc)
Definition: gridwise_tensor_rearrange.hpp:138
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_tensor_rearrange.hpp:77
static constexpr auto I0
Definition: gridwise_tensor_rearrange.hpp:74
static constexpr auto I1
Definition: gridwise_tensor_rearrange.hpp:75
static __device__ void Run(const InputGridDesc &in_grid_desc, const InputDataType *__restrict__ p_in_global, const OutputGridDesc &out_grid_desc, OutputDataType *__restrict__ p_out_global, const index_t batch_count, const Block2ETileMap &block_2_tile_map, const ComputePtrOffsetOfStridedBatch &compute_ptr_offset_of_batch)
Definition: gridwise_tensor_rearrange.hpp:79
Definition: sequence.hpp:43
Definition: thread_group_tensor_slice_transfer_v7.hpp:42
Definition: tuple.hpp:117
Definition: integral_constant.hpp:10
Definition: unary_element_wise_operation.hpp:241