include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp Source File

include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp Source File#

Composable Kernel: include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp Source File
gridwise_moe_gemm_blockscale.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
16 
18 
19 #define DEBUG_LOG 0
20 
21 namespace ck {
22 
23 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
24 // kernel function Blockers:
25 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
26 // two lds chunks.
27 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
28 // buffer when we declare __shared__ inside blkgemmpipe
29 
31 {
32  gelu_and_mul = 0,
33  silu_and_mul = 1
34 };
35 
36 template <typename GridwiseGemm,
37  bool HasMainKBlockLoop,
38  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
39  index_t MinimumOccupancy = 1,
40  TailNumber TailNum = TailNumber::Even>
41 __global__ void
42 #if CK_USE_LAUNCH_BOUNDS
43 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
44 #endif
45  // __attribute__((amdgpu_waves_per_eu(1, 1)))
46  kernel_moe_gemm(typename GridwiseGemm::Argument karg)
47 {
48 #if defined(__gfx9__)
49  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
50 
51  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
52 
53  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
54  karg.p_sorted_token_ids,
55  karg.p_sorted_expert_ids,
56  karg.p_max_token_id,
57  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
58  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
59  karg.p_ds_grid,
60  karg.p_c_grid,
61  karg.p_a_scale_grid,
62  karg.p_b_scale_grid,
63  p_shared,
64  karg,
65  karg.a_element_op,
66  karg.b_element_op,
67  karg.c_element_op);
68 #else
69  ignore = karg;
70 #endif // end of if (defined(__gfx9__))
71 }
72 
73 template <typename GridwiseGemm,
74  bool HasMainKBlockLoop,
75  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
76  index_t MinimumOccupancy = 1,
77  TailNumber TailNum = TailNumber::Even>
78 __global__ void
79 #if CK_USE_LAUNCH_BOUNDS
80 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
81 #endif
82  // __attribute__((amdgpu_waves_per_eu(1, 1)))
83  kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
84 {
85 #if defined(__gfx9__)
86  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
87  __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
88 
89  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
90 
91  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
92  karg.p_sorted_token_ids,
93  karg.p_sorted_expert_ids,
94  karg.p_max_token_id,
95  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
96  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
97  karg.p_ds_grid,
98  karg.p_c_grid,
99  karg.p_a_scale_grid,
100  karg.p_b_scale_grid,
101  p_shared,
102  p_shared1,
103  karg,
104  karg.a_element_op,
105  karg.b_element_op,
106  karg.c_element_op);
107 #else
108  ignore = karg;
109 #endif // end of if (defined(__gfx9__))
110 }
111 
112 template <typename ALayout,
113  typename BLayout,
114  typename DsLayout,
115  typename CLayout,
116  typename ADataType,
117  typename BDataType,
118  typename AccDataType,
119  typename CShuffleDataType,
120  typename DsDataType,
121  typename CDataType,
122  typename AElementwiseOperation,
123  typename BElementwiseOperation,
124  typename CElementwiseOperation,
126  index_t BlockSize,
127  index_t ScaleBlockM,
128  index_t ScaleBlockN,
129  index_t ScaleBlockK,
130  index_t MPerBlock,
131  index_t NPerBlock,
132  index_t KPerBlock,
133  index_t AK1Value,
134  index_t BK1Value,
135  index_t MPerXdl,
136  index_t NPerXdl,
137  index_t MXdlPerWave,
138  index_t NXdlPerWave,
139  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
140  typename ABlockTransferThreadClusterArrangeOrder,
141  typename ABlockTransferSrcAccessOrder,
142  index_t ABlockTransferSrcVectorDim,
143  index_t ABlockTransferSrcScalarPerVector,
144  index_t ABlockTransferDstScalarPerVector_AK1,
145  bool AThreadTransferSrcResetCoordinateAfterRun,
146  index_t ABlockLdsExtraM,
147  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
148  typename BBlockTransferThreadClusterArrangeOrder,
149  typename BBlockTransferSrcAccessOrder,
150  index_t BBlockTransferSrcVectorDim,
151  index_t BBlockTransferSrcScalarPerVector,
152  index_t BBlockTransferDstScalarPerVector_BK1,
153  bool BThreadTransferSrcResetCoordinateAfterRun,
154  index_t BBlockLdsExtraN,
155  index_t CShuffleMXdlPerWavePerShuffle,
156  index_t CShuffleNXdlPerWavePerShuffle,
157  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
158  typename CDEShuffleBlockTransferScalarPerVectors,
161  index_t ActivationOperation = 0,
162  bool NSwizzle = false,
163  bool IsInputGemm = true,
164  bool MulRoutedWeight = true,
165  typename IndexType = index_t,
166  typename ComputeTypeA = CDataType,
167  typename ComputeTypeB = ComputeTypeA,
168  typename LDSTypeA = ADataType,
169  typename LDSTypeB = BDataType>
171 {
172  using AScaleType = float;
173  using BScaleType = float;
174 
175  static constexpr auto I0 = Number<0>{};
176  static constexpr auto I1 = Number<1>{};
177  static constexpr auto I2 = Number<2>{};
178  static constexpr auto I3 = Number<3>{};
179  static constexpr auto I4 = Number<4>{};
180  static constexpr auto I5 = Number<5>{};
181  static constexpr auto I6 = Number<6>{};
182  static constexpr auto I7 = Number<7>{};
183 
185  CDEShuffleBlockTransferScalarPerVectors{}[I0];
186  // K1 should be Number<...>
187  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
188  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
189  static constexpr auto AK1Number = Number<AK1Value>{};
190  static constexpr auto BK1Number = Number<BK1Value>{};
191  static constexpr auto BlockSizeNumber = Number<BlockSize>{};
192 
193  static constexpr index_t NumDTensor = DsDataType::Size();
194 
196  static constexpr index_t KPack =
198  static constexpr index_t KGroup = []() {
200  // On gfx950, we have a mfma that required 32 f8 elements as input,
201  // splited into 2 groups of 16 f8 elements.
202  // the 2 groups is not contiguous in the B preshuffed layout.
203  // and we do not want it to be contiguous in the B preshuffled layout
204  // because a memory instruction can only read 16 f8 elements at a time.
205  return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
206  else
207  return 1;
208  }();
209  static constexpr index_t KLane =
211  static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
212  static constexpr index_t NLane = NPerXdl;
213  static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
214  // static constexpr index_t NumTokens = 1;
215  static constexpr index_t SortedTileSize = MPerBlock;
216 
217  static constexpr auto MakeDsGridPointer()
218  {
219  return generate_tuple(
220  [&](auto i) {
221  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
222 
223  return static_cast<const DDataType*>(nullptr);
224  },
226  }
227 
228  using DsGridPointer = decltype(MakeDsGridPointer());
229 
231 
232  static constexpr index_t APackedSize = []() {
234  return 2;
235  else
236  return 1;
237  }();
238 
239  static constexpr index_t BPackedSize = []() {
241  return 2;
242  else
243  return 1;
244  }();
245 
246  __host__ static auto CalculateGridSize(index_t M, index_t N)
247  {
248  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
249  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
250  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
251  const index_t gridy = NSwizzle ? 1 : mblock;
252  return std::make_tuple(gridx, gridy, 1);
253  }
254 
255  __host__ __device__ static auto CalculateMPadded(index_t M)
256  {
257  return math::integer_least_multiple(M, MPerBlock);
258  }
259 
260  __host__ __device__ static auto CalculateNPadded(index_t N)
261  {
262  return math::integer_least_multiple(N, NPerBlock);
263  }
264 
265  __host__ __device__ static auto CalculateBN0Shuffled(index_t N)
266  {
267  return math::integer_divide_ceil(N, NLane);
268  }
269  __host__ __device__ static auto CalculateBK0Shuffled(index_t K)
270  {
272  }
273 
274  __host__ __device__ static auto CalculateKPadded(index_t K)
275  {
276  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
277  }
278 
279  __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
280  {
281  auto K_t = K_Batch * KPerBlock;
282  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
283  }
284 
285  __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
286  {
287  auto K_t = K_Batch * KPerBlock;
288  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
289  }
290 
291  __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
292  {
293  auto K_t = K_Batch * KPerBlock;
294  return (K + K_t - 1) / K_t * KPerBlock;
295  }
296 
297  __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
298  {
299  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
300  auto K_t = K_Batch * KReadVec;
301  return (K + K_t - 1) / K_t * KReadVec;
302  }
303 
304  __host__ __device__ static auto CalculateMBlock(index_t M)
305  {
306  return math::integer_divide_ceil(M, MPerBlock);
307  }
308 
309  __host__ __device__ static auto CalculateNBlock(index_t N)
310  {
311  return math::integer_divide_ceil(N, NPerBlock);
312  }
313 
314  template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
315  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
316  {
317  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
318  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
319 
321  TileDesc_K0_MN_K1{},
327  }
328 
329  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
330  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
331  {
332  const auto a_grid_desc_mraw_kraw = [&]() {
333  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
334  {
335  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
336  }
337  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
338  {
339  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
340  }
341  }();
342 
344 
345  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
346  GemmSpec == GemmSpecialization::MNKPadding)
347  {
348  // pad both M and K
349  const auto a_grid_desc_m_k =
350  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
352  make_right_pad_transform(K, KPad - K)),
355 
356  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
357  a_grid_desc_m_k,
362 
363  return a_grid_desc_ak0_m_ak1;
364  }
365  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
366  GemmSpec == GemmSpecialization::MNPadding)
367  {
368  // pad M, but not K
369  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
370  a_grid_desc_mraw_kraw,
372  make_right_pad_transform(M, MPad - M)),
375 
376  return a_grid_desc_ak0_m_ak1;
377  }
378  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
379  GemmSpec == GemmSpecialization::NKPadding)
380  {
381  // pad K, but not M
382  const auto a_grid_desc_m_k = transform_tensor_descriptor(
383  a_grid_desc_mraw_kraw,
387 
388  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
389  a_grid_desc_m_k,
394 
395  return a_grid_desc_ak0_m_ak1;
396  }
397  else
398  {
399  // not pad M or K
400  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
401  a_grid_desc_mraw_kraw,
406 
407  return a_grid_desc_ak0_m_ak1;
408  }
409  }
410 
411  __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
412  {
413  constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack / KGroup>{};
415  make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
416  make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
417  }
418 
419  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
420  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
421  {
422  const auto b_grid_desc_nraw_kraw = [&]() {
424  {
425  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
426  }
428  {
429  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
430  }
431  }();
432 
434 
435  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
436  GemmSpec != GemmSpecialization::Default),
437  "pk_i4_t does not support padding");
438 
439  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
440  GemmSpec == GemmSpecialization::MNKPadding)
441  {
442  // pad both N and K
443  const auto b_grid_desc_n_k =
444  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
446  make_right_pad_transform(K, KPad - K)),
449 
450  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
451  b_grid_desc_n_k,
456 
457  return b_grid_desc_bk0_n_bk1;
458  }
459  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
460  GemmSpec == GemmSpecialization::MNPadding)
461  {
462  // pad N, but not K
463  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
464  b_grid_desc_nraw_kraw,
466  make_right_pad_transform(N, NPad - N)),
469 
470  return b_grid_desc_bk0_n_bk1;
471  }
472  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
473  GemmSpec == GemmSpecialization::MKPadding)
474  {
475  // pad K, but not N
476  const auto b_grid_desc_n_k = transform_tensor_descriptor(
477  b_grid_desc_nraw_kraw,
481 
482  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
483  b_grid_desc_n_k,
488 
489  return b_grid_desc_bk0_n_bk1;
490  }
491  else
492  {
493  // not pad N or K
494  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
495  b_grid_desc_nraw_kraw,
500 
501  return b_grid_desc_bk0_n_bk1;
502  }
503  }
504 
505  template <typename ABlockDesc_AK0_M_AK1>
506  __host__ __device__ static constexpr auto
507  MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
508  {
509  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
510 
511  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
512  }
513 
514  template <typename BBlockDesc_BK0_N_BK1>
515  __host__ __device__ static constexpr auto
516  MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
517  {
518  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
519  }
520 
521  template <typename ELayout>
522  __host__ __device__ static auto MakeCGridDescriptor_M_N(
523  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
524  {
525  const auto c_grid_desc_mraw_nraw = [&]() {
527  {
528  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
529  }
531  {
532  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
533  }
534  }();
535 
536  // pad M and N
537  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
539  make_right_pad_transform(N, NPad - N)),
542  }
543 
544  template <typename DLayout>
545  __host__ __device__ static auto
547  {
548  const auto c_grid_desc_mraw_nraw = [&]() {
550  {
551  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
552  }
554  {
555  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
556  }
557  }();
558 
559  // pad M and N
560  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
562  make_right_pad_transform(N, NPad - N)),
565  }
566 
567  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
568  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
569  {
570  return generate_tuple(
571  [&](auto i) {
572  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
573  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
574  },
576  }
577 
578  template <typename DsGridDesc>
580  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
581  {
582  return generate_tuple(
583  [&](auto i) {
585  ds_grid_desc_m_n[i], MBlock, NBlock);
586  },
588  }
589 
590  using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))>;
591 
592  struct Problem
593  {
594  __host__ __device__ Problem(index_t NumTokens_,
595  index_t TopK_,
596  index_t M_,
597  index_t N_,
598  index_t K_,
599  index_t StrideA_,
600  index_t StrideB_,
601  std::array<index_t, NumDTensor> StrideDs_,
602  index_t StrideC_,
603  index_t KBatch_)
604  : NumTokens{NumTokens_},
605  TopK{TopK_},
606  M{M_},
607  N{N_},
608  K{K_},
609  StrideA{StrideA_},
610  StrideB{StrideB_},
611  StrideDs{StrideDs_},
612  StrideC{StrideC_},
613  KBatch{KBatch_},
616  KRead{CalculateKRead(K_, KBatch_)},
617  KPadded{CalculateKPadded(K_, KBatch_)},
618  AK0{CalculateAK0Padded(K_, KBatch_)},
619  BK0{CalculateBK0Padded(K_, KBatch_)},
620  MBlock{CalculateMBlock(M_)},
621  NBlock{CalculateNBlock(N_)},
624  {
625  }
626 
627  __host__ void Print() const
628  {
629  std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
630  << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
631  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
632  << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
633  << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
634  << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
635  << "NBlock: " << NBlock << "}" << std::endl;
636  }
637 
645  std::array<index_t, NumDTensor> StrideDs;
656  // FOR PRESHUFFLE ONLY
659  };
660 
661  // Argument
663  {
664  __host__ Argument(const index_t* p_sorted_token_ids_,
665  const index_t* p_sorted_expert_ids_,
666  const index_t* p_max_token_id_,
667  const ADataType* p_a_grid_,
668  const BDataType* p_b_grid_,
669  std::array<const void*, NumDTensor> p_ds_grid_,
670  CDataType* p_c_grid_,
671  index_t NumTokens_,
672  index_t TopK_,
673  index_t M_,
674  index_t N_,
675  index_t K_,
676  index_t StrideA_,
677  index_t StrideB_,
678  std::array<index_t, NumDTensor> StrideDs_,
679  index_t StrideC_,
680  const AScaleType* p_a_scale_grid_,
681  const BScaleType* p_b_scale_grid_,
682  index_t k_batch_,
683  AElementwiseOperation a_element_op_,
684  BElementwiseOperation b_element_op_,
685  CElementwiseOperation c_element_op_)
686  : Problem{NumTokens_,
687  TopK_,
688  M_,
689  N_,
690  K_,
691  StrideA_,
692  StrideB_,
693  StrideDs_,
694  StrideC_,
695  k_batch_},
696  p_sorted_token_ids{p_sorted_token_ids_},
697  p_sorted_expert_ids{p_sorted_expert_ids_},
698  p_max_token_id{p_max_token_id_},
699  p_a_grid{p_a_grid_},
700  p_b_grid{p_b_grid_},
701  p_ds_grid{},
702  p_c_grid{p_c_grid_},
703  p_a_scale_grid{p_a_scale_grid_},
704  p_b_scale_grid{p_b_scale_grid_},
705  a_element_op{a_element_op_},
706  b_element_op{b_element_op_},
707  c_element_op{c_element_op_}
708  {
709 
710  // populate pointer, desc for Ds
711  static_for<0, NumDTensor, 1>{}([&](auto i) {
712  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
713 
714  // D pointer
715  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
716  });
717  }
718 
722  const ADataType* p_a_grid;
723  const BDataType* p_b_grid;
725  CDataType* p_c_grid;
726 
729 
730  const AElementwiseOperation a_element_op;
731  const BElementwiseOperation b_element_op;
732  const CElementwiseOperation c_element_op;
733  };
734 
736  {
737  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
738  {
739  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
740  {
741  a_k_split_offset = k_id * karg.KRead / APackedSize;
742  }
743  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
744  {
745  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
746  }
747 
748  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
749  {
750  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
751  }
752  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
753  {
754  // KPack * NLane * KLane * K0 * N0
755  b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
756  }
757 
758  if(k_id < karg.KBatch - 1)
759  {
760  karg.K = karg.KRead;
761  }
762  else
763  {
764  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
765  }
766  }
767 
770  };
771 
772  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
773  {
774  // A matrix in LDS memory, dst of blockwise copy
775  if constexpr(ABlockLdsExtraM)
776  {
780  }
781  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
782  // in some cases.
784  {
785  constexpr auto a_lds_block_desc =
788 
789  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
790  a_lds_block_desc,
796 
797  return a_lds_block_desc_permuted;
798  }
799  else // ColumnMajor A
800  {
801  // kfold and mpair dimension is not always required.
802  // more dimension in merge_transform increase the difficulty of generating immarg offset
803  // for compiler.
804  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
805  constexpr auto M1 = MPerBlock / M0;
806 
807  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
808  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
809  constexpr auto KThreadRead = 64 / MPerXdl;
810  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
811 
812  constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
813  ? 1
814  : 128 / (AK1Number * M0 * sizeof(LDSTypeA));
815  constexpr auto KThreadReadPerm =
816  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
817  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
818  : KThreadRead;
819 
820  // 1<=mpair<=n0
821  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128)
822  ? 1
823  : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0
824  ? M0
825  : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA)));
826 
827  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
831  Number<kfold * M0 / mpair>{},
832  Number<mpair>{},
833  AK1Number));
834 
835  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
836  a_lds_block_desc,
837  make_tuple(
841  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
844  make_tuple(
846  make_tuple(
848 
849  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
850  a_lds_block_desc_permuted,
851  make_tuple(
859  Sequence<1>{},
860  Sequence<2>{},
861  Sequence<3>{},
862  Sequence<4>{},
863  Sequence<5>{}),
865  Sequence<2>{},
866  Sequence<0, 3>{},
867  Sequence<4, 5>{},
868  Sequence<6>{},
869  Sequence<7>{}));
870 
871  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
872  a_lds_block_desc_unmerged,
875  Number<KThreadWrite / kfold / KThreadReadPerm>{},
876  Number<kfold>{},
883 
884  return a_lds_block_desc_ak0_m_ak1;
885  }
886  }
887 
888  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
889  {
890  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
893  }
894 
896  {
897  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
898 
899  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
901  make_tuple(I1,
903  I1,
905 
906  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
907  }
908 
911  BlkGemmPipelineVer,
912  BlkGemmPipeSched,
913  BlockSize,
914  ADataType,
915  BDataType,
916  ComputeTypeA,
917  AccDataType,
924  ABlockTransferSrcScalarPerVector,
925  BBlockTransferSrcScalarPerVector,
926  MPerBlock,
927  NPerBlock,
928  KPerBlock,
929  ScaleBlockM,
930  ScaleBlockN,
931  ScaleBlockK,
932  MPerXdl,
933  NPerXdl,
934  MXdlPerWave,
935  NXdlPerWave,
936  KPack,
937  IsInputGemm>())>;
938 
939  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
940  {
941  // LDS allocation for A and B: be careful of alignment
942  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
943  // lds max alignment
944  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
945 
946  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
947  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
948 
949  // LDS allocation for C shuffle in LDS
950  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
952 
953  constexpr auto c_block_size =
954  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
955 
956  return math::max(a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize,
957  c_block_size * sizeof(CShuffleDataType));
958  }
959 
960  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
961  __host__ static constexpr bool CheckValidity(const Argument& karg)
962  {
963  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
964  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
965  "Invalid tuning param!");
966 
967  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
972  {
973  if(!(karg.M % MPerBlock == 0))
974  {
975 #if DEBUG_LOG
976  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
977  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
978  << std::endl;
979 
980 #endif // DEBUG_LOG
981  return false;
982  }
983  }
984 
985  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
990  {
991  if(!(karg.N % NPerBlock == 0))
992  {
993 #if DEBUG_LOG
994  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
995  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
996  << std::endl;
997 
998 #endif // DEBUG_LOG
999  return false;
1000  }
1001  }
1002 
1003  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1007  {
1008 
1009  auto K_t = karg.KBatch * KPerBlock;
1010  if(!(karg.K % K_t == 0))
1011  {
1012 #if DEBUG_LOG
1013  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1014  << karg.K << " " << __FILE__ << ":" << __LINE__
1015  << ", in function: " << __func__ << std::endl;
1016 
1017 #endif // DEBUG_LOG
1018  return false;
1019  }
1020  }
1021  else
1022  {
1023  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1024  auto K_t = karg.KBatch * KReadVec;
1025  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1026  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1027  {
1028  return false;
1029  }
1030  }
1031 
1033  {
1034  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1035  {
1036 #if DEBUG_LOG
1037  std::cout << "Arg K (" << karg.K
1038  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1039  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1040  << __LINE__ << ", in function: " << __func__ << std::endl;
1041 
1042 #endif // DEBUG_LOG
1043  return false;
1044  }
1045  }
1046  else
1047  {
1048  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1049  {
1050 #if DEBUG_LOG
1051  std::cout << "Arg M (" << karg.M
1052  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1053  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1054  << __LINE__ << ", in function: " << __func__ << std::endl;
1055 
1056 #endif // DEBUG_LOG
1057  return false;
1058  }
1059  }
1060 
1062  {
1063  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1064  {
1065 #if DEBUG_LOG
1066  std::cout << "Arg N (" << karg.N
1067  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1068  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1069  << __LINE__ << ", in function: " << __func__ << std::endl;
1070 
1071 #endif // DEBUG_LOG
1072  return false;
1073  }
1074  }
1075  else
1076  {
1077  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1078  {
1079 #if DEBUG_LOG
1080  std::cout << "Arg K (" << karg.K
1081  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1082  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1083  << __LINE__ << ", in function: " << __func__ << std::endl;
1084 
1085 #endif // DEBUG_LOG
1086  return false;
1087  }
1088  }
1089 
1091  {
1093  {
1094 #if DEBUG_LOG
1095  std::cout << "Arg N (" << karg.N
1096  << ") value is not a multiple of "
1097  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1098  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1099  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1100 
1101 #endif // DEBUG_LOG
1102  return false;
1103  }
1104  }
1105  else
1106  {
1108  {
1109 #if DEBUG_LOG
1110  std::cout << "Arg M (" << karg.M
1111  << ") value is not a multiple of "
1112  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1113  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1114  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1115 
1116 #endif // DEBUG_LOG
1117  return false;
1118  }
1119  }
1120 
1121  // check gridwise gemm pipeline
1122 #if 0
1123  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1124 
1125  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1126  {
1127  return false;
1128  }
1129 #endif
1130  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1131  return true;
1132  }
1133 
1134  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1135  {
1136  const index_t num_loop = K / KPerBlock;
1137 
1138  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1139  }
1140 
1141  __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1142  {
1143  const index_t num_loop = K / KPerBlock;
1144 
1145  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1146  }
1147 
1148  template <typename CGridDesc>
1150  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1151  {
1152  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1153  c_grid_desc_m_n,
1158 
1159  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1160  }
1161 
1162  // return block_id to C matrix tile idx (m0, n0) mapping
1163  // if arch = gfx942
1164  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1165  // NPerBlock>;
1166 
1167  template <bool HasMainKBlockLoop,
1168  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1169  TailNumber TailNum = TailNumber::Odd>
1170  __device__ static void Run(const index_t* p_sorted_token_ids,
1171  const index_t* p_sorted_expert_ids,
1172  const index_t* p_max_token_id,
1173  const ADataType* p_a_grid,
1174  const BDataType* p_b_grid,
1175  DsGridPointer& p_ds_grid,
1176  CDataType* p_c_grid,
1177  const AScaleType* p_a_scale_grid,
1178  const BScaleType* p_b_scale_grid,
1179  void* p_shared,
1180  const Problem& problem,
1181  AElementwiseOperation a_element_op,
1182  BElementwiseOperation b_element_op,
1183  CElementwiseOperation c_element_op)
1184  {
1185  ignore = b_element_op;
1186  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1187  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1188  problem.MPadded,
1189  problem.K,
1190  problem.KPadded,
1191  problem.StrideA,
1192  problem.AK0);
1193  const auto b_grid_desc_bpreshuffled =
1195  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1196  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1197  problem.MPadded,
1198  problem.N,
1199  problem.NPadded,
1200  problem.StrideC);
1201 
1202  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
1203  make_tuple(math::integer_divide_ceil(IsInputGemm ? problem.NumTokens
1204  : problem.NumTokens * problem.TopK,
1205  ScaleBlockM),
1206  math::integer_divide_ceil(problem.K, ScaleBlockK)),
1207  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1208  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
1209  make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
1210  math::integer_divide_ceil(problem.K, ScaleBlockK)),
1211  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1212 
1213  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1215  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1216  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1217  // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
1218  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1219  if(expert_block_id * MPerBlock >= max_token_id)
1220  return;
1221  const index_t expert_id =
1222  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1223  const auto block_mn = [&]() -> std::pair<int, int> {
1224  if constexpr(NSwizzle)
1225  {
1226  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1227  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1228  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1229  const index_t expert_swizzle =
1230  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1231  const index_t bid_new = blockIdx.x - prefix_block;
1232  const index_t nid = __builtin_amdgcn_readfirstlane(
1233  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1234  const index_t mid =
1235  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1236  return {nid, mid};
1237  }
1238  else
1239  {
1240  return {blockIdx.x, blockIdx.y};
1241  }
1242  }();
1243  const index_t block_n_id = block_mn.first;
1244  const index_t block_m_id = block_mn.second;
1245  const index_t token0 =
1246  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1247 
1248  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1249  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1250  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1251  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1252  constexpr auto AKThreads = AK0Threads * AK1Threads;
1253  constexpr auto AMRepeats = MPerBlock / AMThreads;
1254  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1255 
1256  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1257  return;
1259  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1260  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1261  index_t token_offset = fused_token & 0xffffff;
1262  if constexpr(!IsInputGemm)
1263  {
1264  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1265  }
1266  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1267  });
1268  const index_t expert_stride =
1269  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1270  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1271  math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
1272  math::integer_divide_ceil(problem.K, ScaleBlockK));
1273 
1274  // N0, K0, Blocksize*KPack
1275  const index_t n_block_data_idx_on_grid =
1276  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1277 
1278  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1279  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1280  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1281  p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
1282  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1283 
1284  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1285  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1286  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1287  p_b_scale_grid + expert_id * expert_scale_stride,
1288  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1289 
1290  // A matrix in LDS memory, dst of blockwise copy
1291  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1292 
1293  // B matrix in LDS memory, dst of blockwise copy
1294  // dummy
1295  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1296  // A matrix blockwise copy
1297  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1299  AElementwiseOperation,
1303  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1304  ABlockTransferThreadClusterArrangeOrder,
1305  ADataType,
1306  LDSTypeA,
1307  decltype(a_grid_desc_ak0_m_ak1),
1308  decltype(a_block_desc_ak0_m_ak1),
1309  ABlockTransferSrcAccessOrder,
1311  ABlockTransferSrcVectorDim,
1312  2,
1313  ABlockTransferSrcScalarPerVector,
1314  ABlockTransferDstScalarPerVector_AK1,
1315  1,
1316  1,
1317  AThreadTransferSrcResetCoordinateAfterRun,
1318  true,
1319  IndexType,
1320  1,
1321  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1322  make_multi_index(0, 0, 0),
1323  a_element_op,
1324  a_block_desc_ak0_m_ak1,
1325  make_multi_index(0, 0, 0),
1327  gather_offsets);
1328 
1329  // Thread-wise copy
1330  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1331  auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1332  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1333 
1334  auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
1335  BDataType,
1336  BDataType,
1337  decltype(b_grid_desc_bpreshuffled),
1338  decltype(b_block_desc_bk0_n_bk1),
1341  3,
1342  BBlockTransferSrcScalarPerVector,
1343  BThreadTransferSrcResetCoordinateAfterRun,
1344  true>(b_grid_desc_bpreshuffled,
1345  make_multi_index(n_block_data_idx_on_grid,
1347  0,
1348  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1349 
1350  // LDS allocation for A and B: be careful of alignment
1351  // Cast after lds
1352  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1353  static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1354 
1355  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1356  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
1357 
1358  // Blockwise GEMM pipeline
1359  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1360  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1361  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1362  decltype(c_thread_buf) c_thread_buf_up;
1363 
1364  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1365  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1366  KPerBlock);
1367 
1368  constexpr index_t ScaleSliceSizeM = MXdlPerWave;
1369  constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
1370  constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
1371 
1372  // ScaleSliceSizeK is last dimension in A/B scale for vector memory access
1373  // ScaleSliceSizeK is first dimension in C scale for packed math
1374  constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
1376 
1377  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
1378  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
1379  auto a_thread_offset =
1380  get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl;
1381 
1382  constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
1384 
1385  constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
1387 
1388  // get each thread's offset in the scale tensor
1389  // A scale
1390  const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
1391 
1392  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
1393  return;
1394  StaticallyIndexedArray<index_t, MXdlPerWave> scale_gather_offsets;
1395  static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
1396  const index_t fused_token =
1397  p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
1398  index_t token_offset = fused_token & 0xffffff;
1399  if constexpr(!IsInputGemm)
1400  {
1401  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1402  }
1403  scale_gather_offsets(m0) =
1404  token_offset * math::integer_divide_ceil(problem.K, ScaleBlockK);
1405  });
1406 
1407  auto a_scale_thread_copy =
1409  AScaleType,
1410  decltype(a_scale_grid_desc_am_ak),
1411  decltype(a_scale_thread_desc),
1414  1,
1415  ScaleSliceSizeK,
1416  1,
1417  false,
1418  MXdlPerWave>(
1419  a_scale_grid_desc_am_ak, make_multi_index(0, 0), scale_gather_offsets);
1420 
1421  auto b_scale_thread_copy =
1423  BScaleType,
1424  decltype(b_scale_grid_desc_bn_ak),
1425  decltype(b_scale_thread_desc),
1428  1,
1429  ScaleSliceSizeK,
1430  1,
1431  false>(
1432  b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
1433 
1434  // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
1435  constexpr auto a_scale_thread_slice_copy_step =
1436  make_tuple(make_multi_index(0, 0), make_multi_index(0, ScaleSliceSizeK));
1437  constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
1438 
1439  constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
1440  if constexpr(IsInputGemm)
1441  {
1442  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
1443  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1444  p_b_grid_up + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
1445  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1446  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
1447  BDataType,
1448  BDataType,
1449  decltype(b_grid_desc_bpreshuffled),
1450  decltype(b_block_desc_bk0_n_bk1),
1453  3,
1454  BBlockTransferSrcScalarPerVector,
1455  BThreadTransferSrcResetCoordinateAfterRun,
1456  true>(b_grid_desc_bpreshuffled,
1457  make_multi_index(n_block_data_idx_on_grid,
1459  0,
1460  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1461  const BScaleType* p_b_scale_grid_up =
1462  p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
1463  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1464  p_b_scale_grid_up + expert_id * expert_scale_stride,
1465  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1466  auto b_scale_thread_copy_up =
1468  BScaleType,
1469  decltype(b_scale_grid_desc_bn_ak),
1470  decltype(b_scale_thread_desc),
1473  1,
1474  ScaleSliceSizeK,
1475  1,
1476  false>(
1477  b_scale_grid_desc_bn_ak,
1478  make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
1479 
1480  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1481  a_grid_desc_ak0_m_ak1,
1482  a_block_desc_ak0_m_ak1,
1483  a_blockwise_copy,
1484  a_grid_buf,
1485  a_block_buf,
1486  a_block_slice_copy_step,
1487 
1488  b_grid_desc_bpreshuffled,
1489  b_block_desc_bk0_n_bk1,
1490  b_blockwise_copy,
1491  b_blockwise_copy_up,
1492  b_grid_buf,
1493  b_grid_buf_up,
1494  b_block_buf,
1495  b_block_slice_copy_step,
1496 
1497  c_scale_thread_desc,
1498  c_thread_buf,
1499  c_thread_buf_up,
1500 
1501  a_scale_grid_desc_am_ak,
1502  a_scale_thread_desc,
1503  a_scale_thread_copy,
1504  a_scale_grid_buf,
1505  a_scale_thread_slice_copy_step,
1506 
1507  b_scale_grid_desc_bn_ak,
1508  b_scale_thread_desc,
1509  b_scale_thread_copy,
1510  b_scale_thread_copy_up,
1511  b_scale_grid_buf,
1512  b_scale_grid_buf_up,
1513  b_scale_thread_slice_copy_step,
1514 
1515  num_k_block_main_loop);
1516  }
1517  else
1518  {
1519  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1520  a_grid_desc_ak0_m_ak1,
1521  a_block_desc_ak0_m_ak1,
1522  a_blockwise_copy,
1523  a_grid_buf,
1524  a_block_buf,
1525  a_block_slice_copy_step,
1526 
1527  b_grid_desc_bpreshuffled,
1528  b_block_desc_bk0_n_bk1,
1529  b_blockwise_copy,
1530  b_grid_buf,
1531  b_block_buf,
1532  b_block_slice_copy_step,
1533 
1534  c_scale_thread_desc,
1535  c_thread_buf,
1536 
1537  a_scale_grid_desc_am_ak,
1538  a_scale_thread_desc,
1539  a_scale_thread_copy,
1540  a_scale_grid_buf,
1541  a_scale_thread_slice_copy_step,
1542 
1543  b_scale_grid_desc_bn_ak,
1544  b_scale_thread_desc,
1545  b_scale_thread_copy,
1546  b_scale_grid_buf,
1547  b_scale_thread_slice_copy_step,
1548 
1549  num_k_block_main_loop);
1550  }
1551 
1552  // shuffle C and write out
1553  {
1554  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1555  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1556  "wrong!");
1557 
1558  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1559 
1560  // transposed XDL
1561  // TODO: hacky, fix it!
1562  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
1563  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1564 
1565  // TODO: hacky, fix it!
1566  // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
1567  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
1568  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1569 
1570  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
1571  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
1572  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
1573  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
1574  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
1575  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
1576  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
1577  constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
1578 
1579  static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
1580  static_assert(M0 * M1 * M2 == MPerBlock);
1581  static_assert(N4 == 4);
1582  const index_t m1 = get_warp_local_1d_id() / NWave;
1583  const index_t m2 = threadIdx.x % get_warp_size() % M2;
1584 
1585  float topk_weight;
1586  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
1587  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
1588  if constexpr(MulRoutedWeight)
1589  {
1590  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
1591  topk_weight = p_ds_grid[I0][m_pos];
1592  }
1593  static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
1594  static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
1595  constexpr index_t c_offset =
1596  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1597  make_tuple(m0, n0, n2 * N4 + n4));
1598  constexpr auto cidx = Number<c_offset>{};
1599  if constexpr(IsInputGemm) // gu fusion, elementwise
1600  {
1601  if constexpr(ActivationOperation == Activation::silu_and_mul)
1602  {
1603  float gate = c_thread_buf[cidx];
1604  float up = c_thread_buf_up[cidx];
1605  if constexpr(MulRoutedWeight)
1606  {
1607  gate = gate * topk_weight;
1608  up = up * topk_weight;
1609  }
1611  {
1612  gate *= 16;
1613  up *= 16;
1614  }
1616  c_thread_buf(cidx) = gate * up;
1617  }
1618  else if(ActivationOperation == Activation::gelu_and_mul)
1619  {
1620  float gate = c_thread_buf[cidx];
1621  float up = c_thread_buf_up[cidx];
1622  if constexpr(MulRoutedWeight)
1623  {
1624  gate = gate * topk_weight;
1625  up = up * topk_weight;
1626  }
1628  {
1629  gate *= 16;
1630  up *= 16;
1631  }
1633  c_thread_buf(cidx) = gate * up;
1634  }
1635  }
1636  else
1637  {
1638  if constexpr(MulRoutedWeight)
1639  {
1640  c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
1641  }
1642  }
1643  });
1644  });
1645  });
1646  });
1647 
1648  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1650 
1651  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1652  static_cast<CShuffleDataType*>(p_shared),
1653  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1654 
1655  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
1656  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1657  make_tuple(
1660  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1661  M1, // M1 = MWave
1662  M2)), // M2 = MPerXdl
1665  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1666  N1, // N1 = NWave
1667  N2, // N2 * N3 * N4 = NPerXdl
1668  N3,
1669  N4))),
1671  make_tuple(
1673 
1674  // calculate origin of thread output tensor on global memory
1675  // blockwise GEMM c matrix starting index
1676  const auto c_thread_mtx_on_block =
1677  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1678 
1679  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1680  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1681 
1682  const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
1686  make_tuple(Sequence<0>{}));
1687 
1688  const auto m_thread_data_on_block_idx =
1689  m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
1690  make_multi_index(m_thread_data_on_block));
1691 
1692  const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
1694  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
1696  make_tuple(Sequence<0>{}));
1697 
1698  const auto n_thread_data_on_block_idx =
1699  n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
1700  make_multi_index(n_thread_data_on_block));
1701 
1702  // shuffle: threadwise copy C from VGPR to LDS
1703  auto c_thread_copy_vgpr_to_lds =
1705  CShuffleDataType,
1706  decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1707  decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1709  Sequence<CShuffleMXdlPerWavePerShuffle,
1710  CShuffleNXdlPerWavePerShuffle,
1711  I1,
1712  I1,
1713  I1,
1714  N2,
1715  I1,
1716  N4>,
1718  7,
1719  1,
1721  1,
1722  true>{
1723  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1724  make_multi_index(0,
1725  0,
1726  m_thread_data_on_block_idx[I1],
1727  n_thread_data_on_block_idx[I1],
1728  m_thread_data_on_block_idx[I2],
1729  n_thread_data_on_block_idx[I2],
1730  n_thread_data_on_block_idx[I3],
1731  n_thread_data_on_block_idx[I4]),
1733 
1734  using EDataType = CDataType;
1735 
1736  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1737  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1738 
1739  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1741  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1742 
1743  const auto ds_grid_buf = generate_tuple(
1744  [&](auto i) {
1745  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
1746  const DDataType* ptr_ = p_ds_grid[i];
1747  // hack logic here to support different kind of strides. todo fix it.
1748  // ascale t, 1; bscale E, N, 1, move ptr to E
1749  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1750  ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
1751  },
1752  Number<NumDTensor>{});
1753 
1754  // tuple of reference to C/Ds tensor descriptors
1755  const auto c_ds_desc_refs = concat_tuple_of_reference(
1756  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1757  generate_tie([&](auto i) -> const auto& // return type should be reference
1758  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1759  Number<NumDTensor>{}));
1760 
1761  // tuple of reference to C/Ds tensor descriptors
1762  const auto c_ds_buf_refs = concat_tuple_of_reference(
1763  tie(c_shuffle_block_buf),
1764  generate_tie([&](auto i) -> const auto& // return type should be reference
1765  { return ds_grid_buf[i]; },
1766  Number<NumDTensor>{}));
1767 
1768  // tuple of starting index of C/Ds blockwise copy
1769  const auto idx_c_ds_block_begin =
1772  [&](auto) {
1773  return make_multi_index(block_m_id, 0, block_n_id, 0);
1774  // return make_multi_index(block_work_idx[I0], 0,
1775  // block_work_idx[I1], 0);
1776  },
1777  Number<NumDTensor>{}));
1778 
1779  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1780  c_grid_desc_mblock_mperblock_nblock_nperblock;
1781 
1782  using CDEBlockTransferCluster =
1783  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1784  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1785  constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix
1786  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1788  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1790  decltype(c_ds_desc_refs),
1791  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1792  CElementwiseOperation,
1793  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1794  // support arbitray type
1795  Sequence<1,
1796  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1797  1,
1798  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1799  CDEBlockTransferCluster,
1800  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1801  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1802  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1803  3, // index_t SrcVectorDim,
1804  3, // index_t DstVectorDim,
1805  CDEShuffleBlockTransferScalarPerVectors,
1810  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1811  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1812  IndexType,
1813  1, // ScatterDim
1814  true, // OutputScatter: false, only use scatter weights
1815  scatter_weight_idx // ScatterWeightIdx: ascale
1816  >{c_ds_desc_refs,
1817  idx_c_ds_block_begin,
1818  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1819  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
1820  c_element_op};
1821 
1822  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1823  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1824  // space filling curve for threadwise C in VGPR
1825  constexpr auto sfc_c_vgpr =
1828  Sequence<CShuffleMXdlPerWavePerShuffle,
1829  CShuffleNXdlPerWavePerShuffle,
1830  1,
1831  1,
1832  1,
1833  N2,
1834  1,
1835  N4>>{};
1836 
1837  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1838 
1839  // space filling curve for shuffled blockwise C/D/E
1840  constexpr auto sfc_cde_block =
1843  Sequence<1,
1844  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1845  1,
1846  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1847 
1848  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1849  constexpr auto EMThreads =
1850  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
1851  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1852  constexpr auto ENThreads =
1853  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
1854  static_for<0, num_access, 1>{}([&](auto access_id) {
1855  // make sure it's safe to write to LDS
1857 
1858  auto dstidx = sfc_cde_block.GetIndex(access_id);
1859  const index_t c_token_pos =
1860  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
1861  static_for<0, EMRepeats, 1>{}([&](auto m0) {
1862  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1863  index_t token_offset = fused_token & 0xffffff;
1864  if constexpr(IsInputGemm)
1865  {
1866  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1867  }
1868  scatter_offsets(m0) = token_offset * problem.N;
1869  });
1870 
1871  block_sync_lds();
1872 
1873  // each thread write its data from VGPR to LDS
1874  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1875  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1876  c_thread_buf,
1877  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1878  c_shuffle_block_buf);
1879 
1880  // make sure it's safe to read from LDS
1881  block_sync_lds();
1882 
1883  // each block copy its data from LDS to global
1884  cde_block_copy_lds_and_global.Run(
1885  c_ds_desc_refs,
1886  c_ds_buf_refs,
1887  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1888  tie(c_grid_buf),
1889  scatter_offsets);
1890 
1891  if constexpr(access_id < num_access - 1)
1892  {
1893  constexpr auto cde_lds_and_global_step =
1894  sfc_cde_block.GetForwardStep(access_id);
1895 
1896  // move on Ds
1897  static_for<0, NumDTensor, 1>{}([&](auto i) {
1898  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1899  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
1900  });
1901 
1902  // move on E
1903  cde_block_copy_lds_and_global.MoveDstSliceWindow(
1904  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1905  I0,
1906  cde_lds_and_global_step);
1907  }
1908  });
1909  }
1910  }
1911 
1912  template <bool HasMainKBlockLoop,
1913  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1914  TailNumber TailNum = TailNumber::Odd>
1915  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
1916  const index_t* p_sorted_expert_ids,
1917  const index_t* p_max_token_id,
1918  const ADataType* p_a_grid,
1919  const BDataType* p_b_grid,
1920  DsGridPointer& p_ds_grid,
1921  CDataType* p_c_grid,
1922  const AScaleType* p_a_scale_grid,
1923  const BScaleType* p_b_scale_grid,
1924  void* p_shared,
1925  void* p_shared1,
1926  const Problem& problem,
1927  AElementwiseOperation a_element_op,
1928  BElementwiseOperation b_element_op,
1929  CElementwiseOperation c_element_op)
1930  {
1931  ignore = b_element_op;
1932  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1933  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1934  problem.MPadded,
1935  problem.K,
1936  problem.KPadded,
1937  problem.StrideA,
1938  problem.AK0);
1939  const auto b_grid_desc_bpreshuffled =
1941  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1942  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1943  problem.MPadded,
1944  problem.N,
1945  problem.NPadded,
1946  problem.StrideC);
1947 
1948  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
1949  make_tuple(math::integer_divide_ceil(IsInputGemm ? problem.NumTokens
1950  : problem.NumTokens * problem.TopK,
1951  ScaleBlockM),
1952  math::integer_divide_ceil(problem.K, ScaleBlockK)),
1953  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1954  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
1955  make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
1956  math::integer_divide_ceil(problem.K, ScaleBlockK)),
1957  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1958  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1960  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1961  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1962  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1963  if(expert_block_id * MPerBlock >= max_token_id)
1964  return;
1965  const index_t expert_id =
1966  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1967  const auto block_mn = [&]() -> std::pair<int, int> {
1968  if constexpr(NSwizzle)
1969  {
1970  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1971  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1972  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1973  const index_t expert_swizzle = ecnt > 0 ? ecnt : 1;
1974  const index_t bid_new = blockIdx.x - prefix_block;
1975  const index_t nid = __builtin_amdgcn_readfirstlane(
1976  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1977  const index_t mid =
1978  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1979  return {nid, mid};
1980  }
1981  else
1982  {
1983  return {blockIdx.x, blockIdx.y};
1984  }
1985  }();
1986  const index_t block_n_id = block_mn.first;
1987  const index_t block_m_id = block_mn.second;
1988 
1989  const index_t token0 =
1990  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1991 
1992  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1993  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1994  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1995  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1996  constexpr auto AKThreads = AK0Threads * AK1Threads;
1997  constexpr auto AMRepeats = MPerBlock / AMThreads;
1998  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1999 
2000  if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
2001  token0 >= problem.NumTokens)
2002  return;
2004  gather_offsets; //= p_sorted_token_ids[token_pos];
2005  static_for<0, AMRepeats, 1>{}([&](auto m0) {
2006  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
2007  index_t token_offset = fused_token & 0xffffff;
2008  if constexpr(!IsInputGemm)
2009  {
2010  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2011  }
2012  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2013  });
2014  const index_t expert_stride =
2015  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2016  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2017  math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
2018  math::integer_divide_ceil(problem.K, ScaleBlockK));
2019  // N0, K0, Blocksize*KPack
2020  const index_t n_block_data_idx_on_grid =
2021  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
2022 
2023  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2024  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2025  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2026  p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
2027  b_grid_desc_bpreshuffled.GetElementSpaceSize());
2028 
2029  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2030  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2031  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2032  p_b_scale_grid + expert_id * expert_scale_stride,
2033  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2034 
2035  // A matrix in LDS memory, dst of blockwise copy
2036  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2037 
2038  // B matrix in LDS memory, dst of blockwise copy
2039  // dummy
2040  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2041  // A matrix blockwise copy
2042  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
2044  AElementwiseOperation,
2048  ABlockTransferThreadClusterLengths_AK0_M_AK1,
2049  ABlockTransferThreadClusterArrangeOrder,
2050  ADataType,
2051  LDSTypeA,
2052  decltype(a_grid_desc_ak0_m_ak1),
2053  decltype(a_block_desc_ak0_m_ak1),
2054  ABlockTransferSrcAccessOrder,
2056  ABlockTransferSrcVectorDim,
2057  2,
2058  ABlockTransferSrcScalarPerVector,
2059  ABlockTransferDstScalarPerVector_AK1,
2060  1,
2061  1,
2062  AThreadTransferSrcResetCoordinateAfterRun,
2063  true,
2064  IndexType,
2065  1,
2066  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
2067  make_multi_index(0, 0, 0),
2068  a_element_op,
2069  a_block_desc_ak0_m_ak1,
2070  make_multi_index(0, 0, 0),
2072  gather_offsets);
2073 
2074  // Thread-wise copy
2075  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2076  auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2077  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2078  auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2079  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2080  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2081 
2082  auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
2083  BDataType,
2084  BDataType,
2085  decltype(b_grid_desc_bpreshuffled),
2086  decltype(b_block_desc_bk0_n_bk1),
2089  3,
2090  BBlockTransferSrcScalarPerVector,
2091  BThreadTransferSrcResetCoordinateAfterRun,
2092  true>(b_grid_desc_bpreshuffled,
2093  make_multi_index(n_block_data_idx_on_grid,
2095  0,
2096  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2097 
2098  // LDS allocation for A and B: be careful of alignment
2099  // Cast after lds
2100  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2101  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2102  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2103  static_cast<ADataType*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2104  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2105 
2106  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2107  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
2108 
2109  // Blockwise GEMM pipeline
2110  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2111  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2112  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2113  decltype(c_thread_buf) c_thread_buf_up;
2114 
2115  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2116  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2117  KPerBlock);
2118 
2119  // scale
2120  constexpr index_t ScaleSliceSizeM = MXdlPerWave;
2121  constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
2122  constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
2123 
2124  // ScaleSliceSizeK is last dimension in A/B scale for vector memory access
2125  // ScaleSliceSizeK is first dimension in C scale for packed math
2126  constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
2128 
2129  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
2130  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
2131  auto a_thread_offset =
2132  get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl;
2133 
2134  constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
2136 
2137  constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
2139 
2140  // get each thread's offset in the scale tensor
2141  // A scale
2142  const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
2143 
2144  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2145  return;
2146  StaticallyIndexedArray<index_t, MXdlPerWave> scale_gather_offsets;
2147  static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
2148  const index_t fused_token =
2149  p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
2150  index_t token_offset = fused_token & 0xffffff;
2151  if constexpr(!IsInputGemm)
2152  {
2153  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2154  }
2155  scale_gather_offsets(m0) = static_cast<IndexType>(token_offset) *
2156  math::integer_divide_ceil(problem.K, ScaleBlockK);
2157  });
2158 
2159  auto a_scale_thread_copy =
2161  AScaleType,
2162  decltype(a_scale_grid_desc_am_ak),
2163  decltype(a_scale_thread_desc),
2166  1,
2167  ScaleSliceSizeK,
2168  1,
2169  false,
2170  MXdlPerWave>(
2171  a_scale_grid_desc_am_ak, make_multi_index(0, 0), scale_gather_offsets);
2172 
2173  auto b_scale_thread_copy =
2175  BScaleType,
2176  decltype(b_scale_grid_desc_bn_ak),
2177  decltype(b_scale_thread_desc),
2180  1,
2181  ScaleSliceSizeK,
2182  1,
2183  false>(
2184  b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
2185 
2186  // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
2187  constexpr auto a_scale_thread_slice_copy_step =
2188  make_tuple(make_multi_index(0, 0), make_multi_index(0, ScaleSliceSizeK));
2189  constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
2190 
2191  constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
2192  if constexpr(IsInputGemm)
2193  {
2194  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
2195  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2196  p_b_grid_up + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
2197  b_grid_desc_bpreshuffled.GetElementSpaceSize());
2198  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
2199  BDataType,
2200  BDataType,
2201  decltype(b_grid_desc_bpreshuffled),
2202  decltype(b_block_desc_bk0_n_bk1),
2205  3,
2206  BBlockTransferSrcScalarPerVector,
2207  BThreadTransferSrcResetCoordinateAfterRun,
2208  true>(b_grid_desc_bpreshuffled,
2209  make_multi_index(n_block_data_idx_on_grid,
2211  0,
2212  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2213  const BScaleType* p_b_scale_grid_up =
2214  p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
2215  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2216  p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize,
2217  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2218  auto b_scale_thread_copy_up =
2220  BScaleType,
2221  decltype(b_scale_grid_desc_bn_ak),
2222  decltype(b_scale_thread_desc),
2225  1,
2226  ScaleSliceSizeK,
2227  1,
2228  false>(
2229  b_scale_grid_desc_bn_ak,
2230  make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
2231 
2232  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2233  a_grid_desc_ak0_m_ak1,
2234  a_block_desc_ak0_m_ak1,
2235  a_blockwise_copy,
2236  a_grid_buf,
2237  a_block_bufs,
2238  a_block_slice_copy_step,
2239  b_grid_desc_bpreshuffled,
2240  b_block_desc_bk0_n_bk1,
2241  b_blockwise_copy,
2242  b_blockwise_copy_up,
2243  b_grid_buf,
2244  b_grid_buf_up,
2245  b_block_bufs,
2246  b_block_slice_copy_step,
2247  c_scale_thread_desc,
2248  c_thread_buf,
2249  c_thread_buf_up,
2250  a_scale_grid_desc_am_ak,
2251  a_scale_thread_desc,
2252  a_scale_thread_copy,
2253  a_scale_grid_buf,
2254  a_scale_thread_slice_copy_step,
2255  b_scale_grid_desc_bn_ak,
2256  b_scale_thread_desc,
2257  b_scale_thread_copy,
2258  b_scale_thread_copy_up,
2259  b_scale_grid_buf,
2260  b_scale_grid_buf_up,
2261  b_scale_thread_slice_copy_step,
2262  num_k_block_main_loop);
2263  }
2264  else
2265  {
2266  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2267  a_grid_desc_ak0_m_ak1,
2268  a_block_desc_ak0_m_ak1,
2269  a_blockwise_copy,
2270  a_grid_buf,
2271  a_block_bufs,
2272  a_block_slice_copy_step,
2273  b_grid_desc_bpreshuffled,
2274  b_block_desc_bk0_n_bk1,
2275  b_blockwise_copy,
2276  b_grid_buf,
2277  b_block_bufs,
2278  b_block_slice_copy_step,
2279  c_scale_thread_desc,
2280  c_thread_buf,
2281  a_scale_grid_desc_am_ak,
2282  a_scale_thread_desc,
2283  a_scale_thread_copy,
2284  a_scale_grid_buf,
2285  a_scale_thread_slice_copy_step,
2286  b_scale_grid_desc_bn_ak,
2287  b_scale_thread_desc,
2288  b_scale_thread_copy,
2289  b_scale_grid_buf,
2290  b_scale_thread_slice_copy_step,
2291  num_k_block_main_loop);
2292  }
2293 
2294  // shuffle C and write out
2295  {
2296 
2297  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2298  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2299  "wrong!");
2300 
2301  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2302 
2303  // transposed XDL
2304  // TODO: hacky, fix it!
2305  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
2306  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2307 
2308  // TODO: hacky, fix it!
2309  // only used to get lengths
2310  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
2311  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2312 
2313  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
2314  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
2315  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
2316  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
2317  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
2318  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
2319  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
2320  constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
2321 
2322  static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
2323  static_assert(M0 * M1 * M2 == MPerBlock);
2324  static_assert(N4 == 4);
2325  const index_t m1 = get_warp_local_1d_id() / NWave;
2326  const index_t m2 = threadIdx.x % get_warp_size() % M2;
2327 
2328  float topk_weight;
2329  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
2330  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
2331  if constexpr(MulRoutedWeight)
2332  {
2333  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
2334  topk_weight = p_ds_grid[I0][m_pos];
2335  }
2336  static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
2337  static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
2338  constexpr index_t c_offset =
2339  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2340  make_tuple(m0, n0, n2 * N4 + n4));
2341  constexpr auto cidx = Number<c_offset>{};
2342  if constexpr(IsInputGemm) // gu fusion, elementwise
2343  {
2344  if constexpr(ActivationOperation == Activation::silu_and_mul)
2345  {
2346  float gate = c_thread_buf[cidx];
2347  float up = c_thread_buf_up[cidx];
2348  if constexpr(MulRoutedWeight)
2349  {
2350  gate = gate * topk_weight;
2351  up = up * topk_weight;
2352  }
2354  {
2355  gate *= 16;
2356  up *= 16;
2357  }
2359  c_thread_buf(cidx) = gate * up;
2360  }
2361  else if(ActivationOperation == Activation::gelu_and_mul)
2362  {
2363  float gate = c_thread_buf[cidx];
2364  float up = c_thread_buf_up[cidx];
2365  if constexpr(MulRoutedWeight)
2366  {
2367  gate = gate * topk_weight;
2368  up = up * topk_weight;
2369  }
2371  {
2372  gate *= 16;
2373  up *= 16;
2374  }
2376  c_thread_buf(cidx) = gate * up;
2377  }
2378  }
2379  else
2380  {
2381  if constexpr(MulRoutedWeight)
2382  {
2383  c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
2384  }
2385  }
2386 
2387  });
2388  });
2389  });
2390  });
2391 
2392  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2394 
2395  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2396  static_cast<CShuffleDataType*>(p_shared),
2397  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2398 
2399  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
2400  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2401  make_tuple(
2404  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
2405  M1, // M1 = MWave
2406  M2)), // M2 = MPerXdl
2409  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
2410  N1, // N1 = NWave
2411  N2, // N2 * N3 * N4 = NPerXdl
2412  N3,
2413  N4))),
2415  make_tuple(
2417 
2418  // calculate origin of thread output tensor on global memory
2419  // blockwise GEMM c matrix starting index
2420  const auto c_thread_mtx_on_block =
2421  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2422 
2423  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2424  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2425 
2426  const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
2430  make_tuple(Sequence<0>{}));
2431 
2432  const auto m_thread_data_on_block_idx =
2433  m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
2434  make_multi_index(m_thread_data_on_block));
2435 
2436  const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
2438  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
2440  make_tuple(Sequence<0>{}));
2441 
2442  const auto n_thread_data_on_block_idx =
2443  n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
2444  make_multi_index(n_thread_data_on_block));
2445 
2446  // shuffle: threadwise copy C from VGPR to LDS
2447  auto c_thread_copy_vgpr_to_lds =
2449  CShuffleDataType,
2450  decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2451  decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2453  Sequence<CShuffleMXdlPerWavePerShuffle,
2454  CShuffleNXdlPerWavePerShuffle,
2455  I1,
2456  I1,
2457  I1,
2458  N2,
2459  I1,
2460  N4>,
2462  7,
2463  1,
2465  1,
2466  true>{
2467  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2468  make_multi_index(0,
2469  0,
2470  m_thread_data_on_block_idx[I1],
2471  n_thread_data_on_block_idx[I1],
2472  m_thread_data_on_block_idx[I2],
2473  n_thread_data_on_block_idx[I2],
2474  n_thread_data_on_block_idx[I3],
2475  n_thread_data_on_block_idx[I4]),
2477 
2478  using EDataType = CDataType;
2479 
2480  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2481  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2482 
2483  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2485  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2486 
2487  const auto ds_grid_buf = generate_tuple(
2488  [&](auto i) {
2489  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2490  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2491  },
2492  Number<NumDTensor>{});
2493 
2494  // tuple of reference to C/Ds tensor descriptors
2495  const auto c_ds_desc_refs = concat_tuple_of_reference(
2496  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2497  generate_tie([&](auto i) -> const auto& // return type should be reference
2498  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2499  Number<NumDTensor>{}));
2500 
2501  // tuple of reference to C/Ds tensor descriptors
2502  const auto c_ds_buf_refs = concat_tuple_of_reference(
2503  tie(c_shuffle_block_buf),
2504  generate_tie([&](auto i) -> const auto& // return type should be reference
2505  { return ds_grid_buf[i]; },
2506  Number<NumDTensor>{}));
2507 
2508  // tuple of starting index of C/Ds blockwise copy
2509  const auto idx_c_ds_block_begin =
2512  [&](auto) {
2513  return make_multi_index(block_m_id, 0, block_n_id, 0);
2514  // return make_multi_index(block_work_idx[I0], 0,
2515  // block_work_idx[I1], 0);
2516  },
2517  Number<NumDTensor>{}));
2518 
2519  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2520  c_grid_desc_mblock_mperblock_nblock_nperblock;
2521 
2522  using CDEBlockTransferCluster =
2523  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2524  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2525  constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix
2526  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2528  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2530  decltype(c_ds_desc_refs),
2531  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2532  CElementwiseOperation,
2533  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
2534  // support arbitray type
2535  Sequence<1,
2536  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2537  1,
2538  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2539  CDEBlockTransferCluster,
2540  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2541  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2542  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2543  3, // index_t SrcVectorDim,
2544  3, // index_t DstVectorDim,
2545  CDEShuffleBlockTransferScalarPerVectors,
2550  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2551  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2552  IndexType,
2553  1, // ScatterDim
2554  true, // OutputScatter: false, only use scatter weights
2555  scatter_weight_idx // ScatterWeightIdx: ascale
2556  >{c_ds_desc_refs,
2557  idx_c_ds_block_begin,
2558  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2559  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2560  c_element_op};
2561 
2562  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2563  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2564  // space filling curve for threadwise C in VGPR
2565  constexpr auto sfc_c_vgpr =
2568  Sequence<CShuffleMXdlPerWavePerShuffle,
2569  CShuffleNXdlPerWavePerShuffle,
2570  1,
2571  1,
2572  1,
2573  N2,
2574  1,
2575  N4>>{};
2576 
2577  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2578 
2579  // space filling curve for shuffled blockwise C/D/E
2580  constexpr auto sfc_cde_block =
2583  Sequence<1,
2584  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2585  1,
2586  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2587 
2588  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2589  constexpr auto EMThreads =
2590  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2591  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2592  constexpr auto ENThreads =
2593  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2594  static_for<0, num_access, 1>{}([&](auto access_id) {
2595  // make sure it's safe to write to LDS
2597  scatter_offsets; //= p_sorted_token_ids[c_token_pos];
2598 
2599  auto dstidx = sfc_cde_block.GetIndex(access_id);
2600  const index_t c_token_pos =
2601  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2602  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2603  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2604  index_t token_offset = fused_token & 0xffffff;
2605  if constexpr(IsInputGemm)
2606  {
2607  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2608  }
2609  scatter_offsets(m0) = token_offset * problem.N;
2610  });
2611 
2612  block_sync_lds();
2613 
2614  // each thread write its data from VGPR to LDS
2615  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2616  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2617  c_thread_buf,
2618  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2619  c_shuffle_block_buf);
2620 
2621  // make sure it's safe to read from LDS
2622  block_sync_lds();
2623 
2624  // each block copy its data from LDS to global
2625  cde_block_copy_lds_and_global.Run(
2626  c_ds_desc_refs,
2627  c_ds_buf_refs,
2628  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2629  tie(c_grid_buf),
2630  scatter_offsets);
2631 
2632  if constexpr(access_id < num_access - 1)
2633  {
2634  constexpr auto cde_lds_and_global_step =
2635  sfc_cde_block.GetForwardStep(access_id);
2636 
2637  // move on Ds
2638  static_for<0, NumDTensor, 1>{}([&](auto i) {
2639  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2640  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2641  });
2642 
2643  // move on E
2644  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2645  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2646  I0,
2647  cde_lds_and_global_step);
2648  }
2649  });
2650  }
2651  }
2652 };
2653 
2654 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:266
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:23
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:275
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
int64_t long_index_t
Definition: ck.hpp:298
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:297
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:81
Definition: gridwise_moe_gemm_blockscale.hpp:663
const index_t * p_sorted_token_ids
Definition: gridwise_moe_gemm_blockscale.hpp:719
CDataType * p_c_grid
Definition: gridwise_moe_gemm_blockscale.hpp:725
const BScaleType * p_b_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:728
const index_t * p_max_token_id
Definition: gridwise_moe_gemm_blockscale.hpp:721
DsGridPointer p_ds_grid
Definition: gridwise_moe_gemm_blockscale.hpp:724
const CElementwiseOperation c_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:732
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, const AScaleType *p_a_scale_grid_, const BScaleType *p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_gemm_blockscale.hpp:664
const ADataType * p_a_grid
Definition: gridwise_moe_gemm_blockscale.hpp:722
const AElementwiseOperation a_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:730
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_gemm_blockscale.hpp:720
const BDataType * p_b_grid
Definition: gridwise_moe_gemm_blockscale.hpp:723
const AScaleType * p_a_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:727
const BElementwiseOperation b_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:731
Definition: gridwise_moe_gemm_blockscale.hpp:593
index_t K
Definition: gridwise_moe_gemm_blockscale.hpp:642
__host__ __device__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_gemm_blockscale.hpp:594
index_t TopK
Definition: gridwise_moe_gemm_blockscale.hpp:639
index_t BK0Shuffled
Definition: gridwise_moe_gemm_blockscale.hpp:658
index_t NPadded
Definition: gridwise_moe_gemm_blockscale.hpp:649
index_t StrideB
Definition: gridwise_moe_gemm_blockscale.hpp:644
__host__ void Print() const
Definition: gridwise_moe_gemm_blockscale.hpp:627
index_t BK0
Definition: gridwise_moe_gemm_blockscale.hpp:653
index_t BN0Shuffled
Definition: gridwise_moe_gemm_blockscale.hpp:657
index_t KRead
Definition: gridwise_moe_gemm_blockscale.hpp:650
index_t N
Definition: gridwise_moe_gemm_blockscale.hpp:641
index_t StrideC
Definition: gridwise_moe_gemm_blockscale.hpp:646
index_t KBatch
Definition: gridwise_moe_gemm_blockscale.hpp:647
index_t MBlock
Definition: gridwise_moe_gemm_blockscale.hpp:654
index_t KPadded
Definition: gridwise_moe_gemm_blockscale.hpp:651
index_t NumTokens
Definition: gridwise_moe_gemm_blockscale.hpp:638
index_t StrideA
Definition: gridwise_moe_gemm_blockscale.hpp:643
index_t AK0
Definition: gridwise_moe_gemm_blockscale.hpp:652
index_t M
Definition: gridwise_moe_gemm_blockscale.hpp:640
index_t MPadded
Definition: gridwise_moe_gemm_blockscale.hpp:648
index_t NBlock
Definition: gridwise_moe_gemm_blockscale.hpp:655
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_gemm_blockscale.hpp:645
Definition: gridwise_moe_gemm_blockscale.hpp:736
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_gemm_blockscale.hpp:737
index_t a_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:768
index_t b_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:769
Definition: gridwise_moe_gemm_blockscale.hpp:171
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:522
static constexpr index_t KPack
Definition: gridwise_moe_gemm_blockscale.hpp:196
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_gemm_blockscale.hpp:888
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_gemm_blockscale.hpp:411
static constexpr auto AK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:189
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:269
static constexpr auto BK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:190
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_gemm_blockscale.hpp:230
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1915
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1134
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_gemm_blockscale.hpp:772
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:260
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:285
static constexpr auto I3
Definition: gridwise_moe_gemm_blockscale.hpp:178
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm_blockscale.hpp:961
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:315
static constexpr index_t KLane
Definition: gridwise_moe_gemm_blockscale.hpp:209
remove_cvref_t< decltype(BlockGemmBlockMoeScaleBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, ScaleBlockM, ScaleBlockN, ScaleBlockK, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_gemm_blockscale.hpp:937
float AScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:172
static constexpr auto AK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:187
static constexpr index_t KRepeat
Definition: gridwise_moe_gemm_blockscale.hpp:211
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_gemm_blockscale.hpp:329
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:1149
static constexpr auto I2
Definition: gridwise_moe_gemm_blockscale.hpp:177
static constexpr index_t APackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:232
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:265
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:274
static constexpr auto I4
Definition: gridwise_moe_gemm_blockscale.hpp:179
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_gemm_blockscale.hpp:567
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:255
static constexpr auto I6
Definition: gridwise_moe_gemm_blockscale.hpp:181
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_gemm_blockscale.hpp:184
__host__ static constexpr __device__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1141
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:516
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:246
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_gemm_blockscale.hpp:228
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_gemm_blockscale.hpp:895
float BScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:173
static constexpr auto I7
Definition: gridwise_moe_gemm_blockscale.hpp:182
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:304
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:579
static constexpr index_t NWave
Definition: gridwise_moe_gemm_blockscale.hpp:213
static constexpr index_t NLane
Definition: gridwise_moe_gemm_blockscale.hpp:212
static constexpr index_t NumDTensor
Definition: gridwise_moe_gemm_blockscale.hpp:193
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:297
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:309
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_gemm_blockscale.hpp:939
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:291
static constexpr auto I0
Definition: gridwise_moe_gemm_blockscale.hpp:175
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:507
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:546
static constexpr auto BlockSizeNumber
Definition: gridwise_moe_gemm_blockscale.hpp:191
static constexpr index_t KGroup
Definition: gridwise_moe_gemm_blockscale.hpp:198
static constexpr index_t BPackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:239
static constexpr auto I5
Definition: gridwise_moe_gemm_blockscale.hpp:180
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1170
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_gemm_blockscale.hpp:217
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_gemm_blockscale.hpp:419
static constexpr auto I1
Definition: gridwise_moe_gemm_blockscale.hpp:176
static constexpr auto BK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:188
static constexpr index_t SortedTileSize
Definition: gridwise_moe_gemm_blockscale.hpp:215
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:279
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))> DsGridDesc_M_N
Definition: gridwise_moe_gemm_blockscale.hpp:590
Definition: xdlops_gemm.hpp:942
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1388
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1343
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1382
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: threadwise_tensor_slice_transfer.hpp:440
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:197
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:981
Definition: unary_element_wise_operation.hpp:308
Definition: unary_element_wise_operation.hpp:1023