include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp Source File

include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp Source File#

Composable Kernel: include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp Source File
gridwise_moe_mx_gemm_bpreshuffle.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/env.hpp"
16 
20 
21 #define DEBUG_LOG 0
22 
23 namespace ck {
24 
25 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
26 // kernel function Blockers:
27 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
28 // two lds chunks.
29 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
30 // buffer when we declare __shared__ inside blkgemmpipe
31 
33 {
34  gelu_and_mul = 0,
35  silu_and_mul = 1
36 };
37 
38 template <typename GridwiseGemm,
39  bool HasMainKBlockLoop,
40  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
41  index_t MinimumOccupancy = 1,
42  TailNumber TailNum = TailNumber::Even>
43 __global__ void
44 #if CK_USE_LAUNCH_BOUNDS
45 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
46 #endif
47  // __attribute__((amdgpu_waves_per_eu(1, 1)))
48  kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
49 {
50 #if defined(__gfx9__)
51  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
52 
53  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
54 
55  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
56  karg.p_sorted_token_ids,
57  karg.p_sorted_expert_ids,
58  karg.p_max_token_id,
59  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
60  karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
61  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
62  karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
63  karg.p_ds_grid,
64  karg.p_c_grid,
65  p_shared,
66  karg,
67  karg.a_element_op,
68  karg.b_element_op,
69  karg.c_element_op);
70 #else
71  ignore = karg;
72 #endif // end of if (defined(__gfx9__))
73 }
74 
75 template <typename GridwiseGemm,
76  bool HasMainKBlockLoop,
77  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
78  index_t MinimumOccupancy = 1,
79  TailNumber TailNum = TailNumber::Even>
80 __global__ void
81 #if CK_USE_LAUNCH_BOUNDS
82 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
83 #endif
84  // __attribute__((amdgpu_waves_per_eu(1, 1)))
85  kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
86 {
87 #if defined(__gfx9__)
88  __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
89  __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
90 
91  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
92 
93  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
94  karg.p_sorted_token_ids,
95  karg.p_sorted_expert_ids,
96  karg.p_max_token_id,
97  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
98  karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
99  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
100  karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
101  karg.p_ds_grid,
102  karg.p_c_grid,
103  p_shared_0,
104  p_shared_1,
105  karg,
106  karg.a_element_op,
107  karg.b_element_op,
108  karg.c_element_op);
109 #else
110  ignore = karg;
111 #endif // end of if (defined(__gfx9__))
112 }
113 
114 template <typename ALayout,
115  typename BLayout,
116  typename DsLayout,
117  typename CLayout,
118  typename ADataType,
119  typename AScaleDataType,
120  typename BDataType,
121  typename BScaleDataType,
122  typename AccDataType,
123  typename CShuffleDataType,
124  typename DsDataType,
125  typename CDataType,
126  typename AElementwiseOperation,
127  typename BElementwiseOperation,
128  typename CElementwiseOperation,
130  index_t ScaleBlockSize, // Scaling block size
131  index_t BlockSize, // Thread block size
132  index_t MPerBlock,
133  index_t NPerBlock,
134  index_t KPerBlock,
135  index_t AK1Value,
136  index_t BK1Value,
137  index_t MPerXdl,
138  index_t NPerXdl,
139  index_t MXdlPerWave,
140  index_t NXdlPerWave,
141  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
142  typename ABlockTransferThreadClusterArrangeOrder,
143  typename ABlockTransferSrcAccessOrder,
144  index_t ABlockTransferSrcVectorDim,
145  index_t ABlockTransferSrcScalarPerVector,
146  index_t ABlockTransferDstScalarPerVector_AK1,
147  bool AThreadTransferSrcResetCoordinateAfterRun,
148  index_t ABlockLdsExtraM,
149  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
150  typename BBlockTransferThreadClusterArrangeOrder,
151  typename BBlockTransferSrcAccessOrder,
152  index_t BBlockTransferSrcVectorDim,
153  index_t BBlockTransferSrcScalarPerVector,
154  index_t BBlockTransferDstScalarPerVector_BK1,
155  bool BThreadTransferSrcResetCoordinateAfterRun,
156  index_t BBlockLdsExtraN,
157  index_t CShuffleMXdlPerWavePerShuffle,
158  index_t CShuffleNXdlPerWavePerShuffle,
159  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
160  typename CDEShuffleBlockTransferScalarPerVectors,
163  index_t ActivationOperation = 0,
164  bool NSwizzle = false,
165  bool IsInputGemm = true,
166  bool MulRoutedWeight = true,
167  typename IndexType = index_t,
168  typename ComputeTypeA = ADataType,
169  typename ComputeTypeB = BDataType>
171 {
172  using LDSTypeA = ADataType;
173  using LDSTypeB = BDataType;
174 
175  static constexpr auto I0 = Number<0>{};
176  static constexpr auto I1 = Number<1>{};
177  static constexpr auto I2 = Number<2>{};
178  static constexpr auto I3 = Number<3>{};
179  static constexpr auto I4 = Number<4>{};
180  static constexpr auto I5 = Number<5>{};
181  static constexpr auto I6 = Number<6>{};
182  static constexpr auto I7 = Number<7>{};
183  static constexpr auto I8 = Number<8>{};
184  static constexpr auto I9 = Number<9>{};
185 
187  CDEShuffleBlockTransferScalarPerVectors{}[I0];
188  // K1 should be Number<...>
189  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
190  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
191  static constexpr auto AK1Number = Number<AK1Value>{};
192  static constexpr auto BK1Number = Number<BK1Value>{};
193 
194  static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
195  static constexpr bool is_single_rate_mfma = false;
196  static constexpr auto is_scale_mfma = true;
197 
198  static constexpr index_t NumDTensor = DsDataType::Size();
199 
200  static constexpr auto MXdlPack = 2;
201  static constexpr auto NXdlPack = 2;
202  static constexpr auto KXdlPack = 2;
203 
204  //> KPack is at least the k_per_blk of selected mfma
205  //
206  // Should be a multiple of k_per_blk.
207  // TODO: Move this to blockwise pipeline base
208  // KPack in packed data types for pk A/B
209 
210  static constexpr index_t APackedSize = packed_size_v<ADataType>;
211  static constexpr index_t BPackedSize = packed_size_v<BDataType>;
212 
213  using mfma_selector = MfmaSelector<ComputeTypeA,
214  MPerXdl,
215  NPerXdl,
216  ComputeTypeB,
218  is_scale_mfma>;
219  static constexpr index_t KPack =
221 
222  static constexpr index_t NLane = NPerXdl;
223  static constexpr index_t KLane = 64 / NLane;
224  static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
225  static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
226 
227  // static constexpr index_t NumTokens = 1;
228  static constexpr index_t SortedTileSize = MPerBlock;
229 
231  static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
232  static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
233  static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
234  "A scale pack data type too large!");
235  static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
236  "B scale pack data type too large!");
237 
238  static constexpr auto MakeDsGridPointer()
239  {
240  return generate_tuple(
241  [&](auto i) {
242  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
243 
244  return static_cast<const DDataType*>(nullptr);
245  },
247  }
248 
249  using DsGridPointer = decltype(MakeDsGridPointer());
250 
252 
253  __host__ static auto CalculateGridSize(index_t M, index_t N)
254  {
255  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
256  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
257  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
258  const index_t gridy = NSwizzle ? 1 : mblock;
259 
260  return std::make_tuple(gridx, gridy, 1);
261  }
262 
263  __host__ static auto CalculateMPadded(index_t M)
264  {
265  return math::integer_least_multiple(M, MPerBlock);
266  }
267 
268  __host__ static auto CalculateNPadded(index_t N)
269  {
270  return math::integer_least_multiple(N, NPerBlock);
271  }
272 
273  __host__ static auto CalculateBN0Shuffled(index_t N)
274  {
275  return math::integer_divide_ceil(N, NLane);
276  }
277  __host__ static auto CalculateBK0Shuffled(index_t K)
278  {
280  }
281 
282  __host__ static auto CalculateKPadded(index_t K)
283  {
284  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
285  }
286 
287  __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
288  {
289  auto K_t = K_Batch * KPerBlock;
290  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
291  }
292 
293  __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
294  {
295  auto K_t = K_Batch * KPerBlock;
296  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
297  }
298 
299  __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
300  {
301  auto K_t = K_Batch * KPerBlock;
302  return (K + K_t - 1) / K_t * KPerBlock;
303  }
304 
305  __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
306  {
307  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
308  auto K_t = K_Batch * KReadVec;
309  return (K + K_t - 1) / K_t * KReadVec;
310  }
311 
312  __host__ static auto CalculateMBlock(index_t M)
313  {
314  return math::integer_divide_ceil(M, MPerBlock);
315  }
316 
317  __host__ static auto CalculateNBlock(index_t N)
318  {
319  return math::integer_divide_ceil(N, NPerBlock);
320  }
321 
322  template <index_t MNXdlPerWave,
323  index_t MNWaves,
324  index_t MNXdlPack,
325  index_t MNPerXdl,
326  bool IsXor,
327  typename TileDesc_K0_MN_K1>
328  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
329  {
330  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
331  constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
332  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
333 
334  if constexpr(IsXor)
335  {
336  constexpr auto permuted_desc = transform_tensor_descriptor(
337  TileDesc_K0_MN_K1{},
342 
344  permuted_desc,
345  make_tuple(
348  Number<MNWaves>{},
350  Number<MNPerXdl>{}))),
353  }
354  else
355  {
357  TileDesc_K0_MN_K1{},
358  make_tuple(
361  Number<MNWaves>{},
363  Number<MNPerXdl>{}))),
366  }
367  }
368 
369  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
370  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
371  {
372  const auto a_grid_desc_mraw_kraw = [&]() {
373  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
374  {
375  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
376  }
377  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
378  {
379  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
380  }
381  }();
382 
384 
385  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
386  GemmSpec == GemmSpecialization::MNKPadding)
387  {
388  // pad both M and K
389  const auto a_grid_desc_m_k =
390  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
392  make_right_pad_transform(K, KPad - K)),
395 
396  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
397  a_grid_desc_m_k,
402 
403  return a_grid_desc_ak0_m_ak1;
404  }
405  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
406  GemmSpec == GemmSpecialization::MNPadding)
407  {
408  // pad M, but not K
409  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
410  a_grid_desc_mraw_kraw,
412  make_right_pad_transform(M, MPad - M)),
415 
416  return a_grid_desc_ak0_m_ak1;
417  }
418  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
419  GemmSpec == GemmSpecialization::NKPadding)
420  {
421  // pad K, but not M
422  const auto a_grid_desc_m_k = transform_tensor_descriptor(
423  a_grid_desc_mraw_kraw,
427 
428  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
429  a_grid_desc_m_k,
434 
435  return a_grid_desc_ak0_m_ak1;
436  }
437  else
438  {
439  // not pad M or K
440  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
441  a_grid_desc_mraw_kraw,
442  make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
446 
447  const auto a_grid_desc_permuted = transform_tensor_descriptor(
448  a_grid_desc_ak0_m_ak1,
451  make_pass_through_transform(AK1Value)),
454 
455  const auto a_grid_desc = transform_tensor_descriptor(
456  a_grid_desc_permuted,
457  make_tuple(
460  make_pass_through_transform(AK1Value)),
463 
464  return a_grid_desc;
465  }
466  }
467 
468  __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
469  {
470  constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack>{};
472  make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber));
473  }
474 
475  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
476  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
477  {
478  const auto b_grid_desc_nraw_kraw = [&]() {
480  {
481  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
482  }
484  {
485  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
486  }
487  }();
488 
490 
491  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
492  GemmSpec != GemmSpecialization::Default),
493  "pk_i4_t does not support padding");
494  static_assert(!(is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t> &&
495  GemmSpec != GemmSpecialization::Default),
496  "f4x2_pk_t does not support padding");
497 
498  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
499  GemmSpec == GemmSpecialization::MNKPadding)
500  {
501  // pad both N and K
502  const auto b_grid_desc_n_k =
503  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
505  make_right_pad_transform(K, KPad - K)),
508 
509  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
510  b_grid_desc_n_k,
515 
516  return b_grid_desc_bk0_n_bk1;
517  }
518  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
519  GemmSpec == GemmSpecialization::MNPadding)
520  {
521  // pad N, but not K
522  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
523  b_grid_desc_nraw_kraw,
525  make_right_pad_transform(N, NPad - N)),
528 
529  return b_grid_desc_bk0_n_bk1;
530  }
531  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
532  GemmSpec == GemmSpecialization::MKPadding)
533  {
534  // pad K, but not N
535  const auto b_grid_desc_n_k = transform_tensor_descriptor(
536  b_grid_desc_nraw_kraw,
540 
541  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
542  b_grid_desc_n_k,
547 
548  return b_grid_desc_bk0_n_bk1;
549  }
550  else
551  {
552  // not pad N or K
553  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
554  b_grid_desc_nraw_kraw,
555  make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)),
559 
560  const auto b_grid_desc_permuted = transform_tensor_descriptor(
561  b_grid_desc_bk0_n_bk1,
564  make_pass_through_transform(BK1Value)),
567 
568  const auto b_grid_desc = transform_tensor_descriptor(
569  b_grid_desc_permuted,
570  make_tuple(
573  make_pass_through_transform(BK1Value)),
576 
577  return b_grid_desc;
578  }
579  }
580 
581  template <typename ABlockDesc_AK0_M_AK1>
582  __host__ __device__ static constexpr auto
583  MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&)
584  {
585  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
586 
587  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MXdlPack, MPerXdl, true>(
588  ABlockDesc_AK0_M_AK1{});
589  }
590 
591  template <typename BBlockDesc_BK0_N_BK1>
592  __host__ __device__ static constexpr auto
593  MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&)
594  {
595  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
596 
597  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NXdlPack, NPerXdl, false>(
598  BBlockDesc_BK0_N_BK1{});
599  }
600 
601  template <typename ELayout>
602  __host__ __device__ static auto MakeCGridDescriptor_M_N(
603  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
604  {
605  const auto c_grid_desc_mraw_nraw = [&]() {
607  {
608  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
609  }
611  {
612  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
613  }
614  }();
615 
616  // pad M and N
617  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
619  make_right_pad_transform(N, NPad - N)),
622  }
623 
624  template <typename DLayout>
625  __host__ __device__ static auto
627  {
628  const auto c_grid_desc_mraw_nraw = [&]() {
630  {
631  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
632  }
634  {
635  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
636  }
637  }();
638 
639  // pad M and N
640  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
642  make_right_pad_transform(N, NPad - N)),
645  }
646 
647  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
648  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
649  {
650  return generate_tuple(
651  [&](auto i) {
652  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
653  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
654  },
656  }
657 
658  template <typename DsGridDesc>
660  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
661  {
662  return generate_tuple(
663  [&](auto i) {
665  ds_grid_desc_m_n[i], MBlock, NBlock);
666  },
668  }
669 
670  struct Problem
671  {
672  __host__ Problem(index_t NumTokens_,
673  index_t TopK_,
674  index_t M_,
675  index_t N_,
676  index_t K_,
677  index_t StrideA_,
678  index_t StrideScaleA_,
679  index_t StrideB_,
680  index_t StrideScaleB_,
681  std::array<index_t, NumDTensor> StrideDs_,
682  index_t StrideC_,
683  index_t KBatch_)
684  : NumTokens{NumTokens_},
685  TopK{TopK_},
686  M{M_},
687  N{N_},
688  K{K_},
689  StrideA{StrideA_},
690  StrideScaleA{StrideScaleA_},
691  StrideB{StrideB_},
692  StrideScaleB{StrideScaleB_},
693  StrideDs{StrideDs_},
694  StrideC{StrideC_},
695  KBatch{KBatch_},
698  KRead{CalculateKRead(K_, KBatch_)},
699  KPadded{CalculateKPadded(K_, KBatch_)},
700  AK0{CalculateAK0Padded(K_, KBatch_)},
701  BK0{CalculateBK0Padded(K_, KBatch_)},
702  MBlock{CalculateMBlock(M_)},
703  NBlock{CalculateNBlock(N_)},
706  {
707  }
708 
709  __host__ void Print() const
710  {
711  std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
712  << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
713  << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", "
714  << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", "
715  << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
716  << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
717  << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
718  << ", " << "NBlock: " << NBlock << "}" << std::endl;
719  }
720 
730  std::array<index_t, NumDTensor> StrideDs;
741  // FOR PRESHUFFLE ONLY
744  };
745 
746  // Argument
748  {
749  __host__ Argument(const index_t* p_sorted_token_ids_,
750  const index_t* p_sorted_expert_ids_,
751  const index_t* p_max_token_id_,
752  const ADataType* p_a_grid_,
753  const AScaleDataType* p_a_scale_grid_,
754  const BDataType* p_b_grid_,
755  const BScaleDataType* p_b_scale_grid_,
756  std::array<const void*, NumDTensor> p_ds_grid_,
757  CDataType* p_c_grid_,
758  index_t NumTokens_,
759  index_t TopK_,
760  index_t M_,
761  index_t N_,
762  index_t K_,
763  index_t StrideA_,
764  index_t StrideScaleA_,
765  index_t StrideB_,
766  index_t StrideScaleB_,
767  std::array<index_t, NumDTensor> StrideDs_,
768  index_t StrideC_,
769  index_t k_batch_,
770  AElementwiseOperation a_element_op_,
771  BElementwiseOperation b_element_op_,
772  CElementwiseOperation c_element_op_)
773  : Problem{NumTokens_,
774  TopK_,
775  M_,
776  N_,
777  K_ / APackedSize,
778  StrideA_ / APackedSize,
779  StrideScaleA_,
780  StrideB_ / BPackedSize,
781  StrideScaleB_,
782  StrideDs_,
783  StrideC_,
784  k_batch_},
785  p_sorted_token_ids{p_sorted_token_ids_},
786  p_sorted_expert_ids{p_sorted_expert_ids_},
787  p_max_token_id{p_max_token_id_},
788  p_a_grid{p_a_grid_},
789  p_a_scale_grid{p_a_scale_grid_},
790  p_b_grid{p_b_grid_},
791  p_b_scale_grid{p_b_scale_grid_},
792  p_ds_grid{},
793  p_c_grid{p_c_grid_},
794  a_element_op{a_element_op_},
795  b_element_op{b_element_op_},
796  c_element_op{c_element_op_}
797  {
798 
799  // populate pointer, desc for Ds
800  static_for<0, NumDTensor, 1>{}([&](auto i) {
801  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
802 
803  // D pointer
804  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
805  });
806  }
807 
811  const ADataType* p_a_grid;
812  const AScaleDataType* p_a_scale_grid;
813  const BDataType* p_b_grid;
814  const BScaleDataType* p_b_scale_grid;
816  CDataType* p_c_grid;
817 
818  const AElementwiseOperation a_element_op;
819  const BElementwiseOperation b_element_op;
820  const CElementwiseOperation c_element_op;
821  };
822 
824  {
825  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
826  {
827  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
828  {
829  a_k_split_offset = k_id * karg.KRead;
830  }
831  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
832  {
833  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
834  }
835 
836  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
837  {
838  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
839  }
840  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
841  {
842  // KPack * NLane * KLane * K0 * N0
843  b_k_split_offset = k_id * karg.KRead * NPerXdl;
844  }
845 
846  // Calculate A scale offset
847  a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize) * MXdlPack *
848  MPerXdl / scale_pack_size_a;
849 
850  // Calculate B scale offset
851  b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize) * NXdlPack *
852  NPerXdl / scale_pack_size_b;
853 
854  if(k_id < karg.KBatch - 1)
855  {
856  karg.K = karg.KRead;
857  }
858  else
859  {
860  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
861  }
862  }
863 
868  };
869 
870  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
871  {
872  // A matrix in LDS memory, dst of blockwise copy
873  if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
874  {
875  // contiguous in LDS
879  }
880  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
881  // in some cases.
883  {
884  constexpr auto a_lds_block_desc =
887 
888  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
889  a_lds_block_desc,
895 
896  return a_lds_block_desc_permuted;
897  }
898  else // ColumnMajor A
899  {
900  // kfold and mpair dimension is not always required.
901  // more dimension in merge_transform increase the difficulty of generating immarg offset
902  // for compiler.
903  constexpr auto WaveSize = 64;
904  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
905  constexpr auto M1 = MPerBlock / M0;
906 
907  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
908  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
909  constexpr auto KThreadRead = WaveSize / MPerXdl;
910  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
911 
912  constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
913  ? 1
914  : 128 / (AK1Number * M0 * sizeof(ADataType));
915  constexpr auto KThreadReadPerm =
916  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
917  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
918  : KThreadRead;
919 
920  // 1<=mpair<=n0
921  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
922  ? 1
923  : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
924  ? M0
925  : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
926 
927  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
931  Number<kfold * M0 / mpair>{},
932  Number<mpair>{},
933  AK1Number));
934 
935  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
936  a_lds_block_desc,
937  make_tuple(
941  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
944  make_tuple(
946  make_tuple(
948 
949  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
950  a_lds_block_desc_permuted,
951  make_tuple(
959  Sequence<1>{},
960  Sequence<2>{},
961  Sequence<3>{},
962  Sequence<4>{},
963  Sequence<5>{}),
965  Sequence<2>{},
966  Sequence<0, 3>{},
967  Sequence<4, 5>{},
968  Sequence<6>{},
969  Sequence<7>{}));
970 
971  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
972  a_lds_block_desc_unmerged,
975  Number<KThreadWrite / kfold / KThreadReadPerm>{},
976  Number<kfold>{},
983 
984  return a_lds_block_desc_ak0_m_ak1;
985  }
986  }
987 
988  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
989  {
990  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
992  I1,
994  Number<KRepeat>{},
995  Number<BK1Value>{}));
996  }
997 
999  {
1000  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1001 
1002  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1004  make_tuple(I1,
1006  I1,
1008 
1009  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1010  }
1011 
1014  BlkGemmPipelineVer,
1015  BlkGemmPipeSched,
1016  BlockSize,
1017  ScaleBlockSize,
1018  ADataType,
1019  AScaleDataType,
1020  BDataType,
1021  BScaleDataType,
1022  ComputeTypeA,
1023  AccDataType,
1030  ABlockTransferSrcScalarPerVector,
1031  BBlockTransferSrcScalarPerVector,
1032  MPerBlock,
1033  NPerBlock,
1034  KPerBlock,
1035  MPerXdl,
1036  NPerXdl,
1037  MXdlPerWave,
1038  NXdlPerWave,
1039  KPack,
1040  IsInputGemm>())>;
1041 
1042  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1043  {
1044  // LDS allocation for A and B: be careful of alignment
1045  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1046  // lds max alignment
1047  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1048 
1049  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1050  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1051 
1052  // LDS allocation for C shuffle in LDS
1053  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1055 
1056  constexpr auto c_block_size =
1057  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1058 
1059  return math::max(a_block_space_size_aligned * sizeof(ADataType),
1060  c_block_size * sizeof(CShuffleDataType));
1061  }
1062 
1063  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1064  __host__ static constexpr bool CheckValidity(const Argument& karg)
1065  {
1066  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1067  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1068  "Invalid tuning param!");
1069 
1070  static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
1071  "KPerBlock should be multiple of ScaleBlockSize");
1072 
1073  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1078  {
1079  if(!(karg.M % MPerBlock == 0))
1080  {
1081  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1082  {
1083  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1084  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1085  << std::endl;
1086  }
1087  return false;
1088  }
1089  }
1090 
1091  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1096  {
1097  if(!(karg.N % NPerBlock == 0))
1098  {
1099  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1100  {
1101  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1102  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1103  << std::endl;
1104  }
1105  return false;
1106  }
1107  }
1108 
1109  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1113  {
1114  auto K_t = karg.KBatch * KPerBlock;
1115  if(!(karg.K % K_t == 0))
1116  {
1117  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1118  {
1119  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1120  << karg.K << " " << __FILE__ << ":" << __LINE__
1121  << ", in function: " << __func__ << std::endl;
1122  }
1123  return false;
1124  }
1125  }
1126  else
1127  {
1128  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1129  auto K_t = karg.KBatch * KReadVec;
1130  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1131  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1132  {
1133  return false;
1134  }
1135  }
1136 
1138  {
1139  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1140  {
1141  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1142  {
1143  std::cout << "Arg K (" << karg.K
1144  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1145  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1146  << __LINE__ << ", in function: " << __func__ << std::endl;
1147  }
1148  return false;
1149  }
1150  }
1151  else
1152  {
1153  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1154  {
1155  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1156  {
1157  std::cout << "Arg M (" << karg.M
1158  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1159  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1160  << __LINE__ << ", in function: " << __func__ << std::endl;
1161  }
1162  return false;
1163  }
1164  }
1165 
1167  {
1168  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1169  {
1170  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1171  {
1172  std::cout << "Arg N (" << karg.N
1173  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1174  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1175  << __LINE__ << ", in function: " << __func__ << std::endl;
1176  }
1177  return false;
1178  }
1179  }
1180  else
1181  {
1182  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1183  {
1184  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1185  {
1186  std::cout << "Arg K (" << karg.K
1187  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1188  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1189  << __LINE__ << ", in function: " << __func__ << std::endl;
1190  }
1191  return false;
1192  }
1193  }
1194 
1196  {
1198  {
1199  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1200  {
1201  std::cout << "Arg N (" << karg.N
1202  << ") value is not a multiple of "
1203  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1205  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1206  << std::endl;
1207  }
1208  return false;
1209  }
1210  }
1211  else
1212  {
1214  {
1215  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1216  {
1217  std::cout << "Arg M (" << karg.M
1218  << ") value is not a multiple of "
1219  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1221  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1222  << std::endl;
1223 
1224  return false;
1225  }
1226  }
1227  }
1228 
1229  // check gridwise gemm pipeline
1230 #if 0
1231  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1232 
1233  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1234  {
1235  return false;
1236  }
1237 #endif
1238  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1239  return true;
1240  }
1241 
1242  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1243  {
1244  const index_t num_loop = K / KPerBlock;
1245 
1246  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1247  }
1248 
1249  __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1250  {
1251  const index_t num_loop = K / KPerBlock;
1252 
1253  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1254  }
1255 
1256  template <typename CGridDesc>
1257  __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1258  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1259  {
1260  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1261  c_grid_desc_m_n,
1266 
1267  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1268  }
1269 
1270  // return block_id to C matrix tile idx (m0, n0) mapping
1271  // if arch = gfx942
1272  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1273  // NPerBlock>;
1274 
1275 #if 0
1276  template <bool HasMainKBlockLoop,
1277  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1278  TailNumber TailNum = TailNumber::Odd>
1279  __device__ static void Run(const index_t* p_sorted_token_ids,
1280  const index_t* p_sorted_expert_ids,
1281  const index_t* p_max_token_id,
1282  const ADataType* p_a_grid,
1283  const AScaleDataType* p_a_scale_grid,
1284  const BDataType* p_b_grid,
1285  const BScaleDataType* p_b_scale_grid,
1286  DsGridPointer& p_ds_grid,
1287  CDataType* p_c_grid,
1288  void* p_shared,
1289  const Problem& problem,
1290  AElementwiseOperation a_element_op,
1291  BElementwiseOperation b_element_op,
1292  CElementwiseOperation c_element_op)
1293  {
1294  ignore = b_element_op;
1295  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1296  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1297  problem.MPadded,
1298  problem.K,
1299  problem.KPadded,
1300  problem.StrideA,
1301  problem.AK0);
1302  const auto b_grid_desc_bpreshuffled =
1303  MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
1304  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1305  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1306  problem.MPadded,
1307  problem.N,
1308  problem.NPadded,
1309  problem.StrideC);
1310 
1311  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
1312  make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerBlock),
1313  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
1314  (KXdlPack * 64 / MPerXdl),
1315  64 * KXdlPack * MXdlPack / scale_pack_size_a));
1316 
1317  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
1318  make_tuple(problem.N / (NXdlPack * NPerXdl),
1319  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
1320  (KXdlPack * 64 / NPerXdl),
1321  64 * KXdlPack * NXdlPack / scale_pack_size_b));
1322 
1323  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1325  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1326  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1327  // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
1328  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1329  if(expert_block_id * MPerBlock >= max_token_id)
1330  return;
1331  const index_t expert_id =
1332  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1333 
1334  const auto block_mn = [&]() -> std::pair<int, int> {
1335  if constexpr(NSwizzle)
1336  {
1337  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1338  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1339  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1340  const index_t expert_swizzle =
1341  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1342  const index_t bid_new = blockIdx.x - prefix_block;
1343  const index_t nid = __builtin_amdgcn_readfirstlane(
1344  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1345  const index_t mid =
1346  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1347  return {nid, mid};
1348  }
1349  else
1350  {
1351  return {blockIdx.x, blockIdx.y};
1352  }
1353  }();
1354 
1355  const index_t block_n_id = block_mn.first;
1356  const index_t block_m_id = block_mn.second;
1357  const index_t token0 =
1358  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1359 
1360  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1361  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1362  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1363  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1364  constexpr auto AKThreads = AK0Threads * AK1Threads;
1365  constexpr auto AMRepeats = MPerBlock / AMThreads;
1366  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1367 
1368  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1369  return;
1370  StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
1371  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1372  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1373  index_t token_offset = fused_token & 0xffffff;
1374  if constexpr(!IsInputGemm)
1375  {
1376  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1377  }
1378  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K / APackedSize;
1379  });
1380  const index_t expert_stride =
1381  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1382  const index_t expert_scale_stride =
1383  __builtin_amdgcn_readfirstlane(problem.N * (IsInputGemm ? 2 : 1) *
1384  math::integer_divide_ceil(problem.K, ScaleBlockSize));
1385 
1386  // N0, K0, Blocksize*KPack
1387  const index_t n_block_data_idx_on_grid =
1388  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1389 
1390  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1391  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1392  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1393  p_b_grid + expert_id * expert_stride / BPackedSize,
1394  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1395 
1396  // A, B scale buffer
1397  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1398  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1399  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1400  p_b_scale_grid + expert_id * expert_scale_stride,
1401  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1402 
1403  // A matrix in LDS memory, dst of blockwise copy
1404  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1405 
1406  // B matrix in LDS memory, dst of blockwise copy
1407  // dummy
1408  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1409  // A matrix blockwise copy
1410  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1412  AElementwiseOperation,
1415  Sequence<AK0Number, MPerBlock, AK1Number>,
1416  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1417  ABlockTransferThreadClusterArrangeOrder,
1418  ADataType,
1419  LDSTypeA,
1420  decltype(a_grid_desc_ak0_m_ak1),
1421  decltype(a_block_desc_ak0_m_ak1),
1422  ABlockTransferSrcAccessOrder,
1423  Sequence<0, 1, 2>,
1424  ABlockTransferSrcVectorDim,
1425  2,
1426  ABlockTransferSrcScalarPerVector,
1427  ABlockTransferDstScalarPerVector_AK1,
1428  1,
1429  1,
1430  AThreadTransferSrcResetCoordinateAfterRun,
1431  true,
1432  IndexType,
1433  1,
1434  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1435  make_multi_index(0, 0, 0),
1436  a_element_op,
1437  a_block_desc_ak0_m_ak1,
1438  make_multi_index(0, 0, 0),
1440  gather_offsets);
1441 
1442  // Thread-wise copy
1443  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1444  auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1445  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1446 
1447  auto b_blockwise_copy =
1448  ThreadwiseTensorSliceTransfer_v2<BDataType,
1449  BDataType,
1450  decltype(b_grid_desc_bpreshuffled),
1451  decltype(b_block_desc_bk0_n_bk1),
1452  Sequence<Number<NXdlPerWave / NXdlPack>{},
1453  I1,
1454  Number<NXdlPack>{},
1455  Number<KRepeat>{},
1456  Number<BK1Value>{}>,
1457  Sequence<1, 2, 0, 3>,
1458  4,
1459  BBlockTransferSrcScalarPerVector,
1460  BThreadTransferSrcResetCoordinateAfterRun,
1461  true>(
1462  b_grid_desc_bpreshuffled,
1463  make_multi_index(n_block_data_idx_on_grid,
1465  0,
1466  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1467 
1468  // LDS allocation for A and B: be careful of alignment
1469  // Cast after lds
1470  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1471  static_cast<LDSTypeA*>(p_shared),
1472  a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize);
1473 
1474  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1475  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
1476 
1477  // Blockwise GEMM pipeline
1478  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1479  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1480  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1481  decltype(c_thread_buf) c_thread_buf_up;
1482 
1483  StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
1484  float,
1485  c_thread_buf.num_of_v_,
1486  c_thread_buf.s_per_v,
1487  true>
1488  c_thread_buf_fp32;
1489 
1490  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1491  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1492  KPerBlock);
1493 
1494  // a and b scale processing
1495  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1496  const auto waveId_m = wave_idx[I0];
1497  const auto waveId_n = wave_idx[I1];
1498 
1499  static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
1500 
1501  auto thread_offset_shuffled =
1502  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
1503 
1504  auto a_thread_offset_m = waveId_m;
1505 
1506  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1507  AScaleDataType,
1508  AScaleDataType,
1509  decltype(a_scale_grid_desc_am_ak),
1510  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1511  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
1512  Sequence<0, 1, 2>, // DimAccessOrder
1513  2, // SrcVectorDim
1514  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
1515  1, // SrcScalarStrideInVector
1516  true>(a_scale_grid_desc_am_ak,
1517  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
1518  0,
1519  thread_offset_shuffled / scale_pack_size_a));
1520 
1521  // B scale load
1522  auto b_thread_offset_n = waveId_n;
1523 
1524  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1525  BScaleDataType,
1526  BScaleDataType,
1527  decltype(b_scale_grid_desc_bn_ak),
1528  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1529  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1530  Sequence<0, 1, 2>, // DimAccessOrder
1531  2, // SrcVectorDim
1532  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1533  1, // SrcScalarStrideInVector
1534  true>(b_scale_grid_desc_bn_ak,
1535  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1536  0,
1537  thread_offset_shuffled / scale_pack_size_b));
1538 
1539  if constexpr(IsInputGemm)
1540  {
1541  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
1542  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1543  p_b_grid_up + expert_id * expert_stride / BPackedSize,
1544  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1545  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
1546  BDataType,
1547  BDataType,
1548  decltype(b_grid_desc_bpreshuffled),
1549  decltype(b_block_desc_bk0_n_bk1),
1550  Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
1551  Sequence<1, 2, 0, 3>,
1552  3,
1553  BBlockTransferSrcScalarPerVector,
1554  BThreadTransferSrcResetCoordinateAfterRun,
1555  true>(b_grid_desc_bpreshuffled,
1556  make_multi_index(n_block_data_idx_on_grid,
1558  0,
1559  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1560  const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
1561  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1562  p_b_scale_grid_up + expert_id * expert_scale_stride,
1563  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1564  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
1565  BScaleDataType,
1566  BScaleDataType,
1567  decltype(b_scale_grid_desc_bn_ak),
1568  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1569  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1570  Sequence<0, 1, 2>, // DimAccessOrder
1571  2, // SrcVectorDim
1572  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1573  1, // SrcScalarStrideInVector
1574  true>(
1575  b_scale_grid_desc_bn_ak,
1576  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1577  0,
1578  thread_offset_shuffled / scale_pack_size_b));
1579 
1580  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1581  a_grid_desc_ak0_m_ak1,
1582  a_block_desc_ak0_m_ak1,
1583  a_blockwise_copy,
1584  a_grid_buf,
1585  a_block_buf,
1586  a_block_slice_copy_step,
1587  b_grid_desc_bpreshuffled,
1588  b_block_desc_bk0_n_bk1,
1589  b_blockwise_copy,
1590  b_blockwise_copy_up,
1591  b_grid_buf,
1592  b_grid_buf_up,
1593  b_block_buf,
1594  b_block_slice_copy_step,
1595  c_thread_buf,
1596  c_thread_buf_up,
1597  a_scale_grid_desc_am_ak,
1598  a_scale_thread_copy,
1599  a_scale_grid_buf,
1600  b_scale_grid_desc_bn_ak,
1601  b_scale_thread_copy,
1602  b_scale_thread_copy_up,
1603  b_scale_grid_buf,
1604  b_scale_grid_buf_up,
1605  num_k_block_main_loop);
1606  }
1607  else
1608  {
1609  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1610  a_grid_desc_ak0_m_ak1,
1611  a_block_desc_ak0_m_ak1,
1612  a_blockwise_copy,
1613  a_grid_buf,
1614  a_block_buf,
1615  a_block_slice_copy_step,
1616  b_grid_desc_bpreshuffled,
1617  b_block_desc_bk0_n_bk1,
1618  b_blockwise_copy,
1619  b_grid_buf,
1620  b_block_buf,
1621  b_block_slice_copy_step,
1622  c_thread_buf,
1623  a_scale_grid_desc_am_ak,
1624  a_scale_thread_copy,
1625  a_scale_grid_buf,
1626  b_scale_grid_desc_bn_ak,
1627  b_scale_thread_copy,
1628  b_scale_grid_buf,
1629  num_k_block_main_loop);
1630  }
1631 
1632  // shuffle C and write out
1633  {
1634  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1635  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1636  "wrong!");
1637 
1638  // TODO: hacky, fix it!
1639  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1640  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1641 
1642  // TODO: hacky, fix it!
1643  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1644  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1645  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1646 
1647  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1648  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1649  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1650  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1651  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1652  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1653  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1654  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1655 
1656  // mul scales
1657  static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
1658  static_assert(M4 == 4);
1659  const index_t m1 = get_warp_local_1d_id() / NWave;
1660  const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
1661 
1662  vector_type<float, 4> topk_weights; // for gemm2 only
1663  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
1664  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
1665  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
1666  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
1667  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
1668  if constexpr(MulRoutedWeight)
1669  {
1670  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
1671  p_ds_grid[I2] + m_pos);
1672  }
1673  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
1674  constexpr index_t c_offset =
1675  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1676  make_tuple(m0, n0, m2 * M4 + m4));
1677  constexpr auto cidx = Number<c_offset>{};
1678 
1679  if constexpr(IsInputGemm) // gu fusion
1680  {
1681  if constexpr(ActivationOperation == Activation::silu_and_mul)
1682  {
1683  float gate = c_thread_buf[cidx];
1684  float up = c_thread_buf_up[cidx];
1685  if constexpr(MulRoutedWeight)
1686  {
1687  gate = gate * topk_weights.AsType<float>()[m4];
1688  up = up * topk_weights.AsType<float>()[m4];
1689  }
1690  tensor_operation::element_wise::Silu{}(gate, gate);
1691  c_thread_buf_fp32(cidx) = gate * up;
1692  }
1693  else if(ActivationOperation == Activation::gelu_and_mul)
1694  {
1695  float gate = c_thread_buf[cidx];
1696  float up = c_thread_buf_up[cidx];
1697  if constexpr(MulRoutedWeight)
1698  {
1699  gate = gate * topk_weights.AsType<float>()[m4];
1700  up = up * topk_weights.AsType<float>()[m4];
1701  }
1702  tensor_operation::element_wise::Gelu{}(gate, gate);
1703  c_thread_buf_fp32(cidx) = gate * up;
1704  }
1705  }
1706  else
1707  {
1708  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1709  if constexpr(MulRoutedWeight)
1710  {
1711  c_thread_buf_fp32(cidx) =
1712  topk_weights.AsType<float>()[m4] * c_thread_buf_fp32[cidx];
1713  }
1714  }
1715  });
1716  });
1717  });
1718  });
1719 
1720  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1722 
1723  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1724  static_cast<CShuffleDataType*>(p_shared),
1725  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1726 
1727  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1728  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1729  make_tuple(
1732  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1733  M1, // M1 = MWave
1734  M2, // M2 * M3 * M4 = MPerXdl
1735  M3,
1736  M4)),
1739  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1740  N1, // N1 = NWave
1741  N2))), // N2 = NPerXdl
1742  make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
1743  make_tuple(
1744  Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
1745 
1746  // calculate origin of thread output tensor on global memory
1747  // blockwise GEMM c matrix starting index
1748  const auto c_thread_mtx_on_block =
1749  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1750 
1751  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1752  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1753 
1754  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1756  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1757  make_tuple(Sequence<0, 1, 2, 3, 4>{}),
1758  make_tuple(Sequence<0>{}));
1759 
1760  const auto m_thread_data_on_block_idx =
1761  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1762  make_multi_index(m_thread_data_on_block));
1763 
1764  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1767  make_tuple(Sequence<0, 1, 2>{}),
1768  make_tuple(Sequence<0>{}));
1769 
1770  const auto n_thread_data_on_block_idx =
1771  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1772  make_multi_index(n_thread_data_on_block));
1773 
1774  // shuffle: threadwise copy C from VGPR to LDS
1775  auto c_thread_copy_vgpr_to_lds =
1776  ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
1777  CShuffleDataType,
1778  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1779  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1781  Sequence<CShuffleMXdlPerWavePerShuffle,
1782  CShuffleNXdlPerWavePerShuffle,
1783  I1,
1784  I1,
1785  M2,
1786  I1,
1787  M4,
1788  I1>,
1789  Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
1790  7,
1791  1,
1793  1,
1794  true>{
1795  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1796  make_multi_index(0,
1797  0,
1798  m_thread_data_on_block_idx[I1],
1799  n_thread_data_on_block_idx[I1],
1800  m_thread_data_on_block_idx[I2],
1801  m_thread_data_on_block_idx[I3],
1802  m_thread_data_on_block_idx[I4],
1803  n_thread_data_on_block_idx[I2]),
1805 
1806  using EDataType = CDataType;
1807 
1808  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1809  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1810 
1811  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1813  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1814 
1815  const auto ds_grid_buf = generate_tuple(
1816  [&](auto i) {
1817  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1818  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1819  },
1820  Number<NumDTensor>{});
1821 
1822  // tuple of reference to C/Ds tensor descriptors
1823  const auto c_ds_desc_refs = concat_tuple_of_reference(
1824  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1825  generate_tie([&](auto i) -> const auto& // return type should be reference
1826  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1827  Number<NumDTensor>{}));
1828 
1829  // tuple of reference to C/Ds tensor descriptors
1830  const auto c_ds_buf_refs = concat_tuple_of_reference(
1831  tie(c_shuffle_block_buf),
1832  generate_tie([&](auto i) -> const auto& // return type should be reference
1833  { return ds_grid_buf[i]; },
1834  Number<NumDTensor>{}));
1835 
1836  // tuple of starting index of C/Ds blockwise copy
1837  const auto idx_c_ds_block_begin =
1840  [&](auto) {
1841  return make_multi_index(block_m_id, 0, block_n_id, 0);
1842  // return make_multi_index(block_work_idx[I0], 0,
1843  // block_work_idx[I1], 0);
1844  },
1845  Number<NumDTensor>{}));
1846 
1847  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1848  c_grid_desc_mblock_mperblock_nblock_nperblock;
1849 
1850  using CDEBlockTransferCluster =
1851  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1852  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1853  constexpr index_t scatter_weight_idx = 1; // hack fix felix
1854  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1856  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1857  Tuple<EDataType>,
1858  decltype(c_ds_desc_refs),
1859  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1860  CElementwiseOperation,
1861  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1862  // support arbitray type
1863  Sequence<1,
1864  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1865  1,
1866  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1867  CDEBlockTransferCluster,
1868  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1869  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1870  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1871  3, // index_t SrcVectorDim,
1872  3, // index_t DstVectorDim,
1873  CDEShuffleBlockTransferScalarPerVectors,
1876  Sequence<true>,
1878  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1879  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1880  IndexType,
1881  1, // ScatterDim
1882  true, // OutputScatter: false, only use scatter weights
1883  scatter_weight_idx // ScatterWeightIdx: ascale
1884  >{c_ds_desc_refs,
1885  idx_c_ds_block_begin,
1886  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1887  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
1888  c_element_op};
1889 
1890  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1891  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1892  constexpr auto sfc_c_vgpr =
1893  SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
1894  Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
1895  Sequence<CShuffleMXdlPerWavePerShuffle,
1896  CShuffleNXdlPerWavePerShuffle,
1897  1,
1898  1,
1899  M2,
1900  1,
1901  M4,
1902  1>>{};
1903 
1904  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1905 
1906  // space filling curve for shuffled blockwise C/D/E
1907  constexpr auto sfc_cde_block =
1908  SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
1909  Sequence<0, 2, 1, 3>,
1910  Sequence<1,
1911  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1912  1,
1913  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1914 
1915  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1916  constexpr auto EMThreads =
1917  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
1918  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1919  constexpr auto ENThreads =
1920  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
1921  static_for<0, num_access, 1>{}([&](auto access_id) {
1922  // make sure it's safe to write to LDS
1923  StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
1924 
1925  auto dstidx = sfc_cde_block.GetIndex(access_id);
1926  const index_t c_token_pos =
1927  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
1928  static_for<0, EMRepeats, 1>{}([&](auto m0) {
1929  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1930  IndexType token_offset = fused_token & 0xffffff;
1931  if constexpr(IsInputGemm)
1932  {
1933  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1934  }
1935  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
1936  });
1937 
1938  block_sync_lds();
1939 
1940  // each thread write its data from VGPR to LDS
1941  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1942  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1943  c_thread_buf_fp32,
1944  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1945  c_shuffle_block_buf);
1946 
1947  // make sure it's safe to read from LDS
1948  block_sync_lds();
1949 
1950  // each block copy its data from LDS to global
1951  cde_block_copy_lds_and_global.Run(
1952  c_ds_desc_refs,
1953  c_ds_buf_refs,
1954  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1955  tie(c_grid_buf),
1956  scatter_offsets);
1957 
1958  if constexpr(access_id < num_access - 1)
1959  {
1960  constexpr auto cde_lds_and_global_step =
1961  sfc_cde_block.GetForwardStep(access_id);
1962 
1963  // move on Ds
1964  static_for<0, NumDTensor, 1>{}([&](auto i) {
1965  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1966  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
1967  });
1968 
1969  // move on E
1970  cde_block_copy_lds_and_global.MoveDstSliceWindow(
1971  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1972  I0,
1973  cde_lds_and_global_step);
1974  }
1975  });
1976  }
1977  }
1978 #endif
1979 
1980  template <bool HasMainKBlockLoop,
1981  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1982  TailNumber TailNum = TailNumber::Odd>
1983  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
1984  const index_t* p_sorted_expert_ids,
1985  const index_t* p_max_token_id,
1986  const ADataType* p_a_grid,
1987  const AScaleDataType* p_a_scale_grid,
1988  const BDataType* p_b_grid,
1989  const BScaleDataType* p_b_scale_grid,
1990  DsGridPointer& p_ds_grid,
1991  CDataType* p_c_grid,
1992  void* p_shared_0,
1993  void* p_shared_1,
1994  const Problem& problem,
1995  AElementwiseOperation a_element_op,
1996  BElementwiseOperation b_element_op,
1997  CElementwiseOperation c_element_op)
1998  {
1999  ignore = a_element_op;
2000  ignore = b_element_op;
2001  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2002  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
2003  problem.MPadded,
2004  problem.K,
2005  problem.KPadded,
2006  problem.StrideA,
2007  problem.AK0);
2008  const auto b_grid_desc_bpreshuffled =
2010  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
2011  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
2012  problem.MPadded,
2013  problem.N,
2014  problem.NPadded,
2015  problem.StrideC);
2016 
2017  // We pad the M unconditionaly for Scale
2018  const auto Padded_Scale_M =
2019  math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
2020  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
2021  make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
2022  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
2023  (KXdlPack * 64 / MPerXdl),
2025  make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
2026  (ScaleBlockSize / APackedSize)) *
2027  MPerXdl * MXdlPack / scale_pack_size_a,
2029  1));
2030 
2031  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
2032  make_tuple(problem.N / (NXdlPack * NPerXdl),
2033  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
2034  (KXdlPack * 64 / NPerXdl),
2036  make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
2037  (ScaleBlockSize / BPackedSize)) *
2038  NPerXdl * NXdlPack / scale_pack_size_b,
2040  1));
2041 
2042  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2044  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2045 
2046  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2047  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
2048  if(expert_block_id * MPerBlock >= max_token_id)
2049  return;
2050  const index_t expert_id =
2051  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2052  const auto block_mn = [&]() -> std::pair<int, int> {
2053  if constexpr(NSwizzle)
2054  {
2055  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2056  const index_t prefix_block = ecnt_prefix * problem.NBlock;
2057  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2058  const index_t expert_swizzle =
2059  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
2060  const index_t bid_new = blockIdx.x - prefix_block;
2061  const index_t nid = __builtin_amdgcn_readfirstlane(
2062  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2063  const index_t mid =
2064  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2065  return {nid, mid};
2066  }
2067  else
2068  {
2069  return {blockIdx.x, blockIdx.y};
2070  }
2071  }();
2072 
2073  const index_t block_n_id = block_mn.first;
2074  const index_t block_m_id = block_mn.second;
2075  const index_t token0 =
2076  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2077 
2078  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2079  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2080  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
2081  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
2082  constexpr auto AKThreads = AK0Threads * AK1Threads;
2083  constexpr auto AMRepeats = MPerBlock / AMThreads;
2084  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads;
2085 
2086  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
2087  return;
2089  static_for<0, AMRepeats, 1>{}([&](auto m0) {
2090  const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
2091  index_t token_offset = fused_token & 0xffffff;
2092  if constexpr(!IsInputGemm)
2093  {
2094  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2095  }
2096  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2097  });
2098 
2099  const index_t expert_stride =
2100  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2101  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2102  problem.N * (IsInputGemm ? 2 : 1) *
2103  math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
2104 
2105  // N0, K0, Blocksize*KPack
2106  const index_t n_block_data_idx_on_grid =
2107  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / NXdlPack);
2108 
2109  // Gride buffer creation
2110  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2111  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2112  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2113  p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
2114 
2115  // A, B scale buffer
2116  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2117  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2118  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2119  p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
2120  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2121 
2122  // A matrix in LDS memory, dst of blockwise copy
2123  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2124 
2125  // B matrix in LDS memory, dst of blockwise copy
2126  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2127 
2128  // A matrix blockwise direct to LDS copy
2129  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
2132  ABlockTransferThreadClusterLengths_AK0_M_AK1,
2133  ABlockTransferThreadClusterArrangeOrder,
2134  ADataType,
2135  ADataType,
2136  decltype(a_grid_desc_ak0_m_ak1),
2137  decltype(a_block_desc_ak0_m_ak1),
2138  ABlockTransferSrcAccessOrder,
2139  ABlockTransferSrcVectorDim,
2140  2,
2141  ABlockTransferSrcScalarPerVector,
2142  IndexType,
2143  1>(a_grid_desc_ak0_m_ak1,
2144  make_multi_index(0, 0, 0),
2145  a_block_desc_ak0_m_ak1,
2146  make_multi_index(0, 0, 0),
2147  gather_offsets);
2148 
2149  // Thread-wise copy
2150  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2151  auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2152  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2153  auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2154  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2155  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2156 
2157  auto b_blockwise_copy =
2159  BDataType,
2160  decltype(b_grid_desc_bpreshuffled),
2161  decltype(b_block_desc_bk0_n_bk1),
2162  Sequence<Number<NXdlPerWave / NXdlPack>{},
2163  I1,
2164  Number<NXdlPack>{},
2165  Number<KRepeat>{},
2166  Number<BK1Value>{}>,
2168  4,
2169  BBlockTransferSrcScalarPerVector,
2170  BThreadTransferSrcResetCoordinateAfterRun,
2171  true>(
2172  b_grid_desc_bpreshuffled,
2173  make_multi_index(n_block_data_idx_on_grid,
2175  0,
2176  0,
2177  KPack * (get_thread_local_1d_id() % WarpSize)));
2178 
2179  // LDS allocation for A and B: be careful of alignment
2180  // Cast after lds
2181  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2182  static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2183  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2184  static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2185  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2186 
2187  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2188  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0);
2189 
2190  // Blockwise GEMM pipeline
2191  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2192  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2193  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2194  decltype(c_thread_buf) c_thread_buf_up;
2195 
2197  float,
2198  c_thread_buf.num_of_v_,
2199  c_thread_buf.s_per_v,
2200  true>
2201  c_thread_buf_fp32;
2202 
2203  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2204  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2205  KPerBlock);
2206 
2207  // a and b scale processing
2208  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2209  const auto waveId_m = wave_idx[I0];
2210  const auto waveId_n = wave_idx[I1];
2211 
2212  auto thread_offset_shuffled =
2213  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
2214 
2215  auto a_thread_offset_m = waveId_m;
2216 
2217  // get each thread's offset int the scale tensor
2218  const index_t token_scale_pos = block_m_id * MPerBlock;
2219  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2220  return;
2221 
2222  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2223  AScaleDataType,
2224  AScaleDataType,
2225  decltype(a_scale_grid_desc_am_ak),
2226  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2227  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
2228  Sequence<0, 1, 2>, // DimAccessOrder
2229  2, // SrcVectorDim
2230  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
2231  1, // SrcScalarStrideInVector
2232  true>(a_scale_grid_desc_am_ak,
2233  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
2234  0,
2235  thread_offset_shuffled / scale_pack_size_a));
2236 
2237  // B scale load
2238  auto b_thread_offset_n = waveId_n;
2239 
2240  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2241  BScaleDataType,
2242  BScaleDataType,
2243  decltype(b_scale_grid_desc_bn_ak),
2244  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2245  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2246  Sequence<0, 1, 2>, // DimAccessOrder
2247  2, // SrcVectorDim
2248  KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
2249  1, // SrcScalarStrideInVector
2250  true>(b_scale_grid_desc_bn_ak,
2251  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2252  0,
2253  thread_offset_shuffled / scale_pack_size_b));
2254 
2255  if constexpr(IsInputGemm)
2256  {
2257  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
2258  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2259  p_b_grid_up + expert_id * expert_stride,
2260  b_grid_desc_bpreshuffled.GetElementSpaceSize());
2261  auto b_blockwise_copy_up =
2263  BDataType,
2264  decltype(b_grid_desc_bpreshuffled),
2265  decltype(b_block_desc_bk0_n_bk1),
2266  Sequence<Number<NXdlPerWave / NXdlPack>{},
2267  I1,
2268  Number<NXdlPack>{},
2269  Number<KRepeat>{},
2270  Number<BK1Value>{}>,
2272  4,
2273  BBlockTransferSrcScalarPerVector,
2274  BThreadTransferSrcResetCoordinateAfterRun,
2275  true>(
2276  b_grid_desc_bpreshuffled,
2277  make_multi_index(n_block_data_idx_on_grid,
2279  0,
2280  0,
2281  KPack * (get_thread_local_1d_id() % WarpSize)));
2282  const BScaleDataType* p_b_scale_grid_up =
2283  p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
2284  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2285  p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
2286  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2287 
2288  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
2289  BScaleDataType,
2290  BScaleDataType,
2291  decltype(b_scale_grid_desc_bn_ak),
2292  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2293  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2294  Sequence<0, 1, 2>, // DimAccessOrder
2295  2, // SrcVectorDim
2296  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
2297  1, // SrcScalarStrideInVector
2298  true>(
2299  b_scale_grid_desc_bn_ak,
2300  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2301  0,
2302  thread_offset_shuffled / scale_pack_size_b));
2303 
2304  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2305  // A
2306  a_grid_desc_ak0_m_ak1,
2307  a_block_desc_ak0_m_ak1,
2308  a_blockwise_copy,
2309  a_grid_buf,
2310  a_block_bufs,
2311  a_block_slice_copy_step,
2312  // Gate and Up
2313  b_grid_desc_bpreshuffled,
2314  b_block_desc_bk0_n_bk1,
2315  b_blockwise_copy,
2316  b_blockwise_copy_up,
2317  b_grid_buf,
2318  b_grid_buf_up,
2319  b_block_bufs,
2320  b_block_slice_copy_step,
2321  // C
2322  c_thread_buf,
2323  c_thread_buf_up,
2324  // A scale
2325  a_scale_grid_desc_am_ak,
2326  a_scale_thread_copy,
2327  a_scale_grid_buf,
2328  // B scale
2329  b_scale_grid_desc_bn_ak,
2330  b_scale_thread_copy,
2331  b_scale_thread_copy_up,
2332  b_scale_grid_buf,
2333  b_scale_grid_buf_up,
2334  num_k_block_main_loop);
2335  }
2336  else
2337  {
2338  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2339  a_grid_desc_ak0_m_ak1, // A
2340  a_block_desc_ak0_m_ak1,
2341  a_blockwise_copy,
2342  a_grid_buf,
2343  a_block_bufs,
2344  a_block_slice_copy_step,
2345  b_grid_desc_bpreshuffled, // B
2346  b_block_desc_bk0_n_bk1,
2347  b_blockwise_copy,
2348  b_grid_buf,
2349  b_block_bufs,
2350  b_block_slice_copy_step,
2351  c_thread_buf, // C
2352  a_scale_grid_desc_am_ak, // A scale
2353  a_scale_thread_copy,
2354  a_scale_grid_buf,
2355  b_scale_grid_desc_bn_ak, // B scale
2356  b_scale_thread_copy,
2357  b_scale_grid_buf,
2358  num_k_block_main_loop);
2359  }
2360 
2361  // shuffle C and write out
2362  {
2363  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2364  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2365  "wrong!");
2366  static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
2367  CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
2368  "wrong!");
2369 
2370  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2371 
2372  // TODO: hacky, fix it!
2373  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2374  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2375 
2376  // TODO: hacky, fix it!
2377  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2378  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2379  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2380 
2381  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2382  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2383  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2384  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2385  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2386  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2387  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2388  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2389  constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
2390  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
2391 
2392  // mul scales
2393 
2394  static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
2395  static_assert(M5 == 4);
2396  const index_t m1 = get_warp_local_1d_id() / NWave;
2397  const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
2398 
2399  vector_type<float, 4> topk_weights; // for gemm2 only
2400  static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
2401  static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
2402  static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
2403  static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
2404  static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
2405  const index_t m_pos = block_m_id * MPerBlock +
2406  m0 * M2 * M1 * M3 * M4 * M5 +
2407  m1 * M2 * M3 * M4 * M5 +
2408  imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
2409  if constexpr(MulRoutedWeight)
2410  {
2411  topk_weights =
2412  *c_style_pointer_cast<const vector_type<float, M5>*>(
2413  p_ds_grid[I2] + m_pos);
2414  }
2415  static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
2416  constexpr index_t c_offset =
2417  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2418  make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
2419  constexpr auto cidx = Number<c_offset>{};
2420 
2421  if constexpr(IsInputGemm) // gu fusion
2422  {
2423  if constexpr(ActivationOperation ==
2425  {
2426  float gate = c_thread_buf[cidx];
2427  float up = c_thread_buf_up[cidx];
2428  if constexpr(MulRoutedWeight)
2429  {
2430  gate = gate * topk_weights.AsType<float>()[m5];
2431  up = up * topk_weights.AsType<float>()[m5];
2432  }
2434  c_thread_buf_fp32(cidx) = gate * up;
2435  }
2436  else if(ActivationOperation == Activation::gelu_and_mul)
2437  {
2438  float gate = c_thread_buf[cidx];
2439  float up = c_thread_buf_up[cidx];
2440  if constexpr(MulRoutedWeight)
2441  {
2442  gate = gate * topk_weights.AsType<float>()[m5];
2443  up = up * topk_weights.AsType<float>()[m5];
2444  }
2446  c_thread_buf_fp32(cidx) = gate * up;
2447  }
2448  }
2449  else
2450  {
2451  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2452  if constexpr(MulRoutedWeight)
2453  {
2454  c_thread_buf_fp32(cidx) =
2455  topk_weights.AsType<float>()[m5] *
2456  c_thread_buf_fp32[cidx];
2457  }
2458  }
2459  });
2460  });
2461  });
2462  });
2463  });
2464  });
2465 
2466  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2468 
2469  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2470  static_cast<CShuffleDataType*>(p_shared_0),
2471  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2472 
2473  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2474  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2475  make_tuple(
2478  Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
2479  // shuffle
2480  M1, // M1 = MWave
2481  M2, // M2 * M3 * M4 = MPerXdl
2482  M3,
2483  M4,
2484  M5)),
2488  // per shuffle
2489  N1, // N1 = NWave
2490  N2, // N2 = NXdlPack
2491  N3))), // N3 = NPerXdl
2495  Sequence<>{},
2497 
2498  // calculate origin of thread output tensor on global memory
2499  // blockwise GEMM c matrix starting index
2500  const auto c_thread_mtx_on_block =
2501  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2502 
2503  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2504  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2505 
2506  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2508  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
2510  make_tuple(Sequence<0>{}));
2511 
2512  const auto m_thread_data_on_block_idx =
2513  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2514  make_multi_index(m_thread_data_on_block));
2515 
2516  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2518  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
2520  make_tuple(Sequence<0>{}));
2521 
2522  const auto n_thread_data_on_block_idx =
2523  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2524  make_multi_index(n_thread_data_on_block));
2525 
2526  // shuffle: threadwise copy C from VGPR to LDS
2527  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
2528  AccDataType,
2529  CShuffleDataType,
2530  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2531  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2533  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2534  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2535  I1,
2536  I1,
2537  M2,
2538  N2,
2539  M3,
2540  I1,
2541  M5,
2542  I1>,
2544  9,
2545  1,
2547  1,
2548  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2549  make_multi_index(0,
2550  0,
2551  m_thread_data_on_block_idx[I1],
2552  n_thread_data_on_block_idx[I1],
2553  m_thread_data_on_block_idx[I2],
2554  n_thread_data_on_block_idx[I2],
2555  m_thread_data_on_block_idx[I3],
2556  m_thread_data_on_block_idx[I4],
2557  m_thread_data_on_block_idx[I5],
2558  n_thread_data_on_block_idx[I3]),
2560 
2561  using EDataType = CDataType;
2562 
2563  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2564  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2565 
2566  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2568  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2569 
2570  const auto ds_grid_buf = generate_tuple(
2571  [&](auto i) {
2572  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2573  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2574  },
2575  Number<NumDTensor>{});
2576 
2577  // tuple of reference to C/Ds tensor descriptors
2578  const auto c_ds_desc_refs = concat_tuple_of_reference(
2579  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2580  generate_tie([&](auto i) -> const auto& // return type should be reference
2581  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2582  Number<NumDTensor>{}));
2583 
2584  // tuple of reference to C/Ds tensor descriptors
2585  const auto c_ds_buf_refs = concat_tuple_of_reference(
2586  tie(c_shuffle_block_buf),
2587  generate_tie([&](auto i) -> const auto& // return type should be reference
2588  { return ds_grid_buf[i]; },
2589  Number<NumDTensor>{}));
2590 
2591  // tuple of starting index of C/Ds blockwise copy
2592  const auto idx_c_ds_block_begin =
2595  [&](auto) {
2596  return make_multi_index(block_m_id, 0, block_n_id, 0);
2597  // return make_multi_index(block_work_idx[I0], 0,
2598  // block_work_idx[I1], 0);
2599  },
2600  Number<NumDTensor>{}));
2601 
2602  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2603  c_grid_desc_mblock_mperblock_nblock_nperblock;
2604 
2605  using CDEBlockTransferCluster =
2606  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2607  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2608  constexpr index_t scatter_weight_idx = 3; // hack fix felix
2609  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2611  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2613  decltype(c_ds_desc_refs),
2614  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2615  CElementwiseOperation,
2616  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
2617  // Sequence support
2618  // arbitray type
2619  Sequence<1,
2620  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2621  1,
2622  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2623  CDEBlockTransferCluster,
2624  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2625  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2626  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2627  3, // index_t SrcVectorDim,
2628  3, // index_t DstVectorDim,
2629  CDEShuffleBlockTransferScalarPerVectors,
2634  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2635  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2636  IndexType,
2637  1, // ScatterDim
2638  true, // OutputScatter: false, only use scatter weights
2639  scatter_weight_idx // ScatterWeightIdx: ascale
2640  >{c_ds_desc_refs,
2641  idx_c_ds_block_begin,
2642  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2643  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2644  c_element_op};
2645 
2646  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2647  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2648 
2649  constexpr auto sfc_c_vgpr =
2650  SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2651  NXdlPerWave / NXdlPack,
2652  1,
2653  1,
2654  MXdlPack,
2655  NXdlPack,
2656  M2,
2657  1,
2658  M4,
2659  1>,
2661  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2662  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2663  1,
2664  1,
2665  MXdlPack,
2666  NXdlPack,
2667  M2,
2668  1,
2669  M4,
2670  1>>{};
2671 
2672  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2673 
2674  // space filling curve for shuffled blockwise C/D/E
2675  constexpr auto sfc_cde_block =
2678  Sequence<1,
2679  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2680  1,
2681  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2682 
2683  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2684  constexpr auto EMThreads =
2685  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2686  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2687  constexpr auto ENThreads =
2688  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2689  static_for<0, num_access, 1>{}([&](auto access_id) {
2690  // make sure it's safe to write to LDS
2692 
2693  auto dstidx = sfc_cde_block.GetIndex(access_id);
2694  const index_t c_token_pos =
2695  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2696  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2697  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2698  IndexType token_offset = fused_token & 0xffffff;
2699  if constexpr(IsInputGemm)
2700  {
2701  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2702  }
2703  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2704  });
2705 
2706  block_sync_lds();
2707 
2708  // each thread write its data from VGPR to LDS
2709  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2710  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2711  c_thread_buf_fp32,
2712  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2713  c_shuffle_block_buf);
2714 
2715  // make sure it's safe to read from LDS
2716  block_sync_lds();
2717 
2718  // each block copy its data from LDS to global
2719  cde_block_copy_lds_and_global.Run(
2720  c_ds_desc_refs,
2721  c_ds_buf_refs,
2722  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2723  tie(c_grid_buf),
2724  scatter_offsets);
2725 
2726  if constexpr(access_id < num_access - 1)
2727  {
2728  constexpr auto cde_lds_and_global_step =
2729  sfc_cde_block.GetForwardStep(access_id);
2730 
2731  // move on Ds
2732  static_for<0, NumDTensor, 1>{}([&](auto i) {
2733  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2734  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2735  });
2736 
2737  // move on E
2738  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2739  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2740  I0,
2741  cde_lds_and_global_step);
2742  }
2743  });
2744  }
2745  }
2746 };
2747 
2748 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:266
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:23
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:275
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp:36
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:87
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:297
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:748
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:809
const index_t * p_sorted_token_ids
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:808
const BScaleDataType * p_b_scale_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:814
const BElementwiseOperation b_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:819
const BDataType * p_b_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:813
const index_t * p_max_token_id
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:810
DsGridPointer p_ds_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:815
const AElementwiseOperation a_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:818
CDataType * p_c_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:816
const AScaleDataType * p_a_scale_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:812
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:749
const CElementwiseOperation c_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:820
const ADataType * p_a_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:811
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:671
index_t BN0Shuffled
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:742
index_t AK0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:737
index_t NPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:734
index_t BK0Shuffled
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:743
index_t KBatch
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:732
index_t BK0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:738
index_t TopK
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:722
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:672
index_t KRead
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:735
index_t K
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:725
index_t MPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:733
index_t StrideC
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:731
index_t StrideScaleB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:729
index_t NumTokens
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:721
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:730
index_t MBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:739
index_t StrideScaleA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:727
index_t N
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:724
__host__ void Print() const
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:709
index_t StrideA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:726
index_t StrideB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:728
index_t M
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:723
index_t KPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:736
index_t NBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:740
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:824
index_t a_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:866
index_t b_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:867
index_t a_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:864
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:825
index_t b_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:865
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:171
static constexpr auto I6
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:181
static __host__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:273
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:299
static constexpr auto AK1Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:191
static constexpr auto AK0Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:189
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1042
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:249
static constexpr auto I1
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:176
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:282
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:253
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:647
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1983
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1257
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:659
static constexpr index_t NLane
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:222
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:988
static constexpr index_t SortedTileSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:228
static constexpr index_t scale_pack_size_b
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:232
static constexpr auto BK1Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:192
remove_cvref_t< decltype(BlockGemmMXBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1040
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1249
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:268
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:238
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:369
static constexpr auto BK0Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:190
static constexpr auto I9
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:184
static constexpr auto I4
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:179
static constexpr auto I3
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:178
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:305
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:287
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:263
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:186
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:312
static constexpr index_t KRepeat
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:225
static constexpr auto NXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:201
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:998
static constexpr auto KXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:202
BDataType LDSTypeB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:173
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:475
static constexpr auto I0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:175
static constexpr auto I5
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:180
static constexpr index_t scale_pack_size_a
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:231
static constexpr auto lcm_AK1_BK1
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:194
static constexpr index_t KLane
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:223
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:870
static constexpr auto MXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:200
static constexpr auto I8
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:183
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:593
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:602
static constexpr bool is_single_rate_mfma
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:195
static constexpr auto I2
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:177
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:583
static constexpr index_t APackedSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:210
static constexpr index_t NumDTensor
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:198
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:317
static constexpr index_t BPackedSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:211
static constexpr index_t KPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:219
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1242
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1064
ADataType LDSTypeA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:172
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:293
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:626
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:328
static __host__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:277
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:251
static constexpr auto I7
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:182
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:468
static constexpr index_t NWave
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:224
static constexpr auto is_scale_mfma
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:196
Definition: xdlops_gemm.hpp:942
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1343
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Definition: thread_group_tensor_slice_transfer_gather_direct_load.hpp:57
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:25
Definition: data_type.hpp:41
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:197
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:981
Definition: unary_element_wise_operation.hpp:308
Definition: unary_element_wise_operation.hpp:1023
Definition: dtype_vector.hpp:10
#define CK_ENV(name)
Definition: env.hpp:129