/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp Source File
thread_group_tensor_slice_transfer_v6r2.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
11 
12 namespace ck {
13 
14 // this version does following things to avoid scratch memory issue
15 // 1. Use StaticallyIndexedArray instead of C array for thread buffer
16 // 2. It does not keep reference to tensor descriptor
17 // 3. Run() does not construct new tensor coordinate
18 template <typename ThreadGroup,
19  typename ElementwiseOperation,
20  InMemoryDataOperationEnum DstInMemOp,
21  typename SliceLengths,
22  typename ThreadClusterLengths,
23  typename ThreadClusterArrangeOrder,
24  typename Src0Data,
25  typename Src1Data,
26  typename DstData,
27  typename Src0Desc,
28  typename Src1Desc,
29  typename DstDesc,
30  typename DimAccessOrder,
31  index_t VectorDim,
32  index_t ScalarPerVector,
33  bool ThreadTransferSrc0ResetCoordinateAfterRun,
34  bool ThreadTransferSrc1ResetCoordinateAfterRun,
35  bool ThreadTransferDstResetCoordinateAfterRun>
37 {
39 
40  static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
41 
43 
44  __device__ constexpr ThreadGroupTensorSliceTransfer_v6r2(const Src0Desc& src0_desc,
45  const Index& src0_block_slice_origin,
46  const Src1Desc& src1_desc,
47  const Index& src1_block_slice_origin,
48  const DstDesc& dst_desc,
49  const Index& dst_block_slice_origin,
50  const ElementwiseOperation& element_op)
51  : threadwise_transfer_(src0_desc,
53  src1_desc,
55  dst_desc,
57  element_op)
58 
59  {
63  nDim == ThreadClusterLengths::Size() &&
64  nDim == ThreadClusterArrangeOrder::Size() &&
65  nDim == DimAccessOrder::Size(),
66  "wrong! nDim not consistent");
67 
68  static_assert(
69  is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
70  "wrong! threads should be mapped to cover entire slicing window");
71 
72  static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
73  "wrong! ThreadGroup::GetNumOfThread() too small");
74 
75  if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
76  ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
77  {
78  const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
79  make_multi_index(ThreadGroup::GetThreadId()));
80 
81  const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
82 
83  threadwise_transfer_.SetSrc0SliceOrigin(
84  src0_desc, src0_block_slice_origin + thread_data_idx_begin);
85  threadwise_transfer_.SetSrc1SliceOrigin(
86  src1_desc, src1_block_slice_origin + thread_data_idx_begin);
87  threadwise_transfer_.SetDstSliceOrigin(dst_desc,
88  dst_block_slice_origin + thread_data_idx_begin);
89  }
90  }
91 
92  template <typename Src0Buffer, typename Src1Buffer, typename DstBuffer>
93  __device__ void Run(const Src0Desc& src0_desc,
94  const Src0Buffer& src0_buf,
95  const Src1Desc& src1_desc,
96  const Src1Buffer& src1_buf,
97  const DstDesc& dst_desc,
98  DstBuffer& dst_buf)
99  {
100  if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
101  ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
102  {
103  threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf);
104  }
105  }
106 
107  __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
108  {
109  if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
110  ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
111  {
112  threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
113  }
114  }
115 
116  __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
117  {
118  if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
119  ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
120  {
121  threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
122  }
123  }
124 
125  __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
126  {
127  if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
128  ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
129  {
130  threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
131  }
132  }
133 
134  private:
135  static constexpr auto thread_cluster_desc_ =
136  make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
137 
138  using ThreadwiseTransfer =
139  ThreadwiseTensorSliceTransfer_v6r2<Src0Data,
140  Src1Data,
141  DstData,
142  Src0Desc,
143  Src1Desc,
144  DstDesc,
145  ElementwiseOperation,
146  decltype(thread_slice_lengths),
147  DimAccessOrder,
148  VectorDim,
149  ScalarPerVector,
150  DstInMemOp,
151  ThreadTransferSrc0ResetCoordinateAfterRun,
152  ThreadTransferSrc1ResetCoordinateAfterRun,
153  ThreadTransferDstResetCoordinateAfterRun>;
154 
155  ThreadwiseTransfer threadwise_transfer_;
156 };
157 
158 } // namespace ck
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
InMemoryDataOperationEnum
Definition: ck.hpp:267
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto make_zero_multi_index()
Definition: array_multi_index.hpp:21
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: array.hpp:14
Definition: thread_group_tensor_slice_transfer_v6r2.hpp:37
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &step)
Definition: thread_group_tensor_slice_transfer_v6r2.hpp:125
static constexpr index_t nDim
Definition: thread_group_tensor_slice_transfer_v6r2.hpp:38
constexpr __device__ ThreadGroupTensorSliceTransfer_v6r2(const Src0Desc &src0_desc, const Index &src0_block_slice_origin, const Src1Desc &src1_desc, const Index &src1_block_slice_origin, const DstDesc &dst_desc, const Index &dst_block_slice_origin, const ElementwiseOperation &element_op)
Definition: thread_group_tensor_slice_transfer_v6r2.hpp:44
static constexpr auto thread_slice_lengths
Definition: thread_group_tensor_slice_transfer_v6r2.hpp:40
__device__ void MoveSrc1SliceWindow(const Src1Desc &src1_desc, const Index &step)
Definition: thread_group_tensor_slice_transfer_v6r2.hpp:116
__device__ void MoveSrc0SliceWindow(const Src0Desc &src0_desc, const Index &step)
Definition: thread_group_tensor_slice_transfer_v6r2.hpp:107
__device__ void Run(const Src0Desc &src0_desc, const Src0Buffer &src0_buf, const Src1Desc &src1_desc, const Src1Buffer &src1_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: thread_group_tensor_slice_transfer_v6r2.hpp:93
__device__ void SetSrc0SliceOrigin(const Src0Desc &src0_desc, const Index &src0_slice_origin_idx)
Definition: threadwise_tensor_slice_transfer_v6r2.hpp:68
__device__ void SetSrc1SliceOrigin(const Src1Desc &src1_desc, const Index &src1_slice_origin_idx)
Definition: threadwise_tensor_slice_transfer_v6r2.hpp:74
__device__ void SetDstSliceOrigin(const DstDesc &dst_desc, const Index &dst_slice_origin_idx)
Definition: threadwise_tensor_slice_transfer_v6r2.hpp:80
__device__ void MoveSrc0SliceWindow(const Src0Desc &src0_desc, const Index &src0_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer_v6r2.hpp:209
__device__ void Run(const Src0Desc &src0_desc, const Src0Buffer &src0_buf, const Src1Desc &src1_desc, const Src1Buffer &src1_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer_v6r2.hpp:86
__device__ void MoveSrc1SliceWindow(const Src1Desc &src1_desc, const Index &src1_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer_v6r2.hpp:224
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &dst_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer_v6r2.hpp:239
Definition: type.hpp:177