/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp Source File
gridwise_moe_mx_gemm_bns.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/env.hpp"
18 
20 
21 #define DEBUG_LOG 0
22 
23 namespace ck {
24 
25 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
26 // kernel function Blockers:
27 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
28 // two lds chunks.
29 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
30 // buffer when we declare __shared__ inside blkgemmpipe
31 
33 {
34  gelu_and_mul = 0,
35  silu_and_mul = 1
36 };
37 
38 template <typename GridwiseGemm,
39  bool HasMainKBlockLoop,
40  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
41  index_t MinimumOccupancy = 1,
42  TailNumber TailNum = TailNumber::Even>
43 __global__ void
44 #if CK_USE_LAUNCH_BOUNDS
45 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
46 #endif
47  // __attribute__((amdgpu_waves_per_eu(1, 1)))
48  kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
49 {
50 #if defined(__gfx9__)
51  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
52 
53  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
54 
55  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
56  karg.p_sorted_token_ids,
57  karg.p_sorted_expert_ids,
58  karg.p_max_token_id,
59  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
60  karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
61  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
62  karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
63  karg.p_ds_grid,
64  karg.p_c_grid,
65  p_shared,
66  karg,
67  karg.a_element_op,
68  karg.b_element_op,
69  karg.c_element_op);
70 #else
71  ignore = karg;
72 #endif // end of if (defined(__gfx9__))
73 }
74 
75 #if 0
76 template <typename GridwiseGemm,
77  bool HasMainKBlockLoop,
78  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
79  index_t MinimumOccupancy = 1,
80  TailNumber TailNum = TailNumber::Even>
81 __global__ void
82 #if CK_USE_LAUNCH_BOUNDS
83 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
84 #endif
85  // __attribute__((amdgpu_waves_per_eu(1, 1)))
86  kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
87 {
88 #if defined(__gfx9__)
89  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
90  __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
91 
92  // auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
93 
94  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
95  karg.p_sorted_token_ids,
96  karg.p_sorted_expert_ids,
97  karg.p_max_token_id,
98  karg.p_a_grid,
99  karg.p_a_scale_grid,
100  karg.p_b_grid,
101  karg.p_b_scale_grid,
102  karg.p_ds_grid,
103  karg.p_c_grid,
104  p_shared,
105  p_shared1,
106  karg,
107  karg.a_element_op,
108  karg.b_element_op,
109  karg.c_element_op);
110 #else
111  ignore = karg;
112 #endif // end of if (defined(__gfx9__))
113 }
114 #endif
115 
116 template <typename ALayout,
117  typename BLayout,
118  typename DsLayout,
119  typename CLayout,
120  typename ADataType,
121  typename AScaleDataType,
122  typename BDataType,
123  typename BScaleDataType,
124  typename AccDataType,
125  typename CShuffleDataType,
126  typename DsDataType,
127  typename CDataType,
128  typename AElementwiseOperation,
129  typename BElementwiseOperation,
130  typename CElementwiseOperation,
132  index_t ScaleBlockSize,
133  index_t BlockSize,
134  index_t MPerBlock,
135  index_t NPerBlock,
136  index_t KPerBlock,
137  index_t AK1Value,
138  index_t BK1Value,
139  index_t MPerXdl,
140  index_t NPerXdl,
141  index_t MXdlPerWave,
142  index_t NXdlPerWave,
143  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
144  typename ABlockTransferThreadClusterArrangeOrder,
145  typename ABlockTransferSrcAccessOrder,
146  index_t ABlockTransferSrcVectorDim,
147  index_t ABlockTransferSrcScalarPerVector,
148  index_t ABlockTransferDstScalarPerVector_AK1,
149  bool AThreadTransferSrcResetCoordinateAfterRun,
150  index_t ABlockLdsExtraM,
151  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
152  typename BBlockTransferThreadClusterArrangeOrder,
153  typename BBlockTransferSrcAccessOrder,
154  index_t BBlockTransferSrcVectorDim,
155  index_t BBlockTransferSrcScalarPerVector,
156  index_t BBlockTransferDstScalarPerVector_BK1,
157  bool BThreadTransferSrcResetCoordinateAfterRun,
158  index_t BBlockLdsExtraN,
159  index_t CShuffleMXdlPerWavePerShuffle,
160  index_t CShuffleNXdlPerWavePerShuffle,
161  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
162  typename CDEShuffleBlockTransferScalarPerVectors,
165  index_t ActivationOperation = 0,
166  bool NSwizzle = false,
167  bool IsInputGemm = true,
168  bool MulRoutedWeight = true,
169  typename IndexType = index_t,
170  typename ComputeTypeA = ADataType,
171  typename ComputeTypeB = BDataType>
173 {
174  using LDSTypeA = ADataType;
175  using LDSTypeB = BDataType;
176 
177  static constexpr auto I0 = Number<0>{};
178  static constexpr auto I1 = Number<1>{};
179  static constexpr auto I2 = Number<2>{};
180  static constexpr auto I3 = Number<3>{};
181  static constexpr auto I4 = Number<4>{};
182  static constexpr auto I5 = Number<5>{};
183  static constexpr auto I6 = Number<6>{};
184  static constexpr auto I7 = Number<7>{};
185  static constexpr auto I8 = Number<8>{};
186  static constexpr auto I9 = Number<9>{};
187 
189  CDEShuffleBlockTransferScalarPerVectors{}[I0];
190  // K1 should be Number<...>
191  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
192  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
193  static constexpr auto AK1Number = Number<AK1Value>{};
194  static constexpr auto BK1Number = Number<BK1Value>{};
195 
196  static constexpr index_t NumDTensor = DsDataType::Size();
197 
198  static constexpr auto MXdlPack = 2;
199  static constexpr auto NXdlPack = 2;
200  static constexpr auto KXdlPack = 2;
201 
202  static constexpr index_t APackedSize = packed_size_v<ADataType>;
203  static constexpr index_t BPackedSize = packed_size_v<BDataType>;
204 
205  static constexpr bool is_single_rate_mfma = false;
206  static constexpr auto is_scale_mfma = true;
207  using mfma_selector = MfmaSelector<ComputeTypeA,
208  MPerXdl,
209  NPerXdl,
210  ComputeTypeB,
212  is_scale_mfma>;
213  static constexpr index_t KPack = math::max(
215 
216  // static constexpr index_t NumTokens = 1;
217  static constexpr index_t SortedTileSize = MPerBlock;
218 
219  static constexpr auto MakeDsGridPointer()
220  {
221  return generate_tuple(
222  [&](auto i) {
223  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
224 
225  return static_cast<const DDataType*>(nullptr);
226  },
228  }
229 
230  using DsGridPointer = decltype(MakeDsGridPointer());
231 
233 
234  __host__ static auto CalculateGridSize(index_t M, index_t N)
235  {
236  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
237  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
238  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
239  const index_t gridy = NSwizzle ? 1 : mblock;
240 
241  return std::make_tuple(gridx, gridy, 1);
242  }
243 
244  __host__ static auto CalculateMPadded(index_t M)
245  {
246  return math::integer_least_multiple(M, MPerBlock);
247  }
248 
249  __host__ static auto CalculateNPadded(index_t N)
250  {
251  return math::integer_least_multiple(N, NPerBlock);
252  }
253 
254  __host__ static auto CalculateKPadded(index_t K)
255  {
256  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
257  }
258 
259  __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
260  {
261  auto K_t = K_Batch * KPerBlock;
262  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
263  }
264 
265  __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
266  {
267  auto K_t = K_Batch * KPerBlock;
268  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
269  }
270 
271  __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
272  {
273  auto K_t = K_Batch * KPerBlock;
274  return (K + K_t - 1) / K_t * KPerBlock;
275  }
276 
277  __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
278  {
279  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
280  auto K_t = K_Batch * KReadVec;
281  return (K + K_t - 1) / K_t * KReadVec;
282  }
283 
284  __host__ static auto CalculateMBlock(index_t M)
285  {
286  return math::integer_divide_ceil(M, MPerBlock);
287  }
288 
289  __host__ static auto CalculateNBlock(index_t N)
290  {
291  return math::integer_divide_ceil(N, NPerBlock);
292  }
293 
294  template <index_t MNXdlPerWave,
295  index_t MNWaves,
296  index_t MNXdlPack,
297  index_t MNPerXdl,
298  typename TileDesc_K0_MN_K1>
299  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
300  {
301  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
302  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
303 
305  TileDesc_K0_MN_K1{},
308  Number<MNWaves>{},
310  Number<MNPerXdl>{}))),
313  }
314 
315  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
316  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
317  {
318  const auto a_grid_desc_mraw_kraw = [&]() {
319  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
320  {
321  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
322  }
323  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
324  {
325  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
326  }
327  }();
328 
330 
331  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
332  GemmSpec == GemmSpecialization::MNKPadding)
333  {
334  // pad both M and K
335  const auto a_grid_desc_m_k =
336  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
338  make_right_pad_transform(K, KPad - K)),
341 
342  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
343  a_grid_desc_m_k,
348 
349  return a_grid_desc_ak0_m_ak1;
350  }
351  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
352  GemmSpec == GemmSpecialization::MNPadding)
353  {
354  // pad M, but not K
355  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
356  a_grid_desc_mraw_kraw,
358  make_right_pad_transform(M, MPad - M)),
361 
362  return a_grid_desc_ak0_m_ak1;
363  }
364  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
365  GemmSpec == GemmSpecialization::NKPadding)
366  {
367  // pad K, but not M
368  const auto a_grid_desc_m_k = transform_tensor_descriptor(
369  a_grid_desc_mraw_kraw,
373 
374  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
375  a_grid_desc_m_k,
380 
381  return a_grid_desc_ak0_m_ak1;
382  }
383  else
384  {
385  // not pad M or K
386  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
387  a_grid_desc_mraw_kraw,
392 
393  return a_grid_desc_ak0_m_ak1;
394  }
395  }
396 
397  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
398  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
399  {
400  const auto b_grid_desc_nraw_kraw = [&]() {
402  {
403  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
404  }
406  {
407  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
408  }
409  }();
410 
412 
413  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
414  GemmSpec != GemmSpecialization::Default),
415  "pk_i4_t does not support padding");
416  static_assert(!(is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t> &&
417  GemmSpec != GemmSpecialization::Default),
418  "f4x2_pk_t does not support padding");
419 
420  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
421  GemmSpec == GemmSpecialization::MNKPadding)
422  {
423  // pad both N and K
424  const auto b_grid_desc_n_k =
425  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
427  make_right_pad_transform(K, KPad - K)),
430 
431  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
432  b_grid_desc_n_k,
437 
438  return b_grid_desc_bk0_n_bk1;
439  }
440  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
441  GemmSpec == GemmSpecialization::MNPadding)
442  {
443  // pad N, but not K
444  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
445  b_grid_desc_nraw_kraw,
447  make_right_pad_transform(N, NPad - N)),
450 
451  return b_grid_desc_bk0_n_bk1;
452  }
453  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
454  GemmSpec == GemmSpecialization::MKPadding)
455  {
456  // pad K, but not N
457  const auto b_grid_desc_n_k = transform_tensor_descriptor(
458  b_grid_desc_nraw_kraw,
462 
463  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
464  b_grid_desc_n_k,
469 
470  return b_grid_desc_bk0_n_bk1;
471  }
472  else
473  {
474  // not pad N or K
475  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
476  b_grid_desc_nraw_kraw,
481 
482  return b_grid_desc_bk0_n_bk1;
483  }
484  }
485 
486  template <typename ABlockDesc_AK0_M_AK1>
487  __host__ __device__ static constexpr auto
488  MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&)
489  {
490  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
491 
492  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MXdlPack, MPerXdl>(
493  ABlockDesc_AK0_M_AK1{});
494  }
495 
496  template <typename BBlockDesc_BK0_N_BK1>
497  __host__ __device__ static constexpr auto
498  MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&)
499  {
500  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
501 
502  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NXdlPack, NPerXdl>(
503  BBlockDesc_BK0_N_BK1{});
504  }
505 
506  template <typename ELayout>
507  __host__ __device__ static auto MakeCGridDescriptor_M_N(
508  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
509  {
510  const auto c_grid_desc_mraw_nraw = [&]() {
512  {
513  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
514  }
516  {
517  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
518  }
519  }();
520 
521  // pad M and N
522  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
524  make_right_pad_transform(N, NPad - N)),
527  }
528 
529  template <typename DLayout>
530  __host__ __device__ static auto
532  {
533  const auto c_grid_desc_mraw_nraw = [&]() {
535  {
536  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
537  }
539  {
540  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
541  }
542  }();
543 
544  // pad M and N
545  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
547  make_right_pad_transform(N, NPad - N)),
550  }
551 
552  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
553  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
554  {
555  return generate_tuple(
556  [&](auto i) {
557  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
558  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
559  },
561  }
562 
563  template <typename DsGridDesc>
565  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
566  {
567  return generate_tuple(
568  [&](auto i) {
570  ds_grid_desc_m_n[i], MBlock, NBlock);
571  },
573  }
574 
575  struct Problem
576  {
577  __host__ Problem(index_t NumTokens_,
578  index_t TopK_,
579  index_t M_,
580  index_t N_,
581  index_t K_,
582  index_t StrideA_,
583  index_t StrideScaleA_,
584  index_t StrideB_,
585  index_t StrideScaleB_,
586  std::array<index_t, NumDTensor> StrideDs_,
587  index_t StrideC_,
588  index_t KBatch_)
589  : NumTokens{NumTokens_},
590  TopK{TopK_},
591  M{M_},
592  N{N_},
593  K{K_},
594  StrideA{StrideA_},
595  StrideScaleA{StrideScaleA_},
596  StrideB{StrideB_},
597  StrideScaleB{StrideScaleB_},
598  StrideDs{StrideDs_},
599  StrideC{StrideC_},
600  KBatch{KBatch_},
603  KRead{CalculateKRead(K_, KBatch_)},
604  KPadded{CalculateKPadded(K_, KBatch_)},
605  AK0{CalculateAK0Padded(K_, KBatch_)},
606  BK0{CalculateBK0Padded(K_, KBatch_)},
607  MBlock{CalculateMBlock(M_)},
609  {
610  }
611 
612  __host__ void Print() const
613  {
614  std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
615  << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
616  << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", "
617  << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", "
618  << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
619  << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
620  << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
621  << ", " << "NBlock: " << NBlock << "}" << std::endl;
622  }
623 
633  std::array<index_t, NumDTensor> StrideDs;
644  };
645 
646  // Argument
648  {
649  __host__ Argument(const index_t* p_sorted_token_ids_,
650  const index_t* p_sorted_expert_ids_,
651  const index_t* p_max_token_id_,
652  const ADataType* p_a_grid_,
653  const AScaleDataType* p_a_scale_grid_,
654  const BDataType* p_b_grid_,
655  const BScaleDataType* p_b_scale_grid_,
656  std::array<const void*, NumDTensor> p_ds_grid_,
657  CDataType* p_c_grid_,
658  index_t NumTokens_,
659  index_t TopK_,
660  index_t M_,
661  index_t N_,
662  index_t K_,
663  index_t StrideA_,
664  index_t StrideScaleA_,
665  index_t StrideB_,
666  index_t StrideScaleB_,
667  std::array<index_t, NumDTensor> StrideDs_,
668  index_t StrideC_,
669  index_t k_batch_,
670  AElementwiseOperation a_element_op_,
671  BElementwiseOperation b_element_op_,
672  CElementwiseOperation c_element_op_)
673  : Problem{NumTokens_,
674  TopK_,
675  M_,
676  N_,
677  K_ / APackedSize,
678  StrideA_ / APackedSize,
679  StrideScaleA_,
680  StrideB_ / BPackedSize,
681  StrideScaleB_,
682  StrideDs_,
683  StrideC_,
684  k_batch_},
685  p_sorted_token_ids{p_sorted_token_ids_},
686  p_sorted_expert_ids{p_sorted_expert_ids_},
687  p_max_token_id{p_max_token_id_},
688  p_a_grid{p_a_grid_},
689  p_a_scale_grid{p_a_scale_grid_},
690  p_b_grid{p_b_grid_},
691  p_b_scale_grid{p_b_scale_grid_},
692  p_ds_grid{},
693  p_c_grid{p_c_grid_},
694  a_element_op{a_element_op_},
695  b_element_op{b_element_op_},
696  c_element_op{c_element_op_}
697  {
698 
699  // populate pointer, desc for Ds
700  static_for<0, NumDTensor, 1>{}([&](auto i) {
701  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
702 
703  // D pointer
704  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
705  });
706  }
707 
711  const ADataType* p_a_grid;
712  const AScaleDataType* p_a_scale_grid;
713  const BDataType* p_b_grid;
714  const BScaleDataType* p_b_scale_grid;
716  CDataType* p_c_grid;
717 
718  const AElementwiseOperation a_element_op;
719  const BElementwiseOperation b_element_op;
720  const CElementwiseOperation c_element_op;
721  };
722 
724  {
725  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
726  {
727  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
728  {
729  a_k_split_offset = k_id * karg.KRead;
730  }
731  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
732  {
733  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
734  }
735 
736  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
737  {
738  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
739  }
740  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
741  {
742  // KPack * NLane * KLane * K0 * N0
743  b_k_split_offset = k_id * karg.KRead;
744  }
745 
746  // Calculate A scale offset
747  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
748  {
749  a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize);
750  }
751  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
752  {
754  k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA;
755  }
756 
757  // Calculate B scale offset
758  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
759  {
761  k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB;
762  }
763  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
764  {
765  b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize);
766  }
767 
768  if(k_id < karg.KBatch - 1)
769  {
770  karg.K = karg.KRead;
771  }
772  else
773  {
774  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
775  }
776  }
777 
782  };
783 
784  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
785  {
786  // A matrix in LDS memory, dst of blockwise copy
787  if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
788  {
792  }
793  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
794  // in some cases.
796  {
797  constexpr auto a_lds_block_desc =
800 
801  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
802  a_lds_block_desc,
808 
809  return a_lds_block_desc_permuted;
810  }
811  else // ColumnMajor A
812  {
813  // kfold and mpair dimension is not always required.
814  // more dimension in merge_transform increase the difficulty of generating immarg offset
815  // for compiler.
816  constexpr auto WaveSize = 64;
817  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
818  constexpr auto M1 = MPerBlock / M0;
819 
820  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
821  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
822  constexpr auto KThreadRead = WaveSize / MPerXdl;
823  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
824 
825  constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
826  ? 1
827  : 128 / (AK1Number * M0 * sizeof(ADataType));
828  constexpr auto KThreadReadPerm =
829  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
830  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
831  : KThreadRead;
832 
833  // 1<=mpair<=n0
834  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
835  ? 1
836  : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
837  ? M0
838  : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
839 
840  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
844  Number<kfold * M0 / mpair>{},
845  Number<mpair>{},
846  AK1Number));
847 
848  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
849  a_lds_block_desc,
850  make_tuple(
854  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
857  make_tuple(
859  make_tuple(
861 
862  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
863  a_lds_block_desc_permuted,
864  make_tuple(
872  Sequence<1>{},
873  Sequence<2>{},
874  Sequence<3>{},
875  Sequence<4>{},
876  Sequence<5>{}),
878  Sequence<2>{},
879  Sequence<0, 3>{},
880  Sequence<4, 5>{},
881  Sequence<6>{},
882  Sequence<7>{}));
883 
884  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
885  a_lds_block_desc_unmerged,
888  Number<KThreadWrite / kfold / KThreadReadPerm>{},
889  Number<kfold>{},
896 
897  return a_lds_block_desc_ak0_m_ak1;
898  }
899  }
900 
901  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
902  {
903  // B matrix in LDS memory, dst of blockwise copy
904  if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
905  {
909  }
911  {
912  // NLdsLayer * K0 as logical Bank
913  constexpr auto b_lds_block_desc =
916 
917  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
918  b_lds_block_desc,
924 
925  return b_lds_block_desc_permuted;
926  }
927  else // RowMajor B
928  {
929  constexpr auto WaveSize = 64;
930  constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
931  constexpr auto N1 = NPerBlock / N0;
932 
933  constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
934  constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
935  constexpr auto KThreadRead = WaveSize / NPerXdl;
936  constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
937 
938  constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
939  ? 1
940  : 128 / (BK1Number * N0 * sizeof(BDataType));
941  constexpr auto KThreadReadPerm =
942  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
943  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
944  : KThreadRead;
945 
946  // 1<=npair<=n0
947  constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
948  ? 1
949  : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
950  ? N0
951  : 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
952 
953  constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
957  Number<kfold * N0 / npair>{},
958  Number<npair>{},
959  BK1Number));
960 
961  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
962  b_lds_block_desc,
963  make_tuple(
967  make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
970  make_tuple(
972  make_tuple(
974 
975  constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
976  b_lds_block_desc_permuted,
977  make_tuple(
985  Sequence<1>{},
986  Sequence<2>{},
987  Sequence<3>{},
988  Sequence<4>{},
989  Sequence<5>{}),
991  Sequence<2>{},
992  Sequence<0, 3>{},
993  Sequence<4, 5>{},
994  Sequence<6>{},
995  Sequence<7>{}));
996 
997  constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
998  b_lds_block_desc_unmerged,
1001  Number<KThreadWrite / kfold / KThreadReadPerm>{},
1002  Number<kfold>{},
1009 
1010  return b_lds_block_desc_bk0_n_bk1;
1011  }
1012  }
1013 
1015  {
1016  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1017  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1018 
1019  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1021  make_tuple(I1,
1023  I1,
1025 
1026  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1027  }
1028 
1031  BlkGemmPipelineVer,
1032  BlkGemmPipeSched,
1033  BlockSize,
1034  ScaleBlockSize,
1035  ADataType,
1036  AScaleDataType,
1037  BDataType,
1038  BScaleDataType,
1039  ComputeTypeA,
1040  AccDataType,
1047  ABlockTransferSrcScalarPerVector,
1048  BBlockTransferSrcScalarPerVector,
1049  MPerBlock,
1050  NPerBlock,
1051  KPerBlock,
1052  MPerXdl,
1053  NPerXdl,
1054  MXdlPerWave,
1055  NXdlPerWave,
1056  KPack,
1057  IsInputGemm>())>;
1058 
1059  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1060  {
1061  // LDS allocation for A and B: be careful of alignment
1062  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1063  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1064 
1065  // lds max alignment
1066  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1067 
1068  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1069  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1070 
1071  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1072  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1073 
1074  // LDS allocation for C shuffle in LDS
1075  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1077 
1078  constexpr auto c_block_size =
1079  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1080 
1081  if constexpr(IsInputGemm)
1082  {
1083  return math::max((a_block_space_size_aligned * sizeof(ADataType) +
1084  b_block_space_size_aligned * sizeof(BDataType)) *
1085  2,
1086  c_block_size * sizeof(CShuffleDataType));
1087  }
1088  else
1089  {
1090  return math::max((a_block_space_size_aligned * sizeof(ADataType) +
1091  b_block_space_size_aligned * sizeof(BDataType)),
1092  c_block_size * sizeof(CShuffleDataType));
1093  }
1094  }
1095 
1096  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1097  __host__ static constexpr bool CheckValidity(const Argument& karg)
1098  {
1099  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1100  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1101  "Invalid tuning param!");
1102 
1103  static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
1104  "KPerBlock should be multiple of ScaleBlockSize");
1105 
1106  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1111  {
1112  if(!(karg.M % MPerBlock == 0))
1113  {
1114  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1115  {
1116  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1117  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1118  << std::endl;
1119  }
1120  return false;
1121  }
1122  }
1123 
1124  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1129  {
1130  if(!(karg.N % NPerBlock == 0))
1131  {
1132  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1133  {
1134  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1135  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1136  << std::endl;
1137  }
1138  return false;
1139  }
1140  }
1141 
1142  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1146  {
1147  auto K_t = karg.KBatch * KPerBlock;
1148  if(!(karg.K % K_t == 0))
1149  {
1150  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1151  {
1152  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1153  << karg.K << " " << __FILE__ << ":" << __LINE__
1154  << ", in function: " << __func__ << std::endl;
1155  }
1156  return false;
1157  }
1158  }
1159  else
1160  {
1161  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1162  auto K_t = karg.KBatch * KReadVec;
1163  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1164  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1165  {
1166  return false;
1167  }
1168  }
1169 
1171  {
1172  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1173  {
1174  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1175  {
1176  std::cout << "Arg K (" << karg.K
1177  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1178  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1179  << __LINE__ << ", in function: " << __func__ << std::endl;
1180  }
1181  return false;
1182  }
1183  }
1184  else
1185  {
1186  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1187  {
1188  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1189  {
1190  std::cout << "Arg M (" << karg.M
1191  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1192  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1193  << __LINE__ << ", in function: " << __func__ << std::endl;
1194  }
1195  return false;
1196  }
1197  }
1198 
1200  {
1201  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1202  {
1203  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1204  {
1205  std::cout << "Arg N (" << karg.N
1206  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1207  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1208  << __LINE__ << ", in function: " << __func__ << std::endl;
1209  }
1210  return false;
1211  }
1212  }
1213  else
1214  {
1215  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1216  {
1217  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1218  {
1219  std::cout << "Arg K (" << karg.K
1220  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1221  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1222  << __LINE__ << ", in function: " << __func__ << std::endl;
1223  }
1224  return false;
1225  }
1226  }
1227 
1229  {
1231  {
1232  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1233  {
1234  std::cout << "Arg N (" << karg.N
1235  << ") value is not a multiple of "
1236  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1238  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1239  << std::endl;
1240  }
1241  return false;
1242  }
1243  }
1244  else
1245  {
1247  {
1248  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1249  {
1250  std::cout << "Arg M (" << karg.M
1251  << ") value is not a multiple of "
1252  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1254  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1255  << std::endl;
1256 
1257  return false;
1258  }
1259  }
1260  }
1261 
1262  // check gridwise gemm pipeline
1263 #if 0
1264  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1265 
1266  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1267  {
1268  return false;
1269  }
1270 #endif
1271  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1272  return true;
1273  }
1274 
1275  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1276  {
1277  const index_t num_loop = K / KPerBlock;
1278 
1279  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1280  }
1281 
1282  __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1283  {
1284  const index_t num_loop = K / KPerBlock;
1285 
1286  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1287  }
1288 
1289  template <typename CGridDesc>
1290  __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1291  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1292  {
1293  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1294  c_grid_desc_m_n,
1299 
1300  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1301  }
1302 
1303  // return block_id to C matrix tile idx (m0, n0) mapping
1304  // if arch = gfx942
1305  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1306  // NPerBlock>;
1307 
1309  static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
1310  static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
1311  static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
1312  "A scale pack data type too large!");
1313  static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
1314  "B scale pack data type too large!");
1315 
1316  template <bool HasMainKBlockLoop,
1317  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1318  TailNumber TailNum = TailNumber::Odd>
1319  __device__ static void Run(const index_t* p_sorted_token_ids,
1320  const index_t* p_sorted_expert_ids,
1321  const index_t* p_max_token_id,
1322  const ADataType* p_a_grid,
1323  const AScaleDataType* p_a_scale_grid,
1324  const BDataType* p_b_grid,
1325  const BScaleDataType* p_b_scale_grid,
1326  DsGridPointer& p_ds_grid,
1327  CDataType* p_c_grid,
1328  void* p_shared,
1329  const Problem& problem,
1330  AElementwiseOperation a_element_op,
1331  BElementwiseOperation b_element_op,
1332  CElementwiseOperation c_element_op)
1333  {
1334  ignore = b_element_op;
1335  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1336  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1337  problem.MPadded,
1338  problem.K,
1339  problem.KPadded,
1340  problem.StrideA,
1341  problem.AK0);
1342  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1343  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1344  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1345  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1346  problem.MPadded,
1347  problem.N,
1348  problem.NPadded,
1349  problem.StrideC);
1350 
1351  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
1352  make_tuple(problem.M / (MXdlPack * MPerXdl),
1353  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
1354  (KXdlPack * 64 / MPerXdl),
1355  64 * KXdlPack * MXdlPack / scale_pack_size_a));
1356 
1357  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
1358  make_tuple(problem.N / (NXdlPack * NPerXdl),
1359  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
1360  (KXdlPack * 64 / NPerXdl),
1361  64 * KXdlPack * NXdlPack / scale_pack_size_b));
1362 
1363  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1365  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1366 
1367  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1368  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1369  if(expert_block_id * MPerBlock >= max_token_id)
1370  return;
1371  const index_t expert_id =
1372  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1373 
1374  const auto block_mn = [&]() -> std::pair<int, int> {
1375  if constexpr(NSwizzle)
1376  {
1377  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1378  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1379  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1380  const index_t expert_swizzle =
1381  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1382  const index_t bid_new = blockIdx.x - prefix_block;
1383  const index_t nid = __builtin_amdgcn_readfirstlane(
1384  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1385  const index_t mid =
1386  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1387  return {nid, mid};
1388  }
1389  else
1390  {
1391  return {blockIdx.x, blockIdx.y};
1392  }
1393  }();
1394 
1395  const index_t block_n_id = block_mn.first;
1396  const index_t block_m_id = block_mn.second;
1397  const index_t token0 =
1398  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1399 
1400  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1401  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1402  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1403  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1404  constexpr auto AKThreads = AK0Threads * AK1Threads;
1405  constexpr auto AMRepeats = MPerBlock / AMThreads;
1406  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1407 
1408  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1409  return;
1411  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1412  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1413  index_t token_offset = fused_token & 0xffffff;
1414  if constexpr(!IsInputGemm)
1415  {
1416  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1417  }
1418  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1419  });
1420 
1421  const index_t expert_stride =
1422  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1423  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1424  problem.N * (IsInputGemm ? 2 : 1) *
1425  math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
1426 
1427  // N0, K0, Blocksize*KPack
1428  const index_t n_block_data_idx_on_grid =
1429  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1430 
1431  // Gride buffer creation
1432  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1433  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1434  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1435  p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1436 
1437  // A, B scale buffer
1438  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1439  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1440  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1441  p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
1442  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1443 
1444  // lds max alignment
1445  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1446 
1447  // A matrix in LDS memory, dst of blockwise copy
1448  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1449 
1450  // B matrix in LDS memory, dst of blockwise copy
1451  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1452 
1453  // A matrix blockwise copy
1454  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1456  AElementwiseOperation,
1460  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1461  ABlockTransferThreadClusterArrangeOrder,
1462  ADataType,
1463  ADataType,
1464  decltype(a_grid_desc_ak0_m_ak1),
1465  decltype(a_block_desc_ak0_m_ak1),
1466  ABlockTransferSrcAccessOrder,
1468  ABlockTransferSrcVectorDim,
1469  2,
1470  ABlockTransferSrcScalarPerVector,
1471  ABlockTransferDstScalarPerVector_AK1,
1472  1,
1473  1,
1474  AThreadTransferSrcResetCoordinateAfterRun,
1475  true,
1476  IndexType,
1477  1,
1478  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1479  make_multi_index(0, 0, 0),
1480  a_element_op,
1481  a_block_desc_ak0_m_ak1,
1482  make_multi_index(0, 0, 0),
1484  gather_offsets);
1485 
1486  // B matrix blockwise copy
1487  auto b_blockwise_copy =
1489  BElementwiseOperation,
1493  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1494  BBlockTransferThreadClusterArrangeOrder,
1495  BDataType,
1496  BDataType,
1497  decltype(b_grid_desc_bk0_n_bk1),
1498  decltype(b_block_desc_bk0_n_bk1),
1499  BBlockTransferSrcAccessOrder,
1501  BBlockTransferSrcVectorDim,
1502  2,
1503  BBlockTransferSrcScalarPerVector,
1504  BBlockTransferDstScalarPerVector_BK1,
1505  1,
1506  1,
1507  BThreadTransferSrcResetCoordinateAfterRun,
1508  true,
1509  BlockwiseGemmPipe::GlobalBufferNum>(
1510  b_grid_desc_bk0_n_bk1,
1511  make_multi_index(0, n_block_data_idx_on_grid, 0),
1512  b_element_op,
1513  b_block_desc_bk0_n_bk1,
1514  make_multi_index(0, 0, 0),
1516 
1517  // LDS allocation for A and B: be careful of alignment
1518  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1519  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1520 
1521  // Cast after lds
1522  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1523  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1524 
1525  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1526  reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
1527  a_block_space_size_aligned * sizeof(ADataType)),
1528  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1529 
1530  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1531  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1532 
1533  // Blockwise GEMM pipeline
1534  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1535  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1536  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1537  decltype(c_thread_buf) c_thread_buf_up;
1538 
1540  float,
1541  c_thread_buf.num_of_v_,
1542  c_thread_buf.s_per_v,
1543  true>
1544  c_thread_buf_fp32;
1545 
1546  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1547  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1548  KPerBlock);
1549 
1550  // a and b scale processing
1551  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1552  const auto waveId_m = wave_idx[I0];
1553  const auto waveId_n = wave_idx[I1];
1554 
1555  auto thread_offset_shuffled =
1556  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
1557 
1558  auto a_thread_offset_m = waveId_m;
1559 
1560  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1561  AScaleDataType,
1562  AScaleDataType,
1563  decltype(a_scale_grid_desc_am_ak),
1564  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1565  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
1566  Sequence<0, 1, 2>, // DimAccessOrder
1567  2, // SrcVectorDim
1568  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
1569  1, // SrcScalarStrideInVector
1570  true>(a_scale_grid_desc_am_ak,
1571  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
1572  0,
1573  thread_offset_shuffled / scale_pack_size_a));
1574 
1575  // B scale load
1576  auto b_thread_offset_n = waveId_n;
1577 
1578  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1579  BScaleDataType,
1580  BScaleDataType,
1581  decltype(b_scale_grid_desc_bn_ak),
1582  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1583  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1584  Sequence<0, 1, 2>, // DimAccessOrder
1585  2, // SrcVectorDim
1586  KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
1587  1, // SrcScalarStrideInVector
1588  true>(b_scale_grid_desc_bn_ak,
1589  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1590  0,
1591  thread_offset_shuffled / scale_pack_size_b));
1592 
1593  if constexpr(IsInputGemm)
1594  {
1595  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1596  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1597  auto b_block_buf_up = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1598  reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
1599  a_block_space_size_aligned * sizeof(ADataType) +
1600  b_block_space_size_aligned * sizeof(BDataType)),
1601  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1602 
1603  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
1604  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1605  p_b_grid_up + expert_id * expert_stride,
1606  b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1607 
1608  auto b_blockwise_copy_up =
1610  BElementwiseOperation,
1614  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1615  BBlockTransferThreadClusterArrangeOrder,
1616  BDataType,
1617  BDataType,
1618  decltype(b_grid_desc_bk0_n_bk1),
1619  decltype(b_block_desc_bk0_n_bk1),
1620  BBlockTransferSrcAccessOrder,
1622  BBlockTransferSrcVectorDim,
1623  2,
1624  BBlockTransferSrcScalarPerVector,
1625  BBlockTransferDstScalarPerVector_BK1,
1626  1,
1627  1,
1628  BThreadTransferSrcResetCoordinateAfterRun,
1629  true,
1630  BlockwiseGemmPipe::GlobalBufferNum>(
1631  b_grid_desc_bk0_n_bk1,
1632  make_multi_index(0, n_block_data_idx_on_grid, 0),
1633  b_element_op,
1634  b_block_desc_bk0_n_bk1,
1635  make_multi_index(0, 0, 0),
1637 
1638  const BScaleDataType* p_b_scale_grid_up =
1639  p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
1640  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1641  p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
1642  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1643 
1644  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
1645  BScaleDataType,
1646  BScaleDataType,
1647  decltype(b_scale_grid_desc_bn_ak),
1648  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1649  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1650  Sequence<0, 1, 2>, // DimAccessOrder
1651  2, // SrcVectorDim
1652  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1653  1, // SrcScalarStrideInVector
1654  true>(
1655  b_scale_grid_desc_bn_ak,
1656  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1657  0,
1658  thread_offset_shuffled / scale_pack_size_b));
1659 
1660  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1661  // A
1662  a_grid_desc_ak0_m_ak1,
1663  a_block_desc_ak0_m_ak1,
1664  a_blockwise_copy,
1665  a_grid_buf,
1666  a_block_buf,
1667  a_block_slice_copy_step,
1668  // Gate and Up
1669  b_grid_desc_bk0_n_bk1,
1670  b_block_desc_bk0_n_bk1,
1671  b_blockwise_copy,
1672  b_blockwise_copy_up,
1673  b_grid_buf,
1674  b_grid_buf_up,
1675  b_block_buf,
1676  b_block_buf_up,
1677  b_block_slice_copy_step,
1678  // C
1679  c_thread_buf,
1680  c_thread_buf_up,
1681  // A scale
1682  a_scale_grid_desc_am_ak,
1683  a_scale_thread_copy,
1684  a_scale_grid_buf,
1685  // Gate and Up scale
1686  b_scale_grid_desc_bn_ak,
1687  b_scale_thread_copy,
1688  b_scale_thread_copy_up,
1689  b_scale_grid_buf,
1690  b_scale_grid_buf_up,
1691  num_k_block_main_loop);
1692  }
1693  else
1694  {
1695  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1696  a_grid_desc_ak0_m_ak1, // A
1697  a_block_desc_ak0_m_ak1,
1698  a_blockwise_copy,
1699  a_grid_buf,
1700  a_block_buf,
1701  a_block_slice_copy_step,
1702  b_grid_desc_bk0_n_bk1, // B
1703  b_block_desc_bk0_n_bk1,
1704  b_blockwise_copy,
1705  b_grid_buf,
1706  b_block_buf,
1707  b_block_slice_copy_step,
1708  c_thread_buf, // C
1709  a_scale_grid_desc_am_ak, // A scale
1710  a_scale_thread_copy,
1711  a_scale_grid_buf,
1712  b_scale_grid_desc_bn_ak, // B scale
1713  b_scale_thread_copy,
1714  b_scale_grid_buf,
1715  num_k_block_main_loop);
1716  }
1717 
1718  // shuffle C and write out
1719  {
1720  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1721  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1722  "wrong!");
1723  static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
1724  CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
1725  "wrong!");
1726 
1727  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1728  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1729 
1730  // TODO: hacky, fix it!
1731  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1732  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1733 
1734  // TODO: hacky, fix it!
1735  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1736  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1737  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1738 
1739  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1740  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1741  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1742  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1743  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1744  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1745  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1746  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1747  constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
1748  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
1749 
1750  // mul scales
1751  static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
1752  static_assert(M5 == 4);
1753  const index_t m1 = get_warp_local_1d_id() / NWave; // Mwave id
1754  const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
1755 
1756  vector_type<float, 4> topk_weights; // for gemm2 only
1757  static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
1758  static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
1759  static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
1760  static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
1761  static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
1762  const index_t m_pos = block_m_id * MPerBlock +
1763  m0 * M2 * M1 * M3 * M4 * M5 +
1764  m1 * M2 * M3 * M4 * M5 +
1765  imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
1766 
1767  if constexpr(MulRoutedWeight)
1768  {
1769  topk_weights =
1770  *c_style_pointer_cast<const vector_type<float, M5>*>(
1771  p_ds_grid[I2] + m_pos);
1772  }
1773  static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
1774  constexpr index_t c_offset =
1775  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1776  make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
1777  constexpr auto cidx = Number<c_offset>{};
1778 
1779  if constexpr(IsInputGemm) // gu fusion
1780  {
1781  if constexpr(ActivationOperation ==
1783  {
1784  float gate = c_thread_buf[cidx];
1785  float up = c_thread_buf_up[cidx];
1786  if constexpr(MulRoutedWeight)
1787  {
1788  gate = gate * topk_weights.AsType<float>()[m5];
1789  up = up * topk_weights.AsType<float>()[m5];
1790  }
1792  c_thread_buf_fp32(cidx) = gate * up;
1793  }
1794  else if(ActivationOperation == Activation::gelu_and_mul)
1795  {
1796  float gate = c_thread_buf[cidx];
1797  float up = c_thread_buf_up[cidx];
1798  if constexpr(MulRoutedWeight)
1799  {
1800  gate = gate * topk_weights.AsType<float>()[m5];
1801  up = up * topk_weights.AsType<float>()[m5];
1802  }
1804  c_thread_buf_fp32(cidx) = gate * up;
1805 
1806  /*float gate = c_thread_buf[cidx];
1807  float up = c_thread_buf_up[cidx];
1808  if constexpr(MulRoutedWeight)
1809  {
1810  gate = gate * topk_weights.AsType<float>()[m5];
1811  //up = up * topk_weights.AsType<float>()[m5];
1812  }
1813  tensor_operation::element_wise::Gelu{}(gate, gate);
1814  c_thread_buf_fp32(cidx) = up;*/
1815  }
1816  }
1817  else
1818  {
1819  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1820  if constexpr(MulRoutedWeight)
1821  {
1822  c_thread_buf_fp32(cidx) =
1823  topk_weights.AsType<float>()[m5] *
1824  c_thread_buf_fp32[cidx];
1825  }
1826  }
1827  });
1828  });
1829  });
1830  });
1831  });
1832  });
1833 
1834  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1836 
1837  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1838  static_cast<CShuffleDataType*>(p_shared),
1839  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1840 
1841  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1842  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1843  make_tuple(
1847  // per shuffle
1848  M1, // M1 = MWave
1849  M2, // M2 = MXdlPack
1850  M3, // M3 * M4 * M5 = MPerXdl
1851  M4,
1852  M5)),
1856  // per shuffle
1857  N1, // N1 = NWave
1858  N2, // N2 = NXdlPack
1859  N3))), // N3 = NPerXdl
1863  Sequence<>{},
1865 
1866  // calculate origin of thread output tensor on global memory
1867  // blockwise GEMM c matrix starting index
1868  const auto c_thread_mtx_on_block =
1869  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1870 
1871  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1872  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1873 
1874  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1876  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
1878  make_tuple(Sequence<0>{}));
1879 
1880  const auto m_thread_data_on_block_idx =
1881  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1882  make_multi_index(m_thread_data_on_block));
1883 
1884  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1886  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
1888  make_tuple(Sequence<0>{}));
1889 
1890  const auto n_thread_data_on_block_idx =
1891  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1892  make_multi_index(n_thread_data_on_block));
1893 
1894  // shuffle: threadwise copy C from VGPR to LDS
1895  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1896  AccDataType,
1897  CShuffleDataType,
1898  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1899  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1901  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
1902  CShuffleNXdlPerWavePerShuffle / NXdlPack,
1903  I1,
1904  I1,
1905  M2,
1906  N2,
1907  M3,
1908  I1,
1909  M5,
1910  I1>,
1912  9,
1913  1,
1915  1,
1916  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1917  make_multi_index(0,
1918  0,
1919  m_thread_data_on_block_idx[I1],
1920  n_thread_data_on_block_idx[I1],
1921  m_thread_data_on_block_idx[I2],
1922  n_thread_data_on_block_idx[I2],
1923  m_thread_data_on_block_idx[I3],
1924  m_thread_data_on_block_idx[I4],
1925  m_thread_data_on_block_idx[I5],
1926  n_thread_data_on_block_idx[I3]),
1928 
1929  using EDataType = CDataType;
1930 
1931  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1932  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1933 
1934  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1936  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1937 
1938  const auto ds_grid_buf = generate_tuple(
1939  [&](auto i) {
1940  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1941  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1942  },
1943  Number<NumDTensor>{});
1944 
1945  // tuple of reference to C/Ds tensor descriptors
1946  const auto c_ds_desc_refs = concat_tuple_of_reference(
1947  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1948  generate_tie([&](auto i) -> const auto& // return type should be reference
1949  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1950  Number<NumDTensor>{}));
1951 
1952  // tuple of reference to C/Ds tensor descriptors
1953  const auto c_ds_buf_refs = concat_tuple_of_reference(
1954  tie(c_shuffle_block_buf),
1955  generate_tie([&](auto i) -> const auto& // return type should be reference
1956  { return ds_grid_buf[i]; },
1957  Number<NumDTensor>{}));
1958 
1959  // tuple of starting index of C/Ds blockwise copy
1960  const auto idx_c_ds_block_begin =
1963  [&](auto) {
1964  return make_multi_index(block_m_id, 0, block_n_id, 0);
1965  // return make_multi_index(block_work_idx[I0], 0,
1966  // block_work_idx[I1], 0);
1967  },
1968  Number<NumDTensor>{}));
1969 
1970  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1971  c_grid_desc_mblock_mperblock_nblock_nperblock;
1972 
1973  using CDEBlockTransferCluster =
1974  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1975  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1976  constexpr index_t scatter_weight_idx = 3; // hack fix felix
1977  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1979  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1981  decltype(c_ds_desc_refs),
1982  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1983  CElementwiseOperation,
1984  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
1985  // Sequence support
1986  // arbitray type
1987  Sequence<1,
1988  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1989  1,
1990  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1991  CDEBlockTransferCluster,
1992  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1993  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1994  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1995  3, // index_t SrcVectorDim,
1996  3, // index_t DstVectorDim,
1997  CDEShuffleBlockTransferScalarPerVectors,
2002  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2003  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2004  IndexType,
2005  1, // ScatterDim
2006  true, // OutputScatter: false, only use scatter weights
2007  scatter_weight_idx // ScatterWeightIdx: ascale
2008  >{c_ds_desc_refs,
2009  idx_c_ds_block_begin,
2010  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2011  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2012  c_element_op};
2013 
2014  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2015  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2016 
2017  constexpr auto sfc_c_vgpr =
2018  SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2019  NXdlPerWave / NXdlPack,
2020  1,
2021  1,
2022  MXdlPack,
2023  NXdlPack,
2024  M2,
2025  1,
2026  M4,
2027  1>,
2029  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2030  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2031  1,
2032  1,
2033  MXdlPack,
2034  NXdlPack,
2035  M2,
2036  1,
2037  M4,
2038  1>>{};
2039 
2040  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2041 
2042  // space filling curve for shuffled blockwise C/D/E
2043  constexpr auto sfc_cde_block =
2046  Sequence<1,
2047  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2048  1,
2049  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2050 
2051  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2052  constexpr auto EMThreads =
2053  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2054  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2055  constexpr auto ENThreads =
2056  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2057  static_for<0, num_access, 1>{}([&](auto access_id) {
2058  // make sure it's safe to write to LDS
2060 
2061  auto dstidx = sfc_cde_block.GetIndex(access_id);
2062  const index_t c_token_pos =
2063  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2064  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2065  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2066  IndexType token_offset = fused_token & 0xffffff;
2067  if constexpr(IsInputGemm)
2068  {
2069  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2070  }
2071  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2072  });
2073 
2074  block_sync_lds();
2075 
2076  // each thread write its data from VGPR to LDS
2077  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2078  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2079  c_thread_buf_fp32,
2080  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2081  c_shuffle_block_buf);
2082 
2083  // make sure it's safe to read from LDS
2084  block_sync_lds();
2085 
2086  // each block copy its data from LDS to global
2087  cde_block_copy_lds_and_global.Run(
2088  c_ds_desc_refs,
2089  c_ds_buf_refs,
2090  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2091  tie(c_grid_buf),
2092  scatter_offsets);
2093 
2094  if constexpr(access_id < num_access - 1)
2095  {
2096  constexpr auto cde_lds_and_global_step =
2097  sfc_cde_block.GetForwardStep(access_id);
2098 
2099  // move on Ds
2100  static_for<0, NumDTensor, 1>{}([&](auto i) {
2101  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2102  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2103  });
2104 
2105  // move on E
2106  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2107  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2108  I0,
2109  cde_lds_and_global_step);
2110  }
2111  });
2112  }
2113  }
2114 
2115 #if 0
2116  template <bool HasMainKBlockLoop,
2117  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2118  TailNumber TailNum = TailNumber::Odd>
2119  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
2120  const index_t* p_sorted_expert_ids,
2121  const index_t* p_max_token_id,
2122  const ADataType* p_a_grid,
2123  const AScaleDataType* p_a_scale_grid,
2124  const BDataType* p_b_grid,
2125  const BScaleDataType* p_b_scale_grid,
2126  DsGridPointer& p_ds_grid,
2127  CDataType* p_c_grid,
2128  void* p_shared,
2129  void* p_shared1,
2130  const Problem& problem,
2131  AElementwiseOperation a_element_op,
2132  BElementwiseOperation b_element_op,
2133  CElementwiseOperation c_element_op)
2134  {
2135  ignore = b_element_op;
2136  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2137  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
2138  problem.MPadded,
2139  problem.K,
2140  problem.KPadded,
2141  problem.StrideA,
2142  problem.AK0);
2143  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
2144  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2145  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
2146  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
2147  problem.MPadded,
2148  problem.N,
2149  problem.NPadded,
2150  problem.StrideC);
2151 
2152  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
2153  make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerXdl),
2154  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
2155  (KXdlPack * 64 / MPerXdl),
2156  64 * KXdlPack * MXdlPack / scale_pack_size_a));
2157 
2158  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
2159  make_tuple(problem.N / (NXdlPack * NPerXdl),
2160  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
2161  (KXdlPack * 64 / NPerXdl),
2162  64 * KXdlPack * NXdlPack / scale_pack_size_b));
2163 
2164  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2166  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2167  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2168  // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
2169  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
2170  if(expert_block_id * MPerBlock >= max_token_id)
2171  return;
2172  const index_t expert_id =
2173  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2174  const auto block_mn = [&]() -> std::pair<int, int> {
2175  if constexpr(NSwizzle)
2176  {
2177  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2178  const index_t prefix_block = ecnt_prefix * problem.NBlock;
2179  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2180  const index_t expert_swizzle =
2181  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
2182  const index_t bid_new = blockIdx.x - prefix_block;
2183  const index_t nid = __builtin_amdgcn_readfirstlane(
2184  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2185  const index_t mid =
2186  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2187  return {nid, mid};
2188  }
2189  else
2190  {
2191  return {blockIdx.x, blockIdx.y};
2192  }
2193  }();
2194 
2195  const index_t block_n_id = block_mn.first;
2196  const index_t block_m_id = block_mn.second;
2197  const index_t token0 =
2198  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2199 
2200  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2201  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2202  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
2203  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
2204  constexpr auto AKThreads = AK0Threads * AK1Threads;
2205  constexpr auto AMRepeats = MPerBlock / AMThreads;
2206  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
2207 
2208  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
2209  return;
2210  StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
2211  static_for<0, AMRepeats, 1>{}([&](auto m0) {
2212  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
2213  index_t token_offset = fused_token & 0xffffff;
2214  if constexpr(!IsInputGemm)
2215  {
2216  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2217  }
2218  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2219  });
2220 
2221  const index_t expert_stride =
2222  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2223  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2224  problem.N * math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
2225 
2226  // N0, K0, Blocksize*KPack
2227  const index_t n_block_data_idx_on_grid =
2228  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
2229 
2230  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2231  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2232 
2233  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2234  p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
2235 
2236  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2237  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2238  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2239  p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
2240  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2241 
2242  // A matrix in LDS memory, dst of blockwise copy
2243  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2244 
2245  // B matrix in LDS memory, dst of blockwise copy
2246  // dummy
2247  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2248  // A matrix blockwise copy
2249  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
2251  AElementwiseOperation,
2254  Sequence<AK0Number, MPerBlock, AK1Number>,
2255  ABlockTransferThreadClusterLengths_AK0_M_AK1,
2256  ABlockTransferThreadClusterArrangeOrder,
2257  ADataType,
2258  LDSTypeA,
2259  decltype(a_grid_desc_ak0_m_ak1),
2260  decltype(a_block_desc_ak0_m_ak1),
2261  ABlockTransferSrcAccessOrder,
2262  Sequence<0, 1, 2>,
2263  ABlockTransferSrcVectorDim,
2264  2,
2265  ABlockTransferSrcScalarPerVector,
2266  ABlockTransferDstScalarPerVector_AK1,
2267  1,
2268  1,
2269  AThreadTransferSrcResetCoordinateAfterRun,
2270  true,
2271  IndexType,
2272  1,
2273  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
2274  make_multi_index(0, 0, 0),
2275  a_element_op,
2276  a_block_desc_ak0_m_ak1,
2277  make_multi_index(0, 0, 0),
2279  gather_offsets);
2280 
2281  // Thread-wise copy
2282  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2283  auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2284  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2285  auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2286  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2287  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2288 
2289  auto b_blockwise_copy =
2290  ThreadwiseTensorSliceTransfer_v2<BDataType,
2291  BDataType,
2292  decltype(b_grid_desc_bpreshuffled),
2293  decltype(b_block_desc_bk0_n_bk1),
2294  Sequence<Number<NXdlPerWave / NXdlPack>{},
2295  I1,
2296  Number<NXdlPack>{},
2297  Number<KRepeat>{},
2298  Number<BK1Value>{}>,
2299  Sequence<1, 2, 0, 3, 4>,
2300  4,
2301  BBlockTransferSrcScalarPerVector,
2302  BThreadTransferSrcResetCoordinateAfterRun,
2303  true>(
2304  b_grid_desc_bpreshuffled,
2305  make_multi_index(n_block_data_idx_on_grid,
2306  get_warp_local_1d_id() % NWave,
2307  0,
2308  0,
2309  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2310 
2311  // LDS allocation for A and B: be careful of alignment
2312  // Cast after lds
2313  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2314  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2315  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2316  static_cast<ADataType*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2317  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2318 
2319  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2320  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0);
2321 
2322  // Blockwise GEMM pipeline
2323  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2324  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2325  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2326  decltype(c_thread_buf) c_thread_buf_up;
2327 
2328  StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
2329  float,
2330  c_thread_buf.num_of_v_,
2331  c_thread_buf.s_per_v,
2332  true>
2333  c_thread_buf_fp32;
2334 
2335  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2336  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2337  KPerBlock);
2338 
2339  // a and b scale processing
2340  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2341  const auto waveId_m = wave_idx[I0];
2342  const auto waveId_n = wave_idx[I1];
2343 
2344  auto thread_offset_shuffled =
2345  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
2346 
2347  auto a_thread_offset_m = waveId_m;
2348 
2349  // get each thread's offset int the scale tensor
2350  const index_t token_scale_pos = block_m_id * MPerBlock;
2351  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2352  return;
2353 
2354  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2355  AScaleDataType,
2356  AScaleDataType,
2357  decltype(a_scale_grid_desc_am_ak),
2358  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2359  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
2360  Sequence<0, 1, 2>, // DimAccessOrder
2361  2, // SrcVectorDim
2362  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
2363  1, // SrcScalarStrideInVector
2364  true>(a_scale_grid_desc_am_ak,
2365  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
2366  0,
2367  thread_offset_shuffled / scale_pack_size_a));
2368 
2369  // B scale load
2370  auto b_thread_offset_n = waveId_n;
2371 
2372  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2373  BScaleDataType,
2374  BScaleDataType,
2375  decltype(b_scale_grid_desc_bn_ak),
2376  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2377  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2378  Sequence<0, 1, 2>, // DimAccessOrder
2379  2, // SrcVectorDim
2380  KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
2381  1, // SrcScalarStrideInVector
2382  true>(b_scale_grid_desc_bn_ak,
2383  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2384  0,
2385  thread_offset_shuffled / scale_pack_size_b));
2386 
2387  if constexpr(IsInputGemm)
2388  {
2389  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
2390  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2391  p_b_grid_up + expert_id * expert_stride / BPackedSize,
2392  b_grid_desc_bpreshuffled.GetElementSpaceSize());
2393  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
2394  BDataType,
2395  BDataType,
2396  decltype(b_grid_desc_bpreshuffled),
2397  decltype(b_block_desc_bk0_n_bk1),
2398  Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
2399  Sequence<1, 2, 0, 3>,
2400  3,
2401  BBlockTransferSrcScalarPerVector,
2402  BThreadTransferSrcResetCoordinateAfterRun,
2403  true>(b_grid_desc_bpreshuffled,
2404  make_multi_index(n_block_data_idx_on_grid,
2405  get_warp_local_1d_id() % NWave,
2406  0,
2407  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2408  const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
2409  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2410  p_b_scale_grid_up + expert_id * expert_scale_stride,
2411  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2412  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
2413  BScaleDataType,
2414  BScaleDataType,
2415  decltype(b_scale_grid_desc_bn_ak),
2416  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2417  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2418  Sequence<0, 1, 2>, // DimAccessOrder
2419  2, // SrcVectorDim
2420  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
2421  1, // SrcScalarStrideInVector
2422  true>(
2423  b_scale_grid_desc_bn_ak,
2424  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2425  0,
2426  thread_offset_shuffled / scale_pack_size_b));
2427 
2428  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2429  a_grid_desc_ak0_m_ak1,
2430  a_block_desc_ak0_m_ak1,
2431  a_blockwise_copy,
2432  a_grid_buf,
2433  a_block_bufs,
2434  a_block_slice_copy_step,
2435  b_grid_desc_bpreshuffled,
2436  b_block_desc_bk0_n_bk1,
2437  b_blockwise_copy,
2438  b_blockwise_copy_up,
2439  b_grid_buf,
2440  b_grid_buf_up,
2441  b_block_bufs,
2442  b_block_slice_copy_step,
2443  c_thread_buf,
2444  c_thread_buf_up,
2445  a_scale_grid_desc_am_ak,
2446  a_scale_thread_copy,
2447  a_scale_grid_buf,
2448  b_scale_grid_desc_bn_ak,
2449  b_scale_thread_copy,
2450  b_scale_thread_copy_up,
2451  b_scale_grid_buf,
2452  b_scale_grid_buf_up,
2453  num_k_block_main_loop);
2454  }
2455  else
2456  {
2457  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2458  a_grid_desc_ak0_m_ak1,
2459  a_block_desc_ak0_m_ak1,
2460  a_blockwise_copy,
2461  a_grid_buf,
2462  a_block_bufs,
2463  a_block_slice_copy_step,
2464  b_grid_desc_bpreshuffled,
2465  b_block_desc_bk0_n_bk1,
2466  b_blockwise_copy,
2467  b_grid_buf,
2468  b_block_bufs,
2469  b_block_slice_copy_step,
2470  c_thread_buf,
2471  a_scale_grid_desc_am_ak,
2472  a_scale_thread_copy,
2473  a_scale_grid_buf,
2474  b_scale_grid_desc_bn_ak,
2475  b_scale_thread_copy,
2476  b_scale_grid_buf,
2477  num_k_block_main_loop);
2478  }
2479 
2480  // shuffle C and write out
2481  {
2482  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2483  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2484  "wrong!");
2485 
2486  // TODO: hacky, fix it!
2487  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2488  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2489 
2490  // TODO: hacky, fix it!
2491  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2492  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2493  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2494 
2495  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2496  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2497  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2498  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2499  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2500  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2501  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2502  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2503 
2504  // mul scales
2505 
2506  static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
2507  static_assert(M4 == 4);
2508  const index_t m1 = get_warp_local_1d_id() / NWave;
2509  const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
2510 
2511  vector_type<float, 4> topk_weights; // for gemm2 only
2512  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
2513  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
2514  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
2515  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2516  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2517  if constexpr(MulRoutedWeight)
2518  {
2519  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
2520  p_ds_grid[I2] + m_pos);
2521  }
2522  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
2523  constexpr index_t c_offset =
2524  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2525  make_tuple(m0 / MXdlPack,
2526  n0 / NXdlPack,
2527  m0 % MXdlPack,
2528  n0 % NXdlPack,
2529  m2 * M4 + m4));
2530  constexpr auto cidx = Number<c_offset>{};
2531 
2532  if constexpr(IsInputGemm) // gu fusion
2533  {
2534  if constexpr(ActivationOperation == Activation::silu_and_mul)
2535  {
2536  float gate = c_thread_buf[cidx];
2537  float up = c_thread_buf_up[cidx];
2538  if constexpr(MulRoutedWeight)
2539  {
2540  gate = gate * topk_weights.AsType<float>()[m4];
2541  up = up * topk_weights.AsType<float>()[m4];
2542  }
2543  tensor_operation::element_wise::Silu{}(gate, gate);
2544  c_thread_buf_fp32(cidx) = gate * up;
2545  }
2546  else if(ActivationOperation == Activation::gelu_and_mul)
2547  {
2548  float gate = c_thread_buf[cidx];
2549  float up = c_thread_buf_up[cidx];
2550  if constexpr(MulRoutedWeight)
2551  {
2552  gate = gate * topk_weights.AsType<float>()[m4];
2553  up = up * topk_weights.AsType<float>()[m4];
2554  }
2555  tensor_operation::element_wise::Gelu{}(gate, gate);
2556  c_thread_buf_fp32(cidx) = gate * up;
2557  }
2558  }
2559  else
2560  {
2561  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2562  if constexpr(MulRoutedWeight)
2563  {
2564  c_thread_buf_fp32(cidx) =
2565  topk_weights.AsType<float>()[m4] * c_thread_buf_fp32[cidx];
2566  }
2567  }
2568  });
2569  });
2570  });
2571  });
2572 
2573  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2575 
2576  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2577  static_cast<CShuffleDataType*>(p_shared),
2578  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2579 
2580  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2581  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2584  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per
2585  // shuffle
2586  M1, // M1 = MWave
2587  M2, // M2 * M3 * M4 = MPerXdl
2588  M3,
2589  M4)),
2592  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per
2593  // shuffle
2594  N1, // N1 = NWave
2595  N2))), // N2 = NPerXdl
2596  make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
2597  make_tuple(
2598  Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
2599 
2600  // calculate origin of thread output tensor on global memory
2601  // blockwise GEMM c matrix starting index
2602  const auto c_thread_mtx_on_block =
2603  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2604 
2605  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2606  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2607 
2608  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2610  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
2611  make_tuple(Sequence<0, 1, 2, 3, 4>{}),
2612  make_tuple(Sequence<0>{}));
2613 
2614  const auto m_thread_data_on_block_idx =
2615  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2616  make_multi_index(m_thread_data_on_block));
2617 
2618  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2621  make_tuple(Sequence<0, 1, 2>{}),
2622  make_tuple(Sequence<0>{}));
2623 
2624  const auto n_thread_data_on_block_idx =
2625  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2626  make_multi_index(n_thread_data_on_block));
2627 
2628  // shuffle: threadwise copy C from VGPR to LDS
2629  auto c_thread_copy_vgpr_to_lds =
2630  ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
2631  CShuffleDataType,
2632  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2633  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2635  Sequence<CShuffleMXdlPerWavePerShuffle,
2636  CShuffleNXdlPerWavePerShuffle,
2637  I1,
2638  I1,
2639  M2,
2640  I1,
2641  M4,
2642  I1>,
2643  Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
2644  7,
2645  1,
2647  1,
2648  true>{
2649  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2650  make_multi_index(0,
2651  0,
2652  m_thread_data_on_block_idx[I1],
2653  n_thread_data_on_block_idx[I1],
2654  m_thread_data_on_block_idx[I2],
2655  m_thread_data_on_block_idx[I3],
2656  m_thread_data_on_block_idx[I4],
2657  n_thread_data_on_block_idx[I2]),
2659 
2660  using EDataType = CDataType;
2661 
2662  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2663  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2664 
2665  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2667  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2668 
2669  const auto ds_grid_buf = generate_tuple(
2670  [&](auto i) {
2671  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2672  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2673  },
2674  Number<NumDTensor>{});
2675 
2676  // tuple of reference to C/Ds tensor descriptors
2677  const auto c_ds_desc_refs = concat_tuple_of_reference(
2678  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2679  generate_tie([&](auto i) -> const auto& // return type should be reference
2680  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2681  Number<NumDTensor>{}));
2682 
2683  // tuple of reference to C/Ds tensor descriptors
2684  const auto c_ds_buf_refs = concat_tuple_of_reference(
2685  tie(c_shuffle_block_buf),
2686  generate_tie([&](auto i) -> const auto& // return type should be reference
2687  { return ds_grid_buf[i]; },
2688  Number<NumDTensor>{}));
2689 
2690  // tuple of starting index of C/Ds blockwise copy
2691  const auto idx_c_ds_block_begin =
2694  [&](auto) {
2695  return make_multi_index(block_m_id, 0, block_n_id, 0);
2696  // return make_multi_index(block_work_idx[I0], 0,
2697  // block_work_idx[I1], 0);
2698  },
2699  Number<NumDTensor>{}));
2700 
2701  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2702  c_grid_desc_mblock_mperblock_nblock_nperblock;
2703 
2704  using CDEBlockTransferCluster =
2705  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2706  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2707  constexpr index_t scatter_weight_idx = 3; // hack fix felix
2708  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2710  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2711  Tuple<EDataType>,
2712  decltype(c_ds_desc_refs),
2713  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2714  CElementwiseOperation,
2715  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
2716  // Sequence support
2717  // arbitray type
2718  Sequence<1,
2719  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2720  1,
2721  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2722  CDEBlockTransferCluster,
2723  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2724  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2725  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2726  3, // index_t SrcVectorDim,
2727  3, // index_t DstVectorDim,
2728  CDEShuffleBlockTransferScalarPerVectors,
2731  Sequence<true>,
2733  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2734  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2735  IndexType,
2736  1, // ScatterDim
2737  true, // OutputScatter: false, only use scatter weights
2738  scatter_weight_idx // ScatterWeightIdx: ascale
2739  >{c_ds_desc_refs,
2740  idx_c_ds_block_begin,
2741  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2742  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2743  c_element_op};
2744 
2745  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2746  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2747  constexpr auto sfc_c_vgpr =
2748  SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
2749  Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
2750  Sequence<CShuffleMXdlPerWavePerShuffle,
2751  CShuffleNXdlPerWavePerShuffle,
2752  1,
2753  1,
2754  M2,
2755  1,
2756  M4,
2757  1>>{};
2758 
2759  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2760 
2761  // space filling curve for shuffled blockwise C/D/E
2762  constexpr auto sfc_cde_block =
2763  SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
2764  Sequence<0, 2, 1, 3>,
2765  Sequence<1,
2766  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2767  1,
2768  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2769 
2770  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2771  constexpr auto EMThreads =
2772  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2773  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2774  constexpr auto ENThreads =
2775  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2776  static_for<0, num_access, 1>{}([&](auto access_id) {
2777  // make sure it's safe to write to LDS
2778  StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
2779 
2780  auto dstidx = sfc_cde_block.GetIndex(access_id);
2781  const index_t c_token_pos =
2782  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2783  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2784  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2785  IndexType token_offset = fused_token & 0xffffff;
2786  if constexpr(IsInputGemm)
2787  {
2788  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2789  }
2790  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2791  });
2792 
2793  block_sync_lds();
2794 
2795  // each thread write its data from VGPR to LDS
2796  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2797  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2798  c_thread_buf_fp32,
2799  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2800  c_shuffle_block_buf);
2801 
2802  // make sure it's safe to read from LDS
2803  block_sync_lds();
2804 
2805  // each block copy its data from LDS to global
2806  cde_block_copy_lds_and_global.Run(
2807  c_ds_desc_refs,
2808  c_ds_buf_refs,
2809  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2810  tie(c_grid_buf),
2811  scatter_offsets);
2812 
2813  if constexpr(access_id < num_access - 1)
2814  {
2815  constexpr auto cde_lds_and_global_step =
2816  sfc_cde_block.GetForwardStep(access_id);
2817 
2818  // move on Ds
2819  static_for<0, NumDTensor, 1>{}([&](auto i) {
2820  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2821  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2822  });
2823 
2824  // move on E
2825  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2826  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2827  I0,
2828  cde_lds_and_global_step);
2829  }
2830  });
2831  }
2832  }
2833 #endif
2834 };
2835 
2836 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:266
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:23
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:275
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:87
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:297
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
constexpr auto BlockGemmMXNBSPipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp:37
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: gridwise_moe_mx_gemm_bns.hpp:648
const ADataType * p_a_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:711
const index_t * p_sorted_token_ids
Definition: gridwise_moe_mx_gemm_bns.hpp:708
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_mx_gemm_bns.hpp:709
const index_t * p_max_token_id
Definition: gridwise_moe_mx_gemm_bns.hpp:710
DsGridPointer p_ds_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:715
const CElementwiseOperation c_element_op
Definition: gridwise_moe_mx_gemm_bns.hpp:720
CDataType * p_c_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:716
const AElementwiseOperation a_element_op
Definition: gridwise_moe_mx_gemm_bns.hpp:718
const BScaleDataType * p_b_scale_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:714
const BDataType * p_b_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:713
const AScaleDataType * p_a_scale_grid
Definition: gridwise_moe_mx_gemm_bns.hpp:712
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_mx_gemm_bns.hpp:649
const BElementwiseOperation b_element_op
Definition: gridwise_moe_mx_gemm_bns.hpp:719
Definition: gridwise_moe_mx_gemm_bns.hpp:576
index_t M
Definition: gridwise_moe_mx_gemm_bns.hpp:626
index_t TopK
Definition: gridwise_moe_mx_gemm_bns.hpp:625
index_t NPadded
Definition: gridwise_moe_mx_gemm_bns.hpp:637
index_t MPadded
Definition: gridwise_moe_mx_gemm_bns.hpp:636
index_t StrideScaleB
Definition: gridwise_moe_mx_gemm_bns.hpp:632
index_t StrideScaleA
Definition: gridwise_moe_mx_gemm_bns.hpp:630
index_t MBlock
Definition: gridwise_moe_mx_gemm_bns.hpp:642
index_t StrideC
Definition: gridwise_moe_mx_gemm_bns.hpp:634
index_t AK0
Definition: gridwise_moe_mx_gemm_bns.hpp:640
index_t KPadded
Definition: gridwise_moe_mx_gemm_bns.hpp:639
index_t NBlock
Definition: gridwise_moe_mx_gemm_bns.hpp:643
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_mx_gemm_bns.hpp:577
index_t StrideA
Definition: gridwise_moe_mx_gemm_bns.hpp:629
index_t StrideB
Definition: gridwise_moe_mx_gemm_bns.hpp:631
index_t KBatch
Definition: gridwise_moe_mx_gemm_bns.hpp:635
index_t BK0
Definition: gridwise_moe_mx_gemm_bns.hpp:641
index_t KRead
Definition: gridwise_moe_mx_gemm_bns.hpp:638
__host__ void Print() const
Definition: gridwise_moe_mx_gemm_bns.hpp:612
index_t K
Definition: gridwise_moe_mx_gemm_bns.hpp:628
index_t N
Definition: gridwise_moe_mx_gemm_bns.hpp:627
index_t NumTokens
Definition: gridwise_moe_mx_gemm_bns.hpp:624
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_mx_gemm_bns.hpp:633
Definition: gridwise_moe_mx_gemm_bns.hpp:724
index_t a_k_split_offset
Definition: gridwise_moe_mx_gemm_bns.hpp:778
index_t b_k_split_offset
Definition: gridwise_moe_mx_gemm_bns.hpp:779
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_mx_gemm_bns.hpp:725
index_t b_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bns.hpp:781
index_t a_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bns.hpp:780
Definition: gridwise_moe_mx_gemm_bns.hpp:173
static constexpr auto is_scale_mfma
Definition: gridwise_moe_mx_gemm_bns.hpp:206
static constexpr auto I1
Definition: gridwise_moe_mx_gemm_bns.hpp:178
static constexpr index_t scale_pack_size_a
Definition: gridwise_moe_mx_gemm_bns.hpp:1309
static constexpr index_t NumDTensor
Definition: gridwise_moe_mx_gemm_bns.hpp:196
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_mx_gemm_bns.hpp:244
static constexpr auto MXdlPack
Definition: gridwise_moe_mx_gemm_bns.hpp:198
static constexpr auto BK0Number
Definition: gridwise_moe_mx_gemm_bns.hpp:192
static constexpr auto BK1Number
Definition: gridwise_moe_mx_gemm_bns.hpp:194
ADataType LDSTypeA
Definition: gridwise_moe_mx_gemm_bns.hpp:174
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm_bns.hpp:1275
static constexpr bool is_single_rate_mfma
Definition: gridwise_moe_mx_gemm_bns.hpp:205
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_mx_gemm_bns.hpp:232
static constexpr auto I5
Definition: gridwise_moe_mx_gemm_bns.hpp:182
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm_bns.hpp:397
BDataType LDSTypeB
Definition: gridwise_moe_mx_gemm_bns.hpp:175
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm_bns.hpp:1282
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_mx_gemm_bns.hpp:230
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm_bns.hpp:315
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:1097
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm_bns.hpp:552
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_mx_gemm_bns.hpp:1319
static constexpr auto I4
Definition: gridwise_moe_mx_gemm_bns.hpp:181
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_mx_gemm_bns.hpp:1059
static constexpr index_t APackedSize
Definition: gridwise_moe_mx_gemm_bns.hpp:202
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_mx_gemm_bns.hpp:219
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bns.hpp:564
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_mx_gemm_bns.hpp:249
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bns.hpp:259
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_mx_gemm_bns.hpp:498
static constexpr auto I8
Definition: gridwise_moe_mx_gemm_bns.hpp:185
static constexpr auto I7
Definition: gridwise_moe_mx_gemm_bns.hpp:184
static constexpr auto KXdlPack
Definition: gridwise_moe_mx_gemm_bns.hpp:200
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bns.hpp:1290
static constexpr auto I2
Definition: gridwise_moe_mx_gemm_bns.hpp:179
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_mx_gemm_bns.hpp:488
static constexpr index_t scale_pack_size_b
Definition: gridwise_moe_mx_gemm_bns.hpp:1310
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bns.hpp:277
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_mx_gemm_bns.hpp:299
remove_cvref_t< decltype(BlockGemmMXNBSPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_mx_gemm_bns.hpp:1057
static constexpr auto I6
Definition: gridwise_moe_mx_gemm_bns.hpp:183
static constexpr index_t KPack
Definition: gridwise_moe_mx_gemm_bns.hpp:213
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bns.hpp:265
static constexpr auto I3
Definition: gridwise_moe_mx_gemm_bns.hpp:180
static constexpr auto AK1Number
Definition: gridwise_moe_mx_gemm_bns.hpp:193
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_mx_gemm_bns.hpp:901
static constexpr index_t SortedTileSize
Definition: gridwise_moe_mx_gemm_bns.hpp:217
static constexpr auto I9
Definition: gridwise_moe_mx_gemm_bns.hpp:186
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_mx_gemm_bns.hpp:531
static constexpr auto AK0Number
Definition: gridwise_moe_mx_gemm_bns.hpp:191
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_mx_gemm_bns.hpp:1014
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_mx_gemm_bns.hpp:284
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_mx_gemm_bns.hpp:188
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_mx_gemm_bns.hpp:254
static constexpr auto NXdlPack
Definition: gridwise_moe_mx_gemm_bns.hpp:199
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_mx_gemm_bns.hpp:289
static constexpr index_t BPackedSize
Definition: gridwise_moe_mx_gemm_bns.hpp:203
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_mx_gemm_bns.hpp:507
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bns.hpp:271
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_mx_gemm_bns.hpp:784
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm_bns.hpp:234
static constexpr auto I0
Definition: gridwise_moe_mx_gemm_bns.hpp:177
Definition: xdlops_gemm.hpp:942
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1343
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:25
Definition: data_type.hpp:41
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:197
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:981
Definition: unary_element_wise_operation.hpp:308
Definition: unary_element_wise_operation.hpp:1023
Definition: dtype_vector.hpp:10
#define CK_ENV(name)
Definition: env.hpp:129