/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp Source File
gridwise_moe_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
16 
18 
19 #define DEBUG_LOG 0
20 
21 namespace ck {
22 
23 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
24 // kernel function Blockers:
25 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
26 // two lds chunks.
27 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
28 // buffer when we declare __shared__ inside blkgemmpipe
29 
31 {
33  silu_and_mul = 1
34 };
35 
36 template <typename GridwiseGemm,
37  bool HasMainKBlockLoop,
38  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
39  index_t MinimumOccupancy = 1,
40  TailNumber TailNum = TailNumber::Even>
41 __global__ void
42 #if CK_USE_LAUNCH_BOUNDS
43  __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
44 #endif
45  // __attribute__((amdgpu_waves_per_eu(1, 1)))
46  kernel_moe_gemm(typename GridwiseGemm::Argument karg)
47 {
48 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
49  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
50 
51  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
52 
53  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
54  karg.p_sorted_token_ids,
55  karg.p_sorted_expert_ids,
56  karg.p_max_token_id,
57  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
58  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
59  karg.p_ds_grid,
60  karg.p_c_grid,
61  p_shared,
62  karg,
63  karg.a_element_op,
64  karg.b_element_op,
65  karg.c_element_op);
66 #else
67  ignore = karg;
68 #endif // end of if (defined(__gfx9__))
69 }
70 
71 template <typename GridwiseGemm,
72  bool HasMainKBlockLoop,
73  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
74  index_t MinimumOccupancy = 1,
75  TailNumber TailNum = TailNumber::Even>
76 __global__ void
77 #if CK_USE_LAUNCH_BOUNDS
78  __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
79 #endif
80  // __attribute__((amdgpu_waves_per_eu(1, 1)))
81  kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
82 {
83 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
84  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
85  __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
86 
87  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
88 
89  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
90  karg.p_sorted_token_ids,
91  karg.p_sorted_expert_ids,
92  karg.p_max_token_id,
93  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
94  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
95  karg.p_ds_grid,
96  karg.p_c_grid,
97  p_shared,
98  p_shared1,
99  karg,
100  karg.a_element_op,
101  karg.b_element_op,
102  karg.c_element_op);
103 #else
104  ignore = karg;
105 #endif // end of if (defined(__gfx9__))
106 }
107 
108 template <typename ALayout,
109  typename BLayout,
110  typename DsLayout,
111  typename CLayout,
112  typename ADataType,
113  typename BDataType,
114  typename AccDataType,
115  typename CShuffleDataType,
116  typename DsDataType,
117  typename CDataType,
118  typename AElementwiseOperation,
119  typename BElementwiseOperation,
120  typename CElementwiseOperation,
122  index_t BlockSize,
123  index_t MPerBlock,
124  index_t NPerBlock,
125  index_t KPerBlock,
126  index_t AK1Value,
127  index_t BK1Value,
128  index_t MPerXdl,
129  index_t NPerXdl,
130  index_t MXdlPerWave,
131  index_t NXdlPerWave,
132  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
133  typename ABlockTransferThreadClusterArrangeOrder,
134  typename ABlockTransferSrcAccessOrder,
135  index_t ABlockTransferSrcVectorDim,
136  index_t ABlockTransferSrcScalarPerVector,
137  index_t ABlockTransferDstScalarPerVector_AK1,
138  bool AThreadTransferSrcResetCoordinateAfterRun,
139  index_t ABlockLdsExtraM,
140  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
141  typename BBlockTransferThreadClusterArrangeOrder,
142  typename BBlockTransferSrcAccessOrder,
143  index_t BBlockTransferSrcVectorDim,
144  index_t BBlockTransferSrcScalarPerVector,
145  index_t BBlockTransferDstScalarPerVector_BK1,
146  bool BThreadTransferSrcResetCoordinateAfterRun,
147  index_t BBlockLdsExtraN,
148  index_t CShuffleMXdlPerWavePerShuffle,
149  index_t CShuffleNXdlPerWavePerShuffle,
150  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
151  typename CDEShuffleBlockTransferScalarPerVectors,
154  index_t ActivationOperation = 0,
155  bool NSwizzle = false,
156  bool IsInputGemm = true,
157  bool MulRoutedWeight = true,
158  bool PerTokenQuant = false,
159  typename IndexType = index_t,
160  typename ComputeTypeA = CDataType,
161  typename ComputeTypeB = ComputeTypeA,
162  typename LDSTypeA = ADataType,
163  typename LDSTypeB = BDataType>
165 {
166  static constexpr auto I0 = Number<0>{};
167  static constexpr auto I1 = Number<1>{};
168  static constexpr auto I2 = Number<2>{};
169  static constexpr auto I3 = Number<3>{};
170  static constexpr auto I4 = Number<4>{};
171  static constexpr auto I5 = Number<5>{};
172  static constexpr auto I6 = Number<6>{};
173  static constexpr auto I7 = Number<7>{};
174 
176  CDEShuffleBlockTransferScalarPerVectors{}[I0];
177  // K1 should be Number<...>
178  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
179  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
180  static constexpr auto AK1Number = Number<AK1Value>{};
181  static constexpr auto BK1Number = Number<BK1Value>{};
182  static constexpr auto BlockSizeNumber = Number<BlockSize>{};
183 
184  static constexpr index_t NumDTensor = DsDataType::Size();
185 
187  static constexpr index_t KPack =
189  static constexpr index_t KLane =
191 
192  static constexpr index_t KGroup = []() {
194  // On gfx950, we have a mfma that required 32 f8 elements as input,
195  // splited into 2 groups of 16 f8 elements.
196  // the 2 groups is not contiguous in the B preshuffed layout.
197  // and we do not want it to be contiguous in the B preshuffled layout
198  // because a memory instruction can only read 16 f8 elements at a time.
199  return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
200  else
201  return 1;
202  }();
203 
204  static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
205 
206  static constexpr index_t NLane = NPerXdl;
207  static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
208  // static constexpr index_t NumTokens = 1;
209  static constexpr index_t SortedTileSize = MPerBlock;
210 
211  static constexpr auto MakeDsGridPointer()
212  {
213  return generate_tuple(
214  [&](auto i) {
215  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
216 
217  return static_cast<const DDataType*>(nullptr);
218  },
220  }
221 
222  using DsGridPointer = decltype(MakeDsGridPointer());
223 
225 
226  static constexpr index_t APackedSize = []() {
228  return 2;
229  else
230  return 1;
231  }();
232 
233  static constexpr index_t BPackedSize = []() {
235  return 2;
236  else
237  return 1;
238  }();
239 
240  __host__ static auto CalculateGridSize(index_t M, index_t N)
241  {
242  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
243  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
244  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
245  const index_t gridy = NSwizzle ? 1 : mblock;
246 
247  return std::make_tuple(gridx, gridy, 1);
248  }
249 
250  __host__ __device__ static auto CalculateMPadded(index_t M)
251  {
252  return math::integer_least_multiple(M, MPerBlock);
253  }
254 
255  __host__ __device__ static auto CalculateNPadded(index_t N)
256  {
257  return math::integer_least_multiple(N, NPerBlock);
258  }
259 
260  __host__ __device__ static auto CalculateBN0Shuffled(index_t N)
261  {
262  return math::integer_divide_ceil(N, NLane);
263  }
264  __host__ __device__ static auto CalculateBK0Shuffled(index_t K)
265  {
267  }
268 
269  __host__ __device__ static auto CalculateKPadded(index_t K)
270  {
271  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
272  }
273 
274  __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
275  {
276  auto K_t = K_Batch * KPerBlock;
277  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
278  }
279 
280  __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
281  {
282  auto K_t = K_Batch * KPerBlock;
283  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
284  }
285 
286  __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
287  {
288  auto K_t = K_Batch * KPerBlock;
289  return (K + K_t - 1) / K_t * KPerBlock;
290  }
291 
292  __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
293  {
294  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
295  auto K_t = K_Batch * KReadVec;
296  return (K + K_t - 1) / K_t * KReadVec;
297  }
298 
299  __host__ __device__ static auto CalculateMBlock(index_t M)
300  {
301  return math::integer_divide_ceil(M, MPerBlock);
302  }
303 
304  __host__ __device__ static auto CalculateNBlock(index_t N)
305  {
306  return math::integer_divide_ceil(N, NPerBlock);
307  }
308 
309  template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
310  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
311  {
312  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
313  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
314 
316  TileDesc_K0_MN_K1{},
322  }
323 
324  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
325  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
326  {
327  const auto a_grid_desc_mraw_kraw = [&]() {
328  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
329  {
330  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
331  }
332  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
333  {
334  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
335  }
336  }();
337 
339 
340  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
341  GemmSpec == GemmSpecialization::MNKPadding)
342  {
343  // pad both M and K
344  const auto a_grid_desc_m_k =
345  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
347  make_right_pad_transform(K, KPad - K)),
350 
351  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
352  a_grid_desc_m_k,
357 
358  return a_grid_desc_ak0_m_ak1;
359  }
360  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
361  GemmSpec == GemmSpecialization::MNPadding)
362  {
363  // pad M, but not K
364  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
365  a_grid_desc_mraw_kraw,
367  make_right_pad_transform(M, MPad - M)),
370 
371  return a_grid_desc_ak0_m_ak1;
372  }
373  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
374  GemmSpec == GemmSpecialization::NKPadding)
375  {
376  // pad K, but not M
377  const auto a_grid_desc_m_k = transform_tensor_descriptor(
378  a_grid_desc_mraw_kraw,
382 
383  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
384  a_grid_desc_m_k,
389 
390  return a_grid_desc_ak0_m_ak1;
391  }
392  else
393  {
394  // not pad M or K
395  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
396  a_grid_desc_mraw_kraw,
401 
402  return a_grid_desc_ak0_m_ak1;
403  }
404  }
405 
406  __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
407  {
408  constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack / KGroup>{};
410  make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
411  make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
412  }
413 
414  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
415  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
416  {
417  const auto b_grid_desc_nraw_kraw = [&]() {
419  {
420  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
421  }
423  {
424  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
425  }
426  }();
427 
429 
430  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
431  GemmSpec != GemmSpecialization::Default),
432  "pk_i4_t does not support padding");
433 
434  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
435  GemmSpec == GemmSpecialization::MNKPadding)
436  {
437  // pad both N and K
438  const auto b_grid_desc_n_k =
439  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
441  make_right_pad_transform(K, KPad - K)),
444 
445  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
446  b_grid_desc_n_k,
451 
452  return b_grid_desc_bk0_n_bk1;
453  }
454  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
455  GemmSpec == GemmSpecialization::MNPadding)
456  {
457  // pad N, but not K
458  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
459  b_grid_desc_nraw_kraw,
461  make_right_pad_transform(N, NPad - N)),
464 
465  return b_grid_desc_bk0_n_bk1;
466  }
467  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
468  GemmSpec == GemmSpecialization::MKPadding)
469  {
470  // pad K, but not N
471  const auto b_grid_desc_n_k = transform_tensor_descriptor(
472  b_grid_desc_nraw_kraw,
476 
477  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
478  b_grid_desc_n_k,
483 
484  return b_grid_desc_bk0_n_bk1;
485  }
486  else
487  {
488  // not pad N or K
489  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
490  b_grid_desc_nraw_kraw,
495 
496  return b_grid_desc_bk0_n_bk1;
497  }
498  }
499 
500  template <typename ABlockDesc_AK0_M_AK1>
501  __host__ __device__ static constexpr auto
502  MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
503  {
504  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
505 
506  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
507  }
508 
509  template <typename BBlockDesc_BK0_N_BK1>
510  __host__ __device__ static constexpr auto
511  MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
512  {
513  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
514  }
515 
516  template <typename ELayout>
517  __host__ __device__ static auto MakeCGridDescriptor_M_N(
518  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
519  {
520  const auto c_grid_desc_mraw_nraw = [&]() {
522  {
523  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
524  }
526  {
527  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
528  }
529  }();
530 
531  // pad M and N
532  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
534  make_right_pad_transform(N, NPad - N)),
537  }
538 
539  template <typename DLayout>
540  __host__ __device__ static auto
542  {
543  const auto c_grid_desc_mraw_nraw = [&]() {
545  {
546  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
547  }
549  {
550  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
551  }
552  }();
553 
554  // pad M and N
555  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
557  make_right_pad_transform(N, NPad - N)),
560  }
561 
562  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
563  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
564  {
565  return generate_tuple(
566  [&](auto i) {
567  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
568  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
569  },
571  }
572 
573  template <typename DsGridDesc>
575  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
576  {
577  return generate_tuple(
578  [&](auto i) {
580  ds_grid_desc_m_n[i], MBlock, NBlock);
581  },
583  }
584 
585  struct Problem
586  {
587  __host__ __device__ Problem(index_t NumTokens_,
588  index_t TopK_,
589  index_t M_,
590  index_t N_,
591  index_t K_,
592  index_t StrideA_,
593  index_t StrideB_,
594  std::array<index_t, NumDTensor> StrideDs_,
595  index_t StrideC_,
596  index_t KBatch_)
597  : NumTokens{NumTokens_},
598  TopK{TopK_},
599  M{M_},
600  N{N_},
601  K{K_},
602  StrideA{StrideA_},
603  StrideB{StrideB_},
604  StrideDs{StrideDs_},
605  StrideC{StrideC_},
606  KBatch{KBatch_},
609  KRead{CalculateKRead(K_, KBatch_)},
610  KPadded{CalculateKPadded(K_, KBatch_)},
611  AK0{CalculateAK0Padded(K_, KBatch_)},
612  BK0{CalculateBK0Padded(K_, KBatch_)},
613  MBlock{CalculateMBlock(M_)},
614  NBlock{CalculateNBlock(N_)},
617  {
618  }
619 
620  __host__ void Print() const
621  {
622  std::cout << "problem {"
623  << "NumTokens:" << NumTokens << ", "
624  << "TopK:" << TopK << ", "
625  << "M:" << M << ", "
626  << "N:" << N << ", "
627  << "K:" << K << ", "
628  << "SA:" << StrideA << ", "
629  << "SB:" << StrideB << ", "
630  << "SC:" << StrideC << ", "
631  << "MP:" << MPadded << ", "
632  << "NP:" << NPadded << ", "
633  << "KRead:" << KRead << ", "
634  << "KP:" << KPadded << ", "
635  << "AK0:" << AK0 << ", "
636  << "BK0:" << BK0 << ", "
637  << "MBlock: " << MBlock << ", "
638  << "NBlock: " << NBlock << "}" << std::endl;
639  }
640 
648  std::array<index_t, NumDTensor> StrideDs;
659  // FOR PRESHUFFLE ONLY
662  };
663 
664  // Argument
666  {
667  __host__ Argument(const index_t* p_sorted_token_ids_,
668  const index_t* p_sorted_expert_ids_,
669  const index_t* p_max_token_id_,
670  const ADataType* p_a_grid_,
671  const BDataType* p_b_grid_,
672  std::array<const void*, NumDTensor> p_ds_grid_,
673  CDataType* p_c_grid_,
674  index_t NumTokens_,
675  index_t TopK_,
676  index_t M_,
677  index_t N_,
678  index_t K_,
679  index_t StrideA_,
680  index_t StrideB_,
681  std::array<index_t, NumDTensor> StrideDs_,
682  index_t StrideC_,
683  index_t k_batch_,
684  AElementwiseOperation a_element_op_,
685  BElementwiseOperation b_element_op_,
686  CElementwiseOperation c_element_op_)
687  : Problem{NumTokens_,
688  TopK_,
689  M_,
690  N_,
691  K_,
692  StrideA_,
693  StrideB_,
694  StrideDs_,
695  StrideC_,
696  k_batch_},
697  p_sorted_token_ids{p_sorted_token_ids_},
698  p_sorted_expert_ids{p_sorted_expert_ids_},
699  p_max_token_id{p_max_token_id_},
700  p_a_grid{p_a_grid_},
701  p_b_grid{p_b_grid_},
702  p_ds_grid{},
703  p_c_grid{p_c_grid_},
704  a_element_op{a_element_op_},
705  b_element_op{b_element_op_},
706  c_element_op{c_element_op_}
707  {
708 
709  // populate pointer, desc for Ds
710  static_for<0, NumDTensor, 1>{}([&](auto i) {
711  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
712 
713  // D pointer
714  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
715  });
716  }
717 
721  const ADataType* p_a_grid;
722  const BDataType* p_b_grid;
724  CDataType* p_c_grid;
725 
726  const AElementwiseOperation a_element_op;
727  const BElementwiseOperation b_element_op;
728  const CElementwiseOperation c_element_op;
729  };
730 
732  {
733  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
734  {
735  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
736  {
737  a_k_split_offset = k_id * karg.KRead / APackedSize;
738  }
739  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
740  {
741  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
742  }
743 
744  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
745  {
746  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
747  }
748  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
749  {
750  // KPack * NLane * KLane * K0 * N0
751  b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
752  }
753 
754  if(k_id < karg.KBatch - 1)
755  {
756  karg.K = karg.KRead;
757  }
758  else
759  {
760  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
761  }
762  }
763 
766  };
767 
768  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
769  {
770  // A matrix in LDS memory, dst of blockwise copy
771  if constexpr(ABlockLdsExtraM)
772  {
776  }
777  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
778  // in some cases.
780  {
781  constexpr auto a_lds_block_desc =
784 
785  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
786  a_lds_block_desc,
792 
793  return a_lds_block_desc_permuted;
794  }
795  else // ColumnMajor A
796  {
797  // kfold and mpair dimension is not always required.
798  // more dimension in merge_transform increase the difficulty of generating immarg offset
799  // for compiler.
800  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
801  constexpr auto M1 = MPerBlock / M0;
802 
803  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
804  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
805  constexpr auto KThreadRead = 64 / MPerXdl;
806  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
807 
808  constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
809  ? 1
810  : 128 / (AK1Number * M0 * sizeof(LDSTypeA));
811  constexpr auto KThreadReadPerm =
812  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
813  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
814  : KThreadRead;
815 
816  // 1<=mpair<=n0
817  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128)
818  ? 1
819  : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0
820  ? M0
821  : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA)));
822 
823  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
827  Number<kfold * M0 / mpair>{},
828  Number<mpair>{},
829  AK1Number));
830 
831  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
832  a_lds_block_desc,
833  make_tuple(
837  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
840  make_tuple(
842  make_tuple(
844 
845  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
846  a_lds_block_desc_permuted,
847  make_tuple(
855  Sequence<1>{},
856  Sequence<2>{},
857  Sequence<3>{},
858  Sequence<4>{},
859  Sequence<5>{}),
861  Sequence<2>{},
862  Sequence<0, 3>{},
863  Sequence<4, 5>{},
864  Sequence<6>{},
865  Sequence<7>{}));
866 
867  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
868  a_lds_block_desc_unmerged,
871  Number<KThreadWrite / kfold / KThreadReadPerm>{},
872  Number<kfold>{},
879 
880  return a_lds_block_desc_ak0_m_ak1;
881  }
882  }
883 
884  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
885  {
886  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
889  }
890 
892  {
893  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
894 
895  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
897  make_tuple(I1,
899  I1,
901 
902  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
903  }
904 
907  BlkGemmPipelineVer,
908  BlkGemmPipeSched,
909  BlockSize,
910  ADataType,
911  BDataType,
912  ComputeTypeA,
913  AccDataType,
920  ABlockTransferSrcScalarPerVector,
921  BBlockTransferSrcScalarPerVector,
922  MPerBlock,
923  NPerBlock,
924  KPerBlock,
925  MPerXdl,
926  NPerXdl,
927  MXdlPerWave,
928  NXdlPerWave,
929  KPack,
930  IsInputGemm>())>;
931 
932  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
933  {
934  // LDS allocation for A and B: be careful of alignment
935  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
936  // lds max alignment
937  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
938 
939  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
940  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
941 
942  // LDS allocation for C shuffle in LDS
943  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
945 
946  constexpr auto c_block_size =
947  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
948 
949  return math::max(a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize,
950  c_block_size * sizeof(CShuffleDataType));
951  }
952 
953  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
954  __host__ static constexpr bool CheckValidity(const Argument& karg)
955  {
956  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
957  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
958  "Invalid tuning param!");
959 
960  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
965  {
966  if(!(karg.M % MPerBlock == 0))
967  {
968 #if DEBUG_LOG
969  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
970  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
971  << std::endl;
972 
973 #endif // DEBUG_LOG
974  return false;
975  }
976  }
977 
978  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
983  {
984  if(!(karg.N % NPerBlock == 0))
985  {
986 #if DEBUG_LOG
987  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
988  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
989  << std::endl;
990 
991 #endif // DEBUG_LOG
992  return false;
993  }
994  }
995 
996  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1000  {
1001 
1002  auto K_t = karg.KBatch * KPerBlock;
1003  if(!(karg.K % K_t == 0))
1004  {
1005 #if DEBUG_LOG
1006  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1007  << karg.K << " " << __FILE__ << ":" << __LINE__
1008  << ", in function: " << __func__ << std::endl;
1009 
1010 #endif // DEBUG_LOG
1011  return false;
1012  }
1013  }
1014  else
1015  {
1016  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1017  auto K_t = karg.KBatch * KReadVec;
1018  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1019  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1020  {
1021  return false;
1022  }
1023  }
1024 
1026  {
1027  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1028  {
1029 #if DEBUG_LOG
1030  std::cout << "Arg K (" << karg.K
1031  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1032  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1033  << __LINE__ << ", in function: " << __func__ << std::endl;
1034 
1035 #endif // DEBUG_LOG
1036  return false;
1037  }
1038  }
1039  else
1040  {
1041  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1042  {
1043 #if DEBUG_LOG
1044  std::cout << "Arg M (" << karg.M
1045  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1046  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1047  << __LINE__ << ", in function: " << __func__ << std::endl;
1048 
1049 #endif // DEBUG_LOG
1050  return false;
1051  }
1052  }
1053 
1055  {
1056  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1057  {
1058 #if DEBUG_LOG
1059  std::cout << "Arg N (" << karg.N
1060  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1061  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1062  << __LINE__ << ", in function: " << __func__ << std::endl;
1063 
1064 #endif // DEBUG_LOG
1065  return false;
1066  }
1067  }
1068  else
1069  {
1070  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1071  {
1072 #if DEBUG_LOG
1073  std::cout << "Arg K (" << karg.K
1074  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1075  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1076  << __LINE__ << ", in function: " << __func__ << std::endl;
1077 
1078 #endif // DEBUG_LOG
1079  return false;
1080  }
1081  }
1082 
1084  {
1086  {
1087 #if DEBUG_LOG
1088  std::cout << "Arg N (" << karg.N
1089  << ") value is not a multiple of "
1090  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1091  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1092  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1093 
1094 #endif // DEBUG_LOG
1095  return false;
1096  }
1097  }
1098  else
1099  {
1101  {
1102 #if DEBUG_LOG
1103  std::cout << "Arg M (" << karg.M
1104  << ") value is not a multiple of "
1105  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1106  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1107  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1108 
1109 #endif // DEBUG_LOG
1110  return false;
1111  }
1112  }
1113 
1114  // check gridwise gemm pipeline
1115 #if 1
1116  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1117 
1118  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1119  {
1120  return false;
1121  }
1122 #endif
1123  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1124  return true;
1125  }
1126 
1127  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1128  {
1129  const index_t num_loop = K / KPerBlock;
1130 
1131  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1132  }
1133 
1134  __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1135  {
1136  const index_t num_loop = K / KPerBlock;
1137 
1138  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1139  }
1140 
1141  template <typename CGridDesc>
1143  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1144  {
1145  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1146  c_grid_desc_m_n,
1151 
1152  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1153  }
1154 
1155  // return block_id to C matrix tile idx (m0, n0) mapping
1156  // if arch = gfx942
1157  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1158  // NPerBlock>;
1159 
1160  template <bool HasMainKBlockLoop,
1161  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1162  TailNumber TailNum = TailNumber::Odd>
1163  __device__ static void Run(const index_t* p_sorted_token_ids,
1164  const index_t* p_sorted_expert_ids,
1165  const index_t* p_max_token_id,
1166  const ADataType* p_a_grid,
1167  const BDataType* p_b_grid,
1168  DsGridPointer& p_ds_grid,
1169  CDataType* p_c_grid,
1170  void* p_shared,
1171  const Problem& problem,
1172  AElementwiseOperation a_element_op,
1173  BElementwiseOperation b_element_op,
1174  CElementwiseOperation c_element_op)
1175  {
1176  ignore = b_element_op;
1177  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1178  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1179  problem.MPadded,
1180  problem.K,
1181  problem.KPadded,
1182  problem.StrideA,
1183  problem.AK0);
1184  const auto b_grid_desc_bpreshuffled =
1186  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1187  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1188  problem.MPadded,
1189  problem.N,
1190  problem.NPadded,
1191  problem.StrideC);
1192  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1194  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1195  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1196  // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
1197  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1198  if(expert_block_id * MPerBlock >= max_token_id)
1199  return;
1200  const index_t expert_id =
1201  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1202  const auto block_mn = [&]() -> std::pair<int, int> {
1203  if constexpr(NSwizzle)
1204  {
1205  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1206  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1207  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1208  const index_t expert_swizzle =
1209  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1210  const index_t bid_new = blockIdx.x - prefix_block;
1211  const index_t nid = __builtin_amdgcn_readfirstlane(
1212  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1213  const index_t mid =
1214  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1215  return {nid, mid};
1216  }
1217  else
1218  {
1219  return {blockIdx.x, blockIdx.y};
1220  }
1221  }();
1222 
1223  const index_t block_n_id = block_mn.first;
1224  const index_t block_m_id = block_mn.second;
1225  const index_t token0 =
1226  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1227 
1228  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1229  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1230  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1231  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1232  constexpr auto AKThreads = AK0Threads * AK1Threads;
1233  constexpr auto AMRepeats = MPerBlock / AMThreads;
1234  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1235 
1236  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1237  return;
1239  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1240  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1241  index_t token_offset = fused_token & 0xffffff;
1242  if constexpr(!IsInputGemm)
1243  {
1244  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1245  }
1246  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1247  });
1248  const index_t expert_stride =
1249  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1250 
1251  // N0, K0, Blocksize*KPack
1252  const index_t n_block_data_idx_on_grid =
1253  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1254 
1255  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1256  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1257  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1258  p_b_grid + expert_id * expert_stride / BPackedSize,
1259  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1260  // A matrix in LDS memory, dst of blockwise copy
1261  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1262 
1263  // B matrix in LDS memory, dst of blockwise copy
1264  // dummy
1265  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1266  // A matrix blockwise copy
1267  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1269  AElementwiseOperation,
1273  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1274  ABlockTransferThreadClusterArrangeOrder,
1275  ADataType,
1276  LDSTypeA,
1277  decltype(a_grid_desc_ak0_m_ak1),
1278  decltype(a_block_desc_ak0_m_ak1),
1279  ABlockTransferSrcAccessOrder,
1281  ABlockTransferSrcVectorDim,
1282  2,
1283  ABlockTransferSrcScalarPerVector,
1284  ABlockTransferDstScalarPerVector_AK1,
1285  1,
1286  1,
1287  AThreadTransferSrcResetCoordinateAfterRun,
1288  true,
1289  IndexType,
1290  1,
1291  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1292  make_multi_index(0, 0, 0),
1293  a_element_op,
1294  a_block_desc_ak0_m_ak1,
1295  make_multi_index(0, 0, 0),
1297  gather_offsets);
1298 
1299  // Thread-wise copy
1300  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1301  auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1302  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1303 
1304  auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
1305  BDataType,
1306  BDataType,
1307  decltype(b_grid_desc_bpreshuffled),
1308  decltype(b_block_desc_bk0_n_bk1),
1311  3,
1312  BBlockTransferSrcScalarPerVector,
1313  BThreadTransferSrcResetCoordinateAfterRun,
1314  true>(b_grid_desc_bpreshuffled,
1315  make_multi_index(n_block_data_idx_on_grid,
1317  0,
1318  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1319 
1320  // LDS allocation for A and B: be careful of alignment
1321  // Cast after lds
1322  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1323  static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1324 
1325  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1326  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
1327 
1328  // Blockwise GEMM pipeline
1329  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1330  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1331  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1332  decltype(c_thread_buf) c_thread_buf_up;
1333 
1335  float,
1336  c_thread_buf.num_of_v_,
1337  c_thread_buf.s_per_v,
1338  true>
1339  c_thread_buf_fp32;
1340 
1341  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1342  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1343  KPerBlock);
1344  if constexpr(IsInputGemm)
1345  {
1346  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
1347  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1348  p_b_grid_up + expert_id * expert_stride / BPackedSize,
1349  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1350  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
1351  BDataType,
1352  BDataType,
1353  decltype(b_grid_desc_bpreshuffled),
1354  decltype(b_block_desc_bk0_n_bk1),
1357  3,
1358  BBlockTransferSrcScalarPerVector,
1359  BThreadTransferSrcResetCoordinateAfterRun,
1360  true>(b_grid_desc_bpreshuffled,
1361  make_multi_index(n_block_data_idx_on_grid,
1363  0,
1364  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1365 
1366  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1367  a_grid_desc_ak0_m_ak1,
1368  a_block_desc_ak0_m_ak1,
1369  a_blockwise_copy,
1370  a_grid_buf,
1371  a_block_buf,
1372  a_block_slice_copy_step,
1373  b_grid_desc_bpreshuffled,
1374  b_blockwise_copy,
1375  b_blockwise_copy_up,
1376  b_grid_buf,
1377  b_grid_buf_up,
1378  b_block_buf,
1379  b_block_slice_copy_step,
1380  c_thread_buf,
1381  c_thread_buf_up,
1382  num_k_block_main_loop);
1383  }
1384  else
1385  {
1386  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1387  a_grid_desc_ak0_m_ak1,
1388  a_block_desc_ak0_m_ak1,
1389  a_blockwise_copy,
1390  a_grid_buf,
1391  a_block_buf,
1392  a_block_slice_copy_step,
1393  b_grid_desc_bpreshuffled,
1394  b_blockwise_copy,
1395  b_grid_buf,
1396  b_block_buf,
1397  b_block_slice_copy_step,
1398  c_thread_buf,
1399  num_k_block_main_loop);
1400  }
1401 
1402  // shuffle C and write out
1403  {
1404  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1405  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1406  "wrong!");
1407 
1408  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1409 
1410  // TODO: hacky, fix it!
1411  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1412  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1413 
1414  // TODO: hacky, fix it!
1415  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1416  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1417  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1418 
1419  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1420  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1421  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1422  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1423  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1424  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1425  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1426  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1427 
1428  // mul scales
1429  const float* p_sorted_weights_0 = p_ds_grid[I0];
1430  const float* p_scale_b = p_ds_grid[I1];
1431 
1432  static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
1433  static_assert(M4 == 4);
1434  const index_t m1 = get_warp_local_1d_id() / NWave;
1435  const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
1436 
1437  if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr)
1438  {
1439  if constexpr(PerTokenQuant)
1440  {
1441  constexpr index_t scale_stride = (IsInputGemm ? 2 : 1);
1442  p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock +
1443  get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl;
1444  }
1445  else
1446  {
1447  p_scale_b += expert_id;
1448  }
1449 
1450  vector_type<int32_t, 4> scale_token_ids;
1451  vector_type<float, 4> topk_weights;
1452  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
1453  const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant];
1454  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
1455  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
1456  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
1457  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
1458  if constexpr(PerTokenQuant)
1459  {
1460  scale_token_ids =
1461  *c_style_pointer_cast<const vector_type<int32_t, M4>*>(
1462  p_sorted_token_ids + m_pos);
1463  }
1464  if constexpr(MulRoutedWeight)
1465  {
1466  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
1467  p_ds_grid[I2] + m_pos);
1468  }
1469  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
1470  float scale_a = [&]() {
1471  if constexpr(PerTokenQuant)
1472  {
1473  index_t fused_token = scale_token_ids.AsType<index_t>()[m4];
1474  const index_t token_offset = fused_token & 0xffffff;
1475  return token_offset < problem.NumTokens
1476  ? p_sorted_weights_0[IsInputGemm
1477  ? token_offset
1478  : token_offset *
1479  problem.TopK +
1480  (fused_token >>
1481  24)]
1482  : 0.0;
1483  }
1484  else
1485  {
1486  return p_sorted_weights_0[0];
1487  }
1488  }();
1489  constexpr index_t c_offset =
1490  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1491  make_tuple(m0, n0, m2 * M4 + m4));
1492  constexpr auto cidx = Number<c_offset>{};
1493  if constexpr(IsInputGemm) // gu fusion
1494  {
1495  if constexpr(ActivationOperation == Activation::silu_and_mul)
1496  {
1497  const float scale_up =
1498  p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
1499  PerTokenQuant];
1500  float gate = scale_a * scale_b * c_thread_buf[cidx];
1501  float up = scale_a * scale_up * c_thread_buf_up[cidx];
1502  if constexpr(MulRoutedWeight)
1503  {
1504  gate = gate * topk_weights.AsType<float>()[m4];
1505  up = up * topk_weights.AsType<float>()[m4];
1506  }
1508  {
1509  gate *= 16;
1510  up *= 16;
1511  }
1513  c_thread_buf_fp32(cidx) = gate * up;
1514  }
1515  else if(ActivationOperation == Activation::gelu_and_mul)
1516  {
1517  const float scale_up =
1518  p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
1519  PerTokenQuant];
1520  float gate = scale_a * scale_b * c_thread_buf[cidx];
1521  float up = scale_a * scale_up * c_thread_buf_up[cidx];
1522  if constexpr(MulRoutedWeight)
1523  {
1524  gate = gate * topk_weights.AsType<float>()[m4];
1525  up = up * topk_weights.AsType<float>()[m4];
1526  }
1528  {
1529  gate *= 16;
1530  up *= 16;
1531  }
1533  c_thread_buf_fp32(cidx) = gate * up;
1534  }
1535  }
1536  else
1537  {
1538  c_thread_buf_fp32(cidx) =
1539  scale_a * scale_b * c_thread_buf[cidx];
1540  if constexpr(MulRoutedWeight)
1541  {
1542  c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) *
1543  topk_weights.AsType<float>()[m4];
1544  }
1545  }
1546  });
1547  });
1548  });
1549  });
1550  }
1551  else
1552  {
1553  vector_type<float, 4> topk_weights; // for gemm2 only
1554  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
1555  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
1556  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
1557  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
1558  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
1559  if constexpr(MulRoutedWeight)
1560  {
1561  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
1562  p_ds_grid[I2] + m_pos);
1563  }
1564  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
1565  constexpr index_t c_offset =
1566  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1567  make_tuple(m0, n0, m2 * M4 + m4));
1568  constexpr auto cidx = Number<c_offset>{};
1569 
1570  if constexpr(IsInputGemm) // gu fusion
1571  {
1572  if constexpr(ActivationOperation == Activation::silu_and_mul)
1573  {
1574  float gate = c_thread_buf[cidx];
1575  float up = c_thread_buf_up[cidx];
1576  if constexpr(MulRoutedWeight)
1577  {
1578  gate = gate * topk_weights.AsType<float>()[m4];
1579  up = up * topk_weights.AsType<float>()[m4];
1580  }
1582  c_thread_buf_fp32(cidx) = gate * up;
1583  }
1584  else if(ActivationOperation == Activation::gelu_and_mul)
1585  {
1586  float gate = c_thread_buf[cidx];
1587  float up = c_thread_buf_up[cidx];
1588  if constexpr(MulRoutedWeight)
1589  {
1590  gate = gate * topk_weights.AsType<float>()[m4];
1591  up = up * topk_weights.AsType<float>()[m4];
1592  }
1594  c_thread_buf_fp32(cidx) = gate * up;
1595  }
1596  }
1597  else
1598  {
1599  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1600  if constexpr(MulRoutedWeight)
1601  {
1602  c_thread_buf_fp32(cidx) = topk_weights.AsType<float>()[m4] *
1603  c_thread_buf_fp32[cidx];
1604  }
1605  }
1606  });
1607  });
1608  });
1609  });
1610  }
1611 
1612  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1614 
1615  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1616  static_cast<CShuffleDataType*>(p_shared),
1617  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1618 
1619  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1620  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1621  make_tuple(
1624  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1625  M1, // M1 = MWave
1626  M2, // M2 * M3 * M4 = MPerXdl
1627  M3,
1628  M4)),
1631  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1632  N1, // N1 = NWave
1633  N2))), // N2 = NPerXdl
1635  make_tuple(
1637 
1638  // calculate origin of thread output tensor on global memory
1639  // blockwise GEMM c matrix starting index
1640  const auto c_thread_mtx_on_block =
1641  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1642 
1643  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1644  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1645 
1646  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1648  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1650  make_tuple(Sequence<0>{}));
1651 
1652  const auto m_thread_data_on_block_idx =
1653  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1654  make_multi_index(m_thread_data_on_block));
1655 
1656  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1660  make_tuple(Sequence<0>{}));
1661 
1662  const auto n_thread_data_on_block_idx =
1663  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1664  make_multi_index(n_thread_data_on_block));
1665 
1666  // shuffle: threadwise copy C from VGPR to LDS
1667  auto c_thread_copy_vgpr_to_lds =
1669  CShuffleDataType,
1670  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1671  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1673  Sequence<CShuffleMXdlPerWavePerShuffle,
1674  CShuffleNXdlPerWavePerShuffle,
1675  I1,
1676  I1,
1677  M2,
1678  I1,
1679  M4,
1680  I1>,
1682  7,
1683  1,
1685  1,
1686  true>{
1687  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1688  make_multi_index(0,
1689  0,
1690  m_thread_data_on_block_idx[I1],
1691  n_thread_data_on_block_idx[I1],
1692  m_thread_data_on_block_idx[I2],
1693  m_thread_data_on_block_idx[I3],
1694  m_thread_data_on_block_idx[I4],
1695  n_thread_data_on_block_idx[I2]),
1697 
1698  using EDataType = CDataType;
1699 
1700  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1701  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1702 
1703  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1705  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1706 
1707  const auto ds_grid_buf = generate_tuple(
1708  [&](auto i) {
1709  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1710  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1711  },
1712  Number<NumDTensor>{});
1713 
1714  // tuple of reference to C/Ds tensor descriptors
1715  const auto c_ds_desc_refs = concat_tuple_of_reference(
1716  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1717  generate_tie(
1718  [&](auto i) -> const auto& // return type should be reference
1719  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1720  Number<NumDTensor>{}));
1721 
1722  // tuple of reference to C/Ds tensor descriptors
1723  const auto c_ds_buf_refs = concat_tuple_of_reference(
1724  tie(c_shuffle_block_buf),
1725  generate_tie(
1726  [&](auto i) -> const auto& // return type should be reference
1727  { return ds_grid_buf[i]; },
1728  Number<NumDTensor>{}));
1729 
1730  // tuple of starting index of C/Ds blockwise copy
1731  const auto idx_c_ds_block_begin =
1734  [&](auto) {
1735  return make_multi_index(block_m_id, 0, block_n_id, 0);
1736  // return make_multi_index(block_work_idx[I0], 0,
1737  // block_work_idx[I1], 0);
1738  },
1739  Number<NumDTensor>{}));
1740 
1741  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1742  c_grid_desc_mblock_mperblock_nblock_nperblock;
1743 
1744  using CDEBlockTransferCluster =
1745  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1746  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1747  constexpr index_t scatter_weight_idx = 3; // hack fix felix
1748  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1750  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1752  decltype(c_ds_desc_refs),
1753  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1754  CElementwiseOperation,
1755  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1756  // support arbitray type
1757  Sequence<1,
1758  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1759  1,
1760  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1761  CDEBlockTransferCluster,
1762  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1763  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1764  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1765  3, // index_t SrcVectorDim,
1766  3, // index_t DstVectorDim,
1767  CDEShuffleBlockTransferScalarPerVectors,
1772  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1773  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1774  IndexType,
1775  1, // ScatterDim
1776  true, // OutputScatter: false, only use scatter weights
1777  scatter_weight_idx // ScatterWeightIdx: ascale
1778  >{c_ds_desc_refs,
1779  idx_c_ds_block_begin,
1780  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1781  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
1782  c_element_op};
1783 
1784  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1785  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1786  constexpr auto sfc_c_vgpr =
1789  Sequence<CShuffleMXdlPerWavePerShuffle,
1790  CShuffleNXdlPerWavePerShuffle,
1791  1,
1792  1,
1793  M2,
1794  1,
1795  M4,
1796  1>>{};
1797 
1798  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1799 
1800  // space filling curve for shuffled blockwise C/D/E
1801  constexpr auto sfc_cde_block =
1804  Sequence<1,
1805  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1806  1,
1807  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1808 
1809  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1810  constexpr auto EMThreads =
1811  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
1812  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1813  constexpr auto ENThreads =
1814  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
1815  static_for<0, num_access, 1>{}([&](auto access_id) {
1816  // make sure it's safe to write to LDS
1818 
1819  auto dstidx = sfc_cde_block.GetIndex(access_id);
1820  const index_t c_token_pos =
1821  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
1822  static_for<0, EMRepeats, 1>{}([&](auto m0) {
1823  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1824  IndexType token_offset = fused_token & 0xffffff;
1825  if constexpr(IsInputGemm)
1826  {
1827  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1828  }
1829  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
1830  });
1831 
1832  block_sync_lds();
1833 
1834  // each thread write its data from VGPR to LDS
1835  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1836  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1837  c_thread_buf_fp32,
1838  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1839  c_shuffle_block_buf);
1840 
1841  // make sure it's safe to read from LDS
1842  block_sync_lds();
1843 
1844  // each block copy its data from LDS to global
1845  cde_block_copy_lds_and_global.Run(
1846  c_ds_desc_refs,
1847  c_ds_buf_refs,
1848  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1849  tie(c_grid_buf),
1850  scatter_offsets);
1851 
1852  if constexpr(access_id < num_access - 1)
1853  {
1854  constexpr auto cde_lds_and_global_step =
1855  sfc_cde_block.GetForwardStep(access_id);
1856 
1857  // move on Ds
1858  static_for<0, NumDTensor, 1>{}([&](auto i) {
1859  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1860  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
1861  });
1862 
1863  // move on E
1864  cde_block_copy_lds_and_global.MoveDstSliceWindow(
1865  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1866  I0,
1867  cde_lds_and_global_step);
1868  }
1869  });
1870  }
1871  }
1872 
1873  template <bool HasMainKBlockLoop,
1874  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1875  TailNumber TailNum = TailNumber::Odd>
1876  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
1877  const index_t* p_sorted_expert_ids,
1878  const index_t* p_max_token_id,
1879  const ADataType* p_a_grid,
1880  const BDataType* p_b_grid,
1881  DsGridPointer& p_ds_grid,
1882  CDataType* p_c_grid,
1883  void* p_shared,
1884  void* p_shared1,
1885  const Problem& problem,
1886  AElementwiseOperation a_element_op,
1887  BElementwiseOperation b_element_op,
1888  CElementwiseOperation c_element_op)
1889  {
1890  ignore = b_element_op;
1891  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1892  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1893  problem.MPadded,
1894  problem.K,
1895  problem.KPadded,
1896  problem.StrideA,
1897  problem.AK0);
1898  const auto b_grid_desc_bpreshuffled =
1900  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1901  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1902  problem.MPadded,
1903  problem.N,
1904  problem.NPadded,
1905  problem.StrideC);
1906  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1908  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1909  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1910  // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
1911  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1912  if(expert_block_id * MPerBlock >= max_token_id)
1913  return;
1914  const index_t expert_id =
1915  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1916  const auto block_mn = [&]() -> std::pair<int, int> {
1917  if constexpr(NSwizzle)
1918  {
1919  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1920  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1921  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1922  const index_t expert_swizzle =
1923  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1924  const index_t bid_new = blockIdx.x - prefix_block;
1925  const index_t nid = __builtin_amdgcn_readfirstlane(
1926  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1927  const index_t mid =
1928  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1929  return {nid, mid};
1930  }
1931  else
1932  {
1933  return {blockIdx.x, blockIdx.y};
1934  }
1935  }();
1936 
1937  const index_t block_n_id = block_mn.first;
1938  const index_t block_m_id = block_mn.second;
1939  const index_t token0 =
1940  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1941 
1942  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1943  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1944  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1945  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1946  constexpr auto AKThreads = AK0Threads * AK1Threads;
1947  constexpr auto AMRepeats = MPerBlock / AMThreads;
1948  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1949 
1950  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1951  return;
1953  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1954  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1955  index_t token_offset = fused_token & 0xffffff;
1956  if constexpr(!IsInputGemm)
1957  {
1958  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1959  }
1960  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1961  });
1962  const index_t expert_stride =
1963  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1964 
1965  // N0, K0, Blocksize*KPack
1966  const index_t n_block_data_idx_on_grid =
1967  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1968 
1969  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1970  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1971  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1972  p_b_grid + expert_id * expert_stride / BPackedSize,
1973  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1974 
1975  // A matrix in LDS memory, dst of blockwise copy
1976  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1977 
1978  // B matrix in LDS memory, dst of blockwise copy
1979  // dummy
1980  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1981  // A matrix blockwise copy
1982  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1984  AElementwiseOperation,
1988  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1989  ABlockTransferThreadClusterArrangeOrder,
1990  ADataType,
1991  LDSTypeA,
1992  decltype(a_grid_desc_ak0_m_ak1),
1993  decltype(a_block_desc_ak0_m_ak1),
1994  ABlockTransferSrcAccessOrder,
1996  ABlockTransferSrcVectorDim,
1997  2,
1998  ABlockTransferSrcScalarPerVector,
1999  ABlockTransferDstScalarPerVector_AK1,
2000  1,
2001  1,
2002  AThreadTransferSrcResetCoordinateAfterRun,
2003  true,
2004  IndexType,
2005  1,
2006  2>(a_grid_desc_ak0_m_ak1,
2007  make_multi_index(0, 0, 0),
2008  a_element_op,
2009  a_block_desc_ak0_m_ak1,
2010  make_multi_index(0, 0, 0),
2012  gather_offsets);
2013 
2014  // Thread-wise copy
2015  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2016  auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2017  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2018  auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2019  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2020  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2021 
2022  auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
2023  BDataType,
2024  BDataType,
2025  decltype(b_grid_desc_bpreshuffled),
2026  decltype(b_block_desc_bk0_n_bk1),
2029  3,
2030  BBlockTransferSrcScalarPerVector,
2031  BThreadTransferSrcResetCoordinateAfterRun,
2032  true>(b_grid_desc_bpreshuffled,
2033  make_multi_index(n_block_data_idx_on_grid,
2035  0,
2036  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2037 
2038  // LDS allocation for A and B: be careful of alignment
2039  // Cast after lds
2040  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2041  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2042  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2043  static_cast<ADataType*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2044  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2045 
2046  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2047  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
2048 
2049  // Blockwise GEMM pipeline
2050  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2051  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2052  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2053  decltype(c_thread_buf) c_thread_buf_up;
2054 
2056  float,
2057  c_thread_buf.num_of_v_,
2058  c_thread_buf.s_per_v,
2059  true>
2060  c_thread_buf_fp32;
2061 
2062  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2063  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2064  KPerBlock);
2065 
2066  if constexpr(IsInputGemm)
2067  {
2068  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
2069  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2070  p_b_grid_up + expert_id * expert_stride / BPackedSize,
2071  b_grid_desc_bpreshuffled.GetElementSpaceSize());
2072  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
2073  BDataType,
2074  BDataType,
2075  decltype(b_grid_desc_bpreshuffled),
2076  decltype(b_block_desc_bk0_n_bk1),
2079  3,
2080  BBlockTransferSrcScalarPerVector,
2081  BThreadTransferSrcResetCoordinateAfterRun,
2082  true>(b_grid_desc_bpreshuffled,
2083  make_multi_index(n_block_data_idx_on_grid,
2085  0,
2086  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2087  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2088  a_grid_desc_ak0_m_ak1,
2089  a_block_desc_ak0_m_ak1,
2090  a_blockwise_copy,
2091  a_grid_buf,
2092  a_block_bufs,
2093  a_block_slice_copy_step,
2094  b_grid_desc_bpreshuffled,
2095  b_blockwise_copy,
2096  b_blockwise_copy_up,
2097  b_grid_buf,
2098  b_grid_buf_up,
2099  b_block_bufs,
2100  b_block_slice_copy_step,
2101  c_thread_buf,
2102  c_thread_buf_up,
2103  num_k_block_main_loop);
2104  }
2105  else
2106  {
2107 
2108  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2109  a_grid_desc_ak0_m_ak1,
2110  a_block_desc_ak0_m_ak1,
2111  a_blockwise_copy,
2112  a_grid_buf,
2113  a_block_bufs,
2114  a_block_slice_copy_step,
2115  b_grid_desc_bpreshuffled,
2116  b_blockwise_copy,
2117  b_grid_buf,
2118  b_block_bufs,
2119  b_block_slice_copy_step,
2120  c_thread_buf,
2121  num_k_block_main_loop);
2122  }
2123 
2124  // shuffle C and write out
2125  {
2126  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2127  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2128  "wrong!");
2129 
2130  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2131 
2132  // TODO: hacky, fix it!
2133  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2134  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2135 
2136  // TODO: hacky, fix it!
2137  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2138  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2139  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2140 
2141  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2142  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2143  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2144  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2145  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2146  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2147  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2148  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2149 
2150  // mul scales
2151  const float* p_sorted_weights_0 = p_ds_grid[I0];
2152  const float* p_scale_b = p_ds_grid[I1];
2153 
2154  static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
2155  static_assert(M4 == 4);
2156  const index_t m1 = get_warp_local_1d_id() / NWave;
2157  const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
2158 
2159  if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr)
2160  {
2161  if constexpr(PerTokenQuant)
2162  {
2163  constexpr index_t scale_stride = (IsInputGemm ? 2 : 1);
2164  p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock +
2165  get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl;
2166  }
2167  else
2168  {
2169  p_scale_b += expert_id;
2170  }
2171 
2172  vector_type<int32_t, 4> scale_token_ids;
2173  vector_type<float, 4> topk_weights;
2174  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
2175  const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant];
2176  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
2177  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
2178  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2179  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2180  if constexpr(PerTokenQuant)
2181  {
2182  scale_token_ids =
2183  *c_style_pointer_cast<const vector_type<int32_t, M4>*>(
2184  p_sorted_token_ids + m_pos);
2185  }
2186  if constexpr(MulRoutedWeight)
2187  {
2188  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
2189  p_ds_grid[I2] + m_pos);
2190  }
2191  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
2192  float scale_a = [&]() {
2193  if constexpr(PerTokenQuant)
2194  {
2195  index_t fused_token = scale_token_ids.AsType<index_t>()[m4];
2196  const index_t token_offset = fused_token & 0xffffff;
2197  return token_offset < problem.NumTokens
2198  ? p_sorted_weights_0[IsInputGemm
2199  ? token_offset
2200  : token_offset *
2201  problem.TopK +
2202  (fused_token >>
2203  24)]
2204  : 0.0;
2205  }
2206  else
2207  {
2208  return p_sorted_weights_0[0];
2209  }
2210  }();
2211  constexpr index_t c_offset =
2212  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2213  make_tuple(m0, n0, m2 * M4 + m4));
2214  constexpr auto cidx = Number<c_offset>{};
2215  if constexpr(IsInputGemm) // gu fusion
2216  {
2217  if constexpr(ActivationOperation == Activation::silu_and_mul)
2218  {
2219  const float scale_up =
2220  p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
2221  PerTokenQuant];
2222  float gate = scale_a * scale_b * c_thread_buf[cidx];
2223  float up = scale_a * scale_up * c_thread_buf_up[cidx];
2224  if constexpr(MulRoutedWeight)
2225  {
2226  gate = gate * topk_weights.AsType<float>()[m4];
2227  up = up * topk_weights.AsType<float>()[m4];
2228  }
2230  {
2231  gate *= 16;
2232  up *= 16;
2233  }
2235  c_thread_buf_fp32(cidx) = gate * up;
2236  }
2237  else if(ActivationOperation == Activation::gelu_and_mul)
2238  {
2239  const float scale_up =
2240  p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
2241  PerTokenQuant];
2242  float gate = scale_a * scale_b * c_thread_buf[cidx];
2243  float up = scale_a * scale_up * c_thread_buf_up[cidx];
2244  if constexpr(MulRoutedWeight)
2245  {
2246  gate = gate * topk_weights.AsType<float>()[m4];
2247  up = up * topk_weights.AsType<float>()[m4];
2248  }
2250  {
2251  gate *= 16;
2252  up *= 16;
2253  }
2255  c_thread_buf_fp32(cidx) = gate * up;
2256  }
2257  }
2258  else
2259  {
2260  c_thread_buf_fp32(cidx) =
2261  scale_a * scale_b * c_thread_buf[cidx];
2262  if constexpr(MulRoutedWeight)
2263  {
2264  c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) *
2265  topk_weights.AsType<float>()[m4];
2266  }
2267  }
2268  });
2269  });
2270  });
2271  });
2272  }
2273  else
2274  {
2275  vector_type<float, 4> topk_weights; // for gemm2 only
2276  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
2277  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
2278  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
2279  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2280  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2281  if constexpr(MulRoutedWeight)
2282  {
2283  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
2284  p_ds_grid[I2] + m_pos);
2285  }
2286  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
2287  constexpr index_t c_offset =
2288  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2289  make_tuple(m0, n0, m2 * M4 + m4));
2290  constexpr auto cidx = Number<c_offset>{};
2291 
2292  if constexpr(IsInputGemm) // gu fusion
2293  {
2294  if constexpr(ActivationOperation == Activation::silu_and_mul)
2295  {
2296  float gate = c_thread_buf[cidx];
2297  float up = c_thread_buf_up[cidx];
2298  if constexpr(MulRoutedWeight)
2299  {
2300  gate = gate * topk_weights.AsType<float>()[m4];
2301  up = up * topk_weights.AsType<float>()[m4];
2302  }
2304  c_thread_buf_fp32(cidx) = gate * up;
2305  }
2306  else if(ActivationOperation == Activation::gelu_and_mul)
2307  {
2308  float gate = c_thread_buf[cidx];
2309  float up = c_thread_buf_up[cidx];
2310  if constexpr(MulRoutedWeight)
2311  {
2312  gate = gate * topk_weights.AsType<float>()[m4];
2313  up = up * topk_weights.AsType<float>()[m4];
2314  }
2316  c_thread_buf_fp32(cidx) = gate * up;
2317  }
2318  }
2319  else
2320  {
2321  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2322  if constexpr(MulRoutedWeight)
2323  {
2324  c_thread_buf_fp32(cidx) = topk_weights.AsType<float>()[m4] *
2325  c_thread_buf_fp32[cidx];
2326  }
2327  }
2328  });
2329  });
2330  });
2331  });
2332  }
2333 
2334  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2336 
2337  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2338  static_cast<CShuffleDataType*>(p_shared),
2339  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2340 
2341  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2342  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2343  make_tuple(
2346  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
2347  M1, // M1 = MWave
2348  M2, // M2 * M3 * M4 = MPerXdl
2349  M3,
2350  M4)),
2353  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
2354  N1, // N1 = NWave
2355  N2))), // N2 = NPerXdl
2357  make_tuple(
2359 
2360  // calculate origin of thread output tensor on global memory
2361  // blockwise GEMM c matrix starting index
2362  const auto c_thread_mtx_on_block =
2363  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2364 
2365  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2366  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2367 
2368  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2370  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
2372  make_tuple(Sequence<0>{}));
2373 
2374  const auto m_thread_data_on_block_idx =
2375  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2376  make_multi_index(m_thread_data_on_block));
2377 
2378  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2382  make_tuple(Sequence<0>{}));
2383 
2384  const auto n_thread_data_on_block_idx =
2385  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2386  make_multi_index(n_thread_data_on_block));
2387 
2388  // shuffle: threadwise copy C from VGPR to LDS
2389  auto c_thread_copy_vgpr_to_lds =
2391  CShuffleDataType,
2392  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2393  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2395  Sequence<CShuffleMXdlPerWavePerShuffle,
2396  CShuffleNXdlPerWavePerShuffle,
2397  I1,
2398  I1,
2399  M2,
2400  I1,
2401  M4,
2402  I1>,
2404  7,
2405  1,
2407  1,
2408  true>{
2409  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2410  make_multi_index(0,
2411  0,
2412  m_thread_data_on_block_idx[I1],
2413  n_thread_data_on_block_idx[I1],
2414  m_thread_data_on_block_idx[I2],
2415  m_thread_data_on_block_idx[I3],
2416  m_thread_data_on_block_idx[I4],
2417  n_thread_data_on_block_idx[I2]),
2419 
2420  using EDataType = CDataType;
2421 
2422  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2423  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2424 
2425  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2427  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2428 
2429  const auto ds_grid_buf = generate_tuple(
2430  [&](auto i) {
2431  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2432  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2433  },
2434  Number<NumDTensor>{});
2435 
2436  // tuple of reference to C/Ds tensor descriptors
2437  const auto c_ds_desc_refs = concat_tuple_of_reference(
2438  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2439  generate_tie(
2440  [&](auto i) -> const auto& // return type should be reference
2441  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2442  Number<NumDTensor>{}));
2443 
2444  // tuple of reference to C/Ds tensor descriptors
2445  const auto c_ds_buf_refs = concat_tuple_of_reference(
2446  tie(c_shuffle_block_buf),
2447  generate_tie(
2448  [&](auto i) -> const auto& // return type should be reference
2449  { return ds_grid_buf[i]; },
2450  Number<NumDTensor>{}));
2451 
2452  // tuple of starting index of C/Ds blockwise copy
2453  const auto idx_c_ds_block_begin =
2456  [&](auto) {
2457  return make_multi_index(block_m_id, 0, block_n_id, 0);
2458  // return make_multi_index(block_work_idx[I0], 0,
2459  // block_work_idx[I1], 0);
2460  },
2461  Number<NumDTensor>{}));
2462 
2463  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2464  c_grid_desc_mblock_mperblock_nblock_nperblock;
2465 
2466  using CDEBlockTransferCluster =
2467  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2468  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2469  constexpr index_t scatter_weight_idx = 3; // hack fix felix
2470  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2472  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2474  decltype(c_ds_desc_refs),
2475  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2476  CElementwiseOperation,
2477  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
2478  // support arbitray type
2479  Sequence<1,
2480  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2481  1,
2482  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2483  CDEBlockTransferCluster,
2484  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2485  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2486  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2487  3, // index_t SrcVectorDim,
2488  3, // index_t DstVectorDim,
2489  CDEShuffleBlockTransferScalarPerVectors,
2494  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2495  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2496  IndexType,
2497  1, // ScatterDim
2498  true, // OutputScatter: false, only use scatter weights
2499  scatter_weight_idx // ScatterWeightIdx: ascale
2500  >{c_ds_desc_refs,
2501  idx_c_ds_block_begin,
2502  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2503  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2504  c_element_op};
2505 
2506  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2507  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2508  constexpr auto sfc_c_vgpr =
2511  Sequence<CShuffleMXdlPerWavePerShuffle,
2512  CShuffleNXdlPerWavePerShuffle,
2513  1,
2514  1,
2515  M2,
2516  1,
2517  M4,
2518  1>>{};
2519 
2520  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2521 
2522  // space filling curve for shuffled blockwise C/D/E
2523  constexpr auto sfc_cde_block =
2526  Sequence<1,
2527  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2528  1,
2529  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2530 
2531  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2532  constexpr auto EMThreads =
2533  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2534  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2535  constexpr auto ENThreads =
2536  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2537  static_for<0, num_access, 1>{}([&](auto access_id) {
2538  // make sure it's safe to write to LDS
2540 
2541  auto dstidx = sfc_cde_block.GetIndex(access_id);
2542  const index_t c_token_pos =
2543  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2544  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2545  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2546  IndexType token_offset = fused_token & 0xffffff;
2547  if constexpr(IsInputGemm)
2548  {
2549  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2550  }
2551  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2552  });
2553 
2554  block_sync_lds();
2555 
2556  // each thread write its data from VGPR to LDS
2557  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2558  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2559  c_thread_buf_fp32,
2560  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2561  c_shuffle_block_buf);
2562 
2563  // make sure it's safe to read from LDS
2564  block_sync_lds();
2565 
2566  // each block copy its data from LDS to global
2567  cde_block_copy_lds_and_global.Run(
2568  c_ds_desc_refs,
2569  c_ds_buf_refs,
2570  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2571  tie(c_grid_buf),
2572  scatter_offsets);
2573 
2574  if constexpr(access_id < num_access - 1)
2575  {
2576  constexpr auto cde_lds_and_global_step =
2577  sfc_cde_block.GetForwardStep(access_id);
2578 
2579  // move on Ds
2580  static_for<0, NumDTensor, 1>{}([&](auto i) {
2581  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2582  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2583  });
2584 
2585  // move on E
2586  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2587  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2588  I0,
2589  cde_lds_and_global_step);
2590  }
2591  });
2592  }
2593  }
2594 };
2595 
2596 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:269
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:23
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:278
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
constexpr auto BlockGemmBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp:41
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:300
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:81
Definition: gridwise_moe_gemm.hpp:666
const BDataType * p_b_grid
Definition: gridwise_moe_gemm.hpp:722
const index_t * p_sorted_token_ids
Definition: gridwise_moe_gemm.hpp:718
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_gemm.hpp:719
const AElementwiseOperation a_element_op
Definition: gridwise_moe_gemm.hpp:726
const ADataType * p_a_grid
Definition: gridwise_moe_gemm.hpp:721
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_gemm.hpp:667
const index_t * p_max_token_id
Definition: gridwise_moe_gemm.hpp:720
const BElementwiseOperation b_element_op
Definition: gridwise_moe_gemm.hpp:727
CDataType * p_c_grid
Definition: gridwise_moe_gemm.hpp:724
DsGridPointer p_ds_grid
Definition: gridwise_moe_gemm.hpp:723
const CElementwiseOperation c_element_op
Definition: gridwise_moe_gemm.hpp:728
Definition: gridwise_moe_gemm.hpp:586
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_gemm.hpp:648
index_t NumTokens
Definition: gridwise_moe_gemm.hpp:641
index_t MBlock
Definition: gridwise_moe_gemm.hpp:657
index_t BK0Shuffled
Definition: gridwise_moe_gemm.hpp:661
index_t TopK
Definition: gridwise_moe_gemm.hpp:642
index_t K
Definition: gridwise_moe_gemm.hpp:645
__host__ void Print() const
Definition: gridwise_moe_gemm.hpp:620
index_t NPadded
Definition: gridwise_moe_gemm.hpp:652
index_t BK0
Definition: gridwise_moe_gemm.hpp:656
index_t KRead
Definition: gridwise_moe_gemm.hpp:653
index_t MPadded
Definition: gridwise_moe_gemm.hpp:651
index_t AK0
Definition: gridwise_moe_gemm.hpp:655
index_t StrideA
Definition: gridwise_moe_gemm.hpp:646
index_t StrideC
Definition: gridwise_moe_gemm.hpp:649
index_t M
Definition: gridwise_moe_gemm.hpp:643
index_t KBatch
Definition: gridwise_moe_gemm.hpp:650
index_t BN0Shuffled
Definition: gridwise_moe_gemm.hpp:660
__host__ __device__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_gemm.hpp:587
index_t KPadded
Definition: gridwise_moe_gemm.hpp:654
index_t StrideB
Definition: gridwise_moe_gemm.hpp:647
index_t N
Definition: gridwise_moe_gemm.hpp:644
index_t NBlock
Definition: gridwise_moe_gemm.hpp:658
Definition: gridwise_moe_gemm.hpp:732
index_t a_k_split_offset
Definition: gridwise_moe_gemm.hpp:764
index_t b_k_split_offset
Definition: gridwise_moe_gemm.hpp:765
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_gemm.hpp:733
Definition: gridwise_moe_gemm.hpp:165
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_gemm.hpp:240
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:292
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_gemm.hpp:211
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:286
static constexpr index_t KRepeat
Definition: gridwise_moe_gemm.hpp:204
remove_cvref_t< decltype(BlockGemmBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_gemm.hpp:930
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_gemm.hpp:255
static constexpr index_t NLane
Definition: gridwise_moe_gemm.hpp:206
static constexpr auto I5
Definition: gridwise_moe_gemm.hpp:171
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_gemm.hpp:541
static constexpr auto BK0Number
Definition: gridwise_moe_gemm.hpp:179
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_gemm.hpp:324
static constexpr index_t NumDTensor
Definition: gridwise_moe_gemm.hpp:184
__host__ static constexpr __device__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_gemm.hpp:1134
static constexpr auto I2
Definition: gridwise_moe_gemm.hpp:168
static constexpr index_t APackedSize
Definition: gridwise_moe_gemm.hpp:226
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_gemm.hpp:299
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_gemm.hpp:224
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_gemm.hpp:406
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_gemm.hpp:414
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_gemm.hpp:511
static constexpr auto I6
Definition: gridwise_moe_gemm.hpp:172
static constexpr auto I0
Definition: gridwise_moe_gemm.hpp:166
static constexpr index_t SortedTileSize
Definition: gridwise_moe_gemm.hpp:209
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_gemm.hpp:1127
static constexpr auto I1
Definition: gridwise_moe_gemm.hpp:167
static constexpr auto I4
Definition: gridwise_moe_gemm.hpp:170
static constexpr auto AK1Number
Definition: gridwise_moe_gemm.hpp:180
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:274
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_gemm.hpp:304
static constexpr auto BK1Number
Definition: gridwise_moe_gemm.hpp:181
static constexpr auto BlockSizeNumber
Definition: gridwise_moe_gemm.hpp:182
static constexpr index_t BPackedSize
Definition: gridwise_moe_gemm.hpp:233
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_gemm.hpp:517
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_gemm.hpp:264
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_gemm.hpp:222
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm.hpp:1876
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:280
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_gemm.hpp:562
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm.hpp:954
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_gemm.hpp:891
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_gemm.hpp:502
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_gemm.hpp:175
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_gemm.hpp:250
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm.hpp:1142
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_gemm.hpp:932
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_gemm.hpp:884
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_gemm.hpp:260
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm.hpp:1163
static constexpr index_t KPack
Definition: gridwise_moe_gemm.hpp:187
static constexpr index_t NWave
Definition: gridwise_moe_gemm.hpp:207
static constexpr auto I3
Definition: gridwise_moe_gemm.hpp:169
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_gemm.hpp:269
static constexpr auto AK0Number
Definition: gridwise_moe_gemm.hpp:178
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm.hpp:574
static constexpr index_t KGroup
Definition: gridwise_moe_gemm.hpp:192
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_gemm.hpp:310
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_gemm.hpp:768
static constexpr index_t KLane
Definition: gridwise_moe_gemm.hpp:189
static constexpr auto I7
Definition: gridwise_moe_gemm.hpp:173
Definition: xdlops_gemm.hpp:942
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1388
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1343
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1382
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:981
Definition: unary_element_wise_operation.hpp:308
Definition: unary_element_wise_operation.hpp:1023