/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp Source File
thread_group_tensor_slice_transfer_v4r1.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
11 
12 namespace ck {
13 
23 template <typename ThreadGroup,
24  typename SrcElementwiseOperation,
25  typename DstElementwiseOperation,
26  InMemoryDataOperationEnum DstInMemOp,
27  typename BlockSliceLengths,
28  typename ThreadClusterLengths,
29  typename ThreadClusterArrangeOrder,
30  typename SrcData,
31  typename DstData,
32  typename SrcDesc,
33  typename DstDesc,
34  typename SrcDimAccessOrder,
35  typename DstDimAccessOrder,
36  index_t SrcVectorDim,
37  index_t DstVectorDim,
38  index_t SrcScalarPerVector,
39  index_t DstScalarPerVector,
40  index_t SrcScalarStrideInVector,
41  index_t DstScalarStrideInVector,
42  bool ThreadTransferSrcResetCoordinateAfterRun,
43  bool ThreadTransferDstResetCoordinateAfterRun,
44  index_t NumThreadScratch = 1>
46 {
48 
49  static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
50 
52 
54  const SrcDesc& src_desc,
55  const Index& src_block_slice_origin,
56  const SrcElementwiseOperation& src_element_op,
57  const DstDesc& dst_desc,
58  const Index& dst_block_slice_origin,
59  const DstElementwiseOperation& dst_element_op)
60  : threadwise_transfer_(src_desc,
62  src_element_op,
63  dst_desc,
65  dst_element_op)
66 
67  {
70  nDim == ThreadClusterLengths::Size() &&
71  nDim == ThreadClusterArrangeOrder::Size() &&
72  nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
73  "wrong! nDim not consistent");
74 
75  static_assert(
76  is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
77  "wrong! threads should be mapped to cover entire slicing window");
78 
79  static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
80  "wrong! ThreadGroup::GetNumOfThread() too small");
81 
82  if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
83  ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
84  {
85  const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
86  make_multi_index(ThreadGroup::GetThreadId()));
87 
88  const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
89 
90  threadwise_transfer_.SetSrcSliceOrigin(src_desc,
91  src_block_slice_origin + thread_data_idx_begin);
92  threadwise_transfer_.SetDstSliceOrigin(dst_desc,
93  dst_block_slice_origin + thread_data_idx_begin);
94  }
95  }
96 
97  __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_block_slice_origin)
98  {
99  if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
100  ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
101  {
102  const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
103  make_multi_index(ThreadGroup::GetThreadId()));
104 
105  const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
106 
107  threadwise_transfer_.SetSrcSliceOrigin(src_desc,
108  src_block_slice_origin + thread_data_idx_begin);
109  }
110  }
111 
112  template <typename SrcBuffer, index_t ThreadScratchId = 0>
113  __device__ void RunRead(const SrcDesc& src_desc,
114  const SrcBuffer& src_buf,
115  Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
116  {
117  if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
118  ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
119  {
120  threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
121  }
122  }
123 
124  template <typename DstBuffer, index_t ThreadScratchId = 0>
125  __device__ void RunWrite(const DstDesc& dst_desc,
126  DstBuffer& dst_buf,
127  Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
128  {
129  if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
130  ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
131  {
132  threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
133  }
134  }
135 
136  template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
137  __device__ void Run(const SrcDesc& src_desc,
138  const SrcBuffer& src_buf,
139  const DstDesc& dst_desc,
140  DstBuffer& dst_buf,
141  Number<ThreadScratchId> thread_scratch_id)
142  {
143  RunRead(src_desc, src_buf, thread_scratch_id);
144  RunWrite(dst_desc, dst_buf, thread_scratch_id);
145  }
146 
147  __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
148  {
149  if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
150  ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
151  {
152  threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
153  }
154  }
155 
156  __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
157  {
158  if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
159  ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
160  {
161  threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
162  }
163  }
164 
165  private:
166  static constexpr auto thread_cluster_desc_ =
167  make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
168 
169  using ThreadwiseTransfer =
170  ThreadwiseTensorSliceTransfer_v3r1<decltype(thread_slice_lengths),
171  SrcElementwiseOperation,
172  DstElementwiseOperation,
173  DstInMemOp,
174  SrcData,
175  DstData,
176  SrcDesc,
177  DstDesc,
178  SrcDimAccessOrder,
179  DstDimAccessOrder,
180  SrcVectorDim,
181  DstVectorDim,
182  SrcScalarPerVector,
183  DstScalarPerVector,
184  SrcScalarStrideInVector,
185  DstScalarStrideInVector,
186  ThreadTransferSrcResetCoordinateAfterRun,
187  ThreadTransferDstResetCoordinateAfterRun,
188  NumThreadScratch>;
189 
190  ThreadwiseTransfer threadwise_transfer_;
191 };
192 
193 } // namespace ck
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
InMemoryDataOperationEnum
Definition: ck.hpp:267
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto make_zero_multi_index()
Definition: array_multi_index.hpp:21
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: array.hpp:14
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id)
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:137
constexpr __device__ ThreadGroupTensorSliceTransfer_v4r1(const SrcDesc &src_desc, const Index &src_block_slice_origin, const SrcElementwiseOperation &src_element_op, const DstDesc &dst_desc, const Index &dst_block_slice_origin, const DstElementwiseOperation &dst_element_op)
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:53
static constexpr index_t nDim
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:47
static constexpr auto thread_slice_lengths
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:49
__device__ void SetSrcSliceOrigin(const SrcDesc &src_desc, const Index &src_block_slice_origin)
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:97
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:113
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step)
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:147
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &step)
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:156
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:125
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &dst_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer_v3r1.hpp:805
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer_v3r1.hpp:790
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition: threadwise_tensor_slice_transfer_v3r1.hpp:117
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition: threadwise_tensor_slice_transfer_v3r1.hpp:512
__device__ void SetSrcSliceOrigin(const SrcDesc &src_desc, const Index &src_slice_origin_idx)
Definition: threadwise_tensor_slice_transfer_v3r1.hpp:106
__device__ void SetDstSliceOrigin(const DstDesc &dst_desc, const Index &dst_slice_origin_idx)
Definition: threadwise_tensor_slice_transfer_v3r1.hpp:111
Definition: integral_constant.hpp:10
Definition: type.hpp:177