/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp Source File
gridwise_moe_mx_gemm_bns.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/env.hpp"
18 
20 
21 #define DEBUG_LOG 0
22 
23 namespace ck {
24 
25 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
26 // kernel function Blockers:
27 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
28 // two lds chunks.
29 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
30 // buffer when we declare __shared__ inside blkgemmpipe
31 
33 {
34  gelu_and_mul = 0,
35  silu_and_mul = 1
36 };
37 
38 template <typename GridwiseGemm,
39  bool HasMainKBlockLoop,
40  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
41  index_t MinimumOccupancy = 1,
42  TailNumber TailNum = TailNumber::Even>
43 __global__ void
44 #if CK_USE_LAUNCH_BOUNDS
45  __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
46 #endif
47  // __attribute__((amdgpu_waves_per_eu(1, 1)))
48  kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
49 {
50 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
51  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
52 
53  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
54 
55  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
56  karg.p_sorted_token_ids,
57  karg.p_sorted_expert_ids,
58  karg.p_max_token_id,
59  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
60  karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
61  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
62  karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
63  karg.p_ds_grid,
64  karg.p_c_grid,
65  p_shared,
66  karg,
67  karg.a_element_op,
68  karg.b_element_op,
69  karg.c_element_op);
70 #else
71  ignore = karg;
72 #endif // end of if (defined(__gfx9__))
73 }
74 
75 #if 0
76 template <typename GridwiseGemm,
77  bool HasMainKBlockLoop,
78  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
79  index_t MinimumOccupancy = 1,
80  TailNumber TailNum = TailNumber::Even>
81 __global__ void
82 #if CK_USE_LAUNCH_BOUNDS
83 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
84 #endif
85  // __attribute__((amdgpu_waves_per_eu(1, 1)))
86  kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
87 {
88 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
89  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
90  __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
91 
92  // auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
93 
94  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
95  karg.p_sorted_token_ids,
96  karg.p_sorted_expert_ids,
97  karg.p_max_token_id,
98  karg.p_a_grid,
99  karg.p_a_scale_grid,
100  karg.p_b_grid,
101  karg.p_b_scale_grid,
102  karg.p_ds_grid,
103  karg.p_c_grid,
104  p_shared,
105  p_shared1,
106  karg,
107  karg.a_element_op,
108  karg.b_element_op,
109  karg.c_element_op);
110 #else
111  ignore = karg;
112 #endif // end of if (defined(__gfx9__))
113 }
114 #endif
115 
116 template <typename ALayout,
117  typename BLayout,
118  typename DsLayout,
119  typename CLayout,
120  typename ADataType,
121  typename AScaleDataType,
122  typename BDataType,
123  typename BScaleDataType,
124  typename AccDataType,
125  typename CShuffleDataType,
126  typename DsDataType,
127  typename CDataType,
128  typename AElementwiseOperation,
129  typename BElementwiseOperation,
130  typename CElementwiseOperation,
132  index_t ScaleBlockSize,
133  index_t BlockSize,
134  index_t MPerBlock,
135  index_t NPerBlock,
136  index_t KPerBlock,
137  index_t AK1Value,
138  index_t BK1Value,
139  index_t MPerXdl,
140  index_t NPerXdl,
141  index_t MXdlPerWave,
142  index_t NXdlPerWave,
143  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
144  typename ABlockTransferThreadClusterArrangeOrder,
145  typename ABlockTransferSrcAccessOrder,
146  index_t ABlockTransferSrcVectorDim,
147  index_t ABlockTransferSrcScalarPerVector,
148  index_t ABlockTransferDstScalarPerVector_AK1,
149  bool AThreadTransferSrcResetCoordinateAfterRun,
150  index_t ABlockLdsExtraM,
151  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
152  typename BBlockTransferThreadClusterArrangeOrder,
153  typename BBlockTransferSrcAccessOrder,
154  index_t BBlockTransferSrcVectorDim,
155  index_t BBlockTransferSrcScalarPerVector,
156  index_t BBlockTransferDstScalarPerVector_BK1,
157  bool BThreadTransferSrcResetCoordinateAfterRun,
158  index_t BBlockLdsExtraN,
159  index_t CShuffleMXdlPerWavePerShuffle,
160  index_t CShuffleNXdlPerWavePerShuffle,
161  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
162  typename CDEShuffleBlockTransferScalarPerVectors,
165  index_t ActivationOperation = 0,
166  bool NSwizzle = false,
167  bool IsInputGemm = true,
168  bool MulRoutedWeight = true,
169  typename IndexType = index_t,
170  typename ComputeTypeA = ADataType,
171  typename ComputeTypeB = BDataType>
173 {
174  using LDSTypeA = ADataType;
175  using LDSTypeB = BDataType;
176 
177  static constexpr auto I0 = Number<0>{};
178  static constexpr auto I1 = Number<1>{};
179  static constexpr auto I2 = Number<2>{};
180  static constexpr auto I3 = Number<3>{};
181  static constexpr auto I4 = Number<4>{};
182  static constexpr auto I5 = Number<5>{};
183  static constexpr auto I6 = Number<6>{};
184  static constexpr auto I7 = Number<7>{};
185  static constexpr auto I8 = Number<8>{};
186  static constexpr auto I9 = Number<9>{};
187 
189  CDEShuffleBlockTransferScalarPerVectors{}[I0];
190  // K1 should be Number<...>
191  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
192  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
193  static constexpr auto AK1Number = Number<AK1Value>{};
194  static constexpr auto BK1Number = Number<BK1Value>{};
195 
196  static constexpr index_t NumDTensor = DsDataType::Size();
197 
198  static constexpr auto MXdlPack = 2;
199  static constexpr auto NXdlPack = 2;
200  static constexpr auto KXdlPack = 2;
201 
202  static constexpr index_t APackedSize = packed_size_v<ADataType>;
203  static constexpr index_t BPackedSize = packed_size_v<BDataType>;
204 
205  static constexpr bool is_single_rate_mfma = false;
206  static constexpr auto is_scale_mfma = true;
207  using mfma_selector = MfmaSelector<ComputeTypeA,
208  MPerXdl,
209  NPerXdl,
210  ComputeTypeB,
212  is_scale_mfma>;
213  static constexpr index_t KPack = math::max(
215 
216  // static constexpr index_t NumTokens = 1;
217  static constexpr index_t SortedTileSize = MPerBlock;
218 
219  static constexpr auto MakeDsGridPointer()
220  {
221  return generate_tuple(
222  [&](auto i) {
223  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
224 
225  return static_cast<const DDataType*>(nullptr);
226  },
228  }
229 
230  using DsGridPointer = decltype(MakeDsGridPointer());
231 
233 
234  __host__ static auto CalculateGridSize(index_t M, index_t N)
235  {
236  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
237  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
238  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
239  const index_t gridy = NSwizzle ? 1 : mblock;
240 
241  return std::make_tuple(gridx, gridy, 1);
242  }
243 
244  __host__ static auto CalculateMPadded(index_t M)
245  {
246  return math::integer_least_multiple(M, MPerBlock);
247  }
248 
249  __host__ static auto CalculateNPadded(index_t N)
250  {
251  return math::integer_least_multiple(N, NPerBlock);
252  }
253 
254  __host__ static auto CalculateKPadded(index_t K)
255  {
256  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
257  }
258 
259  __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
260  {
261  auto K_t = K_Batch * KPerBlock;
262  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
263  }
264 
265  __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
266  {
267  auto K_t = K_Batch * KPerBlock;
268  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
269  }
270 
271  __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
272  {
273  auto K_t = K_Batch * KPerBlock;
274  return (K + K_t - 1) / K_t * KPerBlock;
275  }
276 
277  __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
278  {
279  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
280  auto K_t = K_Batch * KReadVec;
281  return (K + K_t - 1) / K_t * KReadVec;
282  }
283 
284  __host__ static auto CalculateMBlock(index_t M)
285  {
286  return math::integer_divide_ceil(M, MPerBlock);
287  }
288 
289  __host__ static auto CalculateNBlock(index_t N)
290  {
291  return math::integer_divide_ceil(N, NPerBlock);
292  }
293 
294  template <index_t MNXdlPerWave,
295  index_t MNWaves,
296  index_t MNXdlPack,
297  index_t MNPerXdl,
298  typename TileDesc_K0_MN_K1>
299  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
300  {
301  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
302  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
303 
305  TileDesc_K0_MN_K1{},
308  Number<MNWaves>{},
310  Number<MNPerXdl>{}))),
313  }
314 
315  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
316  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
317  {
318  const auto a_grid_desc_mraw_kraw = [&]() {
319  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
320  {
321  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
322  }
323  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
324  {
325  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
326  }
327  }();
328 
330 
331  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
332  GemmSpec == GemmSpecialization::MNKPadding)
333  {
334  // pad both M and K
335  const auto a_grid_desc_m_k =
336  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
338  make_right_pad_transform(K, KPad - K)),
341 
342  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
343  a_grid_desc_m_k,
348 
349  return a_grid_desc_ak0_m_ak1;
350  }
351  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
352  GemmSpec == GemmSpecialization::MNPadding)
353  {
354  // pad M, but not K
355  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
356  a_grid_desc_mraw_kraw,
358  make_right_pad_transform(M, MPad - M)),
361 
362  return a_grid_desc_ak0_m_ak1;
363  }
364  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
365  GemmSpec == GemmSpecialization::NKPadding)
366  {
367  // pad K, but not M
368  const auto a_grid_desc_m_k = transform_tensor_descriptor(
369  a_grid_desc_mraw_kraw,
373 
374  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
375  a_grid_desc_m_k,
380 
381  return a_grid_desc_ak0_m_ak1;
382  }
383  else
384  {
385  // not pad M or K
386  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
387  a_grid_desc_mraw_kraw,
392 
393  return a_grid_desc_ak0_m_ak1;
394  }
395  }
396 
397  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
398  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
399  {
400  const auto b_grid_desc_nraw_kraw = [&]() {
402  {
403  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
404  }
406  {
407  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
408  }
409  }();
410 
412 
413  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
414  GemmSpec != GemmSpecialization::Default),
415  "pk_i4_t does not support padding");
416  static_assert(!(is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t> &&
417  GemmSpec != GemmSpecialization::Default),
418  "f4x2_pk_t does not support padding");
419 
420  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
421  GemmSpec == GemmSpecialization::MNKPadding)
422  {
423  // pad both N and K
424  const auto b_grid_desc_n_k =
425  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
427  make_right_pad_transform(K, KPad - K)),
430 
431  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
432  b_grid_desc_n_k,
437 
438  return b_grid_desc_bk0_n_bk1;
439  }
440  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
441  GemmSpec == GemmSpecialization::MNPadding)
442  {
443  // pad N, but not K
444  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
445  b_grid_desc_nraw_kraw,
447  make_right_pad_transform(N, NPad - N)),
450 
451  return b_grid_desc_bk0_n_bk1;
452  }
453  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
454  GemmSpec == GemmSpecialization::MKPadding)
455  {
456  // pad K, but not N
457  const auto b_grid_desc_n_k = transform_tensor_descriptor(
458  b_grid_desc_nraw_kraw,
462 
463  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
464  b_grid_desc_n_k,
469 
470  return b_grid_desc_bk0_n_bk1;
471  }
472  else
473  {
474  // not pad N or K
475  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
476  b_grid_desc_nraw_kraw,
481 
482  return b_grid_desc_bk0_n_bk1;
483  }
484  }
485 
486  template <typename ABlockDesc_AK0_M_AK1>
487  __host__ __device__ static constexpr auto
488  MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&)
489  {
490  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
491 
492  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MXdlPack, MPerXdl>(
493  ABlockDesc_AK0_M_AK1{});
494  }
495 
496  template <typename BBlockDesc_BK0_N_BK1>
497  __host__ __device__ static constexpr auto
498  MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&)
499  {
500  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
501 
502  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NXdlPack, NPerXdl>(
503  BBlockDesc_BK0_N_BK1{});
504  }
505 
506  template <typename ELayout>
507  __host__ __device__ static auto MakeCGridDescriptor_M_N(
508  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
509  {
510  const auto c_grid_desc_mraw_nraw = [&]() {
512  {
513  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
514  }
516  {
517  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
518  }
519  }();
520 
521  // pad M and N
522  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
524  make_right_pad_transform(N, NPad - N)),
527  }
528 
529  template <typename DLayout>
530  __host__ __device__ static auto
532  {
533  const auto c_grid_desc_mraw_nraw = [&]() {
535  {
536  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
537  }
539  {
540  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
541  }
542  }();
543 
544  // pad M and N
545  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
547  make_right_pad_transform(N, NPad - N)),
550  }
551 
552  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
553  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
554  {
555  return generate_tuple(
556  [&](auto i) {
557  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
558  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
559  },
561  }
562 
563  template <typename DsGridDesc>
565  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
566  {
567  return generate_tuple(
568  [&](auto i) {
570  ds_grid_desc_m_n[i], MBlock, NBlock);
571  },
573  }
574 
575  struct Problem
576  {
577  __host__ Problem(index_t NumTokens_,
578  index_t TopK_,
579  index_t M_,
580  index_t N_,
581  index_t K_,
582  index_t StrideA_,
583  index_t StrideScaleA_,
584  index_t StrideB_,
585  index_t StrideScaleB_,
586  std::array<index_t, NumDTensor> StrideDs_,
587  index_t StrideC_,
588  index_t KBatch_)
589  : NumTokens{NumTokens_},
590  TopK{TopK_},
591  M{M_},
592  N{N_},
593  K{K_},
594  StrideA{StrideA_},
595  StrideScaleA{StrideScaleA_},
596  StrideB{StrideB_},
597  StrideScaleB{StrideScaleB_},
598  StrideDs{StrideDs_},
599  StrideC{StrideC_},
600  KBatch{KBatch_},
603  KRead{CalculateKRead(K_, KBatch_)},
604  KPadded{CalculateKPadded(K_, KBatch_)},
605  AK0{CalculateAK0Padded(K_, KBatch_)},
606  BK0{CalculateBK0Padded(K_, KBatch_)},
607  MBlock{CalculateMBlock(M_)},
609  {
610  }
611 
612  __host__ void Print() const
613  {
614  std::cout << "problem {"
615  << "NumTokens:" << NumTokens << ", "
616  << "TopK:" << TopK << ", "
617  << "M:" << M << ", "
618  << "N:" << N << ", "
619  << "K:" << K << ", "
620  << "SA:" << StrideA << ", "
621  << "SScaleA:" << StrideScaleA << ", "
622  << "SB:" << StrideB << ", "
623  << "SScaleB:" << StrideScaleB << ", "
624  << "SC:" << StrideC << ", "
625  << "MP:" << MPadded << ", "
626  << "NP:" << NPadded << ", "
627  << "KRead:" << KRead << ", "
628  << "KP:" << KPadded << ", "
629  << "AK0:" << AK0 << ", "
630  << "BK0:" << BK0 << ", "
631  << "MBlock: " << MBlock << ", "
632  << "NBlock: " << NBlock << "}" << std::endl;
633  }
634 
644  std::array<index_t, NumDTensor> StrideDs;
655  };
656 
657  // Argument
659  {
660  __host__ Argument(const index_t* p_sorted_token_ids_,
661  const index_t* p_sorted_expert_ids_,
662  const index_t* p_max_token_id_,
663  const ADataType* p_a_grid_,
664  const AScaleDataType* p_a_scale_grid_,
665  const BDataType* p_b_grid_,
666  const BScaleDataType* p_b_scale_grid_,
667  std::array<const void*, NumDTensor> p_ds_grid_,
668  CDataType* p_c_grid_,
669  index_t NumTokens_,
670  index_t TopK_,
671  index_t M_,
672  index_t N_,
673  index_t K_,
674  index_t StrideA_,
675  index_t StrideScaleA_,
676  index_t StrideB_,
677  index_t StrideScaleB_,
678  std::array<index_t, NumDTensor> StrideDs_,
679  index_t StrideC_,
680  index_t k_batch_,
681  AElementwiseOperation a_element_op_,
682  BElementwiseOperation b_element_op_,
683  CElementwiseOperation c_element_op_)
684  : Problem{NumTokens_,
685  TopK_,
686  M_,
687  N_,
688  K_ / APackedSize,
689  StrideA_ / APackedSize,
690  StrideScaleA_,
691  StrideB_ / BPackedSize,
692  StrideScaleB_,
693  StrideDs_,
694  StrideC_,
695  k_batch_},
696  p_sorted_token_ids{p_sorted_token_ids_},
697  p_sorted_expert_ids{p_sorted_expert_ids_},
698  p_max_token_id{p_max_token_id_},
699  p_a_grid{p_a_grid_},
700  p_a_scale_grid{p_a_scale_grid_},
701  p_b_grid{p_b_grid_},
702  p_b_scale_grid{p_b_scale_grid_},
703  p_ds_grid{},
704  p_c_grid{p_c_grid_},
705  a_element_op{a_element_op_},
706  b_element_op{b_element_op_},
707  c_element_op{c_element_op_}
708  {
709 
710  // populate pointer, desc for Ds
711  static_for<0, NumDTensor, 1>{}([&](auto i) {
712  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
713 
714  // D pointer
715  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
716  });
717  }
718 
722  const ADataType* p_a_grid;
723  const AScaleDataType* p_a_scale_grid;
724  const BDataType* p_b_grid;
725  const BScaleDataType* p_b_scale_grid;
727  CDataType* p_c_grid;
728 
729  const AElementwiseOperation a_element_op;
730  const BElementwiseOperation b_element_op;
731  const CElementwiseOperation c_element_op;
732  };
733 
735  {
736  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
737  {
738  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
739  {
740  a_k_split_offset = k_id * karg.KRead;
741  }
742  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
743  {
744  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
745  }
746 
747  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
748  {
749  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
750  }
751  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
752  {
753  // KPack * NLane * KLane * K0 * N0
754  b_k_split_offset = k_id * karg.KRead;
755  }
756 
757  // Calculate A scale offset
758  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
759  {
760  a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize);
761  }
762  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
763  {
765  k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA;
766  }
767 
768  // Calculate B scale offset
769  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
770  {
772  k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB;
773  }
774  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
775  {
776  b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize);
777  }
778 
779  if(k_id < karg.KBatch - 1)
780  {
781  karg.K = karg.KRead;
782  }
783  else
784  {
785  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
786  }
787  }
788 
793  };
794 
795  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
796  {
797  // A matrix in LDS memory, dst of blockwise copy
798  if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
799  {
803  }
804  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
805  // in some cases.
807  {
808  constexpr auto a_lds_block_desc =
811 
812  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
813  a_lds_block_desc,
819 
820  return a_lds_block_desc_permuted;
821  }
822  else // ColumnMajor A
823  {
824  // kfold and mpair dimension is not always required.
825  // more dimension in merge_transform increase the difficulty of generating immarg offset
826  // for compiler.
827  constexpr auto WaveSize = 64;
828  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
829  constexpr auto M1 = MPerBlock / M0;
830 
831  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
832  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
833  constexpr auto KThreadRead = WaveSize / MPerXdl;
834  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
835 
836  constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
837  ? 1
838  : 128 / (AK1Number * M0 * sizeof(ADataType));
839  constexpr auto KThreadReadPerm =
840  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
841  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
842  : KThreadRead;
843 
844  // 1<=mpair<=n0
845  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
846  ? 1
847  : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
848  ? M0
849  : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
850 
851  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
855  Number<kfold * M0 / mpair>{},
856  Number<mpair>{},
857  AK1Number));
858 
859  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
860  a_lds_block_desc,
861  make_tuple(
865  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
868  make_tuple(
870  make_tuple(
872 
873  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
874  a_lds_block_desc_permuted,
875  make_tuple(
883  Sequence<1>{},
884  Sequence<2>{},
885  Sequence<3>{},
886  Sequence<4>{},
887  Sequence<5>{}),
889  Sequence<2>{},
890  Sequence<0, 3>{},
891  Sequence<4, 5>{},
892  Sequence<6>{},
893  Sequence<7>{}));
894 
895  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
896  a_lds_block_desc_unmerged,
899  Number<KThreadWrite / kfold / KThreadReadPerm>{},
900  Number<kfold>{},
907 
908  return a_lds_block_desc_ak0_m_ak1;
909  }
910  }
911 
912  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
913  {
914  // B matrix in LDS memory, dst of blockwise copy
915  if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
916  {
920  }
922  {
923  // NLdsLayer * K0 as logical Bank
924  constexpr auto b_lds_block_desc =
927 
928  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
929  b_lds_block_desc,
935 
936  return b_lds_block_desc_permuted;
937  }
938  else // RowMajor B
939  {
940  constexpr auto WaveSize = 64;
941  constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
942  constexpr auto N1 = NPerBlock / N0;
943 
944  constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
945  constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
946  constexpr auto KThreadRead = WaveSize / NPerXdl;
947  constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
948 
949  constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
950  ? 1
951  : 128 / (BK1Number * N0 * sizeof(BDataType));
952  constexpr auto KThreadReadPerm =
953  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
954  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
955  : KThreadRead;
956 
957  // 1<=npair<=n0
958  constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
959  ? 1
960  : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
961  ? N0
962  : 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
963 
964  constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
968  Number<kfold * N0 / npair>{},
969  Number<npair>{},
970  BK1Number));
971 
972  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
973  b_lds_block_desc,
974  make_tuple(
978  make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
981  make_tuple(
983  make_tuple(
985 
986  constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
987  b_lds_block_desc_permuted,
988  make_tuple(
996  Sequence<1>{},
997  Sequence<2>{},
998  Sequence<3>{},
999  Sequence<4>{},
1000  Sequence<5>{}),
1002  Sequence<2>{},
1003  Sequence<0, 3>{},
1004  Sequence<4, 5>{},
1005  Sequence<6>{},
1006  Sequence<7>{}));
1007 
1008  constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
1009  b_lds_block_desc_unmerged,
1012  Number<KThreadWrite / kfold / KThreadReadPerm>{},
1013  Number<kfold>{},
1020 
1021  return b_lds_block_desc_bk0_n_bk1;
1022  }
1023  }
1024 
1026  {
1027  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1028  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1029 
1030  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1032  make_tuple(I1,
1034  I1,
1036 
1037  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1038  }
1039 
1042  BlkGemmPipelineVer,
1043  BlkGemmPipeSched,
1044  BlockSize,
1045  ScaleBlockSize,
1046  ADataType,
1047  AScaleDataType,
1048  BDataType,
1049  BScaleDataType,
1050  ComputeTypeA,
1051  AccDataType,
1058  ABlockTransferSrcScalarPerVector,
1059  BBlockTransferSrcScalarPerVector,
1060  MPerBlock,
1061  NPerBlock,
1062  KPerBlock,
1063  MPerXdl,
1064  NPerXdl,
1065  MXdlPerWave,
1066  NXdlPerWave,
1067  KPack,
1068  IsInputGemm>())>;
1069 
1070  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1071  {
1072  // LDS allocation for A and B: be careful of alignment
1073  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1074  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1075 
1076  // lds max alignment
1077  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1078 
1079  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1080  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1081 
1082  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1083  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1084 
1085  // LDS allocation for C shuffle in LDS
1086  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1088 
1089  constexpr auto c_block_size =
1090  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1091 
1092  if constexpr(IsInputGemm)
1093  {
1094  return math::max((a_block_space_size_aligned * sizeof(ADataType) +
1095  b_block_space_size_aligned * sizeof(BDataType)) *
1096  2,
1097  c_block_size * sizeof(CShuffleDataType));
1098  }
1099  else
1100  {
1101  return math::max((a_block_space_size_aligned * sizeof(ADataType) +
1102  b_block_space_size_aligned * sizeof(BDataType)),
1103  c_block_size * sizeof(CShuffleDataType));
1104  }
1105  }
1106 
1107  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1108  __host__ static constexpr bool CheckValidity(const Argument& karg)
1109  {
1110  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1111  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1112  "Invalid tuning param!");
1113 
1114  static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
1115  "KPerBlock should be multiple of ScaleBlockSize");
1116 
1117  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1122  {
1123  if(!(karg.M % MPerBlock == 0))
1124  {
1125  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1126  {
1127  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1128  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1129  << std::endl;
1130  }
1131  return false;
1132  }
1133  }
1134 
1135  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1140  {
1141  if(!(karg.N % NPerBlock == 0))
1142  {
1143  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1144  {
1145  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1146  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1147  << std::endl;
1148  }
1149  return false;
1150  }
1151  }
1152 
1153  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1157  {
1158  auto K_t = karg.KBatch * KPerBlock;
1159  if(!(karg.K % K_t == 0))
1160  {
1161  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1162  {
1163  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1164  << karg.K << " " << __FILE__ << ":" << __LINE__
1165  << ", in function: " << __func__ << std::endl;
1166  }
1167  return false;
1168  }
1169  }
1170  else
1171  {
1172  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1173  auto K_t = karg.KBatch * KReadVec;
1174  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1175  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1176  {
1177  return false;
1178  }
1179  }
1180 
1182  {
1183  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1184  {
1185  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1186  {
1187  std::cout << "Arg K (" << karg.K
1188  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1189  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1190  << __LINE__ << ", in function: " << __func__ << std::endl;
1191  }
1192  return false;
1193  }
1194  }
1195  else
1196  {
1197  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1198  {
1199  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1200  {
1201  std::cout << "Arg M (" << karg.M
1202  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1203  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1204  << __LINE__ << ", in function: " << __func__ << std::endl;
1205  }
1206  return false;
1207  }
1208  }
1209 
1211  {
1212  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1213  {
1214  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1215  {
1216  std::cout << "Arg N (" << karg.N
1217  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1218  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1219  << __LINE__ << ", in function: " << __func__ << std::endl;
1220  }
1221  return false;
1222  }
1223  }
1224  else
1225  {
1226  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1227  {
1228  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1229  {
1230  std::cout << "Arg K (" << karg.K
1231  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1232  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1233  << __LINE__ << ", in function: " << __func__ << std::endl;
1234  }
1235  return false;
1236  }
1237  }
1238 
1240  {
1242  {
1243  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1244  {
1245  std::cout << "Arg N (" << karg.N
1246  << ") value is not a multiple of "
1247  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1249  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1250  << std::endl;
1251  }
1252  return false;
1253  }
1254  }
1255  else
1256  {
1258  {
1259  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1260  {
1261  std::cout << "Arg M (" << karg.M
1262  << ") value is not a multiple of "
1263  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1265  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1266  << std::endl;
1267 
1268  return false;
1269  }
1270  }
1271  }
1272 
1273  // check gridwise gemm pipeline
1274 #if 0
1275  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1276 
1277  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1278  {
1279  return false;
1280  }
1281 #endif
1282  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1283  return true;
1284  }
1285 
1286  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1287  {
1288  const index_t num_loop = K / KPerBlock;
1289 
1290  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1291  }
1292 
1293  __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1294  {
1295  const index_t num_loop = K / KPerBlock;
1296 
1297  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1298  }
1299 
1300  template <typename CGridDesc>
1301  __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1302  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1303  {
1304  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1305  c_grid_desc_m_n,
1310 
1311  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1312  }
1313 
1314  // return block_id to C matrix tile idx (m0, n0) mapping
1315  // if arch = gfx942
1316  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1317  // NPerBlock>;
1318 
1320  static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
1321  static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
1322  static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
1323  "A scale pack data type too large!");
1324  static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
1325  "B scale pack data type too large!");
1326 
1327  template <bool HasMainKBlockLoop,
1328  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1329  TailNumber TailNum = TailNumber::Odd>
1330  __device__ static void Run(const index_t* p_sorted_token_ids,
1331  const index_t* p_sorted_expert_ids,
1332  const index_t* p_max_token_id,
1333  const ADataType* p_a_grid,
1334  const AScaleDataType* p_a_scale_grid,
1335  const BDataType* p_b_grid,
1336  const BScaleDataType* p_b_scale_grid,
1337  DsGridPointer& p_ds_grid,
1338  CDataType* p_c_grid,
1339  void* p_shared,
1340  const Problem& problem,
1341  AElementwiseOperation a_element_op,
1342  BElementwiseOperation b_element_op,
1343  CElementwiseOperation c_element_op)
1344  {
1345  ignore = b_element_op;
1346  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1347  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1348  problem.MPadded,
1349  problem.K,
1350  problem.KPadded,
1351  problem.StrideA,
1352  problem.AK0);
1353  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1354  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1355  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1356  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1357  problem.MPadded,
1358  problem.N,
1359  problem.NPadded,
1360  problem.StrideC);
1361 
1362  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
1363  make_tuple(problem.M / (MXdlPack * MPerXdl),
1364  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
1365  (KXdlPack * 64 / MPerXdl),
1366  64 * KXdlPack * MXdlPack / scale_pack_size_a));
1367 
1368  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
1369  make_tuple(problem.N / (NXdlPack * NPerXdl),
1370  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
1371  (KXdlPack * 64 / NPerXdl),
1372  64 * KXdlPack * NXdlPack / scale_pack_size_b));
1373 
1374  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1376  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1377 
1378  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1379  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1380  if(expert_block_id * MPerBlock >= max_token_id)
1381  return;
1382  const index_t expert_id =
1383  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1384 
1385  const auto block_mn = [&]() -> std::pair<int, int> {
1386  if constexpr(NSwizzle)
1387  {
1388  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1389  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1390  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1391  const index_t expert_swizzle =
1392  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1393  const index_t bid_new = blockIdx.x - prefix_block;
1394  const index_t nid = __builtin_amdgcn_readfirstlane(
1395  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1396  const index_t mid =
1397  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1398  return {nid, mid};
1399  }
1400  else
1401  {
1402  return {blockIdx.x, blockIdx.y};
1403  }
1404  }();
1405 
1406  const index_t block_n_id = block_mn.first;
1407  const index_t block_m_id = block_mn.second;
1408  const index_t token0 =
1409  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1410 
1411  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1412  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1413  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1414  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1415  constexpr auto AKThreads = AK0Threads * AK1Threads;
1416  constexpr auto AMRepeats = MPerBlock / AMThreads;
1417  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1418 
1419  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1420  return;
1422  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1423  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1424  index_t token_offset = fused_token & 0xffffff;
1425  if constexpr(!IsInputGemm)
1426  {
1427  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1428  }
1429  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1430  });
1431 
1432  const index_t expert_stride =
1433  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1434  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1435  problem.N * (IsInputGemm ? 2 : 1) *
1436  math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
1437 
1438  // N0, K0, Blocksize*KPack
1439  const index_t n_block_data_idx_on_grid =
1440  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1441 
1442  // Gride buffer creation
1443  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1444  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1445  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1446  p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1447 
1448  // A, B scale buffer
1449  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1450  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1451  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1452  p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
1453  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1454 
1455  // lds max alignment
1456  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1457 
1458  // A matrix in LDS memory, dst of blockwise copy
1459  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1460 
1461  // B matrix in LDS memory, dst of blockwise copy
1462  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1463 
1464  // A matrix blockwise copy
1465  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1467  AElementwiseOperation,
1471  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1472  ABlockTransferThreadClusterArrangeOrder,
1473  ADataType,
1474  ADataType,
1475  decltype(a_grid_desc_ak0_m_ak1),
1476  decltype(a_block_desc_ak0_m_ak1),
1477  ABlockTransferSrcAccessOrder,
1479  ABlockTransferSrcVectorDim,
1480  2,
1481  ABlockTransferSrcScalarPerVector,
1482  ABlockTransferDstScalarPerVector_AK1,
1483  1,
1484  1,
1485  AThreadTransferSrcResetCoordinateAfterRun,
1486  true,
1487  IndexType,
1488  1,
1489  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1490  make_multi_index(0, 0, 0),
1491  a_element_op,
1492  a_block_desc_ak0_m_ak1,
1493  make_multi_index(0, 0, 0),
1495  gather_offsets);
1496 
1497  // B matrix blockwise copy
1498  auto b_blockwise_copy =
1500  BElementwiseOperation,
1504  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1505  BBlockTransferThreadClusterArrangeOrder,
1506  BDataType,
1507  BDataType,
1508  decltype(b_grid_desc_bk0_n_bk1),
1509  decltype(b_block_desc_bk0_n_bk1),
1510  BBlockTransferSrcAccessOrder,
1512  BBlockTransferSrcVectorDim,
1513  2,
1514  BBlockTransferSrcScalarPerVector,
1515  BBlockTransferDstScalarPerVector_BK1,
1516  1,
1517  1,
1518  BThreadTransferSrcResetCoordinateAfterRun,
1519  true,
1520  BlockwiseGemmPipe::GlobalBufferNum>(
1521  b_grid_desc_bk0_n_bk1,
1522  make_multi_index(0, n_block_data_idx_on_grid, 0),
1523  b_element_op,
1524  b_block_desc_bk0_n_bk1,
1525  make_multi_index(0, 0, 0),
1527 
1528  // LDS allocation for A and B: be careful of alignment
1529  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1530  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1531 
1532  // Cast after lds
1533  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1534  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1535 
1536  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1537  reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
1538  a_block_space_size_aligned * sizeof(ADataType)),
1539  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1540 
1541  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1542  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1543 
1544  // Blockwise GEMM pipeline
1545  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1546  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1547  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1548  decltype(c_thread_buf) c_thread_buf_up;
1549 
1551  float,
1552  c_thread_buf.num_of_v_,
1553  c_thread_buf.s_per_v,
1554  true>
1555  c_thread_buf_fp32;
1556 
1557  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1558  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1559  KPerBlock);
1560 
1561  // a and b scale processing
1562  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1563  const auto waveId_m = wave_idx[I0];
1564  const auto waveId_n = wave_idx[I1];
1565 
1566  auto thread_offset_shuffled =
1567  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
1568 
1569  auto a_thread_offset_m = waveId_m;
1570 
1571  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1572  AScaleDataType,
1573  AScaleDataType,
1574  decltype(a_scale_grid_desc_am_ak),
1575  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1576  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
1577  Sequence<0, 1, 2>, // DimAccessOrder
1578  2, // SrcVectorDim
1579  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
1580  1, // SrcScalarStrideInVector
1581  true>(a_scale_grid_desc_am_ak,
1582  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
1583  0,
1584  thread_offset_shuffled / scale_pack_size_a));
1585 
1586  // B scale load
1587  auto b_thread_offset_n = waveId_n;
1588 
1589  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1590  BScaleDataType,
1591  BScaleDataType,
1592  decltype(b_scale_grid_desc_bn_ak),
1593  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1594  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1595  Sequence<0, 1, 2>, // DimAccessOrder
1596  2, // SrcVectorDim
1597  KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
1598  1, // SrcScalarStrideInVector
1599  true>(b_scale_grid_desc_bn_ak,
1600  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1601  0,
1602  thread_offset_shuffled / scale_pack_size_b));
1603 
1604  if constexpr(IsInputGemm)
1605  {
1606  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1607  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1608  auto b_block_buf_up = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1609  reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
1610  a_block_space_size_aligned * sizeof(ADataType) +
1611  b_block_space_size_aligned * sizeof(BDataType)),
1612  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1613 
1614  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
1615  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1616  p_b_grid_up + expert_id * expert_stride,
1617  b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1618 
1619  auto b_blockwise_copy_up =
1621  BElementwiseOperation,
1625  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1626  BBlockTransferThreadClusterArrangeOrder,
1627  BDataType,
1628  BDataType,
1629  decltype(b_grid_desc_bk0_n_bk1),
1630  decltype(b_block_desc_bk0_n_bk1),
1631  BBlockTransferSrcAccessOrder,
1633  BBlockTransferSrcVectorDim,
1634  2,
1635  BBlockTransferSrcScalarPerVector,
1636  BBlockTransferDstScalarPerVector_BK1,
1637  1,
1638  1,
1639  BThreadTransferSrcResetCoordinateAfterRun,
1640  true,
1641  BlockwiseGemmPipe::GlobalBufferNum>(
1642  b_grid_desc_bk0_n_bk1,
1643  make_multi_index(0, n_block_data_idx_on_grid, 0),
1644  b_element_op,
1645  b_block_desc_bk0_n_bk1,
1646  make_multi_index(0, 0, 0),
1648 
1649  const BScaleDataType* p_b_scale_grid_up =
1650  p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
1651  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1652  p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
1653  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1654 
1655  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
1656  BScaleDataType,
1657  BScaleDataType,
1658  decltype(b_scale_grid_desc_bn_ak),
1659  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1660  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1661  Sequence<0, 1, 2>, // DimAccessOrder
1662  2, // SrcVectorDim
1663  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1664  1, // SrcScalarStrideInVector
1665  true>(
1666  b_scale_grid_desc_bn_ak,
1667  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1668  0,
1669  thread_offset_shuffled / scale_pack_size_b));
1670 
1671  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1672  // A
1673  a_grid_desc_ak0_m_ak1,
1674  a_block_desc_ak0_m_ak1,
1675  a_blockwise_copy,
1676  a_grid_buf,
1677  a_block_buf,
1678  a_block_slice_copy_step,
1679  // Gate and Up
1680  b_grid_desc_bk0_n_bk1,
1681  b_block_desc_bk0_n_bk1,
1682  b_blockwise_copy,
1683  b_blockwise_copy_up,
1684  b_grid_buf,
1685  b_grid_buf_up,
1686  b_block_buf,
1687  b_block_buf_up,
1688  b_block_slice_copy_step,
1689  // C
1690  c_thread_buf,
1691  c_thread_buf_up,
1692  // A scale
1693  a_scale_grid_desc_am_ak,
1694  a_scale_thread_copy,
1695  a_scale_grid_buf,
1696  // Gate and Up scale
1697  b_scale_grid_desc_bn_ak,
1698  b_scale_thread_copy,
1699  b_scale_thread_copy_up,
1700  b_scale_grid_buf,
1701  b_scale_grid_buf_up,
1702  num_k_block_main_loop);
1703  }
1704  else
1705  {
1706  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1707  a_grid_desc_ak0_m_ak1, // A
1708  a_block_desc_ak0_m_ak1,
1709  a_blockwise_copy,
1710  a_grid_buf,
1711  a_block_buf,
1712  a_block_slice_copy_step,
1713  b_grid_desc_bk0_n_bk1, // B
1714  b_block_desc_bk0_n_bk1,
1715  b_blockwise_copy,
1716  b_grid_buf,
1717  b_block_buf,
1718  b_block_slice_copy_step,
1719  c_thread_buf, // C
1720  a_scale_grid_desc_am_ak, // A scale
1721  a_scale_thread_copy,
1722  a_scale_grid_buf,
1723  b_scale_grid_desc_bn_ak, // B scale
1724  b_scale_thread_copy,
1725  b_scale_grid_buf,
1726  num_k_block_main_loop);
1727  }
1728 
1729  // shuffle C and write out
1730  {
1731  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1732  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1733  "wrong!");
1734  static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
1735  CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
1736  "wrong!");
1737 
1738  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1739  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1740 
1741  // TODO: hacky, fix it!
1742  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1743  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1744 
1745  // TODO: hacky, fix it!
1746  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1747  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1748  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1749 
1750  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1751  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1752  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1753  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1754  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1755  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1756  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1757  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1758  constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
1759  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
1760 
1761  // mul scales
1762  static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
1763  static_assert(M5 == 4);
1764  const index_t m1 = get_warp_local_1d_id() / NWave; // Mwave id
1765  const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
1766 
1767  vector_type<float, 4> topk_weights; // for gemm2 only
1768  static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
1769  static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
1770  static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
1771  static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
1772  static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
1773  const index_t m_pos = block_m_id * MPerBlock +
1774  m0 * M2 * M1 * M3 * M4 * M5 +
1775  m1 * M2 * M3 * M4 * M5 +
1776  imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
1777 
1778  if constexpr(MulRoutedWeight)
1779  {
1780  topk_weights =
1781  *c_style_pointer_cast<const vector_type<float, M5>*>(
1782  p_ds_grid[I2] + m_pos);
1783  }
1784  static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
1785  constexpr index_t c_offset =
1786  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1787  make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
1788  constexpr auto cidx = Number<c_offset>{};
1789 
1790  if constexpr(IsInputGemm) // gu fusion
1791  {
1792  if constexpr(ActivationOperation ==
1794  {
1795  float gate = c_thread_buf[cidx];
1796  float up = c_thread_buf_up[cidx];
1797  if constexpr(MulRoutedWeight)
1798  {
1799  gate = gate * topk_weights.AsType<float>()[m5];
1800  up = up * topk_weights.AsType<float>()[m5];
1801  }
1803  c_thread_buf_fp32(cidx) = gate * up;
1804  }
1805  else if(ActivationOperation == Activation::gelu_and_mul)
1806  {
1807  float gate = c_thread_buf[cidx];
1808  float up = c_thread_buf_up[cidx];
1809  if constexpr(MulRoutedWeight)
1810  {
1811  gate = gate * topk_weights.AsType<float>()[m5];
1812  up = up * topk_weights.AsType<float>()[m5];
1813  }
1815  c_thread_buf_fp32(cidx) = gate * up;
1816 
1817  /*float gate = c_thread_buf[cidx];
1818  float up = c_thread_buf_up[cidx];
1819  if constexpr(MulRoutedWeight)
1820  {
1821  gate = gate * topk_weights.AsType<float>()[m5];
1822  //up = up * topk_weights.AsType<float>()[m5];
1823  }
1824  tensor_operation::element_wise::Gelu{}(gate, gate);
1825  c_thread_buf_fp32(cidx) = up;*/
1826  }
1827  }
1828  else
1829  {
1830  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1831  if constexpr(MulRoutedWeight)
1832  {
1833  c_thread_buf_fp32(cidx) =
1834  topk_weights.AsType<float>()[m5] *
1835  c_thread_buf_fp32[cidx];
1836  }
1837  }
1838  });
1839  });
1840  });
1841  });
1842  });
1843  });
1844 
1845  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1847 
1848  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1849  static_cast<CShuffleDataType*>(p_shared),
1850  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1851 
1852  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1853  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1854  make_tuple(
1858  // per shuffle
1859  M1, // M1 = MWave
1860  M2, // M2 = MXdlPack
1861  M3, // M3 * M4 * M5 = MPerXdl
1862  M4,
1863  M5)),
1867  // per shuffle
1868  N1, // N1 = NWave
1869  N2, // N2 = NXdlPack
1870  N3))), // N3 = NPerXdl
1874  Sequence<>{},
1876 
1877  // calculate origin of thread output tensor on global memory
1878  // blockwise GEMM c matrix starting index
1879  const auto c_thread_mtx_on_block =
1880  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1881 
1882  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1883  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1884 
1885  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1887  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
1889  make_tuple(Sequence<0>{}));
1890 
1891  const auto m_thread_data_on_block_idx =
1892  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1893  make_multi_index(m_thread_data_on_block));
1894 
1895  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1897  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
1899  make_tuple(Sequence<0>{}));
1900 
1901  const auto n_thread_data_on_block_idx =
1902  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1903  make_multi_index(n_thread_data_on_block));
1904 
1905  // shuffle: threadwise copy C from VGPR to LDS
1906  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1907  AccDataType,
1908  CShuffleDataType,
1909  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1910  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1912  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
1913  CShuffleNXdlPerWavePerShuffle / NXdlPack,
1914  I1,
1915  I1,
1916  M2,
1917  N2,
1918  M3,
1919  I1,
1920  M5,
1921  I1>,
1923  9,
1924  1,
1926  1,
1927  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1928  make_multi_index(0,
1929  0,
1930  m_thread_data_on_block_idx[I1],
1931  n_thread_data_on_block_idx[I1],
1932  m_thread_data_on_block_idx[I2],
1933  n_thread_data_on_block_idx[I2],
1934  m_thread_data_on_block_idx[I3],
1935  m_thread_data_on_block_idx[I4],
1936  m_thread_data_on_block_idx[I5],
1937  n_thread_data_on_block_idx[I3]),
1939 
1940  using EDataType = CDataType;
1941 
1942  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1943  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1944 
1945  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1947  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1948 
1949  const auto ds_grid_buf = generate_tuple(
1950  [&](auto i) {
1951  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1952  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1953  },
1954  Number<NumDTensor>{});
1955 
1956  // tuple of reference to C/Ds tensor descriptors
1957  const auto c_ds_desc_refs = concat_tuple_of_reference(
1958  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1959  generate_tie(
1960  [&](auto i) -> const auto& // return type should be reference
1961  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1962  Number<NumDTensor>{}));
1963 
1964  // tuple of reference to C/Ds tensor descriptors
1965  const auto c_ds_buf_refs = concat_tuple_of_reference(
1966  tie(c_shuffle_block_buf),
1967  generate_tie(
1968  [&](auto i) -> const auto& // return type should be reference
1969  { return ds_grid_buf[i]; },
1970  Number<NumDTensor>{}));
1971 
1972  // tuple of starting index of C/Ds blockwise copy
1973  const auto idx_c_ds_block_begin =
1976  [&](auto) {
1977  return make_multi_index(block_m_id, 0, block_n_id, 0);
1978  // return make_multi_index(block_work_idx[I0], 0,
1979  // block_work_idx[I1], 0);
1980  },
1981  Number<NumDTensor>{}));
1982 
1983  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1984  c_grid_desc_mblock_mperblock_nblock_nperblock;
1985 
1986  using CDEBlockTransferCluster =
1987  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1988  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1989  constexpr index_t scatter_weight_idx = 3; // hack fix felix
1990  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1992  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1994  decltype(c_ds_desc_refs),
1995  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1996  CElementwiseOperation,
1997  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
1998  // Sequence support
1999  // arbitray type
2000  Sequence<1,
2001  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2002  1,
2003  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2004  CDEBlockTransferCluster,
2005  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2006  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2007  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2008  3, // index_t SrcVectorDim,
2009  3, // index_t DstVectorDim,
2010  CDEShuffleBlockTransferScalarPerVectors,
2015  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2016  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2017  IndexType,
2018  1, // ScatterDim
2019  true, // OutputScatter: false, only use scatter weights
2020  scatter_weight_idx // ScatterWeightIdx: ascale
2021  >{c_ds_desc_refs,
2022  idx_c_ds_block_begin,
2023  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2024  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2025  c_element_op};
2026 
2027  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2028  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2029 
2030  constexpr auto sfc_c_vgpr =
2031  SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2032  NXdlPerWave / NXdlPack,
2033  1,
2034  1,
2035  MXdlPack,
2036  NXdlPack,
2037  M2,
2038  1,
2039  M4,
2040  1>,
2042  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2043  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2044  1,
2045  1,
2046  MXdlPack,
2047  NXdlPack,
2048  M2,
2049  1,
2050  M4,
2051  1>>{};
2052 
2053  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2054 
2055  // space filling curve for shuffled blockwise C/D/E
2056  constexpr auto sfc_cde_block =
2059  Sequence<1,
2060  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2061  1,
2062  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2063 
2064  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2065  constexpr auto EMThreads =
2066  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2067  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2068  constexpr auto ENThreads =
2069  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2070  static_for<0, num_access, 1>{}([&](auto access_id) {
2071  // make sure it's safe to write to LDS
2073 
2074  auto dstidx = sfc_cde_block.GetIndex(access_id);
2075  const index_t c_token_pos =
2076  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2077  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2078  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2079  IndexType token_offset = fused_token & 0xffffff;
2080  if constexpr(IsInputGemm)
2081  {
2082  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2083  }
2084  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2085  });
2086 
2087  block_sync_lds();
2088 
2089  // each thread write its data from VGPR to LDS
2090  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2091  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2092  c_thread_buf_fp32,
2093  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2094  c_shuffle_block_buf);
2095 
2096  // make sure it's safe to read from LDS
2097  block_sync_lds();
2098 
2099  // each block copy its data from LDS to global
2100  cde_block_copy_lds_and_global.Run(
2101  c_ds_desc_refs,
2102  c_ds_buf_refs,
2103  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2104  tie(c_grid_buf),
2105  scatter_offsets);
2106 
2107  if constexpr(access_id < num_access - 1)
2108  {
2109  constexpr auto cde_lds_and_global_step =
2110  sfc_cde_block.GetForwardStep(access_id);
2111 
2112  // move on Ds
2113  static_for<0, NumDTensor, 1>{}([&](auto i) {
2114  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2115  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2116  });
2117 
2118  // move on E
2119  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2120  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2121  I0,
2122  cde_lds_and_global_step);
2123  }
2124  });
2125  }
2126  }
2127 
2128 #if 0
2129  template <bool HasMainKBlockLoop,
2130  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2131  TailNumber TailNum = TailNumber::Odd>
2132  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
2133  const index_t* p_sorted_expert_ids,
2134  const index_t* p_max_token_id,
2135  const ADataType* p_a_grid,
2136  const AScaleDataType* p_a_scale_grid,
2137  const BDataType* p_b_grid,
2138  const BScaleDataType* p_b_scale_grid,
2139  DsGridPointer& p_ds_grid,
2140  CDataType* p_c_grid,
2141  void* p_shared,
2142  void* p_shared1,
2143  const Problem& problem,
2144  AElementwiseOperation a_element_op,
2145  BElementwiseOperation b_element_op,
2146  CElementwiseOperation c_element_op)
2147  {
2148  ignore = b_element_op;
2149  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2150  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
2151  problem.MPadded,
2152  problem.K,
2153  problem.KPadded,
2154  problem.StrideA,
2155  problem.AK0);
2156  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
2157  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2158  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
2159  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
2160  problem.MPadded,
2161  problem.N,
2162  problem.NPadded,
2163  problem.StrideC);
2164 
2165  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
2166  make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerXdl),
2167  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
2168  (KXdlPack * 64 / MPerXdl),
2169  64 * KXdlPack * MXdlPack / scale_pack_size_a));
2170 
2171  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
2172  make_tuple(problem.N / (NXdlPack * NPerXdl),
2173  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
2174  (KXdlPack * 64 / NPerXdl),
2175  64 * KXdlPack * NXdlPack / scale_pack_size_b));
2176 
2177  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2179  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2180  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2181  // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
2182  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
2183  if(expert_block_id * MPerBlock >= max_token_id)
2184  return;
2185  const index_t expert_id =
2186  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2187  const auto block_mn = [&]() -> std::pair<int, int> {
2188  if constexpr(NSwizzle)
2189  {
2190  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2191  const index_t prefix_block = ecnt_prefix * problem.NBlock;
2192  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2193  const index_t expert_swizzle =
2194  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
2195  const index_t bid_new = blockIdx.x - prefix_block;
2196  const index_t nid = __builtin_amdgcn_readfirstlane(
2197  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2198  const index_t mid =
2199  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2200  return {nid, mid};
2201  }
2202  else
2203  {
2204  return {blockIdx.x, blockIdx.y};
2205  }
2206  }();
2207 
2208  const index_t block_n_id = block_mn.first;
2209  const index_t block_m_id = block_mn.second;
2210  const index_t token0 =
2211  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2212 
2213  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2214  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2215  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
2216  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
2217  constexpr auto AKThreads = AK0Threads * AK1Threads;
2218  constexpr auto AMRepeats = MPerBlock / AMThreads;
2219  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
2220 
2221  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
2222  return;
2223  StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
2224  static_for<0, AMRepeats, 1>{}([&](auto m0) {
2225  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
2226  index_t token_offset = fused_token & 0xffffff;
2227  if constexpr(!IsInputGemm)
2228  {
2229  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2230  }
2231  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2232  });
2233 
2234  const index_t expert_stride =
2235  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2236  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2237  problem.N * math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
2238 
2239  // N0, K0, Blocksize*KPack
2240  const index_t n_block_data_idx_on_grid =
2241  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
2242 
2243  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2244  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2245 
2246  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2247  p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
2248 
2249  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2250  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2251  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2252  p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
2253  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2254 
2255  // A matrix in LDS memory, dst of blockwise copy
2256  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2257 
2258  // B matrix in LDS memory, dst of blockwise copy
2259  // dummy
2260  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2261  // A matrix blockwise copy
2262  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
2264  AElementwiseOperation,
2267  Sequence<AK0Number, MPerBlock, AK1Number>,
2268  ABlockTransferThreadClusterLengths_AK0_M_AK1,
2269  ABlockTransferThreadClusterArrangeOrder,
2270  ADataType,
2271  LDSTypeA,
2272  decltype(a_grid_desc_ak0_m_ak1),
2273  decltype(a_block_desc_ak0_m_ak1),
2274  ABlockTransferSrcAccessOrder,
2275  Sequence<0, 1, 2>,
2276  ABlockTransferSrcVectorDim,
2277  2,
2278  ABlockTransferSrcScalarPerVector,
2279  ABlockTransferDstScalarPerVector_AK1,
2280  1,
2281  1,
2282  AThreadTransferSrcResetCoordinateAfterRun,
2283  true,
2284  IndexType,
2285  1,
2286  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
2287  make_multi_index(0, 0, 0),
2288  a_element_op,
2289  a_block_desc_ak0_m_ak1,
2290  make_multi_index(0, 0, 0),
2292  gather_offsets);
2293 
2294  // Thread-wise copy
2295  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2296  auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2297  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2298  auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2299  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2300  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2301 
2302  auto b_blockwise_copy =
2303  ThreadwiseTensorSliceTransfer_v2<BDataType,
2304  BDataType,
2305  decltype(b_grid_desc_bpreshuffled),
2306  decltype(b_block_desc_bk0_n_bk1),
2307  Sequence<Number<NXdlPerWave / NXdlPack>{},
2308  I1,
2309  Number<NXdlPack>{},
2310  Number<KRepeat>{},
2311  Number<BK1Value>{}>,
2312  Sequence<1, 2, 0, 3, 4>,
2313  4,
2314  BBlockTransferSrcScalarPerVector,
2315  BThreadTransferSrcResetCoordinateAfterRun,
2316  true>(
2317  b_grid_desc_bpreshuffled,
2318  make_multi_index(n_block_data_idx_on_grid,
2319  get_warp_local_1d_id() % NWave,
2320  0,
2321  0,
2322  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2323 
2324  // LDS allocation for A and B: be careful of alignment
2325  // Cast after lds
2326  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2327  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2328  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2329  static_cast<ADataType*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2330  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2331 
2332  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2333  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0);
2334 
2335  // Blockwise GEMM pipeline
2336  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2337  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2338  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2339  decltype(c_thread_buf) c_thread_buf_up;
2340 
2341  StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
2342  float,
2343  c_thread_buf.num_of_v_,
2344  c_thread_buf.s_per_v,
2345  true>
2346  c_thread_buf_fp32;
2347 
2348  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2349  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2350  KPerBlock);
2351 
2352  // a and b scale processing
2353  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2354  const auto waveId_m = wave_idx[I0];
2355  const auto waveId_n = wave_idx[I1];
2356 
2357  auto thread_offset_shuffled =
2358  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
2359 
2360  auto a_thread_offset_m = waveId_m;
2361 
2362  // get each thread's offset int the scale tensor
2363  const index_t token_scale_pos = block_m_id * MPerBlock;
2364  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2365  return;
2366 
2367  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2368  AScaleDataType,
2369  AScaleDataType,
2370  decltype(a_scale_grid_desc_am_ak),
2371  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2372  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
2373  Sequence<0, 1, 2>, // DimAccessOrder
2374  2, // SrcVectorDim
2375  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
2376  1, // SrcScalarStrideInVector
2377  true>(a_scale_grid_desc_am_ak,
2378  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
2379  0,
2380  thread_offset_shuffled / scale_pack_size_a));
2381 
2382  // B scale load
2383  auto b_thread_offset_n = waveId_n;
2384 
2385  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2386  BScaleDataType,
2387  BScaleDataType,
2388  decltype(b_scale_grid_desc_bn_ak),
2389  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2390  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2391  Sequence<0, 1, 2>, // DimAccessOrder
2392  2, // SrcVectorDim
2393  KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
2394  1, // SrcScalarStrideInVector
2395  true>(b_scale_grid_desc_bn_ak,
2396  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2397  0,
2398  thread_offset_shuffled / scale_pack_size_b));
2399 
2400  if constexpr(IsInputGemm)
2401  {
2402  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
2403  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2404  p_b_grid_up + expert_id * expert_stride / BPackedSize,
2405  b_grid_desc_bpreshuffled.GetElementSpaceSize());
2406  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
2407  BDataType,
2408  BDataType,
2409  decltype(b_grid_desc_bpreshuffled),
2410  decltype(b_block_desc_bk0_n_bk1),
2411  Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
2412  Sequence<1, 2, 0, 3>,
2413  3,
2414  BBlockTransferSrcScalarPerVector,
2415  BThreadTransferSrcResetCoordinateAfterRun,
2416  true>(b_grid_desc_bpreshuffled,
2417  make_multi_index(n_block_data_idx_on_grid,
2418  get_warp_local_1d_id() % NWave,
2419  0,
2420  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2421  const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
2422  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2423  p_b_scale_grid_up + expert_id * expert_scale_stride,
2424  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2425  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
2426  BScaleDataType,
2427  BScaleDataType,
2428  decltype(b_scale_grid_desc_bn_ak),
2429  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2430  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2431  Sequence<0, 1, 2>, // DimAccessOrder
2432  2, // SrcVectorDim
2433  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
2434  1, // SrcScalarStrideInVector
2435  true>(
2436  b_scale_grid_desc_bn_ak,
2437  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2438  0,
2439  thread_offset_shuffled / scale_pack_size_b));
2440 
2441  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2442  a_grid_desc_ak0_m_ak1,
2443  a_block_desc_ak0_m_ak1,
2444  a_blockwise_copy,
2445  a_grid_buf,
2446  a_block_bufs,
2447  a_block_slice_copy_step,
2448  b_grid_desc_bpreshuffled,
2449  b_block_desc_bk0_n_bk1,
2450  b_blockwise_copy,
2451  b_blockwise_copy_up,
2452  b_grid_buf,
2453  b_grid_buf_up,
2454  b_block_bufs,
2455  b_block_slice_copy_step,
2456  c_thread_buf,
2457  c_thread_buf_up,
2458  a_scale_grid_desc_am_ak,
2459  a_scale_thread_copy,
2460  a_scale_grid_buf,
2461  b_scale_grid_desc_bn_ak,
2462  b_scale_thread_copy,
2463  b_scale_thread_copy_up,
2464  b_scale_grid_buf,
2465  b_scale_grid_buf_up,
2466  num_k_block_main_loop);
2467  }
2468  else
2469  {
2470  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2471  a_grid_desc_ak0_m_ak1,
2472  a_block_desc_ak0_m_ak1,
2473  a_blockwise_copy,
2474  a_grid_buf,
2475  a_block_bufs,
2476  a_block_slice_copy_step,
2477  b_grid_desc_bpreshuffled,
2478  b_block_desc_bk0_n_bk1,
2479  b_blockwise_copy,
2480  b_grid_buf,
2481  b_block_bufs,
2482  b_block_slice_copy_step,
2483  c_thread_buf,
2484  a_scale_grid_desc_am_ak,
2485  a_scale_thread_copy,
2486  a_scale_grid_buf,
2487  b_scale_grid_desc_bn_ak,
2488  b_scale_thread_copy,
2489  b_scale_grid_buf,
2490  num_k_block_main_loop);
2491  }
2492 
2493  // shuffle C and write out
2494  {
2495  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2496  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2497  "wrong!");
2498 
2499  // TODO: hacky, fix it!
2500  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2501  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2502 
2503  // TODO: hacky, fix it!
2504  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2505  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2506  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2507 
2508  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2509  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2510  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2511  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2512  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2513  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2514  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2515  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2516 
2517  // mul scales
2518 
2519  static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
2520  static_assert(M4 == 4);
2521  const index_t m1 = get_warp_local_1d_id() / NWave;
2522  const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
2523 
2524  vector_type<float, 4> topk_weights; // for gemm2 only
2525  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
2526  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
2527  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
2528  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2529  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2530  if constexpr(MulRoutedWeight)
2531  {
2532  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
2533  p_ds_grid[I2] + m_pos);
2534  }
2535  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
2536  constexpr index_t c_offset =
2537  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2538  make_tuple(m0 / MXdlPack,
2539  n0 / NXdlPack,
2540  m0 % MXdlPack,
2541  n0 % NXdlPack,
2542  m2 * M4 + m4));
2543  constexpr auto cidx = Number<c_offset>{};
2544 
2545  if constexpr(IsInputGemm) // gu fusion
2546  {
2547  if constexpr(ActivationOperation == Activation::silu_and_mul)
2548  {
2549  float gate = c_thread_buf[cidx];
2550  float up = c_thread_buf_up[cidx];
2551  if constexpr(MulRoutedWeight)
2552  {
2553  gate = gate * topk_weights.AsType<float>()[m4];
2554  up = up * topk_weights.AsType<float>()[m4];
2555  }
2556  tensor_operation::element_wise::Silu{}(gate, gate);
2557  c_thread_buf_fp32(cidx) = gate * up;
2558  }
2559  else if(ActivationOperation == Activation::gelu_and_mul)
2560  {
2561  float gate = c_thread_buf[cidx];
2562  float up = c_thread_buf_up[cidx];
2563  if constexpr(MulRoutedWeight)
2564  {
2565  gate = gate * topk_weights.AsType<float>()[m4];
2566  up = up * topk_weights.AsType<float>()[m4];
2567  }
2568  tensor_operation::element_wise::Gelu{}(gate, gate);
2569  c_thread_buf_fp32(cidx) = gate * up;
2570  }
2571  }
2572  else
2573  {
2574  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2575  if constexpr(MulRoutedWeight)
2576  {
2577  c_thread_buf_fp32(cidx) =
2578  topk_weights.AsType<float>()[m4] * c_thread_buf_fp32[cidx];
2579  }
2580  }
2581  });
2582  });
2583  });
2584  });
2585 
2586  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2588 
2589  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2590  static_cast<CShuffleDataType*>(p_shared),
2591  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2592 
2593  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2594  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2597  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per
2598  // shuffle
2599  M1, // M1 = MWave
2600  M2, // M2 * M3 * M4 = MPerXdl
2601  M3,
2602  M4)),
2605  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per
2606  // shuffle
2607  N1, // N1 = NWave
2608  N2))), // N2 = NPerXdl
2609  make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
2610  make_tuple(
2611  Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
2612 
2613  // calculate origin of thread output tensor on global memory
2614  // blockwise GEMM c matrix starting index
2615  const auto c_thread_mtx_on_block =
2616  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2617 
2618  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2619  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2620 
2621  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2623  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
2624  make_tuple(Sequence<0, 1, 2, 3, 4>{}),
2625  make_tuple(Sequence<0>{}));
2626 
2627  const auto m_thread_data_on_block_idx =
2628  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2629  make_multi_index(m_thread_data_on_block));
2630 
2631  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2634  make_tuple(Sequence<0, 1, 2>{}),
2635  make_tuple(Sequence<0>{}));
2636 
2637  const auto n_thread_data_on_block_idx =
2638  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2639  make_multi_index(n_thread_data_on_block));
2640 
2641  // shuffle: threadwise copy C from VGPR to LDS
2642  auto c_thread_copy_vgpr_to_lds =
2643  ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
2644  CShuffleDataType,
2645  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2646  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2648  Sequence<CShuffleMXdlPerWavePerShuffle,
2649  CShuffleNXdlPerWavePerShuffle,
2650  I1,
2651  I1,
2652  M2,
2653  I1,
2654  M4,
2655  I1>,
2656  Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
2657  7,
2658  1,
2660  1,
2661  true>{
2662  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2663  make_multi_index(0,
2664  0,
2665  m_thread_data_on_block_idx[I1],
2666  n_thread_data_on_block_idx[I1],
2667  m_thread_data_on_block_idx[I2],
2668  m_thread_data_on_block_idx[I3],
2669  m_thread_data_on_block_idx[I4],
2670  n_thread_data_on_block_idx[I2]),
2672 
2673  using EDataType = CDataType;
2674 
2675  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2676  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2677 
2678  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2680  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2681 
2682  const auto ds_grid_buf = generate_tuple(
2683  [&](auto i) {
2684  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2685  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2686  },
2687  Number<NumDTensor>{});
2688 
2689  // tuple of reference to C/Ds tensor descriptors
2690  const auto c_ds_desc_refs = concat_tuple_of_reference(
2691  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2692  generate_tie([&](auto i) -> const auto& // return type should be reference
2693  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2694  Number<NumDTensor>{}));
2695 
2696  // tuple of reference to C/Ds tensor descriptors
2697  const auto c_ds_buf_refs = concat_tuple_of_reference(
2698  tie(c_shuffle_block_buf),
2699  generate_tie([&](auto i) -> const auto& // return type should be reference
2700  { return ds_grid_buf[i]; },
2701  Number<NumDTensor>{}));
2702 
2703  // tuple of starting index of C/Ds blockwise copy
2704  const auto idx_c_ds_block_begin =
2707  [&](auto) {
2708  return make_multi_index(block_m_id, 0, block_n_id, 0);
2709  // return make_multi_index(block_work_idx[I0], 0,
2710  // block_work_idx[I1], 0);
2711  },
2712  Number<NumDTensor>{}));
2713 
2714  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2715  c_grid_desc_mblock_mperblock_nblock_nperblock;
2716 
2717  using CDEBlockTransferCluster =
2718  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2719  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2720  constexpr index_t scatter_weight_idx = 3; // hack fix felix
2721  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2723  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2724  Tuple<EDataType>,
2725  decltype(c_ds_desc_refs),
2726  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2727  CElementwiseOperation,
2728  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
2729  // Sequence support
2730  // arbitray type
2731  Sequence<1,
2732  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2733  1,
2734  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2735  CDEBlockTransferCluster,
2736  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2737  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2738  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2739  3, // index_t SrcVectorDim,
2740  3, // index_t DstVectorDim,
2741  CDEShuffleBlockTransferScalarPerVectors,
2744  Sequence<true>,
2746  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2747  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2748  IndexType,
2749  1, // ScatterDim
2750  true, // OutputScatter: false, only use scatter weights
2751  scatter_weight_idx // ScatterWeightIdx: ascale
2752  >{c_ds_desc_refs,
2753  idx_c_ds_block_begin,
2754  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2755  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2756  c_element_op};
2757 
2758  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2759  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2760  constexpr auto sfc_c_vgpr =
2761  SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
2762  Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
2763  Sequence<CShuffleMXdlPerWavePerShuffle,
2764  CShuffleNXdlPerWavePerShuffle,
2765  1,
2766  1,
2767  M2,
2768  1,
2769  M4,
2770  1>>{};
2771 
2772  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2773 
2774  // space filling curve for shuffled blockwise C/D/E
2775  constexpr auto sfc_cde_block =
2776  SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
2777  Sequence<0, 2, 1, 3>,
2778  Sequence<1,
2779  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2780  1,
2781  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2782 
2783  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2784  constexpr auto EMThreads =
2785  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2786  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2787  constexpr auto ENThreads =
2788  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2789  static_for<0, num_access, 1>{}([&](auto access_id) {
2790  // make sure it's safe to write to LDS
2791  StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
2792 
2793  auto dstidx = sfc_cde_block.GetIndex(access_id);
2794  const index_t c_token_pos =
2795  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2796  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2797  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2798  IndexType token_offset = fused_token & 0xffffff;
2799  if constexpr(IsInputGemm)
2800  {
2801  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2802  }
2803  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2804  });
2805 
2806  block_sync_lds();
2807 
2808  // each thread write its data from VGPR to LDS
2809  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2810  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2811  c_thread_buf_fp32,
2812  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2813  c_shuffle_block_buf);
2814 
2815  // make sure it's safe to read from LDS
2816  block_sync_lds();
2817 
2818  // each block copy its data from LDS to global
2819  cde_block_copy_lds_and_global.Run(
2820  c_ds_desc_refs,
2821  c_ds_buf_refs,
2822  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2823  tie(c_grid_buf),
2824  scatter_offsets);
2825 
2826  if constexpr(access_id < num_access - 1)
2827  {
2828  constexpr auto cde_lds_and_global_step =
2829  sfc_cde_block.GetForwardStep(access_id);
2830 
2831  // move on Ds
2832  static_for<0, NumDTensor, 1>{}([&](auto i) {
2833  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2834  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2835  });
2836 
2837  // move on E
2838  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2839  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2840  I0,
2841  cde_lds_and_global_step);
2842  }
2843  });
2844  }
2845  }
2846 #endif
2847 };
2848 
2849 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:269
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:23
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:278
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:87
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:300
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
constexpr auto BlockGemmMXNBSPipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp:37
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: gridwise_moe_mx_gemm_bns.hpp:659
const ADataType * p_a_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:722
const index_t * p_sorted_token_ids
Definition: gridwise_moe_mx_gemm_bns.hpp:719
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_mx_gemm_bns.hpp:720
const index_t * p_max_token_id
Definition: gridwise_moe_mx_gemm_bns.hpp:721
DsGridPointer p_ds_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:726
const CElementwiseOperation c_element_op
Definition: gridwise_moe_mx_gemm_bns.hpp:731
CDataType * p_c_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:727
const AElementwiseOperation a_element_op
Definition: gridwise_moe_mx_gemm_bns.hpp:729
const BScaleDataType * p_b_scale_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:725
const BDataType * p_b_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:724
const AScaleDataType * p_a_scale_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:723
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_mx_gemm_bns.hpp:660
const BElementwiseOperation b_element_op
Definition: gridwise_moe_mx_gemm_bns.hpp:730
Definition: gridwise_moe_mx_gemm_bns.hpp:576
index_t M
Definition: gridwise_moe_mx_gemm_bns.hpp:637
index_t TopK
Definition: gridwise_moe_mx_gemm_bns.hpp:636
index_t NPadded
Definition: gridwise_moe_mx_gemm_bns.hpp:648
index_t MPadded
Definition: gridwise_moe_mx_gemm_bns.hpp:647
index_t StrideScaleB
Definition: gridwise_moe_mx_gemm_bns.hpp:643
index_t StrideScaleA
Definition: gridwise_moe_mx_gemm_bns.hpp:641
index_t MBlock
Definition: gridwise_moe_mx_gemm_bns.hpp:653
index_t StrideC
Definition: gridwise_moe_mx_gemm_bns.hpp:645
index_t AK0
Definition: gridwise_moe_mx_gemm_bns.hpp:651
index_t KPadded
Definition: gridwise_moe_mx_gemm_bns.hpp:650
index_t NBlock
Definition: gridwise_moe_mx_gemm_bns.hpp:654
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_mx_gemm_bns.hpp:577
index_t StrideA
Definition: gridwise_moe_mx_gemm_bns.hpp:640
index_t StrideB
Definition: gridwise_moe_mx_gemm_bns.hpp:642
index_t KBatch
Definition: gridwise_moe_mx_gemm_bns.hpp:646
index_t BK0
Definition: gridwise_moe_mx_gemm_bns.hpp:652
index_t KRead
Definition: gridwise_moe_mx_gemm_bns.hpp:649
__host__ void Print() const
Definition: gridwise_moe_mx_gemm_bns.hpp:612
index_t K
Definition: gridwise_moe_mx_gemm_bns.hpp:639
index_t N
Definition: gridwise_moe_mx_gemm_bns.hpp:638
index_t NumTokens
Definition: gridwise_moe_mx_gemm_bns.hpp:635
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_mx_gemm_bns.hpp:644
Definition: gridwise_moe_mx_gemm_bns.hpp:735
index_t a_k_split_offset
Definition: gridwise_moe_mx_gemm_bns.hpp:789
index_t b_k_split_offset
Definition: gridwise_moe_mx_gemm_bns.hpp:790
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_mx_gemm_bns.hpp:736
index_t b_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bns.hpp:792
index_t a_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bns.hpp:791
Definition: gridwise_moe_mx_gemm_bns.hpp:173
static constexpr auto is_scale_mfma
Definition: gridwise_moe_mx_gemm_bns.hpp:206
static constexpr auto I1
Definition: gridwise_moe_mx_gemm_bns.hpp:178
static constexpr index_t scale_pack_size_a
Definition: gridwise_moe_mx_gemm_bns.hpp:1320
static constexpr index_t NumDTensor
Definition: gridwise_moe_mx_gemm_bns.hpp:196
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_mx_gemm_bns.hpp:244
static constexpr auto MXdlPack
Definition: gridwise_moe_mx_gemm_bns.hpp:198
static constexpr auto BK0Number
Definition: gridwise_moe_mx_gemm_bns.hpp:192
static constexpr auto BK1Number
Definition: gridwise_moe_mx_gemm_bns.hpp:194
ADataType LDSTypeA
Definition: gridwise_moe_mx_gemm_bns.hpp:174
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm_bns.hpp:1286
static constexpr bool is_single_rate_mfma
Definition: gridwise_moe_mx_gemm_bns.hpp:205
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_mx_gemm_bns.hpp:232
static constexpr auto I5
Definition: gridwise_moe_mx_gemm_bns.hpp:182
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm_bns.hpp:397
BDataType LDSTypeB
Definition: gridwise_moe_mx_gemm_bns.hpp:175
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm_bns.hpp:1293
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_mx_gemm_bns.hpp:230
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm_bns.hpp:315
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:1108
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm_bns.hpp:552
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_mx_gemm_bns.hpp:1330
static constexpr auto I4
Definition: gridwise_moe_mx_gemm_bns.hpp:181
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_mx_gemm_bns.hpp:1070
static constexpr index_t APackedSize
Definition: gridwise_moe_mx_gemm_bns.hpp:202
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_mx_gemm_bns.hpp:219
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bns.hpp:564
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_mx_gemm_bns.hpp:249
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bns.hpp:259
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_mx_gemm_bns.hpp:498
static constexpr auto I8
Definition: gridwise_moe_mx_gemm_bns.hpp:185
static constexpr auto I7
Definition: gridwise_moe_mx_gemm_bns.hpp:184
static constexpr auto KXdlPack
Definition: gridwise_moe_mx_gemm_bns.hpp:200
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bns.hpp:1301
static constexpr auto I2
Definition: gridwise_moe_mx_gemm_bns.hpp:179
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_mx_gemm_bns.hpp:488
static constexpr index_t scale_pack_size_b
Definition: gridwise_moe_mx_gemm_bns.hpp:1321
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bns.hpp:277
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_mx_gemm_bns.hpp:299
remove_cvref_t< decltype(BlockGemmMXNBSPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_mx_gemm_bns.hpp:1068
static constexpr auto I6
Definition: gridwise_moe_mx_gemm_bns.hpp:183
static constexpr index_t KPack
Definition: gridwise_moe_mx_gemm_bns.hpp:213
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bns.hpp:265
static constexpr auto I3
Definition: gridwise_moe_mx_gemm_bns.hpp:180
static constexpr auto AK1Number
Definition: gridwise_moe_mx_gemm_bns.hpp:193
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_mx_gemm_bns.hpp:912
static constexpr index_t SortedTileSize
Definition: gridwise_moe_mx_gemm_bns.hpp:217
static constexpr auto I9
Definition: gridwise_moe_mx_gemm_bns.hpp:186
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_mx_gemm_bns.hpp:531
static constexpr auto AK0Number
Definition: gridwise_moe_mx_gemm_bns.hpp:191
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_mx_gemm_bns.hpp:1025
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_mx_gemm_bns.hpp:284
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_mx_gemm_bns.hpp:188
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_mx_gemm_bns.hpp:254
static constexpr auto NXdlPack
Definition: gridwise_moe_mx_gemm_bns.hpp:199
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_mx_gemm_bns.hpp:289
static constexpr index_t BPackedSize
Definition: gridwise_moe_mx_gemm_bns.hpp:203
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_mx_gemm_bns.hpp:507
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bns.hpp:271
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_mx_gemm_bns.hpp:795
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm_bns.hpp:234
static constexpr auto I0
Definition: gridwise_moe_mx_gemm_bns.hpp:177
Definition: xdlops_gemm.hpp:942
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1343
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:25
Definition: data_type.hpp:41
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:981
Definition: unary_element_wise_operation.hpp:308
Definition: unary_element_wise_operation.hpp:1023
Definition: dtype_vector.hpp:10
#define CK_ENV(name)
Definition: env.hpp:128