/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.2.0/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.2.0/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.2.0/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp Source File
blockwise_gemm_pipeline_wmmaops_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Compute optimized pipeline
11 // GlobalPrefetchStages: 2
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 1
14 // LocalSharedMemoryBuffer: 1
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeTypeA,
21  typename ComputeTypeB,
22  typename AccDataType,
23  typename AWmmaTileDesc,
24  typename BWmmaTileDesc,
25  index_t ABlockTransferSrcScalarPerVector,
26  index_t BBlockTransferSrcScalarPerVector,
27  index_t MPerBlock,
28  index_t NPerBlock,
29  index_t KPerBlock,
30  index_t MPerWmma,
31  index_t NPerWmma,
32  index_t MRepeat,
33  index_t NRepeat,
34  index_t KPack,
35  index_t KInner,
36  bool TransposeC = false>
38 {
39 };
40 
41 template <index_t BlockSize,
42  typename ADataType,
43  typename BDataType,
44  typename ComputeTypeA,
45  typename ComputeTypeB,
46  typename AccDataType,
47  typename AWmmaTileDesc,
48  typename BWmmaTileDesc,
49  index_t ABlockTransferSrcScalarPerVector,
50  index_t BBlockTransferSrcScalarPerVector,
51  index_t MPerBlock,
52  index_t NPerBlock,
53  index_t KPerBlock,
54  index_t MPerWmma,
55  index_t NPerWmma,
56  index_t MRepeat,
57  index_t NRepeat,
58  index_t KPack,
59  index_t KInner,
60  bool TransposeC>
62  BlockSize,
63  ADataType,
64  BDataType,
65  ComputeTypeA,
66  ComputeTypeB,
67  AccDataType,
68  AWmmaTileDesc,
69  BWmmaTileDesc,
70  ABlockTransferSrcScalarPerVector,
71  BBlockTransferSrcScalarPerVector,
72  MPerBlock,
73  NPerBlock,
74  KPerBlock,
75  MPerWmma,
76  NPerWmma,
77  MRepeat,
78  NRepeat,
79  KPack,
80  KInner,
81  TransposeC>
83  ADataType,
84  BDataType,
85  ComputeTypeA,
86  ComputeTypeB,
87  AccDataType,
88  AWmmaTileDesc,
89  BWmmaTileDesc,
90  ABlockTransferSrcScalarPerVector,
91  BBlockTransferSrcScalarPerVector,
92  MPerBlock,
93  NPerBlock,
94  KPerBlock,
95  MPerWmma,
96  NPerWmma,
97  MRepeat,
98  NRepeat,
99  KPack,
100  KInner,
101  TransposeC>
102 {
104  ADataType,
105  BDataType,
106  ComputeTypeA,
107  ComputeTypeB,
108  AccDataType,
109  AWmmaTileDesc,
110  BWmmaTileDesc,
111  ABlockTransferSrcScalarPerVector,
112  BBlockTransferSrcScalarPerVector,
113  MPerBlock,
114  NPerBlock,
115  KPerBlock,
116  MPerWmma,
117  NPerWmma,
118  MRepeat,
119  NRepeat,
120  KPack,
121  KInner,
122  TransposeC>;
123  using Base::I0;
124 
125  using Base::A_K1;
126  using Base::A_KRow;
127  using Base::B_K1;
128  using Base::B_KRow;
129  using Base::KRepeat;
130  using Base::WmmaK;
131 
132  using Base::wmma_gemm;
133  using typename Base::HotLoopInstList;
134 
135  using Base::CalculateCThreadOriginDataIndex;
136  using Base::
137  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
138  using Base::GetCThreadBuffer;
139  using Base::
140  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
141  using Base::
142  GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs;
143 
144  using Base::a_block_desc_k0_m0_m1_m2_k1;
145  using Base::b_block_desc_k0_n0_n1_n2_k1;
146 
147  using typename Base::Empty;
148 
149  static constexpr index_t PrefetchStages = 2;
150  static constexpr index_t PrefillStages = 1;
151  static constexpr index_t GlobalBufferNum = 1;
152 
153  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
154  {
155  return num_loop > PrefetchStages;
156  }
157 
158  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
159  {
160  if(BlockHasHotloop(num_loop))
161  {
162  return TailNumber::Full;
163  }
164  else
165  {
166  if(num_loop == 1)
167  {
168  return TailNumber::Odd;
169  }
170  else
171  {
172  return TailNumber::Even;
173  }
174  }
175  }
176 
177  __device__ static constexpr auto HotLoopScheduler()
178  {
179  // TODO: Calculation of the number of instructions may require changes for WMMA
180  /*
181  // A/B split schedule
182  // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
183  constexpr auto num_ds_read_inst_a =
184  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
185  ? HotLoopInstList::A_LDS_Read_Inst_Num
186  : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
187  constexpr auto num_ds_read_inst_b =
188  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
189  ? HotLoopInstList::B_LDS_Read_Inst_Num
190  : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
191 
192  constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
193  constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
194 
195  constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
196  constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
197 
198  constexpr auto num_wmma_inst = HotLoopInstList::C_WMMA_Inst_Num;
199 
200  constexpr auto wmma_cycle = NPerWmma == 16 ? 16 : 32;
201  constexpr auto ds_read_a_issue_cycle =
202  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
203  constexpr auto ds_read_b_issue_cycle =
204  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
205  constexpr auto ds_read_a_wmma_rate =
206  (wmma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
207  constexpr auto ds_read_b_wmma_rate =
208  (wmma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
209 
210  constexpr auto num_dsread_a_wmma =
211  (num_ds_read_inst_a + ds_read_a_wmma_rate - 1) / ds_read_a_wmma_rate;
212  constexpr auto num_dsread_b_wmma =
213  (num_ds_read_inst_b + ds_read_b_wmma_rate - 1) / ds_read_b_wmma_rate;
214 
215  // stage 1
216  // Separate this part?
217  // constexpr auto num_wmma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
218  // sizeof(ComputeDataType) / sizeof(BDataType)
219  // ? sizeof(ComputeDataType) / sizeof(ADataType)
220  // : sizeof(ComputeDataType) / sizeof(BDataType);
221  constexpr auto num_wmma_stage1 = num_wmma_inst - (num_dsread_a_wmma + num_dsread_b_wmma);
222  constexpr auto num_wmma_per_issue =
223  num_wmma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
224  constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
225  constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
226 
227  static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
228  ignore = i;
229  static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
230  ignore = idswrite;
231  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
232  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
233  });
234  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
235  __builtin_amdgcn_sched_group_barrier(
236  0x008, num_wmma_per_issue - num_dswrite_per_issue_a, 0); // WMMA
237  });
238  static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
239  ignore = i;
240  static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
241  ignore = idswrite;
242  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
243  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
244  });
245  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
246  __builtin_amdgcn_sched_group_barrier(
247  0x008, num_wmma_per_issue - num_dswrite_per_issue_b, 0); // WMMA
248  });
249 
250  // stage 2
251  static_for<0, num_dsread_a_wmma, 1>{}([&](auto i) {
252  if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_wmma_rate) >=
253  ds_read_a_wmma_rate)
254  {
255  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_wmma_rate, 0); // DS read
256  }
257  else
258  {
259  __builtin_amdgcn_sched_group_barrier(0x100,
260  num_ds_read_inst_a - (num_dsread_a_wmma - 1) *
261  ds_read_a_wmma_rate,
262  0); // DS read
263  }
264  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
265  });
266 
267  static_for<0, num_dsread_b_wmma, 1>{}([&](auto i) {
268  if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_wmma_rate) >=
269  ds_read_b_wmma_rate)
270  {
271  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_wmma_rate, 0); // DS read
272  }
273  else
274  {
275  __builtin_amdgcn_sched_group_barrier(0x100,
276  num_ds_read_inst_b - (num_dsread_b_wmma - 1) *
277  ds_read_b_wmma_rate,
278  0); // DS read
279  }
280  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
281  });
282  */
283  }
284 
285  template <typename ABlockBuffer,
286  typename AThreadBuffer,
287  typename BBlockBuffer,
288  typename BThreadBuffer,
289  typename BScaleStruct>
290  __device__ inline void LocalLoad(ABlockBuffer& a_block_buf,
291  AThreadBuffer& a_thread_buf,
292  BBlockBuffer& b_block_buf,
293  BThreadBuffer& b_thread_buf,
294  BScaleStruct& b_scale_struct) const
295  {
296  static_for<0, KRepeat, 1>{}([&](auto k0) {
297  static_for<0, MRepeat, 1>{}([&](auto m0) {
298  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
299  make_tuple(I0, m0, k0, I0, I0, I0, I0),
300  a_block_buf,
301  a_thread_desc_,
302  make_tuple(I0, m0, k0, I0, I0, I0, I0),
303  a_thread_buf);
304  });
305 
306  if constexpr(ck::is_same_v<BScaleStruct, Empty>)
307  {
308  static_for<0, NRepeat, 1>{}([&](auto n0) {
309  b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
310  make_tuple(I0, n0, k0, I0, I0, I0, I0),
311  b_block_buf,
312  b_thread_desc_,
313  make_tuple(I0, n0, k0, I0, I0, I0, I0),
314  b_thread_buf);
315  });
316  }
317  else
318  {
319  static_for<0, NRepeat, 1>{}([&](auto n0) {
320  b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
321  make_tuple(I0, n0, k0, I0, I0, I0, I0),
322  b_block_buf,
323  b_scale_struct.b_scale_thread_bufs(
324  I0)[Number<n0 * BScaleStruct::num_scale_k_block +
325  k0 / BScaleStruct::num_scale_krepeat>{}],
326  b_thread_desc_,
327  make_tuple(I0, n0, k0, I0, I0, I0, I0),
328  b_thread_buf);
329  });
330  }
331  });
332  }
333 
334  template <bool HasMainLoop,
335  TailNumber TailNum,
336  typename AGridDesc,
337  typename ABlockDesc,
338  typename ABlockTransfer,
339  typename AGridBuffer,
340  typename ABlockBuffer,
341  typename ABlockTransferStep,
342  typename BGridDesc,
343  typename BBlockDesc,
344  typename BBlockTransfer,
345  typename BGridBuffer,
346  typename BBlockBuffer,
347  typename BBlockTransferStep,
348  typename CThreadBuffer,
349  typename BScaleStruct>
350  __device__ void Run(const AGridDesc& a_grid_desc,
351  const ABlockDesc& a_block_desc,
352  ABlockTransfer& a_blockwise_copy,
353  const AGridBuffer& a_grid_buf,
354  ABlockBuffer& a_block_buf,
355  const ABlockTransferStep& a_block_copy_step,
356  const BGridDesc& b_grid_desc,
357  const BBlockDesc& b_block_desc,
358  BBlockTransfer& b_blockwise_copy,
359  const BGridBuffer& b_grid_buf,
360  BBlockBuffer& b_block_buf,
361  const BBlockTransferStep& b_block_copy_step,
362  CThreadBuffer& c_thread_buf,
363  // BScaleThreadCopy
364  BScaleStruct& b_scale_struct,
365  index_t num_loop,
366  index_t num_loop_per_scale) const
367  {
368  __builtin_amdgcn_sched_barrier(0);
369 
370  constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
371 
372  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
373  a_thread_desc_.GetElementSpaceSize());
374  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
375  b_thread_desc_.GetElementSpaceSize());
376 
377  // Global prefetch 1
378  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
379  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
380 
381  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
382  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
383 
384  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
385 
386  // Local prefill 1
387  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
388  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
389 
390  // Global prefetch 2, perform when at least 2 loops exist.
391  if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
392  {
393  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
394  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
395 
396  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
397  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
398  }
399 
400  // Initialize C
401  c_thread_buf.Clear();
402 
403  // Local prefetch 1
404  block_sync_lds();
405 
406  LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
407 
408  __builtin_amdgcn_sched_barrier(0);
409 
410  // Main body, perform when at least 3 loops exist.
411  if constexpr(HasMainLoop)
412  {
413  index_t i = 0;
414  do
415  {
416  block_sync_lds();
417 
418  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
419  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
420 
421  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
422  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
423 
424  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
425  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
426 
427  b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
428 
429  static_for<0, KRepeat, 1>{}([&](auto k0) {
430  static_for<0, MRepeat, 1>{}([&](auto m0) {
431  static_for<0, NRepeat, 1>{}([&](auto n0) {
432  static_for<0, KInner, 1>{}([&](auto k_inner) {
433  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
434  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
435 
436  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
437  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
438  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
439  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
441  m0,
442  k0,
443  I0,
444  I0,
445  I0,
446  Number<kk % A_K1>{}))>{}];
447  });
448  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
449  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
450  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
451  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
453  n0,
454  k0,
455  I0,
456  I0,
457  I0,
458  Number<kk % B_K1>{}))>{}];
459  });
460 
461  using wmma_input_type_a =
462  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
463  using wmma_input_type_b =
464  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
465 
466  constexpr index_t c_offset =
467  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
468 
469  wmma_gemm.Run(
470  a_thread_vec.template AsType<wmma_input_type_a>(),
471  b_thread_vec.template AsType<wmma_input_type_b>(),
472  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
473  });
474  });
475  });
476  });
477 
478  block_sync_lds();
479 
480  LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
481 
482  HotLoopScheduler();
483  __builtin_amdgcn_sched_barrier(0);
484 
485  i += 1;
486  } while(i < (num_loop - 2));
487  }
488 
489  // Pre-tail, perform when at least 2 loops exist.
490  if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
491  {
492  block_sync_lds();
493 
494  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
495  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
496 
497  // No RunRead or MoveSrcSliceWindow here, already finished them all!
498 
499  b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
500 
501  static_for<0, KRepeat, 1>{}([&](auto k0) {
502  static_for<0, MRepeat, 1>{}([&](auto m0) {
503  static_for<0, NRepeat, 1>{}([&](auto n0) {
504  static_for<0, KInner, 1>{}([&](auto k_inner) {
505  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
506  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
507 
508  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
509  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
510  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
511  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
513  m0,
514  k0,
515  I0,
516  I0,
517  I0,
518  Number<kk % A_K1>{}))>{}];
519  });
520  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
521  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
522  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
523  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
525  n0,
526  k0,
527  I0,
528  I0,
529  I0,
530  Number<kk % B_K1>{}))>{}];
531  });
532 
533  using wmma_input_type_a =
534  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
535  using wmma_input_type_b =
536  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
537 
538  constexpr index_t c_offset =
539  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
540 
541  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
542  b_thread_vec.template AsType<wmma_input_type_b>(),
543  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
544  });
545  });
546  });
547  });
548 
549  block_sync_lds();
550 
551  LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
552 
553  HotLoopScheduler();
554  __builtin_amdgcn_sched_barrier(0);
555  }
556 
557  // Tail, always perform.
558  {
559  static_for<0, KRepeat, 1>{}([&](auto k0) {
560  static_for<0, MRepeat, 1>{}([&](auto m0) {
561  static_for<0, NRepeat, 1>{}([&](auto n0) {
562  static_for<0, KInner, 1>{}([&](auto k_inner) {
563  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
564  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
565 
566  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
567  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
568  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
569  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
571  m0,
572  k0,
573  I0,
574  I0,
575  I0,
576  Number<kk % A_K1>{}))>{}];
577  });
578  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
579  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
580  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
581  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
583  n0,
584  k0,
585  I0,
586  I0,
587  I0,
588  Number<kk % B_K1>{}))>{}];
589  });
590 
591  using wmma_input_type_a =
592  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
593  using wmma_input_type_b =
594  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
595 
596  constexpr index_t c_offset =
597  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
598 
599  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
600  b_thread_vec.template AsType<wmma_input_type_b>(),
601  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
602  });
603  });
604  });
605  });
606  // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
607  // latency
608  // __builtin_amdgcn_sched_barrier(0);
609  }
610  }
611 
612  protected:
613  using Base::a_thread_copy_;
614  using Base::a_thread_desc_;
615  using Base::b_thread_copy_;
616  using Base::b_thread_desc_;
617  using Base::c_thread_desc_;
618 };
619 
620 } // namespace ck
Definition: ck.hpp:268
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:299
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:36
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:350
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:38
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10