/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp Source File
gridwise_moe_mx_gemm_bpreshuffle.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/env.hpp"
16 
20 
21 #define DEBUG_LOG 0
22 
23 namespace ck {
24 
25 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
26 // kernel function Blockers:
27 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
28 // two lds chunks.
29 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
30 // buffer when we declare __shared__ inside blkgemmpipe
31 
33 {
34  gelu_and_mul = 0,
35  silu_and_mul = 1
36 };
37 
38 template <typename GridwiseGemm,
39  bool HasMainKBlockLoop,
40  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
41  index_t MinimumOccupancy = 1,
42  TailNumber TailNum = TailNumber::Even>
43 __global__ void
44 #if CK_USE_LAUNCH_BOUNDS
45 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
46 #endif
47  // __attribute__((amdgpu_waves_per_eu(1, 1)))
48  kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
49 {
50 #if defined(__gfx9__)
51  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
52 
53  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
54 
55  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
56  karg.p_sorted_token_ids,
57  karg.p_sorted_expert_ids,
58  karg.p_max_token_id,
59  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
60  karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
61  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
62  karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
63  karg.p_ds_grid,
64  karg.p_c_grid,
65  p_shared,
66  karg,
67  karg.a_element_op,
68  karg.b_element_op,
69  karg.c_element_op);
70 #else
71  ignore = karg;
72 #endif // end of if (defined(__gfx9__))
73 }
74 
75 template <typename GridwiseGemm,
76  bool HasMainKBlockLoop,
77  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
78  index_t MinimumOccupancy = 1,
79  TailNumber TailNum = TailNumber::Even>
80 __global__ void
81 #if CK_USE_LAUNCH_BOUNDS
82 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
83 #endif
84  // __attribute__((amdgpu_waves_per_eu(1, 1)))
85  kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
86 {
87 #if defined(__gfx9__)
88  if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
89  {
90  __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
91  __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
92 
93  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
94 
95  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
96  karg.p_sorted_token_ids,
97  karg.p_sorted_expert_ids,
98  karg.p_max_token_id,
99  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
100  karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
101  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
102  karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
103  karg.p_ds_grid,
104  karg.p_c_grid,
105  p_shared_0,
106  p_shared_1,
107  karg,
108  karg.a_element_op,
109  karg.b_element_op,
110  karg.c_element_op);
111  }
112 #else
113  ignore = karg;
114 #endif // end of if (defined(__gfx9__))
115 }
116 
117 template <typename ALayout,
118  typename BLayout,
119  typename DsLayout,
120  typename CLayout,
121  typename ADataType,
122  typename AScaleDataType,
123  typename BDataType,
124  typename BScaleDataType,
125  typename AccDataType,
126  typename CShuffleDataType,
127  typename DsDataType,
128  typename CDataType,
129  typename AElementwiseOperation,
130  typename BElementwiseOperation,
131  typename CElementwiseOperation,
133  index_t ScaleBlockSize, // Scaling block size
134  index_t BlockSize, // Thread block size
135  index_t MPerBlock,
136  index_t NPerBlock,
137  index_t KPerBlock,
138  index_t AK1Value,
139  index_t BK1Value,
140  index_t MPerXdl,
141  index_t NPerXdl,
142  index_t MXdlPerWave,
143  index_t NXdlPerWave,
144  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
145  typename ABlockTransferThreadClusterArrangeOrder,
146  typename ABlockTransferSrcAccessOrder,
147  index_t ABlockTransferSrcVectorDim,
148  index_t ABlockTransferSrcScalarPerVector,
149  index_t ABlockTransferDstScalarPerVector_AK1,
150  bool AThreadTransferSrcResetCoordinateAfterRun,
151  index_t ABlockLdsExtraM,
152  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
153  typename BBlockTransferThreadClusterArrangeOrder,
154  typename BBlockTransferSrcAccessOrder,
155  index_t BBlockTransferSrcVectorDim,
156  index_t BBlockTransferSrcScalarPerVector,
157  index_t BBlockTransferDstScalarPerVector_BK1,
158  bool BThreadTransferSrcResetCoordinateAfterRun,
159  index_t BBlockLdsExtraN,
160  index_t CShuffleMXdlPerWavePerShuffle,
161  index_t CShuffleNXdlPerWavePerShuffle,
162  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
163  typename CDEShuffleBlockTransferScalarPerVectors,
166  index_t ActivationOperation = 0,
167  bool NSwizzle = false,
168  bool IsInputGemm = true,
169  bool MulRoutedWeight = true,
170  typename IndexType = index_t,
171  typename ComputeTypeA = ADataType,
172  typename ComputeTypeB = BDataType>
174 {
175  using LDSTypeA = ADataType;
176  using LDSTypeB = BDataType;
177 
178  static constexpr auto I0 = Number<0>{};
179  static constexpr auto I1 = Number<1>{};
180  static constexpr auto I2 = Number<2>{};
181  static constexpr auto I3 = Number<3>{};
182  static constexpr auto I4 = Number<4>{};
183  static constexpr auto I5 = Number<5>{};
184  static constexpr auto I6 = Number<6>{};
185  static constexpr auto I7 = Number<7>{};
186  static constexpr auto I8 = Number<8>{};
187  static constexpr auto I9 = Number<9>{};
188 
190  CDEShuffleBlockTransferScalarPerVectors{}[I0];
191  // K1 should be Number<...>
192  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
193  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
194  static constexpr auto AK1Number = Number<AK1Value>{};
195  static constexpr auto BK1Number = Number<BK1Value>{};
196 
197  static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
198  static constexpr bool is_single_rate_mfma = false;
199  static constexpr auto is_scale_mfma = true;
200 
201  static constexpr index_t NumDTensor = DsDataType::Size();
202 
203  static constexpr auto MXdlPack = 2;
204  static constexpr auto NXdlPack = 2;
205  static constexpr auto KXdlPack = 2;
206 
207  //> KPack is at least the k_per_blk of selected mfma
208  //
209  // Should be a multiple of k_per_blk.
210  // TODO: Move this to blockwise pipeline base
211  // KPack in packed data types for pk A/B
212 
213  static constexpr index_t APackedSize = packed_size_v<ADataType>;
214  static constexpr index_t BPackedSize = packed_size_v<BDataType>;
215 
216  using mfma_selector = MfmaSelector<ComputeTypeA,
217  MPerXdl,
218  NPerXdl,
219  ComputeTypeB,
221  is_scale_mfma>;
222  static constexpr index_t KPack =
224 
225  static constexpr index_t NLane = NPerXdl;
226  static constexpr index_t KLane = 64 / NLane;
227  static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
228  static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
229 
230  // static constexpr index_t NumTokens = 1;
231  static constexpr index_t SortedTileSize = MPerBlock;
232 
234  static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
235  static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
236  static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
237  "A scale pack data type too large!");
238  static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
239  "B scale pack data type too large!");
240 
241  static constexpr auto MakeDsGridPointer()
242  {
243  return generate_tuple(
244  [&](auto i) {
245  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
246 
247  return static_cast<const DDataType*>(nullptr);
248  },
250  }
251 
252  using DsGridPointer = decltype(MakeDsGridPointer());
253 
255 
256  __host__ static auto CalculateGridSize(index_t M, index_t N)
257  {
258  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
259  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
260  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
261  const index_t gridy = NSwizzle ? 1 : mblock;
262 
263  return std::make_tuple(gridx, gridy, 1);
264  }
265 
266  __host__ static auto CalculateMPadded(index_t M)
267  {
268  return math::integer_least_multiple(M, MPerBlock);
269  }
270 
271  __host__ static auto CalculateNPadded(index_t N)
272  {
273  return math::integer_least_multiple(N, NPerBlock);
274  }
275 
276  __host__ __device__ static auto CalculateBN0Shuffled(index_t N)
277  {
278  return math::integer_divide_ceil(N, NLane);
279  }
280  __host__ __device__ static auto CalculateBK0Shuffled(index_t K)
281  {
283  }
284 
285  __host__ static auto CalculateKPadded(index_t K)
286  {
287  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
288  }
289 
290  __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
291  {
292  auto K_t = K_Batch * KPerBlock;
293  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
294  }
295 
296  __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
297  {
298  auto K_t = K_Batch * KPerBlock;
299  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
300  }
301 
302  __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
303  {
304  auto K_t = K_Batch * KPerBlock;
305  return (K + K_t - 1) / K_t * KPerBlock;
306  }
307 
308  __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
309  {
310  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
311  auto K_t = K_Batch * KReadVec;
312  return (K + K_t - 1) / K_t * KReadVec;
313  }
314 
315  __host__ static auto CalculateMBlock(index_t M)
316  {
317  return math::integer_divide_ceil(M, MPerBlock);
318  }
319 
320  __host__ static auto CalculateNBlock(index_t N)
321  {
322  return math::integer_divide_ceil(N, NPerBlock);
323  }
324 
325  template <index_t MNXdlPerWave,
326  index_t MNWaves,
327  index_t MNXdlPack,
328  index_t MNPerXdl,
329  bool IsXor,
330  typename TileDesc_K0_MN_K1>
331  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
332  {
333  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
334  constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
335  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
336 
337  if constexpr(IsXor)
338  {
339  constexpr auto permuted_desc = transform_tensor_descriptor(
340  TileDesc_K0_MN_K1{},
345 
347  permuted_desc,
348  make_tuple(
351  Number<MNWaves>{},
353  Number<MNPerXdl>{}))),
356  }
357  else
358  {
360  TileDesc_K0_MN_K1{},
361  make_tuple(
364  Number<MNWaves>{},
366  Number<MNPerXdl>{}))),
369  }
370  }
371 
372  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
373  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
374  {
375  const auto a_grid_desc_mraw_kraw = [&]() {
376  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
377  {
378  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
379  }
380  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
381  {
382  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
383  }
384  }();
385 
387 
388  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
389  GemmSpec == GemmSpecialization::MNKPadding)
390  {
391  // pad both M and K
392  const auto a_grid_desc_m_k =
393  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
395  make_right_pad_transform(K, KPad - K)),
398 
399  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
400  a_grid_desc_m_k,
405 
406  return a_grid_desc_ak0_m_ak1;
407  }
408  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
409  GemmSpec == GemmSpecialization::MNPadding)
410  {
411  // pad M, but not K
412  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
413  a_grid_desc_mraw_kraw,
415  make_right_pad_transform(M, MPad - M)),
418 
419  return a_grid_desc_ak0_m_ak1;
420  }
421  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
422  GemmSpec == GemmSpecialization::NKPadding)
423  {
424  // pad K, but not M
425  const auto a_grid_desc_m_k = transform_tensor_descriptor(
426  a_grid_desc_mraw_kraw,
430 
431  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
432  a_grid_desc_m_k,
437 
438  return a_grid_desc_ak0_m_ak1;
439  }
440  else
441  {
442  // not pad M or K
443  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
444  a_grid_desc_mraw_kraw,
445  make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
449 
450  const auto a_grid_desc_permuted = transform_tensor_descriptor(
451  a_grid_desc_ak0_m_ak1,
454  make_pass_through_transform(AK1Value)),
457 
458  const auto a_grid_desc = transform_tensor_descriptor(
459  a_grid_desc_permuted,
460  make_tuple(
463  make_pass_through_transform(AK1Value)),
466 
467  return a_grid_desc;
468  }
469  }
470 
471  __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
472  {
473  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
474  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
475  constexpr index_t NkSwizzleNumber = Number<WaveSize * KPack>{};
477  make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber));
478  }
479 
480  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
481  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
482  {
483  const auto b_grid_desc_nraw_kraw = [&]() {
485  {
486  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
487  }
489  {
490  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
491  }
492  }();
493 
495 
496  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
497  GemmSpec != GemmSpecialization::Default),
498  "pk_i4_t does not support padding");
499  static_assert(!(is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t> &&
500  GemmSpec != GemmSpecialization::Default),
501  "f4x2_pk_t does not support padding");
502 
503  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
504  GemmSpec == GemmSpecialization::MNKPadding)
505  {
506  // pad both N and K
507  const auto b_grid_desc_n_k =
508  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
510  make_right_pad_transform(K, KPad - K)),
513 
514  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
515  b_grid_desc_n_k,
520 
521  return b_grid_desc_bk0_n_bk1;
522  }
523  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
524  GemmSpec == GemmSpecialization::MNPadding)
525  {
526  // pad N, but not K
527  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
528  b_grid_desc_nraw_kraw,
530  make_right_pad_transform(N, NPad - N)),
533 
534  return b_grid_desc_bk0_n_bk1;
535  }
536  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
537  GemmSpec == GemmSpecialization::MKPadding)
538  {
539  // pad K, but not N
540  const auto b_grid_desc_n_k = transform_tensor_descriptor(
541  b_grid_desc_nraw_kraw,
545 
546  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
547  b_grid_desc_n_k,
552 
553  return b_grid_desc_bk0_n_bk1;
554  }
555  else
556  {
557  // not pad N or K
558  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
559  b_grid_desc_nraw_kraw,
560  make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)),
564 
565  const auto b_grid_desc_permuted = transform_tensor_descriptor(
566  b_grid_desc_bk0_n_bk1,
569  make_pass_through_transform(BK1Value)),
572 
573  const auto b_grid_desc = transform_tensor_descriptor(
574  b_grid_desc_permuted,
575  make_tuple(
578  make_pass_through_transform(BK1Value)),
581 
582  return b_grid_desc;
583  }
584  }
585 
586  template <typename ABlockDesc_AK0_M_AK1>
587  __host__ __device__ static constexpr auto
588  MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&)
589  {
590  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
591 
592  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MXdlPack, MPerXdl, true>(
593  ABlockDesc_AK0_M_AK1{});
594  }
595 
596  template <typename BBlockDesc_BK0_N_BK1>
597  __host__ __device__ static constexpr auto
598  MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&)
599  {
600  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
601 
602  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NXdlPack, NPerXdl, false>(
603  BBlockDesc_BK0_N_BK1{});
604  }
605 
606  template <typename ELayout>
607  __host__ __device__ static auto MakeCGridDescriptor_M_N(
608  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
609  {
610  const auto c_grid_desc_mraw_nraw = [&]() {
612  {
613  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
614  }
616  {
617  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
618  }
619  }();
620 
621  // pad M and N
622  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
624  make_right_pad_transform(N, NPad - N)),
627  }
628 
629  template <typename DLayout>
630  __host__ __device__ static auto
632  {
633  const auto c_grid_desc_mraw_nraw = [&]() {
635  {
636  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
637  }
639  {
640  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
641  }
642  }();
643 
644  // pad M and N
645  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
647  make_right_pad_transform(N, NPad - N)),
650  }
651 
652  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
653  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
654  {
655  return generate_tuple(
656  [&](auto i) {
657  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
658  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
659  },
661  }
662 
663  template <typename DsGridDesc>
665  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
666  {
667  return generate_tuple(
668  [&](auto i) {
670  ds_grid_desc_m_n[i], MBlock, NBlock);
671  },
673  }
674 
675  struct Problem
676  {
677  __host__ Problem(index_t NumTokens_,
678  index_t TopK_,
679  index_t M_,
680  index_t N_,
681  index_t K_,
682  index_t StrideA_,
683  index_t StrideScaleA_,
684  index_t StrideB_,
685  index_t StrideScaleB_,
686  std::array<index_t, NumDTensor> StrideDs_,
687  index_t StrideC_,
688  index_t KBatch_)
689  : NumTokens{NumTokens_},
690  TopK{TopK_},
691  M{M_},
692  N{N_},
693  K{K_},
694  StrideA{StrideA_},
695  StrideScaleA{StrideScaleA_},
696  StrideB{StrideB_},
697  StrideScaleB{StrideScaleB_},
698  StrideDs{StrideDs_},
699  StrideC{StrideC_},
700  KBatch{KBatch_},
703  KRead{CalculateKRead(K_, KBatch_)},
704  KPadded{CalculateKPadded(K_, KBatch_)},
705  AK0{CalculateAK0Padded(K_, KBatch_)},
706  BK0{CalculateBK0Padded(K_, KBatch_)},
707  MBlock{CalculateMBlock(M_)},
709  {
710  }
711 
712  __host__ void Print() const
713  {
714  std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
715  << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
716  << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", "
717  << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", "
718  << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
719  << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
720  << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
721  << ", " << "NBlock: " << NBlock << "}" << std::endl;
722  }
723 
733  std::array<index_t, NumDTensor> StrideDs;
744  };
745 
746  // Argument
748  {
749  __host__ Argument(const index_t* p_sorted_token_ids_,
750  const index_t* p_sorted_expert_ids_,
751  const index_t* p_max_token_id_,
752  const ADataType* p_a_grid_,
753  const AScaleDataType* p_a_scale_grid_,
754  const BDataType* p_b_grid_,
755  const BScaleDataType* p_b_scale_grid_,
756  std::array<const void*, NumDTensor> p_ds_grid_,
757  CDataType* p_c_grid_,
758  index_t NumTokens_,
759  index_t TopK_,
760  index_t M_,
761  index_t N_,
762  index_t K_,
763  index_t StrideA_,
764  index_t StrideScaleA_,
765  index_t StrideB_,
766  index_t StrideScaleB_,
767  std::array<index_t, NumDTensor> StrideDs_,
768  index_t StrideC_,
769  index_t k_batch_,
770  AElementwiseOperation a_element_op_,
771  BElementwiseOperation b_element_op_,
772  CElementwiseOperation c_element_op_)
773  : Problem{NumTokens_,
774  TopK_,
775  M_,
776  N_,
777  K_ / APackedSize,
778  StrideA_ / APackedSize,
779  StrideScaleA_,
780  StrideB_ / BPackedSize,
781  StrideScaleB_,
782  StrideDs_,
783  StrideC_,
784  k_batch_},
785  p_sorted_token_ids{p_sorted_token_ids_},
786  p_sorted_expert_ids{p_sorted_expert_ids_},
787  p_max_token_id{p_max_token_id_},
788  p_a_grid{p_a_grid_},
789  p_a_scale_grid{p_a_scale_grid_},
790  p_b_grid{p_b_grid_},
791  p_b_scale_grid{p_b_scale_grid_},
792  p_ds_grid{},
793  p_c_grid{p_c_grid_},
794  a_element_op{a_element_op_},
795  b_element_op{b_element_op_},
796  c_element_op{c_element_op_}
797  {
798 
799  // populate pointer, desc for Ds
800  static_for<0, NumDTensor, 1>{}([&](auto i) {
801  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
802 
803  // D pointer
804  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
805  });
806  }
807 
811  const ADataType* p_a_grid;
812  const AScaleDataType* p_a_scale_grid;
813  const BDataType* p_b_grid;
814  const BScaleDataType* p_b_scale_grid;
816  CDataType* p_c_grid;
817 
818  const AElementwiseOperation a_element_op;
819  const BElementwiseOperation b_element_op;
820  const CElementwiseOperation c_element_op;
821  };
822 
824  {
825  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
826  {
827  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
828  {
829  a_k_split_offset = k_id * karg.KRead;
830  }
831  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
832  {
833  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
834  }
835 
836  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
837  {
838  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
839  }
840  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
841  {
842  // KPack * NLane * KLane * K0 * N0
843  b_k_split_offset = k_id * karg.KRead * NPerXdl;
844  }
845 
846  // Calculate A scale offset
847  a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize) * MXdlPack *
848  MPerXdl / scale_pack_size_a;
849 
850  // Calculate B scale offset
851  b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize) * NXdlPack *
852  NPerXdl / scale_pack_size_b;
853 
854  if(k_id < karg.KBatch - 1)
855  {
856  karg.K = karg.KRead;
857  }
858  else
859  {
860  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
861  }
862  }
863 
868  };
869 
870  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
871  {
872  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
873  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
874 
875  // A matrix in LDS memory, dst of blockwise copy
876  if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
877  {
878  // contiguous in LDS
882  }
883  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
884  // in some cases.
886  {
887  constexpr auto a_lds_block_desc =
890 
891  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
892  a_lds_block_desc,
898 
899  return a_lds_block_desc_permuted;
900  }
901  else // ColumnMajor A
902  {
903  // kfold and mpair dimension is not always required.
904  // more dimension in merge_transform increase the difficulty of generating immarg offset
905  // for compiler.
906  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
907  constexpr auto M1 = MPerBlock / M0;
908 
909  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
910  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
911  constexpr auto KThreadRead = WaveSize / MPerXdl;
912  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
913 
914  constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
915  ? 1
916  : 128 / (AK1Number * M0 * sizeof(ADataType));
917  constexpr auto KThreadReadPerm =
918  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
919  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
920  : KThreadRead;
921 
922  // 1<=mpair<=n0
923  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
924  ? 1
925  : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
926  ? M0
927  : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
928 
929  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
933  Number<kfold * M0 / mpair>{},
934  Number<mpair>{},
935  AK1Number));
936 
937  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
938  a_lds_block_desc,
939  make_tuple(
943  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
946  make_tuple(
948  make_tuple(
950 
951  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
952  a_lds_block_desc_permuted,
953  make_tuple(
961  Sequence<1>{},
962  Sequence<2>{},
963  Sequence<3>{},
964  Sequence<4>{},
965  Sequence<5>{}),
967  Sequence<2>{},
968  Sequence<0, 3>{},
969  Sequence<4, 5>{},
970  Sequence<6>{},
971  Sequence<7>{}));
972 
973  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
974  a_lds_block_desc_unmerged,
977  Number<KThreadWrite / kfold / KThreadReadPerm>{},
978  Number<kfold>{},
985 
986  return a_lds_block_desc_ak0_m_ak1;
987  }
988  }
989 
990  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
991  {
992  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
994  I1,
996  Number<KRepeat>{},
997  Number<BK1Value>{}));
998  }
999 
1001  {
1002  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1003 
1004  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1006  make_tuple(I1,
1008  I1,
1010 
1011  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1012  }
1013 
1016  BlkGemmPipelineVer,
1017  BlkGemmPipeSched,
1018  BlockSize,
1019  ScaleBlockSize,
1020  ADataType,
1021  AScaleDataType,
1022  BDataType,
1023  BScaleDataType,
1024  ComputeTypeA,
1025  AccDataType,
1032  ABlockTransferSrcScalarPerVector,
1033  BBlockTransferSrcScalarPerVector,
1034  MPerBlock,
1035  NPerBlock,
1036  KPerBlock,
1037  MPerXdl,
1038  NPerXdl,
1039  MXdlPerWave,
1040  NXdlPerWave,
1041  KPack,
1042  IsInputGemm>())>;
1043 
1044  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1045  {
1046  // LDS allocation for A and B: be careful of alignment
1047  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1048  // lds max alignment
1049  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1050 
1051  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1052  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1053 
1054  // LDS allocation for C shuffle in LDS
1055  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1057 
1058  constexpr auto c_block_size =
1059  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1060 
1061  return math::max(a_block_space_size_aligned * sizeof(ADataType),
1062  c_block_size * sizeof(CShuffleDataType));
1063  }
1064 
1066 
1067  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1068  __host__ static constexpr bool CheckValidity(const Argument& karg)
1069  {
1070  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1071  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1072  "Invalid tuning param!");
1073 
1074  static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
1075  "KPerBlock should be multiple of ScaleBlockSize");
1076 
1077  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1082  {
1083  if(!(karg.M % MPerBlock == 0))
1084  {
1085  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1086  {
1087  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1088  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1089  << std::endl;
1090  }
1091  return false;
1092  }
1093  }
1094 
1095  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1100  {
1101  if(!(karg.N % NPerBlock == 0))
1102  {
1103  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1104  {
1105  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1106  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1107  << std::endl;
1108  }
1109  return false;
1110  }
1111  }
1112 
1113  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1117  {
1118  auto K_t = karg.KBatch * KPerBlock;
1119  if(!(karg.K % K_t == 0))
1120  {
1121  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1122  {
1123  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1124  << karg.K << " " << __FILE__ << ":" << __LINE__
1125  << ", in function: " << __func__ << std::endl;
1126  }
1127  return false;
1128  }
1129  }
1130  else
1131  {
1132  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1133  auto K_t = karg.KBatch * KReadVec;
1134  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1135  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1136  {
1137  return false;
1138  }
1139  }
1140 
1142  {
1143  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1144  {
1145  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1146  {
1147  std::cout << "Arg K (" << karg.K
1148  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1149  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1150  << __LINE__ << ", in function: " << __func__ << std::endl;
1151  }
1152  return false;
1153  }
1154  }
1155  else
1156  {
1157  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1158  {
1159  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1160  {
1161  std::cout << "Arg M (" << karg.M
1162  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1163  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1164  << __LINE__ << ", in function: " << __func__ << std::endl;
1165  }
1166  return false;
1167  }
1168  }
1169 
1171  {
1172  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1173  {
1174  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1175  {
1176  std::cout << "Arg N (" << karg.N
1177  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1178  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1179  << __LINE__ << ", in function: " << __func__ << std::endl;
1180  }
1181  return false;
1182  }
1183  }
1184  else
1185  {
1186  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1187  {
1188  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1189  {
1190  std::cout << "Arg K (" << karg.K
1191  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1192  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1193  << __LINE__ << ", in function: " << __func__ << std::endl;
1194  }
1195  return false;
1196  }
1197  }
1198 
1200  {
1202  {
1203  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1204  {
1205  std::cout << "Arg N (" << karg.N
1206  << ") value is not a multiple of "
1207  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1209  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1210  << std::endl;
1211  }
1212  return false;
1213  }
1214  }
1215  else
1216  {
1218  {
1219  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1220  {
1221  std::cout << "Arg M (" << karg.M
1222  << ") value is not a multiple of "
1223  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1225  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1226  << std::endl;
1227 
1228  return false;
1229  }
1230  }
1231  }
1232 
1233  // check gridwise gemm pipeline
1234 #if 0
1235  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1236 
1237  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1238  {
1239  return false;
1240  }
1241 #endif
1242  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1243  return true;
1244  }
1245 
1246  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1247  {
1248  const index_t num_loop = K / KPerBlock;
1249  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1250  }
1251 
1252  __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1253  {
1254  const index_t num_loop = K / KPerBlock;
1255 
1256  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1257  }
1258 
1259  template <typename CGridDesc>
1260  __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1261  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1262  {
1263  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1264  c_grid_desc_m_n,
1269 
1270  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1271  }
1272 
1273  // return block_id to C matrix tile idx (m0, n0) mapping
1274  // if arch = gfx942
1275  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1276  // NPerBlock>;
1277 
1278  template <bool HasMainKBlockLoop,
1279  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1280  TailNumber TailNum = TailNumber::Odd>
1281  __device__ static void Run(const index_t* p_sorted_token_ids,
1282  const index_t* p_sorted_expert_ids,
1283  const index_t* p_max_token_id,
1284  const ADataType* p_a_grid,
1285  const AScaleDataType* p_a_scale_grid,
1286  const BDataType* p_b_grid,
1287  const BScaleDataType* p_b_scale_grid,
1288  DsGridPointer& p_ds_grid,
1289  CDataType* p_c_grid,
1290  void* p_shared,
1291  const Problem& problem,
1292  AElementwiseOperation a_element_op,
1293  BElementwiseOperation b_element_op,
1294  CElementwiseOperation c_element_op)
1295  {
1296  ignore = a_element_op;
1297  ignore = b_element_op;
1298  index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1299  index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
1300  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1301  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1302  problem.MPadded,
1303  problem.K,
1304  problem.KPadded,
1305  problem.StrideA,
1306  problem.AK0);
1307  const auto b_grid_desc_bpreshuffled =
1308  MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1309  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1310  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1311  problem.MPadded,
1312  problem.N,
1313  problem.NPadded,
1314  problem.StrideC);
1315 
1316  // We pad the M unconditionaly for Scale
1317  const auto Padded_Scale_M =
1318  math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
1319  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
1320  make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
1321  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
1322  (KXdlPack * 64 / MPerXdl),
1324  make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
1325  (ScaleBlockSize / APackedSize)) *
1326  MPerXdl * MXdlPack / scale_pack_size_a,
1328  1));
1329 
1330  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
1331  make_tuple(problem.N / (NXdlPack * NPerXdl),
1332  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
1333  (KXdlPack * 64 / NPerXdl),
1335  make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
1336  (ScaleBlockSize / BPackedSize)) *
1337  NPerXdl * NXdlPack / scale_pack_size_b,
1339  1));
1340 
1341  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1343  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1344 
1345  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1346  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1347  if(expert_block_id * MPerBlock >= max_token_id)
1348  return;
1349  const index_t expert_id =
1350  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1351  const auto block_mn = [&]() -> std::pair<int, int> {
1352  if constexpr(NSwizzle)
1353  {
1354  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1355  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1356  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1357  const index_t expert_swizzle =
1358  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1359  const index_t bid_new = blockIdx.x - prefix_block;
1360  const index_t nid = __builtin_amdgcn_readfirstlane(
1361  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1362  const index_t mid =
1363  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1364  return {nid, mid};
1365  }
1366  else
1367  {
1368  return {blockIdx.x, blockIdx.y};
1369  }
1370  }();
1371 
1372  const index_t block_n_id = block_mn.first;
1373  const index_t block_m_id = block_mn.second;
1374  const index_t token0 =
1375  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1376 
1377  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1378  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1379  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1380  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1381  constexpr auto AKThreads = AK0Threads * AK1Threads;
1382  constexpr auto AMRepeats = MPerBlock / AMThreads;
1383  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads;
1384 
1385  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1386  return;
1388  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1389  const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
1390  index_t token_offset = fused_token & 0xffffff;
1391  if constexpr(!IsInputGemm)
1392  {
1393  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1394  }
1395  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1396  });
1397 
1398  const index_t expert_stride =
1399  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1400  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1401  problem.N * (IsInputGemm ? 2 : 1) *
1402  math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
1403 
1404  // N0, K0, Blocksize*KPack
1405  const index_t n_block_data_idx_on_grid =
1406  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / NXdlPack);
1407 
1408  // Gride buffer creation
1409  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1410  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1411  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1412  p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1413 
1414  // A, B scale buffer
1415  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1416  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1417  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1418  p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
1419  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1420 
1421  // A matrix in LDS memory, dst of blockwise copy
1422  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1423 
1424  // B matrix in LDS memory, dst of blockwise copy
1425  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1426 
1427  // A matrix blockwise direct to LDS copy
1428  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
1431  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1432  ABlockTransferThreadClusterArrangeOrder,
1433  ADataType,
1434  ADataType,
1435  decltype(a_grid_desc_ak0_m_ak1),
1436  decltype(a_block_desc_ak0_m_ak1),
1437  ABlockTransferSrcAccessOrder,
1438  ABlockTransferSrcVectorDim,
1439  2,
1440  ABlockTransferSrcScalarPerVector,
1441  IndexType,
1442  1>(a_grid_desc_ak0_m_ak1,
1443  make_multi_index(0, 0, 0),
1444  a_block_desc_ak0_m_ak1,
1445  make_multi_index(0, 0, 0),
1446  gather_offsets);
1447 
1448  // Thread-wise copy
1449  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1450  auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1451  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1452  auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1453  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1454  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
1455 
1456  auto b_blockwise_copy =
1458  BDataType,
1459  decltype(b_grid_desc_bpreshuffled),
1460  decltype(b_block_desc_bk0_n_bk1),
1461  Sequence<Number<NXdlPerWave / NXdlPack>{},
1462  I1,
1463  Number<NXdlPack>{},
1464  Number<KRepeat>{},
1465  Number<BK1Value>{}>,
1467  4,
1468  BBlockTransferSrcScalarPerVector,
1469  BThreadTransferSrcResetCoordinateAfterRun,
1470  true>(
1471  b_grid_desc_bpreshuffled,
1472  make_multi_index(n_block_data_idx_on_grid,
1474  0,
1475  0,
1476  KPack * (get_thread_local_1d_id() % WarpSize)));
1477 
1478  // LDS allocation for A and B: be careful of alignment
1479  // Cast after lds
1480  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1481  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1482 
1483  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1484  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0);
1485 
1486  // Blockwise GEMM pipeline
1487  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1488  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1489  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1490  decltype(c_thread_buf) c_thread_buf_up;
1491 
1493  float,
1494  c_thread_buf.num_of_v_,
1495  c_thread_buf.s_per_v,
1496  true>
1497  c_thread_buf_fp32;
1498 
1499  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1500  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1501  KPerBlock);
1502 
1503  // a and b scale processing
1504  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1505  const auto waveId_m = wave_idx[I0];
1506  const auto waveId_n = wave_idx[I1];
1507 
1508  auto thread_offset_shuffled =
1509  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
1510 
1511  auto a_thread_offset_m = waveId_m;
1512 
1513  // get each thread's offset int the scale tensor
1514  const index_t token_scale_pos = block_m_id * MPerBlock;
1515  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
1516  return;
1517 
1518  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1519  AScaleDataType,
1520  AScaleDataType,
1521  decltype(a_scale_grid_desc_am_ak),
1522  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1523  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
1524  Sequence<0, 1, 2>, // DimAccessOrder
1525  2, // SrcVectorDim
1526  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
1527  1, // SrcScalarStrideInVector
1528  true>(a_scale_grid_desc_am_ak,
1529  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
1530  0,
1531  thread_offset_shuffled / scale_pack_size_a));
1532 
1533  // B scale load
1534  auto b_thread_offset_n = waveId_n;
1535 
1536  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1537  BScaleDataType,
1538  BScaleDataType,
1539  decltype(b_scale_grid_desc_bn_ak),
1540  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1541  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1542  Sequence<0, 1, 2>, // DimAccessOrder
1543  2, // SrcVectorDim
1544  KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
1545  1, // SrcScalarStrideInVector
1546  true>(b_scale_grid_desc_bn_ak,
1547  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1548  0,
1549  thread_offset_shuffled / scale_pack_size_b));
1550 
1551  if constexpr(IsInputGemm)
1552  {
1553  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
1554  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1555  p_b_grid_up + expert_id * expert_stride,
1556  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1557  auto b_blockwise_copy_up =
1559  BDataType,
1560  decltype(b_grid_desc_bpreshuffled),
1561  decltype(b_block_desc_bk0_n_bk1),
1562  Sequence<Number<NXdlPerWave / NXdlPack>{},
1563  I1,
1564  Number<NXdlPack>{},
1565  Number<KRepeat>{},
1566  Number<BK1Value>{}>,
1568  4,
1569  BBlockTransferSrcScalarPerVector,
1570  BThreadTransferSrcResetCoordinateAfterRun,
1571  true>(
1572  b_grid_desc_bpreshuffled,
1573  make_multi_index(n_block_data_idx_on_grid,
1575  0,
1576  0,
1577  KPack * (get_thread_local_1d_id() % WarpSize)));
1578  const BScaleDataType* p_b_scale_grid_up =
1579  p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
1580  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1581  p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
1582  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1583 
1584  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
1585  BScaleDataType,
1586  BScaleDataType,
1587  decltype(b_scale_grid_desc_bn_ak),
1588  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1589  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1590  Sequence<0, 1, 2>, // DimAccessOrder
1591  2, // SrcVectorDim
1592  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1593  1, // SrcScalarStrideInVector
1594  true>(
1595  b_scale_grid_desc_bn_ak,
1596  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1597  0,
1598  thread_offset_shuffled / scale_pack_size_b));
1599 
1600  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1601  // A
1602  a_grid_desc_ak0_m_ak1,
1603  a_block_desc_ak0_m_ak1,
1604  a_blockwise_copy,
1605  a_grid_buf,
1606  a_block_buf,
1607  a_block_slice_copy_step,
1608  // Gate and Up
1609  b_grid_desc_bpreshuffled,
1610  b_block_desc_bk0_n_bk1,
1611  b_blockwise_copy,
1612  b_blockwise_copy_up,
1613  b_grid_buf,
1614  b_grid_buf_up,
1615  b_block_bufs,
1616  b_block_slice_copy_step,
1617  // C
1618  c_thread_buf,
1619  c_thread_buf_up,
1620  // A scale
1621  a_scale_grid_desc_am_ak,
1622  a_scale_thread_copy,
1623  a_scale_grid_buf,
1624  // B scale
1625  b_scale_grid_desc_bn_ak,
1626  b_scale_thread_copy,
1627  b_scale_thread_copy_up,
1628  b_scale_grid_buf,
1629  b_scale_grid_buf_up,
1630  num_k_block_main_loop);
1631  }
1632  else
1633  {
1634  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1635  a_grid_desc_ak0_m_ak1, // A
1636  a_block_desc_ak0_m_ak1,
1637  a_blockwise_copy,
1638  a_grid_buf,
1639  a_block_buf,
1640  a_block_slice_copy_step,
1641  b_grid_desc_bpreshuffled, // B
1642  b_block_desc_bk0_n_bk1,
1643  b_blockwise_copy,
1644  b_grid_buf,
1645  b_block_bufs,
1646  b_block_slice_copy_step,
1647  c_thread_buf, // C
1648  a_scale_grid_desc_am_ak, // A scale
1649  a_scale_thread_copy,
1650  a_scale_grid_buf,
1651  b_scale_grid_desc_bn_ak, // B scale
1652  b_scale_thread_copy,
1653  b_scale_grid_buf,
1654  num_k_block_main_loop);
1655  }
1656 
1657  // shuffle C and write out
1658  {
1659  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1660  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1661  "wrong!");
1662  static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
1663  CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
1664  "wrong!");
1665 
1666  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1667 
1668  // TODO: hacky, fix it!
1669  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1670  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1671 
1672  // TODO: hacky, fix it!
1673  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1674  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1675  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1676 
1677  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1678  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1679  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1680  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1681  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1682  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1683  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1684  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1685  constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
1686  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
1687 
1688  // mul scales
1689 
1690  static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
1691  static_assert(M5 == 4);
1692  const index_t m1 = get_warp_local_1d_id() / NWave;
1693  const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
1694 
1695  vector_type<float, 4> topk_weights; // for gemm2 only
1696  static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
1697  static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
1698  static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
1699  static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
1700  static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
1701  const index_t m_pos = block_m_id * MPerBlock +
1702  m0 * M2 * M1 * M3 * M4 * M5 +
1703  m1 * M2 * M3 * M4 * M5 +
1704  imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
1705  if constexpr(MulRoutedWeight)
1706  {
1707  topk_weights =
1708  *c_style_pointer_cast<const vector_type<float, M5>*>(
1709  p_ds_grid[I2] + m_pos);
1710  }
1711  static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
1712  constexpr index_t c_offset =
1713  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1714  make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
1715  constexpr auto cidx = Number<c_offset>{};
1716 
1717  if constexpr(IsInputGemm) // gu fusion
1718  {
1719  if constexpr(ActivationOperation ==
1721  {
1722  float gate = c_thread_buf[cidx];
1723  float up = c_thread_buf_up[cidx];
1724  if constexpr(MulRoutedWeight)
1725  {
1726  gate = gate * topk_weights.AsType<float>()[m5];
1727  up = up * topk_weights.AsType<float>()[m5];
1728  }
1730  c_thread_buf_fp32(cidx) = gate * up;
1731  }
1732  else if(ActivationOperation == Activation::gelu_and_mul)
1733  {
1734  float gate = c_thread_buf[cidx];
1735  float up = c_thread_buf_up[cidx];
1736  if constexpr(MulRoutedWeight)
1737  {
1738  gate = gate * topk_weights.AsType<float>()[m5];
1739  up = up * topk_weights.AsType<float>()[m5];
1740  }
1742  c_thread_buf_fp32(cidx) = gate * up;
1743  }
1744  }
1745  else
1746  {
1747  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1748  if constexpr(MulRoutedWeight)
1749  {
1750  c_thread_buf_fp32(cidx) =
1751  topk_weights.AsType<float>()[m5] *
1752  c_thread_buf_fp32[cidx];
1753  }
1754  }
1755  });
1756  });
1757  });
1758  });
1759  });
1760  });
1761 
1762  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1764 
1765  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1766  static_cast<CShuffleDataType*>(p_shared),
1767  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1768 
1769  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1770  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1771  make_tuple(
1774  Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
1775  // shuffle
1776  M1, // M1 = MWave
1777  M2, // M2 * M3 * M4 = MPerXdl
1778  M3,
1779  M4,
1780  M5)),
1784  // per shuffle
1785  N1, // N1 = NWave
1786  N2, // N2 = NXdlPack
1787  N3))), // N3 = NPerXdl
1791  Sequence<>{},
1793 
1794  // calculate origin of thread output tensor on global memory
1795  // blockwise GEMM c matrix starting index
1796  const auto c_thread_mtx_on_block =
1797  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1798 
1799  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1800  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1801 
1802  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1804  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
1806  make_tuple(Sequence<0>{}));
1807 
1808  const auto m_thread_data_on_block_idx =
1809  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1810  make_multi_index(m_thread_data_on_block));
1811 
1812  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1814  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
1816  make_tuple(Sequence<0>{}));
1817 
1818  const auto n_thread_data_on_block_idx =
1819  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1820  make_multi_index(n_thread_data_on_block));
1821 
1822  // shuffle: threadwise copy C from VGPR to LDS
1823  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1824  AccDataType,
1825  CShuffleDataType,
1826  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1827  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1829  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
1830  CShuffleNXdlPerWavePerShuffle / NXdlPack,
1831  I1,
1832  I1,
1833  M2,
1834  N2,
1835  M3,
1836  I1,
1837  M5,
1838  I1>,
1840  9,
1841  1,
1843  1,
1844  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1845  make_multi_index(0,
1846  0,
1847  m_thread_data_on_block_idx[I1],
1848  n_thread_data_on_block_idx[I1],
1849  m_thread_data_on_block_idx[I2],
1850  n_thread_data_on_block_idx[I2],
1851  m_thread_data_on_block_idx[I3],
1852  m_thread_data_on_block_idx[I4],
1853  m_thread_data_on_block_idx[I5],
1854  n_thread_data_on_block_idx[I3]),
1856 
1857  using EDataType = CDataType;
1858 
1859  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1860  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1861 
1862  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1864  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1865 
1866  const auto ds_grid_buf = generate_tuple(
1867  [&](auto i) {
1868  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1869  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1870  },
1871  Number<NumDTensor>{});
1872 
1873  // tuple of reference to C/Ds tensor descriptors
1874  const auto c_ds_desc_refs = concat_tuple_of_reference(
1875  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1876  generate_tie([&](auto i) -> const auto& // return type should be reference
1877  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1878  Number<NumDTensor>{}));
1879 
1880  // tuple of reference to C/Ds tensor descriptors
1881  const auto c_ds_buf_refs = concat_tuple_of_reference(
1882  tie(c_shuffle_block_buf),
1883  generate_tie([&](auto i) -> const auto& // return type should be reference
1884  { return ds_grid_buf[i]; },
1885  Number<NumDTensor>{}));
1886 
1887  // tuple of starting index of C/Ds blockwise copy
1888  const auto idx_c_ds_block_begin =
1891  [&](auto) {
1892  return make_multi_index(block_m_id, 0, block_n_id, 0);
1893  // return make_multi_index(block_work_idx[I0], 0,
1894  // block_work_idx[I1], 0);
1895  },
1896  Number<NumDTensor>{}));
1897 
1898  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1899  c_grid_desc_mblock_mperblock_nblock_nperblock;
1900 
1901  using CDEBlockTransferCluster =
1902  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1903  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1904  constexpr index_t scatter_weight_idx = 3; // hack fix felix
1905  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1907  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1909  decltype(c_ds_desc_refs),
1910  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1911  CElementwiseOperation,
1912  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
1913  // Sequence support
1914  // arbitray type
1915  Sequence<1,
1916  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1917  1,
1918  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1919  CDEBlockTransferCluster,
1920  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1921  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1922  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1923  3, // index_t SrcVectorDim,
1924  3, // index_t DstVectorDim,
1925  CDEShuffleBlockTransferScalarPerVectors,
1930  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1931  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1932  IndexType,
1933  1, // ScatterDim
1934  true, // OutputScatter: false, only use scatter weights
1935  scatter_weight_idx // ScatterWeightIdx: ascale
1936  >{c_ds_desc_refs,
1937  idx_c_ds_block_begin,
1938  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1939  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
1940  c_element_op};
1941 
1942  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1943  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1944 
1945  constexpr auto sfc_c_vgpr =
1946  SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
1947  NXdlPerWave / NXdlPack,
1948  1,
1949  1,
1950  MXdlPack,
1951  NXdlPack,
1952  M2,
1953  1,
1954  M4,
1955  1>,
1957  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
1958  CShuffleNXdlPerWavePerShuffle / NXdlPack,
1959  1,
1960  1,
1961  MXdlPack,
1962  NXdlPack,
1963  M2,
1964  1,
1965  M4,
1966  1>>{};
1967 
1968  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1969 
1970  // space filling curve for shuffled blockwise C/D/E
1971  constexpr auto sfc_cde_block =
1974  Sequence<1,
1975  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1976  1,
1977  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1978 
1979  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1980  constexpr auto EMThreads =
1981  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
1982  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1983  constexpr auto ENThreads =
1984  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
1985  static_for<0, num_access, 1>{}([&](auto access_id) {
1986  // make sure it's safe to write to LDS
1988 
1989  auto dstidx = sfc_cde_block.GetIndex(access_id);
1990  const index_t c_token_pos =
1991  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
1992  static_for<0, EMRepeats, 1>{}([&](auto m0) {
1993  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1994  IndexType token_offset = fused_token & 0xffffff;
1995  if constexpr(IsInputGemm)
1996  {
1997  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1998  }
1999  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2000  });
2001 
2002  block_sync_lds();
2003 
2004  // each thread write its data from VGPR to LDS
2005  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2006  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2007  c_thread_buf_fp32,
2008  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2009  c_shuffle_block_buf);
2010 
2011  // make sure it's safe to read from LDS
2012  block_sync_lds();
2013 
2014  // each block copy its data from LDS to global
2015  cde_block_copy_lds_and_global.Run(
2016  c_ds_desc_refs,
2017  c_ds_buf_refs,
2018  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2019  tie(c_grid_buf),
2020  scatter_offsets);
2021 
2022  if constexpr(access_id < num_access - 1)
2023  {
2024  constexpr auto cde_lds_and_global_step =
2025  sfc_cde_block.GetForwardStep(access_id);
2026 
2027  // move on Ds
2028  static_for<0, NumDTensor, 1>{}([&](auto i) {
2029  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2030  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2031  });
2032 
2033  // move on E
2034  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2035  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2036  I0,
2037  cde_lds_and_global_step);
2038  }
2039  });
2040  }
2041  }
2042 
2043  template <bool HasMainKBlockLoop,
2044  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2045  TailNumber TailNum = TailNumber::Odd>
2046  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
2047  const index_t* p_sorted_expert_ids,
2048  const index_t* p_max_token_id,
2049  const ADataType* p_a_grid,
2050  const AScaleDataType* p_a_scale_grid,
2051  const BDataType* p_b_grid,
2052  const BScaleDataType* p_b_scale_grid,
2053  DsGridPointer& p_ds_grid,
2054  CDataType* p_c_grid,
2055  void* p_shared_0,
2056  void* p_shared_1,
2057  const Problem& problem,
2058  AElementwiseOperation a_element_op,
2059  BElementwiseOperation b_element_op,
2060  CElementwiseOperation c_element_op)
2061  {
2062  ignore = a_element_op;
2063  ignore = b_element_op;
2064  index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
2065  index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
2066  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2067  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
2068  problem.MPadded,
2069  problem.K,
2070  problem.KPadded,
2071  problem.StrideA,
2072  problem.AK0);
2073  const auto b_grid_desc_bpreshuffled =
2074  MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
2075  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
2076  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
2077  problem.MPadded,
2078  problem.N,
2079  problem.NPadded,
2080  problem.StrideC);
2081 
2082  // We pad the M unconditionaly for Scale
2083  const auto Padded_Scale_M =
2084  math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
2085  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
2086  make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
2087  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
2088  (KXdlPack * 64 / MPerXdl),
2090  make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
2091  (ScaleBlockSize / APackedSize)) *
2092  MPerXdl * MXdlPack / scale_pack_size_a,
2094  1));
2095 
2096  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
2097  make_tuple(problem.N / (NXdlPack * NPerXdl),
2098  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
2099  (KXdlPack * 64 / NPerXdl),
2101  make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
2102  (ScaleBlockSize / BPackedSize)) *
2103  NPerXdl * NXdlPack / scale_pack_size_b,
2105  1));
2106 
2107  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2109  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2110 
2111  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2112  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
2113  if(expert_block_id * MPerBlock >= max_token_id)
2114  return;
2115  const index_t expert_id =
2116  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2117  const auto block_mn = [&]() -> std::pair<int, int> {
2118  if constexpr(NSwizzle)
2119  {
2120  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2121  const index_t prefix_block = ecnt_prefix * problem.NBlock;
2122  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2123  const index_t expert_swizzle =
2124  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
2125  const index_t bid_new = blockIdx.x - prefix_block;
2126  const index_t nid = __builtin_amdgcn_readfirstlane(
2127  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2128  const index_t mid =
2129  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2130  return {nid, mid};
2131  }
2132  else
2133  {
2134  return {blockIdx.x, blockIdx.y};
2135  }
2136  }();
2137 
2138  const index_t block_n_id = block_mn.first;
2139  const index_t block_m_id = block_mn.second;
2140  const index_t token0 =
2141  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2142 
2143  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2144  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2145  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
2146  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
2147  constexpr auto AKThreads = AK0Threads * AK1Threads;
2148  constexpr auto AMRepeats = MPerBlock / AMThreads;
2149  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads;
2150 
2151  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
2152  return;
2154  static_for<0, AMRepeats, 1>{}([&](auto m0) {
2155  const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
2156  index_t token_offset = fused_token & 0xffffff;
2157  if constexpr(!IsInputGemm)
2158  {
2159  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2160  }
2161  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2162  });
2163 
2164  const index_t expert_stride =
2165  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2166  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2167  problem.N * (IsInputGemm ? 2 : 1) *
2168  math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
2169 
2170  // N0, K0, Blocksize*KPack
2171  const index_t n_block_data_idx_on_grid =
2172  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / NXdlPack);
2173 
2174  // Gride buffer creation
2175  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2176  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2177  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2178  p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
2179 
2180  // A, B scale buffer
2181  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2182  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2183  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2184  p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
2185  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2186 
2187  // A matrix in LDS memory, dst of blockwise copy
2188  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2189 
2190  // B matrix in LDS memory, dst of blockwise copy
2191  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2192 
2193  // A matrix blockwise direct to LDS copy
2194  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
2197  ABlockTransferThreadClusterLengths_AK0_M_AK1,
2198  ABlockTransferThreadClusterArrangeOrder,
2199  ADataType,
2200  ADataType,
2201  decltype(a_grid_desc_ak0_m_ak1),
2202  decltype(a_block_desc_ak0_m_ak1),
2203  ABlockTransferSrcAccessOrder,
2204  ABlockTransferSrcVectorDim,
2205  2,
2206  ABlockTransferSrcScalarPerVector,
2207  IndexType,
2208  1>(a_grid_desc_ak0_m_ak1,
2209  make_multi_index(0, 0, 0),
2210  a_block_desc_ak0_m_ak1,
2211  make_multi_index(0, 0, 0),
2212  gather_offsets);
2213 
2214  // Thread-wise copy
2215  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2216  auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2217  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2218  auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2219  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2220  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2221 
2222  auto b_blockwise_copy =
2224  BDataType,
2225  decltype(b_grid_desc_bpreshuffled),
2226  decltype(b_block_desc_bk0_n_bk1),
2227  Sequence<Number<NXdlPerWave / NXdlPack>{},
2228  I1,
2229  Number<NXdlPack>{},
2230  Number<KRepeat>{},
2231  Number<BK1Value>{}>,
2233  4,
2234  BBlockTransferSrcScalarPerVector,
2235  BThreadTransferSrcResetCoordinateAfterRun,
2236  true>(
2237  b_grid_desc_bpreshuffled,
2238  make_multi_index(n_block_data_idx_on_grid,
2240  0,
2241  0,
2242  KPack * (get_thread_local_1d_id() % WarpSize)));
2243 
2244  // LDS allocation for A and B: be careful of alignment
2245  // Cast after lds
2246  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2247  static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2248  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2249  static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2250  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2251 
2252  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2253  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0);
2254 
2255  // Blockwise GEMM pipeline
2256  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2257  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2258  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2259  decltype(c_thread_buf) c_thread_buf_up;
2260 
2262  float,
2263  c_thread_buf.num_of_v_,
2264  c_thread_buf.s_per_v,
2265  true>
2266  c_thread_buf_fp32;
2267 
2268  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2269  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2270  KPerBlock);
2271 
2272  // a and b scale processing
2273  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2274  const auto waveId_m = wave_idx[I0];
2275  const auto waveId_n = wave_idx[I1];
2276 
2277  auto thread_offset_shuffled =
2278  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
2279 
2280  auto a_thread_offset_m = waveId_m;
2281 
2282  // get each thread's offset int the scale tensor
2283  const index_t token_scale_pos = block_m_id * MPerBlock;
2284  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2285  return;
2286 
2287  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2288  AScaleDataType,
2289  AScaleDataType,
2290  decltype(a_scale_grid_desc_am_ak),
2291  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2292  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
2293  Sequence<0, 1, 2>, // DimAccessOrder
2294  2, // SrcVectorDim
2295  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
2296  1, // SrcScalarStrideInVector
2297  true>(a_scale_grid_desc_am_ak,
2298  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
2299  0,
2300  thread_offset_shuffled / scale_pack_size_a));
2301 
2302  // B scale load
2303  auto b_thread_offset_n = waveId_n;
2304 
2305  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2306  BScaleDataType,
2307  BScaleDataType,
2308  decltype(b_scale_grid_desc_bn_ak),
2309  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2310  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2311  Sequence<0, 1, 2>, // DimAccessOrder
2312  2, // SrcVectorDim
2313  KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
2314  1, // SrcScalarStrideInVector
2315  true>(b_scale_grid_desc_bn_ak,
2316  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2317  0,
2318  thread_offset_shuffled / scale_pack_size_b));
2319 
2320  if constexpr(IsInputGemm)
2321  {
2322  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
2323  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2324  p_b_grid_up + expert_id * expert_stride,
2325  b_grid_desc_bpreshuffled.GetElementSpaceSize());
2326  auto b_blockwise_copy_up =
2328  BDataType,
2329  decltype(b_grid_desc_bpreshuffled),
2330  decltype(b_block_desc_bk0_n_bk1),
2331  Sequence<Number<NXdlPerWave / NXdlPack>{},
2332  I1,
2333  Number<NXdlPack>{},
2334  Number<KRepeat>{},
2335  Number<BK1Value>{}>,
2337  4,
2338  BBlockTransferSrcScalarPerVector,
2339  BThreadTransferSrcResetCoordinateAfterRun,
2340  true>(
2341  b_grid_desc_bpreshuffled,
2342  make_multi_index(n_block_data_idx_on_grid,
2344  0,
2345  0,
2346  KPack * (get_thread_local_1d_id() % WarpSize)));
2347  const BScaleDataType* p_b_scale_grid_up =
2348  p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
2349  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2350  p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
2351  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2352 
2353  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
2354  BScaleDataType,
2355  BScaleDataType,
2356  decltype(b_scale_grid_desc_bn_ak),
2357  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2358  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2359  Sequence<0, 1, 2>, // DimAccessOrder
2360  2, // SrcVectorDim
2361  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
2362  1, // SrcScalarStrideInVector
2363  true>(
2364  b_scale_grid_desc_bn_ak,
2365  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2366  0,
2367  thread_offset_shuffled / scale_pack_size_b));
2368 
2369  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2370  // A
2371  a_grid_desc_ak0_m_ak1,
2372  a_block_desc_ak0_m_ak1,
2373  a_blockwise_copy,
2374  a_grid_buf,
2375  a_block_bufs,
2376  a_block_slice_copy_step,
2377  // Gate and Up
2378  b_grid_desc_bpreshuffled,
2379  b_block_desc_bk0_n_bk1,
2380  b_blockwise_copy,
2381  b_blockwise_copy_up,
2382  b_grid_buf,
2383  b_grid_buf_up,
2384  b_block_bufs,
2385  b_block_slice_copy_step,
2386  // C
2387  c_thread_buf,
2388  c_thread_buf_up,
2389  // A scale
2390  a_scale_grid_desc_am_ak,
2391  a_scale_thread_copy,
2392  a_scale_grid_buf,
2393  // B scale
2394  b_scale_grid_desc_bn_ak,
2395  b_scale_thread_copy,
2396  b_scale_thread_copy_up,
2397  b_scale_grid_buf,
2398  b_scale_grid_buf_up,
2399  num_k_block_main_loop);
2400  }
2401  else
2402  {
2403  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2404  a_grid_desc_ak0_m_ak1, // A
2405  a_block_desc_ak0_m_ak1,
2406  a_blockwise_copy,
2407  a_grid_buf,
2408  a_block_bufs,
2409  a_block_slice_copy_step,
2410  b_grid_desc_bpreshuffled, // B
2411  b_block_desc_bk0_n_bk1,
2412  b_blockwise_copy,
2413  b_grid_buf,
2414  b_block_bufs,
2415  b_block_slice_copy_step,
2416  c_thread_buf, // C
2417  a_scale_grid_desc_am_ak, // A scale
2418  a_scale_thread_copy,
2419  a_scale_grid_buf,
2420  b_scale_grid_desc_bn_ak, // B scale
2421  b_scale_thread_copy,
2422  b_scale_grid_buf,
2423  num_k_block_main_loop);
2424  }
2425 
2426  // shuffle C and write out
2427  {
2428  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2429  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2430  "wrong!");
2431  static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
2432  CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
2433  "wrong!");
2434 
2435  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2436 
2437  // TODO: hacky, fix it!
2438  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2439  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2440 
2441  // TODO: hacky, fix it!
2442  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2443  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2444  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2445 
2446  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2447  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2448  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2449  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2450  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2451  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2452  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2453  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2454  constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
2455  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
2456 
2457  // mul scales
2458 
2459  static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
2460  static_assert(M5 == 4);
2461  const index_t m1 = get_warp_local_1d_id() / NWave;
2462  const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
2463 
2464  vector_type<float, 4> topk_weights; // for gemm2 only
2465  static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
2466  static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
2467  static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
2468  static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
2469  static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
2470  const index_t m_pos = block_m_id * MPerBlock +
2471  m0 * M2 * M1 * M3 * M4 * M5 +
2472  m1 * M2 * M3 * M4 * M5 +
2473  imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
2474  if constexpr(MulRoutedWeight)
2475  {
2476  topk_weights =
2477  *c_style_pointer_cast<const vector_type<float, M5>*>(
2478  p_ds_grid[I2] + m_pos);
2479  }
2480  static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
2481  constexpr index_t c_offset =
2482  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2483  make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
2484  constexpr auto cidx = Number<c_offset>{};
2485 
2486  if constexpr(IsInputGemm) // gu fusion
2487  {
2488  if constexpr(ActivationOperation ==
2490  {
2491  float gate = c_thread_buf[cidx];
2492  float up = c_thread_buf_up[cidx];
2493  if constexpr(MulRoutedWeight)
2494  {
2495  gate = gate * topk_weights.AsType<float>()[m5];
2496  up = up * topk_weights.AsType<float>()[m5];
2497  }
2499  c_thread_buf_fp32(cidx) = gate * up;
2500  }
2501  else if(ActivationOperation == Activation::gelu_and_mul)
2502  {
2503  float gate = c_thread_buf[cidx];
2504  float up = c_thread_buf_up[cidx];
2505  if constexpr(MulRoutedWeight)
2506  {
2507  gate = gate * topk_weights.AsType<float>()[m5];
2508  up = up * topk_weights.AsType<float>()[m5];
2509  }
2511  c_thread_buf_fp32(cidx) = gate * up;
2512  }
2513  }
2514  else
2515  {
2516  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2517  if constexpr(MulRoutedWeight)
2518  {
2519  c_thread_buf_fp32(cidx) =
2520  topk_weights.AsType<float>()[m5] *
2521  c_thread_buf_fp32[cidx];
2522  }
2523  }
2524  });
2525  });
2526  });
2527  });
2528  });
2529  });
2530 
2531  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2533 
2534  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2535  static_cast<CShuffleDataType*>(p_shared_0),
2536  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2537 
2538  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2539  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2540  make_tuple(
2543  Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
2544  // shuffle
2545  M1, // M1 = MWave
2546  M2, // M2 * M3 * M4 = MPerXdl
2547  M3,
2548  M4,
2549  M5)),
2553  // per shuffle
2554  N1, // N1 = NWave
2555  N2, // N2 = NXdlPack
2556  N3))), // N3 = NPerXdl
2560  Sequence<>{},
2562 
2563  // calculate origin of thread output tensor on global memory
2564  // blockwise GEMM c matrix starting index
2565  const auto c_thread_mtx_on_block =
2566  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2567 
2568  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2569  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2570 
2571  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2573  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
2575  make_tuple(Sequence<0>{}));
2576 
2577  const auto m_thread_data_on_block_idx =
2578  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2579  make_multi_index(m_thread_data_on_block));
2580 
2581  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2583  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
2585  make_tuple(Sequence<0>{}));
2586 
2587  const auto n_thread_data_on_block_idx =
2588  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2589  make_multi_index(n_thread_data_on_block));
2590 
2591  // shuffle: threadwise copy C from VGPR to LDS
2592  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
2593  AccDataType,
2594  CShuffleDataType,
2595  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2596  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2598  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2599  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2600  I1,
2601  I1,
2602  M2,
2603  N2,
2604  M3,
2605  I1,
2606  M5,
2607  I1>,
2609  9,
2610  1,
2612  1,
2613  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2614  make_multi_index(0,
2615  0,
2616  m_thread_data_on_block_idx[I1],
2617  n_thread_data_on_block_idx[I1],
2618  m_thread_data_on_block_idx[I2],
2619  n_thread_data_on_block_idx[I2],
2620  m_thread_data_on_block_idx[I3],
2621  m_thread_data_on_block_idx[I4],
2622  m_thread_data_on_block_idx[I5],
2623  n_thread_data_on_block_idx[I3]),
2625 
2626  using EDataType = CDataType;
2627 
2628  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2629  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2630 
2631  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2633  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2634 
2635  const auto ds_grid_buf = generate_tuple(
2636  [&](auto i) {
2637  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2638  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2639  },
2640  Number<NumDTensor>{});
2641 
2642  // tuple of reference to C/Ds tensor descriptors
2643  const auto c_ds_desc_refs = concat_tuple_of_reference(
2644  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2645  generate_tie([&](auto i) -> const auto& // return type should be reference
2646  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2647  Number<NumDTensor>{}));
2648 
2649  // tuple of reference to C/Ds tensor descriptors
2650  const auto c_ds_buf_refs = concat_tuple_of_reference(
2651  tie(c_shuffle_block_buf),
2652  generate_tie([&](auto i) -> const auto& // return type should be reference
2653  { return ds_grid_buf[i]; },
2654  Number<NumDTensor>{}));
2655 
2656  // tuple of starting index of C/Ds blockwise copy
2657  const auto idx_c_ds_block_begin =
2660  [&](auto) {
2661  return make_multi_index(block_m_id, 0, block_n_id, 0);
2662  // return make_multi_index(block_work_idx[I0], 0,
2663  // block_work_idx[I1], 0);
2664  },
2665  Number<NumDTensor>{}));
2666 
2667  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2668  c_grid_desc_mblock_mperblock_nblock_nperblock;
2669 
2670  using CDEBlockTransferCluster =
2671  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2672  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2673  constexpr index_t scatter_weight_idx = 3; // hack fix felix
2674  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2676  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2678  decltype(c_ds_desc_refs),
2679  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2680  CElementwiseOperation,
2681  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
2682  // Sequence support
2683  // arbitray type
2684  Sequence<1,
2685  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2686  1,
2687  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2688  CDEBlockTransferCluster,
2689  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2690  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2691  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2692  3, // index_t SrcVectorDim,
2693  3, // index_t DstVectorDim,
2694  CDEShuffleBlockTransferScalarPerVectors,
2699  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2700  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2701  IndexType,
2702  1, // ScatterDim
2703  true, // OutputScatter: false, only use scatter weights
2704  scatter_weight_idx // ScatterWeightIdx: ascale
2705  >{c_ds_desc_refs,
2706  idx_c_ds_block_begin,
2707  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2708  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2709  c_element_op};
2710 
2711  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2712  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2713 
2714  constexpr auto sfc_c_vgpr =
2715  SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2716  NXdlPerWave / NXdlPack,
2717  1,
2718  1,
2719  MXdlPack,
2720  NXdlPack,
2721  M2,
2722  1,
2723  M4,
2724  1>,
2726  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2727  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2728  1,
2729  1,
2730  MXdlPack,
2731  NXdlPack,
2732  M2,
2733  1,
2734  M4,
2735  1>>{};
2736 
2737  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2738 
2739  // space filling curve for shuffled blockwise C/D/E
2740  constexpr auto sfc_cde_block =
2743  Sequence<1,
2744  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2745  1,
2746  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2747 
2748  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2749  constexpr auto EMThreads =
2750  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2751  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2752  constexpr auto ENThreads =
2753  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2754  static_for<0, num_access, 1>{}([&](auto access_id) {
2755  // make sure it's safe to write to LDS
2757 
2758  auto dstidx = sfc_cde_block.GetIndex(access_id);
2759  const index_t c_token_pos =
2760  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2761  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2762  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2763  IndexType token_offset = fused_token & 0xffffff;
2764  if constexpr(IsInputGemm)
2765  {
2766  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2767  }
2768  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2769  });
2770 
2771  block_sync_lds();
2772 
2773  // each thread write its data from VGPR to LDS
2774  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2775  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2776  c_thread_buf_fp32,
2777  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2778  c_shuffle_block_buf);
2779 
2780  // make sure it's safe to read from LDS
2781  block_sync_lds();
2782 
2783  // each block copy its data from LDS to global
2784  cde_block_copy_lds_and_global.Run(
2785  c_ds_desc_refs,
2786  c_ds_buf_refs,
2787  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2788  tie(c_grid_buf),
2789  scatter_offsets);
2790 
2791  if constexpr(access_id < num_access - 1)
2792  {
2793  constexpr auto cde_lds_and_global_step =
2794  sfc_cde_block.GetForwardStep(access_id);
2795 
2796  // move on Ds
2797  static_for<0, NumDTensor, 1>{}([&](auto i) {
2798  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2799  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2800  });
2801 
2802  // move on E
2803  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2804  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2805  I0,
2806  cde_lds_and_global_step);
2807  }
2808  });
2809  }
2810  }
2811 };
2812 
2813 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition: device_base.hpp:178
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:47
Definition: ck.hpp:268
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:45
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:277
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp:37
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:90
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:748
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:809
const index_t * p_sorted_token_ids
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:808
const BScaleDataType * p_b_scale_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:814
const BElementwiseOperation b_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:819
const BDataType * p_b_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:813
const index_t * p_max_token_id
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:810
DsGridPointer p_ds_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:815
const AElementwiseOperation a_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:818
CDataType * p_c_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:816
const AScaleDataType * p_a_scale_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:812
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:749
const CElementwiseOperation c_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:820
const ADataType * p_a_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:811
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:676
index_t AK0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:740
index_t NPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:737
index_t KBatch
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:735
index_t BK0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:741
index_t TopK
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:725
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:677
index_t KRead
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:738
index_t K
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:728
index_t MPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:736
index_t StrideC
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:734
index_t StrideScaleB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:732
index_t NumTokens
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:724
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:733
index_t MBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:742
index_t StrideScaleA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:730
index_t N
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:727
__host__ void Print() const
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:712
index_t StrideA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:729
index_t StrideB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:731
index_t M
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:726
index_t KPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:739
index_t NBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:743
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:824
index_t a_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:866
index_t b_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:867
index_t a_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:864
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:825
index_t b_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:865
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:174
static constexpr auto I6
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:184
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:302
static constexpr auto AK1Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:194
static constexpr auto AK0Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:192
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1044
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:252
static constexpr auto I1
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:179
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:285
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:256
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:652
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:2046
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1260
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:664
static constexpr index_t NLane
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:225
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:990
static constexpr index_t SortedTileSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:231
static constexpr index_t scale_pack_size_b
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:235
static constexpr auto BK1Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:195
remove_cvref_t< decltype(BlockGemmMXBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1042
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1252
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:271
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:241
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:372
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:280
static constexpr auto BK0Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:193
static constexpr auto I9
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:187
static constexpr auto I4
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:182
static constexpr auto I3
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:181
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:308
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:290
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:266
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:189
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:315
static constexpr index_t KRepeat
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:228
static constexpr auto NXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:204
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1000
static constexpr auto KXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:205
BDataType LDSTypeB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:176
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:480
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1281
static constexpr auto I0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:178
static constexpr auto I5
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:183
static constexpr index_t scale_pack_size_a
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:234
static constexpr auto lcm_AK1_BK1
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:197
static constexpr index_t KLane
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:226
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:870
static constexpr auto MXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:203
static constexpr auto I8
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:186
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:276
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:598
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:607
static constexpr bool is_single_rate_mfma
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:198
static constexpr auto I2
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:180
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:588
static constexpr index_t APackedSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:213
static constexpr index_t NumDTensor
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:201
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:320
static constexpr index_t BPackedSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:214
static constexpr index_t KPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:222
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1246
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1068
ADataType LDSTypeA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:175
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:296
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:631
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:331
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:254
static constexpr auto I7
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:185
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:471
static constexpr index_t NWave
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:227
static constexpr auto is_scale_mfma
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:199
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1757
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Definition: thread_group_tensor_slice_transfer_gather_direct_load.hpp:57
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:26
Definition: data_type.hpp:42
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:187
Definition: functional2.hpp:33
Definition: device_base.hpp:197
Definition: unary_element_wise_operation.hpp:1041
Definition: unary_element_wise_operation.hpp:340
Definition: unary_element_wise_operation.hpp:1087
Definition: dtype_vector.hpp:10
#define CK_ENV(name)
Definition: env.hpp:129