/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.1/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.1/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.1/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp Source File
blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Naive pipeline with lowest resource request per WGP
11 // GlobalPrefetchStages: 2
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 1
14 // LocalSharedMemoryBuffer: 1
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t ThreadBlockSize,
18  index_t ScaleBlockSize,
19  typename ADataType,
20  typename AScaleDataType,
21  typename BDataType,
22  typename BScaleDataType,
23  typename ATileDesc,
24  typename BTileDesc,
25  typename AMmaTileDesc,
26  typename BMmaTileDesc,
27  index_t ABlockTransferSrcScalarPerVector,
28  index_t BBlockTransferSrcScalarPerVector,
29  index_t MPerBlock,
30  index_t NPerBlock,
31  index_t KPerBlock,
32  index_t MPerXDL,
33  index_t NPerXDL,
34  index_t MRepeat, // MXdlPerWave
35  index_t NRepeat, // NXdlPerWave
36  index_t KPack>
38 {
39 };
40 
41 template <index_t ThreadBlockSize,
42  index_t ScaleBlockSize,
43  typename ADataType,
44  typename AScaleDataType,
45  typename BDataType,
46  typename BScaleDataType,
47  typename ATileDesc,
48  typename BTileDesc,
49  typename AMmaTileDesc,
50  typename BMmaTileDesc,
51  index_t ABlockTransferSrcScalarPerVector,
52  index_t BBlockTransferSrcScalarPerVector,
53  index_t MPerBlock,
54  index_t NPerBlock,
55  index_t KPerBlock,
56  index_t MPerXDL,
57  index_t NPerXDL,
58  index_t MRepeat, // MXdlPerWave
59  index_t NRepeat, // NXdlPerWave
60  index_t KPack>
62  ThreadBlockSize,
63  ScaleBlockSize,
64  ADataType,
65  AScaleDataType,
66  BDataType,
67  BScaleDataType,
68  ATileDesc,
69  BTileDesc,
70  AMmaTileDesc,
71  BMmaTileDesc,
72  ABlockTransferSrcScalarPerVector,
73  BBlockTransferSrcScalarPerVector,
74  MPerBlock,
75  NPerBlock,
76  KPerBlock,
77  MPerXDL,
78  NPerXDL,
79  MRepeat,
80  NRepeat,
81  KPack>
82  : BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
83  ADataType,
84  BDataType,
85  ATileDesc,
86  BTileDesc,
87  AMmaTileDesc,
88  BMmaTileDesc,
89  ABlockTransferSrcScalarPerVector,
90  BBlockTransferSrcScalarPerVector,
91  MPerBlock,
92  NPerBlock,
93  KPerBlock,
94  MPerXDL,
95  NPerXDL,
96  MRepeat,
97  NRepeat,
98  KPack>
99 
100 {
101 
103  ADataType,
104  BDataType,
105  ATileDesc,
106  BTileDesc,
107  AMmaTileDesc,
108  BMmaTileDesc,
109  ABlockTransferSrcScalarPerVector,
110  BBlockTransferSrcScalarPerVector,
111  MPerBlock,
112  NPerBlock,
113  KPerBlock,
114  MPerXDL,
115  NPerXDL,
116  MRepeat,
117  NRepeat,
118  KPack>;
119  using Base::I0;
120  using Base::I1;
121  using Base::KRepeat;
122  using Base::MWaves;
123  using Base::NWaves;
124  using Base::WaveSize;
125  using Base::xdlops_gemm;
126  using typename Base::HotLoopInstList;
127 
128  using Base::CalculateCThreadOriginDataIndex;
129  using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
130  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
131  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
132  using Base::GetCThreadBuffer;
133  using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
134  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
135  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
136  using Base::GetWaveIdx;
137  using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
138  using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
139 
140  using Base::a_block_desc_m0_m1_m2_m3_k;
141  using Base::b_block_desc_n0_n1_n2_n3_k;
142 
143  using Base::AMmaKStride;
144  using Base::APackedSize;
145  using Base::BMmaKStride;
146  using Base::BPackedSize;
147  using Base::KThreadChunk;
148 
149  using Base::KXdlPack;
150  using Base::MXdlPack;
151  using Base::NXdlPack;
152 
153  using AccType = typename Base::AccType;
154  using Tuple5 = typename Base::Tuple5;
157 
158  static constexpr index_t PrefetchStages = 2;
159  static constexpr index_t PrefillStages = 1;
160  static constexpr index_t GlobalBufferNum = 1;
161 
162  static constexpr auto ScalesPerKBlockSize =
163  KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
164 
165  //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
166  static constexpr auto ScalesPerXdlopsRun =
167  (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
168 
169  //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
170  static constexpr auto ScalesPerXdlopsRunPerThread =
171  ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
172 
174  static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
175  static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
176  static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
177  "A scale pack data type too large!");
178  static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
179  "B scale pack data type too large!");
180  static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a;
181  static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b;
182 
183  __host__ static constexpr bool BlockHasHotloop(index_t num_loop)
184  {
185  return num_loop > PrefetchStages;
186  }
187 
188  __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
189  {
190  return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
191  }
192 
193  __device__ static constexpr auto HotLoopScheduler()
194  {
195  // A/B split schedule
196  // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
197  constexpr auto num_ds_read_inst_a =
198  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
199  ? HotLoopInstList::A_LDS_Read_Inst_Num
200  : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
201  constexpr auto num_ds_read_inst_b =
202  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
203  ? HotLoopInstList::B_LDS_Read_Inst_Num
204  : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
205 
206  constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
207  constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
208 
209  constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
210  constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
211 
212  constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack;
213  constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack;
214 
215  constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize;
216 
217  constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
218  constexpr auto ds_read_a_issue_cycle =
219  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
220  constexpr auto ds_read_b_issue_cycle =
221  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
222 
223  constexpr auto ds_read_a_mfma_rate =
224  (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
225  constexpr auto ds_read_b_mfma_rate =
226  (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
227 
228  constexpr auto num_dsread_a_mfma =
229  (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
230  constexpr auto num_dsread_b_mfma =
231  (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
232 
233  // stage 1
234  constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
235  constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b +
236  num_buffer_load_a_scale + num_buffer_load_b_scale;
237 
238  constexpr auto mfma_perstage_more =
239  math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total);
240  constexpr auto mfma_perstage_less =
241  math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total);
242 
243  constexpr auto mfma_stages_more =
244  num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total;
245 
246  constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
247  constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
248 
250  if constexpr(i < mfma_stages_more)
251  {
252  static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) {
253  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
254  if constexpr(imfma < num_dswrite_per_issue_a)
255  {
256  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
257  }
258  });
259  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
260  }
261  else
262  {
263  static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) {
264  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
265  if constexpr(imfma < num_dswrite_per_issue_a)
266  {
267  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
268  }
269  });
270  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
271  }
272  });
273 
275  if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more)
276  {
277  static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) {
278  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
279  if constexpr(imfma < num_dswrite_per_issue_a)
280  {
281  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
282  }
283  });
284  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
285  }
286  else
287  {
288  static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) {
289  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
290  if constexpr(imfma < num_dswrite_per_issue_b)
291  {
292  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
293  }
294  });
295  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
296  }
297  });
298 
300  if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) < mfma_stages_more)
301  {
302  static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
303  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
304  });
305  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
306  }
307  else
308  {
309  static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
310  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
311  });
312  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
313  }
314  });
315 
317  if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b +
318  num_buffer_load_a_scale) < mfma_stages_more)
319  {
320  static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
321  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
322  });
323  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
324  }
325  else
326  {
327  static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
328  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
329  });
330  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
331  }
332  });
333 
334  // stage 2
336  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
337  if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
338  ds_read_a_mfma_rate)
339  {
340  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
341  }
342  else
343  {
344  __builtin_amdgcn_sched_group_barrier(0x100,
345  num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
346  ds_read_a_mfma_rate,
347  0); // DS read
348  }
349  });
350 
352  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
353  if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
354  ds_read_b_mfma_rate)
355  {
356  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
357  }
358  else
359  {
360  __builtin_amdgcn_sched_group_barrier(0x100,
361  num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
362  ds_read_b_mfma_rate,
363  0); // DS read
364  }
365  });
366  }
367 
368  template <bool HasMainLoop,
369  TailNumber TailNum,
370  typename AGridDesc,
371  typename ABlockDesc,
372  typename ABlockTransfer,
373  typename AGridBuffer,
374  typename ABlockBuffer,
375  typename ABlockTransferStep,
376  typename BGridDesc,
377  typename BBlockDesc,
378  typename BBlockTransfer,
379  typename BGridBuffer,
380  typename BBlockBuffer,
381  typename BBlockTransferStep,
382  typename CThreadBuffer,
383  typename AScaleGridBuffer,
384  typename AScaleGridDesc,
385  typename AScaleThreadTransfer,
386  typename BScaleGridBuffer,
387  typename BScaleGridDesc,
388  typename BScaleThreadTransfer>
389  __device__ void Run(
390  // ABlockCopy
391  const AGridDesc& a_grid_desc,
392  const ABlockDesc& a_block_desc,
393  ABlockTransfer& a_blockwise_copy,
394  const AGridBuffer& a_grid_buf,
395  ABlockBuffer& a_block_buf,
396  const ABlockTransferStep& a_block_copy_step,
397  // BBlockCopy
398  const BGridDesc& b_grid_desc,
399  const BBlockDesc& b_block_desc,
400  BBlockTransfer& b_blockwise_copy,
401  const BGridBuffer& b_grid_buf,
402  BBlockBuffer& b_block_buf,
403  const BBlockTransferStep& b_block_copy_step,
404  // CThread
405  CThreadBuffer& c_thread_buf,
406  // A and B scales
407  const AScaleGridDesc& a_scale_grid_desc,
408  AScaleThreadTransfer& a_scale_thread_copy,
409  const AScaleGridBuffer& a_scale_grid_buf,
410  const BScaleGridDesc& b_scale_grid_desc,
411  BScaleThreadTransfer& b_scale_thread_copy,
412  const BScaleGridBuffer& b_scale_grid_buf,
413  index_t num_loop) const
414  {
415  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
416  a_thread_desc_.GetElementSpaceSize());
417  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
418  b_thread_desc_.GetElementSpaceSize());
419 
420  auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
421  a_scale_thread_desc.GetElementSpaceSize());
422 
423  auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
424  b_scale_thread_desc.GetElementSpaceSize());
425 
426  StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
427  StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
428 
429  // Global prefetch 1
430  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
431  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
432 
433  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
434  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
435 
436  // Prefetch a_scales
437  static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
438  static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
439  a_scale_thread_copy.Run(a_scale_grid_desc,
440  a_scale_grid_buf,
441  a_scale_thread_desc,
442  make_tuple(m0, k0, I0),
443  a_scale_thread_bufs(I0));
444 
445  a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
446  make_multi_index(0, I1, 0));
447  });
448  a_scale_thread_copy.MoveSrcSliceWindow(
449  a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
450  });
451 
452  // restore row id and advance to the next set of scales
453  a_scale_thread_copy.MoveSrcSliceWindow(
454  a_scale_grid_desc,
455  make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
456 
457  // Prefetch b_scales
458  static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
459  static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
460  b_scale_thread_copy.Run(b_scale_grid_desc,
461  b_scale_grid_buf,
462  b_scale_thread_desc,
463  make_tuple(n0, k0, I0),
464  b_scale_thread_bufs(I0));
465 
466  b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
467  make_multi_index(0, I1, 0));
468  });
469  b_scale_thread_copy.MoveSrcSliceWindow(
470  b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
471  });
472 
473  // restore col id and advance to the next set of scales
474  // NWaves * NPerXDL * NRepeat == NPerBlock
475  b_scale_thread_copy.MoveSrcSliceWindow(
476  b_scale_grid_desc,
477  make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
478 
479  // Local prefill 1
480  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
481  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
482 
483  // Global prefetch 2
484  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
485  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
486 
487  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
488  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
489 
490  // Local prefetch 1
491  block_sync_lds();
492  static_for<0, KRepeat, 1>{}([&](auto k) {
493  constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
494  (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
495  static_for<0, MRepeat, 1>{}([&](auto m0) {
496  static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
497  [&](auto chunk) {
498  constexpr auto a_k_step_chunk =
499  k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
500  a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
502  I0,
504  I0,
506  a_block_buf,
507  a_thread_desc_,
509  I0,
511  k,
513  a_thread_buf);
514  });
515  });
516  static_for<0, NRepeat, 1>{}([&](auto n0) {
517  // read block data in chunks to assemble correct thread vectors
518  static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
519  [&](auto chunk) {
520  constexpr auto b_k_step_chunk =
521  k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
522  b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
524  I0,
526  I0,
528  b_block_buf,
529  b_thread_desc_,
531  I0,
533  k,
535  b_thread_buf);
536  });
537  });
538  });
539 
540  // Initialize C
541  c_thread_buf.Clear();
542  __builtin_amdgcn_sched_barrier(0);
543 
544  // main body
545  if constexpr(HasMainLoop)
546  {
547  // loop over k with the step KPerBlock
548  index_t i = 0;
549  do
550  {
551  auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) {
552  block_sync_lds();
553 
554  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
555  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
556 
557  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
558  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
559 
560  // Prefetch a_scales
561  static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
562  static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
563  a_scale_thread_copy.Run(a_scale_grid_desc,
564  a_scale_grid_buf,
565  a_scale_thread_desc,
566  make_tuple(m0, k0, I0),
567  a_scale_thread_bufs(scale_mem_buf));
568 
569  a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
570  make_multi_index(0, I1, 0));
571  });
572  a_scale_thread_copy.MoveSrcSliceWindow(
573  a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
574  });
575 
576  // restore row id and advance to the next set of scales
577  a_scale_thread_copy.MoveSrcSliceWindow(
578  a_scale_grid_desc,
579  make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
580 
581  // Prefetch b_scales
582  static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
583  static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
584  b_scale_thread_copy.Run(b_scale_grid_desc,
585  b_scale_grid_buf,
586  b_scale_thread_desc,
587  make_tuple(n0, k0, I0),
588  b_scale_thread_bufs(scale_mem_buf));
589 
590  b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
591  make_multi_index(0, I1, 0));
592  });
593  b_scale_thread_copy.MoveSrcSliceWindow(
594  b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
595  });
596 
597  // restore col id and advance to the next set of scales
598  // NWaves * NPerXDL * NRepeat == NPerBlock
599  b_scale_thread_copy.MoveSrcSliceWindow(
600  b_scale_grid_desc,
601  make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
602 
603  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
604  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
605 
606  static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
607  static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
608  static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
609  constexpr index_t a_scale_offset =
610  a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
611  constexpr index_t b_scale_offset =
612  b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
613 
614  static_assert(0 < ScalesPerXdlopsRunPerThread,
615  "Must have at least one scale per Xdlops "
616  "per Thread.");
617 
619  a_scale_thread_vec;
621  b_scale_thread_vec;
622 
623  // Pack scale_thread_buf into scale_thread_vec
625  a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
626  a_scale_thread_bufs(
627  scale_comp_buf)[Number<a_scale_offset + s>{}];
628  });
629 
631  b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
632  b_scale_thread_bufs(
633  scale_comp_buf)[Number<b_scale_offset + s>{}];
634  });
635 
636  static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
637  static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
638  static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
639  constexpr auto kxdl = ikxdl + k0 * KXdlPack;
640 
643 
644  static_for<0, KPack, 1>{}([&](auto ik) {
645  a_thread_vec.template AsType<ComputeTypeA>()(
646  ik) = a_thread_buf
647  [Number<a_thread_desc_.CalculateOffset(
648  make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
649  b_thread_vec.template AsType<ComputeTypeB>()(
650  ik) = b_thread_buf
651  [Number<b_thread_desc_.CalculateOffset(
652  make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
653  });
654 
655  using mfma_input_type_a =
656  typename vector_type<ComputeTypeA,
657  xdlops_gemm.K1PerXdlops /
658  APackedSize>::type;
659 
660  using mfma_input_type_b =
661  typename vector_type<ComputeTypeB,
662  xdlops_gemm.K1PerXdlops /
663  BPackedSize>::type;
664 
665  using mfma_scale_input_type_a =
666  typename vector_type<AScaleDataType,
667  a_scale_thread_vec_size>::type;
668  using mfma_scale_input_type_b =
669  typename vector_type<BScaleDataType,
670  b_scale_thread_vec_size>::type;
671 
672  constexpr index_t c_offset =
673  c_thread_desc_.CalculateOffset(
674  make_tuple(m0, n0, imxdl, inxdl, 0));
675 
676  // MFMA accumulation
677  xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
678  ikxdl * NXdlPack + inxdl>(
679  a_thread_vec.template AsType<mfma_input_type_a>(),
680  a_scale_thread_vec
681  .template AsType<mfma_scale_input_type_a>(),
682  b_thread_vec.template AsType<mfma_input_type_b>(),
683  b_scale_thread_vec
684  .template AsType<mfma_scale_input_type_b>(),
685  c_thread_buf.GetVectorTypeReference(
686  Number<c_offset>{}));
687  });
688  });
689  });
690  });
691  });
692  });
693 
694  // k indexes mapping to threads for 32x32x64:
695  // t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc.
696  // t32: |16 --> 31 48 --> 63 | 80 --> 95 112 --> 127 | etc.
697  // k = 0 k = 1
698 
699  // k indexes mapping to threads for 16x16x128:
700  // t0 : |0 --> 15 64 --> 79 | 128 --> 143 192 --> 207| etc.
701  // t16: |16 --> 31 80 --> 95 | 144 --> 159 208 --> 223| etc.
702  // t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc.
703  // t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc.
704  // k = 0 k = 1
705  block_sync_lds();
706  static_for<0, KRepeat, 1>{}([&](auto k) {
707  constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
708  (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
709  static_for<0, MRepeat, 1>{}([&](auto m0) {
710  static_for<0,
711  xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk),
712  1>{}([&](auto chunk) {
713  constexpr auto a_k_step_chunk =
714  k_step +
715  chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
716  a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
718  I0,
720  I0,
722  a_block_buf,
723  a_thread_desc_,
725  I0,
727  k,
729  a_thread_buf);
730  });
731  });
732  static_for<0, NRepeat, 1>{}([&](auto n0) {
733  // read block data in chunks to assemble correct thread vectors
734  static_for<0,
735  xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk),
736  1>{}([&](auto chunk) {
737  constexpr auto b_k_step_chunk =
738  k_step +
739  chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
740  b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
742  I0,
744  I0,
746  b_block_buf,
747  b_thread_desc_,
749  I0,
751  k,
753  b_thread_buf);
754  });
755  });
756  });
757 
758  HotLoopScheduler();
759  __builtin_amdgcn_sched_barrier(0);
760  };
761 
762  LoopFunc(I0, I1);
763  LoopFunc(I1, I0);
764 
765  i += 2;
766  } while(i < (num_loop - 2));
767  }
768 
769  // tail
770  if constexpr(TailNum == TailNumber::Even)
771  {
772  // Prefetch a_scales
773  static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
774  static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
775  a_scale_thread_copy.Run(a_scale_grid_desc,
776  a_scale_grid_buf,
777  a_scale_thread_desc,
778  make_tuple(m0, k0, I0),
779  a_scale_thread_bufs(I1));
780 
781  a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
782  make_multi_index(0, I1, 0));
783  });
784  a_scale_thread_copy.MoveSrcSliceWindow(
785  a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
786  });
787 
788  // Prefetch b_scales
789  static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
790  static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
791  b_scale_thread_copy.Run(b_scale_grid_desc,
792  b_scale_grid_buf,
793  b_scale_thread_desc,
794  make_tuple(n0, k0, I0),
795  b_scale_thread_bufs(I1));
796 
797  b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
798  make_multi_index(0, I1, 0));
799  });
800  b_scale_thread_copy.MoveSrcSliceWindow(
801  b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
802  });
803 
804  block_sync_lds();
805  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
806  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
807 
808  static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
809  static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
810  static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
811  constexpr index_t a_scale_offset =
812  a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
813  constexpr index_t b_scale_offset =
814  b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
815 
816  static_assert(0 < ScalesPerXdlopsRunPerThread,
817  "Must have at least one scale per Xdlops "
818  "per Thread.");
819 
822 
823  // Pack scale_thread_buf into scale_thread_vec
825  a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
826  a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
827  });
828 
830  b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
831  b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
832  });
833 
834  static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
835  static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
836  static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
837  constexpr auto kxdl = ikxdl + k0 * KXdlPack;
838 
841 
842  static_for<0, KPack, 1>{}([&](auto ik) {
843  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
844  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
845  make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
846  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
847  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
848  make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
849  });
850 
851  using mfma_input_type_a =
852  typename vector_type<ComputeTypeA,
853  xdlops_gemm.K1PerXdlops /
854  APackedSize>::type;
855 
856  using mfma_input_type_b =
857  typename vector_type<ComputeTypeB,
858  xdlops_gemm.K1PerXdlops /
859  BPackedSize>::type;
860 
861  using mfma_scale_input_type_a =
862  typename vector_type<AScaleDataType,
863  a_scale_thread_vec_size>::type;
864  using mfma_scale_input_type_b =
865  typename vector_type<BScaleDataType,
866  b_scale_thread_vec_size>::type;
867 
868  constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
869  make_tuple(m0, n0, imxdl, inxdl, 0));
870 
871  // MFMA accumulation
872  xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
873  ikxdl * NXdlPack + inxdl>(
874  a_thread_vec.template AsType<mfma_input_type_a>(),
875  a_scale_thread_vec
876  .template AsType<mfma_scale_input_type_a>(),
877  b_thread_vec.template AsType<mfma_input_type_b>(),
878  b_scale_thread_vec
879  .template AsType<mfma_scale_input_type_b>(),
880  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
881  });
882  });
883  });
884  });
885  });
886  });
887 
888  block_sync_lds();
889 
890  static_for<0, KRepeat, 1>{}([&](auto k) {
891  constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
892  (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
893  static_for<0, MRepeat, 1>{}([&](auto m0) {
894  static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
895  [&](auto chunk) {
896  constexpr auto a_k_step_chunk =
897  k_step +
898  chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
899  a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
901  I0,
903  I0,
905  a_block_buf,
906  a_thread_desc_,
908  I0,
910  k,
912  a_thread_buf);
913  });
914  });
915  static_for<0, NRepeat, 1>{}([&](auto n0) {
916  // read block data in chunks to assemble correct thread vectors
917  static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
918  [&](auto chunk) {
919  constexpr auto b_k_step_chunk =
920  k_step +
921  chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
922  b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
924  I0,
926  I0,
928  b_block_buf,
929  b_thread_desc_,
931  I0,
933  k,
935  b_thread_buf);
936  });
937  });
938  });
939 
940  static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
941  static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
942  static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
943  constexpr index_t a_scale_offset =
944  a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
945  constexpr index_t b_scale_offset =
946  b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
947 
948  static_assert(0 < ScalesPerXdlopsRunPerThread,
949  "Must have at least one scale per Xdlops "
950  "per Thread.");
951 
954 
955  // Pack scale_thread_buf into scale_thread_vec
957  a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
958  a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
959  });
960 
962  b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
963  b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
964  });
965 
966  static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
967  static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
968  static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
969  constexpr auto kxdl = ikxdl + k0 * KXdlPack;
970 
973 
974  static_for<0, KPack, 1>{}([&](auto ik) {
975  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
976  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
977  make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
978  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
979  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
980  make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
981  });
982 
983  using mfma_input_type_a =
984  typename vector_type<ComputeTypeA,
985  xdlops_gemm.K1PerXdlops /
986  APackedSize>::type;
987 
988  using mfma_input_type_b =
989  typename vector_type<ComputeTypeB,
990  xdlops_gemm.K1PerXdlops /
991  BPackedSize>::type;
992 
993  using mfma_scale_input_type_a =
994  typename vector_type<AScaleDataType,
995  a_scale_thread_vec_size>::type;
996  using mfma_scale_input_type_b =
997  typename vector_type<BScaleDataType,
998  b_scale_thread_vec_size>::type;
999 
1000  constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
1001  make_tuple(m0, n0, imxdl, inxdl, 0));
1002 
1003  // MFMA accumulation
1004  xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
1005  ikxdl * NXdlPack + inxdl>(
1006  a_thread_vec.template AsType<mfma_input_type_a>(),
1007  a_scale_thread_vec
1008  .template AsType<mfma_scale_input_type_a>(),
1009  b_thread_vec.template AsType<mfma_input_type_b>(),
1010  b_scale_thread_vec
1011  .template AsType<mfma_scale_input_type_b>(),
1012  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1013  });
1014  });
1015  });
1016  });
1017  });
1018  });
1019  }
1020  else if constexpr(TailNum == TailNumber::Odd)
1021  {
1022  static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
1023  static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
1024  static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
1025  constexpr index_t a_scale_offset =
1026  a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
1027  constexpr index_t b_scale_offset =
1028  b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
1029 
1030  static_assert(0 < ScalesPerXdlopsRunPerThread,
1031  "Must have at least one scale per Xdlops "
1032  "per Thread.");
1033 
1036 
1037  // Pack scale_thread_buf into scale_thread_vec
1039  a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
1040  a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
1041  });
1042 
1044  b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
1045  b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
1046  });
1047 
1048  static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
1049  static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
1050  static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
1051  constexpr auto kxdl = ikxdl + k0 * KXdlPack;
1052 
1053  vector_type<ComputeTypeA, KPack> a_thread_vec;
1054  vector_type<ComputeTypeB, KPack> b_thread_vec;
1055 
1056  static_for<0, KPack, 1>{}([&](auto ik) {
1057  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1058  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1059  make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
1060  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1061  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
1062  make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
1063  });
1064 
1065  using mfma_input_type_a =
1066  typename vector_type<ComputeTypeA,
1067  xdlops_gemm.K1PerXdlops /
1068  APackedSize>::type;
1069 
1070  using mfma_input_type_b =
1071  typename vector_type<ComputeTypeB,
1072  xdlops_gemm.K1PerXdlops /
1073  BPackedSize>::type;
1074 
1075  using mfma_scale_input_type_a =
1076  typename vector_type<AScaleDataType,
1077  a_scale_thread_vec_size>::type;
1078  using mfma_scale_input_type_b =
1079  typename vector_type<BScaleDataType,
1080  b_scale_thread_vec_size>::type;
1081 
1082  constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
1083  make_tuple(m0, n0, imxdl, inxdl, 0));
1084 
1085  // MFMA accumulation
1086  xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
1087  ikxdl * NXdlPack + inxdl>(
1088  a_thread_vec.template AsType<mfma_input_type_a>(),
1089  a_scale_thread_vec
1090  .template AsType<mfma_scale_input_type_a>(),
1091  b_thread_vec.template AsType<mfma_input_type_b>(),
1092  b_scale_thread_vec
1093  .template AsType<mfma_scale_input_type_b>(),
1094  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1095  });
1096  });
1097  });
1098  });
1099  });
1100  });
1101  }
1102  }
1103 
1104  // TODO: make this field protected when a_scale_thread_copy_ is moved
1105  // here
1106  static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
1108  Number<KRepeat / KXdlPack>{},
1109  Number<ScalesPerXdlopsRunPerThread * a_scale_thread_vec_size>{}));
1110 
1111  // TODO: make this field protected when b_scale_thread_copy_ is moved
1112  // here
1113  static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
1115  Number<KRepeat / KXdlPack>{},
1116  Number<ScalesPerXdlopsRunPerThread * b_scale_thread_vec_size>{}));
1117 
1118  protected:
1119  using Base::a_thread_copy_;
1120  using Base::a_thread_desc_;
1121  using Base::b_thread_copy_;
1122  using Base::b_thread_desc_;
1123  using Base::c_thread_desc_;
1124 };
1125 
1126 } // namespace ck
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
Definition: ck.hpp:269
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:300
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_mx_pipeline_xdlops_base.hpp:33
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, const AScaleGridDesc &a_scale_grid_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const BScaleGridDesc &b_scale_grid_desc, BScaleThreadTransfer &b_scale_thread_copy, const BScaleGridBuffer &b_scale_grid_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp:389
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp:38
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:25
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10