/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp Source File
gridwise_moe_mx_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/env.hpp"
16 
20 
21 #define DEBUG_LOG 0
22 
23 namespace ck {
24 
25 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
26 // kernel function Blockers:
27 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
28 // two lds chunks.
29 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
30 // buffer when we declare __shared__ inside blkgemmpipe
31 
33 {
34  gelu_and_mul = 0,
35  silu_and_mul = 1
36 };
37 
38 #if 0
39 template <typename GridwiseGemm,
40  bool HasMainKBlockLoop,
41  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
42  index_t MinimumOccupancy = 1,
43  TailNumber TailNum = TailNumber::Even>
44 __global__ void
45 #if CK_USE_LAUNCH_BOUNDS
46 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
47 #endif
48  // __attribute__((amdgpu_waves_per_eu(1, 1)))
49  kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
50 {
51 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
52  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
53 
54  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
55 
56  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
57  karg.p_sorted_token_ids,
58  karg.p_sorted_expert_ids,
59  karg.p_max_token_id,
60  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
61  karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
62  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
63  karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
64  karg.p_ds_grid,
65  karg.p_c_grid,
66  p_shared,
67  karg,
68  karg.a_element_op,
69  karg.b_element_op,
70  karg.c_element_op);
71 #else
72  ignore = karg;
73 #endif // end of if (defined(__gfx9__))
74 }
75 #endif
76 
77 template <typename GridwiseGemm,
78  bool HasMainKBlockLoop,
79  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
80  index_t MinimumOccupancy = 1,
81  TailNumber TailNum = TailNumber::Even>
82 __global__ void
83 #if CK_USE_LAUNCH_BOUNDS
84  __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
85 #endif
86  // __attribute__((amdgpu_waves_per_eu(1, 1)))
87  kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
88 {
89 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
90  __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
91  __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
92 
93  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
94 
95  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
96  karg.p_sorted_token_ids,
97  karg.p_sorted_expert_ids,
98  karg.p_max_token_id,
99  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
100  karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
101  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
102  karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
103  karg.p_ds_grid,
104  karg.p_c_grid,
105  p_shared_0,
106  p_shared_1,
107  karg,
108  karg.a_element_op,
109  karg.b_element_op,
110  karg.c_element_op);
111 #else
112  ignore = karg;
113 #endif // end of if (defined(__gfx9__))
114 }
115 
116 template <typename ALayout,
117  typename BLayout,
118  typename DsLayout,
119  typename CLayout,
120  typename ADataType,
121  typename AScaleDataType,
122  typename BDataType,
123  typename BScaleDataType,
124  typename AccDataType,
125  typename CShuffleDataType,
126  typename DsDataType,
127  typename CDataType,
128  typename AElementwiseOperation,
129  typename BElementwiseOperation,
130  typename CElementwiseOperation,
132  index_t ScaleBlockSize, // Scaling block size
133  index_t BlockSize, // Thread block size
134  index_t MPerBlock,
135  index_t NPerBlock,
136  index_t KPerBlock,
137  index_t AK1Value,
138  index_t BK1Value,
139  index_t MPerXdl,
140  index_t NPerXdl,
141  index_t MXdlPerWave,
142  index_t NXdlPerWave,
143  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
144  typename ABlockTransferThreadClusterArrangeOrder,
145  typename ABlockTransferSrcAccessOrder,
146  index_t ABlockTransferSrcVectorDim,
147  index_t ABlockTransferSrcScalarPerVector,
148  index_t ABlockTransferDstScalarPerVector_AK1,
149  bool AThreadTransferSrcResetCoordinateAfterRun,
150  index_t ABlockLdsExtraM,
151  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
152  typename BBlockTransferThreadClusterArrangeOrder,
153  typename BBlockTransferSrcAccessOrder,
154  index_t BBlockTransferSrcVectorDim,
155  index_t BBlockTransferSrcScalarPerVector,
156  index_t BBlockTransferDstScalarPerVector_BK1,
157  bool BThreadTransferSrcResetCoordinateAfterRun,
158  index_t BBlockLdsExtraN,
159  index_t CShuffleMXdlPerWavePerShuffle,
160  index_t CShuffleNXdlPerWavePerShuffle,
161  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
162  typename CDEShuffleBlockTransferScalarPerVectors,
165  index_t ActivationOperation = 0,
166  bool NSwizzle = false,
167  bool IsInputGemm = true,
168  bool MulRoutedWeight = true,
169  typename IndexType = index_t,
170  typename ComputeTypeA = ADataType,
171  typename ComputeTypeB = BDataType>
173 {
174  using LDSTypeA = ADataType;
175  using LDSTypeB = BDataType;
176 
177  static constexpr auto I0 = Number<0>{};
178  static constexpr auto I1 = Number<1>{};
179  static constexpr auto I2 = Number<2>{};
180  static constexpr auto I3 = Number<3>{};
181  static constexpr auto I4 = Number<4>{};
182  static constexpr auto I5 = Number<5>{};
183  static constexpr auto I6 = Number<6>{};
184  static constexpr auto I7 = Number<7>{};
185  static constexpr auto I8 = Number<8>{};
186  static constexpr auto I9 = Number<9>{};
187 
189  CDEShuffleBlockTransferScalarPerVectors{}[I0];
190  // K1 should be Number<...>
191  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
192  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
193  static constexpr auto AK1Number = Number<AK1Value>{};
194  static constexpr auto BK1Number = Number<BK1Value>{};
195 
196  static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
197  static constexpr bool is_single_rate_mfma = false;
198  static constexpr auto is_scale_mfma = true;
199 
200  static constexpr index_t NumDTensor = DsDataType::Size();
201 
202  static constexpr auto MXdlPack = 2;
203  static constexpr auto NXdlPack = 2;
204  static constexpr auto KXdlPack = 2;
205 
206  //> KPack is at least the k_per_blk of selected mfma
207  //
208  // Should be a multiple of k_per_blk.
209  // TODO: Move this to blockwise pipeline base
210  // KPack in packed data types for pk A/B
211 
212  static constexpr index_t APackedSize = packed_size_v<ADataType>;
213  static constexpr index_t BPackedSize = packed_size_v<BDataType>;
214 
215  using mfma_selector = MfmaSelector<ComputeTypeA,
216  MPerXdl,
217  NPerXdl,
218  ComputeTypeB,
220  is_scale_mfma>;
221  static constexpr index_t KPack =
223 
224  // static constexpr index_t NumTokens = 1;
225  static constexpr index_t SortedTileSize = MPerBlock;
226 
227  static constexpr auto MakeDsGridPointer()
228  {
229  return generate_tuple(
230  [&](auto i) {
231  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
232 
233  return static_cast<const DDataType*>(nullptr);
234  },
236  }
237 
238  using DsGridPointer = decltype(MakeDsGridPointer());
239 
241 
242  __host__ static auto CalculateGridSize(index_t M, index_t N)
243  {
244  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
245  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
246  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
247  const index_t gridy = NSwizzle ? 1 : mblock;
248 
249  return std::make_tuple(gridx, gridy, 1);
250  }
251 
252  __host__ static auto CalculateMPadded(index_t M)
253  {
254  return math::integer_least_multiple(M, MPerBlock);
255  }
256 
257  __host__ static auto CalculateNPadded(index_t N)
258  {
259  return math::integer_least_multiple(N, NPerBlock);
260  }
261 
262  __host__ static auto CalculateKPadded(index_t K)
263  {
264  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
265  }
266 
267  __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
268  {
269  auto K_t = K_Batch * KPerBlock;
270  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
271  }
272 
273  __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
274  {
275  auto K_t = K_Batch * KPerBlock;
276  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
277  }
278 
279  __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
280  {
281  auto K_t = K_Batch * KPerBlock;
282  return (K + K_t - 1) / K_t * KPerBlock;
283  }
284 
285  __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
286  {
287  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
288  auto K_t = K_Batch * KReadVec;
289  return (K + K_t - 1) / K_t * KReadVec;
290  }
291 
292  __host__ static auto CalculateMBlock(index_t M)
293  {
294  return math::integer_divide_ceil(M, MPerBlock);
295  }
296 
297  __host__ static auto CalculateNBlock(index_t N)
298  {
299  return math::integer_divide_ceil(N, NPerBlock);
300  }
301 
302  template <index_t MNXdlPerWave,
303  index_t MNWaves,
304  index_t MNXdlPack,
305  index_t MNPerXdl,
306  typename TileDesc_K0_MN_K1>
307  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
308  {
309  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
310  constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
311  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
312 
313  constexpr auto permuted_desc = transform_tensor_descriptor(
314  TileDesc_K0_MN_K1{},
319 
321  permuted_desc,
324  Number<MNWaves>{},
326  Number<MNPerXdl>{}))),
329  }
330 
331  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
332  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
333  {
334  const auto a_grid_desc_mraw_kraw = [&]() {
335  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
336  {
337  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
338  }
339  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
340  {
341  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
342  }
343  }();
344 
346 
347  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
348  GemmSpec == GemmSpecialization::MNKPadding)
349  {
350  // pad both M and K
351  const auto a_grid_desc_m_k =
352  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
354  make_right_pad_transform(K, KPad - K)),
357 
358  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
359  a_grid_desc_m_k,
364 
365  return a_grid_desc_ak0_m_ak1;
366  }
367  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
368  GemmSpec == GemmSpecialization::MNPadding)
369  {
370  // pad M, but not K
371  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
372  a_grid_desc_mraw_kraw,
373  make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
374  make_right_pad_transform(M, MPad - M)),
377 
378  const auto a_grid_desc_permuted = transform_tensor_descriptor(
379  a_grid_desc_ak0_m_ak1,
382  make_pass_through_transform(AK1Value)),
385 
386  const auto a_grid_desc = transform_tensor_descriptor(
387  a_grid_desc_permuted,
388  make_tuple(
391  make_pass_through_transform(AK1Value)),
394  return a_grid_desc;
395  }
396  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
397  GemmSpec == GemmSpecialization::NKPadding)
398  {
399  // pad K, but not M
400  const auto a_grid_desc_m_k = transform_tensor_descriptor(
401  a_grid_desc_mraw_kraw,
405 
406  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
407  a_grid_desc_m_k,
412 
413  return a_grid_desc_ak0_m_ak1;
414  }
415  else
416  {
417  // not pad M or K
418  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
419  a_grid_desc_mraw_kraw,
420  make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
424 
425  const auto a_grid_desc_permuted = transform_tensor_descriptor(
426  a_grid_desc_ak0_m_ak1,
429  make_pass_through_transform(AK1Value)),
432 
433  const auto a_grid_desc = transform_tensor_descriptor(
434  a_grid_desc_permuted,
435  make_tuple(
438  make_pass_through_transform(AK1Value)),
441 
442  return a_grid_desc;
443  }
444  }
445 
446  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
447  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
448  {
449  const auto b_grid_desc_nraw_kraw = [&]() {
451  {
452  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
453  }
455  {
456  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
457  }
458  }();
459 
461 
462  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
463  GemmSpec != GemmSpecialization::Default),
464  "pk_i4_t does not support padding");
465  static_assert(!(is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t> &&
466  (GemmSpec != GemmSpecialization::Default &&
467  GemmSpec != GemmSpecialization::MPadding)),
468  "f4x2_pk_t does not support K padding");
469 
470  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
471  GemmSpec == GemmSpecialization::MNKPadding)
472  {
473  // pad both N and K
474  const auto b_grid_desc_n_k =
475  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
477  make_right_pad_transform(K, KPad - K)),
480 
481  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
482  b_grid_desc_n_k,
487 
488  return b_grid_desc_bk0_n_bk1;
489  }
490  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
491  GemmSpec == GemmSpecialization::MNPadding)
492  {
493  // pad N, but not K
494  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
495  b_grid_desc_nraw_kraw,
497  make_right_pad_transform(N, NPad - N)),
500 
501  return b_grid_desc_bk0_n_bk1;
502  }
503  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
504  GemmSpec == GemmSpecialization::MKPadding)
505  {
506  // pad K, but not N
507  const auto b_grid_desc_n_k = transform_tensor_descriptor(
508  b_grid_desc_nraw_kraw,
512 
513  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
514  b_grid_desc_n_k,
519 
520  return b_grid_desc_bk0_n_bk1;
521  }
522  else
523  {
524  // not pad N or K
525  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
526  b_grid_desc_nraw_kraw,
527  make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)),
531 
532  const auto b_grid_desc_permuted = transform_tensor_descriptor(
533  b_grid_desc_bk0_n_bk1,
536  make_pass_through_transform(BK1Value)),
539 
540  const auto b_grid_desc = transform_tensor_descriptor(
541  b_grid_desc_permuted,
542  make_tuple(
545  make_pass_through_transform(BK1Value)),
548 
549  return b_grid_desc;
550  }
551  }
552 
553  template <typename ABlockDesc_AK0_M_AK1>
554  __host__ __device__ static constexpr auto
555  MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&)
556  {
557  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
558 
559  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MXdlPack, MPerXdl>(
560  ABlockDesc_AK0_M_AK1{});
561  }
562 
563  template <typename BBlockDesc_BK0_N_BK1>
564  __host__ __device__ static constexpr auto
565  MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&)
566  {
567  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
568 
569  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NXdlPack, NPerXdl>(
570  BBlockDesc_BK0_N_BK1{});
571  }
572 
573  template <typename ELayout>
574  __host__ __device__ static auto MakeCGridDescriptor_M_N(
575  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
576  {
577  const auto c_grid_desc_mraw_nraw = [&]() {
579  {
580  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
581  }
583  {
584  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
585  }
586  }();
587 
588  // pad M and N
589  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
591  make_right_pad_transform(N, NPad - N)),
594  }
595 
596  template <typename DLayout>
597  __host__ __device__ static auto
599  {
600  const auto c_grid_desc_mraw_nraw = [&]() {
602  {
603  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
604  }
606  {
607  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
608  }
609  }();
610 
611  // pad M and N
612  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
614  make_right_pad_transform(N, NPad - N)),
617  }
618 
619  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
620  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
621  {
622  return generate_tuple(
623  [&](auto i) {
624  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
625  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
626  },
628  }
629 
630  template <typename DsGridDesc>
632  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
633  {
634  return generate_tuple(
635  [&](auto i) {
637  ds_grid_desc_m_n[i], MBlock, NBlock);
638  },
640  }
641 
642  struct Problem
643  {
644  __host__ Problem(index_t NumTokens_,
645  index_t TopK_,
646  index_t M_,
647  index_t N_,
648  index_t K_,
649  index_t StrideA_,
650  index_t StrideScaleA_,
651  index_t StrideB_,
652  index_t StrideScaleB_,
653  std::array<index_t, NumDTensor> StrideDs_,
654  index_t StrideC_,
655  index_t KBatch_)
656  : NumTokens{NumTokens_},
657  TopK{TopK_},
658  M{M_},
659  N{N_},
660  K{K_},
661  StrideA{StrideA_},
662  StrideScaleA{StrideScaleA_},
663  StrideB{StrideB_},
664  StrideScaleB{StrideScaleB_},
665  StrideDs{StrideDs_},
666  StrideC{StrideC_},
667  KBatch{KBatch_},
670  KRead{CalculateKRead(K_, KBatch_)},
671  KPadded{CalculateKPadded(K_, KBatch_)},
672  AK0{CalculateAK0Padded(K_, KBatch_)},
673  BK0{CalculateBK0Padded(K_, KBatch_)},
674  MBlock{CalculateMBlock(M_)},
676  {
677  }
678 
679  __host__ void Print() const
680  {
681  std::cout << "problem {"
682  << "NumTokens:" << NumTokens << ", "
683  << "TopK:" << TopK << ", "
684  << "M:" << M << ", "
685  << "N:" << N << ", "
686  << "K:" << K << ", "
687  << "SA:" << StrideA << ", "
688  << "SScaleA:" << StrideScaleA << ", "
689  << "SB:" << StrideB << ", "
690  << "SScaleB:" << StrideScaleB << ", "
691  << "SC:" << StrideC << ", "
692  << "MP:" << MPadded << ", "
693  << "NP:" << NPadded << ", "
694  << "KRead:" << KRead << ", "
695  << "KP:" << KPadded << ", "
696  << "AK0:" << AK0 << ", "
697  << "BK0:" << BK0 << ", "
698  << "MBlock: " << MBlock << ", "
699  << "NBlock: " << NBlock << "}" << std::endl;
700  }
701 
711  std::array<index_t, NumDTensor> StrideDs;
722  };
723 
724  // Argument
726  {
727  __host__ Argument(const index_t* p_sorted_token_ids_,
728  const index_t* p_sorted_expert_ids_,
729  const index_t* p_max_token_id_,
730  const ADataType* p_a_grid_,
731  const AScaleDataType* p_a_scale_grid_,
732  const BDataType* p_b_grid_,
733  const BScaleDataType* p_b_scale_grid_,
734  std::array<const void*, NumDTensor> p_ds_grid_,
735  CDataType* p_c_grid_,
736  index_t NumTokens_,
737  index_t TopK_,
738  index_t M_,
739  index_t N_,
740  index_t K_,
741  index_t StrideA_,
742  index_t StrideScaleA_,
743  index_t StrideB_,
744  index_t StrideScaleB_,
745  std::array<index_t, NumDTensor> StrideDs_,
746  index_t StrideC_,
747  index_t k_batch_,
748  AElementwiseOperation a_element_op_,
749  BElementwiseOperation b_element_op_,
750  CElementwiseOperation c_element_op_)
751  : Problem{NumTokens_,
752  TopK_,
753  M_,
754  N_,
755  K_ / APackedSize,
756  StrideA_ / APackedSize,
757  StrideScaleA_,
758  StrideB_ / BPackedSize,
759  StrideScaleB_,
760  StrideDs_,
761  StrideC_,
762  k_batch_},
763  p_sorted_token_ids{p_sorted_token_ids_},
764  p_sorted_expert_ids{p_sorted_expert_ids_},
765  p_max_token_id{p_max_token_id_},
766  p_a_grid{p_a_grid_},
767  p_a_scale_grid{p_a_scale_grid_},
768  p_b_grid{p_b_grid_},
769  p_b_scale_grid{p_b_scale_grid_},
770  p_ds_grid{},
771  p_c_grid{p_c_grid_},
772  a_element_op{a_element_op_},
773  b_element_op{b_element_op_},
774  c_element_op{c_element_op_}
775  {
776 
777  // populate pointer, desc for Ds
778  static_for<0, NumDTensor, 1>{}([&](auto i) {
779  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
780 
781  // D pointer
782  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
783  });
784  }
785 
789  const ADataType* p_a_grid;
790  const AScaleDataType* p_a_scale_grid;
791  const BDataType* p_b_grid;
792  const BScaleDataType* p_b_scale_grid;
794  CDataType* p_c_grid;
795 
796  const AElementwiseOperation a_element_op;
797  const BElementwiseOperation b_element_op;
798  const CElementwiseOperation c_element_op;
799  };
800 
802  {
803  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
804  {
805  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
806  {
807  a_k_split_offset = k_id * karg.KRead;
808  }
809  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
810  {
811  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
812  }
813 
814  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
815  {
816  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
817  }
818  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
819  {
820  // KPack * NLane * KLane * K0 * N0
821  b_k_split_offset = k_id * karg.KRead;
822  }
823 
824  // Calculate A scale offset
825  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
826  {
827  a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize);
828  }
829  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
830  {
832  k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA;
833  }
834 
835  // Calculate B scale offset
836  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
837  {
839  k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB;
840  }
841  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
842  {
843  b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize);
844  }
845 
846  if(k_id < karg.KBatch - 1)
847  {
848  karg.K = karg.KRead;
849  }
850  else
851  {
852  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
853  }
854  }
855 
860  };
861 
862  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
863  {
864  // A matrix in LDS memory, dst of blockwise copy
865  if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
866  {
867  // contiguous in LDS
871  }
872  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
873  // in some cases.
875  {
876  constexpr auto a_lds_block_desc =
879 
880  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
881  a_lds_block_desc,
887 
888  return a_lds_block_desc_permuted;
889  }
890  else // ColumnMajor A
891  {
892  // kfold and mpair dimension is not always required.
893  // more dimension in merge_transform increase the difficulty of generating immarg offset
894  // for compiler.
895  constexpr auto WaveSize = 64;
896  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
897  constexpr auto M1 = MPerBlock / M0;
898 
899  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
900  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
901  constexpr auto KThreadRead = WaveSize / MPerXdl;
902  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
903 
904  constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
905  ? 1
906  : 128 / (AK1Number * M0 * sizeof(ADataType));
907  constexpr auto KThreadReadPerm =
908  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
909  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
910  : KThreadRead;
911 
912  // 1<=mpair<=n0
913  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
914  ? 1
915  : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
916  ? M0
917  : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
918 
919  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
923  Number<kfold * M0 / mpair>{},
924  Number<mpair>{},
925  AK1Number));
926 
927  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
928  a_lds_block_desc,
929  make_tuple(
933  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
936  make_tuple(
938  make_tuple(
940 
941  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
942  a_lds_block_desc_permuted,
943  make_tuple(
951  Sequence<1>{},
952  Sequence<2>{},
953  Sequence<3>{},
954  Sequence<4>{},
955  Sequence<5>{}),
957  Sequence<2>{},
958  Sequence<0, 3>{},
959  Sequence<4, 5>{},
960  Sequence<6>{},
961  Sequence<7>{}));
962 
963  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
964  a_lds_block_desc_unmerged,
967  Number<KThreadWrite / kfold / KThreadReadPerm>{},
968  Number<kfold>{},
975 
976  return a_lds_block_desc_ak0_m_ak1;
977  }
978  }
979 
980  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
981  {
982  // B matrix in LDS memory, dst of blockwise copy
983  if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
984  {
985  // contiguous in lds
989  }
991  {
992  // NLdsLayer * K0 as logical Bank
993  constexpr auto b_lds_block_desc =
996 
997  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
998  b_lds_block_desc,
1004 
1005  return b_lds_block_desc_permuted;
1006  }
1007  else // RowMajor B
1008  {
1009  constexpr auto WaveSize = 64;
1010  constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
1011  constexpr auto N1 = NPerBlock / N0;
1012 
1013  constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
1014  constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
1015  constexpr auto KThreadRead = WaveSize / NPerXdl;
1016  constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
1017 
1018  constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
1019  ? 1
1020  : 128 / (BK1Number * N0 * sizeof(BDataType));
1021  constexpr auto KThreadReadPerm =
1022  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
1023  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
1024  : KThreadRead;
1025 
1026  // 1<=npair<=n0
1027  constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
1028  ? 1
1029  : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
1030  ? N0
1031  : 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
1032 
1033  constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
1037  Number<kfold * N0 / npair>{},
1038  Number<npair>{},
1039  BK1Number));
1040 
1041  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
1042  b_lds_block_desc,
1043  make_tuple(
1047  make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
1050  make_tuple(
1052  make_tuple(
1054 
1055  constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
1056  b_lds_block_desc_permuted,
1057  make_tuple(
1065  Sequence<1>{},
1066  Sequence<2>{},
1067  Sequence<3>{},
1068  Sequence<4>{},
1069  Sequence<5>{}),
1071  Sequence<2>{},
1072  Sequence<0, 3>{},
1073  Sequence<4, 5>{},
1074  Sequence<6>{},
1075  Sequence<7>{}));
1076 
1077  constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
1078  b_lds_block_desc_unmerged,
1081  Number<KThreadWrite / kfold / KThreadReadPerm>{},
1082  Number<kfold>{},
1089 
1090  return b_lds_block_desc_bk0_n_bk1;
1091  }
1092  }
1093 
1095  {
1096  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1097  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1098 
1099  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1101  make_tuple(I1,
1103  I1,
1105 
1106  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1107  }
1108 
1111  BlkGemmPipelineVer,
1112  BlkGemmPipeSched,
1113  BlockSize,
1114  ScaleBlockSize,
1115  ADataType,
1116  AScaleDataType,
1117  BDataType,
1118  BScaleDataType,
1119  ComputeTypeA,
1120  AccDataType,
1127  ABlockTransferSrcScalarPerVector,
1128  BBlockTransferSrcScalarPerVector,
1129  MPerBlock,
1130  NPerBlock,
1131  KPerBlock,
1132  MPerXdl,
1133  NPerXdl,
1134  MXdlPerWave,
1135  NXdlPerWave,
1136  KPack,
1137  IsInputGemm>())>;
1138 
1139  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1140  {
1141  // LDS allocation for A and B: be careful of alignment
1142  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1143  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1144 
1145  // lds max alignment
1146  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1147 
1148  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1149  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1150 
1151  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1152  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1153 
1154  // LDS allocation for C shuffle in LDS
1155  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1157 
1158  constexpr auto c_block_size =
1159  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1160 
1161  if constexpr(IsInputGemm)
1162  {
1163  return math::max(a_block_space_size_aligned * sizeof(ADataType) +
1164  b_block_space_size_aligned * sizeof(BDataType) * 2,
1165  c_block_size * sizeof(CShuffleDataType));
1166  }
1167  else
1168  {
1169  return math::max((a_block_space_size_aligned * sizeof(ADataType) +
1170  b_block_space_size_aligned * sizeof(BDataType)),
1171  c_block_size * sizeof(CShuffleDataType));
1172  }
1173  }
1174 
1175  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1176  __host__ static constexpr bool CheckValidity(const Argument& karg)
1177  {
1178  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1179  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1180  "Invalid tuning param!");
1181 
1182  static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
1183  "KPerBlock should be multiple of ScaleBlockSize");
1184 
1185  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1190  {
1191  if(!(karg.M % MPerBlock == 0))
1192  {
1193  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1194  {
1195  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1196  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1197  << std::endl;
1198  }
1199  return false;
1200  }
1201  }
1202 
1203  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1208  {
1209  if(!(karg.N % NPerBlock == 0))
1210  {
1211  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1212  {
1213  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1214  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1215  << std::endl;
1216  }
1217  return false;
1218  }
1219  }
1220 
1221  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1225  {
1226  auto K_t = karg.KBatch * KPerBlock;
1227  if(!(karg.K % K_t == 0))
1228  {
1229  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1230  {
1231  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1232  << karg.K << " " << __FILE__ << ":" << __LINE__
1233  << ", in function: " << __func__ << std::endl;
1234  }
1235  return false;
1236  }
1237  }
1238  else
1239  {
1240  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1241  auto K_t = karg.KBatch * KReadVec;
1242  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1243  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1244  {
1245  return false;
1246  }
1247  }
1248 
1250  {
1251  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1252  {
1253  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1254  {
1255  std::cout << "Arg K (" << karg.K
1256  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1257  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1258  << __LINE__ << ", in function: " << __func__ << std::endl;
1259  }
1260  return false;
1261  }
1262  }
1263  else
1264  {
1265  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1266  {
1267  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1268  {
1269  std::cout << "Arg M (" << karg.M
1270  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1271  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1272  << __LINE__ << ", in function: " << __func__ << std::endl;
1273  }
1274  return false;
1275  }
1276  }
1277 
1279  {
1280  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1281  {
1282  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1283  {
1284  std::cout << "Arg N (" << karg.N
1285  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1286  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1287  << __LINE__ << ", in function: " << __func__ << std::endl;
1288  }
1289  return false;
1290  }
1291  }
1292  else
1293  {
1294  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1295  {
1296  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1297  {
1298  std::cout << "Arg K (" << karg.K
1299  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1300  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1301  << __LINE__ << ", in function: " << __func__ << std::endl;
1302  }
1303  return false;
1304  }
1305  }
1306 
1308  {
1310  {
1311  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1312  {
1313  std::cout << "Arg N (" << karg.N
1314  << ") value is not a multiple of "
1315  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1317  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1318  << std::endl;
1319  }
1320  return false;
1321  }
1322  }
1323  else
1324  {
1326  {
1327  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1328  {
1329  std::cout << "Arg M (" << karg.M
1330  << ") value is not a multiple of "
1331  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1333  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1334  << std::endl;
1335 
1336  return false;
1337  }
1338  }
1339  }
1340 
1341  // check gridwise gemm pipeline
1342 #if 0
1343  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1344 
1345  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1346  {
1347  return false;
1348  }
1349 #endif
1350  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1351  return true;
1352  }
1353 
1354  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1355  {
1356  const index_t num_loop = K / KPerBlock;
1357 
1358  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1359  }
1360 
1361  __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1362  {
1363  const index_t num_loop = K / KPerBlock;
1364 
1365  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1366  }
1367 
1368  template <typename CGridDesc>
1369  __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1370  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1371  {
1372  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1373  c_grid_desc_m_n,
1378 
1379  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1380  }
1381 
1382  // return block_id to C matrix tile idx (m0, n0) mapping
1383  // if arch = gfx942
1384  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1385  // NPerBlock>;
1386 
1388  static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
1389  static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
1390  static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
1391  "A scale pack data type too large!");
1392  static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
1393  "B scale pack data type too large!");
1394 
1395  static_assert(is_same_v<AElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
1396  is_same_v<BElementwiseOperation, tensor_operation::element_wise::PassThrough>,
1397  "A/B ElementwiseOperation should be PassThrough as load_to_lds is used!");
1398 
1399 #if 0
1400  template <bool HasMainKBlockLoop,
1401  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1402  TailNumber TailNum = TailNumber::Odd>
1403  __device__ static void Run(const index_t* p_sorted_token_ids,
1404  const index_t* p_sorted_expert_ids,
1405  const index_t* p_max_token_id,
1406  const ADataType* p_a_grid,
1407  const AScaleDataType* p_a_scale_grid,
1408  const BDataType* p_b_grid,
1409  const BScaleDataType* p_b_scale_grid,
1410  DsGridPointer& p_ds_grid,
1411  CDataType* p_c_grid,
1412  void* p_shared,
1413  const Problem& problem,
1414  AElementwiseOperation a_element_op,
1415  BElementwiseOperation b_element_op,
1416  CElementwiseOperation c_element_op)
1417  {
1418  ignore = a_element_op;
1419  ignore = b_element_op;
1420  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1421  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1422  problem.MPadded,
1423  problem.K,
1424  problem.KPadded,
1425  problem.StrideA,
1426  problem.AK0);
1427  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1428  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1429  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1430  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1431  problem.MPadded,
1432  problem.N,
1433  problem.NPadded,
1434  problem.StrideC);
1435 
1436  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
1437  make_tuple(problem.M / (MXdlPack * MPerXdl),
1438  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
1439  (KXdlPack * 64 / MPerXdl),
1440  64 * KXdlPack * MXdlPack / scale_pack_size_a));
1441 
1442  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
1443  make_tuple(problem.N / (NXdlPack * NPerXdl),
1444  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
1445  (KXdlPack * 64 / NPerXdl),
1446  64 * KXdlPack * NXdlPack / scale_pack_size_b));
1447 
1448  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1450  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1451 
1452  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1453  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1454  if(expert_block_id * MPerBlock >= max_token_id)
1455  return;
1456  const index_t expert_id =
1457  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1458 
1459  const auto block_mn = [&]() -> std::pair<int, int> {
1460  if constexpr(NSwizzle)
1461  {
1462  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1463  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1464  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1465  const index_t expert_swizzle =
1466  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1467  const index_t bid_new = blockIdx.x - prefix_block;
1468  const index_t nid = __builtin_amdgcn_readfirstlane(
1469  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1470  const index_t mid =
1471  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1472  return {nid, mid};
1473  }
1474  else
1475  {
1476  return {blockIdx.x, blockIdx.y};
1477  }
1478  }();
1479 
1480  const index_t block_n_id = block_mn.first;
1481  const index_t block_m_id = block_mn.second;
1482  const index_t token0 =
1483  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1484 
1485  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1486  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1487  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1488  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1489  constexpr auto AKThreads = AK0Threads * AK1Threads;
1490  constexpr auto AMRepeats = MPerBlock / AMThreads;
1491  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1492 
1493  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1494  return;
1495  StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
1496  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1497  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1498  index_t token_offset = fused_token & 0xffffff;
1499  if constexpr(!IsInputGemm)
1500  {
1501  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1502  }
1503  gather_offsets(m0) = static_cast<IndexType>(token_offset);
1504  });
1505 
1506  const index_t expert_stride =
1507  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1508  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1509  problem.N * (IsInputGemm ? 2 : 1) *
1510  math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
1511 
1512  // N0, K0, Blocksize*KPack
1513  const index_t n_block_data_idx_on_grid =
1514  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1515 
1516  // Gride buffer creation
1517  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1518  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1519  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1520  p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1521 
1522  // A, B scale buffer
1523  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1524  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1525  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1526  p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
1527  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1528 
1529  // lds max alignment
1530  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1531 
1532  // A matrix in LDS memory, dst of blockwise copy
1533  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1534 
1535  // B matrix in LDS memory, dst of blockwise copy
1536  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1537 
1538  // A matrix blockwise direct to LDS copy
1539  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
1541  Sequence<AK0Number, MPerBlock, AK1Number>,
1542  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1543  ABlockTransferThreadClusterArrangeOrder,
1544  ADataType,
1545  ADataType,
1546  decltype(a_grid_desc_ak0_m_ak1),
1547  decltype(a_block_desc_ak0_m_ak1),
1548  ABlockTransferSrcAccessOrder,
1549  ABlockTransferSrcVectorDim,
1550  2,
1551  ABlockTransferSrcScalarPerVector,
1552  IndexType,
1553  1>(a_grid_desc_ak0_m_ak1,
1554  make_multi_index(0, 0, 0),
1555  a_block_desc_ak0_m_ak1,
1556  make_multi_index(0, 0, 0),
1557  gather_offsets);
1558 
1559  // B matrix blockwise copy
1560  auto b_blockwise_copy =
1561  ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
1562  Sequence<BK0Number, NPerBlock, BK1Number>,
1563  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1564  BBlockTransferThreadClusterArrangeOrder,
1565  BDataType,
1566  BDataType,
1567  decltype(b_grid_desc_bk0_n_bk1),
1568  decltype(b_block_desc_bk0_n_bk1),
1569  BBlockTransferSrcAccessOrder,
1570  BBlockTransferSrcVectorDim,
1571  2,
1572  BBlockTransferSrcScalarPerVector>(
1573  b_grid_desc_bk0_n_bk1,
1574  make_multi_index(0, n_block_data_idx_on_grid, 0),
1575  b_block_desc_bk0_n_bk1,
1576  make_multi_index(0, 0, 0));
1577 
1578  // LDS allocation for A and B: be careful of alignment
1579  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1580  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1581 
1582  // Cast after lds
1583  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1584  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1585 
1586  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1587  reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
1588  a_block_space_size_aligned * sizeof(ADataType)),
1589  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1590 
1591  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1592  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1593 
1594  // Blockwise GEMM pipeline
1595  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1596  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1597  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1598  decltype(c_thread_buf) c_thread_buf_up;
1599 
1600  StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
1601  float,
1602  c_thread_buf.num_of_v_,
1603  c_thread_buf.s_per_v,
1604  true>
1605  c_thread_buf_fp32;
1606 
1607  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1608  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1609  KPerBlock);
1610 
1611  // a and b scale processing
1612  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1613  const auto waveId_m = wave_idx[I0];
1614  const auto waveId_n = wave_idx[I1];
1615 
1616  auto thread_offset_shuffled =
1617  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
1618 
1619  auto a_thread_offset_m = waveId_m;
1620 
1621  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1622  AScaleDataType,
1623  AScaleDataType,
1624  decltype(a_scale_grid_desc_am_ak),
1625  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1626  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
1627  Sequence<0, 1, 2>, // DimAccessOrder
1628  2, // SrcVectorDim
1629  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
1630  1, // SrcScalarStrideInVector
1631  true>(a_scale_grid_desc_am_ak,
1632  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
1633  0,
1634  thread_offset_shuffled / scale_pack_size_a));
1635 
1636  // B scale load
1637  auto b_thread_offset_n = waveId_n;
1638 
1639  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1640  BScaleDataType,
1641  BScaleDataType,
1642  decltype(b_scale_grid_desc_bn_ak),
1643  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1644  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1645  Sequence<0, 1, 2>, // DimAccessOrder
1646  2, // SrcVectorDim
1647  KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
1648  1, // SrcScalarStrideInVector
1649  true>(b_scale_grid_desc_bn_ak,
1650  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1651  0,
1652  thread_offset_shuffled / scale_pack_size_b));
1653 
1654  if constexpr(IsInputGemm)
1655  {
1656  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1657  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1658  auto b_block_buf_up = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1659  reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
1660  a_block_space_size_aligned * sizeof(ADataType) +
1661  b_block_space_size_aligned * sizeof(BDataType)),
1662  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1663 
1664  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
1665  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1666  p_b_grid_up + expert_id * expert_stride,
1667  b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1668 
1669  auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad<
1671  Sequence<BK0Number, NPerBlock, BK1Number>,
1672  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1673  BBlockTransferThreadClusterArrangeOrder,
1674  BDataType,
1675  BDataType,
1676  decltype(b_grid_desc_bk0_n_bk1),
1677  decltype(b_block_desc_bk0_n_bk1),
1678  BBlockTransferSrcAccessOrder,
1679  BBlockTransferSrcVectorDim,
1680  2,
1681  BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1,
1682  make_multi_index(0, n_block_data_idx_on_grid, 0),
1683  b_block_desc_bk0_n_bk1,
1684  make_multi_index(0, 0, 0));
1685 
1686  const BScaleDataType* p_b_scale_grid_up =
1687  p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
1688  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1689  p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
1690  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1691 
1692  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
1693  BScaleDataType,
1694  BScaleDataType,
1695  decltype(b_scale_grid_desc_bn_ak),
1696  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1697  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1698  Sequence<0, 1, 2>, // DimAccessOrder
1699  2, // SrcVectorDim
1700  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1701  1, // SrcScalarStrideInVector
1702  true>(
1703  b_scale_grid_desc_bn_ak,
1704  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1705  0,
1706  thread_offset_shuffled / scale_pack_size_b));
1707 
1708  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1709  // A
1710  a_grid_desc_ak0_m_ak1,
1711  a_block_desc_ak0_m_ak1,
1712  a_blockwise_copy,
1713  a_grid_buf,
1714  a_block_buf,
1715  a_block_slice_copy_step,
1716  // Gate and Up
1717  b_grid_desc_bk0_n_bk1,
1718  b_block_desc_bk0_n_bk1,
1719  b_blockwise_copy,
1720  b_blockwise_copy_up,
1721  b_grid_buf,
1722  b_grid_buf_up,
1723  b_block_buf,
1724  b_block_buf_up,
1725  b_block_slice_copy_step,
1726  // C
1727  c_thread_buf,
1728  c_thread_buf_up,
1729  // A scale
1730  a_scale_grid_desc_am_ak,
1731  a_scale_thread_copy,
1732  a_scale_grid_buf,
1733  // Gate and Up scale
1734  b_scale_grid_desc_bn_ak,
1735  b_scale_thread_copy,
1736  b_scale_thread_copy_up,
1737  b_scale_grid_buf,
1738  b_scale_grid_buf_up,
1739  num_k_block_main_loop);
1740  }
1741  else
1742  {
1743  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1744  a_grid_desc_ak0_m_ak1, // A
1745  a_block_desc_ak0_m_ak1,
1746  a_blockwise_copy,
1747  a_grid_buf,
1748  a_block_buf,
1749  a_block_slice_copy_step,
1750  b_grid_desc_bk0_n_bk1, // B
1751  b_block_desc_bk0_n_bk1,
1752  b_blockwise_copy,
1753  b_grid_buf,
1754  b_block_buf,
1755  b_block_slice_copy_step,
1756  c_thread_buf, // C
1757  a_scale_grid_desc_am_ak, // A scale
1758  a_scale_thread_copy,
1759  a_scale_grid_buf,
1760  b_scale_grid_desc_bn_ak, // B scale
1761  b_scale_thread_copy,
1762  b_scale_grid_buf,
1763  num_k_block_main_loop);
1764  }
1765 
1766  // shuffle C and write out
1767  {
1768  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1769  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1770  "wrong!");
1771  static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
1772  CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
1773  "wrong!");
1774 
1775  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1776  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1777 
1778  // TODO: hacky, fix it!
1779  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1780  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1781 
1782  // TODO: hacky, fix it!
1783  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1784  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1785  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1786 
1787  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1788  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1789  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1790  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1791  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1792  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1793  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1794  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1795  constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
1796  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
1797 
1798  // mul scales
1799  static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
1800  static_assert(M5 == 4);
1801  const index_t m1 = get_warp_local_1d_id() / NWave; // Mwave id
1802  const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
1803 
1804  vector_type<float, 4> topk_weights; // for gemm2 only
1805  static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
1806  static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
1807  static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
1808  static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
1809  static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
1810  const index_t m_pos = block_m_id * MPerBlock +
1811  m0 * M2 * M1 * M3 * M4 * M5 +
1812  m1 * M2 * M3 * M4 * M5 +
1813  imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
1814  if constexpr(MulRoutedWeight)
1815  {
1816  topk_weights =
1817  *c_style_pointer_cast<const vector_type<float, M5>*>(
1818  p_ds_grid[I2] + m_pos);
1819  }
1820  static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
1821  constexpr index_t c_offset =
1822  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1823  make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
1824  constexpr auto cidx = Number<c_offset>{};
1825 
1826  if constexpr(IsInputGemm) // gu fusion
1827  {
1828  if constexpr(ActivationOperation ==
1830  {
1831  float gate = c_thread_buf[cidx];
1832  float up = c_thread_buf_up[cidx];
1833  if constexpr(MulRoutedWeight)
1834  {
1835  gate = gate * topk_weights.AsType<float>()[m5];
1836  up = up * topk_weights.AsType<float>()[m5];
1837  }
1838  tensor_operation::element_wise::Silu{}(gate, gate);
1839  c_thread_buf_fp32(cidx) = gate * up;
1840  }
1841  else if(ActivationOperation == Activation::gelu_and_mul)
1842  {
1843  float gate = c_thread_buf[cidx];
1844  float up = c_thread_buf_up[cidx];
1845  if constexpr(MulRoutedWeight)
1846  {
1847  gate = gate * topk_weights.AsType<float>()[m5];
1848  up = up * topk_weights.AsType<float>()[m5];
1849  }
1850  tensor_operation::element_wise::Gelu{}(gate, gate);
1851  c_thread_buf_fp32(cidx) = gate * up;
1852 
1853  /*float gate = c_thread_buf[cidx];
1854  float up = c_thread_buf_up[cidx];
1855  if constexpr(MulRoutedWeight)
1856  {
1857  gate = gate * topk_weights.AsType<float>()[m5];
1858  //up = up * topk_weights.AsType<float>()[m5];
1859  }
1860  tensor_operation::element_wise::Gelu{}(gate, gate);
1861  c_thread_buf_fp32(cidx) = up;*/
1862  }
1863  }
1864  else
1865  {
1866  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1867  if constexpr(MulRoutedWeight)
1868  {
1869  c_thread_buf_fp32(cidx) =
1870  topk_weights.AsType<float>()[m5] *
1871  c_thread_buf_fp32[cidx];
1872  }
1873  }
1874  });
1875  });
1876  });
1877  });
1878  });
1879  });
1880 
1881  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1883 
1884  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1885  static_cast<CShuffleDataType*>(p_shared),
1886  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1887 
1888  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1889  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1890  make_tuple(
1893  Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave)
1894  // per shuffle
1895  M1, // M1 = MWave
1896  M2, // M2 = MXdlPack
1897  M3, // M3 * M4 * M5 = MPerXdl
1898  M4,
1899  M5)),
1902  Number<CShuffleNXdlPerWavePerShuffle / NXdlPack>{}, // N0 (NXdlPerWave)
1903  // per shuffle
1904  N1, // N1 = NWave
1905  N2, // N2 = NXdlPack
1906  N3))), // N3 = NPerXdl
1907  make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
1908  make_tuple(Sequence<>{},
1909  Sequence<0, 2, 4, 6, 7, 8>{},
1910  Sequence<>{},
1911  Sequence<1, 3, 5, 9>{}));
1912 
1913  // calculate origin of thread output tensor on global memory
1914  // blockwise GEMM c matrix starting index
1915  const auto c_thread_mtx_on_block =
1916  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1917 
1918  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1919  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1920 
1921  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1923  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
1924  make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}),
1925  make_tuple(Sequence<0>{}));
1926 
1927  const auto m_thread_data_on_block_idx =
1928  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1929  make_multi_index(m_thread_data_on_block));
1930 
1931  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1933  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
1934  make_tuple(Sequence<0, 1, 2, 3>{}),
1935  make_tuple(Sequence<0>{}));
1936 
1937  const auto n_thread_data_on_block_idx =
1938  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1939  make_multi_index(n_thread_data_on_block));
1940 
1941  // shuffle: threadwise copy C from VGPR to LDS
1942  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1943  AccDataType,
1944  CShuffleDataType,
1945  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1946  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1948  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
1949  CShuffleNXdlPerWavePerShuffle / NXdlPack,
1950  I1,
1951  I1,
1952  M2,
1953  N2,
1954  M3,
1955  I1,
1956  M5,
1957  I1>,
1958  Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
1959  9,
1960  1,
1962  1,
1963  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1964  make_multi_index(0,
1965  0,
1966  m_thread_data_on_block_idx[I1],
1967  n_thread_data_on_block_idx[I1],
1968  m_thread_data_on_block_idx[I2],
1969  n_thread_data_on_block_idx[I2],
1970  m_thread_data_on_block_idx[I3],
1971  m_thread_data_on_block_idx[I4],
1972  m_thread_data_on_block_idx[I5],
1973  n_thread_data_on_block_idx[I3]),
1975 
1976  using EDataType = CDataType;
1977 
1978  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1979  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1980 
1981  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1983  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1984 
1985  const auto ds_grid_buf = generate_tuple(
1986  [&](auto i) {
1987  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1988  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1989  },
1990  Number<NumDTensor>{});
1991 
1992  // tuple of reference to C/Ds tensor descriptors
1993  const auto c_ds_desc_refs = concat_tuple_of_reference(
1994  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1995  generate_tie([&](auto i) -> const auto& // return type should be reference
1996  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1997  Number<NumDTensor>{}));
1998 
1999  // tuple of reference to C/Ds tensor descriptors
2000  const auto c_ds_buf_refs = concat_tuple_of_reference(
2001  tie(c_shuffle_block_buf),
2002  generate_tie([&](auto i) -> const auto& // return type should be reference
2003  { return ds_grid_buf[i]; },
2004  Number<NumDTensor>{}));
2005 
2006  // tuple of starting index of C/Ds blockwise copy
2007  const auto idx_c_ds_block_begin =
2010  [&](auto) {
2011  return make_multi_index(block_m_id, 0, block_n_id, 0);
2012  // return make_multi_index(block_work_idx[I0], 0,
2013  // block_work_idx[I1], 0);
2014  },
2015  Number<NumDTensor>{}));
2016 
2017  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2018  c_grid_desc_mblock_mperblock_nblock_nperblock;
2019 
2020  using CDEBlockTransferCluster =
2021  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2022  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2023  constexpr index_t scatter_weight_idx = 3; // hack fix felix
2024  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2026  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2027  Tuple<EDataType>,
2028  decltype(c_ds_desc_refs),
2029  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2030  CElementwiseOperation,
2031  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
2032  // Sequence support
2033  // arbitray type
2034  Sequence<1,
2035  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2036  1,
2037  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2038  CDEBlockTransferCluster,
2039  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2040  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2041  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2042  3, // index_t SrcVectorDim,
2043  3, // index_t DstVectorDim,
2044  CDEShuffleBlockTransferScalarPerVectors,
2047  Sequence<true>,
2049  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2050  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2051  IndexType,
2052  1, // ScatterDim
2053  true, // OutputScatter: false, only use scatter weights
2054  scatter_weight_idx // ScatterWeightIdx: ascale
2055  >{c_ds_desc_refs,
2056  idx_c_ds_block_begin,
2057  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2058  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2059  c_element_op};
2060 
2061  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2062  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2063 
2064  constexpr auto sfc_c_vgpr =
2065  SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2066  NXdlPerWave / NXdlPack,
2067  1,
2068  1,
2069  MXdlPack,
2070  NXdlPack,
2071  M2,
2072  1,
2073  M4,
2074  1>,
2075  Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
2076  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2077  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2078  1,
2079  1,
2080  MXdlPack,
2081  NXdlPack,
2082  M2,
2083  1,
2084  M4,
2085  1>>{};
2086 
2087  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2088 
2089  // space filling curve for shuffled blockwise C/D/E
2090  constexpr auto sfc_cde_block =
2091  SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
2092  Sequence<0, 2, 1, 3>,
2093  Sequence<1,
2094  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2095  1,
2096  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2097 
2098  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2099  constexpr auto EMThreads =
2100  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2101  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2102  constexpr auto ENThreads =
2103  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2104  static_for<0, num_access, 1>{}([&](auto access_id) {
2105  // make sure it's safe to write to LDS
2106  StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
2107 
2108  auto dstidx = sfc_cde_block.GetIndex(access_id);
2109  const index_t c_token_pos =
2110  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2111  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2112  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2113  IndexType token_offset = fused_token & 0xffffff;
2114  if constexpr(IsInputGemm)
2115  {
2116  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2117  }
2118  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2119  });
2120 
2121  block_sync_lds();
2122 
2123  // each thread write its data from VGPR to LDS
2124  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2125  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2126  c_thread_buf_fp32,
2127  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2128  c_shuffle_block_buf);
2129 
2130  // make sure it's safe to read from LDS
2131  block_sync_lds();
2132 
2133  // each block copy its data from LDS to global
2134  cde_block_copy_lds_and_global.Run(
2135  c_ds_desc_refs,
2136  c_ds_buf_refs,
2137  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2138  tie(c_grid_buf),
2139  scatter_offsets);
2140 
2141  if constexpr(access_id < num_access - 1)
2142  {
2143  constexpr auto cde_lds_and_global_step =
2144  sfc_cde_block.GetForwardStep(access_id);
2145 
2146  // move on Ds
2147  static_for<0, NumDTensor, 1>{}([&](auto i) {
2148  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2149  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2150  });
2151 
2152  // move on E
2153  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2154  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2155  I0,
2156  cde_lds_and_global_step);
2157  }
2158  });
2159  }
2160  }
2161 #endif
2162 
2163  template <bool HasMainKBlockLoop,
2164  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2165  TailNumber TailNum = TailNumber::Odd>
2166  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
2167  const index_t* p_sorted_expert_ids,
2168  const index_t* p_max_token_id,
2169  const ADataType* p_a_grid,
2170  const AScaleDataType* p_a_scale_grid,
2171  const BDataType* p_b_grid,
2172  const BScaleDataType* p_b_scale_grid,
2173  DsGridPointer& p_ds_grid,
2174  CDataType* p_c_grid,
2175  void* p_shared_0,
2176  void* p_shared_1,
2177  const Problem& problem,
2178  AElementwiseOperation a_element_op,
2179  BElementwiseOperation b_element_op,
2180  CElementwiseOperation c_element_op)
2181  {
2182  ignore = a_element_op;
2183  ignore = b_element_op;
2184  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2185  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
2186  problem.MPadded,
2187  problem.K,
2188  problem.KPadded,
2189  problem.StrideA,
2190  problem.AK0);
2191  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
2192  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2193  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
2194  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
2195  problem.MPadded,
2196  problem.N,
2197  problem.NPadded,
2198  problem.StrideC);
2199 
2200  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
2201  make_tuple(problem.M / (MXdlPack * MPerXdl),
2202  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
2203  (KXdlPack * 64 / MPerXdl),
2204  64 * KXdlPack * MXdlPack / scale_pack_size_a));
2205 
2206  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
2207  make_tuple(problem.N / (NXdlPack * NPerXdl),
2208  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
2209  (KXdlPack * 64 / NPerXdl),
2210  64 * KXdlPack * NXdlPack / scale_pack_size_b));
2211 
2212  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2214  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2215 
2216  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2217  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
2218  if(expert_block_id * MPerBlock >= max_token_id)
2219  return;
2220  const index_t expert_id =
2221  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2222  const auto block_mn = [&]() -> std::pair<int, int> {
2223  if constexpr(NSwizzle)
2224  {
2225  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2226  const index_t prefix_block = ecnt_prefix * problem.NBlock;
2227  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2228  const index_t expert_swizzle =
2229  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
2230  const index_t bid_new = blockIdx.x - prefix_block;
2231  const index_t nid = __builtin_amdgcn_readfirstlane(
2232  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2233  const index_t mid =
2234  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2235  return {nid, mid};
2236  }
2237  else
2238  {
2239  return {blockIdx.x, blockIdx.y};
2240  }
2241  }();
2242 
2243  const index_t block_n_id = block_mn.first;
2244  const index_t block_m_id = block_mn.second;
2245  const index_t token0 =
2246  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2247 
2248  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2249  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2250  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
2251  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
2252  constexpr auto AKThreads = AK0Threads * AK1Threads;
2253  constexpr auto AMRepeats = MPerBlock / AMThreads;
2254  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads;
2255 
2256  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
2257  return;
2259  static_for<0, AMRepeats, 1>{}([&](auto m0) {
2260  const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
2261  index_t token_offset = fused_token & 0xffffff;
2262  if constexpr(!IsInputGemm)
2263  {
2264  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2265  }
2266  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2267  });
2268 
2269  const index_t expert_stride =
2270  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2271  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2272  problem.N * (IsInputGemm ? 2 : 1) *
2273  math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
2274 
2275  // N0, K0, Blocksize*KPack
2276  const index_t n_block_data_idx_on_grid =
2277  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
2278 
2279  // Gride buffer creation
2280  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2281  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2282  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2283  p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
2284 
2285  // A, B scale buffer
2286  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2287  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2288  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2289  p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
2290  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2291 
2292  // lds max alignment
2293  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
2294 
2295  // A matrix in LDS memory, dst of blockwise copy
2296  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2297 
2298  // B matrix in LDS memory, dst of blockwise copy
2299  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2300 
2301  // A matrix blockwise direct to LDS copy
2302  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
2305  ABlockTransferThreadClusterLengths_AK0_M_AK1,
2306  ABlockTransferThreadClusterArrangeOrder,
2307  ADataType,
2308  ADataType,
2309  decltype(a_grid_desc_ak0_m_ak1),
2310  decltype(a_block_desc_ak0_m_ak1),
2311  ABlockTransferSrcAccessOrder,
2312  ABlockTransferSrcVectorDim,
2313  2,
2314  ABlockTransferSrcScalarPerVector,
2315  IndexType,
2316  1>(a_grid_desc_ak0_m_ak1,
2317  make_multi_index(0, 0, 0),
2318  a_block_desc_ak0_m_ak1,
2319  make_multi_index(0, 0, 0),
2320  gather_offsets);
2321 
2322  // B matrix blockwise copy
2323  auto b_blockwise_copy =
2326  BBlockTransferThreadClusterLengths_BK0_N_BK1,
2327  BBlockTransferThreadClusterArrangeOrder,
2328  BDataType,
2329  BDataType,
2330  decltype(b_grid_desc_bk0_n_bk1),
2331  decltype(b_block_desc_bk0_n_bk1),
2332  BBlockTransferSrcAccessOrder,
2333  BBlockTransferSrcVectorDim,
2334  2,
2335  BBlockTransferSrcScalarPerVector>(
2336  b_grid_desc_bk0_n_bk1,
2337  make_multi_index(0, n_block_data_idx_on_grid, 0),
2338  b_block_desc_bk0_n_bk1,
2339  make_multi_index(0, 0, 0));
2340 
2341  // LDS allocation for A and B: be careful of alignment
2342  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
2343  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
2344 
2345  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2346  static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2347 
2348  auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2349  bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
2350  a_block_space_size_aligned * sizeof(ADataType)),
2351  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2352 
2353  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2354  static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2355 
2356  auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2357  bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
2358  a_block_space_size_aligned * sizeof(ADataType)),
2359  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2360 
2361  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2362  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2363 
2364  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2365  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
2366 
2367  // Blockwise GEMM pipeline
2368  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2369  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2370  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2371  decltype(c_thread_buf) c_thread_buf_up;
2372 
2374  float,
2375  c_thread_buf.num_of_v_,
2376  c_thread_buf.s_per_v,
2377  true>
2378  c_thread_buf_fp32;
2379 
2380  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2381  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2382  KPerBlock);
2383 
2384  // a and b scale processing
2385  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2386  const auto waveId_m = wave_idx[I0];
2387  const auto waveId_n = wave_idx[I1];
2388 
2389  auto thread_offset_shuffled =
2390  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
2391 
2392  auto a_thread_offset_m = waveId_m;
2393 
2394  // get each thread's offset int the scale tensor
2395  const index_t token_scale_pos = block_m_id * MPerBlock;
2396  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2397  return;
2398 
2399  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2400  AScaleDataType,
2401  AScaleDataType,
2402  decltype(a_scale_grid_desc_am_ak),
2403  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2404  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
2405  Sequence<0, 1, 2>, // DimAccessOrder
2406  2, // SrcVectorDim
2407  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
2408  1, // SrcScalarStrideInVector
2409  true>(a_scale_grid_desc_am_ak,
2410  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
2411  0,
2412  thread_offset_shuffled / scale_pack_size_a));
2413 
2414  // B scale load
2415  auto b_thread_offset_n = waveId_n;
2416 
2417  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2418  BScaleDataType,
2419  BScaleDataType,
2420  decltype(b_scale_grid_desc_bn_ak),
2421  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2422  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2423  Sequence<0, 1, 2>, // DimAccessOrder
2424  2, // SrcVectorDim
2425  KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
2426  1, // SrcScalarStrideInVector
2427  true>(b_scale_grid_desc_bn_ak,
2428  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2429  0,
2430  thread_offset_shuffled / scale_pack_size_b));
2431 
2432  if constexpr(IsInputGemm)
2433  {
2434  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
2435  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2436  p_b_grid_up + expert_id * expert_stride,
2437  b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
2438 
2439  // lds ping pong buffers for up
2440  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
2441  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
2442  auto b_block_buf_up_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2443  bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
2444  a_block_space_size_aligned * sizeof(ADataType) +
2445  b_block_space_size_aligned * sizeof(BDataType)),
2446  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2447  auto b_block_buf_up_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2448  bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
2449  a_block_space_size_aligned * sizeof(ADataType) +
2450  b_block_space_size_aligned * sizeof(BDataType)),
2451  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2452 
2453  auto b_block_bufs_up = make_tuple(b_block_buf_up_ping, b_block_buf_up_pong);
2454 
2455  auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad<
2458  BBlockTransferThreadClusterLengths_BK0_N_BK1,
2459  BBlockTransferThreadClusterArrangeOrder,
2460  BDataType,
2461  BDataType,
2462  decltype(b_grid_desc_bk0_n_bk1),
2463  decltype(b_block_desc_bk0_n_bk1),
2464  BBlockTransferSrcAccessOrder,
2465  BBlockTransferSrcVectorDim,
2466  2,
2467  BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1,
2468  make_multi_index(0, n_block_data_idx_on_grid, 0),
2469  b_block_desc_bk0_n_bk1,
2470  make_multi_index(0, 0, 0));
2471 
2472  const BScaleDataType* p_b_scale_grid_up =
2473  p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
2474  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2475  p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
2476  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2477 
2478  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
2479  BScaleDataType,
2480  BScaleDataType,
2481  decltype(b_scale_grid_desc_bn_ak),
2482  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2483  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2484  Sequence<0, 1, 2>, // DimAccessOrder
2485  2, // SrcVectorDim
2486  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
2487  1, // SrcScalarStrideInVector
2488  true>(
2489  b_scale_grid_desc_bn_ak,
2490  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2491  0,
2492  thread_offset_shuffled / scale_pack_size_b));
2493 
2494  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2495  // A
2496  a_grid_desc_ak0_m_ak1,
2497  a_block_desc_ak0_m_ak1,
2498  a_blockwise_copy,
2499  a_grid_buf,
2500  a_block_bufs,
2501  a_block_slice_copy_step,
2502  // Gate and Up
2503  b_grid_desc_bk0_n_bk1,
2504  b_block_desc_bk0_n_bk1,
2505  b_blockwise_copy,
2506  b_blockwise_copy_up,
2507  b_grid_buf,
2508  b_grid_buf_up,
2509  b_block_bufs,
2510  b_block_bufs_up,
2511  b_block_slice_copy_step,
2512  // C
2513  c_thread_buf,
2514  c_thread_buf_up,
2515  // A scale
2516  a_scale_grid_desc_am_ak,
2517  a_scale_thread_copy,
2518  a_scale_grid_buf,
2519  // B scale
2520  b_scale_grid_desc_bn_ak,
2521  b_scale_thread_copy,
2522  b_scale_thread_copy_up,
2523  b_scale_grid_buf,
2524  b_scale_grid_buf_up,
2525  num_k_block_main_loop);
2526  }
2527  else
2528  {
2529  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2530  a_grid_desc_ak0_m_ak1, // A
2531  a_block_desc_ak0_m_ak1,
2532  a_blockwise_copy,
2533  a_grid_buf,
2534  a_block_bufs,
2535  a_block_slice_copy_step,
2536  b_grid_desc_bk0_n_bk1, // B
2537  b_block_desc_bk0_n_bk1,
2538  b_blockwise_copy,
2539  b_grid_buf,
2540  b_block_bufs,
2541  b_block_slice_copy_step,
2542  c_thread_buf, // C
2543  a_scale_grid_desc_am_ak, // A scale
2544  a_scale_thread_copy,
2545  a_scale_grid_buf,
2546  b_scale_grid_desc_bn_ak, // B scale
2547  b_scale_thread_copy,
2548  b_scale_grid_buf,
2549  num_k_block_main_loop);
2550  }
2551 
2552  // shuffle C and write out
2553  {
2554  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2555  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2556  "wrong!");
2557  static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
2558  CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
2559  "wrong!");
2560 
2561  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2562  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2563 
2564  // TODO: hacky, fix it!
2565  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2566  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2567 
2568  // TODO: hacky, fix it!
2569  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2570  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2571  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2572 
2573  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2574  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2575  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2576  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2577  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2578  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2579  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2580  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2581  constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
2582  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
2583 
2584  // mul scales
2585 
2586  static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
2587  static_assert(M5 == 4);
2588  const index_t m1 = get_warp_local_1d_id() / NWave;
2589  const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
2590 
2591  vector_type<float, 4> topk_weights; // for gemm2 only
2592  static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
2593  static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
2594  static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
2595  static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
2596  static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
2597  const index_t m_pos = block_m_id * MPerBlock +
2598  m0 * M2 * M1 * M3 * M4 * M5 +
2599  m1 * M2 * M3 * M4 * M5 +
2600  imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
2601  if constexpr(MulRoutedWeight)
2602  {
2603  topk_weights =
2604  *c_style_pointer_cast<const vector_type<float, M5>*>(
2605  p_ds_grid[I2] + m_pos);
2606  }
2607  static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
2608  constexpr index_t c_offset =
2609  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2610  make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
2611  constexpr auto cidx = Number<c_offset>{};
2612 
2613  if constexpr(IsInputGemm) // gu fusion
2614  {
2615  if constexpr(ActivationOperation ==
2617  {
2618  float gate = c_thread_buf[cidx];
2619  float up = c_thread_buf_up[cidx];
2620  if constexpr(MulRoutedWeight)
2621  {
2622  gate = gate * topk_weights.AsType<float>()[m5];
2623  up = up * topk_weights.AsType<float>()[m5];
2624  }
2626  c_thread_buf_fp32(cidx) = gate * up;
2627  }
2628  else if(ActivationOperation == Activation::gelu_and_mul)
2629  {
2630  float gate = c_thread_buf[cidx];
2631  float up = c_thread_buf_up[cidx];
2632  if constexpr(MulRoutedWeight)
2633  {
2634  gate = gate * topk_weights.AsType<float>()[m5];
2635  up = up * topk_weights.AsType<float>()[m5];
2636  }
2638  c_thread_buf_fp32(cidx) = gate * up;
2639  }
2640  }
2641  else
2642  {
2643  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2644  if constexpr(MulRoutedWeight)
2645  {
2646  c_thread_buf_fp32(cidx) =
2647  topk_weights.AsType<float>()[m5] *
2648  c_thread_buf_fp32[cidx];
2649  }
2650  }
2651  });
2652  });
2653  });
2654  });
2655  });
2656  });
2657 
2658  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2660 
2661  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2662  static_cast<CShuffleDataType*>(p_shared_0),
2663  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2664 
2665  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2666  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2667  make_tuple(
2670  Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
2671  // shuffle
2672  M1, // M1 = MWave
2673  M2, // M2 * M3 * M4 = MPerXdl
2674  M3,
2675  M4,
2676  M5)),
2680  // per shuffle
2681  N1, // N1 = NWave
2682  N2, // N2 = NXdlPack
2683  N3))), // N3 = NPerXdl
2687  Sequence<>{},
2689 
2690  // calculate origin of thread output tensor on global memory
2691  // blockwise GEMM c matrix starting index
2692  const auto c_thread_mtx_on_block =
2693  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2694 
2695  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2696  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2697 
2698  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2700  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
2702  make_tuple(Sequence<0>{}));
2703 
2704  const auto m_thread_data_on_block_idx =
2705  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2706  make_multi_index(m_thread_data_on_block));
2707 
2708  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2710  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
2712  make_tuple(Sequence<0>{}));
2713 
2714  const auto n_thread_data_on_block_idx =
2715  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2716  make_multi_index(n_thread_data_on_block));
2717 
2718  // shuffle: threadwise copy C from VGPR to LDS
2719  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
2720  AccDataType,
2721  CShuffleDataType,
2722  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2723  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2725  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2726  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2727  I1,
2728  I1,
2729  M2,
2730  N2,
2731  M3,
2732  I1,
2733  M5,
2734  I1>,
2736  9,
2737  1,
2739  1,
2740  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2741  make_multi_index(0,
2742  0,
2743  m_thread_data_on_block_idx[I1],
2744  n_thread_data_on_block_idx[I1],
2745  m_thread_data_on_block_idx[I2],
2746  n_thread_data_on_block_idx[I2],
2747  m_thread_data_on_block_idx[I3],
2748  m_thread_data_on_block_idx[I4],
2749  m_thread_data_on_block_idx[I5],
2750  n_thread_data_on_block_idx[I3]),
2752 
2753  using EDataType = CDataType;
2754 
2755  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2756  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2757 
2758  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2760  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2761 
2762  const auto ds_grid_buf = generate_tuple(
2763  [&](auto i) {
2764  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2765  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2766  },
2767  Number<NumDTensor>{});
2768 
2769  // tuple of reference to C/Ds tensor descriptors
2770  const auto c_ds_desc_refs = concat_tuple_of_reference(
2771  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2772  generate_tie(
2773  [&](auto i) -> const auto& // return type should be reference
2774  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2775  Number<NumDTensor>{}));
2776 
2777  // tuple of reference to C/Ds tensor descriptors
2778  const auto c_ds_buf_refs = concat_tuple_of_reference(
2779  tie(c_shuffle_block_buf),
2780  generate_tie(
2781  [&](auto i) -> const auto& // return type should be reference
2782  { return ds_grid_buf[i]; },
2783  Number<NumDTensor>{}));
2784 
2785  // tuple of starting index of C/Ds blockwise copy
2786  const auto idx_c_ds_block_begin =
2789  [&](auto) {
2790  return make_multi_index(block_m_id, 0, block_n_id, 0);
2791  // return make_multi_index(block_work_idx[I0], 0,
2792  // block_work_idx[I1], 0);
2793  },
2794  Number<NumDTensor>{}));
2795 
2796  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2797  c_grid_desc_mblock_mperblock_nblock_nperblock;
2798 
2799  using CDEBlockTransferCluster =
2800  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2801  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2802  constexpr index_t scatter_weight_idx = 3; // hack fix felix
2803  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2805  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2807  decltype(c_ds_desc_refs),
2808  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2809  CElementwiseOperation,
2810  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
2811  // Sequence support
2812  // arbitray type
2813  Sequence<1,
2814  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2815  1,
2816  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2817  CDEBlockTransferCluster,
2818  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2819  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2820  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2821  3, // index_t SrcVectorDim,
2822  3, // index_t DstVectorDim,
2823  CDEShuffleBlockTransferScalarPerVectors,
2828  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2829  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2830  IndexType,
2831  1, // ScatterDim
2832  true, // OutputScatter: false, only use scatter weights
2833  scatter_weight_idx // ScatterWeightIdx: ascale
2834  >{c_ds_desc_refs,
2835  idx_c_ds_block_begin,
2836  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2837  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2838  c_element_op};
2839 
2840  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2841  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2842 
2843  constexpr auto sfc_c_vgpr =
2844  SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2845  NXdlPerWave / NXdlPack,
2846  1,
2847  1,
2848  MXdlPack,
2849  NXdlPack,
2850  M2,
2851  1,
2852  M4,
2853  1>,
2855  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2856  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2857  1,
2858  1,
2859  MXdlPack,
2860  NXdlPack,
2861  M2,
2862  1,
2863  M4,
2864  1>>{};
2865 
2866  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2867 
2868  // space filling curve for shuffled blockwise C/D/E
2869  constexpr auto sfc_cde_block =
2872  Sequence<1,
2873  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2874  1,
2875  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2876 
2877  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2878  constexpr auto EMThreads =
2879  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2880  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2881  constexpr auto ENThreads =
2882  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2883  static_for<0, num_access, 1>{}([&](auto access_id) {
2884  // make sure it's safe to write to LDS
2886 
2887  auto dstidx = sfc_cde_block.GetIndex(access_id);
2888  const index_t c_token_pos =
2889  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2890  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2891  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2892  IndexType token_offset = fused_token & 0xffffff;
2893  if constexpr(IsInputGemm)
2894  {
2895  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2896  }
2897  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2898  });
2899 
2900  block_sync_lds();
2901 
2902  // each thread write its data from VGPR to LDS
2903  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2904  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2905  c_thread_buf_fp32,
2906  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2907  c_shuffle_block_buf);
2908 
2909  // make sure it's safe to read from LDS
2910  block_sync_lds();
2911 
2912  // each block copy its data from LDS to global
2913  cde_block_copy_lds_and_global.Run(
2914  c_ds_desc_refs,
2915  c_ds_buf_refs,
2916  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2917  tie(c_grid_buf),
2918  scatter_offsets);
2919 
2920  if constexpr(access_id < num_access - 1)
2921  {
2922  constexpr auto cde_lds_and_global_step =
2923  sfc_cde_block.GetForwardStep(access_id);
2924 
2925  // move on Ds
2926  static_for<0, NumDTensor, 1>{}([&](auto i) {
2927  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2928  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2929  });
2930 
2931  // move on E
2932  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2933  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2934  I0,
2935  cde_lds_and_global_step);
2936  }
2937  });
2938  }
2939  }
2940 };
2941 
2942 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:269
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:23
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
constexpr auto BlockGemmMXPipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_selector.hpp:36
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:278
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:87
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:300
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: gridwise_moe_mx_gemm.hpp:726
const index_t * p_max_token_id
Definition: gridwise_moe_mx_gemm.hpp:788
CDataType * p_c_grid
Definition: gridwise_moe_mx_gemm.hpp:794
const AElementwiseOperation a_element_op
Definition: gridwise_moe_mx_gemm.hpp:796
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_mx_gemm.hpp:787
const CElementwiseOperation c_element_op
Definition: gridwise_moe_mx_gemm.hpp:798
const BDataType * p_b_grid
Definition: gridwise_moe_mx_gemm.hpp:791
const BScaleDataType * p_b_scale_grid
Definition: gridwise_moe_mx_gemm.hpp:792
DsGridPointer p_ds_grid
Definition: gridwise_moe_mx_gemm.hpp:793
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_mx_gemm.hpp:727
const index_t * p_sorted_token_ids
Definition: gridwise_moe_mx_gemm.hpp:786
const BElementwiseOperation b_element_op
Definition: gridwise_moe_mx_gemm.hpp:797
const AScaleDataType * p_a_scale_grid
Definition: gridwise_moe_mx_gemm.hpp:790
const ADataType * p_a_grid
Definition: gridwise_moe_mx_gemm.hpp:789
Definition: gridwise_moe_mx_gemm.hpp:643
index_t MBlock
Definition: gridwise_moe_mx_gemm.hpp:720
index_t NPadded
Definition: gridwise_moe_mx_gemm.hpp:715
index_t K
Definition: gridwise_moe_mx_gemm.hpp:706
index_t N
Definition: gridwise_moe_mx_gemm.hpp:705
index_t NumTokens
Definition: gridwise_moe_mx_gemm.hpp:702
index_t M
Definition: gridwise_moe_mx_gemm.hpp:704
index_t StrideA
Definition: gridwise_moe_mx_gemm.hpp:707
index_t StrideScaleB
Definition: gridwise_moe_mx_gemm.hpp:710
index_t KRead
Definition: gridwise_moe_mx_gemm.hpp:716
index_t NBlock
Definition: gridwise_moe_mx_gemm.hpp:721
index_t StrideC
Definition: gridwise_moe_mx_gemm.hpp:712
index_t StrideB
Definition: gridwise_moe_mx_gemm.hpp:709
__host__ void Print() const
Definition: gridwise_moe_mx_gemm.hpp:679
index_t BK0
Definition: gridwise_moe_mx_gemm.hpp:719
index_t StrideScaleA
Definition: gridwise_moe_mx_gemm.hpp:708
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_mx_gemm.hpp:711
index_t MPadded
Definition: gridwise_moe_mx_gemm.hpp:714
index_t KBatch
Definition: gridwise_moe_mx_gemm.hpp:713
index_t KPadded
Definition: gridwise_moe_mx_gemm.hpp:717
index_t TopK
Definition: gridwise_moe_mx_gemm.hpp:703
index_t AK0
Definition: gridwise_moe_mx_gemm.hpp:718
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_mx_gemm.hpp:644
Definition: gridwise_moe_mx_gemm.hpp:802
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_mx_gemm.hpp:803
index_t b_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:857
index_t b_scale_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:859
index_t a_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:856
index_t a_scale_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:858
Definition: gridwise_moe_mx_gemm.hpp:173
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm.hpp:242
static constexpr index_t APackedSize
Definition: gridwise_moe_mx_gemm.hpp:212
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_mx_gemm.hpp:227
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm.hpp:1176
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm.hpp:619
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm.hpp:631
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:273
static constexpr auto AK0Number
Definition: gridwise_moe_mx_gemm.hpp:191
static constexpr auto I7
Definition: gridwise_moe_mx_gemm.hpp:184
static constexpr index_t BPackedSize
Definition: gridwise_moe_mx_gemm.hpp:213
remove_cvref_t< decltype(BlockGemmMXPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_mx_gemm.hpp:1137
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_mx_gemm.hpp:252
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:279
static constexpr auto KXdlPack
Definition: gridwise_moe_mx_gemm.hpp:204
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_mx_gemm.hpp:980
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm.hpp:331
static constexpr auto I6
Definition: gridwise_moe_mx_gemm.hpp:183
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:1361
static constexpr auto AK1Number
Definition: gridwise_moe_mx_gemm.hpp:193
static constexpr auto I9
Definition: gridwise_moe_mx_gemm.hpp:186
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:285
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_mx_gemm.hpp:238
static constexpr auto I8
Definition: gridwise_moe_mx_gemm.hpp:185
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm.hpp:446
static constexpr auto I0
Definition: gridwise_moe_mx_gemm.hpp:177
BDataType LDSTypeB
Definition: gridwise_moe_mx_gemm.hpp:175
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_mx_gemm.hpp:555
static constexpr auto NXdlPack
Definition: gridwise_moe_mx_gemm.hpp:203
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_mx_gemm.hpp:574
static constexpr auto I3
Definition: gridwise_moe_mx_gemm.hpp:180
static constexpr index_t scale_pack_size_b
Definition: gridwise_moe_mx_gemm.hpp:1389
static constexpr index_t scale_pack_size_a
Definition: gridwise_moe_mx_gemm.hpp:1388
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_mx_gemm.hpp:292
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:262
static constexpr index_t NumDTensor
Definition: gridwise_moe_mx_gemm.hpp:200
static constexpr index_t SortedTileSize
Definition: gridwise_moe_mx_gemm.hpp:225
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_mx_gemm.hpp:297
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_mx_gemm.hpp:862
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_mx_gemm.hpp:188
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_mx_gemm.hpp:2166
ADataType LDSTypeA
Definition: gridwise_moe_mx_gemm.hpp:174
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_mx_gemm.hpp:565
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_mx_gemm.hpp:240
static constexpr auto lcm_AK1_BK1
Definition: gridwise_moe_mx_gemm.hpp:196
static constexpr auto I1
Definition: gridwise_moe_mx_gemm.hpp:178
static constexpr auto BK1Number
Definition: gridwise_moe_mx_gemm.hpp:194
static constexpr auto is_scale_mfma
Definition: gridwise_moe_mx_gemm.hpp:198
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_mx_gemm.hpp:598
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_mx_gemm.hpp:257
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_mx_gemm.hpp:1139
static constexpr auto I5
Definition: gridwise_moe_mx_gemm.hpp:182
static constexpr index_t KPack
Definition: gridwise_moe_mx_gemm.hpp:221
static constexpr auto I4
Definition: gridwise_moe_mx_gemm.hpp:181
static constexpr auto I2
Definition: gridwise_moe_mx_gemm.hpp:179
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_mx_gemm.hpp:307
static constexpr auto MXdlPack
Definition: gridwise_moe_mx_gemm.hpp:202
static constexpr bool is_single_rate_mfma
Definition: gridwise_moe_mx_gemm.hpp:197
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_mx_gemm.hpp:1094
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm.hpp:1369
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:267
static constexpr auto BK0Number
Definition: gridwise_moe_mx_gemm.hpp:192
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:1354
Definition: xdlops_gemm.hpp:942
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1343
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Definition: thread_group_tensor_slice_transfer_direct_load.hpp:55
Definition: thread_group_tensor_slice_transfer_gather_direct_load.hpp:57
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:25
Definition: data_type.hpp:41
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:981
Definition: unary_element_wise_operation.hpp:308
Definition: unary_element_wise_operation.hpp:1023
Definition: dtype_vector.hpp:10
#define CK_ENV(name)
Definition: env.hpp:128