/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp Source File
thread_group_tensor_slice_transfer_v6r1.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
11 
12 namespace ck {
13 
14 // this version does following things to avoid scratch memory issue
15 // 1. Use StaticallyIndexedArray instead of C array for thread buffer
16 // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
17 // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
18 template <typename ThreadGroup,
19  typename ElementwiseOperation,
20  InMemoryDataOperationEnum DstInMemOp,
21  typename SliceLengths,
22  typename ThreadClusterLengths,
23  typename ThreadClusterArrangeOrder,
24  typename SrcData,
25  typename DstData,
26  typename SrcDesc,
27  typename DstDesc,
28  typename DimAccessOrder,
29  index_t VectorDim,
30  index_t ScalarPerVector,
31  bool ThreadTransferSrcResetCoordinateAfterRun,
32  bool ThreadTransferDstResetCoordinateAfterRun>
34 {
36 
37  static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
38 
40 
41  __device__ constexpr ThreadGroupTensorSliceTransfer_v6r1(const SrcDesc& src_desc,
42  const Index& src_block_slice_origin,
43  const DstDesc& dst_desc,
44  const Index& dst_block_slice_origin,
45  const ElementwiseOperation& element_op)
46  : threadwise_transfer_(src_desc,
48  dst_desc,
50  element_op)
51 
52  {
55  nDim == ThreadClusterLengths::Size() &&
56  nDim == ThreadClusterArrangeOrder::Size() &&
57  nDim == DimAccessOrder::Size(),
58  "wrong! nDim not consistent");
59 
60  static_assert(
61  is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
62  "wrong! threads should be mapped to cover entire slicing window");
63 
64  static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
65  "wrong! ThreadGroup::GetNumOfThread() too small");
66 
67  if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
68  ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
69  {
70  const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
71  make_multi_index(ThreadGroup::GetThreadId()));
72 
73  const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
74 
75  threadwise_transfer_.SetSrcSliceOrigin(src_desc,
76  src_block_slice_origin + thread_data_idx_begin);
77  threadwise_transfer_.SetDstSliceOrigin(dst_desc,
78  dst_block_slice_origin + thread_data_idx_begin);
79  }
80  }
81 
82  template <typename SrcBuffer, typename DstBuffer>
83  __device__ void Run(const SrcDesc& src_desc,
84  const SrcBuffer& src_buf,
85  const DstDesc& dst_desc,
86  DstBuffer& dst_buf)
87  {
88  if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
89  ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
90  {
91  threadwise_transfer_.Run(src_desc, src_buf, dst_desc, dst_buf);
92  }
93  }
94 
95  __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
96  {
97  if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
98  ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
99  {
100  threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
101  }
102  }
103 
104  __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
105  {
106  if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
107  ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
108  {
109  threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
110  }
111  }
112 
113  private:
114  static constexpr auto thread_cluster_desc_ =
115  make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
116 
117  using ThreadwiseTransfer =
118  ThreadwiseTensorSliceTransfer_v6r1<SrcData,
119  DstData,
120  SrcDesc,
121  DstDesc,
122  ElementwiseOperation,
123  decltype(thread_slice_lengths),
124  DimAccessOrder,
125  VectorDim,
126  ScalarPerVector,
127  DstInMemOp,
128  ThreadTransferSrcResetCoordinateAfterRun,
129  ThreadTransferDstResetCoordinateAfterRun>;
130 
131  ThreadwiseTransfer threadwise_transfer_;
132 };
133 
134 } // namespace ck
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
InMemoryDataOperationEnum
Definition: ck.hpp:267
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto make_zero_multi_index()
Definition: array_multi_index.hpp:21
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: array.hpp:14
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
static constexpr auto thread_slice_lengths
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:37
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step)
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:95
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &step)
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:104
static constexpr index_t nDim
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:35
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:83
constexpr __device__ ThreadGroupTensorSliceTransfer_v6r1(const SrcDesc &src_desc, const Index &src_block_slice_origin, const DstDesc &dst_desc, const Index &dst_block_slice_origin, const ElementwiseOperation &element_op)
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:41
__device__ void SetSrcSliceOrigin(const SrcDesc &src_desc, const Index &src_slice_origin_idx)
Definition: threadwise_tensor_slice_transfer_v6r1.hpp:61
__device__ void SetDstSliceOrigin(const DstDesc &dst_desc, const Index &dst_slice_origin_idx)
Definition: threadwise_tensor_slice_transfer_v6r1.hpp:66
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer_v6r1.hpp:72
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &dst_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer_v6r1.hpp:193
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer_v6r1.hpp:178
Definition: type.hpp:177