/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp Source File
gridwise_moe_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
16 
18 
19 #define DEBUG_LOG 0
20 
21 namespace ck {
22 
23 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
24 // kernel function Blockers:
25 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
26 // two lds chunks.
27 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
28 // buffer when we declare __shared__ inside blkgemmpipe
29 
31 {
33  silu_and_mul = 1
34 };
35 
36 template <typename GridwiseGemm,
37  bool HasMainKBlockLoop,
38  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
39  index_t MinimumOccupancy = 1,
40  TailNumber TailNum = TailNumber::Even>
41 __global__ void
42 #if CK_USE_LAUNCH_BOUNDS
43 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
44 #endif
45  // __attribute__((amdgpu_waves_per_eu(1, 1)))
46  kernel_moe_gemm(typename GridwiseGemm::Argument karg)
47 {
48 #if defined(__gfx9__)
49  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
50 
51  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
52 
53  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
54  karg.p_sorted_token_ids,
55  karg.p_sorted_expert_ids,
56  karg.p_max_token_id,
57  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
58  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
59  karg.p_ds_grid,
60  karg.p_c_grid,
61  p_shared,
62  karg,
63  karg.a_element_op,
64  karg.b_element_op,
65  karg.c_element_op);
66 #else
67  ignore = karg;
68 #endif // end of if (defined(__gfx9__))
69 }
70 
71 template <typename GridwiseGemm,
72  bool HasMainKBlockLoop,
73  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
74  index_t MinimumOccupancy = 1,
75  TailNumber TailNum = TailNumber::Even>
76 __global__ void
77 #if CK_USE_LAUNCH_BOUNDS
78 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
79 #endif
80  // __attribute__((amdgpu_waves_per_eu(1, 1)))
81  kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
82 {
83 #if defined(__gfx9__)
84  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
85  __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
86 
87  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
88 
89  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
90  karg.p_sorted_token_ids,
91  karg.p_sorted_expert_ids,
92  karg.p_max_token_id,
93  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
94  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
95  karg.p_ds_grid,
96  karg.p_c_grid,
97  p_shared,
98  p_shared1,
99  karg,
100  karg.a_element_op,
101  karg.b_element_op,
102  karg.c_element_op);
103 #else
104  ignore = karg;
105 #endif // end of if (defined(__gfx9__))
106 }
107 
108 template <typename ALayout,
109  typename BLayout,
110  typename DsLayout,
111  typename CLayout,
112  typename ADataType,
113  typename BDataType,
114  typename AccDataType,
115  typename CShuffleDataType,
116  typename DsDataType,
117  typename CDataType,
118  typename AElementwiseOperation,
119  typename BElementwiseOperation,
120  typename CElementwiseOperation,
122  index_t BlockSize,
123  index_t MPerBlock,
124  index_t NPerBlock,
125  index_t KPerBlock,
126  index_t AK1Value,
127  index_t BK1Value,
128  index_t MPerXdl,
129  index_t NPerXdl,
130  index_t MXdlPerWave,
131  index_t NXdlPerWave,
132  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
133  typename ABlockTransferThreadClusterArrangeOrder,
134  typename ABlockTransferSrcAccessOrder,
135  index_t ABlockTransferSrcVectorDim,
136  index_t ABlockTransferSrcScalarPerVector,
137  index_t ABlockTransferDstScalarPerVector_AK1,
138  bool AThreadTransferSrcResetCoordinateAfterRun,
139  index_t ABlockLdsExtraM,
140  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
141  typename BBlockTransferThreadClusterArrangeOrder,
142  typename BBlockTransferSrcAccessOrder,
143  index_t BBlockTransferSrcVectorDim,
144  index_t BBlockTransferSrcScalarPerVector,
145  index_t BBlockTransferDstScalarPerVector_BK1,
146  bool BThreadTransferSrcResetCoordinateAfterRun,
147  index_t BBlockLdsExtraN,
148  index_t CShuffleMXdlPerWavePerShuffle,
149  index_t CShuffleNXdlPerWavePerShuffle,
150  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
151  typename CDEShuffleBlockTransferScalarPerVectors,
154  index_t ActivationOperation = 0,
155  bool NSwizzle = false,
156  bool IsInputGemm = true,
157  bool MulRoutedWeight = true,
158  bool PerTokenQuant = false,
159  typename IndexType = index_t,
160  typename ComputeTypeA = CDataType,
161  typename ComputeTypeB = ComputeTypeA,
162  typename LDSTypeA = ADataType,
163  typename LDSTypeB = BDataType>
165 {
166  static constexpr auto I0 = Number<0>{};
167  static constexpr auto I1 = Number<1>{};
168  static constexpr auto I2 = Number<2>{};
169  static constexpr auto I3 = Number<3>{};
170  static constexpr auto I4 = Number<4>{};
171  static constexpr auto I5 = Number<5>{};
172  static constexpr auto I6 = Number<6>{};
173  static constexpr auto I7 = Number<7>{};
174 
176  CDEShuffleBlockTransferScalarPerVectors{}[I0];
177  // K1 should be Number<...>
178  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
179  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
180  static constexpr auto AK1Number = Number<AK1Value>{};
181  static constexpr auto BK1Number = Number<BK1Value>{};
182  static constexpr auto BlockSizeNumber = Number<BlockSize>{};
183 
184  static constexpr index_t NumDTensor = DsDataType::Size();
185 
187  static constexpr index_t KPack =
189  static constexpr index_t KLane =
191 
192  static constexpr index_t KGroup = []() {
194  // On gfx950, we have a mfma that required 32 f8 elements as input,
195  // splited into 2 groups of 16 f8 elements.
196  // the 2 groups is not contiguous in the B preshuffed layout.
197  // and we do not want it to be contiguous in the B preshuffled layout
198  // because a memory instruction can only read 16 f8 elements at a time.
199  return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
200  else
201  return 1;
202  }();
203 
204  static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
205 
206  static constexpr index_t NLane = NPerXdl;
207  static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
208  // static constexpr index_t NumTokens = 1;
209  static constexpr index_t SortedTileSize = MPerBlock;
210 
211  static constexpr auto MakeDsGridPointer()
212  {
213  return generate_tuple(
214  [&](auto i) {
215  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
216 
217  return static_cast<const DDataType*>(nullptr);
218  },
220  }
221 
222  using DsGridPointer = decltype(MakeDsGridPointer());
223 
225 
226  static constexpr index_t APackedSize = []() {
228  return 2;
229  else
230  return 1;
231  }();
232 
233  static constexpr index_t BPackedSize = []() {
235  return 2;
236  else
237  return 1;
238  }();
239 
240  __host__ static auto CalculateGridSize(index_t M, index_t N)
241  {
242  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
243  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
244  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
245  const index_t gridy = NSwizzle ? 1 : mblock;
246 
247  return std::make_tuple(gridx, gridy, 1);
248  }
249 
250  __host__ __device__ static auto CalculateMPadded(index_t M)
251  {
252  return math::integer_least_multiple(M, MPerBlock);
253  }
254 
255  __host__ __device__ static auto CalculateNPadded(index_t N)
256  {
257  return math::integer_least_multiple(N, NPerBlock);
258  }
259 
260  __host__ __device__ static auto CalculateBN0Shuffled(index_t N)
261  {
262  return math::integer_divide_ceil(N, NLane);
263  }
264  __host__ __device__ static auto CalculateBK0Shuffled(index_t K)
265  {
267  }
268 
269  __host__ __device__ static auto CalculateKPadded(index_t K)
270  {
271  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
272  }
273 
274  __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
275  {
276  auto K_t = K_Batch * KPerBlock;
277  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
278  }
279 
280  __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
281  {
282  auto K_t = K_Batch * KPerBlock;
283  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
284  }
285 
286  __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
287  {
288  auto K_t = K_Batch * KPerBlock;
289  return (K + K_t - 1) / K_t * KPerBlock;
290  }
291 
292  __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
293  {
294  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
295  auto K_t = K_Batch * KReadVec;
296  return (K + K_t - 1) / K_t * KReadVec;
297  }
298 
299  __host__ __device__ static auto CalculateMBlock(index_t M)
300  {
301  return math::integer_divide_ceil(M, MPerBlock);
302  }
303 
304  __host__ __device__ static auto CalculateNBlock(index_t N)
305  {
306  return math::integer_divide_ceil(N, NPerBlock);
307  }
308 
309  template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
310  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
311  {
312  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
313  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
314 
316  TileDesc_K0_MN_K1{},
322  }
323 
324  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
325  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
326  {
327  const auto a_grid_desc_mraw_kraw = [&]() {
328  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
329  {
330  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
331  }
332  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
333  {
334  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
335  }
336  }();
337 
339 
340  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
341  GemmSpec == GemmSpecialization::MNKPadding)
342  {
343  // pad both M and K
344  const auto a_grid_desc_m_k =
345  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
347  make_right_pad_transform(K, KPad - K)),
350 
351  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
352  a_grid_desc_m_k,
357 
358  return a_grid_desc_ak0_m_ak1;
359  }
360  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
361  GemmSpec == GemmSpecialization::MNPadding)
362  {
363  // pad M, but not K
364  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
365  a_grid_desc_mraw_kraw,
367  make_right_pad_transform(M, MPad - M)),
370 
371  return a_grid_desc_ak0_m_ak1;
372  }
373  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
374  GemmSpec == GemmSpecialization::NKPadding)
375  {
376  // pad K, but not M
377  const auto a_grid_desc_m_k = transform_tensor_descriptor(
378  a_grid_desc_mraw_kraw,
382 
383  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
384  a_grid_desc_m_k,
389 
390  return a_grid_desc_ak0_m_ak1;
391  }
392  else
393  {
394  // not pad M or K
395  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
396  a_grid_desc_mraw_kraw,
401 
402  return a_grid_desc_ak0_m_ak1;
403  }
404  }
405 
406  __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
407  {
408  constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack / KGroup>{};
410  make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
411  make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
412  }
413 
414  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
415  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
416  {
417  const auto b_grid_desc_nraw_kraw = [&]() {
419  {
420  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
421  }
423  {
424  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
425  }
426  }();
427 
429 
430  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
431  GemmSpec != GemmSpecialization::Default),
432  "pk_i4_t does not support padding");
433 
434  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
435  GemmSpec == GemmSpecialization::MNKPadding)
436  {
437  // pad both N and K
438  const auto b_grid_desc_n_k =
439  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
441  make_right_pad_transform(K, KPad - K)),
444 
445  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
446  b_grid_desc_n_k,
451 
452  return b_grid_desc_bk0_n_bk1;
453  }
454  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
455  GemmSpec == GemmSpecialization::MNPadding)
456  {
457  // pad N, but not K
458  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
459  b_grid_desc_nraw_kraw,
461  make_right_pad_transform(N, NPad - N)),
464 
465  return b_grid_desc_bk0_n_bk1;
466  }
467  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
468  GemmSpec == GemmSpecialization::MKPadding)
469  {
470  // pad K, but not N
471  const auto b_grid_desc_n_k = transform_tensor_descriptor(
472  b_grid_desc_nraw_kraw,
476 
477  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
478  b_grid_desc_n_k,
483 
484  return b_grid_desc_bk0_n_bk1;
485  }
486  else
487  {
488  // not pad N or K
489  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
490  b_grid_desc_nraw_kraw,
495 
496  return b_grid_desc_bk0_n_bk1;
497  }
498  }
499 
500  template <typename ABlockDesc_AK0_M_AK1>
501  __host__ __device__ static constexpr auto
502  MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
503  {
504  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
505 
506  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
507  }
508 
509  template <typename BBlockDesc_BK0_N_BK1>
510  __host__ __device__ static constexpr auto
511  MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
512  {
513  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
514  }
515 
516  template <typename ELayout>
517  __host__ __device__ static auto MakeCGridDescriptor_M_N(
518  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
519  {
520  const auto c_grid_desc_mraw_nraw = [&]() {
522  {
523  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
524  }
526  {
527  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
528  }
529  }();
530 
531  // pad M and N
532  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
534  make_right_pad_transform(N, NPad - N)),
537  }
538 
539  template <typename DLayout>
540  __host__ __device__ static auto
542  {
543  const auto c_grid_desc_mraw_nraw = [&]() {
545  {
546  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
547  }
549  {
550  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
551  }
552  }();
553 
554  // pad M and N
555  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
557  make_right_pad_transform(N, NPad - N)),
560  }
561 
562  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
563  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
564  {
565  return generate_tuple(
566  [&](auto i) {
567  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
568  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
569  },
571  }
572 
573  template <typename DsGridDesc>
575  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
576  {
577  return generate_tuple(
578  [&](auto i) {
580  ds_grid_desc_m_n[i], MBlock, NBlock);
581  },
583  }
584 
585  struct Problem
586  {
587  __host__ __device__ Problem(index_t NumTokens_,
588  index_t TopK_,
589  index_t M_,
590  index_t N_,
591  index_t K_,
592  index_t StrideA_,
593  index_t StrideB_,
594  std::array<index_t, NumDTensor> StrideDs_,
595  index_t StrideC_,
596  index_t KBatch_)
597  : NumTokens{NumTokens_},
598  TopK{TopK_},
599  M{M_},
600  N{N_},
601  K{K_},
602  StrideA{StrideA_},
603  StrideB{StrideB_},
604  StrideDs{StrideDs_},
605  StrideC{StrideC_},
606  KBatch{KBatch_},
609  KRead{CalculateKRead(K_, KBatch_)},
610  KPadded{CalculateKPadded(K_, KBatch_)},
611  AK0{CalculateAK0Padded(K_, KBatch_)},
612  BK0{CalculateBK0Padded(K_, KBatch_)},
613  MBlock{CalculateMBlock(M_)},
614  NBlock{CalculateNBlock(N_)},
617  {
618  }
619 
620  __host__ void Print() const
621  {
622  std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
623  << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
624  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
625  << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
626  << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
627  << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
628  << "NBlock: " << NBlock << "}" << std::endl;
629  }
630 
638  std::array<index_t, NumDTensor> StrideDs;
649  // FOR PRESHUFFLE ONLY
652  };
653 
654  // Argument
656  {
657  __host__ Argument(const index_t* p_sorted_token_ids_,
658  const index_t* p_sorted_expert_ids_,
659  const index_t* p_max_token_id_,
660  const ADataType* p_a_grid_,
661  const BDataType* p_b_grid_,
662  std::array<const void*, NumDTensor> p_ds_grid_,
663  CDataType* p_c_grid_,
664  index_t NumTokens_,
665  index_t TopK_,
666  index_t M_,
667  index_t N_,
668  index_t K_,
669  index_t StrideA_,
670  index_t StrideB_,
671  std::array<index_t, NumDTensor> StrideDs_,
672  index_t StrideC_,
673  index_t k_batch_,
674  AElementwiseOperation a_element_op_,
675  BElementwiseOperation b_element_op_,
676  CElementwiseOperation c_element_op_)
677  : Problem{NumTokens_,
678  TopK_,
679  M_,
680  N_,
681  K_,
682  StrideA_,
683  StrideB_,
684  StrideDs_,
685  StrideC_,
686  k_batch_},
687  p_sorted_token_ids{p_sorted_token_ids_},
688  p_sorted_expert_ids{p_sorted_expert_ids_},
689  p_max_token_id{p_max_token_id_},
690  p_a_grid{p_a_grid_},
691  p_b_grid{p_b_grid_},
692  p_ds_grid{},
693  p_c_grid{p_c_grid_},
694  a_element_op{a_element_op_},
695  b_element_op{b_element_op_},
696  c_element_op{c_element_op_}
697  {
698 
699  // populate pointer, desc for Ds
700  static_for<0, NumDTensor, 1>{}([&](auto i) {
701  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
702 
703  // D pointer
704  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
705  });
706  }
707 
711  const ADataType* p_a_grid;
712  const BDataType* p_b_grid;
714  CDataType* p_c_grid;
715 
716  const AElementwiseOperation a_element_op;
717  const BElementwiseOperation b_element_op;
718  const CElementwiseOperation c_element_op;
719  };
720 
722  {
723  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
724  {
725  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
726  {
727  a_k_split_offset = k_id * karg.KRead / APackedSize;
728  }
729  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
730  {
731  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
732  }
733 
734  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
735  {
736  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
737  }
738  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
739  {
740  // KPack * NLane * KLane * K0 * N0
741  b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
742  }
743 
744  if(k_id < karg.KBatch - 1)
745  {
746  karg.K = karg.KRead;
747  }
748  else
749  {
750  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
751  }
752  }
753 
756  };
757 
758  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
759  {
760  // A matrix in LDS memory, dst of blockwise copy
761  if constexpr(ABlockLdsExtraM)
762  {
766  }
767  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
768  // in some cases.
770  {
771  constexpr auto a_lds_block_desc =
774 
775  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
776  a_lds_block_desc,
782 
783  return a_lds_block_desc_permuted;
784  }
785  else // ColumnMajor A
786  {
787  // kfold and mpair dimension is not always required.
788  // more dimension in merge_transform increase the difficulty of generating immarg offset
789  // for compiler.
790  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
791  constexpr auto M1 = MPerBlock / M0;
792 
793  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
794  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
795  constexpr auto KThreadRead = 64 / MPerXdl;
796  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
797 
798  constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
799  ? 1
800  : 128 / (AK1Number * M0 * sizeof(LDSTypeA));
801  constexpr auto KThreadReadPerm =
802  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
803  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
804  : KThreadRead;
805 
806  // 1<=mpair<=n0
807  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128)
808  ? 1
809  : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0
810  ? M0
811  : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA)));
812 
813  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
817  Number<kfold * M0 / mpair>{},
818  Number<mpair>{},
819  AK1Number));
820 
821  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
822  a_lds_block_desc,
823  make_tuple(
827  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
830  make_tuple(
832  make_tuple(
834 
835  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
836  a_lds_block_desc_permuted,
837  make_tuple(
845  Sequence<1>{},
846  Sequence<2>{},
847  Sequence<3>{},
848  Sequence<4>{},
849  Sequence<5>{}),
851  Sequence<2>{},
852  Sequence<0, 3>{},
853  Sequence<4, 5>{},
854  Sequence<6>{},
855  Sequence<7>{}));
856 
857  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
858  a_lds_block_desc_unmerged,
861  Number<KThreadWrite / kfold / KThreadReadPerm>{},
862  Number<kfold>{},
869 
870  return a_lds_block_desc_ak0_m_ak1;
871  }
872  }
873 
874  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
875  {
876  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
879  }
880 
882  {
883  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
884 
885  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
887  make_tuple(I1,
889  I1,
891 
892  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
893  }
894 
897  BlkGemmPipelineVer,
898  BlkGemmPipeSched,
899  BlockSize,
900  ADataType,
901  BDataType,
902  ComputeTypeA,
903  AccDataType,
910  ABlockTransferSrcScalarPerVector,
911  BBlockTransferSrcScalarPerVector,
912  MPerBlock,
913  NPerBlock,
914  KPerBlock,
915  MPerXdl,
916  NPerXdl,
917  MXdlPerWave,
918  NXdlPerWave,
919  KPack,
920  IsInputGemm>())>;
921 
922  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
923  {
924  // LDS allocation for A and B: be careful of alignment
925  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
926  // lds max alignment
927  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
928 
929  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
930  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
931 
932  // LDS allocation for C shuffle in LDS
933  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
935 
936  constexpr auto c_block_size =
937  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
938 
939  return math::max(a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize,
940  c_block_size * sizeof(CShuffleDataType));
941  }
942 
943  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
944  __host__ static constexpr bool CheckValidity(const Argument& karg)
945  {
946  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
947  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
948  "Invalid tuning param!");
949 
950  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
955  {
956  if(!(karg.M % MPerBlock == 0))
957  {
958 #if DEBUG_LOG
959  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
960  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
961  << std::endl;
962 
963 #endif // DEBUG_LOG
964  return false;
965  }
966  }
967 
968  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
973  {
974  if(!(karg.N % NPerBlock == 0))
975  {
976 #if DEBUG_LOG
977  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
978  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
979  << std::endl;
980 
981 #endif // DEBUG_LOG
982  return false;
983  }
984  }
985 
986  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
990  {
991 
992  auto K_t = karg.KBatch * KPerBlock;
993  if(!(karg.K % K_t == 0))
994  {
995 #if DEBUG_LOG
996  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
997  << karg.K << " " << __FILE__ << ":" << __LINE__
998  << ", in function: " << __func__ << std::endl;
999 
1000 #endif // DEBUG_LOG
1001  return false;
1002  }
1003  }
1004  else
1005  {
1006  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1007  auto K_t = karg.KBatch * KReadVec;
1008  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1009  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1010  {
1011  return false;
1012  }
1013  }
1014 
1016  {
1017  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1018  {
1019 #if DEBUG_LOG
1020  std::cout << "Arg K (" << karg.K
1021  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1022  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1023  << __LINE__ << ", in function: " << __func__ << std::endl;
1024 
1025 #endif // DEBUG_LOG
1026  return false;
1027  }
1028  }
1029  else
1030  {
1031  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1032  {
1033 #if DEBUG_LOG
1034  std::cout << "Arg M (" << karg.M
1035  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1036  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1037  << __LINE__ << ", in function: " << __func__ << std::endl;
1038 
1039 #endif // DEBUG_LOG
1040  return false;
1041  }
1042  }
1043 
1045  {
1046  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1047  {
1048 #if DEBUG_LOG
1049  std::cout << "Arg N (" << karg.N
1050  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1051  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1052  << __LINE__ << ", in function: " << __func__ << std::endl;
1053 
1054 #endif // DEBUG_LOG
1055  return false;
1056  }
1057  }
1058  else
1059  {
1060  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1061  {
1062 #if DEBUG_LOG
1063  std::cout << "Arg K (" << karg.K
1064  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1065  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1066  << __LINE__ << ", in function: " << __func__ << std::endl;
1067 
1068 #endif // DEBUG_LOG
1069  return false;
1070  }
1071  }
1072 
1074  {
1076  {
1077 #if DEBUG_LOG
1078  std::cout << "Arg N (" << karg.N
1079  << ") value is not a multiple of "
1080  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1081  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1082  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1083 
1084 #endif // DEBUG_LOG
1085  return false;
1086  }
1087  }
1088  else
1089  {
1091  {
1092 #if DEBUG_LOG
1093  std::cout << "Arg M (" << karg.M
1094  << ") value is not a multiple of "
1095  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1096  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1097  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1098 
1099 #endif // DEBUG_LOG
1100  return false;
1101  }
1102  }
1103 
1104  // check gridwise gemm pipeline
1105 #if 0
1106  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1107 
1108  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1109  {
1110  return false;
1111  }
1112 #endif
1113  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1114  return true;
1115  }
1116 
1117  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1118  {
1119  const index_t num_loop = K / KPerBlock;
1120 
1121  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1122  }
1123 
1124  __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1125  {
1126  const index_t num_loop = K / KPerBlock;
1127 
1128  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1129  }
1130 
1131  template <typename CGridDesc>
1133  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1134  {
1135  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1136  c_grid_desc_m_n,
1141 
1142  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1143  }
1144 
1145  // return block_id to C matrix tile idx (m0, n0) mapping
1146  // if arch = gfx942
1147  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1148  // NPerBlock>;
1149 
1150  template <bool HasMainKBlockLoop,
1151  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1152  TailNumber TailNum = TailNumber::Odd>
1153  __device__ static void Run(const index_t* p_sorted_token_ids,
1154  const index_t* p_sorted_expert_ids,
1155  const index_t* p_max_token_id,
1156  const ADataType* p_a_grid,
1157  const BDataType* p_b_grid,
1158  DsGridPointer& p_ds_grid,
1159  CDataType* p_c_grid,
1160  void* p_shared,
1161  const Problem& problem,
1162  AElementwiseOperation a_element_op,
1163  BElementwiseOperation b_element_op,
1164  CElementwiseOperation c_element_op)
1165  {
1166  ignore = b_element_op;
1167  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1168  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1169  problem.MPadded,
1170  problem.K,
1171  problem.KPadded,
1172  problem.StrideA,
1173  problem.AK0);
1174  const auto b_grid_desc_bpreshuffled =
1176  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1177  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1178  problem.MPadded,
1179  problem.N,
1180  problem.NPadded,
1181  problem.StrideC);
1182  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1184  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1185  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1186  // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
1187  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1188  if(expert_block_id * MPerBlock >= max_token_id)
1189  return;
1190  const index_t expert_id =
1191  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1192  const auto block_mn = [&]() -> std::pair<int, int> {
1193  if constexpr(NSwizzle)
1194  {
1195  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1196  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1197  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1198  const index_t expert_swizzle =
1199  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1200  const index_t bid_new = blockIdx.x - prefix_block;
1201  const index_t nid = __builtin_amdgcn_readfirstlane(
1202  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1203  const index_t mid =
1204  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1205  return {nid, mid};
1206  }
1207  else
1208  {
1209  return {blockIdx.x, blockIdx.y};
1210  }
1211  }();
1212 
1213  const index_t block_n_id = block_mn.first;
1214  const index_t block_m_id = block_mn.second;
1215  const index_t token0 =
1216  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1217 
1218  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1219  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1220  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1221  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1222  constexpr auto AKThreads = AK0Threads * AK1Threads;
1223  constexpr auto AMRepeats = MPerBlock / AMThreads;
1224  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1225 
1226  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1227  return;
1229  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1230  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1231  index_t token_offset = fused_token & 0xffffff;
1232  if constexpr(!IsInputGemm)
1233  {
1234  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1235  }
1236  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1237  });
1238  const IndexType expert_stride =
1239  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1240  const IndexType expert_offset = expert_id * expert_stride / BPackedSize;
1241  // N0, K0, Blocksize*KPack
1242  const index_t n_block_data_idx_on_grid =
1243  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1244 
1245  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1246  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1247  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1248  p_b_grid + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1249  // A matrix in LDS memory, dst of blockwise copy
1250  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1251 
1252  // B matrix in LDS memory, dst of blockwise copy
1253  // dummy
1254  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1255  // A matrix blockwise copy
1256  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1258  AElementwiseOperation,
1262  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1263  ABlockTransferThreadClusterArrangeOrder,
1264  ADataType,
1265  LDSTypeA,
1266  decltype(a_grid_desc_ak0_m_ak1),
1267  decltype(a_block_desc_ak0_m_ak1),
1268  ABlockTransferSrcAccessOrder,
1270  ABlockTransferSrcVectorDim,
1271  2,
1272  ABlockTransferSrcScalarPerVector,
1273  ABlockTransferDstScalarPerVector_AK1,
1274  1,
1275  1,
1276  AThreadTransferSrcResetCoordinateAfterRun,
1277  true,
1278  IndexType,
1279  1,
1280  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1281  make_multi_index(0, 0, 0),
1282  a_element_op,
1283  a_block_desc_ak0_m_ak1,
1284  make_multi_index(0, 0, 0),
1286  gather_offsets);
1287 
1288  // Thread-wise copy
1289  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1290  auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1291  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1292 
1293  auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
1294  BDataType,
1295  BDataType,
1296  decltype(b_grid_desc_bpreshuffled),
1297  decltype(b_block_desc_bk0_n_bk1),
1300  3,
1301  BBlockTransferSrcScalarPerVector,
1302  BThreadTransferSrcResetCoordinateAfterRun,
1303  true>(b_grid_desc_bpreshuffled,
1304  make_multi_index(n_block_data_idx_on_grid,
1306  0,
1307  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1308 
1309  // LDS allocation for A and B: be careful of alignment
1310  // Cast after lds
1311  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1312  static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1313 
1314  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1315  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
1316 
1317  // Blockwise GEMM pipeline
1318  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1319  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1320  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1321  decltype(c_thread_buf) c_thread_buf_up;
1322 
1324  float,
1325  c_thread_buf.num_of_v_,
1326  c_thread_buf.s_per_v,
1327  true>
1328  c_thread_buf_fp32;
1329 
1330  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1331  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1332  KPerBlock);
1333  if constexpr(IsInputGemm)
1334  {
1335  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
1336  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1337  p_b_grid_up + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1338  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
1339  BDataType,
1340  BDataType,
1341  decltype(b_grid_desc_bpreshuffled),
1342  decltype(b_block_desc_bk0_n_bk1),
1345  3,
1346  BBlockTransferSrcScalarPerVector,
1347  BThreadTransferSrcResetCoordinateAfterRun,
1348  true>(b_grid_desc_bpreshuffled,
1349  make_multi_index(n_block_data_idx_on_grid,
1351  0,
1352  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1353 
1354  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1355  a_grid_desc_ak0_m_ak1,
1356  a_block_desc_ak0_m_ak1,
1357  a_blockwise_copy,
1358  a_grid_buf,
1359  a_block_buf,
1360  a_block_slice_copy_step,
1361  b_grid_desc_bpreshuffled,
1362  b_blockwise_copy,
1363  b_blockwise_copy_up,
1364  b_grid_buf,
1365  b_grid_buf_up,
1366  b_block_buf,
1367  b_block_slice_copy_step,
1368  c_thread_buf,
1369  c_thread_buf_up,
1370  num_k_block_main_loop);
1371  }
1372  else
1373  {
1374  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1375  a_grid_desc_ak0_m_ak1,
1376  a_block_desc_ak0_m_ak1,
1377  a_blockwise_copy,
1378  a_grid_buf,
1379  a_block_buf,
1380  a_block_slice_copy_step,
1381  b_grid_desc_bpreshuffled,
1382  b_blockwise_copy,
1383  b_grid_buf,
1384  b_block_buf,
1385  b_block_slice_copy_step,
1386  c_thread_buf,
1387  num_k_block_main_loop);
1388  }
1389 
1390  // shuffle C and write out
1391  {
1392  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1393  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1394  "wrong!");
1395 
1396  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1397 
1398  // TODO: hacky, fix it!
1399  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1400  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1401 
1402  // TODO: hacky, fix it!
1403  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1404  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1405  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1406 
1407  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1408  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1409  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1410  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1411  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1412  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1413  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1414  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1415 
1416  // mul scales
1417  const float* p_sorted_weights_0 = p_ds_grid[I0];
1418  const float* p_scale_b = p_ds_grid[I1];
1419 
1420  static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
1421  static_assert(M4 == 4);
1422  const index_t m1 = get_warp_local_1d_id() / NWave;
1423  const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
1424 
1425  if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr)
1426  {
1427  if constexpr(PerTokenQuant)
1428  {
1429  constexpr index_t scale_stride = (IsInputGemm ? 2 : 1);
1430  p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock +
1431  get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl;
1432  }
1433  else
1434  {
1435  p_scale_b += expert_id;
1436  }
1437 
1438  vector_type<int32_t, 4> scale_token_ids;
1439  vector_type<float, 4> topk_weights;
1440  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
1441  const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant];
1442  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
1443  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
1444  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
1445  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
1446  if constexpr(PerTokenQuant)
1447  {
1448  scale_token_ids =
1449  *c_style_pointer_cast<const vector_type<int32_t, M4>*>(
1450  p_sorted_token_ids + m_pos);
1451  }
1452  if constexpr(MulRoutedWeight)
1453  {
1454  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
1455  p_ds_grid[I2] + m_pos);
1456  }
1457  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
1458  float scale_a = [&]() {
1459  if constexpr(PerTokenQuant)
1460  {
1461  index_t fused_token = scale_token_ids.AsType<index_t>()[m4];
1462  const index_t token_offset = fused_token & 0xffffff;
1463  return token_offset < problem.NumTokens
1464  ? p_sorted_weights_0[IsInputGemm
1465  ? token_offset
1466  : token_offset *
1467  problem.TopK +
1468  (fused_token >>
1469  24)]
1470  : 0.0;
1471  }
1472  else
1473  {
1474  return p_sorted_weights_0[0];
1475  }
1476  }();
1477  constexpr index_t c_offset =
1478  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1479  make_tuple(m0, n0, m2 * M4 + m4));
1480  constexpr auto cidx = Number<c_offset>{};
1481  if constexpr(IsInputGemm) // gu fusion
1482  {
1483  if constexpr(ActivationOperation == Activation::silu_and_mul)
1484  {
1485  const float scale_up =
1486  p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
1487  PerTokenQuant];
1488  float gate = scale_a * scale_b * c_thread_buf[cidx];
1489  float up = scale_a * scale_up * c_thread_buf_up[cidx];
1490  if constexpr(MulRoutedWeight)
1491  {
1492  gate = gate * topk_weights.AsType<float>()[m4];
1493  up = up * topk_weights.AsType<float>()[m4];
1494  }
1496  {
1497  gate *= 16;
1498  up *= 16;
1499  }
1501  c_thread_buf_fp32(cidx) = gate * up;
1502  }
1503  else if(ActivationOperation == Activation::gelu_and_mul)
1504  {
1505  const float scale_up =
1506  p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
1507  PerTokenQuant];
1508  float gate = scale_a * scale_b * c_thread_buf[cidx];
1509  float up = scale_a * scale_up * c_thread_buf_up[cidx];
1510  if constexpr(MulRoutedWeight)
1511  {
1512  gate = gate * topk_weights.AsType<float>()[m4];
1513  up = up * topk_weights.AsType<float>()[m4];
1514  }
1516  {
1517  gate *= 16;
1518  up *= 16;
1519  }
1521  c_thread_buf_fp32(cidx) = gate * up;
1522  }
1523  }
1524  else
1525  {
1526  c_thread_buf_fp32(cidx) =
1527  scale_a * scale_b * c_thread_buf[cidx];
1528  if constexpr(MulRoutedWeight)
1529  {
1530  c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) *
1531  topk_weights.AsType<float>()[m4];
1532  }
1533  }
1534  });
1535  });
1536  });
1537  });
1538  }
1539  else
1540  {
1541  vector_type<float, 4> topk_weights; // for gemm2 only
1542  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
1543  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
1544  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
1545  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
1546  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
1547  if constexpr(MulRoutedWeight)
1548  {
1549  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
1550  p_ds_grid[I2] + m_pos);
1551  }
1552  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
1553  constexpr index_t c_offset =
1554  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1555  make_tuple(m0, n0, m2 * M4 + m4));
1556  constexpr auto cidx = Number<c_offset>{};
1557 
1558  if constexpr(IsInputGemm) // gu fusion
1559  {
1560  if constexpr(ActivationOperation == Activation::silu_and_mul)
1561  {
1562  float gate = c_thread_buf[cidx];
1563  float up = c_thread_buf_up[cidx];
1564  if constexpr(MulRoutedWeight)
1565  {
1566  gate = gate * topk_weights.AsType<float>()[m4];
1567  up = up * topk_weights.AsType<float>()[m4];
1568  }
1570  c_thread_buf_fp32(cidx) = gate * up;
1571  }
1572  else if(ActivationOperation == Activation::gelu_and_mul)
1573  {
1574  float gate = c_thread_buf[cidx];
1575  float up = c_thread_buf_up[cidx];
1576  if constexpr(MulRoutedWeight)
1577  {
1578  gate = gate * topk_weights.AsType<float>()[m4];
1579  up = up * topk_weights.AsType<float>()[m4];
1580  }
1582  c_thread_buf_fp32(cidx) = gate * up;
1583  }
1584  }
1585  else
1586  {
1587  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1588  if constexpr(MulRoutedWeight)
1589  {
1590  c_thread_buf_fp32(cidx) = topk_weights.AsType<float>()[m4] *
1591  c_thread_buf_fp32[cidx];
1592  }
1593  }
1594  });
1595  });
1596  });
1597  });
1598  }
1599 
1600  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1602 
1603  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1604  static_cast<CShuffleDataType*>(p_shared),
1605  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1606 
1607  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1608  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1609  make_tuple(
1612  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1613  M1, // M1 = MWave
1614  M2, // M2 * M3 * M4 = MPerXdl
1615  M3,
1616  M4)),
1619  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1620  N1, // N1 = NWave
1621  N2))), // N2 = NPerXdl
1623  make_tuple(
1625 
1626  // calculate origin of thread output tensor on global memory
1627  // blockwise GEMM c matrix starting index
1628  const auto c_thread_mtx_on_block =
1629  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1630 
1631  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1632  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1633 
1634  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1636  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1638  make_tuple(Sequence<0>{}));
1639 
1640  const auto m_thread_data_on_block_idx =
1641  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1642  make_multi_index(m_thread_data_on_block));
1643 
1644  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1648  make_tuple(Sequence<0>{}));
1649 
1650  const auto n_thread_data_on_block_idx =
1651  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1652  make_multi_index(n_thread_data_on_block));
1653 
1654  // shuffle: threadwise copy C from VGPR to LDS
1655  auto c_thread_copy_vgpr_to_lds =
1657  CShuffleDataType,
1658  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1659  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1661  Sequence<CShuffleMXdlPerWavePerShuffle,
1662  CShuffleNXdlPerWavePerShuffle,
1663  I1,
1664  I1,
1665  M2,
1666  I1,
1667  M4,
1668  I1>,
1670  7,
1671  1,
1673  1,
1674  true>{
1675  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1676  make_multi_index(0,
1677  0,
1678  m_thread_data_on_block_idx[I1],
1679  n_thread_data_on_block_idx[I1],
1680  m_thread_data_on_block_idx[I2],
1681  m_thread_data_on_block_idx[I3],
1682  m_thread_data_on_block_idx[I4],
1683  n_thread_data_on_block_idx[I2]),
1685 
1686  using EDataType = CDataType;
1687 
1688  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1689  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1690 
1691  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1693  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1694 
1695  const auto ds_grid_buf = generate_tuple(
1696  [&](auto i) {
1697  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1698  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1699  },
1700  Number<NumDTensor>{});
1701 
1702  // tuple of reference to C/Ds tensor descriptors
1703  const auto c_ds_desc_refs = concat_tuple_of_reference(
1704  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1705  generate_tie([&](auto i) -> const auto& // return type should be reference
1706  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1707  Number<NumDTensor>{}));
1708 
1709  // tuple of reference to C/Ds tensor descriptors
1710  const auto c_ds_buf_refs = concat_tuple_of_reference(
1711  tie(c_shuffle_block_buf),
1712  generate_tie([&](auto i) -> const auto& // return type should be reference
1713  { return ds_grid_buf[i]; },
1714  Number<NumDTensor>{}));
1715 
1716  // tuple of starting index of C/Ds blockwise copy
1717  const auto idx_c_ds_block_begin =
1720  [&](auto) {
1721  return make_multi_index(block_m_id, 0, block_n_id, 0);
1722  // return make_multi_index(block_work_idx[I0], 0,
1723  // block_work_idx[I1], 0);
1724  },
1725  Number<NumDTensor>{}));
1726 
1727  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1728  c_grid_desc_mblock_mperblock_nblock_nperblock;
1729 
1730  using CDEBlockTransferCluster =
1731  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1732  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1733  constexpr index_t scatter_weight_idx = 3; // hack fix felix
1734  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1736  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1738  decltype(c_ds_desc_refs),
1739  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1740  CElementwiseOperation,
1741  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1742  // support arbitray type
1743  Sequence<1,
1744  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1745  1,
1746  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1747  CDEBlockTransferCluster,
1748  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1749  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1750  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1751  3, // index_t SrcVectorDim,
1752  3, // index_t DstVectorDim,
1753  CDEShuffleBlockTransferScalarPerVectors,
1758  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1759  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1760  IndexType,
1761  1, // ScatterDim
1762  true, // OutputScatter: false, only use scatter weights
1763  scatter_weight_idx // ScatterWeightIdx: ascale
1764  >{c_ds_desc_refs,
1765  idx_c_ds_block_begin,
1766  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1767  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
1768  c_element_op};
1769 
1770  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1771  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1772  constexpr auto sfc_c_vgpr =
1775  Sequence<CShuffleMXdlPerWavePerShuffle,
1776  CShuffleNXdlPerWavePerShuffle,
1777  1,
1778  1,
1779  M2,
1780  1,
1781  M4,
1782  1>>{};
1783 
1784  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1785 
1786  // space filling curve for shuffled blockwise C/D/E
1787  constexpr auto sfc_cde_block =
1790  Sequence<1,
1791  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1792  1,
1793  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1794 
1795  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1796  constexpr auto EMThreads =
1797  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
1798  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1799  constexpr auto ENThreads =
1800  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
1801  static_for<0, num_access, 1>{}([&](auto access_id) {
1802  // make sure it's safe to write to LDS
1804 
1805  auto dstidx = sfc_cde_block.GetIndex(access_id);
1806  const index_t c_token_pos =
1807  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
1808  static_for<0, EMRepeats, 1>{}([&](auto m0) {
1809  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1810  IndexType token_offset = fused_token & 0xffffff;
1811  if constexpr(IsInputGemm)
1812  {
1813  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1814  }
1815  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
1816  });
1817 
1818  block_sync_lds();
1819 
1820  // each thread write its data from VGPR to LDS
1821  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1822  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1823  c_thread_buf_fp32,
1824  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1825  c_shuffle_block_buf);
1826 
1827  // make sure it's safe to read from LDS
1828  block_sync_lds();
1829 
1830  // each block copy its data from LDS to global
1831  cde_block_copy_lds_and_global.Run(
1832  c_ds_desc_refs,
1833  c_ds_buf_refs,
1834  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1835  tie(c_grid_buf),
1836  scatter_offsets);
1837 
1838  if constexpr(access_id < num_access - 1)
1839  {
1840  constexpr auto cde_lds_and_global_step =
1841  sfc_cde_block.GetForwardStep(access_id);
1842 
1843  // move on Ds
1844  static_for<0, NumDTensor, 1>{}([&](auto i) {
1845  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1846  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
1847  });
1848 
1849  // move on E
1850  cde_block_copy_lds_and_global.MoveDstSliceWindow(
1851  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1852  I0,
1853  cde_lds_and_global_step);
1854  }
1855  });
1856  }
1857  }
1858 
1859  template <bool HasMainKBlockLoop,
1860  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1861  TailNumber TailNum = TailNumber::Odd>
1862  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
1863  const index_t* p_sorted_expert_ids,
1864  const index_t* p_max_token_id,
1865  const ADataType* p_a_grid,
1866  const BDataType* p_b_grid,
1867  DsGridPointer& p_ds_grid,
1868  CDataType* p_c_grid,
1869  void* p_shared,
1870  void* p_shared1,
1871  const Problem& problem,
1872  AElementwiseOperation a_element_op,
1873  BElementwiseOperation b_element_op,
1874  CElementwiseOperation c_element_op)
1875  {
1876  ignore = b_element_op;
1877  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1878  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1879  problem.MPadded,
1880  problem.K,
1881  problem.KPadded,
1882  problem.StrideA,
1883  problem.AK0);
1884  const auto b_grid_desc_bpreshuffled =
1886  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1887  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1888  problem.MPadded,
1889  problem.N,
1890  problem.NPadded,
1891  problem.StrideC);
1892  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1894  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1895  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1896  // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
1897  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1898  if(expert_block_id * MPerBlock >= max_token_id)
1899  return;
1900  const index_t expert_id =
1901  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1902  const auto block_mn = [&]() -> std::pair<int, int> {
1903  if constexpr(NSwizzle)
1904  {
1905  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1906  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1907  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1908  const index_t expert_swizzle =
1909  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1910  const index_t bid_new = blockIdx.x - prefix_block;
1911  const index_t nid = __builtin_amdgcn_readfirstlane(
1912  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1913  const index_t mid =
1914  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1915  return {nid, mid};
1916  }
1917  else
1918  {
1919  return {blockIdx.x, blockIdx.y};
1920  }
1921  }();
1922 
1923  const index_t block_n_id = block_mn.first;
1924  const index_t block_m_id = block_mn.second;
1925  const index_t token0 =
1926  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1927 
1928  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1929  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1930  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1931  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1932  constexpr auto AKThreads = AK0Threads * AK1Threads;
1933  constexpr auto AMRepeats = MPerBlock / AMThreads;
1934  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1935 
1936  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1937  return;
1939  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1940  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1941  index_t token_offset = fused_token & 0xffffff;
1942  if constexpr(!IsInputGemm)
1943  {
1944  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1945  }
1946  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1947  });
1948  const IndexType expert_stride =
1949  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1950  const IndexType expert_offset = expert_id * expert_stride / BPackedSize;
1951  // N0, K0, Blocksize*KPack
1952  const index_t n_block_data_idx_on_grid =
1953  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1954 
1955  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1956  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1957  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1958  p_b_grid + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1959 
1960  // A matrix in LDS memory, dst of blockwise copy
1961  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1962 
1963  // B matrix in LDS memory, dst of blockwise copy
1964  // dummy
1965  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1966  // A matrix blockwise copy
1967  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1969  AElementwiseOperation,
1973  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1974  ABlockTransferThreadClusterArrangeOrder,
1975  ADataType,
1976  LDSTypeA,
1977  decltype(a_grid_desc_ak0_m_ak1),
1978  decltype(a_block_desc_ak0_m_ak1),
1979  ABlockTransferSrcAccessOrder,
1981  ABlockTransferSrcVectorDim,
1982  2,
1983  ABlockTransferSrcScalarPerVector,
1984  ABlockTransferDstScalarPerVector_AK1,
1985  1,
1986  1,
1987  AThreadTransferSrcResetCoordinateAfterRun,
1988  true,
1989  IndexType,
1990  1,
1991  2>(a_grid_desc_ak0_m_ak1,
1992  make_multi_index(0, 0, 0),
1993  a_element_op,
1994  a_block_desc_ak0_m_ak1,
1995  make_multi_index(0, 0, 0),
1997  gather_offsets);
1998 
1999  // Thread-wise copy
2000  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2001  auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2002  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2003  auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2004  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2005  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2006 
2007  auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
2008  BDataType,
2009  BDataType,
2010  decltype(b_grid_desc_bpreshuffled),
2011  decltype(b_block_desc_bk0_n_bk1),
2014  3,
2015  BBlockTransferSrcScalarPerVector,
2016  BThreadTransferSrcResetCoordinateAfterRun,
2017  true>(b_grid_desc_bpreshuffled,
2018  make_multi_index(n_block_data_idx_on_grid,
2020  0,
2021  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2022 
2023  // LDS allocation for A and B: be careful of alignment
2024  // Cast after lds
2025  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2026  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2027  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2028  static_cast<ADataType*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2029  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2030 
2031  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2032  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
2033 
2034  // Blockwise GEMM pipeline
2035  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2036  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2037  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2038  decltype(c_thread_buf) c_thread_buf_up;
2039 
2041  float,
2042  c_thread_buf.num_of_v_,
2043  c_thread_buf.s_per_v,
2044  true>
2045  c_thread_buf_fp32;
2046 
2047  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2048  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2049  KPerBlock);
2050 
2051  if constexpr(IsInputGemm)
2052  {
2053  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
2054  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2055  p_b_grid_up + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
2056  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
2057  BDataType,
2058  BDataType,
2059  decltype(b_grid_desc_bpreshuffled),
2060  decltype(b_block_desc_bk0_n_bk1),
2063  3,
2064  BBlockTransferSrcScalarPerVector,
2065  BThreadTransferSrcResetCoordinateAfterRun,
2066  true>(b_grid_desc_bpreshuffled,
2067  make_multi_index(n_block_data_idx_on_grid,
2069  0,
2070  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2071  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2072  a_grid_desc_ak0_m_ak1,
2073  a_block_desc_ak0_m_ak1,
2074  a_blockwise_copy,
2075  a_grid_buf,
2076  a_block_bufs,
2077  a_block_slice_copy_step,
2078  b_grid_desc_bpreshuffled,
2079  b_blockwise_copy,
2080  b_blockwise_copy_up,
2081  b_grid_buf,
2082  b_grid_buf_up,
2083  b_block_bufs,
2084  b_block_slice_copy_step,
2085  c_thread_buf,
2086  c_thread_buf_up,
2087  num_k_block_main_loop);
2088  }
2089  else
2090  {
2091 
2092  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2093  a_grid_desc_ak0_m_ak1,
2094  a_block_desc_ak0_m_ak1,
2095  a_blockwise_copy,
2096  a_grid_buf,
2097  a_block_bufs,
2098  a_block_slice_copy_step,
2099  b_grid_desc_bpreshuffled,
2100  b_blockwise_copy,
2101  b_grid_buf,
2102  b_block_bufs,
2103  b_block_slice_copy_step,
2104  c_thread_buf,
2105  num_k_block_main_loop);
2106  }
2107 
2108  // shuffle C and write out
2109  {
2110  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2111  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2112  "wrong!");
2113 
2114  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2115 
2116  // TODO: hacky, fix it!
2117  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2118  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2119 
2120  // TODO: hacky, fix it!
2121  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2122  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2123  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2124 
2125  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2126  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2127  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2128  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2129  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2130  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2131  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2132  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2133 
2134  // mul scales
2135  const float* p_sorted_weights_0 = p_ds_grid[I0];
2136  const float* p_scale_b = p_ds_grid[I1];
2137 
2138  static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
2139  static_assert(M4 == 4);
2140  const index_t m1 = get_warp_local_1d_id() / NWave;
2141  const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
2142 
2143  if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr)
2144  {
2145  if constexpr(PerTokenQuant)
2146  {
2147  constexpr index_t scale_stride = (IsInputGemm ? 2 : 1);
2148  p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock +
2149  get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl;
2150  }
2151  else
2152  {
2153  p_scale_b += expert_id;
2154  }
2155 
2156  vector_type<int32_t, 4> scale_token_ids;
2157  vector_type<float, 4> topk_weights;
2158  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
2159  const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant];
2160  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
2161  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
2162  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2163  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2164  if constexpr(PerTokenQuant)
2165  {
2166  scale_token_ids =
2167  *c_style_pointer_cast<const vector_type<int32_t, M4>*>(
2168  p_sorted_token_ids + m_pos);
2169  }
2170  if constexpr(MulRoutedWeight)
2171  {
2172  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
2173  p_ds_grid[I2] + m_pos);
2174  }
2175  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
2176  float scale_a = [&]() {
2177  if constexpr(PerTokenQuant)
2178  {
2179  index_t fused_token = scale_token_ids.AsType<index_t>()[m4];
2180  const index_t token_offset = fused_token & 0xffffff;
2181  return token_offset < problem.NumTokens
2182  ? p_sorted_weights_0[IsInputGemm
2183  ? token_offset
2184  : token_offset *
2185  problem.TopK +
2186  (fused_token >>
2187  24)]
2188  : 0.0;
2189  }
2190  else
2191  {
2192  return p_sorted_weights_0[0];
2193  }
2194  }();
2195  constexpr index_t c_offset =
2196  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2197  make_tuple(m0, n0, m2 * M4 + m4));
2198  constexpr auto cidx = Number<c_offset>{};
2199  if constexpr(IsInputGemm) // gu fusion
2200  {
2201  if constexpr(ActivationOperation == Activation::silu_and_mul)
2202  {
2203  const float scale_up =
2204  p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
2205  PerTokenQuant];
2206  float gate = scale_a * scale_b * c_thread_buf[cidx];
2207  float up = scale_a * scale_up * c_thread_buf_up[cidx];
2208  if constexpr(MulRoutedWeight)
2209  {
2210  gate = gate * topk_weights.AsType<float>()[m4];
2211  up = up * topk_weights.AsType<float>()[m4];
2212  }
2214  {
2215  gate *= 16;
2216  up *= 16;
2217  }
2219  c_thread_buf_fp32(cidx) = gate * up;
2220  }
2221  else if(ActivationOperation == Activation::gelu_and_mul)
2222  {
2223  const float scale_up =
2224  p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
2225  PerTokenQuant];
2226  float gate = scale_a * scale_b * c_thread_buf[cidx];
2227  float up = scale_a * scale_up * c_thread_buf_up[cidx];
2228  if constexpr(MulRoutedWeight)
2229  {
2230  gate = gate * topk_weights.AsType<float>()[m4];
2231  up = up * topk_weights.AsType<float>()[m4];
2232  }
2234  {
2235  gate *= 16;
2236  up *= 16;
2237  }
2239  c_thread_buf_fp32(cidx) = gate * up;
2240  }
2241  }
2242  else
2243  {
2244  c_thread_buf_fp32(cidx) =
2245  scale_a * scale_b * c_thread_buf[cidx];
2246  if constexpr(MulRoutedWeight)
2247  {
2248  c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) *
2249  topk_weights.AsType<float>()[m4];
2250  }
2251  }
2252  });
2253  });
2254  });
2255  });
2256  }
2257  else
2258  {
2259  vector_type<float, 4> topk_weights; // for gemm2 only
2260  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
2261  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
2262  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
2263  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2264  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2265  if constexpr(MulRoutedWeight)
2266  {
2267  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
2268  p_ds_grid[I2] + m_pos);
2269  }
2270  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
2271  constexpr index_t c_offset =
2272  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2273  make_tuple(m0, n0, m2 * M4 + m4));
2274  constexpr auto cidx = Number<c_offset>{};
2275 
2276  if constexpr(IsInputGemm) // gu fusion
2277  {
2278  if constexpr(ActivationOperation == Activation::silu_and_mul)
2279  {
2280  float gate = c_thread_buf[cidx];
2281  float up = c_thread_buf_up[cidx];
2282  if constexpr(MulRoutedWeight)
2283  {
2284  gate = gate * topk_weights.AsType<float>()[m4];
2285  up = up * topk_weights.AsType<float>()[m4];
2286  }
2288  c_thread_buf_fp32(cidx) = gate * up;
2289  }
2290  else if(ActivationOperation == Activation::gelu_and_mul)
2291  {
2292  float gate = c_thread_buf[cidx];
2293  float up = c_thread_buf_up[cidx];
2294  if constexpr(MulRoutedWeight)
2295  {
2296  gate = gate * topk_weights.AsType<float>()[m4];
2297  up = up * topk_weights.AsType<float>()[m4];
2298  }
2300  c_thread_buf_fp32(cidx) = gate * up;
2301  }
2302  }
2303  else
2304  {
2305  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2306  if constexpr(MulRoutedWeight)
2307  {
2308  c_thread_buf_fp32(cidx) = topk_weights.AsType<float>()[m4] *
2309  c_thread_buf_fp32[cidx];
2310  }
2311  }
2312  });
2313  });
2314  });
2315  });
2316  }
2317 
2318  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2320 
2321  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2322  static_cast<CShuffleDataType*>(p_shared),
2323  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2324 
2325  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2326  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2327  make_tuple(
2330  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
2331  M1, // M1 = MWave
2332  M2, // M2 * M3 * M4 = MPerXdl
2333  M3,
2334  M4)),
2337  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
2338  N1, // N1 = NWave
2339  N2))), // N2 = NPerXdl
2341  make_tuple(
2343 
2344  // calculate origin of thread output tensor on global memory
2345  // blockwise GEMM c matrix starting index
2346  const auto c_thread_mtx_on_block =
2347  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2348 
2349  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2350  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2351 
2352  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2354  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
2356  make_tuple(Sequence<0>{}));
2357 
2358  const auto m_thread_data_on_block_idx =
2359  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2360  make_multi_index(m_thread_data_on_block));
2361 
2362  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2366  make_tuple(Sequence<0>{}));
2367 
2368  const auto n_thread_data_on_block_idx =
2369  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2370  make_multi_index(n_thread_data_on_block));
2371 
2372  // shuffle: threadwise copy C from VGPR to LDS
2373  auto c_thread_copy_vgpr_to_lds =
2375  CShuffleDataType,
2376  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2377  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2379  Sequence<CShuffleMXdlPerWavePerShuffle,
2380  CShuffleNXdlPerWavePerShuffle,
2381  I1,
2382  I1,
2383  M2,
2384  I1,
2385  M4,
2386  I1>,
2388  7,
2389  1,
2391  1,
2392  true>{
2393  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2394  make_multi_index(0,
2395  0,
2396  m_thread_data_on_block_idx[I1],
2397  n_thread_data_on_block_idx[I1],
2398  m_thread_data_on_block_idx[I2],
2399  m_thread_data_on_block_idx[I3],
2400  m_thread_data_on_block_idx[I4],
2401  n_thread_data_on_block_idx[I2]),
2403 
2404  using EDataType = CDataType;
2405 
2406  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2407  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2408 
2409  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2411  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2412 
2413  const auto ds_grid_buf = generate_tuple(
2414  [&](auto i) {
2415  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2416  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2417  },
2418  Number<NumDTensor>{});
2419 
2420  // tuple of reference to C/Ds tensor descriptors
2421  const auto c_ds_desc_refs = concat_tuple_of_reference(
2422  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2423  generate_tie([&](auto i) -> const auto& // return type should be reference
2424  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2425  Number<NumDTensor>{}));
2426 
2427  // tuple of reference to C/Ds tensor descriptors
2428  const auto c_ds_buf_refs = concat_tuple_of_reference(
2429  tie(c_shuffle_block_buf),
2430  generate_tie([&](auto i) -> const auto& // return type should be reference
2431  { return ds_grid_buf[i]; },
2432  Number<NumDTensor>{}));
2433 
2434  // tuple of starting index of C/Ds blockwise copy
2435  const auto idx_c_ds_block_begin =
2438  [&](auto) {
2439  return make_multi_index(block_m_id, 0, block_n_id, 0);
2440  // return make_multi_index(block_work_idx[I0], 0,
2441  // block_work_idx[I1], 0);
2442  },
2443  Number<NumDTensor>{}));
2444 
2445  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2446  c_grid_desc_mblock_mperblock_nblock_nperblock;
2447 
2448  using CDEBlockTransferCluster =
2449  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2450  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2451  constexpr index_t scatter_weight_idx = 3; // hack fix felix
2452  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2454  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2456  decltype(c_ds_desc_refs),
2457  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2458  CElementwiseOperation,
2459  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
2460  // support arbitray type
2461  Sequence<1,
2462  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2463  1,
2464  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2465  CDEBlockTransferCluster,
2466  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2467  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2468  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2469  3, // index_t SrcVectorDim,
2470  3, // index_t DstVectorDim,
2471  CDEShuffleBlockTransferScalarPerVectors,
2476  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2477  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2478  IndexType,
2479  1, // ScatterDim
2480  true, // OutputScatter: false, only use scatter weights
2481  scatter_weight_idx // ScatterWeightIdx: ascale
2482  >{c_ds_desc_refs,
2483  idx_c_ds_block_begin,
2484  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2485  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2486  c_element_op};
2487 
2488  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2489  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2490  constexpr auto sfc_c_vgpr =
2493  Sequence<CShuffleMXdlPerWavePerShuffle,
2494  CShuffleNXdlPerWavePerShuffle,
2495  1,
2496  1,
2497  M2,
2498  1,
2499  M4,
2500  1>>{};
2501 
2502  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2503 
2504  // space filling curve for shuffled blockwise C/D/E
2505  constexpr auto sfc_cde_block =
2508  Sequence<1,
2509  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2510  1,
2511  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2512 
2513  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2514  constexpr auto EMThreads =
2515  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2516  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2517  constexpr auto ENThreads =
2518  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2519  static_for<0, num_access, 1>{}([&](auto access_id) {
2520  // make sure it's safe to write to LDS
2522 
2523  auto dstidx = sfc_cde_block.GetIndex(access_id);
2524  const index_t c_token_pos =
2525  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2526  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2527  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2528  IndexType token_offset = fused_token & 0xffffff;
2529  if constexpr(IsInputGemm)
2530  {
2531  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2532  }
2533  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2534  });
2535 
2536  block_sync_lds();
2537 
2538  // each thread write its data from VGPR to LDS
2539  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2540  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2541  c_thread_buf_fp32,
2542  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2543  c_shuffle_block_buf);
2544 
2545  // make sure it's safe to read from LDS
2546  block_sync_lds();
2547 
2548  // each block copy its data from LDS to global
2549  cde_block_copy_lds_and_global.Run(
2550  c_ds_desc_refs,
2551  c_ds_buf_refs,
2552  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2553  tie(c_grid_buf),
2554  scatter_offsets);
2555 
2556  if constexpr(access_id < num_access - 1)
2557  {
2558  constexpr auto cde_lds_and_global_step =
2559  sfc_cde_block.GetForwardStep(access_id);
2560 
2561  // move on Ds
2562  static_for<0, NumDTensor, 1>{}([&](auto i) {
2563  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2564  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2565  });
2566 
2567  // move on E
2568  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2569  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2570  I0,
2571  cde_lds_and_global_step);
2572  }
2573  });
2574  }
2575  }
2576 };
2577 
2578 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:266
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:23
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:275
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
constexpr auto BlockGemmBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp:41
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:297
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:81
Definition: gridwise_moe_gemm.hpp:656
const BDataType * p_b_grid
Definition: gridwise_moe_gemm.hpp:712
const index_t * p_sorted_token_ids
Definition: gridwise_moe_gemm.hpp:708
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_gemm.hpp:709
const AElementwiseOperation a_element_op
Definition: gridwise_moe_gemm.hpp:716
const ADataType * p_a_grid
Definition: gridwise_moe_gemm.hpp:711
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_gemm.hpp:657
const index_t * p_max_token_id
Definition: gridwise_moe_gemm.hpp:710
const BElementwiseOperation b_element_op
Definition: gridwise_moe_gemm.hpp:717
CDataType * p_c_grid
Definition: gridwise_moe_gemm.hpp:714
DsGridPointer p_ds_grid
Definition: gridwise_moe_gemm.hpp:713
const CElementwiseOperation c_element_op
Definition: gridwise_moe_gemm.hpp:718
Definition: gridwise_moe_gemm.hpp:586
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_gemm.hpp:638
index_t NumTokens
Definition: gridwise_moe_gemm.hpp:631
index_t MBlock
Definition: gridwise_moe_gemm.hpp:647
index_t BK0Shuffled
Definition: gridwise_moe_gemm.hpp:651
index_t TopK
Definition: gridwise_moe_gemm.hpp:632
index_t K
Definition: gridwise_moe_gemm.hpp:635
__host__ void Print() const
Definition: gridwise_moe_gemm.hpp:620
index_t NPadded
Definition: gridwise_moe_gemm.hpp:642
index_t BK0
Definition: gridwise_moe_gemm.hpp:646
index_t KRead
Definition: gridwise_moe_gemm.hpp:643
index_t MPadded
Definition: gridwise_moe_gemm.hpp:641
index_t AK0
Definition: gridwise_moe_gemm.hpp:645
index_t StrideA
Definition: gridwise_moe_gemm.hpp:636
index_t StrideC
Definition: gridwise_moe_gemm.hpp:639
index_t M
Definition: gridwise_moe_gemm.hpp:633
index_t KBatch
Definition: gridwise_moe_gemm.hpp:640
index_t BN0Shuffled
Definition: gridwise_moe_gemm.hpp:650
__host__ __device__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_gemm.hpp:587
index_t KPadded
Definition: gridwise_moe_gemm.hpp:644
index_t StrideB
Definition: gridwise_moe_gemm.hpp:637
index_t N
Definition: gridwise_moe_gemm.hpp:634
index_t NBlock
Definition: gridwise_moe_gemm.hpp:648
Definition: gridwise_moe_gemm.hpp:722
index_t a_k_split_offset
Definition: gridwise_moe_gemm.hpp:754
index_t b_k_split_offset
Definition: gridwise_moe_gemm.hpp:755
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_gemm.hpp:723
Definition: gridwise_moe_gemm.hpp:165
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_gemm.hpp:240
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:292
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_gemm.hpp:211
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:286
static constexpr index_t KRepeat
Definition: gridwise_moe_gemm.hpp:204
remove_cvref_t< decltype(BlockGemmBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_gemm.hpp:920
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_gemm.hpp:255
static constexpr index_t NLane
Definition: gridwise_moe_gemm.hpp:206
static constexpr auto I5
Definition: gridwise_moe_gemm.hpp:171
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_gemm.hpp:541
static constexpr auto BK0Number
Definition: gridwise_moe_gemm.hpp:179
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_gemm.hpp:324
static constexpr index_t NumDTensor
Definition: gridwise_moe_gemm.hpp:184
__host__ static constexpr __device__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_gemm.hpp:1124
static constexpr auto I2
Definition: gridwise_moe_gemm.hpp:168
static constexpr index_t APackedSize
Definition: gridwise_moe_gemm.hpp:226
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_gemm.hpp:299
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_gemm.hpp:224
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_gemm.hpp:406
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_gemm.hpp:414
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_gemm.hpp:511
static constexpr auto I6
Definition: gridwise_moe_gemm.hpp:172
static constexpr auto I0
Definition: gridwise_moe_gemm.hpp:166
static constexpr index_t SortedTileSize
Definition: gridwise_moe_gemm.hpp:209
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_gemm.hpp:1117
static constexpr auto I1
Definition: gridwise_moe_gemm.hpp:167
static constexpr auto I4
Definition: gridwise_moe_gemm.hpp:170
static constexpr auto AK1Number
Definition: gridwise_moe_gemm.hpp:180
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:274
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_gemm.hpp:304
static constexpr auto BK1Number
Definition: gridwise_moe_gemm.hpp:181
static constexpr auto BlockSizeNumber
Definition: gridwise_moe_gemm.hpp:182
static constexpr index_t BPackedSize
Definition: gridwise_moe_gemm.hpp:233
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_gemm.hpp:517
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_gemm.hpp:264
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_gemm.hpp:222
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm.hpp:1862
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:280
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_gemm.hpp:562
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm.hpp:944
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_gemm.hpp:881
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_gemm.hpp:502
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_gemm.hpp:175
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_gemm.hpp:250
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm.hpp:1132
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_gemm.hpp:922
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_gemm.hpp:874
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_gemm.hpp:260
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm.hpp:1153
static constexpr index_t KPack
Definition: gridwise_moe_gemm.hpp:187
static constexpr index_t NWave
Definition: gridwise_moe_gemm.hpp:207
static constexpr auto I3
Definition: gridwise_moe_gemm.hpp:169
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_gemm.hpp:269
static constexpr auto AK0Number
Definition: gridwise_moe_gemm.hpp:178
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm.hpp:574
static constexpr index_t KGroup
Definition: gridwise_moe_gemm.hpp:192
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_gemm.hpp:310
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_gemm.hpp:758
static constexpr index_t KLane
Definition: gridwise_moe_gemm.hpp:189
static constexpr auto I7
Definition: gridwise_moe_gemm.hpp:173
Definition: xdlops_gemm.hpp:942
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1388
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1343
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1382
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:197
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:981
Definition: unary_element_wise_operation.hpp:308
Definition: unary_element_wise_operation.hpp:1023