10 namespace tensor_operation {
14 template <
typename InLayout,
typename WeiLayout,
typename OutLayout>
17 return is_same_v<InLayout, tensor_layout::convolution::NWGC> &&
18 is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
19 is_same_v<OutLayout, tensor_layout::convolution::NWGK>;
22 template <
typename InLayout,
typename WeiLayout,
typename OutLayout>
25 return is_same_v<InLayout, tensor_layout::convolution::GNWC> &&
26 is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
27 is_same_v<OutLayout, tensor_layout::convolution::GNWK>;
30 template <
typename InLayout,
typename WeiLayout,
typename OutLayout>
33 return is_same_v<InLayout, tensor_layout::convolution::NGCW> &&
34 is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
35 is_same_v<OutLayout, tensor_layout::convolution::NGKW>;
39 template <
typename InLayout,
typename WeiLayout,
typename OutLayout>
42 return is_same_v<InLayout, tensor_layout::convolution::NHWGC> &&
43 is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
44 is_same_v<OutLayout, tensor_layout::convolution::NHWGK>;
47 template <
typename InLayout,
typename WeiLayout,
typename OutLayout>
50 return is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
51 is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
52 is_same_v<OutLayout, tensor_layout::convolution::GNHWK>;
55 template <
typename InLayout,
typename WeiLayout,
typename OutLayout>
58 return is_same_v<InLayout, tensor_layout::convolution::NGCHW> &&
59 is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
60 is_same_v<OutLayout, tensor_layout::convolution::NGKHW>;
63 template <
typename InLayout,
typename WeiLayout,
typename OutLayout>
66 return is_same_v<InLayout, tensor_layout::convolution::NGCHW> &&
67 is_same_v<WeiLayout, tensor_layout::convolution::GKCYX> &&
68 is_same_v<OutLayout, tensor_layout::convolution::NGKHW>;
71 template <
typename InLayout,
typename WeiLayout,
typename OutLayout>
74 return is_same_v<InLayout, tensor_layout::convolution::NGCHW> &&
75 is_same_v<OutLayout, tensor_layout::convolution::NGKHW>;
79 template <
typename InLayout,
typename WeiLayout,
typename OutLayout>
82 return is_same_v<InLayout, tensor_layout::convolution::NDHWGC> &&
83 is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
84 is_same_v<OutLayout, tensor_layout::convolution::NDHWGK>;
87 template <
typename InLayout,
typename WeiLayout,
typename OutLayout>
90 return is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
91 is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
92 is_same_v<OutLayout, tensor_layout::convolution::GNDHWK>;
95 template <
typename InLayout,
typename WeiLayout,
typename OutLayout>
98 return is_same_v<InLayout, tensor_layout::convolution::NGCDHW> &&
99 is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
100 is_same_v<OutLayout, tensor_layout::convolution::NGKDHW>;
103 template <
typename InLayout,
typename WeiLayout,
typename OutLayout>
106 return is_same_v<InLayout, tensor_layout::convolution::NGCDHW> &&
107 is_same_v<WeiLayout, tensor_layout::convolution::GKCZYX> &&
108 is_same_v<OutLayout, tensor_layout::convolution::NGKDHW>;
111 template <
typename InLayout,
typename WeiLayout,
typename OutLayout>
114 return is_same_v<InLayout, tensor_layout::convolution::NGCDHW> &&
115 is_same_v<OutLayout, tensor_layout::convolution::NGKDHW>;
118 template <
typename InLayout,
typename WeiLayout,
typename OutLayout>
121 return is_NWGC_GKXC_NWGK<InLayout, WeiLayout, OutLayout>() ||
122 is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
123 is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>();
126 template <
typename InLayout,
typename WeiLayout,
typename OutLayout>
129 return is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>() ||
130 is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>() ||
131 is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>();
134 template <
typename InLayout,
typename WeiLayout,
typename OutLayout>
137 return is_NGCW_GKXC_NGKW<InLayout, WeiLayout, OutLayout>() ||
138 is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
139 is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>();
142 template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0,
typename =
void>
143 struct ComputePtrOffsetOfStridedBatch
147 template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
148 struct ComputePtrOffsetOfStridedBatch<NumATensor,
208 template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
209 struct ComputePtrOffsetOfStridedBatch<NumATensor,
263 template <
bool isTuple,
typename Tensors>
264 constexpr
static auto GetNumABTensors()
266 if constexpr(isTuple)
268 return Number<Tensors::Size()>{};
276 template <
bool isTuple,
typename Gr
idwiseGemm,
typename DataType>
277 constexpr
static auto GetAGridPointer()
279 if constexpr(isTuple)
281 return typename GridwiseGemm::AsGridPointer{};
285 return Tuple<const DataType*>{};
289 template <
bool isTuple,
typename Gr
idwiseGemm,
typename DataType>
290 constexpr
static auto GetBGridPointer()
292 if constexpr(isTuple)
294 return typename GridwiseGemm::BsGridPointer{};
298 return Tuple<const DataType*>{};
302 template <
bool isTuple,
typename Id,
typename Type>
303 constexpr
static auto UnpackDataType()
305 if constexpr(isTuple)
index_t BatchStrideC_
Definition: device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:56
index_t BatchStrideB_
Definition: device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:55
index_t BatchStrideA_
Definition: device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:54
Array< ck::index_t, NumDTensor > BatchStrideDs_
Definition: device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:84
index_t BatchStrideE_
Definition: device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:85
constexpr bool is_NWGC_GKXC_NWGK()
Definition: device_grouped_conv_utils.hpp:15
constexpr bool is_NSpatialGC_GKSpatial_NSpatialGK()
Definition: device_grouped_conv_utils.hpp:119
constexpr bool is_GNWC_GKXC_GNWK()
Definition: device_grouped_conv_utils.hpp:23
constexpr bool is_GNDHWC_GKZYXC_GNDHWK()
Definition: device_grouped_conv_utils.hpp:88
constexpr bool is_NGCSpatial_GKSpatial_NGKSpatial()
Definition: device_grouped_conv_utils.hpp:135
constexpr bool is_NHWGC_GKYXC_NHWGK()
Definition: device_grouped_conv_utils.hpp:40
constexpr bool is_NGCHW_GKYXC_NGKHW()
Definition: device_grouped_conv_utils.hpp:56
constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
Definition: device_grouped_conv_utils.hpp:80
constexpr bool is_NGCDHW_NGKDHW()
Definition: device_grouped_conv_utils.hpp:112
constexpr bool is_NGCW_GKXC_NGKW()
Definition: device_grouped_conv_utils.hpp:31
constexpr bool is_NGCHW_GKCYX_NGKHW()
Definition: device_grouped_conv_utils.hpp:64
constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK()
Definition: device_grouped_conv_utils.hpp:127
constexpr bool is_NGCDHW_GKZYXC_NGKDHW()
Definition: device_grouped_conv_utils.hpp:96
constexpr bool is_GNHWC_GKYXC_GNHWK()
Definition: device_grouped_conv_utils.hpp:48
constexpr bool is_NGCDHW_GKCZYX_NGKDHW()
Definition: device_grouped_conv_utils.hpp:104
constexpr bool is_NGCHW_NGKHW()
Definition: device_grouped_conv_utils.hpp:72
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
int64_t long_index_t
Definition: ck.hpp:301
int32_t index_t
Definition: ck.hpp:300
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
__host__ constexpr __device__ long_index_t GetEPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:245
Array< long_index_t, NumDTensor > BatchStrideDs_
Definition: device_grouped_conv_utils.hpp:258
long_index_t BatchStrideA_
Definition: device_grouped_conv_utils.hpp:256
__host__ constexpr __device__ long_index_t GetBPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:232
long_index_t BatchStrideB_
Definition: device_grouped_conv_utils.hpp:257
long_index_t BatchStrideE_
Definition: device_grouped_conv_utils.hpp:259
__host__ constexpr __device__ long_index_t GetAPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:227
__host__ constexpr __device__ long_index_t GetCPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:251
ComputePtrOffsetOfStridedBatch()=default
ComputePtrOffsetOfStridedBatch(long_index_t BatchStrideA, long_index_t BatchStrideB, Array< long_index_t, NumDTensor > BatchStrideDs, long_index_t BatchStrideE)
Definition: device_grouped_conv_utils.hpp:216
__host__ constexpr __device__ auto GetDsPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:237
ComputePtrOffsetOfStridedBatch(Array< long_index_t, NumATensor > &BatchStrideAs, Array< long_index_t, NumBTensor > &BatchStrideBs, Array< long_index_t, NumDTensor > &BatchStrideDs, long_index_t BatchStrideE)
Definition: device_grouped_conv_utils.hpp:155
__host__ constexpr __device__ long_index_t GetEPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:190
ComputePtrOffsetOfStridedBatch()=default
__host__ constexpr __device__ long_index_t GetCPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:196
__host__ constexpr __device__ auto GetAsPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:166
Array< long_index_t, NumATensor > BatchStrideA_
Definition: device_grouped_conv_utils.hpp:201
Array< long_index_t, NumBTensor > BatchStrideB_
Definition: device_grouped_conv_utils.hpp:202
Array< long_index_t, NumDTensor > BatchStrideDs_
Definition: device_grouped_conv_utils.hpp:203
__host__ constexpr __device__ auto GetBsPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:174
long_index_t BatchStrideE_
Definition: device_grouped_conv_utils.hpp:204
__host__ constexpr __device__ auto GetDsPtrOffset(index_t g_idx) const
Definition: device_grouped_conv_utils.hpp:182