/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp Source File
blockwise_gemm_pipeline_wmmaops_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Compute optimized pipeline
11 // GlobalPrefetchStages: 2
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 1
14 // LocalSharedMemoryBuffer: 1
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeTypeA,
21  typename ComputeTypeB,
22  typename AccDataType,
23  typename AWmmaTileDesc,
24  typename BWmmaTileDesc,
25  index_t ABlockTransferSrcScalarPerVector,
26  index_t BBlockTransferSrcScalarPerVector,
27  index_t MPerBlock,
28  index_t NPerBlock,
29  index_t KPerBlock,
30  index_t MPerWmma,
31  index_t NPerWmma,
32  index_t MRepeat,
33  index_t NRepeat,
34  index_t KPack,
35  bool TransposeC = false>
37 {
38 };
39 
40 template <index_t BlockSize,
41  typename ADataType,
42  typename BDataType,
43  typename ComputeTypeA,
44  typename ComputeTypeB,
45  typename AccDataType,
46  typename AWmmaTileDesc,
47  typename BWmmaTileDesc,
48  index_t ABlockTransferSrcScalarPerVector,
49  index_t BBlockTransferSrcScalarPerVector,
50  index_t MPerBlock,
51  index_t NPerBlock,
52  index_t KPerBlock,
53  index_t MPerWmma,
54  index_t NPerWmma,
55  index_t MRepeat,
56  index_t NRepeat,
57  index_t KPack,
58  bool TransposeC>
60  BlockSize,
61  ADataType,
62  BDataType,
63  ComputeTypeA,
64  ComputeTypeB,
65  AccDataType,
66  AWmmaTileDesc,
67  BWmmaTileDesc,
68  ABlockTransferSrcScalarPerVector,
69  BBlockTransferSrcScalarPerVector,
70  MPerBlock,
71  NPerBlock,
72  KPerBlock,
73  MPerWmma,
74  NPerWmma,
75  MRepeat,
76  NRepeat,
77  KPack,
78  TransposeC>
80  ADataType,
81  BDataType,
82  ComputeTypeA,
83  ComputeTypeB,
84  AccDataType,
85  AWmmaTileDesc,
86  BWmmaTileDesc,
87  ABlockTransferSrcScalarPerVector,
88  BBlockTransferSrcScalarPerVector,
89  MPerBlock,
90  NPerBlock,
91  KPerBlock,
92  MPerWmma,
93  NPerWmma,
94  MRepeat,
95  NRepeat,
96  KPack,
97  TransposeC>
98 {
100  ADataType,
101  BDataType,
102  ComputeTypeA,
103  ComputeTypeB,
104  AccDataType,
105  AWmmaTileDesc,
106  BWmmaTileDesc,
107  ABlockTransferSrcScalarPerVector,
108  BBlockTransferSrcScalarPerVector,
109  MPerBlock,
110  NPerBlock,
111  KPerBlock,
112  MPerWmma,
113  NPerWmma,
114  MRepeat,
115  NRepeat,
116  KPack,
117  TransposeC>;
118  using Base::I0;
119 
120  using Base::A_K1;
121  using Base::A_KRow;
122  using Base::B_K1;
123  using Base::B_KRow;
124  using Base::KRepeat;
125  using Base::WmmaK;
126 
127  using Base::wmma_gemm;
128  using typename Base::HotLoopInstList;
129 
130  using Base::CalculateCThreadOriginDataIndex;
131  using Base::
132  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
133  using Base::GetCThreadBuffer;
134  using Base::
135  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
136  using Base::
137  GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs;
138 
139  using Base::a_block_desc_k0_m0_m1_m2_k1;
140  using Base::b_block_desc_k0_n0_n1_n2_k1;
141 
142  using typename Base::Empty;
143 
144  static constexpr index_t PrefetchStages = 2;
145  static constexpr index_t PrefillStages = 1;
146  static constexpr index_t GlobalBufferNum = 1;
147 
148  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
149  {
150  return num_loop > PrefetchStages;
151  }
152 
153  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
154  {
155  if(BlockHasHotloop(num_loop))
156  {
157  return TailNumber::Full;
158  }
159  else
160  {
161  if(num_loop == 1)
162  {
163  return TailNumber::Odd;
164  }
165  else
166  {
167  return TailNumber::Even;
168  }
169  }
170  }
171 
172  __device__ static constexpr auto HotLoopScheduler()
173  {
174  // TODO: Calculation of the number of instructions may require changes for WMMA
175  /*
176  // A/B split schedule
177  // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
178  constexpr auto num_ds_read_inst_a =
179  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
180  ? HotLoopInstList::A_LDS_Read_Inst_Num
181  : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
182  constexpr auto num_ds_read_inst_b =
183  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
184  ? HotLoopInstList::B_LDS_Read_Inst_Num
185  : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
186 
187  constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
188  constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
189 
190  constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
191  constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
192 
193  constexpr auto num_wmma_inst = HotLoopInstList::C_WMMA_Inst_Num;
194 
195  constexpr auto wmma_cycle = NPerWmma == 16 ? 16 : 32;
196  constexpr auto ds_read_a_issue_cycle =
197  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
198  constexpr auto ds_read_b_issue_cycle =
199  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
200  constexpr auto ds_read_a_wmma_rate =
201  (wmma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
202  constexpr auto ds_read_b_wmma_rate =
203  (wmma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
204 
205  constexpr auto num_dsread_a_wmma =
206  (num_ds_read_inst_a + ds_read_a_wmma_rate - 1) / ds_read_a_wmma_rate;
207  constexpr auto num_dsread_b_wmma =
208  (num_ds_read_inst_b + ds_read_b_wmma_rate - 1) / ds_read_b_wmma_rate;
209 
210  // stage 1
211  // Separate this part?
212  // constexpr auto num_wmma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
213  // sizeof(ComputeDataType) / sizeof(BDataType)
214  // ? sizeof(ComputeDataType) / sizeof(ADataType)
215  // : sizeof(ComputeDataType) / sizeof(BDataType);
216  constexpr auto num_wmma_stage1 = num_wmma_inst - (num_dsread_a_wmma + num_dsread_b_wmma);
217  constexpr auto num_wmma_per_issue =
218  num_wmma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
219  constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
220  constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
221 
222  static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
223  ignore = i;
224  static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
225  ignore = idswrite;
226  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
227  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
228  });
229  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
230  __builtin_amdgcn_sched_group_barrier(
231  0x008, num_wmma_per_issue - num_dswrite_per_issue_a, 0); // WMMA
232  });
233  static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
234  ignore = i;
235  static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
236  ignore = idswrite;
237  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
238  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
239  });
240  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
241  __builtin_amdgcn_sched_group_barrier(
242  0x008, num_wmma_per_issue - num_dswrite_per_issue_b, 0); // WMMA
243  });
244 
245  // stage 2
246  static_for<0, num_dsread_a_wmma, 1>{}([&](auto i) {
247  if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_wmma_rate) >=
248  ds_read_a_wmma_rate)
249  {
250  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_wmma_rate, 0); // DS read
251  }
252  else
253  {
254  __builtin_amdgcn_sched_group_barrier(0x100,
255  num_ds_read_inst_a - (num_dsread_a_wmma - 1) *
256  ds_read_a_wmma_rate,
257  0); // DS read
258  }
259  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
260  });
261 
262  static_for<0, num_dsread_b_wmma, 1>{}([&](auto i) {
263  if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_wmma_rate) >=
264  ds_read_b_wmma_rate)
265  {
266  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_wmma_rate, 0); // DS read
267  }
268  else
269  {
270  __builtin_amdgcn_sched_group_barrier(0x100,
271  num_ds_read_inst_b - (num_dsread_b_wmma - 1) *
272  ds_read_b_wmma_rate,
273  0); // DS read
274  }
275  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
276  });
277  */
278  }
279 
280  template <typename ABlockBuffer,
281  typename AThreadBuffer,
282  typename BBlockBuffer,
283  typename BThreadBuffer,
284  typename BScaleStruct>
285  __device__ inline void LocalLoad(ABlockBuffer& a_block_buf,
286  AThreadBuffer& a_thread_buf,
287  BBlockBuffer& b_block_buf,
288  BThreadBuffer& b_thread_buf,
289  BScaleStruct& b_scale_struct) const
290  {
291  static_for<0, KRepeat, 1>{}([&](auto k0) {
292  static_for<0, MRepeat, 1>{}([&](auto m0) {
293  a_thread_copy_.Run(
294  a_block_desc_k0_m0_m1_m2_k1,
295  make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
296  a_block_buf,
297  a_thread_desc_,
298  make_tuple(I0, m0, k0, I0, I0, I0),
299  a_thread_buf);
300  });
301 
302  if constexpr(ck::is_same_v<BScaleStruct, Empty>)
303  {
304  static_for<0, NRepeat, 1>{}([&](auto n0) {
305  b_thread_copy_.Run(
306  b_block_desc_k0_n0_n1_n2_k1,
307  make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
308  b_block_buf,
309  b_thread_desc_,
310  make_tuple(I0, n0, k0, I0, I0, I0),
311  b_thread_buf);
312  });
313  }
314  else
315  {
316  static_for<0, NRepeat, 1>{}([&](auto n0) {
317  b_thread_copy_.Run(
318  b_block_desc_k0_n0_n1_n2_k1,
319  make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
320  b_block_buf,
321  b_scale_struct.b_scale_thread_bufs(
322  I0)[Number<n0 * BScaleStruct::num_scale_k_block +
323  k0 / BScaleStruct::num_scale_krepeat>{}],
324  b_thread_desc_,
325  make_tuple(I0, n0, k0, I0, I0, I0),
326  b_thread_buf);
327  });
328  }
329  });
330  }
331 
332  template <bool HasMainLoop,
333  TailNumber TailNum,
334  typename AGridDesc,
335  typename ABlockDesc,
336  typename ABlockTransfer,
337  typename AGridBuffer,
338  typename ABlockBuffer,
339  typename ABlockTransferStep,
340  typename BGridDesc,
341  typename BBlockDesc,
342  typename BBlockTransfer,
343  typename BGridBuffer,
344  typename BBlockBuffer,
345  typename BBlockTransferStep,
346  typename CThreadBuffer,
347  typename BScaleStruct>
348  __device__ void Run(const AGridDesc& a_grid_desc,
349  const ABlockDesc& a_block_desc,
350  ABlockTransfer& a_blockwise_copy,
351  const AGridBuffer& a_grid_buf,
352  ABlockBuffer& a_block_buf,
353  const ABlockTransferStep& a_block_copy_step,
354  const BGridDesc& b_grid_desc,
355  const BBlockDesc& b_block_desc,
356  BBlockTransfer& b_blockwise_copy,
357  const BGridBuffer& b_grid_buf,
358  BBlockBuffer& b_block_buf,
359  const BBlockTransferStep& b_block_copy_step,
360  CThreadBuffer& c_thread_buf,
361  // BScaleThreadCopy
362  BScaleStruct& b_scale_struct,
363  index_t num_loop,
364  index_t num_loop_per_scale) const
365  {
366  __builtin_amdgcn_sched_barrier(0);
367  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
368  a_thread_desc_.GetElementSpaceSize());
369  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
370  b_thread_desc_.GetElementSpaceSize());
371 
372  // Global prefetch 1
373  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
374  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
375 
376  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
377  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
378 
379  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
380 
381  // Local prefill 1
382  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
383  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
384 
385  // Global prefetch 2, perform when at least 2 loops exist.
386  if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
387  {
388  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
389  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
390 
391  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
392  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
393  }
394 
395  // Initialize C
396  c_thread_buf.Clear();
397 
398  // Local prefetch 1
399  block_sync_lds();
400 
401  LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
402 
403  __builtin_amdgcn_sched_barrier(0);
404 
405  // Main body, perform when at least 3 loops exist.
406  if constexpr(HasMainLoop)
407  {
408  index_t i = 0;
409  do
410  {
411  block_sync_lds();
412 
413  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
414  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
415 
416  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
417  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
418 
419  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
420  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
421 
422  b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
423 
424  static_for<0, KRepeat, 1>{}([&](auto k0) {
425  static_for<0, MRepeat, 1>{}([&](auto m0) {
426  static_for<0, NRepeat, 1>{}([&](auto n0) {
427  vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
428  vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
429 
430  static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
431  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
432  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
434  m0,
435  k0,
436  I0,
437  I0,
438  Number<ik % A_K1>{}))>{}];
439  });
440  static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
441  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
442  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
444  n0,
445  k0,
446  I0,
447  I0,
448  Number<ik % B_K1>{}))>{}];
449  });
450 
451  using wmma_input_type_a =
452  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
453  using wmma_input_type_b =
454  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
455 
456  constexpr index_t c_offset =
457  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
458 
459  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
460  b_thread_vec.template AsType<wmma_input_type_b>(),
461  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
462  });
463  });
464  });
465 
466  block_sync_lds();
467 
468  LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
469 
470  HotLoopScheduler();
471  __builtin_amdgcn_sched_barrier(0);
472 
473  i += 1;
474  } while(i < (num_loop - 2));
475  }
476 
477  // Pre-tail, perform when at least 2 loops exist.
478  if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
479  {
480  block_sync_lds();
481 
482  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
483  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
484 
485  // No RunRead or MoveSrcSliceWindow here, already finished them all!
486 
487  b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
488 
489  static_for<0, KRepeat, 1>{}([&](auto k0) {
490  static_for<0, MRepeat, 1>{}([&](auto m0) {
491  static_for<0, NRepeat, 1>{}([&](auto n0) {
492  vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
493  vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
494 
495  static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
496  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
497  a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
498  Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
499  });
500  static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
501  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
502  b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
503  Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
504  });
505 
506  using wmma_input_type_a =
507  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
508  using wmma_input_type_b =
509  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
510 
511  constexpr index_t c_offset =
512  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
513 
514  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
515  b_thread_vec.template AsType<wmma_input_type_b>(),
516  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
517  });
518  });
519  });
520 
521  block_sync_lds();
522 
523  LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
524 
525  HotLoopScheduler();
526  __builtin_amdgcn_sched_barrier(0);
527  }
528 
529  // Tail, always perform.
530  {
531  static_for<0, KRepeat, 1>{}([&](auto k0) {
532  static_for<0, MRepeat, 1>{}([&](auto m0) {
533  static_for<0, NRepeat, 1>{}([&](auto n0) {
534  vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
535  vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
536 
537  static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
538  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
539  a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
540  Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
541  });
542  static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
543  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
544  b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
545  Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
546  });
547 
548  using wmma_input_type_a =
549  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
550  using wmma_input_type_b =
551  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
552 
553  constexpr index_t c_offset =
554  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
555 
556  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
557  b_thread_vec.template AsType<wmma_input_type_b>(),
558  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
559  });
560  });
561  });
562  // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
563  // latency
564  // __builtin_amdgcn_sched_barrier(0);
565  }
566  }
567 
568  protected:
569  using Base::a_thread_copy_;
570  using Base::a_thread_desc_;
571  using Base::b_thread_copy_;
572  using Base::b_thread_desc_;
573  using Base::c_thread_desc_;
574 };
575 
576 } // namespace ck
Definition: ck.hpp:267
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:35
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:348
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:37
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10