/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp Source File
gridwise_moe_gemm_blockscale.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
16 
18 
19 #define DEBUG_LOG 0
20 
21 namespace ck {
22 
23 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
24 // kernel function Blockers:
25 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
26 // two lds chunks.
27 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
28 // buffer when we declare __shared__ inside blkgemmpipe
29 
31 {
32  gelu_and_mul = 0,
33  silu_and_mul = 1
34 };
35 
36 template <typename GridwiseGemm,
37  bool HasMainKBlockLoop,
38  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
39  index_t MinimumOccupancy = 1,
40  TailNumber TailNum = TailNumber::Even>
41 __global__ void
42 #if CK_USE_LAUNCH_BOUNDS
43  __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
44 #endif
45  // __attribute__((amdgpu_waves_per_eu(1, 1)))
46  kernel_moe_gemm(typename GridwiseGemm::Argument karg)
47 {
48 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
49  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
50 
51  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
52 
53  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
54  karg.p_sorted_token_ids,
55  karg.p_sorted_expert_ids,
56  karg.p_max_token_id,
57  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
58  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
59  karg.p_ds_grid,
60  karg.p_c_grid,
61  karg.p_a_scale_grid,
62  karg.p_b_scale_grid,
63  p_shared,
64  karg,
65  karg.a_element_op,
66  karg.b_element_op,
67  karg.c_element_op);
68 #else
69  ignore = karg;
70 #endif // end of if (defined(__gfx9__))
71 }
72 
73 template <typename GridwiseGemm,
74  bool HasMainKBlockLoop,
75  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
76  index_t MinimumOccupancy = 1,
77  TailNumber TailNum = TailNumber::Even>
78 __global__ void
79 #if CK_USE_LAUNCH_BOUNDS
80  __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
81 #endif
82  // __attribute__((amdgpu_waves_per_eu(1, 1)))
83  kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
84 {
85 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
86  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
87  __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
88 
89  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
90 
91  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
92  karg.p_sorted_token_ids,
93  karg.p_sorted_expert_ids,
94  karg.p_max_token_id,
95  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
96  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
97  karg.p_ds_grid,
98  karg.p_c_grid,
99  karg.p_a_scale_grid,
100  karg.p_b_scale_grid,
101  p_shared,
102  p_shared1,
103  karg,
104  karg.a_element_op,
105  karg.b_element_op,
106  karg.c_element_op);
107 #else
108  ignore = karg;
109 #endif // end of if (defined(__gfx9__))
110 }
111 
112 template <typename ALayout,
113  typename BLayout,
114  typename DsLayout,
115  typename CLayout,
116  typename ADataType,
117  typename BDataType,
118  typename AccDataType,
119  typename CShuffleDataType,
120  typename DsDataType,
121  typename CDataType,
122  typename AElementwiseOperation,
123  typename BElementwiseOperation,
124  typename CElementwiseOperation,
126  index_t BlockSize,
127  index_t ScaleBlockM,
128  index_t ScaleBlockN,
129  index_t ScaleBlockK,
130  index_t MPerBlock,
131  index_t NPerBlock,
132  index_t KPerBlock,
133  index_t AK1Value,
134  index_t BK1Value,
135  index_t MPerXdl,
136  index_t NPerXdl,
137  index_t MXdlPerWave,
138  index_t NXdlPerWave,
139  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
140  typename ABlockTransferThreadClusterArrangeOrder,
141  typename ABlockTransferSrcAccessOrder,
142  index_t ABlockTransferSrcVectorDim,
143  index_t ABlockTransferSrcScalarPerVector,
144  index_t ABlockTransferDstScalarPerVector_AK1,
145  bool AThreadTransferSrcResetCoordinateAfterRun,
146  index_t ABlockLdsExtraM,
147  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
148  typename BBlockTransferThreadClusterArrangeOrder,
149  typename BBlockTransferSrcAccessOrder,
150  index_t BBlockTransferSrcVectorDim,
151  index_t BBlockTransferSrcScalarPerVector,
152  index_t BBlockTransferDstScalarPerVector_BK1,
153  bool BThreadTransferSrcResetCoordinateAfterRun,
154  index_t BBlockLdsExtraN,
155  index_t CShuffleMXdlPerWavePerShuffle,
156  index_t CShuffleNXdlPerWavePerShuffle,
157  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
158  typename CDEShuffleBlockTransferScalarPerVectors,
161  index_t ActivationOperation = 0,
162  bool NSwizzle = false,
163  bool IsInputGemm = true,
164  bool MulRoutedWeight = true,
165  typename IndexType = index_t,
166  typename ComputeTypeA = CDataType,
167  typename ComputeTypeB = ComputeTypeA,
168  typename LDSTypeA = ADataType,
169  typename LDSTypeB = BDataType>
171 {
172  using AScaleType = float;
173  using BScaleType = float;
174 
175  static constexpr auto I0 = Number<0>{};
176  static constexpr auto I1 = Number<1>{};
177  static constexpr auto I2 = Number<2>{};
178  static constexpr auto I3 = Number<3>{};
179  static constexpr auto I4 = Number<4>{};
180  static constexpr auto I5 = Number<5>{};
181  static constexpr auto I6 = Number<6>{};
182  static constexpr auto I7 = Number<7>{};
183 
185  CDEShuffleBlockTransferScalarPerVectors{}[I0];
186  // K1 should be Number<...>
187  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
188  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
189  static constexpr auto AK1Number = Number<AK1Value>{};
190  static constexpr auto BK1Number = Number<BK1Value>{};
191  static constexpr auto BlockSizeNumber = Number<BlockSize>{};
192 
193  static constexpr index_t NumDTensor = DsDataType::Size();
194 
196  static constexpr index_t KPack =
198  static constexpr index_t KGroup = []() {
200  // On gfx950, we have a mfma that required 32 f8 elements as input,
201  // splited into 2 groups of 16 f8 elements.
202  // the 2 groups is not contiguous in the B preshuffed layout.
203  // and we do not want it to be contiguous in the B preshuffled layout
204  // because a memory instruction can only read 16 f8 elements at a time.
205  return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
206  else
207  return 1;
208  }();
209  static constexpr index_t KLane =
211  static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
212  static constexpr index_t NLane = NPerXdl;
213  static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
214  // static constexpr index_t NumTokens = 1;
215  static constexpr index_t SortedTileSize = MPerBlock;
216 
217  static constexpr auto MakeDsGridPointer()
218  {
219  return generate_tuple(
220  [&](auto i) {
221  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
222 
223  return static_cast<const DDataType*>(nullptr);
224  },
226  }
227 
228  using DsGridPointer = decltype(MakeDsGridPointer());
229 
231 
232  static constexpr index_t APackedSize = []() {
234  return 2;
235  else
236  return 1;
237  }();
238 
239  static constexpr index_t BPackedSize = []() {
241  return 2;
242  else
243  return 1;
244  }();
245 
246  __host__ static auto CalculateGridSize(index_t M, index_t N)
247  {
248  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
249  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
250  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
251  const index_t gridy = NSwizzle ? 1 : mblock;
252  return std::make_tuple(gridx, gridy, 1);
253  }
254 
255  __host__ __device__ static auto CalculateMPadded(index_t M)
256  {
257  return math::integer_least_multiple(M, MPerBlock);
258  }
259 
260  __host__ __device__ static auto CalculateNPadded(index_t N)
261  {
262  return math::integer_least_multiple(N, NPerBlock);
263  }
264 
265  __host__ __device__ static auto CalculateBN0Shuffled(index_t N)
266  {
267  return math::integer_divide_ceil(N, NLane);
268  }
269  __host__ __device__ static auto CalculateBK0Shuffled(index_t K)
270  {
272  }
273 
274  __host__ __device__ static auto CalculateKPadded(index_t K)
275  {
276  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
277  }
278 
279  __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
280  {
281  auto K_t = K_Batch * KPerBlock;
282  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
283  }
284 
285  __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
286  {
287  auto K_t = K_Batch * KPerBlock;
288  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
289  }
290 
291  __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
292  {
293  auto K_t = K_Batch * KPerBlock;
294  return (K + K_t - 1) / K_t * KPerBlock;
295  }
296 
297  __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
298  {
299  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
300  auto K_t = K_Batch * KReadVec;
301  return (K + K_t - 1) / K_t * KReadVec;
302  }
303 
304  __host__ __device__ static auto CalculateMBlock(index_t M)
305  {
306  return math::integer_divide_ceil(M, MPerBlock);
307  }
308 
309  __host__ __device__ static auto CalculateNBlock(index_t N)
310  {
311  return math::integer_divide_ceil(N, NPerBlock);
312  }
313 
314  template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
315  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
316  {
317  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
318  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
319 
321  TileDesc_K0_MN_K1{},
327  }
328 
329  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
330  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
331  {
332  const auto a_grid_desc_mraw_kraw = [&]() {
333  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
334  {
335  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
336  }
337  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
338  {
339  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
340  }
341  }();
342 
344 
345  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
346  GemmSpec == GemmSpecialization::MNKPadding)
347  {
348  // pad both M and K
349  const auto a_grid_desc_m_k =
350  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
352  make_right_pad_transform(K, KPad - K)),
355 
356  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
357  a_grid_desc_m_k,
362 
363  return a_grid_desc_ak0_m_ak1;
364  }
365  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
366  GemmSpec == GemmSpecialization::MNPadding)
367  {
368  // pad M, but not K
369  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
370  a_grid_desc_mraw_kraw,
372  make_right_pad_transform(M, MPad - M)),
375 
376  return a_grid_desc_ak0_m_ak1;
377  }
378  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
379  GemmSpec == GemmSpecialization::NKPadding)
380  {
381  // pad K, but not M
382  const auto a_grid_desc_m_k = transform_tensor_descriptor(
383  a_grid_desc_mraw_kraw,
387 
388  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
389  a_grid_desc_m_k,
394 
395  return a_grid_desc_ak0_m_ak1;
396  }
397  else
398  {
399  // not pad M or K
400  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
401  a_grid_desc_mraw_kraw,
406 
407  return a_grid_desc_ak0_m_ak1;
408  }
409  }
410 
411  __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
412  {
413  constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack / KGroup>{};
415  make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
416  make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
417  }
418 
419  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
420  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
421  {
422  const auto b_grid_desc_nraw_kraw = [&]() {
424  {
425  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
426  }
428  {
429  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
430  }
431  }();
432 
434 
435  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
436  GemmSpec != GemmSpecialization::Default),
437  "pk_i4_t does not support padding");
438 
439  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
440  GemmSpec == GemmSpecialization::MNKPadding)
441  {
442  // pad both N and K
443  const auto b_grid_desc_n_k =
444  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
446  make_right_pad_transform(K, KPad - K)),
449 
450  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
451  b_grid_desc_n_k,
456 
457  return b_grid_desc_bk0_n_bk1;
458  }
459  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
460  GemmSpec == GemmSpecialization::MNPadding)
461  {
462  // pad N, but not K
463  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
464  b_grid_desc_nraw_kraw,
466  make_right_pad_transform(N, NPad - N)),
469 
470  return b_grid_desc_bk0_n_bk1;
471  }
472  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
473  GemmSpec == GemmSpecialization::MKPadding)
474  {
475  // pad K, but not N
476  const auto b_grid_desc_n_k = transform_tensor_descriptor(
477  b_grid_desc_nraw_kraw,
481 
482  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
483  b_grid_desc_n_k,
488 
489  return b_grid_desc_bk0_n_bk1;
490  }
491  else
492  {
493  // not pad N or K
494  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
495  b_grid_desc_nraw_kraw,
500 
501  return b_grid_desc_bk0_n_bk1;
502  }
503  }
504 
505  template <typename ABlockDesc_AK0_M_AK1>
506  __host__ __device__ static constexpr auto
507  MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
508  {
509  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
510 
511  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
512  }
513 
514  template <typename BBlockDesc_BK0_N_BK1>
515  __host__ __device__ static constexpr auto
516  MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
517  {
518  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
519  }
520 
521  template <typename ELayout>
522  __host__ __device__ static auto MakeCGridDescriptor_M_N(
523  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
524  {
525  const auto c_grid_desc_mraw_nraw = [&]() {
527  {
528  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
529  }
531  {
532  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
533  }
534  }();
535 
536  // pad M and N
537  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
539  make_right_pad_transform(N, NPad - N)),
542  }
543 
544  template <typename DLayout>
545  __host__ __device__ static auto
547  {
548  const auto c_grid_desc_mraw_nraw = [&]() {
550  {
551  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
552  }
554  {
555  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
556  }
557  }();
558 
559  // pad M and N
560  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
562  make_right_pad_transform(N, NPad - N)),
565  }
566 
567  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
568  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
569  {
570  return generate_tuple(
571  [&](auto i) {
572  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
573  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
574  },
576  }
577 
578  template <typename DsGridDesc>
580  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
581  {
582  return generate_tuple(
583  [&](auto i) {
585  ds_grid_desc_m_n[i], MBlock, NBlock);
586  },
588  }
589 
590  using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))>;
591 
592  struct Problem
593  {
594  __host__ __device__ Problem(index_t NumTokens_,
595  index_t TopK_,
596  index_t M_,
597  index_t N_,
598  index_t K_,
599  index_t StrideA_,
600  index_t StrideB_,
601  std::array<index_t, NumDTensor> StrideDs_,
602  index_t StrideC_,
603  index_t KBatch_)
604  : NumTokens{NumTokens_},
605  TopK{TopK_},
606  M{M_},
607  N{N_},
608  K{K_},
609  StrideA{StrideA_},
610  StrideB{StrideB_},
611  StrideDs{StrideDs_},
612  StrideC{StrideC_},
613  KBatch{KBatch_},
616  KRead{CalculateKRead(K_, KBatch_)},
617  KPadded{CalculateKPadded(K_, KBatch_)},
618  AK0{CalculateAK0Padded(K_, KBatch_)},
619  BK0{CalculateBK0Padded(K_, KBatch_)},
620  MBlock{CalculateMBlock(M_)},
621  NBlock{CalculateNBlock(N_)},
624  {
625  }
626 
627  __host__ void Print() const
628  {
629  std::cout << "problem {"
630  << "NumTokens:" << NumTokens << ", "
631  << "TopK:" << TopK << ", "
632  << "M:" << M << ", "
633  << "N:" << N << ", "
634  << "K:" << K << ", "
635  << "SA:" << StrideA << ", "
636  << "SB:" << StrideB << ", "
637  << "SC:" << StrideC << ", "
638  << "MP:" << MPadded << ", "
639  << "NP:" << NPadded << ", "
640  << "KRead:" << KRead << ", "
641  << "KP:" << KPadded << ", "
642  << "AK0:" << AK0 << ", "
643  << "BK0:" << BK0 << ", "
644  << "MBlock: " << MBlock << ", "
645  << "NBlock: " << NBlock << "}" << std::endl;
646  }
647 
655  std::array<index_t, NumDTensor> StrideDs;
666  // FOR PRESHUFFLE ONLY
669  };
670 
671  // Argument
673  {
674  __host__ Argument(const index_t* p_sorted_token_ids_,
675  const index_t* p_sorted_expert_ids_,
676  const index_t* p_max_token_id_,
677  const ADataType* p_a_grid_,
678  const BDataType* p_b_grid_,
679  std::array<const void*, NumDTensor> p_ds_grid_,
680  CDataType* p_c_grid_,
681  index_t NumTokens_,
682  index_t TopK_,
683  index_t M_,
684  index_t N_,
685  index_t K_,
686  index_t StrideA_,
687  index_t StrideB_,
688  std::array<index_t, NumDTensor> StrideDs_,
689  index_t StrideC_,
690  const AScaleType* p_a_scale_grid_,
691  const BScaleType* p_b_scale_grid_,
692  index_t k_batch_,
693  AElementwiseOperation a_element_op_,
694  BElementwiseOperation b_element_op_,
695  CElementwiseOperation c_element_op_)
696  : Problem{NumTokens_,
697  TopK_,
698  M_,
699  N_,
700  K_,
701  StrideA_,
702  StrideB_,
703  StrideDs_,
704  StrideC_,
705  k_batch_},
706  p_sorted_token_ids{p_sorted_token_ids_},
707  p_sorted_expert_ids{p_sorted_expert_ids_},
708  p_max_token_id{p_max_token_id_},
709  p_a_grid{p_a_grid_},
710  p_b_grid{p_b_grid_},
711  p_ds_grid{},
712  p_c_grid{p_c_grid_},
713  p_a_scale_grid{p_a_scale_grid_},
714  p_b_scale_grid{p_b_scale_grid_},
715  a_element_op{a_element_op_},
716  b_element_op{b_element_op_},
717  c_element_op{c_element_op_}
718  {
719 
720  // populate pointer, desc for Ds
721  static_for<0, NumDTensor, 1>{}([&](auto i) {
722  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
723 
724  // D pointer
725  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
726  });
727  }
728 
732  const ADataType* p_a_grid;
733  const BDataType* p_b_grid;
735  CDataType* p_c_grid;
736 
739 
740  const AElementwiseOperation a_element_op;
741  const BElementwiseOperation b_element_op;
742  const CElementwiseOperation c_element_op;
743  };
744 
746  {
747  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
748  {
749  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
750  {
751  a_k_split_offset = k_id * karg.KRead / APackedSize;
752  }
753  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
754  {
755  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
756  }
757 
758  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
759  {
760  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
761  }
762  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
763  {
764  // KPack * NLane * KLane * K0 * N0
765  b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
766  }
767 
768  if(k_id < karg.KBatch - 1)
769  {
770  karg.K = karg.KRead;
771  }
772  else
773  {
774  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
775  }
776  }
777 
780  };
781 
782  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
783  {
784  // A matrix in LDS memory, dst of blockwise copy
785  if constexpr(ABlockLdsExtraM)
786  {
790  }
791  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
792  // in some cases.
794  {
795  constexpr auto a_lds_block_desc =
798 
799  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
800  a_lds_block_desc,
806 
807  return a_lds_block_desc_permuted;
808  }
809  else // ColumnMajor A
810  {
811  // kfold and mpair dimension is not always required.
812  // more dimension in merge_transform increase the difficulty of generating immarg offset
813  // for compiler.
814  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
815  constexpr auto M1 = MPerBlock / M0;
816 
817  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
818  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
819  constexpr auto KThreadRead = 64 / MPerXdl;
820  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
821 
822  constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
823  ? 1
824  : 128 / (AK1Number * M0 * sizeof(LDSTypeA));
825  constexpr auto KThreadReadPerm =
826  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
827  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
828  : KThreadRead;
829 
830  // 1<=mpair<=n0
831  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128)
832  ? 1
833  : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0
834  ? M0
835  : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA)));
836 
837  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
841  Number<kfold * M0 / mpair>{},
842  Number<mpair>{},
843  AK1Number));
844 
845  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
846  a_lds_block_desc,
847  make_tuple(
851  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
854  make_tuple(
856  make_tuple(
858 
859  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
860  a_lds_block_desc_permuted,
861  make_tuple(
869  Sequence<1>{},
870  Sequence<2>{},
871  Sequence<3>{},
872  Sequence<4>{},
873  Sequence<5>{}),
875  Sequence<2>{},
876  Sequence<0, 3>{},
877  Sequence<4, 5>{},
878  Sequence<6>{},
879  Sequence<7>{}));
880 
881  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
882  a_lds_block_desc_unmerged,
885  Number<KThreadWrite / kfold / KThreadReadPerm>{},
886  Number<kfold>{},
893 
894  return a_lds_block_desc_ak0_m_ak1;
895  }
896  }
897 
898  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
899  {
900  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
903  }
904 
906  {
907  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
908 
909  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
911  make_tuple(I1,
913  I1,
915 
916  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
917  }
918 
921  BlkGemmPipelineVer,
922  BlkGemmPipeSched,
923  BlockSize,
924  ADataType,
925  BDataType,
926  ComputeTypeA,
927  AccDataType,
934  ABlockTransferSrcScalarPerVector,
935  BBlockTransferSrcScalarPerVector,
936  MPerBlock,
937  NPerBlock,
938  KPerBlock,
939  ScaleBlockM,
940  ScaleBlockN,
941  ScaleBlockK,
942  MPerXdl,
943  NPerXdl,
944  MXdlPerWave,
945  NXdlPerWave,
946  KPack,
947  IsInputGemm>())>;
948 
949  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
950  {
951  // LDS allocation for A and B: be careful of alignment
952  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
953  // lds max alignment
954  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
955 
956  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
957  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
958 
959  // LDS allocation for C shuffle in LDS
960  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
962 
963  constexpr auto c_block_size =
964  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
965 
966  return math::max(a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize,
967  c_block_size * sizeof(CShuffleDataType));
968  }
969 
970  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
971  __host__ static constexpr bool CheckValidity(const Argument& karg)
972  {
973  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
974  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
975  "Invalid tuning param!");
976 
977  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
982  {
983  if(!(karg.M % MPerBlock == 0))
984  {
985 #if DEBUG_LOG
986  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
987  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
988  << std::endl;
989 
990 #endif // DEBUG_LOG
991  return false;
992  }
993  }
994 
995  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1000  {
1001  if(!(karg.N % NPerBlock == 0))
1002  {
1003 #if DEBUG_LOG
1004  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1005  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1006  << std::endl;
1007 
1008 #endif // DEBUG_LOG
1009  return false;
1010  }
1011  }
1012 
1013  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1017  {
1018 
1019  auto K_t = karg.KBatch * KPerBlock;
1020  if(!(karg.K % K_t == 0))
1021  {
1022 #if DEBUG_LOG
1023  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1024  << karg.K << " " << __FILE__ << ":" << __LINE__
1025  << ", in function: " << __func__ << std::endl;
1026 
1027 #endif // DEBUG_LOG
1028  return false;
1029  }
1030  }
1031  else
1032  {
1033  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1034  auto K_t = karg.KBatch * KReadVec;
1035  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1036  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1037  {
1038  return false;
1039  }
1040  }
1041 
1043  {
1044  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1045  {
1046 #if DEBUG_LOG
1047  std::cout << "Arg K (" << karg.K
1048  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1049  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1050  << __LINE__ << ", in function: " << __func__ << std::endl;
1051 
1052 #endif // DEBUG_LOG
1053  return false;
1054  }
1055  }
1056  else
1057  {
1058  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1059  {
1060 #if DEBUG_LOG
1061  std::cout << "Arg M (" << karg.M
1062  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1063  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1064  << __LINE__ << ", in function: " << __func__ << std::endl;
1065 
1066 #endif // DEBUG_LOG
1067  return false;
1068  }
1069  }
1070 
1072  {
1073  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1074  {
1075 #if DEBUG_LOG
1076  std::cout << "Arg N (" << karg.N
1077  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1078  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1079  << __LINE__ << ", in function: " << __func__ << std::endl;
1080 
1081 #endif // DEBUG_LOG
1082  return false;
1083  }
1084  }
1085  else
1086  {
1087  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1088  {
1089 #if DEBUG_LOG
1090  std::cout << "Arg K (" << karg.K
1091  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1092  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1093  << __LINE__ << ", in function: " << __func__ << std::endl;
1094 
1095 #endif // DEBUG_LOG
1096  return false;
1097  }
1098  }
1099 
1101  {
1103  {
1104 #if DEBUG_LOG
1105  std::cout << "Arg N (" << karg.N
1106  << ") value is not a multiple of "
1107  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1108  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1109  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1110 
1111 #endif // DEBUG_LOG
1112  return false;
1113  }
1114  }
1115  else
1116  {
1118  {
1119 #if DEBUG_LOG
1120  std::cout << "Arg M (" << karg.M
1121  << ") value is not a multiple of "
1122  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1123  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1124  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1125 
1126 #endif // DEBUG_LOG
1127  return false;
1128  }
1129  }
1130 
1131  // check gridwise gemm pipeline
1132 #if 0
1133  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1134 
1135  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1136  {
1137  return false;
1138  }
1139 #endif
1140  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1141  return true;
1142  }
1143 
1144  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1145  {
1146  const index_t num_loop = K / KPerBlock;
1147 
1148  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1149  }
1150 
1151  __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1152  {
1153  const index_t num_loop = K / KPerBlock;
1154 
1155  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1156  }
1157 
1158  template <typename CGridDesc>
1160  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1161  {
1162  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1163  c_grid_desc_m_n,
1168 
1169  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1170  }
1171 
1172  // return block_id to C matrix tile idx (m0, n0) mapping
1173  // if arch = gfx942
1174  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1175  // NPerBlock>;
1176 
1177  template <bool HasMainKBlockLoop,
1178  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1179  TailNumber TailNum = TailNumber::Odd>
1180  __device__ static void Run(const index_t* p_sorted_token_ids,
1181  const index_t* p_sorted_expert_ids,
1182  const index_t* p_max_token_id,
1183  const ADataType* p_a_grid,
1184  const BDataType* p_b_grid,
1185  DsGridPointer& p_ds_grid,
1186  CDataType* p_c_grid,
1187  const AScaleType* p_a_scale_grid,
1188  const BScaleType* p_b_scale_grid,
1189  void* p_shared,
1190  const Problem& problem,
1191  AElementwiseOperation a_element_op,
1192  BElementwiseOperation b_element_op,
1193  CElementwiseOperation c_element_op)
1194  {
1195  ignore = b_element_op;
1196  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1197  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1198  problem.MPadded,
1199  problem.K,
1200  problem.KPadded,
1201  problem.StrideA,
1202  problem.AK0);
1203  const auto b_grid_desc_bpreshuffled =
1205  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1206  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1207  problem.MPadded,
1208  problem.N,
1209  problem.NPadded,
1210  problem.StrideC);
1211 
1212  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
1213  make_tuple(math::integer_divide_ceil(IsInputGemm ? problem.NumTokens
1214  : problem.NumTokens * problem.TopK,
1215  ScaleBlockM),
1216  math::integer_divide_ceil(problem.K, ScaleBlockK)),
1217  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1218  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
1219  make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
1220  math::integer_divide_ceil(problem.K, ScaleBlockK)),
1221  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1222 
1223  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1225  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1226  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1227  // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
1228  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1229  if(expert_block_id * MPerBlock >= max_token_id)
1230  return;
1231  const index_t expert_id =
1232  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1233  const auto block_mn = [&]() -> std::pair<int, int> {
1234  if constexpr(NSwizzle)
1235  {
1236  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1237  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1238  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1239  const index_t expert_swizzle =
1240  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1241  const index_t bid_new = blockIdx.x - prefix_block;
1242  const index_t nid = __builtin_amdgcn_readfirstlane(
1243  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1244  const index_t mid =
1245  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1246  return {nid, mid};
1247  }
1248  else
1249  {
1250  return {blockIdx.x, blockIdx.y};
1251  }
1252  }();
1253  const index_t block_n_id = block_mn.first;
1254  const index_t block_m_id = block_mn.second;
1255  const index_t token0 =
1256  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1257 
1258  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1259  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1260  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1261  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1262  constexpr auto AKThreads = AK0Threads * AK1Threads;
1263  constexpr auto AMRepeats = MPerBlock / AMThreads;
1264  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1265 
1266  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1267  return;
1269  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1270  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1271  index_t token_offset = fused_token & 0xffffff;
1272  if constexpr(!IsInputGemm)
1273  {
1274  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1275  }
1276  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1277  });
1278  const index_t expert_stride =
1279  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1280  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1281  math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
1282  math::integer_divide_ceil(problem.K, ScaleBlockK));
1283 
1284  // N0, K0, Blocksize*KPack
1285  const index_t n_block_data_idx_on_grid =
1286  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1287 
1288  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1289  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1290  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1291  p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
1292  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1293 
1294  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1295  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1296  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1297  p_b_scale_grid + expert_id * expert_scale_stride,
1298  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1299 
1300  // A matrix in LDS memory, dst of blockwise copy
1301  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1302 
1303  // B matrix in LDS memory, dst of blockwise copy
1304  // dummy
1305  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1306  // A matrix blockwise copy
1307  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1309  AElementwiseOperation,
1313  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1314  ABlockTransferThreadClusterArrangeOrder,
1315  ADataType,
1316  LDSTypeA,
1317  decltype(a_grid_desc_ak0_m_ak1),
1318  decltype(a_block_desc_ak0_m_ak1),
1319  ABlockTransferSrcAccessOrder,
1321  ABlockTransferSrcVectorDim,
1322  2,
1323  ABlockTransferSrcScalarPerVector,
1324  ABlockTransferDstScalarPerVector_AK1,
1325  1,
1326  1,
1327  AThreadTransferSrcResetCoordinateAfterRun,
1328  true,
1329  IndexType,
1330  1,
1331  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1332  make_multi_index(0, 0, 0),
1333  a_element_op,
1334  a_block_desc_ak0_m_ak1,
1335  make_multi_index(0, 0, 0),
1337  gather_offsets);
1338 
1339  // Thread-wise copy
1340  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1341  auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1342  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1343 
1344  auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
1345  BDataType,
1346  BDataType,
1347  decltype(b_grid_desc_bpreshuffled),
1348  decltype(b_block_desc_bk0_n_bk1),
1351  3,
1352  BBlockTransferSrcScalarPerVector,
1353  BThreadTransferSrcResetCoordinateAfterRun,
1354  true>(b_grid_desc_bpreshuffled,
1355  make_multi_index(n_block_data_idx_on_grid,
1357  0,
1358  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1359 
1360  // LDS allocation for A and B: be careful of alignment
1361  // Cast after lds
1362  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1363  static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1364 
1365  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1366  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
1367 
1368  // Blockwise GEMM pipeline
1369  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1370  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1371  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1372  decltype(c_thread_buf) c_thread_buf_up;
1373 
1374  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1375  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1376  KPerBlock);
1377 
1378  constexpr index_t ScaleSliceSizeM = MXdlPerWave;
1379  constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
1380  constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
1381 
1382  // ScaleSliceSizeK is last dimension in A/B scale for vector memory access
1383  // ScaleSliceSizeK is first dimension in C scale for packed math
1384  constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
1386 
1387  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
1388  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
1389  auto a_thread_offset =
1390  get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl;
1391 
1392  constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
1394 
1395  constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
1397 
1398  // get each thread's offset in the scale tensor
1399  // A scale
1400  const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
1401 
1402  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
1403  return;
1404  StaticallyIndexedArray<index_t, MXdlPerWave> scale_gather_offsets;
1405  static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
1406  const index_t fused_token =
1407  p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
1408  index_t token_offset = fused_token & 0xffffff;
1409  if constexpr(!IsInputGemm)
1410  {
1411  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1412  }
1413  scale_gather_offsets(m0) =
1414  token_offset * math::integer_divide_ceil(problem.K, ScaleBlockK);
1415  });
1416 
1417  auto a_scale_thread_copy =
1419  AScaleType,
1420  decltype(a_scale_grid_desc_am_ak),
1421  decltype(a_scale_thread_desc),
1424  1,
1425  ScaleSliceSizeK,
1426  1,
1427  false,
1428  MXdlPerWave>(
1429  a_scale_grid_desc_am_ak, make_multi_index(0, 0), scale_gather_offsets);
1430 
1431  auto b_scale_thread_copy =
1433  BScaleType,
1434  decltype(b_scale_grid_desc_bn_ak),
1435  decltype(b_scale_thread_desc),
1438  1,
1439  ScaleSliceSizeK,
1440  1,
1441  false>(
1442  b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
1443 
1444  // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
1445  constexpr auto a_scale_thread_slice_copy_step =
1446  make_tuple(make_multi_index(0, 0), make_multi_index(0, ScaleSliceSizeK));
1447  constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
1448 
1449  constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
1450  if constexpr(IsInputGemm)
1451  {
1452  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
1453  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1454  p_b_grid_up + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
1455  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1456  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
1457  BDataType,
1458  BDataType,
1459  decltype(b_grid_desc_bpreshuffled),
1460  decltype(b_block_desc_bk0_n_bk1),
1463  3,
1464  BBlockTransferSrcScalarPerVector,
1465  BThreadTransferSrcResetCoordinateAfterRun,
1466  true>(b_grid_desc_bpreshuffled,
1467  make_multi_index(n_block_data_idx_on_grid,
1469  0,
1470  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1471  const BScaleType* p_b_scale_grid_up =
1472  p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
1473  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1474  p_b_scale_grid_up + expert_id * expert_scale_stride,
1475  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1476  auto b_scale_thread_copy_up =
1478  BScaleType,
1479  decltype(b_scale_grid_desc_bn_ak),
1480  decltype(b_scale_thread_desc),
1483  1,
1484  ScaleSliceSizeK,
1485  1,
1486  false>(
1487  b_scale_grid_desc_bn_ak,
1488  make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
1489 
1490  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1491  a_grid_desc_ak0_m_ak1,
1492  a_block_desc_ak0_m_ak1,
1493  a_blockwise_copy,
1494  a_grid_buf,
1495  a_block_buf,
1496  a_block_slice_copy_step,
1497 
1498  b_grid_desc_bpreshuffled,
1499  b_block_desc_bk0_n_bk1,
1500  b_blockwise_copy,
1501  b_blockwise_copy_up,
1502  b_grid_buf,
1503  b_grid_buf_up,
1504  b_block_buf,
1505  b_block_slice_copy_step,
1506 
1507  c_scale_thread_desc,
1508  c_thread_buf,
1509  c_thread_buf_up,
1510 
1511  a_scale_grid_desc_am_ak,
1512  a_scale_thread_desc,
1513  a_scale_thread_copy,
1514  a_scale_grid_buf,
1515  a_scale_thread_slice_copy_step,
1516 
1517  b_scale_grid_desc_bn_ak,
1518  b_scale_thread_desc,
1519  b_scale_thread_copy,
1520  b_scale_thread_copy_up,
1521  b_scale_grid_buf,
1522  b_scale_grid_buf_up,
1523  b_scale_thread_slice_copy_step,
1524 
1525  num_k_block_main_loop);
1526  }
1527  else
1528  {
1529  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1530  a_grid_desc_ak0_m_ak1,
1531  a_block_desc_ak0_m_ak1,
1532  a_blockwise_copy,
1533  a_grid_buf,
1534  a_block_buf,
1535  a_block_slice_copy_step,
1536 
1537  b_grid_desc_bpreshuffled,
1538  b_block_desc_bk0_n_bk1,
1539  b_blockwise_copy,
1540  b_grid_buf,
1541  b_block_buf,
1542  b_block_slice_copy_step,
1543 
1544  c_scale_thread_desc,
1545  c_thread_buf,
1546 
1547  a_scale_grid_desc_am_ak,
1548  a_scale_thread_desc,
1549  a_scale_thread_copy,
1550  a_scale_grid_buf,
1551  a_scale_thread_slice_copy_step,
1552 
1553  b_scale_grid_desc_bn_ak,
1554  b_scale_thread_desc,
1555  b_scale_thread_copy,
1556  b_scale_grid_buf,
1557  b_scale_thread_slice_copy_step,
1558 
1559  num_k_block_main_loop);
1560  }
1561 
1562  // shuffle C and write out
1563  {
1564  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1565  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1566  "wrong!");
1567 
1568  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1569 
1570  // transposed XDL
1571  // TODO: hacky, fix it!
1572  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
1573  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1574 
1575  // TODO: hacky, fix it!
1576  // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
1577  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
1578  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1579 
1580  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
1581  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
1582  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
1583  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
1584  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
1585  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
1586  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
1587  constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
1588 
1589  static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
1590  static_assert(M0 * M1 * M2 == MPerBlock);
1591  static_assert(N4 == 4);
1592  const index_t m1 = get_warp_local_1d_id() / NWave;
1593  const index_t m2 = threadIdx.x % get_warp_size() % M2;
1594 
1595  float topk_weight;
1596  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
1597  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
1598  if constexpr(MulRoutedWeight)
1599  {
1600  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
1601  topk_weight = p_ds_grid[I0][m_pos];
1602  }
1603  static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
1604  static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
1605  constexpr index_t c_offset =
1606  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1607  make_tuple(m0, n0, n2 * N4 + n4));
1608  constexpr auto cidx = Number<c_offset>{};
1609  if constexpr(IsInputGemm) // gu fusion, elementwise
1610  {
1611  if constexpr(ActivationOperation == Activation::silu_and_mul)
1612  {
1613  float gate = c_thread_buf[cidx];
1614  float up = c_thread_buf_up[cidx];
1615  if constexpr(MulRoutedWeight)
1616  {
1617  gate = gate * topk_weight;
1618  up = up * topk_weight;
1619  }
1621  {
1622  gate *= 16;
1623  up *= 16;
1624  }
1626  c_thread_buf(cidx) = gate * up;
1627  }
1628  else if(ActivationOperation == Activation::gelu_and_mul)
1629  {
1630  float gate = c_thread_buf[cidx];
1631  float up = c_thread_buf_up[cidx];
1632  if constexpr(MulRoutedWeight)
1633  {
1634  gate = gate * topk_weight;
1635  up = up * topk_weight;
1636  }
1638  {
1639  gate *= 16;
1640  up *= 16;
1641  }
1643  c_thread_buf(cidx) = gate * up;
1644  }
1645  }
1646  else
1647  {
1648  if constexpr(MulRoutedWeight)
1649  {
1650  c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
1651  }
1652  }
1653  });
1654  });
1655  });
1656  });
1657 
1658  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1660 
1661  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1662  static_cast<CShuffleDataType*>(p_shared),
1663  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1664 
1665  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
1666  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1667  make_tuple(
1670  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1671  M1, // M1 = MWave
1672  M2)), // M2 = MPerXdl
1675  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1676  N1, // N1 = NWave
1677  N2, // N2 * N3 * N4 = NPerXdl
1678  N3,
1679  N4))),
1681  make_tuple(
1683 
1684  // calculate origin of thread output tensor on global memory
1685  // blockwise GEMM c matrix starting index
1686  const auto c_thread_mtx_on_block =
1687  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1688 
1689  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1690  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1691 
1692  const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
1696  make_tuple(Sequence<0>{}));
1697 
1698  const auto m_thread_data_on_block_idx =
1699  m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
1700  make_multi_index(m_thread_data_on_block));
1701 
1702  const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
1704  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
1706  make_tuple(Sequence<0>{}));
1707 
1708  const auto n_thread_data_on_block_idx =
1709  n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
1710  make_multi_index(n_thread_data_on_block));
1711 
1712  // shuffle: threadwise copy C from VGPR to LDS
1713  auto c_thread_copy_vgpr_to_lds =
1715  CShuffleDataType,
1716  decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1717  decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1719  Sequence<CShuffleMXdlPerWavePerShuffle,
1720  CShuffleNXdlPerWavePerShuffle,
1721  I1,
1722  I1,
1723  I1,
1724  N2,
1725  I1,
1726  N4>,
1728  7,
1729  1,
1731  1,
1732  true>{
1733  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1734  make_multi_index(0,
1735  0,
1736  m_thread_data_on_block_idx[I1],
1737  n_thread_data_on_block_idx[I1],
1738  m_thread_data_on_block_idx[I2],
1739  n_thread_data_on_block_idx[I2],
1740  n_thread_data_on_block_idx[I3],
1741  n_thread_data_on_block_idx[I4]),
1743 
1744  using EDataType = CDataType;
1745 
1746  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1747  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1748 
1749  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1751  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1752 
1753  const auto ds_grid_buf = generate_tuple(
1754  [&](auto i) {
1755  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
1756  const DDataType* ptr_ = p_ds_grid[i];
1757  // hack logic here to support different kind of strides. todo fix it.
1758  // ascale t, 1; bscale E, N, 1, move ptr to E
1759  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1760  ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
1761  },
1762  Number<NumDTensor>{});
1763 
1764  // tuple of reference to C/Ds tensor descriptors
1765  const auto c_ds_desc_refs = concat_tuple_of_reference(
1766  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1767  generate_tie(
1768  [&](auto i) -> const auto& // return type should be reference
1769  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1770  Number<NumDTensor>{}));
1771 
1772  // tuple of reference to C/Ds tensor descriptors
1773  const auto c_ds_buf_refs = concat_tuple_of_reference(
1774  tie(c_shuffle_block_buf),
1775  generate_tie(
1776  [&](auto i) -> const auto& // return type should be reference
1777  { return ds_grid_buf[i]; },
1778  Number<NumDTensor>{}));
1779 
1780  // tuple of starting index of C/Ds blockwise copy
1781  const auto idx_c_ds_block_begin =
1784  [&](auto) {
1785  return make_multi_index(block_m_id, 0, block_n_id, 0);
1786  // return make_multi_index(block_work_idx[I0], 0,
1787  // block_work_idx[I1], 0);
1788  },
1789  Number<NumDTensor>{}));
1790 
1791  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1792  c_grid_desc_mblock_mperblock_nblock_nperblock;
1793 
1794  using CDEBlockTransferCluster =
1795  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1796  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1797  constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix
1798  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1800  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1802  decltype(c_ds_desc_refs),
1803  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1804  CElementwiseOperation,
1805  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1806  // support arbitray type
1807  Sequence<1,
1808  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1809  1,
1810  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1811  CDEBlockTransferCluster,
1812  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1813  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1814  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1815  3, // index_t SrcVectorDim,
1816  3, // index_t DstVectorDim,
1817  CDEShuffleBlockTransferScalarPerVectors,
1822  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1823  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1824  IndexType,
1825  1, // ScatterDim
1826  true, // OutputScatter: false, only use scatter weights
1827  scatter_weight_idx // ScatterWeightIdx: ascale
1828  >{c_ds_desc_refs,
1829  idx_c_ds_block_begin,
1830  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1831  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
1832  c_element_op};
1833 
1834  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1835  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1836  // space filling curve for threadwise C in VGPR
1837  constexpr auto sfc_c_vgpr =
1840  Sequence<CShuffleMXdlPerWavePerShuffle,
1841  CShuffleNXdlPerWavePerShuffle,
1842  1,
1843  1,
1844  1,
1845  N2,
1846  1,
1847  N4>>{};
1848 
1849  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1850 
1851  // space filling curve for shuffled blockwise C/D/E
1852  constexpr auto sfc_cde_block =
1855  Sequence<1,
1856  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1857  1,
1858  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1859 
1860  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1861  constexpr auto EMThreads =
1862  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
1863  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1864  constexpr auto ENThreads =
1865  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
1866  static_for<0, num_access, 1>{}([&](auto access_id) {
1867  // make sure it's safe to write to LDS
1869 
1870  auto dstidx = sfc_cde_block.GetIndex(access_id);
1871  const index_t c_token_pos =
1872  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
1873  static_for<0, EMRepeats, 1>{}([&](auto m0) {
1874  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1875  index_t token_offset = fused_token & 0xffffff;
1876  if constexpr(IsInputGemm)
1877  {
1878  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1879  }
1880  scatter_offsets(m0) = token_offset * problem.N;
1881  });
1882 
1883  block_sync_lds();
1884 
1885  // each thread write its data from VGPR to LDS
1886  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1887  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1888  c_thread_buf,
1889  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1890  c_shuffle_block_buf);
1891 
1892  // make sure it's safe to read from LDS
1893  block_sync_lds();
1894 
1895  // each block copy its data from LDS to global
1896  cde_block_copy_lds_and_global.Run(
1897  c_ds_desc_refs,
1898  c_ds_buf_refs,
1899  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1900  tie(c_grid_buf),
1901  scatter_offsets);
1902 
1903  if constexpr(access_id < num_access - 1)
1904  {
1905  constexpr auto cde_lds_and_global_step =
1906  sfc_cde_block.GetForwardStep(access_id);
1907 
1908  // move on Ds
1909  static_for<0, NumDTensor, 1>{}([&](auto i) {
1910  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1911  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
1912  });
1913 
1914  // move on E
1915  cde_block_copy_lds_and_global.MoveDstSliceWindow(
1916  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1917  I0,
1918  cde_lds_and_global_step);
1919  }
1920  });
1921  }
1922  }
1923 
1924  template <bool HasMainKBlockLoop,
1925  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1926  TailNumber TailNum = TailNumber::Odd>
1927  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
1928  const index_t* p_sorted_expert_ids,
1929  const index_t* p_max_token_id,
1930  const ADataType* p_a_grid,
1931  const BDataType* p_b_grid,
1932  DsGridPointer& p_ds_grid,
1933  CDataType* p_c_grid,
1934  const AScaleType* p_a_scale_grid,
1935  const BScaleType* p_b_scale_grid,
1936  void* p_shared,
1937  void* p_shared1,
1938  const Problem& problem,
1939  AElementwiseOperation a_element_op,
1940  BElementwiseOperation b_element_op,
1941  CElementwiseOperation c_element_op)
1942  {
1943  ignore = b_element_op;
1944  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1945  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1946  problem.MPadded,
1947  problem.K,
1948  problem.KPadded,
1949  problem.StrideA,
1950  problem.AK0);
1951  const auto b_grid_desc_bpreshuffled =
1953  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1954  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1955  problem.MPadded,
1956  problem.N,
1957  problem.NPadded,
1958  problem.StrideC);
1959 
1960  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
1961  make_tuple(math::integer_divide_ceil(IsInputGemm ? problem.NumTokens
1962  : problem.NumTokens * problem.TopK,
1963  ScaleBlockM),
1964  math::integer_divide_ceil(problem.K, ScaleBlockK)),
1965  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1966  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
1967  make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
1968  math::integer_divide_ceil(problem.K, ScaleBlockK)),
1969  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1970  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1972  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1973  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1974  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1975  if(expert_block_id * MPerBlock >= max_token_id)
1976  return;
1977  const index_t expert_id =
1978  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1979  const auto block_mn = [&]() -> std::pair<int, int> {
1980  if constexpr(NSwizzle)
1981  {
1982  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1983  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1984  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1985  const index_t expert_swizzle = ecnt > 0 ? ecnt : 1;
1986  const index_t bid_new = blockIdx.x - prefix_block;
1987  const index_t nid = __builtin_amdgcn_readfirstlane(
1988  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1989  const index_t mid =
1990  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1991  return {nid, mid};
1992  }
1993  else
1994  {
1995  return {blockIdx.x, blockIdx.y};
1996  }
1997  }();
1998  const index_t block_n_id = block_mn.first;
1999  const index_t block_m_id = block_mn.second;
2000 
2001  const index_t token0 =
2002  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2003 
2004  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2005  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2006  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
2007  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
2008  constexpr auto AKThreads = AK0Threads * AK1Threads;
2009  constexpr auto AMRepeats = MPerBlock / AMThreads;
2010  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
2011 
2012  if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
2013  token0 >= problem.NumTokens)
2014  return;
2016  gather_offsets; //= p_sorted_token_ids[token_pos];
2017  static_for<0, AMRepeats, 1>{}([&](auto m0) {
2018  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
2019  index_t token_offset = fused_token & 0xffffff;
2020  if constexpr(!IsInputGemm)
2021  {
2022  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2023  }
2024  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2025  });
2026  const index_t expert_stride =
2027  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2028  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2029  math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
2030  math::integer_divide_ceil(problem.K, ScaleBlockK));
2031  // N0, K0, Blocksize*KPack
2032  const index_t n_block_data_idx_on_grid =
2033  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
2034 
2035  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2036  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2037  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2038  p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
2039  b_grid_desc_bpreshuffled.GetElementSpaceSize());
2040 
2041  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2042  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2043  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2044  p_b_scale_grid + expert_id * expert_scale_stride,
2045  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2046 
2047  // A matrix in LDS memory, dst of blockwise copy
2048  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2049 
2050  // B matrix in LDS memory, dst of blockwise copy
2051  // dummy
2052  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2053  // A matrix blockwise copy
2054  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
2056  AElementwiseOperation,
2060  ABlockTransferThreadClusterLengths_AK0_M_AK1,
2061  ABlockTransferThreadClusterArrangeOrder,
2062  ADataType,
2063  LDSTypeA,
2064  decltype(a_grid_desc_ak0_m_ak1),
2065  decltype(a_block_desc_ak0_m_ak1),
2066  ABlockTransferSrcAccessOrder,
2068  ABlockTransferSrcVectorDim,
2069  2,
2070  ABlockTransferSrcScalarPerVector,
2071  ABlockTransferDstScalarPerVector_AK1,
2072  1,
2073  1,
2074  AThreadTransferSrcResetCoordinateAfterRun,
2075  true,
2076  IndexType,
2077  1,
2078  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
2079  make_multi_index(0, 0, 0),
2080  a_element_op,
2081  a_block_desc_ak0_m_ak1,
2082  make_multi_index(0, 0, 0),
2084  gather_offsets);
2085 
2086  // Thread-wise copy
2087  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2088  auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2089  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2090  auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2091  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2092  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2093 
2094  auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
2095  BDataType,
2096  BDataType,
2097  decltype(b_grid_desc_bpreshuffled),
2098  decltype(b_block_desc_bk0_n_bk1),
2101  3,
2102  BBlockTransferSrcScalarPerVector,
2103  BThreadTransferSrcResetCoordinateAfterRun,
2104  true>(b_grid_desc_bpreshuffled,
2105  make_multi_index(n_block_data_idx_on_grid,
2107  0,
2108  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2109 
2110  // LDS allocation for A and B: be careful of alignment
2111  // Cast after lds
2112  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2113  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2114  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2115  static_cast<ADataType*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2116  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2117 
2118  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2119  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
2120 
2121  // Blockwise GEMM pipeline
2122  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2123  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2124  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2125  decltype(c_thread_buf) c_thread_buf_up;
2126 
2127  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2128  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2129  KPerBlock);
2130 
2131  // scale
2132  constexpr index_t ScaleSliceSizeM = MXdlPerWave;
2133  constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
2134  constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
2135 
2136  // ScaleSliceSizeK is last dimension in A/B scale for vector memory access
2137  // ScaleSliceSizeK is first dimension in C scale for packed math
2138  constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
2140 
2141  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
2142  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
2143  auto a_thread_offset =
2144  get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl;
2145 
2146  constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
2148 
2149  constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
2151 
2152  // get each thread's offset in the scale tensor
2153  // A scale
2154  const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
2155 
2156  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2157  return;
2158  StaticallyIndexedArray<index_t, MXdlPerWave> scale_gather_offsets;
2159  static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
2160  const index_t fused_token =
2161  p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
2162  index_t token_offset = fused_token & 0xffffff;
2163  if constexpr(!IsInputGemm)
2164  {
2165  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2166  }
2167  scale_gather_offsets(m0) = static_cast<IndexType>(token_offset) *
2168  math::integer_divide_ceil(problem.K, ScaleBlockK);
2169  });
2170 
2171  auto a_scale_thread_copy =
2173  AScaleType,
2174  decltype(a_scale_grid_desc_am_ak),
2175  decltype(a_scale_thread_desc),
2178  1,
2179  ScaleSliceSizeK,
2180  1,
2181  false,
2182  MXdlPerWave>(
2183  a_scale_grid_desc_am_ak, make_multi_index(0, 0), scale_gather_offsets);
2184 
2185  auto b_scale_thread_copy =
2187  BScaleType,
2188  decltype(b_scale_grid_desc_bn_ak),
2189  decltype(b_scale_thread_desc),
2192  1,
2193  ScaleSliceSizeK,
2194  1,
2195  false>(
2196  b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
2197 
2198  // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
2199  constexpr auto a_scale_thread_slice_copy_step =
2200  make_tuple(make_multi_index(0, 0), make_multi_index(0, ScaleSliceSizeK));
2201  constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
2202 
2203  constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
2204  if constexpr(IsInputGemm)
2205  {
2206  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
2207  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2208  p_b_grid_up + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
2209  b_grid_desc_bpreshuffled.GetElementSpaceSize());
2210  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
2211  BDataType,
2212  BDataType,
2213  decltype(b_grid_desc_bpreshuffled),
2214  decltype(b_block_desc_bk0_n_bk1),
2217  3,
2218  BBlockTransferSrcScalarPerVector,
2219  BThreadTransferSrcResetCoordinateAfterRun,
2220  true>(b_grid_desc_bpreshuffled,
2221  make_multi_index(n_block_data_idx_on_grid,
2223  0,
2224  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2225  const BScaleType* p_b_scale_grid_up =
2226  p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
2227  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2228  p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize,
2229  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2230  auto b_scale_thread_copy_up =
2232  BScaleType,
2233  decltype(b_scale_grid_desc_bn_ak),
2234  decltype(b_scale_thread_desc),
2237  1,
2238  ScaleSliceSizeK,
2239  1,
2240  false>(
2241  b_scale_grid_desc_bn_ak,
2242  make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
2243 
2244  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2245  a_grid_desc_ak0_m_ak1,
2246  a_block_desc_ak0_m_ak1,
2247  a_blockwise_copy,
2248  a_grid_buf,
2249  a_block_bufs,
2250  a_block_slice_copy_step,
2251  b_grid_desc_bpreshuffled,
2252  b_block_desc_bk0_n_bk1,
2253  b_blockwise_copy,
2254  b_blockwise_copy_up,
2255  b_grid_buf,
2256  b_grid_buf_up,
2257  b_block_bufs,
2258  b_block_slice_copy_step,
2259  c_scale_thread_desc,
2260  c_thread_buf,
2261  c_thread_buf_up,
2262  a_scale_grid_desc_am_ak,
2263  a_scale_thread_desc,
2264  a_scale_thread_copy,
2265  a_scale_grid_buf,
2266  a_scale_thread_slice_copy_step,
2267  b_scale_grid_desc_bn_ak,
2268  b_scale_thread_desc,
2269  b_scale_thread_copy,
2270  b_scale_thread_copy_up,
2271  b_scale_grid_buf,
2272  b_scale_grid_buf_up,
2273  b_scale_thread_slice_copy_step,
2274  num_k_block_main_loop);
2275  }
2276  else
2277  {
2278  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2279  a_grid_desc_ak0_m_ak1,
2280  a_block_desc_ak0_m_ak1,
2281  a_blockwise_copy,
2282  a_grid_buf,
2283  a_block_bufs,
2284  a_block_slice_copy_step,
2285  b_grid_desc_bpreshuffled,
2286  b_block_desc_bk0_n_bk1,
2287  b_blockwise_copy,
2288  b_grid_buf,
2289  b_block_bufs,
2290  b_block_slice_copy_step,
2291  c_scale_thread_desc,
2292  c_thread_buf,
2293  a_scale_grid_desc_am_ak,
2294  a_scale_thread_desc,
2295  a_scale_thread_copy,
2296  a_scale_grid_buf,
2297  a_scale_thread_slice_copy_step,
2298  b_scale_grid_desc_bn_ak,
2299  b_scale_thread_desc,
2300  b_scale_thread_copy,
2301  b_scale_grid_buf,
2302  b_scale_thread_slice_copy_step,
2303  num_k_block_main_loop);
2304  }
2305 
2306  // shuffle C and write out
2307  {
2308 
2309  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2310  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2311  "wrong!");
2312 
2313  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2314 
2315  // transposed XDL
2316  // TODO: hacky, fix it!
2317  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
2318  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2319 
2320  // TODO: hacky, fix it!
2321  // only used to get lengths
2322  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
2323  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2324 
2325  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
2326  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
2327  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
2328  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
2329  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
2330  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
2331  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
2332  constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
2333 
2334  static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
2335  static_assert(M0 * M1 * M2 == MPerBlock);
2336  static_assert(N4 == 4);
2337  const index_t m1 = get_warp_local_1d_id() / NWave;
2338  const index_t m2 = threadIdx.x % get_warp_size() % M2;
2339 
2340  float topk_weight;
2341  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
2342  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
2343  if constexpr(MulRoutedWeight)
2344  {
2345  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
2346  topk_weight = p_ds_grid[I0][m_pos];
2347  }
2348  static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
2349  static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
2350  constexpr index_t c_offset =
2351  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2352  make_tuple(m0, n0, n2 * N4 + n4));
2353  constexpr auto cidx = Number<c_offset>{};
2354  if constexpr(IsInputGemm) // gu fusion, elementwise
2355  {
2356  if constexpr(ActivationOperation == Activation::silu_and_mul)
2357  {
2358  float gate = c_thread_buf[cidx];
2359  float up = c_thread_buf_up[cidx];
2360  if constexpr(MulRoutedWeight)
2361  {
2362  gate = gate * topk_weight;
2363  up = up * topk_weight;
2364  }
2366  {
2367  gate *= 16;
2368  up *= 16;
2369  }
2371  c_thread_buf(cidx) = gate * up;
2372  }
2373  else if(ActivationOperation == Activation::gelu_and_mul)
2374  {
2375  float gate = c_thread_buf[cidx];
2376  float up = c_thread_buf_up[cidx];
2377  if constexpr(MulRoutedWeight)
2378  {
2379  gate = gate * topk_weight;
2380  up = up * topk_weight;
2381  }
2383  {
2384  gate *= 16;
2385  up *= 16;
2386  }
2388  c_thread_buf(cidx) = gate * up;
2389  }
2390  }
2391  else
2392  {
2393  if constexpr(MulRoutedWeight)
2394  {
2395  c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
2396  }
2397  }
2398 
2399  });
2400  });
2401  });
2402  });
2403 
2404  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2406 
2407  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2408  static_cast<CShuffleDataType*>(p_shared),
2409  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2410 
2411  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
2412  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2413  make_tuple(
2416  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
2417  M1, // M1 = MWave
2418  M2)), // M2 = MPerXdl
2421  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
2422  N1, // N1 = NWave
2423  N2, // N2 * N3 * N4 = NPerXdl
2424  N3,
2425  N4))),
2427  make_tuple(
2429 
2430  // calculate origin of thread output tensor on global memory
2431  // blockwise GEMM c matrix starting index
2432  const auto c_thread_mtx_on_block =
2433  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2434 
2435  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2436  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2437 
2438  const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
2442  make_tuple(Sequence<0>{}));
2443 
2444  const auto m_thread_data_on_block_idx =
2445  m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
2446  make_multi_index(m_thread_data_on_block));
2447 
2448  const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
2450  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
2452  make_tuple(Sequence<0>{}));
2453 
2454  const auto n_thread_data_on_block_idx =
2455  n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
2456  make_multi_index(n_thread_data_on_block));
2457 
2458  // shuffle: threadwise copy C from VGPR to LDS
2459  auto c_thread_copy_vgpr_to_lds =
2461  CShuffleDataType,
2462  decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2463  decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2465  Sequence<CShuffleMXdlPerWavePerShuffle,
2466  CShuffleNXdlPerWavePerShuffle,
2467  I1,
2468  I1,
2469  I1,
2470  N2,
2471  I1,
2472  N4>,
2474  7,
2475  1,
2477  1,
2478  true>{
2479  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2480  make_multi_index(0,
2481  0,
2482  m_thread_data_on_block_idx[I1],
2483  n_thread_data_on_block_idx[I1],
2484  m_thread_data_on_block_idx[I2],
2485  n_thread_data_on_block_idx[I2],
2486  n_thread_data_on_block_idx[I3],
2487  n_thread_data_on_block_idx[I4]),
2489 
2490  using EDataType = CDataType;
2491 
2492  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2493  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2494 
2495  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2497  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2498 
2499  const auto ds_grid_buf = generate_tuple(
2500  [&](auto i) {
2501  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2502  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2503  },
2504  Number<NumDTensor>{});
2505 
2506  // tuple of reference to C/Ds tensor descriptors
2507  const auto c_ds_desc_refs = concat_tuple_of_reference(
2508  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2509  generate_tie(
2510  [&](auto i) -> const auto& // return type should be reference
2511  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2512  Number<NumDTensor>{}));
2513 
2514  // tuple of reference to C/Ds tensor descriptors
2515  const auto c_ds_buf_refs = concat_tuple_of_reference(
2516  tie(c_shuffle_block_buf),
2517  generate_tie(
2518  [&](auto i) -> const auto& // return type should be reference
2519  { return ds_grid_buf[i]; },
2520  Number<NumDTensor>{}));
2521 
2522  // tuple of starting index of C/Ds blockwise copy
2523  const auto idx_c_ds_block_begin =
2526  [&](auto) {
2527  return make_multi_index(block_m_id, 0, block_n_id, 0);
2528  // return make_multi_index(block_work_idx[I0], 0,
2529  // block_work_idx[I1], 0);
2530  },
2531  Number<NumDTensor>{}));
2532 
2533  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2534  c_grid_desc_mblock_mperblock_nblock_nperblock;
2535 
2536  using CDEBlockTransferCluster =
2537  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2538  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2539  constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix
2540  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2542  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2544  decltype(c_ds_desc_refs),
2545  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2546  CElementwiseOperation,
2547  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
2548  // support arbitray type
2549  Sequence<1,
2550  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2551  1,
2552  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2553  CDEBlockTransferCluster,
2554  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2555  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2556  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2557  3, // index_t SrcVectorDim,
2558  3, // index_t DstVectorDim,
2559  CDEShuffleBlockTransferScalarPerVectors,
2564  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2565  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2566  IndexType,
2567  1, // ScatterDim
2568  true, // OutputScatter: false, only use scatter weights
2569  scatter_weight_idx // ScatterWeightIdx: ascale
2570  >{c_ds_desc_refs,
2571  idx_c_ds_block_begin,
2572  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2573  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2574  c_element_op};
2575 
2576  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2577  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2578  // space filling curve for threadwise C in VGPR
2579  constexpr auto sfc_c_vgpr =
2582  Sequence<CShuffleMXdlPerWavePerShuffle,
2583  CShuffleNXdlPerWavePerShuffle,
2584  1,
2585  1,
2586  1,
2587  N2,
2588  1,
2589  N4>>{};
2590 
2591  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2592 
2593  // space filling curve for shuffled blockwise C/D/E
2594  constexpr auto sfc_cde_block =
2597  Sequence<1,
2598  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2599  1,
2600  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2601 
2602  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2603  constexpr auto EMThreads =
2604  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2605  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2606  constexpr auto ENThreads =
2607  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2608  static_for<0, num_access, 1>{}([&](auto access_id) {
2609  // make sure it's safe to write to LDS
2611  scatter_offsets; //= p_sorted_token_ids[c_token_pos];
2612 
2613  auto dstidx = sfc_cde_block.GetIndex(access_id);
2614  const index_t c_token_pos =
2615  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2616  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2617  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2618  index_t token_offset = fused_token & 0xffffff;
2619  if constexpr(IsInputGemm)
2620  {
2621  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2622  }
2623  scatter_offsets(m0) = token_offset * problem.N;
2624  });
2625 
2626  block_sync_lds();
2627 
2628  // each thread write its data from VGPR to LDS
2629  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2630  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2631  c_thread_buf,
2632  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2633  c_shuffle_block_buf);
2634 
2635  // make sure it's safe to read from LDS
2636  block_sync_lds();
2637 
2638  // each block copy its data from LDS to global
2639  cde_block_copy_lds_and_global.Run(
2640  c_ds_desc_refs,
2641  c_ds_buf_refs,
2642  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2643  tie(c_grid_buf),
2644  scatter_offsets);
2645 
2646  if constexpr(access_id < num_access - 1)
2647  {
2648  constexpr auto cde_lds_and_global_step =
2649  sfc_cde_block.GetForwardStep(access_id);
2650 
2651  // move on Ds
2652  static_for<0, NumDTensor, 1>{}([&](auto i) {
2653  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2654  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2655  });
2656 
2657  // move on E
2658  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2659  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2660  I0,
2661  cde_lds_and_global_step);
2662  }
2663  });
2664  }
2665  }
2666 };
2667 
2668 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:269
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:23
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:278
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
int64_t long_index_t
Definition: ck.hpp:301
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:300
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:81
Definition: gridwise_moe_gemm_blockscale.hpp:673
const index_t * p_sorted_token_ids
Definition: gridwise_moe_gemm_blockscale.hpp:729
CDataType * p_c_grid
Definition: gridwise_moe_gemm_blockscale.hpp:735
const BScaleType * p_b_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:738
const index_t * p_max_token_id
Definition: gridwise_moe_gemm_blockscale.hpp:731
DsGridPointer p_ds_grid
Definition: gridwise_moe_gemm_blockscale.hpp:734
const CElementwiseOperation c_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:742
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, const AScaleType *p_a_scale_grid_, const BScaleType *p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_gemm_blockscale.hpp:674
const ADataType * p_a_grid
Definition: gridwise_moe_gemm_blockscale.hpp:732
const AElementwiseOperation a_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:740
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_gemm_blockscale.hpp:730
const BDataType * p_b_grid
Definition: gridwise_moe_gemm_blockscale.hpp:733
const AScaleType * p_a_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:737
const BElementwiseOperation b_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:741
Definition: gridwise_moe_gemm_blockscale.hpp:593
index_t K
Definition: gridwise_moe_gemm_blockscale.hpp:652
__host__ __device__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_gemm_blockscale.hpp:594
index_t TopK
Definition: gridwise_moe_gemm_blockscale.hpp:649
index_t BK0Shuffled
Definition: gridwise_moe_gemm_blockscale.hpp:668
index_t NPadded
Definition: gridwise_moe_gemm_blockscale.hpp:659
index_t StrideB
Definition: gridwise_moe_gemm_blockscale.hpp:654
__host__ void Print() const
Definition: gridwise_moe_gemm_blockscale.hpp:627
index_t BK0
Definition: gridwise_moe_gemm_blockscale.hpp:663
index_t BN0Shuffled
Definition: gridwise_moe_gemm_blockscale.hpp:667
index_t KRead
Definition: gridwise_moe_gemm_blockscale.hpp:660
index_t N
Definition: gridwise_moe_gemm_blockscale.hpp:651
index_t StrideC
Definition: gridwise_moe_gemm_blockscale.hpp:656
index_t KBatch
Definition: gridwise_moe_gemm_blockscale.hpp:657
index_t MBlock
Definition: gridwise_moe_gemm_blockscale.hpp:664
index_t KPadded
Definition: gridwise_moe_gemm_blockscale.hpp:661
index_t NumTokens
Definition: gridwise_moe_gemm_blockscale.hpp:648
index_t StrideA
Definition: gridwise_moe_gemm_blockscale.hpp:653
index_t AK0
Definition: gridwise_moe_gemm_blockscale.hpp:662
index_t M
Definition: gridwise_moe_gemm_blockscale.hpp:650
index_t MPadded
Definition: gridwise_moe_gemm_blockscale.hpp:658
index_t NBlock
Definition: gridwise_moe_gemm_blockscale.hpp:665
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_gemm_blockscale.hpp:655
Definition: gridwise_moe_gemm_blockscale.hpp:746
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_gemm_blockscale.hpp:747
index_t a_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:778
index_t b_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:779
Definition: gridwise_moe_gemm_blockscale.hpp:171
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:522
static constexpr index_t KPack
Definition: gridwise_moe_gemm_blockscale.hpp:196
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_gemm_blockscale.hpp:898
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_gemm_blockscale.hpp:411
static constexpr auto AK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:189
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:269
static constexpr auto BK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:190
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_gemm_blockscale.hpp:230
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1927
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1144
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_gemm_blockscale.hpp:782
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:260
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:285
static constexpr auto I3
Definition: gridwise_moe_gemm_blockscale.hpp:178
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm_blockscale.hpp:971
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:315
static constexpr index_t KLane
Definition: gridwise_moe_gemm_blockscale.hpp:209
remove_cvref_t< decltype(BlockGemmBlockMoeScaleBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, ScaleBlockM, ScaleBlockN, ScaleBlockK, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_gemm_blockscale.hpp:947
float AScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:172
static constexpr auto AK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:187
static constexpr index_t KRepeat
Definition: gridwise_moe_gemm_blockscale.hpp:211
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_gemm_blockscale.hpp:329
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:1159
static constexpr auto I2
Definition: gridwise_moe_gemm_blockscale.hpp:177
static constexpr index_t APackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:232
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:265
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:274
static constexpr auto I4
Definition: gridwise_moe_gemm_blockscale.hpp:179
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_gemm_blockscale.hpp:567
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:255
static constexpr auto I6
Definition: gridwise_moe_gemm_blockscale.hpp:181
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_gemm_blockscale.hpp:184
__host__ static constexpr __device__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1151
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:516
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:246
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_gemm_blockscale.hpp:228
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_gemm_blockscale.hpp:905
float BScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:173
static constexpr auto I7
Definition: gridwise_moe_gemm_blockscale.hpp:182
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:304
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:579
static constexpr index_t NWave
Definition: gridwise_moe_gemm_blockscale.hpp:213
static constexpr index_t NLane
Definition: gridwise_moe_gemm_blockscale.hpp:212
static constexpr index_t NumDTensor
Definition: gridwise_moe_gemm_blockscale.hpp:193
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:297
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:309
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_gemm_blockscale.hpp:949
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:291
static constexpr auto I0
Definition: gridwise_moe_gemm_blockscale.hpp:175
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:507
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:546
static constexpr auto BlockSizeNumber
Definition: gridwise_moe_gemm_blockscale.hpp:191
static constexpr index_t KGroup
Definition: gridwise_moe_gemm_blockscale.hpp:198
static constexpr index_t BPackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:239
static constexpr auto I5
Definition: gridwise_moe_gemm_blockscale.hpp:180
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1180
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_gemm_blockscale.hpp:217
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_gemm_blockscale.hpp:419
static constexpr auto I1
Definition: gridwise_moe_gemm_blockscale.hpp:176
static constexpr auto BK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:188
static constexpr index_t SortedTileSize
Definition: gridwise_moe_gemm_blockscale.hpp:215
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:279
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))> DsGridDesc_M_N
Definition: gridwise_moe_gemm_blockscale.hpp:590
Definition: xdlops_gemm.hpp:942
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1388
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1343
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1382
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: threadwise_tensor_slice_transfer.hpp:440
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:981
Definition: unary_element_wise_operation.hpp:308
Definition: unary_element_wise_operation.hpp:1023