/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp Source File
gridwise_moe_mx_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/env.hpp"
16 
20 
21 #define DEBUG_LOG 0
22 
23 namespace ck {
24 
25 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
26 // kernel function Blockers:
27 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
28 // two lds chunks.
29 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
30 // buffer when we declare __shared__ inside blkgemmpipe
31 
33 {
34  gelu_and_mul = 0,
35  silu_and_mul = 1
36 };
37 
38 #if 0
39 template <typename GridwiseGemm,
40  bool HasMainKBlockLoop,
41  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
42  index_t MinimumOccupancy = 1,
43  TailNumber TailNum = TailNumber::Even>
44 __global__ void
45 #if CK_USE_LAUNCH_BOUNDS
46 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
47 #endif
48  // __attribute__((amdgpu_waves_per_eu(1, 1)))
49  kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
50 {
51 #if defined(__gfx9__)
52  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
53 
54  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
55 
56  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
57  karg.p_sorted_token_ids,
58  karg.p_sorted_expert_ids,
59  karg.p_max_token_id,
60  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
61  karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
62  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
63  karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
64  karg.p_ds_grid,
65  karg.p_c_grid,
66  p_shared,
67  karg,
68  karg.a_element_op,
69  karg.b_element_op,
70  karg.c_element_op);
71 #else
72  ignore = karg;
73 #endif // end of if (defined(__gfx9__))
74 }
75 #endif
76 
77 template <typename GridwiseGemm,
78  bool HasMainKBlockLoop,
79  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
80  index_t MinimumOccupancy = 1,
81  TailNumber TailNum = TailNumber::Even>
82 __global__ void
83 #if CK_USE_LAUNCH_BOUNDS
84 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
85 #endif
86  // __attribute__((amdgpu_waves_per_eu(1, 1)))
87  kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
88 {
89 #if defined(__gfx9__)
90  __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
91  __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
92 
93  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
94 
95  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
96  karg.p_sorted_token_ids,
97  karg.p_sorted_expert_ids,
98  karg.p_max_token_id,
99  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
100  karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
101  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
102  karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
103  karg.p_ds_grid,
104  karg.p_c_grid,
105  p_shared_0,
106  p_shared_1,
107  karg,
108  karg.a_element_op,
109  karg.b_element_op,
110  karg.c_element_op);
111 #else
112  ignore = karg;
113 #endif // end of if (defined(__gfx9__))
114 }
115 
116 template <typename ALayout,
117  typename BLayout,
118  typename DsLayout,
119  typename CLayout,
120  typename ADataType,
121  typename AScaleDataType,
122  typename BDataType,
123  typename BScaleDataType,
124  typename AccDataType,
125  typename CShuffleDataType,
126  typename DsDataType,
127  typename CDataType,
128  typename AElementwiseOperation,
129  typename BElementwiseOperation,
130  typename CElementwiseOperation,
132  index_t ScaleBlockSize, // Scaling block size
133  index_t BlockSize, // Thread block size
134  index_t MPerBlock,
135  index_t NPerBlock,
136  index_t KPerBlock,
137  index_t AK1Value,
138  index_t BK1Value,
139  index_t MPerXdl,
140  index_t NPerXdl,
141  index_t MXdlPerWave,
142  index_t NXdlPerWave,
143  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
144  typename ABlockTransferThreadClusterArrangeOrder,
145  typename ABlockTransferSrcAccessOrder,
146  index_t ABlockTransferSrcVectorDim,
147  index_t ABlockTransferSrcScalarPerVector,
148  index_t ABlockTransferDstScalarPerVector_AK1,
149  bool AThreadTransferSrcResetCoordinateAfterRun,
150  index_t ABlockLdsExtraM,
151  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
152  typename BBlockTransferThreadClusterArrangeOrder,
153  typename BBlockTransferSrcAccessOrder,
154  index_t BBlockTransferSrcVectorDim,
155  index_t BBlockTransferSrcScalarPerVector,
156  index_t BBlockTransferDstScalarPerVector_BK1,
157  bool BThreadTransferSrcResetCoordinateAfterRun,
158  index_t BBlockLdsExtraN,
159  index_t CShuffleMXdlPerWavePerShuffle,
160  index_t CShuffleNXdlPerWavePerShuffle,
161  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
162  typename CDEShuffleBlockTransferScalarPerVectors,
165  index_t ActivationOperation = 0,
166  bool NSwizzle = false,
167  bool IsInputGemm = true,
168  bool MulRoutedWeight = true,
169  typename IndexType = index_t,
170  typename ComputeTypeA = ADataType,
171  typename ComputeTypeB = BDataType>
173 {
174  using LDSTypeA = ADataType;
175  using LDSTypeB = BDataType;
176 
177  static constexpr auto I0 = Number<0>{};
178  static constexpr auto I1 = Number<1>{};
179  static constexpr auto I2 = Number<2>{};
180  static constexpr auto I3 = Number<3>{};
181  static constexpr auto I4 = Number<4>{};
182  static constexpr auto I5 = Number<5>{};
183  static constexpr auto I6 = Number<6>{};
184  static constexpr auto I7 = Number<7>{};
185  static constexpr auto I8 = Number<8>{};
186  static constexpr auto I9 = Number<9>{};
187 
189  CDEShuffleBlockTransferScalarPerVectors{}[I0];
190  // K1 should be Number<...>
191  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
192  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
193  static constexpr auto AK1Number = Number<AK1Value>{};
194  static constexpr auto BK1Number = Number<BK1Value>{};
195 
196  static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
197  static constexpr bool is_single_rate_mfma = false;
198  static constexpr auto is_scale_mfma = true;
199 
200  static constexpr index_t NumDTensor = DsDataType::Size();
201 
202  static constexpr auto MXdlPack = 2;
203  static constexpr auto NXdlPack = 2;
204  static constexpr auto KXdlPack = 2;
205 
206  //> KPack is at least the k_per_blk of selected mfma
207  //
208  // Should be a multiple of k_per_blk.
209  // TODO: Move this to blockwise pipeline base
210  // KPack in packed data types for pk A/B
211 
212  static constexpr index_t APackedSize = packed_size_v<ADataType>;
213  static constexpr index_t BPackedSize = packed_size_v<BDataType>;
214 
215  using mfma_selector = MfmaSelector<ComputeTypeA,
216  MPerXdl,
217  NPerXdl,
218  ComputeTypeB,
220  is_scale_mfma>;
221  static constexpr index_t KPack =
223 
224  // static constexpr index_t NumTokens = 1;
225  static constexpr index_t SortedTileSize = MPerBlock;
226 
227  static constexpr auto MakeDsGridPointer()
228  {
229  return generate_tuple(
230  [&](auto i) {
231  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
232 
233  return static_cast<const DDataType*>(nullptr);
234  },
236  }
237 
238  using DsGridPointer = decltype(MakeDsGridPointer());
239 
241 
242  __host__ static auto CalculateGridSize(index_t M, index_t N)
243  {
244  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
245  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
246  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
247  const index_t gridy = NSwizzle ? 1 : mblock;
248 
249  return std::make_tuple(gridx, gridy, 1);
250  }
251 
252  __host__ static auto CalculateMPadded(index_t M)
253  {
254  return math::integer_least_multiple(M, MPerBlock);
255  }
256 
257  __host__ static auto CalculateNPadded(index_t N)
258  {
259  return math::integer_least_multiple(N, NPerBlock);
260  }
261 
262  __host__ static auto CalculateKPadded(index_t K)
263  {
264  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
265  }
266 
267  __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
268  {
269  auto K_t = K_Batch * KPerBlock;
270  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
271  }
272 
273  __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
274  {
275  auto K_t = K_Batch * KPerBlock;
276  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
277  }
278 
279  __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
280  {
281  auto K_t = K_Batch * KPerBlock;
282  return (K + K_t - 1) / K_t * KPerBlock;
283  }
284 
285  __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
286  {
287  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
288  auto K_t = K_Batch * KReadVec;
289  return (K + K_t - 1) / K_t * KReadVec;
290  }
291 
292  __host__ static auto CalculateMBlock(index_t M)
293  {
294  return math::integer_divide_ceil(M, MPerBlock);
295  }
296 
297  __host__ static auto CalculateNBlock(index_t N)
298  {
299  return math::integer_divide_ceil(N, NPerBlock);
300  }
301 
302  template <index_t MNXdlPerWave,
303  index_t MNWaves,
304  index_t MNXdlPack,
305  index_t MNPerXdl,
306  typename TileDesc_K0_MN_K1>
307  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
308  {
309  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
310  constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
311  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
312 
313  constexpr auto permuted_desc = transform_tensor_descriptor(
314  TileDesc_K0_MN_K1{},
319 
321  permuted_desc,
324  Number<MNWaves>{},
326  Number<MNPerXdl>{}))),
329  }
330 
331  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
332  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
333  {
334  const auto a_grid_desc_mraw_kraw = [&]() {
335  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
336  {
337  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
338  }
339  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
340  {
341  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
342  }
343  }();
344 
346 
347  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
348  GemmSpec == GemmSpecialization::MNKPadding)
349  {
350  // pad both M and K
351  const auto a_grid_desc_m_k =
352  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
354  make_right_pad_transform(K, KPad - K)),
357 
358  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
359  a_grid_desc_m_k,
364 
365  return a_grid_desc_ak0_m_ak1;
366  }
367  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
368  GemmSpec == GemmSpecialization::MNPadding)
369  {
370  // pad M, but not K
371  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
372  a_grid_desc_mraw_kraw,
373  make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
374  make_right_pad_transform(M, MPad - M)),
377 
378  const auto a_grid_desc_permuted = transform_tensor_descriptor(
379  a_grid_desc_ak0_m_ak1,
382  make_pass_through_transform(AK1Value)),
385 
386  const auto a_grid_desc = transform_tensor_descriptor(
387  a_grid_desc_permuted,
388  make_tuple(
391  make_pass_through_transform(AK1Value)),
394  return a_grid_desc;
395  }
396  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
397  GemmSpec == GemmSpecialization::NKPadding)
398  {
399  // pad K, but not M
400  const auto a_grid_desc_m_k = transform_tensor_descriptor(
401  a_grid_desc_mraw_kraw,
405 
406  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
407  a_grid_desc_m_k,
412 
413  return a_grid_desc_ak0_m_ak1;
414  }
415  else
416  {
417  // not pad M or K
418  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
419  a_grid_desc_mraw_kraw,
420  make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
424 
425  const auto a_grid_desc_permuted = transform_tensor_descriptor(
426  a_grid_desc_ak0_m_ak1,
429  make_pass_through_transform(AK1Value)),
432 
433  const auto a_grid_desc = transform_tensor_descriptor(
434  a_grid_desc_permuted,
435  make_tuple(
438  make_pass_through_transform(AK1Value)),
441 
442  return a_grid_desc;
443  }
444  }
445 
446  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
447  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
448  {
449  const auto b_grid_desc_nraw_kraw = [&]() {
451  {
452  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
453  }
455  {
456  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
457  }
458  }();
459 
461 
462  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
463  GemmSpec != GemmSpecialization::Default),
464  "pk_i4_t does not support padding");
465  static_assert(!(is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t> &&
466  (GemmSpec != GemmSpecialization::Default &&
467  GemmSpec != GemmSpecialization::MPadding)),
468  "f4x2_pk_t does not support K padding");
469 
470  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
471  GemmSpec == GemmSpecialization::MNKPadding)
472  {
473  // pad both N and K
474  const auto b_grid_desc_n_k =
475  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
477  make_right_pad_transform(K, KPad - K)),
480 
481  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
482  b_grid_desc_n_k,
487 
488  return b_grid_desc_bk0_n_bk1;
489  }
490  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
491  GemmSpec == GemmSpecialization::MNPadding)
492  {
493  // pad N, but not K
494  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
495  b_grid_desc_nraw_kraw,
497  make_right_pad_transform(N, NPad - N)),
500 
501  return b_grid_desc_bk0_n_bk1;
502  }
503  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
504  GemmSpec == GemmSpecialization::MKPadding)
505  {
506  // pad K, but not N
507  const auto b_grid_desc_n_k = transform_tensor_descriptor(
508  b_grid_desc_nraw_kraw,
512 
513  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
514  b_grid_desc_n_k,
519 
520  return b_grid_desc_bk0_n_bk1;
521  }
522  else
523  {
524  // not pad N or K
525  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
526  b_grid_desc_nraw_kraw,
527  make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)),
531 
532  const auto b_grid_desc_permuted = transform_tensor_descriptor(
533  b_grid_desc_bk0_n_bk1,
536  make_pass_through_transform(BK1Value)),
539 
540  const auto b_grid_desc = transform_tensor_descriptor(
541  b_grid_desc_permuted,
542  make_tuple(
545  make_pass_through_transform(BK1Value)),
548 
549  return b_grid_desc;
550  }
551  }
552 
553  template <typename ABlockDesc_AK0_M_AK1>
554  __host__ __device__ static constexpr auto
555  MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&)
556  {
557  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
558 
559  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MXdlPack, MPerXdl>(
560  ABlockDesc_AK0_M_AK1{});
561  }
562 
563  template <typename BBlockDesc_BK0_N_BK1>
564  __host__ __device__ static constexpr auto
565  MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&)
566  {
567  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
568 
569  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NXdlPack, NPerXdl>(
570  BBlockDesc_BK0_N_BK1{});
571  }
572 
573  template <typename ELayout>
574  __host__ __device__ static auto MakeCGridDescriptor_M_N(
575  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
576  {
577  const auto c_grid_desc_mraw_nraw = [&]() {
579  {
580  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
581  }
583  {
584  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
585  }
586  }();
587 
588  // pad M and N
589  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
591  make_right_pad_transform(N, NPad - N)),
594  }
595 
596  template <typename DLayout>
597  __host__ __device__ static auto
599  {
600  const auto c_grid_desc_mraw_nraw = [&]() {
602  {
603  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
604  }
606  {
607  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
608  }
609  }();
610 
611  // pad M and N
612  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
614  make_right_pad_transform(N, NPad - N)),
617  }
618 
619  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
620  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
621  {
622  return generate_tuple(
623  [&](auto i) {
624  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
625  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
626  },
628  }
629 
630  template <typename DsGridDesc>
632  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
633  {
634  return generate_tuple(
635  [&](auto i) {
637  ds_grid_desc_m_n[i], MBlock, NBlock);
638  },
640  }
641 
642  struct Problem
643  {
644  __host__ Problem(index_t NumTokens_,
645  index_t TopK_,
646  index_t M_,
647  index_t N_,
648  index_t K_,
649  index_t StrideA_,
650  index_t StrideScaleA_,
651  index_t StrideB_,
652  index_t StrideScaleB_,
653  std::array<index_t, NumDTensor> StrideDs_,
654  index_t StrideC_,
655  index_t KBatch_)
656  : NumTokens{NumTokens_},
657  TopK{TopK_},
658  M{M_},
659  N{N_},
660  K{K_},
661  StrideA{StrideA_},
662  StrideScaleA{StrideScaleA_},
663  StrideB{StrideB_},
664  StrideScaleB{StrideScaleB_},
665  StrideDs{StrideDs_},
666  StrideC{StrideC_},
667  KBatch{KBatch_},
670  KRead{CalculateKRead(K_, KBatch_)},
671  KPadded{CalculateKPadded(K_, KBatch_)},
672  AK0{CalculateAK0Padded(K_, KBatch_)},
673  BK0{CalculateBK0Padded(K_, KBatch_)},
674  MBlock{CalculateMBlock(M_)},
676  {
677  }
678 
679  __host__ void Print() const
680  {
681  std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
682  << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
683  << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", "
684  << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", "
685  << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
686  << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
687  << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
688  << ", " << "NBlock: " << NBlock << "}" << std::endl;
689  }
690 
700  std::array<index_t, NumDTensor> StrideDs;
711  };
712 
713  // Argument
715  {
716  __host__ Argument(const index_t* p_sorted_token_ids_,
717  const index_t* p_sorted_expert_ids_,
718  const index_t* p_max_token_id_,
719  const ADataType* p_a_grid_,
720  const AScaleDataType* p_a_scale_grid_,
721  const BDataType* p_b_grid_,
722  const BScaleDataType* p_b_scale_grid_,
723  std::array<const void*, NumDTensor> p_ds_grid_,
724  CDataType* p_c_grid_,
725  index_t NumTokens_,
726  index_t TopK_,
727  index_t M_,
728  index_t N_,
729  index_t K_,
730  index_t StrideA_,
731  index_t StrideScaleA_,
732  index_t StrideB_,
733  index_t StrideScaleB_,
734  std::array<index_t, NumDTensor> StrideDs_,
735  index_t StrideC_,
736  index_t k_batch_,
737  AElementwiseOperation a_element_op_,
738  BElementwiseOperation b_element_op_,
739  CElementwiseOperation c_element_op_)
740  : Problem{NumTokens_,
741  TopK_,
742  M_,
743  N_,
744  K_ / APackedSize,
745  StrideA_ / APackedSize,
746  StrideScaleA_,
747  StrideB_ / BPackedSize,
748  StrideScaleB_,
749  StrideDs_,
750  StrideC_,
751  k_batch_},
752  p_sorted_token_ids{p_sorted_token_ids_},
753  p_sorted_expert_ids{p_sorted_expert_ids_},
754  p_max_token_id{p_max_token_id_},
755  p_a_grid{p_a_grid_},
756  p_a_scale_grid{p_a_scale_grid_},
757  p_b_grid{p_b_grid_},
758  p_b_scale_grid{p_b_scale_grid_},
759  p_ds_grid{},
760  p_c_grid{p_c_grid_},
761  a_element_op{a_element_op_},
762  b_element_op{b_element_op_},
763  c_element_op{c_element_op_}
764  {
765 
766  // populate pointer, desc for Ds
767  static_for<0, NumDTensor, 1>{}([&](auto i) {
768  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
769 
770  // D pointer
771  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
772  });
773  }
774 
778  const ADataType* p_a_grid;
779  const AScaleDataType* p_a_scale_grid;
780  const BDataType* p_b_grid;
781  const BScaleDataType* p_b_scale_grid;
783  CDataType* p_c_grid;
784 
785  const AElementwiseOperation a_element_op;
786  const BElementwiseOperation b_element_op;
787  const CElementwiseOperation c_element_op;
788  };
789 
791  {
792  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
793  {
794  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
795  {
796  a_k_split_offset = k_id * karg.KRead;
797  }
798  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
799  {
800  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
801  }
802 
803  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
804  {
805  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
806  }
807  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
808  {
809  // KPack * NLane * KLane * K0 * N0
810  b_k_split_offset = k_id * karg.KRead;
811  }
812 
813  // Calculate A scale offset
814  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
815  {
816  a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize);
817  }
818  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
819  {
821  k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA;
822  }
823 
824  // Calculate B scale offset
825  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
826  {
828  k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB;
829  }
830  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
831  {
832  b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize);
833  }
834 
835  if(k_id < karg.KBatch - 1)
836  {
837  karg.K = karg.KRead;
838  }
839  else
840  {
841  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
842  }
843  }
844 
849  };
850 
851  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
852  {
853  // A matrix in LDS memory, dst of blockwise copy
854  if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
855  {
856  // contiguous in LDS
860  }
861  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
862  // in some cases.
864  {
865  constexpr auto a_lds_block_desc =
868 
869  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
870  a_lds_block_desc,
876 
877  return a_lds_block_desc_permuted;
878  }
879  else // ColumnMajor A
880  {
881  // kfold and mpair dimension is not always required.
882  // more dimension in merge_transform increase the difficulty of generating immarg offset
883  // for compiler.
884  constexpr auto WaveSize = 64;
885  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
886  constexpr auto M1 = MPerBlock / M0;
887 
888  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
889  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
890  constexpr auto KThreadRead = WaveSize / MPerXdl;
891  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
892 
893  constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
894  ? 1
895  : 128 / (AK1Number * M0 * sizeof(ADataType));
896  constexpr auto KThreadReadPerm =
897  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
898  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
899  : KThreadRead;
900 
901  // 1<=mpair<=n0
902  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
903  ? 1
904  : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
905  ? M0
906  : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
907 
908  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
912  Number<kfold * M0 / mpair>{},
913  Number<mpair>{},
914  AK1Number));
915 
916  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
917  a_lds_block_desc,
918  make_tuple(
922  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
925  make_tuple(
927  make_tuple(
929 
930  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
931  a_lds_block_desc_permuted,
932  make_tuple(
940  Sequence<1>{},
941  Sequence<2>{},
942  Sequence<3>{},
943  Sequence<4>{},
944  Sequence<5>{}),
946  Sequence<2>{},
947  Sequence<0, 3>{},
948  Sequence<4, 5>{},
949  Sequence<6>{},
950  Sequence<7>{}));
951 
952  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
953  a_lds_block_desc_unmerged,
956  Number<KThreadWrite / kfold / KThreadReadPerm>{},
957  Number<kfold>{},
964 
965  return a_lds_block_desc_ak0_m_ak1;
966  }
967  }
968 
969  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
970  {
971  // B matrix in LDS memory, dst of blockwise copy
972  if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
973  {
974  // contiguous in lds
978  }
980  {
981  // NLdsLayer * K0 as logical Bank
982  constexpr auto b_lds_block_desc =
985 
986  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
987  b_lds_block_desc,
993 
994  return b_lds_block_desc_permuted;
995  }
996  else // RowMajor B
997  {
998  constexpr auto WaveSize = 64;
999  constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
1000  constexpr auto N1 = NPerBlock / N0;
1001 
1002  constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
1003  constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
1004  constexpr auto KThreadRead = WaveSize / NPerXdl;
1005  constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
1006 
1007  constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
1008  ? 1
1009  : 128 / (BK1Number * N0 * sizeof(BDataType));
1010  constexpr auto KThreadReadPerm =
1011  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
1012  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
1013  : KThreadRead;
1014 
1015  // 1<=npair<=n0
1016  constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
1017  ? 1
1018  : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
1019  ? N0
1020  : 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
1021 
1022  constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
1026  Number<kfold * N0 / npair>{},
1027  Number<npair>{},
1028  BK1Number));
1029 
1030  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
1031  b_lds_block_desc,
1032  make_tuple(
1036  make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
1039  make_tuple(
1041  make_tuple(
1043 
1044  constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
1045  b_lds_block_desc_permuted,
1046  make_tuple(
1054  Sequence<1>{},
1055  Sequence<2>{},
1056  Sequence<3>{},
1057  Sequence<4>{},
1058  Sequence<5>{}),
1060  Sequence<2>{},
1061  Sequence<0, 3>{},
1062  Sequence<4, 5>{},
1063  Sequence<6>{},
1064  Sequence<7>{}));
1065 
1066  constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
1067  b_lds_block_desc_unmerged,
1070  Number<KThreadWrite / kfold / KThreadReadPerm>{},
1071  Number<kfold>{},
1078 
1079  return b_lds_block_desc_bk0_n_bk1;
1080  }
1081  }
1082 
1084  {
1085  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1086  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1087 
1088  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1090  make_tuple(I1,
1092  I1,
1094 
1095  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1096  }
1097 
1100  BlkGemmPipelineVer,
1101  BlkGemmPipeSched,
1102  BlockSize,
1103  ScaleBlockSize,
1104  ADataType,
1105  AScaleDataType,
1106  BDataType,
1107  BScaleDataType,
1108  ComputeTypeA,
1109  AccDataType,
1116  ABlockTransferSrcScalarPerVector,
1117  BBlockTransferSrcScalarPerVector,
1118  MPerBlock,
1119  NPerBlock,
1120  KPerBlock,
1121  MPerXdl,
1122  NPerXdl,
1123  MXdlPerWave,
1124  NXdlPerWave,
1125  KPack,
1126  IsInputGemm>())>;
1127 
1128  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1129  {
1130  // LDS allocation for A and B: be careful of alignment
1131  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1132  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1133 
1134  // lds max alignment
1135  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1136 
1137  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1138  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1139 
1140  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1141  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1142 
1143  // LDS allocation for C shuffle in LDS
1144  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1146 
1147  constexpr auto c_block_size =
1148  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1149 
1150  if constexpr(IsInputGemm)
1151  {
1152  return math::max(a_block_space_size_aligned * sizeof(ADataType) +
1153  b_block_space_size_aligned * sizeof(BDataType) * 2,
1154  c_block_size * sizeof(CShuffleDataType));
1155  }
1156  else
1157  {
1158  return math::max((a_block_space_size_aligned * sizeof(ADataType) +
1159  b_block_space_size_aligned * sizeof(BDataType)),
1160  c_block_size * sizeof(CShuffleDataType));
1161  }
1162  }
1163 
1164  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1165  __host__ static constexpr bool CheckValidity(const Argument& karg)
1166  {
1167  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1168  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1169  "Invalid tuning param!");
1170 
1171  static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
1172  "KPerBlock should be multiple of ScaleBlockSize");
1173 
1174  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1179  {
1180  if(!(karg.M % MPerBlock == 0))
1181  {
1182  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1183  {
1184  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1185  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1186  << std::endl;
1187  }
1188  return false;
1189  }
1190  }
1191 
1192  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1197  {
1198  if(!(karg.N % NPerBlock == 0))
1199  {
1200  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1201  {
1202  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1203  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1204  << std::endl;
1205  }
1206  return false;
1207  }
1208  }
1209 
1210  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1214  {
1215  auto K_t = karg.KBatch * KPerBlock;
1216  if(!(karg.K % K_t == 0))
1217  {
1218  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1219  {
1220  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1221  << karg.K << " " << __FILE__ << ":" << __LINE__
1222  << ", in function: " << __func__ << std::endl;
1223  }
1224  return false;
1225  }
1226  }
1227  else
1228  {
1229  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1230  auto K_t = karg.KBatch * KReadVec;
1231  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1232  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1233  {
1234  return false;
1235  }
1236  }
1237 
1239  {
1240  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1241  {
1242  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1243  {
1244  std::cout << "Arg K (" << karg.K
1245  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1246  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1247  << __LINE__ << ", in function: " << __func__ << std::endl;
1248  }
1249  return false;
1250  }
1251  }
1252  else
1253  {
1254  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1255  {
1256  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1257  {
1258  std::cout << "Arg M (" << karg.M
1259  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1260  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1261  << __LINE__ << ", in function: " << __func__ << std::endl;
1262  }
1263  return false;
1264  }
1265  }
1266 
1268  {
1269  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1270  {
1271  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1272  {
1273  std::cout << "Arg N (" << karg.N
1274  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1275  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1276  << __LINE__ << ", in function: " << __func__ << std::endl;
1277  }
1278  return false;
1279  }
1280  }
1281  else
1282  {
1283  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1284  {
1285  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1286  {
1287  std::cout << "Arg K (" << karg.K
1288  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1289  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1290  << __LINE__ << ", in function: " << __func__ << std::endl;
1291  }
1292  return false;
1293  }
1294  }
1295 
1297  {
1299  {
1300  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1301  {
1302  std::cout << "Arg N (" << karg.N
1303  << ") value is not a multiple of "
1304  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1306  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1307  << std::endl;
1308  }
1309  return false;
1310  }
1311  }
1312  else
1313  {
1315  {
1316  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1317  {
1318  std::cout << "Arg M (" << karg.M
1319  << ") value is not a multiple of "
1320  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1322  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1323  << std::endl;
1324 
1325  return false;
1326  }
1327  }
1328  }
1329 
1330  // check gridwise gemm pipeline
1331 #if 0
1332  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1333 
1334  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1335  {
1336  return false;
1337  }
1338 #endif
1339  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1340  return true;
1341  }
1342 
1343  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1344  {
1345  const index_t num_loop = K / KPerBlock;
1346 
1347  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1348  }
1349 
1350  __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1351  {
1352  const index_t num_loop = K / KPerBlock;
1353 
1354  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1355  }
1356 
1357  template <typename CGridDesc>
1358  __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1359  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1360  {
1361  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1362  c_grid_desc_m_n,
1367 
1368  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1369  }
1370 
1371  // return block_id to C matrix tile idx (m0, n0) mapping
1372  // if arch = gfx942
1373  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1374  // NPerBlock>;
1375 
1377  static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
1378  static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
1379  static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
1380  "A scale pack data type too large!");
1381  static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
1382  "B scale pack data type too large!");
1383 
1384  static_assert(is_same_v<AElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
1385  is_same_v<BElementwiseOperation, tensor_operation::element_wise::PassThrough>,
1386  "A/B ElementwiseOperation should be PassThrough as load_to_lds is used!");
1387 
1388 #if 0
1389  template <bool HasMainKBlockLoop,
1390  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1391  TailNumber TailNum = TailNumber::Odd>
1392  __device__ static void Run(const index_t* p_sorted_token_ids,
1393  const index_t* p_sorted_expert_ids,
1394  const index_t* p_max_token_id,
1395  const ADataType* p_a_grid,
1396  const AScaleDataType* p_a_scale_grid,
1397  const BDataType* p_b_grid,
1398  const BScaleDataType* p_b_scale_grid,
1399  DsGridPointer& p_ds_grid,
1400  CDataType* p_c_grid,
1401  void* p_shared,
1402  const Problem& problem,
1403  AElementwiseOperation a_element_op,
1404  BElementwiseOperation b_element_op,
1405  CElementwiseOperation c_element_op)
1406  {
1407  ignore = a_element_op;
1408  ignore = b_element_op;
1409  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1410  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1411  problem.MPadded,
1412  problem.K,
1413  problem.KPadded,
1414  problem.StrideA,
1415  problem.AK0);
1416  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1417  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1418  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1419  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1420  problem.MPadded,
1421  problem.N,
1422  problem.NPadded,
1423  problem.StrideC);
1424 
1425  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
1426  make_tuple(problem.M / (MXdlPack * MPerXdl),
1427  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
1428  (KXdlPack * 64 / MPerXdl),
1429  64 * KXdlPack * MXdlPack / scale_pack_size_a));
1430 
1431  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
1432  make_tuple(problem.N / (NXdlPack * NPerXdl),
1433  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
1434  (KXdlPack * 64 / NPerXdl),
1435  64 * KXdlPack * NXdlPack / scale_pack_size_b));
1436 
1437  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1439  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1440 
1441  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1442  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1443  if(expert_block_id * MPerBlock >= max_token_id)
1444  return;
1445  const index_t expert_id =
1446  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1447 
1448  const auto block_mn = [&]() -> std::pair<int, int> {
1449  if constexpr(NSwizzle)
1450  {
1451  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1452  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1453  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1454  const index_t expert_swizzle =
1455  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1456  const index_t bid_new = blockIdx.x - prefix_block;
1457  const index_t nid = __builtin_amdgcn_readfirstlane(
1458  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1459  const index_t mid =
1460  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1461  return {nid, mid};
1462  }
1463  else
1464  {
1465  return {blockIdx.x, blockIdx.y};
1466  }
1467  }();
1468 
1469  const index_t block_n_id = block_mn.first;
1470  const index_t block_m_id = block_mn.second;
1471  const index_t token0 =
1472  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1473 
1474  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1475  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1476  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1477  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1478  constexpr auto AKThreads = AK0Threads * AK1Threads;
1479  constexpr auto AMRepeats = MPerBlock / AMThreads;
1480  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1481 
1482  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1483  return;
1484  StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
1485  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1486  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1487  index_t token_offset = fused_token & 0xffffff;
1488  if constexpr(!IsInputGemm)
1489  {
1490  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1491  }
1492  gather_offsets(m0) = static_cast<IndexType>(token_offset);
1493  });
1494 
1495  const index_t expert_stride =
1496  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1497  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1498  problem.N * (IsInputGemm ? 2 : 1) *
1499  math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
1500 
1501  // N0, K0, Blocksize*KPack
1502  const index_t n_block_data_idx_on_grid =
1503  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1504 
1505  // Gride buffer creation
1506  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1507  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1508  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1509  p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1510 
1511  // A, B scale buffer
1512  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1513  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1514  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1515  p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
1516  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1517 
1518  // lds max alignment
1519  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1520 
1521  // A matrix in LDS memory, dst of blockwise copy
1522  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1523 
1524  // B matrix in LDS memory, dst of blockwise copy
1525  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1526 
1527  // A matrix blockwise direct to LDS copy
1528  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
1530  Sequence<AK0Number, MPerBlock, AK1Number>,
1531  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1532  ABlockTransferThreadClusterArrangeOrder,
1533  ADataType,
1534  ADataType,
1535  decltype(a_grid_desc_ak0_m_ak1),
1536  decltype(a_block_desc_ak0_m_ak1),
1537  ABlockTransferSrcAccessOrder,
1538  ABlockTransferSrcVectorDim,
1539  2,
1540  ABlockTransferSrcScalarPerVector,
1541  IndexType,
1542  1>(a_grid_desc_ak0_m_ak1,
1543  make_multi_index(0, 0, 0),
1544  a_block_desc_ak0_m_ak1,
1545  make_multi_index(0, 0, 0),
1546  gather_offsets);
1547 
1548  // B matrix blockwise copy
1549  auto b_blockwise_copy =
1550  ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
1551  Sequence<BK0Number, NPerBlock, BK1Number>,
1552  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1553  BBlockTransferThreadClusterArrangeOrder,
1554  BDataType,
1555  BDataType,
1556  decltype(b_grid_desc_bk0_n_bk1),
1557  decltype(b_block_desc_bk0_n_bk1),
1558  BBlockTransferSrcAccessOrder,
1559  BBlockTransferSrcVectorDim,
1560  2,
1561  BBlockTransferSrcScalarPerVector>(
1562  b_grid_desc_bk0_n_bk1,
1563  make_multi_index(0, n_block_data_idx_on_grid, 0),
1564  b_block_desc_bk0_n_bk1,
1565  make_multi_index(0, 0, 0));
1566 
1567  // LDS allocation for A and B: be careful of alignment
1568  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1569  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1570 
1571  // Cast after lds
1572  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1573  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1574 
1575  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1576  reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
1577  a_block_space_size_aligned * sizeof(ADataType)),
1578  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1579 
1580  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1581  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1582 
1583  // Blockwise GEMM pipeline
1584  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1585  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1586  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1587  decltype(c_thread_buf) c_thread_buf_up;
1588 
1589  StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
1590  float,
1591  c_thread_buf.num_of_v_,
1592  c_thread_buf.s_per_v,
1593  true>
1594  c_thread_buf_fp32;
1595 
1596  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1597  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1598  KPerBlock);
1599 
1600  // a and b scale processing
1601  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1602  const auto waveId_m = wave_idx[I0];
1603  const auto waveId_n = wave_idx[I1];
1604 
1605  auto thread_offset_shuffled =
1606  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
1607 
1608  auto a_thread_offset_m = waveId_m;
1609 
1610  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1611  AScaleDataType,
1612  AScaleDataType,
1613  decltype(a_scale_grid_desc_am_ak),
1614  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1615  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
1616  Sequence<0, 1, 2>, // DimAccessOrder
1617  2, // SrcVectorDim
1618  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
1619  1, // SrcScalarStrideInVector
1620  true>(a_scale_grid_desc_am_ak,
1621  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
1622  0,
1623  thread_offset_shuffled / scale_pack_size_a));
1624 
1625  // B scale load
1626  auto b_thread_offset_n = waveId_n;
1627 
1628  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1629  BScaleDataType,
1630  BScaleDataType,
1631  decltype(b_scale_grid_desc_bn_ak),
1632  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1633  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1634  Sequence<0, 1, 2>, // DimAccessOrder
1635  2, // SrcVectorDim
1636  KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
1637  1, // SrcScalarStrideInVector
1638  true>(b_scale_grid_desc_bn_ak,
1639  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1640  0,
1641  thread_offset_shuffled / scale_pack_size_b));
1642 
1643  if constexpr(IsInputGemm)
1644  {
1645  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1646  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1647  auto b_block_buf_up = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1648  reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
1649  a_block_space_size_aligned * sizeof(ADataType) +
1650  b_block_space_size_aligned * sizeof(BDataType)),
1651  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1652 
1653  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
1654  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1655  p_b_grid_up + expert_id * expert_stride,
1656  b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1657 
1658  auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad<
1660  Sequence<BK0Number, NPerBlock, BK1Number>,
1661  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1662  BBlockTransferThreadClusterArrangeOrder,
1663  BDataType,
1664  BDataType,
1665  decltype(b_grid_desc_bk0_n_bk1),
1666  decltype(b_block_desc_bk0_n_bk1),
1667  BBlockTransferSrcAccessOrder,
1668  BBlockTransferSrcVectorDim,
1669  2,
1670  BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1,
1671  make_multi_index(0, n_block_data_idx_on_grid, 0),
1672  b_block_desc_bk0_n_bk1,
1673  make_multi_index(0, 0, 0));
1674 
1675  const BScaleDataType* p_b_scale_grid_up =
1676  p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
1677  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1678  p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
1679  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1680 
1681  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
1682  BScaleDataType,
1683  BScaleDataType,
1684  decltype(b_scale_grid_desc_bn_ak),
1685  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1686  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1687  Sequence<0, 1, 2>, // DimAccessOrder
1688  2, // SrcVectorDim
1689  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1690  1, // SrcScalarStrideInVector
1691  true>(
1692  b_scale_grid_desc_bn_ak,
1693  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1694  0,
1695  thread_offset_shuffled / scale_pack_size_b));
1696 
1697  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1698  // A
1699  a_grid_desc_ak0_m_ak1,
1700  a_block_desc_ak0_m_ak1,
1701  a_blockwise_copy,
1702  a_grid_buf,
1703  a_block_buf,
1704  a_block_slice_copy_step,
1705  // Gate and Up
1706  b_grid_desc_bk0_n_bk1,
1707  b_block_desc_bk0_n_bk1,
1708  b_blockwise_copy,
1709  b_blockwise_copy_up,
1710  b_grid_buf,
1711  b_grid_buf_up,
1712  b_block_buf,
1713  b_block_buf_up,
1714  b_block_slice_copy_step,
1715  // C
1716  c_thread_buf,
1717  c_thread_buf_up,
1718  // A scale
1719  a_scale_grid_desc_am_ak,
1720  a_scale_thread_copy,
1721  a_scale_grid_buf,
1722  // Gate and Up scale
1723  b_scale_grid_desc_bn_ak,
1724  b_scale_thread_copy,
1725  b_scale_thread_copy_up,
1726  b_scale_grid_buf,
1727  b_scale_grid_buf_up,
1728  num_k_block_main_loop);
1729  }
1730  else
1731  {
1732  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1733  a_grid_desc_ak0_m_ak1, // A
1734  a_block_desc_ak0_m_ak1,
1735  a_blockwise_copy,
1736  a_grid_buf,
1737  a_block_buf,
1738  a_block_slice_copy_step,
1739  b_grid_desc_bk0_n_bk1, // B
1740  b_block_desc_bk0_n_bk1,
1741  b_blockwise_copy,
1742  b_grid_buf,
1743  b_block_buf,
1744  b_block_slice_copy_step,
1745  c_thread_buf, // C
1746  a_scale_grid_desc_am_ak, // A scale
1747  a_scale_thread_copy,
1748  a_scale_grid_buf,
1749  b_scale_grid_desc_bn_ak, // B scale
1750  b_scale_thread_copy,
1751  b_scale_grid_buf,
1752  num_k_block_main_loop);
1753  }
1754 
1755  // shuffle C and write out
1756  {
1757  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1758  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1759  "wrong!");
1760  static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
1761  CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
1762  "wrong!");
1763 
1764  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1765  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1766 
1767  // TODO: hacky, fix it!
1768  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1769  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1770 
1771  // TODO: hacky, fix it!
1772  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1773  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1774  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1775 
1776  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1777  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1778  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1779  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1780  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1781  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1782  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1783  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1784  constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
1785  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
1786 
1787  // mul scales
1788  static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
1789  static_assert(M5 == 4);
1790  const index_t m1 = get_warp_local_1d_id() / NWave; // Mwave id
1791  const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
1792 
1793  vector_type<float, 4> topk_weights; // for gemm2 only
1794  static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
1795  static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
1796  static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
1797  static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
1798  static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
1799  const index_t m_pos = block_m_id * MPerBlock +
1800  m0 * M2 * M1 * M3 * M4 * M5 +
1801  m1 * M2 * M3 * M4 * M5 +
1802  imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
1803  if constexpr(MulRoutedWeight)
1804  {
1805  topk_weights =
1806  *c_style_pointer_cast<const vector_type<float, M5>*>(
1807  p_ds_grid[I2] + m_pos);
1808  }
1809  static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
1810  constexpr index_t c_offset =
1811  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1812  make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
1813  constexpr auto cidx = Number<c_offset>{};
1814 
1815  if constexpr(IsInputGemm) // gu fusion
1816  {
1817  if constexpr(ActivationOperation ==
1819  {
1820  float gate = c_thread_buf[cidx];
1821  float up = c_thread_buf_up[cidx];
1822  if constexpr(MulRoutedWeight)
1823  {
1824  gate = gate * topk_weights.AsType<float>()[m5];
1825  up = up * topk_weights.AsType<float>()[m5];
1826  }
1827  tensor_operation::element_wise::Silu{}(gate, gate);
1828  c_thread_buf_fp32(cidx) = gate * up;
1829  }
1830  else if(ActivationOperation == Activation::gelu_and_mul)
1831  {
1832  float gate = c_thread_buf[cidx];
1833  float up = c_thread_buf_up[cidx];
1834  if constexpr(MulRoutedWeight)
1835  {
1836  gate = gate * topk_weights.AsType<float>()[m5];
1837  up = up * topk_weights.AsType<float>()[m5];
1838  }
1839  tensor_operation::element_wise::Gelu{}(gate, gate);
1840  c_thread_buf_fp32(cidx) = gate * up;
1841 
1842  /*float gate = c_thread_buf[cidx];
1843  float up = c_thread_buf_up[cidx];
1844  if constexpr(MulRoutedWeight)
1845  {
1846  gate = gate * topk_weights.AsType<float>()[m5];
1847  //up = up * topk_weights.AsType<float>()[m5];
1848  }
1849  tensor_operation::element_wise::Gelu{}(gate, gate);
1850  c_thread_buf_fp32(cidx) = up;*/
1851  }
1852  }
1853  else
1854  {
1855  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1856  if constexpr(MulRoutedWeight)
1857  {
1858  c_thread_buf_fp32(cidx) =
1859  topk_weights.AsType<float>()[m5] *
1860  c_thread_buf_fp32[cidx];
1861  }
1862  }
1863  });
1864  });
1865  });
1866  });
1867  });
1868  });
1869 
1870  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1872 
1873  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1874  static_cast<CShuffleDataType*>(p_shared),
1875  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1876 
1877  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1878  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1879  make_tuple(
1882  Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave)
1883  // per shuffle
1884  M1, // M1 = MWave
1885  M2, // M2 = MXdlPack
1886  M3, // M3 * M4 * M5 = MPerXdl
1887  M4,
1888  M5)),
1891  Number<CShuffleNXdlPerWavePerShuffle / NXdlPack>{}, // N0 (NXdlPerWave)
1892  // per shuffle
1893  N1, // N1 = NWave
1894  N2, // N2 = NXdlPack
1895  N3))), // N3 = NPerXdl
1896  make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
1897  make_tuple(Sequence<>{},
1898  Sequence<0, 2, 4, 6, 7, 8>{},
1899  Sequence<>{},
1900  Sequence<1, 3, 5, 9>{}));
1901 
1902  // calculate origin of thread output tensor on global memory
1903  // blockwise GEMM c matrix starting index
1904  const auto c_thread_mtx_on_block =
1905  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1906 
1907  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1908  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1909 
1910  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1912  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
1913  make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}),
1914  make_tuple(Sequence<0>{}));
1915 
1916  const auto m_thread_data_on_block_idx =
1917  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1918  make_multi_index(m_thread_data_on_block));
1919 
1920  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1922  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
1923  make_tuple(Sequence<0, 1, 2, 3>{}),
1924  make_tuple(Sequence<0>{}));
1925 
1926  const auto n_thread_data_on_block_idx =
1927  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1928  make_multi_index(n_thread_data_on_block));
1929 
1930  // shuffle: threadwise copy C from VGPR to LDS
1931  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1932  AccDataType,
1933  CShuffleDataType,
1934  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1935  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1937  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
1938  CShuffleNXdlPerWavePerShuffle / NXdlPack,
1939  I1,
1940  I1,
1941  M2,
1942  N2,
1943  M3,
1944  I1,
1945  M5,
1946  I1>,
1947  Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
1948  9,
1949  1,
1951  1,
1952  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1953  make_multi_index(0,
1954  0,
1955  m_thread_data_on_block_idx[I1],
1956  n_thread_data_on_block_idx[I1],
1957  m_thread_data_on_block_idx[I2],
1958  n_thread_data_on_block_idx[I2],
1959  m_thread_data_on_block_idx[I3],
1960  m_thread_data_on_block_idx[I4],
1961  m_thread_data_on_block_idx[I5],
1962  n_thread_data_on_block_idx[I3]),
1964 
1965  using EDataType = CDataType;
1966 
1967  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1968  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1969 
1970  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1972  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1973 
1974  const auto ds_grid_buf = generate_tuple(
1975  [&](auto i) {
1976  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1977  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1978  },
1979  Number<NumDTensor>{});
1980 
1981  // tuple of reference to C/Ds tensor descriptors
1982  const auto c_ds_desc_refs = concat_tuple_of_reference(
1983  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1984  generate_tie([&](auto i) -> const auto& // return type should be reference
1985  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1986  Number<NumDTensor>{}));
1987 
1988  // tuple of reference to C/Ds tensor descriptors
1989  const auto c_ds_buf_refs = concat_tuple_of_reference(
1990  tie(c_shuffle_block_buf),
1991  generate_tie([&](auto i) -> const auto& // return type should be reference
1992  { return ds_grid_buf[i]; },
1993  Number<NumDTensor>{}));
1994 
1995  // tuple of starting index of C/Ds blockwise copy
1996  const auto idx_c_ds_block_begin =
1999  [&](auto) {
2000  return make_multi_index(block_m_id, 0, block_n_id, 0);
2001  // return make_multi_index(block_work_idx[I0], 0,
2002  // block_work_idx[I1], 0);
2003  },
2004  Number<NumDTensor>{}));
2005 
2006  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2007  c_grid_desc_mblock_mperblock_nblock_nperblock;
2008 
2009  using CDEBlockTransferCluster =
2010  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2011  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2012  constexpr index_t scatter_weight_idx = 3; // hack fix felix
2013  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2015  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2016  Tuple<EDataType>,
2017  decltype(c_ds_desc_refs),
2018  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2019  CElementwiseOperation,
2020  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
2021  // Sequence support
2022  // arbitray type
2023  Sequence<1,
2024  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2025  1,
2026  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2027  CDEBlockTransferCluster,
2028  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2029  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2030  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2031  3, // index_t SrcVectorDim,
2032  3, // index_t DstVectorDim,
2033  CDEShuffleBlockTransferScalarPerVectors,
2036  Sequence<true>,
2038  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2039  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2040  IndexType,
2041  1, // ScatterDim
2042  true, // OutputScatter: false, only use scatter weights
2043  scatter_weight_idx // ScatterWeightIdx: ascale
2044  >{c_ds_desc_refs,
2045  idx_c_ds_block_begin,
2046  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2047  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2048  c_element_op};
2049 
2050  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2051  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2052 
2053  constexpr auto sfc_c_vgpr =
2054  SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2055  NXdlPerWave / NXdlPack,
2056  1,
2057  1,
2058  MXdlPack,
2059  NXdlPack,
2060  M2,
2061  1,
2062  M4,
2063  1>,
2064  Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
2065  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2066  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2067  1,
2068  1,
2069  MXdlPack,
2070  NXdlPack,
2071  M2,
2072  1,
2073  M4,
2074  1>>{};
2075 
2076  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2077 
2078  // space filling curve for shuffled blockwise C/D/E
2079  constexpr auto sfc_cde_block =
2080  SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
2081  Sequence<0, 2, 1, 3>,
2082  Sequence<1,
2083  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2084  1,
2085  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2086 
2087  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2088  constexpr auto EMThreads =
2089  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2090  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2091  constexpr auto ENThreads =
2092  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2093  static_for<0, num_access, 1>{}([&](auto access_id) {
2094  // make sure it's safe to write to LDS
2095  StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
2096 
2097  auto dstidx = sfc_cde_block.GetIndex(access_id);
2098  const index_t c_token_pos =
2099  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2100  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2101  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2102  IndexType token_offset = fused_token & 0xffffff;
2103  if constexpr(IsInputGemm)
2104  {
2105  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2106  }
2107  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2108  });
2109 
2110  block_sync_lds();
2111 
2112  // each thread write its data from VGPR to LDS
2113  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2114  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2115  c_thread_buf_fp32,
2116  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2117  c_shuffle_block_buf);
2118 
2119  // make sure it's safe to read from LDS
2120  block_sync_lds();
2121 
2122  // each block copy its data from LDS to global
2123  cde_block_copy_lds_and_global.Run(
2124  c_ds_desc_refs,
2125  c_ds_buf_refs,
2126  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2127  tie(c_grid_buf),
2128  scatter_offsets);
2129 
2130  if constexpr(access_id < num_access - 1)
2131  {
2132  constexpr auto cde_lds_and_global_step =
2133  sfc_cde_block.GetForwardStep(access_id);
2134 
2135  // move on Ds
2136  static_for<0, NumDTensor, 1>{}([&](auto i) {
2137  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2138  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2139  });
2140 
2141  // move on E
2142  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2143  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2144  I0,
2145  cde_lds_and_global_step);
2146  }
2147  });
2148  }
2149  }
2150 #endif
2151 
2152  template <bool HasMainKBlockLoop,
2153  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2154  TailNumber TailNum = TailNumber::Odd>
2155  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
2156  const index_t* p_sorted_expert_ids,
2157  const index_t* p_max_token_id,
2158  const ADataType* p_a_grid,
2159  const AScaleDataType* p_a_scale_grid,
2160  const BDataType* p_b_grid,
2161  const BScaleDataType* p_b_scale_grid,
2162  DsGridPointer& p_ds_grid,
2163  CDataType* p_c_grid,
2164  void* p_shared_0,
2165  void* p_shared_1,
2166  const Problem& problem,
2167  AElementwiseOperation a_element_op,
2168  BElementwiseOperation b_element_op,
2169  CElementwiseOperation c_element_op)
2170  {
2171  ignore = a_element_op;
2172  ignore = b_element_op;
2173  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2174  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
2175  problem.MPadded,
2176  problem.K,
2177  problem.KPadded,
2178  problem.StrideA,
2179  problem.AK0);
2180  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
2181  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2182  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
2183  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
2184  problem.MPadded,
2185  problem.N,
2186  problem.NPadded,
2187  problem.StrideC);
2188 
2189  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
2190  make_tuple(problem.M / (MXdlPack * MPerXdl),
2191  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
2192  (KXdlPack * 64 / MPerXdl),
2193  64 * KXdlPack * MXdlPack / scale_pack_size_a));
2194 
2195  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
2196  make_tuple(problem.N / (NXdlPack * NPerXdl),
2197  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
2198  (KXdlPack * 64 / NPerXdl),
2199  64 * KXdlPack * NXdlPack / scale_pack_size_b));
2200 
2201  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2203  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2204 
2205  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2206  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
2207  if(expert_block_id * MPerBlock >= max_token_id)
2208  return;
2209  const index_t expert_id =
2210  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2211  const auto block_mn = [&]() -> std::pair<int, int> {
2212  if constexpr(NSwizzle)
2213  {
2214  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2215  const index_t prefix_block = ecnt_prefix * problem.NBlock;
2216  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2217  const index_t expert_swizzle =
2218  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
2219  const index_t bid_new = blockIdx.x - prefix_block;
2220  const index_t nid = __builtin_amdgcn_readfirstlane(
2221  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2222  const index_t mid =
2223  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2224  return {nid, mid};
2225  }
2226  else
2227  {
2228  return {blockIdx.x, blockIdx.y};
2229  }
2230  }();
2231 
2232  const index_t block_n_id = block_mn.first;
2233  const index_t block_m_id = block_mn.second;
2234  const index_t token0 =
2235  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2236 
2237  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2238  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2239  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
2240  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
2241  constexpr auto AKThreads = AK0Threads * AK1Threads;
2242  constexpr auto AMRepeats = MPerBlock / AMThreads;
2243  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads;
2244 
2245  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
2246  return;
2248  static_for<0, AMRepeats, 1>{}([&](auto m0) {
2249  const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
2250  index_t token_offset = fused_token & 0xffffff;
2251  if constexpr(!IsInputGemm)
2252  {
2253  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2254  }
2255  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2256  });
2257 
2258  const index_t expert_stride =
2259  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2260  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2261  problem.N * (IsInputGemm ? 2 : 1) *
2262  math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
2263 
2264  // N0, K0, Blocksize*KPack
2265  const index_t n_block_data_idx_on_grid =
2266  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
2267 
2268  // Gride buffer creation
2269  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2270  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2271  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2272  p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
2273 
2274  // A, B scale buffer
2275  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2276  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2277  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2278  p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
2279  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2280 
2281  // lds max alignment
2282  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
2283 
2284  // A matrix in LDS memory, dst of blockwise copy
2285  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2286 
2287  // B matrix in LDS memory, dst of blockwise copy
2288  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2289 
2290  // A matrix blockwise direct to LDS copy
2291  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
2294  ABlockTransferThreadClusterLengths_AK0_M_AK1,
2295  ABlockTransferThreadClusterArrangeOrder,
2296  ADataType,
2297  ADataType,
2298  decltype(a_grid_desc_ak0_m_ak1),
2299  decltype(a_block_desc_ak0_m_ak1),
2300  ABlockTransferSrcAccessOrder,
2301  ABlockTransferSrcVectorDim,
2302  2,
2303  ABlockTransferSrcScalarPerVector,
2304  IndexType,
2305  1>(a_grid_desc_ak0_m_ak1,
2306  make_multi_index(0, 0, 0),
2307  a_block_desc_ak0_m_ak1,
2308  make_multi_index(0, 0, 0),
2309  gather_offsets);
2310 
2311  // B matrix blockwise copy
2312  auto b_blockwise_copy =
2315  BBlockTransferThreadClusterLengths_BK0_N_BK1,
2316  BBlockTransferThreadClusterArrangeOrder,
2317  BDataType,
2318  BDataType,
2319  decltype(b_grid_desc_bk0_n_bk1),
2320  decltype(b_block_desc_bk0_n_bk1),
2321  BBlockTransferSrcAccessOrder,
2322  BBlockTransferSrcVectorDim,
2323  2,
2324  BBlockTransferSrcScalarPerVector>(
2325  b_grid_desc_bk0_n_bk1,
2326  make_multi_index(0, n_block_data_idx_on_grid, 0),
2327  b_block_desc_bk0_n_bk1,
2328  make_multi_index(0, 0, 0));
2329 
2330  // LDS allocation for A and B: be careful of alignment
2331  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
2332  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
2333 
2334  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2335  static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2336 
2337  auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2338  bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
2339  a_block_space_size_aligned * sizeof(ADataType)),
2340  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2341 
2342  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2343  static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2344 
2345  auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2346  bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
2347  a_block_space_size_aligned * sizeof(ADataType)),
2348  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2349 
2350  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2351  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2352 
2353  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2354  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
2355 
2356  // Blockwise GEMM pipeline
2357  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2358  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2359  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2360  decltype(c_thread_buf) c_thread_buf_up;
2361 
2363  float,
2364  c_thread_buf.num_of_v_,
2365  c_thread_buf.s_per_v,
2366  true>
2367  c_thread_buf_fp32;
2368 
2369  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2370  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2371  KPerBlock);
2372 
2373  // a and b scale processing
2374  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2375  const auto waveId_m = wave_idx[I0];
2376  const auto waveId_n = wave_idx[I1];
2377 
2378  auto thread_offset_shuffled =
2379  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
2380 
2381  auto a_thread_offset_m = waveId_m;
2382 
2383  // get each thread's offset int the scale tensor
2384  const index_t token_scale_pos = block_m_id * MPerBlock;
2385  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2386  return;
2387 
2388  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2389  AScaleDataType,
2390  AScaleDataType,
2391  decltype(a_scale_grid_desc_am_ak),
2392  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2393  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
2394  Sequence<0, 1, 2>, // DimAccessOrder
2395  2, // SrcVectorDim
2396  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
2397  1, // SrcScalarStrideInVector
2398  true>(a_scale_grid_desc_am_ak,
2399  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
2400  0,
2401  thread_offset_shuffled / scale_pack_size_a));
2402 
2403  // B scale load
2404  auto b_thread_offset_n = waveId_n;
2405 
2406  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2407  BScaleDataType,
2408  BScaleDataType,
2409  decltype(b_scale_grid_desc_bn_ak),
2410  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2411  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2412  Sequence<0, 1, 2>, // DimAccessOrder
2413  2, // SrcVectorDim
2414  KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
2415  1, // SrcScalarStrideInVector
2416  true>(b_scale_grid_desc_bn_ak,
2417  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2418  0,
2419  thread_offset_shuffled / scale_pack_size_b));
2420 
2421  if constexpr(IsInputGemm)
2422  {
2423  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
2424  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2425  p_b_grid_up + expert_id * expert_stride,
2426  b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
2427 
2428  // lds ping pong buffers for up
2429  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
2430  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
2431  auto b_block_buf_up_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2432  bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
2433  a_block_space_size_aligned * sizeof(ADataType) +
2434  b_block_space_size_aligned * sizeof(BDataType)),
2435  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2436  auto b_block_buf_up_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2437  bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
2438  a_block_space_size_aligned * sizeof(ADataType) +
2439  b_block_space_size_aligned * sizeof(BDataType)),
2440  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2441 
2442  auto b_block_bufs_up = make_tuple(b_block_buf_up_ping, b_block_buf_up_pong);
2443 
2444  auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad<
2447  BBlockTransferThreadClusterLengths_BK0_N_BK1,
2448  BBlockTransferThreadClusterArrangeOrder,
2449  BDataType,
2450  BDataType,
2451  decltype(b_grid_desc_bk0_n_bk1),
2452  decltype(b_block_desc_bk0_n_bk1),
2453  BBlockTransferSrcAccessOrder,
2454  BBlockTransferSrcVectorDim,
2455  2,
2456  BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1,
2457  make_multi_index(0, n_block_data_idx_on_grid, 0),
2458  b_block_desc_bk0_n_bk1,
2459  make_multi_index(0, 0, 0));
2460 
2461  const BScaleDataType* p_b_scale_grid_up =
2462  p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
2463  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2464  p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
2465  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2466 
2467  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
2468  BScaleDataType,
2469  BScaleDataType,
2470  decltype(b_scale_grid_desc_bn_ak),
2471  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2472  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2473  Sequence<0, 1, 2>, // DimAccessOrder
2474  2, // SrcVectorDim
2475  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
2476  1, // SrcScalarStrideInVector
2477  true>(
2478  b_scale_grid_desc_bn_ak,
2479  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2480  0,
2481  thread_offset_shuffled / scale_pack_size_b));
2482 
2483  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2484  // A
2485  a_grid_desc_ak0_m_ak1,
2486  a_block_desc_ak0_m_ak1,
2487  a_blockwise_copy,
2488  a_grid_buf,
2489  a_block_bufs,
2490  a_block_slice_copy_step,
2491  // Gate and Up
2492  b_grid_desc_bk0_n_bk1,
2493  b_block_desc_bk0_n_bk1,
2494  b_blockwise_copy,
2495  b_blockwise_copy_up,
2496  b_grid_buf,
2497  b_grid_buf_up,
2498  b_block_bufs,
2499  b_block_bufs_up,
2500  b_block_slice_copy_step,
2501  // C
2502  c_thread_buf,
2503  c_thread_buf_up,
2504  // A scale
2505  a_scale_grid_desc_am_ak,
2506  a_scale_thread_copy,
2507  a_scale_grid_buf,
2508  // B scale
2509  b_scale_grid_desc_bn_ak,
2510  b_scale_thread_copy,
2511  b_scale_thread_copy_up,
2512  b_scale_grid_buf,
2513  b_scale_grid_buf_up,
2514  num_k_block_main_loop);
2515  }
2516  else
2517  {
2518  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2519  a_grid_desc_ak0_m_ak1, // A
2520  a_block_desc_ak0_m_ak1,
2521  a_blockwise_copy,
2522  a_grid_buf,
2523  a_block_bufs,
2524  a_block_slice_copy_step,
2525  b_grid_desc_bk0_n_bk1, // B
2526  b_block_desc_bk0_n_bk1,
2527  b_blockwise_copy,
2528  b_grid_buf,
2529  b_block_bufs,
2530  b_block_slice_copy_step,
2531  c_thread_buf, // C
2532  a_scale_grid_desc_am_ak, // A scale
2533  a_scale_thread_copy,
2534  a_scale_grid_buf,
2535  b_scale_grid_desc_bn_ak, // B scale
2536  b_scale_thread_copy,
2537  b_scale_grid_buf,
2538  num_k_block_main_loop);
2539  }
2540 
2541  // shuffle C and write out
2542  {
2543  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2544  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2545  "wrong!");
2546  static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
2547  CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
2548  "wrong!");
2549 
2550  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2551  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2552 
2553  // TODO: hacky, fix it!
2554  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2555  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2556 
2557  // TODO: hacky, fix it!
2558  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2559  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2560  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2561 
2562  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2563  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2564  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2565  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2566  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2567  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2568  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2569  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2570  constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
2571  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
2572 
2573  // mul scales
2574 
2575  static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
2576  static_assert(M5 == 4);
2577  const index_t m1 = get_warp_local_1d_id() / NWave;
2578  const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
2579 
2580  vector_type<float, 4> topk_weights; // for gemm2 only
2581  static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
2582  static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
2583  static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
2584  static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
2585  static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
2586  const index_t m_pos = block_m_id * MPerBlock +
2587  m0 * M2 * M1 * M3 * M4 * M5 +
2588  m1 * M2 * M3 * M4 * M5 +
2589  imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
2590  if constexpr(MulRoutedWeight)
2591  {
2592  topk_weights =
2593  *c_style_pointer_cast<const vector_type<float, M5>*>(
2594  p_ds_grid[I2] + m_pos);
2595  }
2596  static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
2597  constexpr index_t c_offset =
2598  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2599  make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
2600  constexpr auto cidx = Number<c_offset>{};
2601 
2602  if constexpr(IsInputGemm) // gu fusion
2603  {
2604  if constexpr(ActivationOperation ==
2606  {
2607  float gate = c_thread_buf[cidx];
2608  float up = c_thread_buf_up[cidx];
2609  if constexpr(MulRoutedWeight)
2610  {
2611  gate = gate * topk_weights.AsType<float>()[m5];
2612  up = up * topk_weights.AsType<float>()[m5];
2613  }
2615  c_thread_buf_fp32(cidx) = gate * up;
2616  }
2617  else if(ActivationOperation == Activation::gelu_and_mul)
2618  {
2619  float gate = c_thread_buf[cidx];
2620  float up = c_thread_buf_up[cidx];
2621  if constexpr(MulRoutedWeight)
2622  {
2623  gate = gate * topk_weights.AsType<float>()[m5];
2624  up = up * topk_weights.AsType<float>()[m5];
2625  }
2627  c_thread_buf_fp32(cidx) = gate * up;
2628  }
2629  }
2630  else
2631  {
2632  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2633  if constexpr(MulRoutedWeight)
2634  {
2635  c_thread_buf_fp32(cidx) =
2636  topk_weights.AsType<float>()[m5] *
2637  c_thread_buf_fp32[cidx];
2638  }
2639  }
2640  });
2641  });
2642  });
2643  });
2644  });
2645  });
2646 
2647  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2649 
2650  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2651  static_cast<CShuffleDataType*>(p_shared_0),
2652  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2653 
2654  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2655  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2656  make_tuple(
2659  Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
2660  // shuffle
2661  M1, // M1 = MWave
2662  M2, // M2 * M3 * M4 = MPerXdl
2663  M3,
2664  M4,
2665  M5)),
2669  // per shuffle
2670  N1, // N1 = NWave
2671  N2, // N2 = NXdlPack
2672  N3))), // N3 = NPerXdl
2676  Sequence<>{},
2678 
2679  // calculate origin of thread output tensor on global memory
2680  // blockwise GEMM c matrix starting index
2681  const auto c_thread_mtx_on_block =
2682  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2683 
2684  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2685  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2686 
2687  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2689  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
2691  make_tuple(Sequence<0>{}));
2692 
2693  const auto m_thread_data_on_block_idx =
2694  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2695  make_multi_index(m_thread_data_on_block));
2696 
2697  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2699  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
2701  make_tuple(Sequence<0>{}));
2702 
2703  const auto n_thread_data_on_block_idx =
2704  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2705  make_multi_index(n_thread_data_on_block));
2706 
2707  // shuffle: threadwise copy C from VGPR to LDS
2708  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
2709  AccDataType,
2710  CShuffleDataType,
2711  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2712  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2714  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2715  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2716  I1,
2717  I1,
2718  M2,
2719  N2,
2720  M3,
2721  I1,
2722  M5,
2723  I1>,
2725  9,
2726  1,
2728  1,
2729  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2730  make_multi_index(0,
2731  0,
2732  m_thread_data_on_block_idx[I1],
2733  n_thread_data_on_block_idx[I1],
2734  m_thread_data_on_block_idx[I2],
2735  n_thread_data_on_block_idx[I2],
2736  m_thread_data_on_block_idx[I3],
2737  m_thread_data_on_block_idx[I4],
2738  m_thread_data_on_block_idx[I5],
2739  n_thread_data_on_block_idx[I3]),
2741 
2742  using EDataType = CDataType;
2743 
2744  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2745  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2746 
2747  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2749  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2750 
2751  const auto ds_grid_buf = generate_tuple(
2752  [&](auto i) {
2753  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2754  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2755  },
2756  Number<NumDTensor>{});
2757 
2758  // tuple of reference to C/Ds tensor descriptors
2759  const auto c_ds_desc_refs = concat_tuple_of_reference(
2760  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2761  generate_tie([&](auto i) -> const auto& // return type should be reference
2762  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2763  Number<NumDTensor>{}));
2764 
2765  // tuple of reference to C/Ds tensor descriptors
2766  const auto c_ds_buf_refs = concat_tuple_of_reference(
2767  tie(c_shuffle_block_buf),
2768  generate_tie([&](auto i) -> const auto& // return type should be reference
2769  { return ds_grid_buf[i]; },
2770  Number<NumDTensor>{}));
2771 
2772  // tuple of starting index of C/Ds blockwise copy
2773  const auto idx_c_ds_block_begin =
2776  [&](auto) {
2777  return make_multi_index(block_m_id, 0, block_n_id, 0);
2778  // return make_multi_index(block_work_idx[I0], 0,
2779  // block_work_idx[I1], 0);
2780  },
2781  Number<NumDTensor>{}));
2782 
2783  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2784  c_grid_desc_mblock_mperblock_nblock_nperblock;
2785 
2786  using CDEBlockTransferCluster =
2787  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2788  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2789  constexpr index_t scatter_weight_idx = 3; // hack fix felix
2790  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2792  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2794  decltype(c_ds_desc_refs),
2795  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2796  CElementwiseOperation,
2797  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
2798  // Sequence support
2799  // arbitray type
2800  Sequence<1,
2801  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2802  1,
2803  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2804  CDEBlockTransferCluster,
2805  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2806  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2807  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2808  3, // index_t SrcVectorDim,
2809  3, // index_t DstVectorDim,
2810  CDEShuffleBlockTransferScalarPerVectors,
2815  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2816  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2817  IndexType,
2818  1, // ScatterDim
2819  true, // OutputScatter: false, only use scatter weights
2820  scatter_weight_idx // ScatterWeightIdx: ascale
2821  >{c_ds_desc_refs,
2822  idx_c_ds_block_begin,
2823  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2824  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2825  c_element_op};
2826 
2827  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2828  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2829 
2830  constexpr auto sfc_c_vgpr =
2831  SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2832  NXdlPerWave / NXdlPack,
2833  1,
2834  1,
2835  MXdlPack,
2836  NXdlPack,
2837  M2,
2838  1,
2839  M4,
2840  1>,
2842  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2843  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2844  1,
2845  1,
2846  MXdlPack,
2847  NXdlPack,
2848  M2,
2849  1,
2850  M4,
2851  1>>{};
2852 
2853  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2854 
2855  // space filling curve for shuffled blockwise C/D/E
2856  constexpr auto sfc_cde_block =
2859  Sequence<1,
2860  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2861  1,
2862  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2863 
2864  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2865  constexpr auto EMThreads =
2866  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2867  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2868  constexpr auto ENThreads =
2869  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2870  static_for<0, num_access, 1>{}([&](auto access_id) {
2871  // make sure it's safe to write to LDS
2873 
2874  auto dstidx = sfc_cde_block.GetIndex(access_id);
2875  const index_t c_token_pos =
2876  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2877  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2878  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2879  IndexType token_offset = fused_token & 0xffffff;
2880  if constexpr(IsInputGemm)
2881  {
2882  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2883  }
2884  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2885  });
2886 
2887  block_sync_lds();
2888 
2889  // each thread write its data from VGPR to LDS
2890  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2891  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2892  c_thread_buf_fp32,
2893  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2894  c_shuffle_block_buf);
2895 
2896  // make sure it's safe to read from LDS
2897  block_sync_lds();
2898 
2899  // each block copy its data from LDS to global
2900  cde_block_copy_lds_and_global.Run(
2901  c_ds_desc_refs,
2902  c_ds_buf_refs,
2903  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2904  tie(c_grid_buf),
2905  scatter_offsets);
2906 
2907  if constexpr(access_id < num_access - 1)
2908  {
2909  constexpr auto cde_lds_and_global_step =
2910  sfc_cde_block.GetForwardStep(access_id);
2911 
2912  // move on Ds
2913  static_for<0, NumDTensor, 1>{}([&](auto i) {
2914  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2915  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2916  });
2917 
2918  // move on E
2919  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2920  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2921  I0,
2922  cde_lds_and_global_step);
2923  }
2924  });
2925  }
2926  }
2927 };
2928 
2929 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:266
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:23
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
constexpr auto BlockGemmMXPipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_selector.hpp:36
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:275
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:87
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:297
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: gridwise_moe_mx_gemm.hpp:715
const index_t * p_max_token_id
Definition: gridwise_moe_mx_gemm.hpp:777
CDataType * p_c_grid
Definition: gridwise_moe_mx_gemm.hpp:783
const AElementwiseOperation a_element_op
Definition: gridwise_moe_mx_gemm.hpp:785
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_mx_gemm.hpp:776
const CElementwiseOperation c_element_op
Definition: gridwise_moe_mx_gemm.hpp:787
const BDataType * p_b_grid
Definition: gridwise_moe_mx_gemm.hpp:780
const BScaleDataType * p_b_scale_grid
Definition: gridwise_moe_mx_gemm.hpp:781
DsGridPointer p_ds_grid
Definition: gridwise_moe_mx_gemm.hpp:782
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_mx_gemm.hpp:716
const index_t * p_sorted_token_ids
Definition: gridwise_moe_mx_gemm.hpp:775
const BElementwiseOperation b_element_op
Definition: gridwise_moe_mx_gemm.hpp:786
const AScaleDataType * p_a_scale_grid
Definition: gridwise_moe_mx_gemm.hpp:779
const ADataType * p_a_grid
Definition: gridwise_moe_mx_gemm.hpp:778
Definition: gridwise_moe_mx_gemm.hpp:643
index_t MBlock
Definition: gridwise_moe_mx_gemm.hpp:709
index_t NPadded
Definition: gridwise_moe_mx_gemm.hpp:704
index_t K
Definition: gridwise_moe_mx_gemm.hpp:695
index_t N
Definition: gridwise_moe_mx_gemm.hpp:694
index_t NumTokens
Definition: gridwise_moe_mx_gemm.hpp:691
index_t M
Definition: gridwise_moe_mx_gemm.hpp:693
index_t StrideA
Definition: gridwise_moe_mx_gemm.hpp:696
index_t StrideScaleB
Definition: gridwise_moe_mx_gemm.hpp:699
index_t KRead
Definition: gridwise_moe_mx_gemm.hpp:705
index_t NBlock
Definition: gridwise_moe_mx_gemm.hpp:710
index_t StrideC
Definition: gridwise_moe_mx_gemm.hpp:701
index_t StrideB
Definition: gridwise_moe_mx_gemm.hpp:698
__host__ void Print() const
Definition: gridwise_moe_mx_gemm.hpp:679
index_t BK0
Definition: gridwise_moe_mx_gemm.hpp:708
index_t StrideScaleA
Definition: gridwise_moe_mx_gemm.hpp:697
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_mx_gemm.hpp:700
index_t MPadded
Definition: gridwise_moe_mx_gemm.hpp:703
index_t KBatch
Definition: gridwise_moe_mx_gemm.hpp:702
index_t KPadded
Definition: gridwise_moe_mx_gemm.hpp:706
index_t TopK
Definition: gridwise_moe_mx_gemm.hpp:692
index_t AK0
Definition: gridwise_moe_mx_gemm.hpp:707
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_mx_gemm.hpp:644
Definition: gridwise_moe_mx_gemm.hpp:791
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_mx_gemm.hpp:792
index_t b_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:846
index_t b_scale_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:848
index_t a_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:845
index_t a_scale_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:847
Definition: gridwise_moe_mx_gemm.hpp:173
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm.hpp:242
static constexpr index_t APackedSize
Definition: gridwise_moe_mx_gemm.hpp:212
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_mx_gemm.hpp:227
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm.hpp:1165
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm.hpp:619
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm.hpp:631
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:273
static constexpr auto AK0Number
Definition: gridwise_moe_mx_gemm.hpp:191
static constexpr auto I7
Definition: gridwise_moe_mx_gemm.hpp:184
static constexpr index_t BPackedSize
Definition: gridwise_moe_mx_gemm.hpp:213
remove_cvref_t< decltype(BlockGemmMXPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_mx_gemm.hpp:1126
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_mx_gemm.hpp:252
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:279
static constexpr auto KXdlPack
Definition: gridwise_moe_mx_gemm.hpp:204
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_mx_gemm.hpp:969
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm.hpp:331
static constexpr auto I6
Definition: gridwise_moe_mx_gemm.hpp:183
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:1350
static constexpr auto AK1Number
Definition: gridwise_moe_mx_gemm.hpp:193
static constexpr auto I9
Definition: gridwise_moe_mx_gemm.hpp:186
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:285
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_mx_gemm.hpp:238
static constexpr auto I8
Definition: gridwise_moe_mx_gemm.hpp:185
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm.hpp:446
static constexpr auto I0
Definition: gridwise_moe_mx_gemm.hpp:177
BDataType LDSTypeB
Definition: gridwise_moe_mx_gemm.hpp:175
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_mx_gemm.hpp:555
static constexpr auto NXdlPack
Definition: gridwise_moe_mx_gemm.hpp:203
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_mx_gemm.hpp:574
static constexpr auto I3
Definition: gridwise_moe_mx_gemm.hpp:180
static constexpr index_t scale_pack_size_b
Definition: gridwise_moe_mx_gemm.hpp:1378
static constexpr index_t scale_pack_size_a
Definition: gridwise_moe_mx_gemm.hpp:1377
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_mx_gemm.hpp:292
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:262
static constexpr index_t NumDTensor
Definition: gridwise_moe_mx_gemm.hpp:200
static constexpr index_t SortedTileSize
Definition: gridwise_moe_mx_gemm.hpp:225
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_mx_gemm.hpp:297
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_mx_gemm.hpp:851
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_mx_gemm.hpp:188
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_mx_gemm.hpp:2155
ADataType LDSTypeA
Definition: gridwise_moe_mx_gemm.hpp:174
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_mx_gemm.hpp:565
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_mx_gemm.hpp:240
static constexpr auto lcm_AK1_BK1
Definition: gridwise_moe_mx_gemm.hpp:196
static constexpr auto I1
Definition: gridwise_moe_mx_gemm.hpp:178
static constexpr auto BK1Number
Definition: gridwise_moe_mx_gemm.hpp:194
static constexpr auto is_scale_mfma
Definition: gridwise_moe_mx_gemm.hpp:198
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_mx_gemm.hpp:598
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_mx_gemm.hpp:257
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_mx_gemm.hpp:1128
static constexpr auto I5
Definition: gridwise_moe_mx_gemm.hpp:182
static constexpr index_t KPack
Definition: gridwise_moe_mx_gemm.hpp:221
static constexpr auto I4
Definition: gridwise_moe_mx_gemm.hpp:181
static constexpr auto I2
Definition: gridwise_moe_mx_gemm.hpp:179
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_mx_gemm.hpp:307
static constexpr auto MXdlPack
Definition: gridwise_moe_mx_gemm.hpp:202
static constexpr bool is_single_rate_mfma
Definition: gridwise_moe_mx_gemm.hpp:197
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_mx_gemm.hpp:1083
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm.hpp:1358
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:267
static constexpr auto BK0Number
Definition: gridwise_moe_mx_gemm.hpp:192
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:1343
Definition: xdlops_gemm.hpp:942
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1343
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Definition: thread_group_tensor_slice_transfer_direct_load.hpp:55
Definition: thread_group_tensor_slice_transfer_gather_direct_load.hpp:57
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:25
Definition: data_type.hpp:41
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:197
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:981
Definition: unary_element_wise_operation.hpp:308
Definition: unary_element_wise_operation.hpp:1023
Definition: dtype_vector.hpp:10
#define CK_ENV(name)
Definition: env.hpp:129