/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp Source File
blockwise_gemm_pipeline_wmmaops_v3.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Compute optimized pipeline
11 // GlobalPrefetchStages: 2
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 1
14 // LocalSharedMemoryBuffer: 1
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeTypeA,
21  typename ComputeTypeB,
22  typename AccDataType,
23  typename AWmmaTileDesc,
24  typename BWmmaTileDesc,
25  index_t ABlockTransferSrcScalarPerVector,
26  index_t BBlockTransferSrcScalarPerVector,
27  index_t MPerBlock,
28  index_t NPerBlock,
29  index_t KPerBlock,
30  index_t MPerWmma,
31  index_t NPerWmma,
32  index_t MRepeat,
33  index_t NRepeat,
34  index_t KPack,
35  index_t KInner,
36  bool TransposeC = false,
37  bool BSkipLDS = false>
39 {
40 };
41 
42 template <index_t BlockSize,
43  typename ADataType,
44  typename BDataType,
45  typename ComputeTypeA,
46  typename ComputeTypeB,
47  typename AccDataType,
48  typename AWmmaTileDesc,
49  typename BWmmaTileDesc,
50  index_t ABlockTransferSrcScalarPerVector,
51  index_t BBlockTransferSrcScalarPerVector,
52  index_t MPerBlock,
53  index_t NPerBlock,
54  index_t KPerBlock,
55  index_t MPerWmma,
56  index_t NPerWmma,
57  index_t MRepeat,
58  index_t NRepeat,
59  index_t KPack,
60  index_t KInner,
61  bool TransposeC>
63  BlockSize,
64  ADataType,
65  BDataType,
66  ComputeTypeA,
67  ComputeTypeB,
68  AccDataType,
69  AWmmaTileDesc,
70  BWmmaTileDesc,
71  ABlockTransferSrcScalarPerVector,
72  BBlockTransferSrcScalarPerVector,
73  MPerBlock,
74  NPerBlock,
75  KPerBlock,
76  MPerWmma,
77  NPerWmma,
78  MRepeat,
79  NRepeat,
80  KPack,
81  KInner,
82  TransposeC,
83  false>
85  ADataType,
86  BDataType,
87  ComputeTypeA,
88  ComputeTypeB,
89  AccDataType,
90  AWmmaTileDesc,
91  BWmmaTileDesc,
92  ABlockTransferSrcScalarPerVector,
93  BBlockTransferSrcScalarPerVector,
94  MPerBlock,
95  NPerBlock,
96  KPerBlock,
97  MPerWmma,
98  NPerWmma,
99  MRepeat,
100  NRepeat,
101  KPack,
102  KInner,
103  TransposeC>
104 {
106  ADataType,
107  BDataType,
108  ComputeTypeA,
109  ComputeTypeB,
110  AccDataType,
111  AWmmaTileDesc,
112  BWmmaTileDesc,
113  ABlockTransferSrcScalarPerVector,
114  BBlockTransferSrcScalarPerVector,
115  MPerBlock,
116  NPerBlock,
117  KPerBlock,
118  MPerWmma,
119  NPerWmma,
120  MRepeat,
121  NRepeat,
122  KPack,
123  KInner,
124  TransposeC>;
125  using Base::I0;
126 
127  using Base::A_K1;
128  using Base::A_KRow;
129  using Base::B_K1;
130  using Base::B_KRow;
131  using Base::KRepeat;
132  using Base::WmmaK;
133 
134  using Base::wmma_gemm;
135  using typename Base::HotLoopInstList;
136 
137  using Base::CalculateCThreadOriginDataIndex;
138  using Base::
139  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
140  using Base::GetCThreadBuffer;
141  using Base::
142  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
143  using Base::
144  GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs;
145 
146  using Base::a_block_desc_k0_m0_m1_m2_k1;
147  using Base::b_block_desc_k0_n0_n1_n2_k1;
148 
149  using typename Base::Empty;
150 
151  static constexpr index_t PrefetchStages = 2;
152  static constexpr index_t PrefillStages = 1;
153  static constexpr index_t GlobalBufferNum = 1;
154 
155  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
156  {
157  return num_loop > PrefetchStages;
158  }
159 
160  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
161  {
162  if(BlockHasHotloop(num_loop))
163  {
164  return TailNumber::Full;
165  }
166  else
167  {
168  if(num_loop == 1)
169  {
170  return TailNumber::Odd;
171  }
172  else
173  {
174  return TailNumber::Even;
175  }
176  }
177  }
178 
179  __device__ static constexpr auto HotLoopScheduler()
180  {
181  // TODO: Calculation of the number of instructions may require changes for WMMA
182  /*
183  // A/B split schedule
184  // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
185  constexpr auto num_ds_read_inst_a =
186  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
187  ? HotLoopInstList::A_LDS_Read_Inst_Num
188  : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
189  constexpr auto num_ds_read_inst_b =
190  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
191  ? HotLoopInstList::B_LDS_Read_Inst_Num
192  : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
193 
194  constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
195  constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
196 
197  constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
198  constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
199 
200  constexpr auto num_wmma_inst = HotLoopInstList::C_WMMA_Inst_Num;
201 
202  constexpr auto wmma_cycle = NPerWmma == 16 ? 16 : 32;
203  constexpr auto ds_read_a_issue_cycle =
204  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
205  constexpr auto ds_read_b_issue_cycle =
206  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
207  constexpr auto ds_read_a_wmma_rate =
208  (wmma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
209  constexpr auto ds_read_b_wmma_rate =
210  (wmma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
211 
212  constexpr auto num_dsread_a_wmma =
213  (num_ds_read_inst_a + ds_read_a_wmma_rate - 1) / ds_read_a_wmma_rate;
214  constexpr auto num_dsread_b_wmma =
215  (num_ds_read_inst_b + ds_read_b_wmma_rate - 1) / ds_read_b_wmma_rate;
216 
217  // stage 1
218  // Separate this part?
219  // constexpr auto num_wmma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
220  // sizeof(ComputeDataType) / sizeof(BDataType)
221  // ? sizeof(ComputeDataType) / sizeof(ADataType)
222  // : sizeof(ComputeDataType) / sizeof(BDataType);
223  constexpr auto num_wmma_stage1 = num_wmma_inst - (num_dsread_a_wmma + num_dsread_b_wmma);
224  constexpr auto num_wmma_per_issue =
225  num_wmma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
226  constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
227  constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
228 
229  static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
230  ignore = i;
231  static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
232  ignore = idswrite;
233  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
234  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
235  });
236  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
237  __builtin_amdgcn_sched_group_barrier(
238  0x008, num_wmma_per_issue - num_dswrite_per_issue_a, 0); // WMMA
239  });
240  static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
241  ignore = i;
242  static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
243  ignore = idswrite;
244  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
245  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
246  });
247  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
248  __builtin_amdgcn_sched_group_barrier(
249  0x008, num_wmma_per_issue - num_dswrite_per_issue_b, 0); // WMMA
250  });
251 
252  // stage 2
253  static_for<0, num_dsread_a_wmma, 1>{}([&](auto i) {
254  if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_wmma_rate) >=
255  ds_read_a_wmma_rate)
256  {
257  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_wmma_rate, 0); // DS read
258  }
259  else
260  {
261  __builtin_amdgcn_sched_group_barrier(0x100,
262  num_ds_read_inst_a - (num_dsread_a_wmma - 1) *
263  ds_read_a_wmma_rate,
264  0); // DS read
265  }
266  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
267  });
268 
269  static_for<0, num_dsread_b_wmma, 1>{}([&](auto i) {
270  if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_wmma_rate) >=
271  ds_read_b_wmma_rate)
272  {
273  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_wmma_rate, 0); // DS read
274  }
275  else
276  {
277  __builtin_amdgcn_sched_group_barrier(0x100,
278  num_ds_read_inst_b - (num_dsread_b_wmma - 1) *
279  ds_read_b_wmma_rate,
280  0); // DS read
281  }
282  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
283  });
284  */
285  }
286 
287  template <typename ABlockBuffer,
288  typename AThreadBuffer,
289  typename BBlockBuffer,
290  typename BThreadBuffer,
291  typename BScaleStruct>
292  __device__ inline void LocalLoad(ABlockBuffer& a_block_buf,
293  AThreadBuffer& a_thread_buf,
294  BBlockBuffer& b_block_buf,
295  BThreadBuffer& b_thread_buf,
296  BScaleStruct& b_scale_struct) const
297  {
298  static_for<0, KRepeat, 1>{}([&](auto k0) {
299  static_for<0, MRepeat, 1>{}([&](auto m0) {
300  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
301  make_tuple(I0, m0, k0, I0, I0, I0, I0),
302  a_block_buf,
303  a_thread_desc_,
304  make_tuple(I0, m0, k0, I0, I0, I0, I0),
305  a_thread_buf);
306  });
307 
308  if constexpr(ck::is_same_v<BScaleStruct, Empty>)
309  {
310  static_for<0, NRepeat, 1>{}([&](auto n0) {
311  b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
312  make_tuple(I0, n0, k0, I0, I0, I0, I0),
313  b_block_buf,
314  b_thread_desc_,
315  make_tuple(I0, n0, k0, I0, I0, I0, I0),
316  b_thread_buf);
317  });
318  }
319  else
320  {
321  static_for<0, NRepeat, 1>{}([&](auto n0) {
322  b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
323  make_tuple(I0, n0, k0, I0, I0, I0, I0),
324  b_block_buf,
325  b_scale_struct.b_scale_thread_bufs(
326  I0)[Number<n0 * BScaleStruct::num_scale_k_block +
327  k0 / BScaleStruct::num_scale_krepeat>{}],
328  b_thread_desc_,
329  make_tuple(I0, n0, k0, I0, I0, I0, I0),
330  b_thread_buf);
331  });
332  }
333  });
334  }
335 
336  template <bool HasMainLoop,
337  TailNumber TailNum,
338  typename AGridDesc,
339  typename ABlockDesc,
340  typename ABlockTransfer,
341  typename AGridBuffer,
342  typename ABlockBuffer,
343  typename ABlockTransferStep,
344  typename BGridDesc,
345  typename BBlockDesc,
346  typename BBlockTransfer,
347  typename BGridBuffer,
348  typename BBlockBuffer,
349  typename BBlockTransferStep,
350  typename CThreadBuffer,
351  typename BScaleStruct>
352  __device__ void Run(const AGridDesc& a_grid_desc,
353  const ABlockDesc& a_block_desc,
354  ABlockTransfer& a_blockwise_copy,
355  const AGridBuffer& a_grid_buf,
356  ABlockBuffer& a_block_buf,
357  const ABlockTransferStep& a_block_copy_step,
358  const BGridDesc& b_grid_desc,
359  const BBlockDesc& b_block_desc,
360  BBlockTransfer& b_blockwise_copy,
361  const BGridBuffer& b_grid_buf,
362  BBlockBuffer& b_block_buf,
363  const BBlockTransferStep& b_block_copy_step,
364  CThreadBuffer& c_thread_buf,
365  // BScaleThreadCopy
366  BScaleStruct& b_scale_struct,
367  index_t num_loop,
368  index_t num_loop_per_scale) const
369  {
370  __builtin_amdgcn_sched_barrier(0);
371 
372  constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
373 
374  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
375  a_thread_desc_.GetElementSpaceSize());
376  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
377  b_thread_desc_.GetElementSpaceSize());
378 
379  // Global prefetch 1
380  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
381  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
382 
383  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
384  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
385 
386  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
387 
388  // Local prefill 1
389  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
390  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
391 
392  // Global prefetch 2, perform when at least 2 loops exist.
393  if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
394  {
395  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
396  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
397 
398  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
399  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
400  }
401 
402  // Initialize C
403  c_thread_buf.Clear();
404 
405  // Local prefetch 1
406  block_sync_lds();
407 
408  LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
409 
410  __builtin_amdgcn_sched_barrier(0);
411 
412  // Main body, perform when at least 3 loops exist.
413  if constexpr(HasMainLoop)
414  {
415  index_t i = 0;
416  do
417  {
418  block_sync_lds();
419 
420  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
421  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
422 
423  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
424  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
425 
426  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
427  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
428 
429  b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
430 
431  static_for<0, KRepeat, 1>{}([&](auto k0) {
432  static_for<0, MRepeat, 1>{}([&](auto m0) {
433  static_for<0, NRepeat, 1>{}([&](auto n0) {
434  static_for<0, KInner, 1>{}([&](auto k_inner) {
435  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
436  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
437 
438  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
439  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
440  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
441  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
443  m0,
444  k0,
445  I0,
446  I0,
447  I0,
448  Number<kk % A_K1>{}))>{}];
449  });
450  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
451  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
452  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
453  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
455  n0,
456  k0,
457  I0,
458  I0,
459  I0,
460  Number<kk % B_K1>{}))>{}];
461  });
462 
463  using wmma_input_type_a =
464  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
465  using wmma_input_type_b =
466  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
467 
468  constexpr index_t c_offset =
469  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
470 
471  wmma_gemm.Run(
472  a_thread_vec.template AsType<wmma_input_type_a>(),
473  b_thread_vec.template AsType<wmma_input_type_b>(),
474  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
475  });
476  });
477  });
478  });
479 
480  block_sync_lds();
481 
482  LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
483 
484  HotLoopScheduler();
485  __builtin_amdgcn_sched_barrier(0);
486 
487  i += 1;
488  } while(i < (num_loop - 2));
489  }
490 
491  // Pre-tail, perform when at least 2 loops exist.
492  if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
493  {
494  block_sync_lds();
495 
496  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
497  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
498 
499  // No RunRead or MoveSrcSliceWindow here, already finished them all!
500 
501  b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
502 
503  static_for<0, KRepeat, 1>{}([&](auto k0) {
504  static_for<0, MRepeat, 1>{}([&](auto m0) {
505  static_for<0, NRepeat, 1>{}([&](auto n0) {
506  static_for<0, KInner, 1>{}([&](auto k_inner) {
507  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
508  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
509 
510  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
511  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
512  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
513  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
515  m0,
516  k0,
517  I0,
518  I0,
519  I0,
520  Number<kk % A_K1>{}))>{}];
521  });
522  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
523  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
524  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
525  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
527  n0,
528  k0,
529  I0,
530  I0,
531  I0,
532  Number<kk % B_K1>{}))>{}];
533  });
534 
535  using wmma_input_type_a =
536  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
537  using wmma_input_type_b =
538  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
539 
540  constexpr index_t c_offset =
541  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
542 
543  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
544  b_thread_vec.template AsType<wmma_input_type_b>(),
545  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
546  });
547  });
548  });
549  });
550 
551  block_sync_lds();
552 
553  LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
554 
555  HotLoopScheduler();
556  __builtin_amdgcn_sched_barrier(0);
557  }
558 
559  // Tail, always perform.
560  {
561  static_for<0, KRepeat, 1>{}([&](auto k0) {
562  static_for<0, MRepeat, 1>{}([&](auto m0) {
563  static_for<0, NRepeat, 1>{}([&](auto n0) {
564  static_for<0, KInner, 1>{}([&](auto k_inner) {
565  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
566  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
567 
568  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
569  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
570  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
571  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
573  m0,
574  k0,
575  I0,
576  I0,
577  I0,
578  Number<kk % A_K1>{}))>{}];
579  });
580  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
581  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
582  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
583  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
585  n0,
586  k0,
587  I0,
588  I0,
589  I0,
590  Number<kk % B_K1>{}))>{}];
591  });
592 
593  using wmma_input_type_a =
594  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
595  using wmma_input_type_b =
596  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
597 
598  constexpr index_t c_offset =
599  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
600 
601  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
602  b_thread_vec.template AsType<wmma_input_type_b>(),
603  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
604  });
605  });
606  });
607  });
608  // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
609  // latency
610  // __builtin_amdgcn_sched_barrier(0);
611  }
612  }
613 
614  protected:
615  using Base::a_thread_copy_;
616  using Base::a_thread_desc_;
617  using Base::b_thread_copy_;
618  using Base::b_thread_desc_;
619  using Base::c_thread_desc_;
620 };
621 
622 } // namespace ck
Definition: ck.hpp:270
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:301
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:36
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:352
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:39
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:11