/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp Source File
device_grouped_conv_utils.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 
9 namespace ck {
10 namespace tensor_operation {
11 namespace device {
12 
13 // 1d
14 template <typename InLayout, typename WeiLayout, typename OutLayout>
15 constexpr bool is_NWGC_GKXC_NWGK()
16 {
17  return is_same_v<InLayout, tensor_layout::convolution::NWGC> &&
18  is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
19  is_same_v<OutLayout, tensor_layout::convolution::NWGK>;
20 }
21 
22 template <typename InLayout, typename WeiLayout, typename OutLayout>
23 constexpr bool is_GNWC_GKXC_GNWK()
24 {
25  return is_same_v<InLayout, tensor_layout::convolution::GNWC> &&
26  is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
27  is_same_v<OutLayout, tensor_layout::convolution::GNWK>;
28 }
29 
30 template <typename InLayout, typename WeiLayout, typename OutLayout>
31 constexpr bool is_NGCW_GKXC_NGKW()
32 {
33  return is_same_v<InLayout, tensor_layout::convolution::NGCW> &&
34  is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
35  is_same_v<OutLayout, tensor_layout::convolution::NGKW>;
36 }
37 
38 // 2d
39 template <typename InLayout, typename WeiLayout, typename OutLayout>
40 constexpr bool is_NHWGC_GKYXC_NHWGK()
41 {
42  return is_same_v<InLayout, tensor_layout::convolution::NHWGC> &&
43  is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
44  is_same_v<OutLayout, tensor_layout::convolution::NHWGK>;
45 }
46 
47 template <typename InLayout, typename WeiLayout, typename OutLayout>
48 constexpr bool is_GNHWC_GKYXC_GNHWK()
49 {
50  return is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
51  is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
52  is_same_v<OutLayout, tensor_layout::convolution::GNHWK>;
53 }
54 
55 template <typename InLayout, typename WeiLayout, typename OutLayout>
56 constexpr bool is_NGCHW_GKYXC_NGKHW()
57 {
58  return is_same_v<InLayout, tensor_layout::convolution::NGCHW> &&
59  is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
60  is_same_v<OutLayout, tensor_layout::convolution::NGKHW>;
61 }
62 // 3d
63 template <typename InLayout, typename WeiLayout, typename OutLayout>
64 constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
65 {
66  return is_same_v<InLayout, tensor_layout::convolution::NDHWGC> &&
67  is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
68  is_same_v<OutLayout, tensor_layout::convolution::NDHWGK>;
69 }
70 
71 template <typename InLayout, typename WeiLayout, typename OutLayout>
72 constexpr bool is_GNDHWC_GKZYXC_GNDHWK()
73 {
74  return is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
75  is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
76  is_same_v<OutLayout, tensor_layout::convolution::GNDHWK>;
77 }
78 
79 template <typename InLayout, typename WeiLayout, typename OutLayout>
80 constexpr bool is_NGCDHW_GKZYXC_NGKDHW()
81 {
82  return is_same_v<InLayout, tensor_layout::convolution::NGCDHW> &&
83  is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
84  is_same_v<OutLayout, tensor_layout::convolution::NGKDHW>;
85 }
86 
87 template <typename InLayout, typename WeiLayout, typename OutLayout>
89 {
90  return is_NWGC_GKXC_NWGK<InLayout, WeiLayout, OutLayout>() ||
91  is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
92  is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>();
93 }
94 
95 template <typename InLayout, typename WeiLayout, typename OutLayout>
97 {
98  return is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>() ||
99  is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>() ||
100  is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>();
101 }
102 
103 template <typename InLayout, typename WeiLayout, typename OutLayout>
105 {
106  return is_NGCW_GKXC_NGKW<InLayout, WeiLayout, OutLayout>() ||
107  is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
108  is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>();
109 }
110 
111 template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void>
112 struct ComputePtrOffsetOfStridedBatch
113 {
114 };
115 
116 template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
117 struct ComputePtrOffsetOfStridedBatch<NumATensor,
118  NumBTensor,
119  NumDTensor,
120  enable_if_t<(NumATensor > 1 || NumBTensor > 1)>>
121 {
123 
125  Array<long_index_t, NumBTensor>& BatchStrideBs,
126  Array<long_index_t, NumDTensor>& BatchStrideDs,
127  long_index_t BatchStrideE)
128  : BatchStrideA_(BatchStrideAs),
129  BatchStrideB_(BatchStrideBs),
130  BatchStrideDs_(BatchStrideDs),
131  BatchStrideE_(BatchStrideE)
132  {
133  }
134 
135  __host__ __device__ constexpr auto GetAsPtrOffset(index_t g_idx) const
136  {
139  [&](auto i) { as_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideA_[i]; });
140  return as_offset;
141  }
142 
143  __host__ __device__ constexpr auto GetBsPtrOffset(index_t g_idx) const
144  {
147  [&](auto i) { bs_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideB_[i]; });
148  return bs_offset;
149  }
150 
151  __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
152  {
155  [&](auto i) { ds_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideDs_[i]; });
156  return ds_offset;
157  }
158 
159  [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
160  {
161  return static_cast<long_index_t>(g_idx) * BatchStrideE_;
162  }
163 
164  // alias for kernels without multiple D
165  [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
166  {
167  return static_cast<long_index_t>(g_idx) * BatchStrideE_;
168  }
169 
174  long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
175 };
176 
177 template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
178 struct ComputePtrOffsetOfStridedBatch<NumATensor,
179  NumBTensor,
180  NumDTensor,
181  enable_if_t<(NumATensor == 1 && NumBTensor == 1)>>
182 {
184 
186  long_index_t BatchStrideB,
187  Array<long_index_t, NumDTensor> BatchStrideDs,
188  long_index_t BatchStrideE)
189  : BatchStrideA_(BatchStrideA),
190  BatchStrideB_(BatchStrideB),
191  BatchStrideDs_(BatchStrideDs),
192  BatchStrideE_(BatchStrideE)
193  {
194  }
195 
196  __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
197  {
198  return static_cast<long_index_t>(g_idx) * BatchStrideA_;
199  }
200 
201  __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
202  {
203  return static_cast<long_index_t>(g_idx) * BatchStrideB_;
204  }
205 
206  __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
207  {
210  [&](auto i) { ds_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideDs_[i]; });
211  return ds_offset;
212  }
213 
214  [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
215  {
216  return static_cast<long_index_t>(g_idx) * BatchStrideE_;
217  }
218 
219  // alias for kernels without multiple D
220  [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
221  {
222  return static_cast<long_index_t>(g_idx) * BatchStrideE_;
223  }
224 
229  long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
230 };
231 
232 template <bool isTuple, typename Tensors>
233 constexpr static auto GetNumABTensors()
234 {
235  if constexpr(isTuple)
236  {
237  return Number<Tensors::Size()>{};
238  }
239  else
240  {
241  return Number<1>{};
242  }
243 }
244 
245 template <bool isTuple, typename GridwiseGemm, typename DataType>
246 constexpr static auto GetAGridPointer()
247 {
248  if constexpr(isTuple)
249  {
250  return typename GridwiseGemm::AsGridPointer{};
251  }
252  else
253  {
254  return Tuple<const DataType*>{};
255  }
256 }
257 
258 template <bool isTuple, typename GridwiseGemm, typename DataType>
259 constexpr static auto GetBGridPointer()
260 {
261  if constexpr(isTuple)
262  {
263  return typename GridwiseGemm::BsGridPointer{};
264  }
265  else
266  {
267  return Tuple<const DataType*>{};
268  }
269 }
270 
271 template <bool isTuple, typename Id, typename Type>
272 constexpr static auto UnpackDataType()
273 {
274  if constexpr(isTuple)
275  {
276  // unpack if tuple
277  return tuple_element_t<Id{}, Type>{};
278  }
279  else
280  {
281  // if no, return Type
282  return Type{};
283  }
284 }
285 
286 } // namespace device
287 } // namespace tensor_operation
288 } // namespace ck
index_t BatchStrideC_
Definition: device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:56
index_t BatchStrideB_
Definition: device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:55
index_t BatchStrideA_
Definition: device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:54
Array< ck::index_t, NumDTensor > BatchStrideDs_
Definition: device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:84
index_t BatchStrideE_
Definition: device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:85
constexpr bool is_NWGC_GKXC_NWGK()
Definition: device_grouped_conv_utils.hpp:15
constexpr bool is_NSpatialGC_GKSpatial_NSpatialGK()
Definition: device_grouped_conv_utils.hpp:88
constexpr bool is_GNWC_GKXC_GNWK()
Definition: device_grouped_conv_utils.hpp:23
constexpr bool is_GNDHWC_GKZYXC_GNDHWK()
Definition: device_grouped_conv_utils.hpp:72
constexpr bool is_NGCSpatial_GKSpatial_NGKSpatial()
Definition: device_grouped_conv_utils.hpp:104
constexpr bool is_NHWGC_GKYXC_NHWGK()
Definition: device_grouped_conv_utils.hpp:40
constexpr bool is_NGCHW_GKYXC_NGKHW()
Definition: device_grouped_conv_utils.hpp:56
constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
Definition: device_grouped_conv_utils.hpp:64
constexpr bool is_NGCW_GKXC_NGKW()
Definition: device_grouped_conv_utils.hpp:31
constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK()
Definition: device_grouped_conv_utils.hpp:96
constexpr bool is_NGCDHW_GKZYXC_NGKDHW()
Definition: device_grouped_conv_utils.hpp:80
constexpr bool is_GNHWC_GKYXC_GNHWK()
Definition: device_grouped_conv_utils.hpp:48
Definition: ck.hpp:264
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
int64_t long_index_t
Definition: ck.hpp:290
int32_t index_t
Definition: ck.hpp:289
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:13
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
__host__ constexpr __device__ long_index_t GetEPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:214
__host__ constexpr __device__ long_index_t GetBPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:201
__host__ constexpr __device__ long_index_t GetAPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:196
__host__ constexpr __device__ long_index_t GetCPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:220
ComputePtrOffsetOfStridedBatch(long_index_t BatchStrideA, long_index_t BatchStrideB, Array< long_index_t, NumDTensor > BatchStrideDs, long_index_t BatchStrideE)
Definition: device_grouped_conv_utils.hpp:185
__host__ constexpr __device__ auto GetDsPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:206
ComputePtrOffsetOfStridedBatch(Array< long_index_t, NumATensor > &BatchStrideAs, Array< long_index_t, NumBTensor > &BatchStrideBs, Array< long_index_t, NumDTensor > &BatchStrideDs, long_index_t BatchStrideE)
Definition: device_grouped_conv_utils.hpp:124
__host__ constexpr __device__ long_index_t GetEPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:159
__host__ constexpr __device__ long_index_t GetCPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:165
__host__ constexpr __device__ auto GetAsPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:135
__host__ constexpr __device__ auto GetBsPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:143
__host__ constexpr __device__ auto GetDsPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:151