/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp Source File
transform_contraction_to_gemm_arraybase.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
10 
11 namespace ck {
12 namespace tensor_operation {
13 
14 // assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
15 template <index_t NumDimG,
16  index_t NumDimM,
17  index_t NumDimN,
19 __host__ __device__ static auto
20 MakeGridDescriptorPair(const std::array<index_t, NumDimG + NumDimM + NumDimN>& gs_ms_ns_lengths_vec,
21  const std::array<index_t, NumDimG + NumDimM + NumDimN>& gs_ms_ns_strides_vec)
22 {
23  // if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
24  // gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN))
25  // {
26  // throw std::runtime_error("wrong! dimension must match input lengths");
27  // }
28 
29  const auto to_tuple = [&](auto& vec, auto start, auto end) {
30  return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
31  };
32 
33  const auto gs_ms_ns_lengths =
34  to_tuple(gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
35  const auto gs_ms_ns_strides =
36  to_tuple(gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
37 
38  // dimension Ids for G0, G1, ...
39  constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
40 
41  // dimension Ids for M0, M1, ...
42  constexpr auto mDimIds =
44 
45  // dimension Ids for N0, N1, ...
46  constexpr auto nDimIds =
48 
49  // lengths for G0, G1, ...
50  const auto gLengths = get_container_subset(gs_ms_ns_lengths, gDimIds);
51 
52  // lengths for M0, M1, ...
53  const auto mLengths = get_container_subset(gs_ms_ns_lengths, mDimIds);
54 
55  // lengths for N0, N1, ...
56  const auto nLengths = get_container_subset(gs_ms_ns_lengths, nDimIds);
57 
58  if constexpr(TensorSpec == device::TensorSpecialization::Packed)
59  {
60  auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{});
61  auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
62  auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
63  const auto grid_desc_g_mraw_nraw = make_naive_tensor_descriptor(
64  make_tuple(G, M, N),
65  make_tuple(gs_ms_ns_strides[Number<NumDimG - 1>{}],
66  gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
67  gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
68 
69  const auto grid_desc_mraw_nraw = make_naive_tensor_descriptor(
70  make_tuple(M, N),
71  make_tuple(gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
72  gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
73 
74  return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw);
75  }
76  else
77  {
78  // naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
79  const auto grid_desc_gs_ms_ns =
80  make_naive_tensor_descriptor(gs_ms_ns_lengths, gs_ms_ns_strides);
81 
82  // transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
83  // N2 * ...]
84  // Note: This does not require padding as it only provides G offset calculation. Technically
85  // descriptor for only G is needed. Here we opt for backward compatibility purpose to return
86  // G_M_N
87  const auto grid_desc_g_mraw_nraw =
88  transform_tensor_descriptor(grid_desc_gs_ms_ns,
90  make_merge_transform(mLengths),
91  make_merge_transform(nLengths)),
92  make_tuple(gDimIds, mDimIds, nDimIds),
93  make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
94 
95  const auto c_ms_ns_lengths = to_tuple(
96  gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
97  const auto c_ms_ns_strides = to_tuple(
98  gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
99 
100  // transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
101  // N2 * ...]
102  const auto grid_desc_ms_ns = make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
103 
104  const auto grid_desc_mraw_nraw = transform_tensor_descriptor(
105  grid_desc_ms_ns,
107  make_tuple(mDimIds - Number<NumDimG>{}, nDimIds - Number<NumDimG>{}),
108  make_tuple(Sequence<0>{}, Sequence<1>{}));
109 
110  return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw);
111  }
112 }
113 
114 template <typename NumDims_G_M_N_K_O, // Sequence<>
115  typename PerBlock_M_N_K_O, // Sequence<>
122 {
123  static constexpr auto I0 = Number<0>{};
124  static constexpr auto I1 = Number<1>{};
125  static constexpr auto I2 = Number<2>{};
126  static constexpr auto I3 = Number<3>{};
127  static constexpr auto I4 = Number<4>{};
128 
129  static constexpr index_t NumDimG = NumDims_G_M_N_K_O::At(I0);
130  static constexpr index_t NumDimM = NumDims_G_M_N_K_O::At(I1);
131  static constexpr index_t NumDimN = NumDims_G_M_N_K_O::At(I2);
132  static constexpr index_t NumDimK = NumDims_G_M_N_K_O::At(I3);
133  static constexpr index_t NumDimO = NumDims_G_M_N_K_O::At(I4);
134 
135  static constexpr index_t MPerBlock = PerBlock_M_N_K_O::At(I0);
136  static constexpr index_t NPerBlock = PerBlock_M_N_K_O::At(I1);
137  static constexpr index_t KPerBlock = PerBlock_M_N_K_O::At(I2);
138  static constexpr index_t OPerBlock = PerBlock_M_N_K_O::At(I3);
139 
140  static constexpr auto matrix_padder =
143 
144  //
145  // A
146  //
147  __host__ __device__ static auto MakeAGridDescriptorPair(
148  const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
149  const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
150  {
151  return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimK, ASpec>(a_gs_ms_ks_lengths_vec,
152  a_gs_ms_ks_strides_vec);
153  }
154 
155  // TODO: rename to G_MRaw_KRaw
156  __host__ __device__ static auto MakeAGridDescriptor_G_M_K(
157  const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
158  const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
159  {
160  return MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).first;
161  }
162  __host__ __device__ static auto MakeAGridDescriptor_M_K(
163  const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
164  const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
165  {
166  return matrix_padder.PadADescriptor_M_K(
167  MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).second);
168  }
169 
170  template <typename AGridDesc_M_K, typename Number>
171  __host__ __device__ static constexpr auto
172  MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k, const Number& AK1)
173  {
174  const auto M = a_grid_desc_m_k.GetLength(I0);
175  const auto K = a_grid_desc_m_k.GetLength(I1);
176 
177  const auto AK0 = K / AK1;
178 
179  return transform_tensor_descriptor(a_grid_desc_m_k,
184  }
185 
186  template <typename AGridDesc_M_K,
187  typename WmmaK,
188  typename MRepeat,
189  typename MWaves,
190  typename MPerWmma,
191  typename AK1>
192  __host__ __device__ static constexpr auto
194  const AGridDesc_M_K& a_grid_desc_m_k,
195  const WmmaK&,
196  const MRepeat&,
197  const MWaves&,
198  const MPerWmma&,
199  const AK1&)
200  {
201  const auto M0 = a_grid_desc_m_k.GetLength(I0) / MPerBlock;
202  const auto K = a_grid_desc_m_k.GetLength(I1);
203  const auto AKWmma = K / WmmaK{};
204  constexpr auto AKRow = 2;
205  constexpr auto AK0PerWmma = WmmaK{} / AKRow / AK1{};
206 
208  a_grid_desc_m_k,
210  make_tuple(AKWmma, Number<AK0PerWmma>{}, Number<AKRow>{}, AK1{})),
211  make_unmerge_transform(make_tuple(M0 * MRepeat{}, MWaves{}, MPerWmma{}))),
214  }
215 
216  //
217  // B (alias of B0)
218  //
219  __host__ __device__ static auto MakeB0GridDescriptorPair(
220  const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_lengths_vec,
221  const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_strides_vec)
222  {
223  return MakeGridDescriptorPair<NumDimG, NumDimN, NumDimK, B0Spec>(b0_gs_ns_ks_lengths_vec,
224  b0_gs_ns_ks_strides_vec);
225  }
226 
227  // TODO: rename to G_MRaw_NRaw
228  __host__ __device__ static auto MakeB0GridDescriptor_G_N_K(
229  const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_lengths_vec,
230  const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_strides_vec)
231  {
232  return MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).first;
233  }
234  __host__ __device__ static auto MakeB0GridDescriptor_N_K(
235  const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_lengths_vec,
236  const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_strides_vec)
237  {
238  // alias of matrix_padder.PadB0Descriptor_N_K
239  return matrix_padder.PadBDescriptor_N_K(
240  MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).second);
241  }
242 
243  template <typename BGridDesc_N_K, typename Number>
244  __host__ __device__ static constexpr auto
245  MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k, const Number& BK1)
246  {
247  const auto N = b_grid_desc_n_k.GetLength(I0);
248  const auto K = b_grid_desc_n_k.GetLength(I1);
249 
250  const auto BK0 = K / BK1;
251 
252  return transform_tensor_descriptor(b_grid_desc_n_k,
257  }
258 
259  template <typename BGridDesc_L_K,
260  typename WmmaK,
261  typename LRepeat,
262  typename LWaves,
263  typename LPerWmma,
264  typename BK1>
265  __host__ __device__ static constexpr auto
267  const BGridDesc_L_K& b_grid_desc_l_k,
268  const WmmaK&,
269  const LRepeat&,
270  const LWaves&,
271  const LPerWmma&,
272  const BK1&)
273  {
274  const auto L0 = b_grid_desc_l_k.GetLength(I0) / NPerBlock;
275  const auto K = b_grid_desc_l_k.GetLength(I1);
276  const auto BKWmma = K / WmmaK{};
277  constexpr auto BKRow = 2;
278  constexpr auto BK0PerWmma = WmmaK{} / BKRow / BK1{};
279 
281  b_grid_desc_l_k,
283  make_tuple(BKWmma, Number<BK0PerWmma>{}, Number<BKRow>{}, BK1{})),
284  make_unmerge_transform(make_tuple(L0 * LRepeat{}, LWaves{}, LPerWmma{}))),
287  }
288 
289  //
290  // B1
291  //
292  __host__ __device__ static auto MakeB1GridDescriptorPair(
293  const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_lengths_vec,
294  const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_strides_vec)
295  {
296  return MakeGridDescriptorPair<NumDimG, NumDimO, NumDimN, B1Spec>(b1_gs_os_ns_lengths_vec,
297  b1_gs_os_ns_strides_vec);
298  }
299 
300  // TODO: rename to G_NRaw_KRaw
301  __host__ __device__ static auto MakeB1GridDescriptor_G_N_K(
302  const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_lengths_vec,
303  const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_strides_vec)
304  {
305  return MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).first;
306  }
307  __host__ __device__ static auto MakeB1GridDescriptor_N_K(
308  const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_lengths_vec,
309  const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_strides_vec)
310  {
311  // alias of matrix_padder.PadB1Descriptor_O_N
312  return matrix_padder.PadB1Descriptor_N_K(
313  MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).second);
314  }
315 
316  template <typename B1GridDesc_N_K, typename Number>
317  __host__ __device__ static constexpr auto
318  MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K& b1_grid_desc_n_k, const Number& B1K1)
319  {
320  const auto N = b1_grid_desc_n_k.GetLength(I0);
321  const auto K = b1_grid_desc_n_k.GetLength(I1);
322 
323  const auto B1K0 = K / B1K1;
324 
326  b1_grid_desc_n_k,
331  }
332 
333  template <typename BGridDesc_N_L,
334  typename WmmaL,
335  typename NRepeat,
336  typename NWaves,
337  typename NPerWmma,
338  typename BL1>
339  __host__ __device__ static constexpr auto
341  const BGridDesc_N_L& b_grid_desc_n_l,
342  const WmmaL&,
343  const NRepeat&,
344  const NWaves&,
345  const NPerWmma&,
346  const BL1&)
347  {
348  const auto N0 = b_grid_desc_n_l.GetLength(I0) / OPerBlock;
349  const auto L = b_grid_desc_n_l.GetLength(I1);
350  const auto BLWmma = L / WmmaL{};
351  constexpr auto BLRow = 2;
352  constexpr auto BL0PerWmma = WmmaL{} / BLRow / BL1{};
353 
355  b_grid_desc_n_l,
357  make_tuple(BLWmma, Number<BL0PerWmma>{}, Number<BLRow>{}, BL1{})),
358  make_unmerge_transform(make_tuple(N0 * NRepeat{}, NWaves{}, NPerWmma{}))),
361  }
362 
363  //
364  // C
365  //
366  __host__ __device__ static auto MakeCGridDescriptorPair(
367  const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_lengths_vec,
368  const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_strides_vec)
369  {
370  return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimO, CSpec>(c_gs_ms_os_lengths_vec,
371  c_gs_ms_os_strides_vec);
372  }
373 
374  // TODO: rename to G_MRaw_NRaw
375  __host__ __device__ static auto MakeCGridDescriptor_G_M_N(
376  const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_lengths_vec,
377  const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_strides_vec)
378  {
379  return MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).first;
380  }
381  __host__ __device__ static auto MakeCGridDescriptor_M_N(
382  const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_lengths_vec,
383  const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_strides_vec)
384  {
385  return matrix_padder.PadCDescriptor_M_N(
386  MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second);
387  }
388 };
389 
390 } // namespace tensor_operation
391 } // namespace ck
TensorSpecialization
Definition: tensor_specialization.hpp:11
GemmSpecialization
Definition: gemm_specialization.hpp:11
__host__ __device__ multiplies() -> multiplies< void, void >
FIXME: create macro to replace 'host device' and nothing more.
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition: container_helper.hpp:111
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
integral_constant< index_t, N > Number
Definition: number.hpp:12
__host__ constexpr __device__ auto get_container_subset(const Array< T, N > &arr, Sequence< Is... >)
Definition: container_helper.hpp:346
Definition: sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:256
Definition: integral_constant.hpp:10
Definition: transform_contraction_to_gemm_arraybase.hpp:122
static constexpr auto I1
Definition: transform_contraction_to_gemm_arraybase.hpp:124
__host__ static __device__ auto MakeCGridDescriptorPair(const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_strides_vec)
Definition: transform_contraction_to_gemm_arraybase.hpp:366
static constexpr index_t KPerBlock
Definition: transform_contraction_to_gemm_arraybase.hpp:137
__host__ static __device__ auto MakeCGridDescriptor_G_M_N(const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_strides_vec)
Definition: transform_contraction_to_gemm_arraybase.hpp:375
__host__ static __device__ auto MakeB1GridDescriptorPair(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_strides_vec)
Definition: transform_contraction_to_gemm_arraybase.hpp:292
static constexpr index_t NPerBlock
Definition: transform_contraction_to_gemm_arraybase.hpp:136
__host__ static constexpr __device__ auto MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1(const AGridDesc_M_K &a_grid_desc_m_k, const WmmaK &, const MRepeat &, const MWaves &, const MPerWmma &, const AK1 &)
Definition: transform_contraction_to_gemm_arraybase.hpp:193
static constexpr auto matrix_padder
Definition: transform_contraction_to_gemm_arraybase.hpp:140
static constexpr auto I3
Definition: transform_contraction_to_gemm_arraybase.hpp:126
static constexpr auto I2
Definition: transform_contraction_to_gemm_arraybase.hpp:125
__host__ static constexpr __device__ auto MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k, const Number &BK1)
Definition: transform_contraction_to_gemm_arraybase.hpp:245
__host__ static __device__ auto MakeB1GridDescriptor_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_strides_vec)
Definition: transform_contraction_to_gemm_arraybase.hpp:307
static constexpr index_t NumDimG
Definition: transform_contraction_to_gemm_arraybase.hpp:129
static constexpr index_t NumDimK
Definition: transform_contraction_to_gemm_arraybase.hpp:132
__host__ static constexpr __device__ auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K &b1_grid_desc_n_k, const Number &B1K1)
Definition: transform_contraction_to_gemm_arraybase.hpp:318
__host__ static __device__ auto MakeB1GridDescriptor_G_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_strides_vec)
Definition: transform_contraction_to_gemm_arraybase.hpp:301
__host__ static __device__ auto MakeB0GridDescriptor_G_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_strides_vec)
Definition: transform_contraction_to_gemm_arraybase.hpp:228
static constexpr index_t NumDimN
Definition: transform_contraction_to_gemm_arraybase.hpp:131
static constexpr index_t OPerBlock
Definition: transform_contraction_to_gemm_arraybase.hpp:138
__host__ static __device__ auto MakeB0GridDescriptorPair(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_strides_vec)
Definition: transform_contraction_to_gemm_arraybase.hpp:219
static constexpr auto I0
Definition: transform_contraction_to_gemm_arraybase.hpp:123
static constexpr index_t NumDimM
Definition: transform_contraction_to_gemm_arraybase.hpp:130
__host__ static constexpr __device__ auto MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k, const Number &AK1)
Definition: transform_contraction_to_gemm_arraybase.hpp:172
__host__ static __device__ auto MakeAGridDescriptor_G_M_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition: transform_contraction_to_gemm_arraybase.hpp:156
__host__ static constexpr __device__ auto MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1(const BGridDesc_N_L &b_grid_desc_n_l, const WmmaL &, const NRepeat &, const NWaves &, const NPerWmma &, const BL1 &)
Definition: transform_contraction_to_gemm_arraybase.hpp:340
static constexpr index_t MPerBlock
Definition: transform_contraction_to_gemm_arraybase.hpp:135
__host__ static constexpr __device__ auto MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1(const BGridDesc_L_K &b_grid_desc_l_k, const WmmaK &, const LRepeat &, const LWaves &, const LPerWmma &, const BK1 &)
Definition: transform_contraction_to_gemm_arraybase.hpp:266
__host__ static __device__ auto MakeCGridDescriptor_M_N(const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_strides_vec)
Definition: transform_contraction_to_gemm_arraybase.hpp:381
static constexpr auto I4
Definition: transform_contraction_to_gemm_arraybase.hpp:127
__host__ static __device__ auto MakeAGridDescriptor_M_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition: transform_contraction_to_gemm_arraybase.hpp:162
__host__ static __device__ auto MakeAGridDescriptorPair(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition: transform_contraction_to_gemm_arraybase.hpp:147
__host__ static __device__ auto MakeB0GridDescriptor_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_strides_vec)
Definition: transform_contraction_to_gemm_arraybase.hpp:234
static constexpr index_t NumDimO
Definition: transform_contraction_to_gemm_arraybase.hpp:133
Definition: matrix_padder.hpp:63