/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp Source File
blockwise_gemm_dl_v2r3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
10 
11 namespace ck {
12 
13 // C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1]
14 // A and B are visible to the whole block, C is distributed among each thread
15 // Assume:
16 // 1. A:
17 // 1. ABlockDesc_BK0_BM_BK1 is known at compile-time
18 // 2. ABlockBuffer is DynamicBuffer
19 // 2. B:
20 // 1. BBlockDesc_BK0_BN_BK1 is known at compile-time
21 // 2. BBlockBuffer is DynamicBuffer
22 // 3. C:
23 // 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time
24 // 2. CThreadBuffer is StaticBuffer
25 // Also assume:
26 // BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2
27 // BM0 = BN0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
28 template <index_t BlockSize,
29  typename FloatA,
30  typename FloatB,
31  typename FloatC,
32  typename ABlockDesc_BK0_BM_BK1,
33  typename BBlockDesc_BK0_BN_BK1,
34  index_t BM1PerThreadBM11,
35  index_t BN1PerThreadBN11,
36  index_t BK0PerThread,
37  typename BM10BN10ThreadClusterBM10Xs, // Sequence<BM10BN10ThreadClusterBM100,
38  // BM10BN10ThreadClusterBM101, ...>
39  typename BM10BN10ThreadClusterBN10Xs, // Sequence<BM10BN10ThreadClusterBN100,
40  // BM10BN10ThreadClusterBN101, ...>
41  index_t AThreadCopyScalarPerVector_BM11,
42  index_t BThreadCopyScalarPerVector_BN11,
43  typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
44  BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
45  bool>::type = false>
47 {
51 
52  static constexpr auto I0 = Number<0>{};
53  static constexpr auto I1 = Number<1>{};
54  static constexpr auto I2 = Number<2>{};
55  static constexpr auto I3 = Number<3>{};
56 
57  static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0);
58  static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2);
59  static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1);
60  static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1);
61 
62  static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0];
63  static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0];
64 
65  static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1];
66  static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1];
67 
68  static constexpr index_t BM11 = BM1PerThreadBM11;
69  static constexpr index_t BN11 = BN1PerThreadBN11;
70 
71  static constexpr index_t BM1 = BM100 * BM101 * BM11;
72  static constexpr index_t BN1 = BN100 * BN101 * BN11;
73 
74  static constexpr index_t BM0 = BM / BM1;
75  static constexpr index_t BN0 = BN / BN1;
76 
77  __host__ __device__ static constexpr auto
78  MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1)
79  {
80  const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor(
81  a_block_desc_bk0_bm_bk1,
87 
88  return a_block_bk0_bm0_bm1_bk1;
89  }
90 
91  __host__ __device__ static constexpr auto
92  MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1)
93  {
94  const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor(
95  b_block_desc_bk0_bn_bk1,
101 
102  return b_block_desc_bk0_bn0_bn1_bk1;
103  }
104 
105  __host__ __device__ static constexpr auto
107  {
108  // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
109  // lower: [BM, BN]
110  constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n =
118 
119  return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n;
120  }
121 
122  __host__ __device__ static constexpr auto
124  {
125  // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
126  // lower: [BM0, BM1, BN0, BN1]
127  constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 =
137 
138  return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1;
139  }
140 
141  __host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1()
142  {
144  }
145 
146  static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ =
147  MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{});
148 
149  static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ =
150  MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{});
151 
152  public:
154  : c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
156  a_thread_copy_{
157  make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1], 0)},
158  b_thread_copy_{
159  make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3], 0)}
160  {
161  static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
162  BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
163  "wrong! Desc should be known at compile-time");
164 
165  static_assert(BlockSize == BM101 * BM100 * BN101 * BN100,
166  "wrong! blocksize and cluster size not consistent");
167 
168  static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!");
169 
170  static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) ==
171  BBlockDesc_BK0_BN_BK1{}.GetLength(I0),
172  "wrong! K dimension not consistent");
173 
174  // TODO remove this restriction
175  static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 &&
176  BM10BN10ThreadClusterBN10Xs::Size() == 2,
177  "wrong!");
178 
179  // TODO: remove this restriction
180  static_assert(BM0 == 2, "wrong");
181  static_assert(BM0 == 2 && BN0 == 2, "wrong");
182  }
183 
185  {
186  // lower: [BM0, BM1, BN0, BN1]
187  // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
188  constexpr auto adaptor0 =
190 
191  // lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
192  // upper: [Tid, BM0, BM11, BN0, BN11]
193  constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
199  make_tuple(
202 
203  constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
204 
205  return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0));
206  }
207 
208  template <typename CThreadDesc_BM0_BM11_BN0_BN11,
209  typename ABlockBuffer,
210  typename BBlockBuffer,
211  typename CThreadBuffer>
212  __device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&,
213  const ABlockBuffer& a_block_buf,
214  const BBlockBuffer& b_block_buf,
215  CThreadBuffer& c_thread_buf) const
216  {
217  static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(),
218  "wrong! Desc should be known at compile-time");
219 
220  // TODO: remove this restriction
221  static_assert(BM0 == 2 && BN0 == 2 &&
222  CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I0) == BM0 &&
223  CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0,
224  "wrong");
225 
226  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
227  a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize());
228  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
229  b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
230 
231  constexpr auto threadwise_contraction =
233  FloatA,
234  FloatB,
235  FloatC,
236  decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
237  decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
238  CThreadDesc_BM0_BM11_BN0_BN11,
242 
243  // read A_sub_0
244  a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
245  make_tuple(I0, I0, I0, I0),
246  a_block_buf,
247  a_thread_desc_bk0_bm0_bm1_bk1_,
248  make_tuple(I0, I0, I0, I0),
249  a_thread_buf);
250 
251  // read B_sub_0
252  b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
253  make_tuple(I0, I0, I0, I0),
254  b_block_buf,
255  b_thread_desc_bk0_bn0_bn1_bk1_,
256  make_tuple(I0, I0, I0, I0),
257  b_thread_buf);
258 
259  // read B_sub_1
260  b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
261  make_tuple(I0, I1, I0, I0),
262  b_block_buf,
263  b_thread_desc_bk0_bn0_bn1_bk1_,
264  make_tuple(I0, I1, I0, I0),
265  b_thread_buf);
266 
267  // read A_sub_1
268  a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
269  make_tuple(I0, I1, I0, I0),
270  a_block_buf,
271  a_thread_desc_bk0_bm0_bm1_bk1_,
272  make_tuple(I0, I1, I0, I0),
273  a_thread_buf);
274 
275  // C_sub_00 += transpose(A_sub_0) * B_sub_0
276  threadwise_contraction.Run(a_thread_buf,
277  make_tuple(I0, I0, I0, I0),
278  b_thread_buf,
279  make_tuple(I0, I0, I0, I0),
280  c_thread_buf,
281  make_tuple(I0, I0, I0, I0));
282 
283  // C_sub_01 += transpose(A_sub_0) * B_sub_1
284  threadwise_contraction.Run(a_thread_buf,
285  make_tuple(I0, I0, I0, I0),
286  b_thread_buf,
287  make_tuple(I0, I1, I0, I0),
288  c_thread_buf,
289  make_tuple(I0, I0, I1, I0));
290 
291  // loop over rest of bk0
293  // read A_sub_0
294  a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
295  make_tuple(bk0, I0, I0, I0),
296  a_block_buf,
297  a_thread_desc_bk0_bm0_bm1_bk1_,
298  make_tuple(I0, I0, I0, I0),
299  a_thread_buf);
300 
301  // C_sub_10 += transpose(A_sub_1) * B_sub_0
302  threadwise_contraction.Run(a_thread_buf,
303  make_tuple(I0, I1, I0, I0),
304  b_thread_buf,
305  make_tuple(I0, I0, I0, I0),
306  c_thread_buf,
307  make_tuple(I1, I0, I0, I0));
308 
309  // read B_sub_0
310  b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
311  make_tuple(bk0, I0, I0, I0),
312  b_block_buf,
313  b_thread_desc_bk0_bn0_bn1_bk1_,
314  make_tuple(I0, I0, I0, I0),
315  b_thread_buf);
316 
317  // C_sub_11 += transpose(A_sub_1) * B_sub_1
318  threadwise_contraction.Run(a_thread_buf,
319  make_tuple(I0, I1, I0, I0),
320  b_thread_buf,
321  make_tuple(I0, I1, I0, I0),
322  c_thread_buf,
323  make_tuple(I1, I0, I1, I0));
324 
325  // read B_sub_1
326  b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
327  make_tuple(bk0, I1, I0, I0),
328  b_block_buf,
329  b_thread_desc_bk0_bn0_bn1_bk1_,
330  make_tuple(I0, I1, I0, I0),
331  b_thread_buf);
332 
333  // read A_sub_1
334  a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
335  make_tuple(bk0, I1, I0, I0),
336  a_block_buf,
337  a_thread_desc_bk0_bm0_bm1_bk1_,
338  make_tuple(I0, I1, I0, I0),
339  a_thread_buf);
340 
341  // C_sub_00 += transpose(A_sub_0) * B_sub_0
342  threadwise_contraction.Run(a_thread_buf,
343  make_tuple(I0, I0, I0, I0),
344  b_thread_buf,
345  make_tuple(I0, I0, I0, I0),
346  c_thread_buf,
347  make_tuple(I0, I0, I0, I0));
348 
349  // C_sub_01 += transpose(A_sub_0) * B_sub_1
350  threadwise_contraction.Run(a_thread_buf,
351  make_tuple(I0, I0, I0, I0),
352  b_thread_buf,
353  make_tuple(I0, I1, I0, I0),
354  c_thread_buf,
355  make_tuple(I0, I0, I1, I0));
356  });
357 
358  // C_sub_10 += transpose(A_sub_1) * B_sub_0
359  threadwise_contraction.Run(a_thread_buf,
360  make_tuple(I0, I1, I0, I0),
361  b_thread_buf,
362  make_tuple(I0, I0, I0, I0),
363  c_thread_buf,
364  make_tuple(I1, I0, I0, I0));
365 
366  // C_sub_11 += transpose(A_sub_1) * B_sub_1
367  threadwise_contraction.Run(a_thread_buf,
368  make_tuple(I0, I1, I0, I0),
369  b_thread_buf,
370  make_tuple(I0, I1, I0, I0),
371  c_thread_buf,
372  make_tuple(I1, I0, I1, I0));
373  }
374 
375  private:
376  // A[BK0, BM0, BM1, BK1]
377  static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ =
379  Number<BK0PerThread>{}, Number<BM0>{}, Number<BM1PerThreadBM11>{}, Number<BK1>{}));
380 
381  // B[BK0, BN0, BN1, BK1]
382  static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ =
384  Number<BK0PerThread>{}, Number<BN0>{}, Number<BN1PerThreadBN11>{}, Number<BK1>{}));
385 
386  using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
387  FloatA,
388  FloatA,
390  decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
391  Sequence<BK0PerThread, 1, BM1PerThreadBM11, BK1>, // SliceLengths
392  Sequence<0, 1, 2, 3>, // DimAccessOrder
393  Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths
394  Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
395 
396  using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
397  FloatB,
398  FloatB,
400  decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
401  Sequence<BK0PerThread, 1, BN1PerThreadBN11, BK1>, // SliceLengths
402  Sequence<0, 1, 2, 3>, // DimAccessOrder
403  Sequence<1, 1, BN1PerThreadBN11, BK1>, // SrcVectorTensorLengths
404  Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
405 
406  CIndex c_thread_origin_data_idx_;
407 
408  AThreadCopy a_thread_copy_;
409  BThreadCopy b_thread_copy_;
410 };
411 
412 } // namespace ck
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:10
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:16
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto chain_tensor_adaptors(const TensorAdaptor0 &adaptor0, const TensorAdaptor1 &adaptor1)
Definition: tensor_adaptor.hpp:245
Definition: array.hpp:14
static constexpr auto I0
Definition: blockwise_gemm_dl_v2r3.hpp:52
static constexpr index_t BK0
Definition: blockwise_gemm_dl_v2r3.hpp:57
static constexpr auto I2
Definition: blockwise_gemm_dl_v2r3.hpp:54
MultiIndex< 4 > CIndex
Definition: blockwise_gemm_dl_v2r3.hpp:50
static constexpr index_t BM0
Definition: blockwise_gemm_dl_v2r3.hpp:74
__device__ BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2()
Definition: blockwise_gemm_dl_v2r3.hpp:153
static constexpr index_t BM101
Definition: blockwise_gemm_dl_v2r3.hpp:65
static constexpr index_t BM
Definition: blockwise_gemm_dl_v2r3.hpp:59
static constexpr index_t BK1
Definition: blockwise_gemm_dl_v2r3.hpp:58
static constexpr index_t BN101
Definition: blockwise_gemm_dl_v2r3.hpp:66
static __device__ CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id)
Definition: blockwise_gemm_dl_v2r3.hpp:184
static constexpr index_t BM1
Definition: blockwise_gemm_dl_v2r3.hpp:71
__host__ static constexpr __device__ auto MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1 &a_block_desc_bk0_bm_bk1)
Definition: blockwise_gemm_dl_v2r3.hpp:78
__device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11 &, const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition: blockwise_gemm_dl_v2r3.hpp:212
static constexpr index_t BN1
Definition: blockwise_gemm_dl_v2r3.hpp:72
__host__ static constexpr __device__ auto MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1()
Definition: blockwise_gemm_dl_v2r3.hpp:123
static constexpr index_t BN0
Definition: blockwise_gemm_dl_v2r3.hpp:75
__host__ static constexpr __device__ auto MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN()
Definition: blockwise_gemm_dl_v2r3.hpp:106
static constexpr auto I1
Definition: blockwise_gemm_dl_v2r3.hpp:53
static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_
Definition: blockwise_gemm_dl_v2r3.hpp:146
static constexpr index_t BN11
Definition: blockwise_gemm_dl_v2r3.hpp:69
static constexpr index_t BN100
Definition: blockwise_gemm_dl_v2r3.hpp:63
static constexpr index_t BM11
Definition: blockwise_gemm_dl_v2r3.hpp:68
__host__ static constexpr __device__ auto MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1 &b_block_desc_bk0_bn_bk1)
Definition: blockwise_gemm_dl_v2r3.hpp:92
__host__ static constexpr __device__ auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1()
Definition: blockwise_gemm_dl_v2r3.hpp:141
static constexpr index_t BN
Definition: blockwise_gemm_dl_v2r3.hpp:60
static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_
Definition: blockwise_gemm_dl_v2r3.hpp:149
static constexpr auto I3
Definition: blockwise_gemm_dl_v2r3.hpp:55
static constexpr index_t BM100
Definition: blockwise_gemm_dl_v2r3.hpp:62
Definition: sequence.hpp:43
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer_v4r1.hpp:62
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31