/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp Source File
gridwise_moe_mx_gemm_bpreshuffle.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/env.hpp"
16 
20 
21 #define DEBUG_LOG 0
22 
23 namespace ck {
24 
25 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
26 // kernel function Blockers:
27 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
28 // two lds chunks.
29 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
30 // buffer when we declare __shared__ inside blkgemmpipe
31 
33 {
34  gelu_and_mul = 0,
35  silu_and_mul = 1
36 };
37 
38 template <typename GridwiseGemm,
39  bool HasMainKBlockLoop,
40  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
41  index_t MinimumOccupancy = 1,
42  TailNumber TailNum = TailNumber::Even>
43 __global__ void
44 #if CK_USE_LAUNCH_BOUNDS
45  __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
46 #endif
47  // __attribute__((amdgpu_waves_per_eu(1, 1)))
48  kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
49 {
50 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
51  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
52 
53  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
54 
55  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
56  karg.p_sorted_token_ids,
57  karg.p_sorted_expert_ids,
58  karg.p_max_token_id,
59  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
60  karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
61  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
62  karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
63  karg.p_ds_grid,
64  karg.p_c_grid,
65  p_shared,
66  karg,
67  karg.a_element_op,
68  karg.b_element_op,
69  karg.c_element_op);
70 #else
71  ignore = karg;
72 #endif // end of if (defined(__gfx9__))
73 }
74 
75 template <typename GridwiseGemm,
76  bool HasMainKBlockLoop,
77  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
78  index_t MinimumOccupancy = 1,
79  TailNumber TailNum = TailNumber::Even>
80 __global__ void
81 #if CK_USE_LAUNCH_BOUNDS
82  __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
83 #endif
84  // __attribute__((amdgpu_waves_per_eu(1, 1)))
85  kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
86 {
87 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
88  __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
89  __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
90 
91  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
92 
93  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
94  karg.p_sorted_token_ids,
95  karg.p_sorted_expert_ids,
96  karg.p_max_token_id,
97  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
98  karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
99  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
100  karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
101  karg.p_ds_grid,
102  karg.p_c_grid,
103  p_shared_0,
104  p_shared_1,
105  karg,
106  karg.a_element_op,
107  karg.b_element_op,
108  karg.c_element_op);
109 #else
110  ignore = karg;
111 #endif // end of if (defined(__gfx9__))
112 }
113 
114 template <typename ALayout,
115  typename BLayout,
116  typename DsLayout,
117  typename CLayout,
118  typename ADataType,
119  typename AScaleDataType,
120  typename BDataType,
121  typename BScaleDataType,
122  typename AccDataType,
123  typename CShuffleDataType,
124  typename DsDataType,
125  typename CDataType,
126  typename AElementwiseOperation,
127  typename BElementwiseOperation,
128  typename CElementwiseOperation,
130  index_t ScaleBlockSize, // Scaling block size
131  index_t BlockSize, // Thread block size
132  index_t MPerBlock,
133  index_t NPerBlock,
134  index_t KPerBlock,
135  index_t AK1Value,
136  index_t BK1Value,
137  index_t MPerXdl,
138  index_t NPerXdl,
139  index_t MXdlPerWave,
140  index_t NXdlPerWave,
141  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
142  typename ABlockTransferThreadClusterArrangeOrder,
143  typename ABlockTransferSrcAccessOrder,
144  index_t ABlockTransferSrcVectorDim,
145  index_t ABlockTransferSrcScalarPerVector,
146  index_t ABlockTransferDstScalarPerVector_AK1,
147  bool AThreadTransferSrcResetCoordinateAfterRun,
148  index_t ABlockLdsExtraM,
149  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
150  typename BBlockTransferThreadClusterArrangeOrder,
151  typename BBlockTransferSrcAccessOrder,
152  index_t BBlockTransferSrcVectorDim,
153  index_t BBlockTransferSrcScalarPerVector,
154  index_t BBlockTransferDstScalarPerVector_BK1,
155  bool BThreadTransferSrcResetCoordinateAfterRun,
156  index_t BBlockLdsExtraN,
157  index_t CShuffleMXdlPerWavePerShuffle,
158  index_t CShuffleNXdlPerWavePerShuffle,
159  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
160  typename CDEShuffleBlockTransferScalarPerVectors,
163  index_t ActivationOperation = 0,
164  bool NSwizzle = false,
165  bool IsInputGemm = true,
166  bool MulRoutedWeight = true,
167  typename IndexType = index_t,
168  typename ComputeTypeA = ADataType,
169  typename ComputeTypeB = BDataType>
171 {
172  using LDSTypeA = ADataType;
173  using LDSTypeB = BDataType;
174 
175  static constexpr auto I0 = Number<0>{};
176  static constexpr auto I1 = Number<1>{};
177  static constexpr auto I2 = Number<2>{};
178  static constexpr auto I3 = Number<3>{};
179  static constexpr auto I4 = Number<4>{};
180  static constexpr auto I5 = Number<5>{};
181  static constexpr auto I6 = Number<6>{};
182  static constexpr auto I7 = Number<7>{};
183  static constexpr auto I8 = Number<8>{};
184  static constexpr auto I9 = Number<9>{};
185 
187  CDEShuffleBlockTransferScalarPerVectors{}[I0];
188  // K1 should be Number<...>
189  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
190  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
191  static constexpr auto AK1Number = Number<AK1Value>{};
192  static constexpr auto BK1Number = Number<BK1Value>{};
193 
194  static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
195  static constexpr bool is_single_rate_mfma = false;
196  static constexpr auto is_scale_mfma = true;
197 
198  static constexpr index_t NumDTensor = DsDataType::Size();
199 
200  static constexpr auto MXdlPack = 2;
201  static constexpr auto NXdlPack = 2;
202  static constexpr auto KXdlPack = 2;
203 
204  //> KPack is at least the k_per_blk of selected mfma
205  //
206  // Should be a multiple of k_per_blk.
207  // TODO: Move this to blockwise pipeline base
208  // KPack in packed data types for pk A/B
209 
210  static constexpr index_t APackedSize = packed_size_v<ADataType>;
211  static constexpr index_t BPackedSize = packed_size_v<BDataType>;
212 
213  using mfma_selector = MfmaSelector<ComputeTypeA,
214  MPerXdl,
215  NPerXdl,
216  ComputeTypeB,
218  is_scale_mfma>;
219  static constexpr index_t KPack =
221 
222  static constexpr index_t NLane = NPerXdl;
223  static constexpr index_t KLane = 64 / NLane;
224  static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
225  static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
226 
227  // static constexpr index_t NumTokens = 1;
228  static constexpr index_t SortedTileSize = MPerBlock;
229 
231  static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
232  static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
233  static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
234  "A scale pack data type too large!");
235  static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
236  "B scale pack data type too large!");
237 
238  static constexpr auto MakeDsGridPointer()
239  {
240  return generate_tuple(
241  [&](auto i) {
242  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
243 
244  return static_cast<const DDataType*>(nullptr);
245  },
247  }
248 
249  using DsGridPointer = decltype(MakeDsGridPointer());
250 
252 
253  __host__ static auto CalculateGridSize(index_t M, index_t N)
254  {
255  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
256  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
257  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
258  const index_t gridy = NSwizzle ? 1 : mblock;
259 
260  return std::make_tuple(gridx, gridy, 1);
261  }
262 
263  __host__ static auto CalculateMPadded(index_t M)
264  {
265  return math::integer_least_multiple(M, MPerBlock);
266  }
267 
268  __host__ static auto CalculateNPadded(index_t N)
269  {
270  return math::integer_least_multiple(N, NPerBlock);
271  }
272 
273  __host__ static auto CalculateBN0Shuffled(index_t N)
274  {
275  return math::integer_divide_ceil(N, NLane);
276  }
277  __host__ static auto CalculateBK0Shuffled(index_t K)
278  {
280  }
281 
282  __host__ static auto CalculateKPadded(index_t K)
283  {
284  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
285  }
286 
287  __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
288  {
289  auto K_t = K_Batch * KPerBlock;
290  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
291  }
292 
293  __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
294  {
295  auto K_t = K_Batch * KPerBlock;
296  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
297  }
298 
299  __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
300  {
301  auto K_t = K_Batch * KPerBlock;
302  return (K + K_t - 1) / K_t * KPerBlock;
303  }
304 
305  __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
306  {
307  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
308  auto K_t = K_Batch * KReadVec;
309  return (K + K_t - 1) / K_t * KReadVec;
310  }
311 
312  __host__ static auto CalculateMBlock(index_t M)
313  {
314  return math::integer_divide_ceil(M, MPerBlock);
315  }
316 
317  __host__ static auto CalculateNBlock(index_t N)
318  {
319  return math::integer_divide_ceil(N, NPerBlock);
320  }
321 
322  template <index_t MNXdlPerWave,
323  index_t MNWaves,
324  index_t MNXdlPack,
325  index_t MNPerXdl,
326  bool IsXor,
327  typename TileDesc_K0_MN_K1>
328  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
329  {
330  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
331  constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
332  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
333 
334  if constexpr(IsXor)
335  {
336  constexpr auto permuted_desc = transform_tensor_descriptor(
337  TileDesc_K0_MN_K1{},
342 
344  permuted_desc,
345  make_tuple(
348  Number<MNWaves>{},
350  Number<MNPerXdl>{}))),
353  }
354  else
355  {
357  TileDesc_K0_MN_K1{},
358  make_tuple(
361  Number<MNWaves>{},
363  Number<MNPerXdl>{}))),
366  }
367  }
368 
369  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
370  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
371  {
372  const auto a_grid_desc_mraw_kraw = [&]() {
373  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
374  {
375  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
376  }
377  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
378  {
379  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
380  }
381  }();
382 
384 
385  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
386  GemmSpec == GemmSpecialization::MNKPadding)
387  {
388  // pad both M and K
389  const auto a_grid_desc_m_k =
390  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
392  make_right_pad_transform(K, KPad - K)),
395 
396  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
397  a_grid_desc_m_k,
402 
403  return a_grid_desc_ak0_m_ak1;
404  }
405  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
406  GemmSpec == GemmSpecialization::MNPadding)
407  {
408  // pad M, but not K
409  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
410  a_grid_desc_mraw_kraw,
412  make_right_pad_transform(M, MPad - M)),
415 
416  return a_grid_desc_ak0_m_ak1;
417  }
418  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
419  GemmSpec == GemmSpecialization::NKPadding)
420  {
421  // pad K, but not M
422  const auto a_grid_desc_m_k = transform_tensor_descriptor(
423  a_grid_desc_mraw_kraw,
427 
428  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
429  a_grid_desc_m_k,
434 
435  return a_grid_desc_ak0_m_ak1;
436  }
437  else
438  {
439  // not pad M or K
440  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
441  a_grid_desc_mraw_kraw,
442  make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
446 
447  const auto a_grid_desc_permuted = transform_tensor_descriptor(
448  a_grid_desc_ak0_m_ak1,
451  make_pass_through_transform(AK1Value)),
454 
455  const auto a_grid_desc = transform_tensor_descriptor(
456  a_grid_desc_permuted,
457  make_tuple(
460  make_pass_through_transform(AK1Value)),
463 
464  return a_grid_desc;
465  }
466  }
467 
468  __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
469  {
470  constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack>{};
472  make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber));
473  }
474 
475  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
476  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
477  {
478  const auto b_grid_desc_nraw_kraw = [&]() {
480  {
481  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
482  }
484  {
485  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
486  }
487  }();
488 
490 
491  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
492  GemmSpec != GemmSpecialization::Default),
493  "pk_i4_t does not support padding");
494  static_assert(!(is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t> &&
495  GemmSpec != GemmSpecialization::Default),
496  "f4x2_pk_t does not support padding");
497 
498  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
499  GemmSpec == GemmSpecialization::MNKPadding)
500  {
501  // pad both N and K
502  const auto b_grid_desc_n_k =
503  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
505  make_right_pad_transform(K, KPad - K)),
508 
509  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
510  b_grid_desc_n_k,
515 
516  return b_grid_desc_bk0_n_bk1;
517  }
518  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
519  GemmSpec == GemmSpecialization::MNPadding)
520  {
521  // pad N, but not K
522  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
523  b_grid_desc_nraw_kraw,
525  make_right_pad_transform(N, NPad - N)),
528 
529  return b_grid_desc_bk0_n_bk1;
530  }
531  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
532  GemmSpec == GemmSpecialization::MKPadding)
533  {
534  // pad K, but not N
535  const auto b_grid_desc_n_k = transform_tensor_descriptor(
536  b_grid_desc_nraw_kraw,
540 
541  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
542  b_grid_desc_n_k,
547 
548  return b_grid_desc_bk0_n_bk1;
549  }
550  else
551  {
552  // not pad N or K
553  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
554  b_grid_desc_nraw_kraw,
555  make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)),
559 
560  const auto b_grid_desc_permuted = transform_tensor_descriptor(
561  b_grid_desc_bk0_n_bk1,
564  make_pass_through_transform(BK1Value)),
567 
568  const auto b_grid_desc = transform_tensor_descriptor(
569  b_grid_desc_permuted,
570  make_tuple(
573  make_pass_through_transform(BK1Value)),
576 
577  return b_grid_desc;
578  }
579  }
580 
581  template <typename ABlockDesc_AK0_M_AK1>
582  __host__ __device__ static constexpr auto
583  MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&)
584  {
585  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
586 
587  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MXdlPack, MPerXdl, true>(
588  ABlockDesc_AK0_M_AK1{});
589  }
590 
591  template <typename BBlockDesc_BK0_N_BK1>
592  __host__ __device__ static constexpr auto
593  MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&)
594  {
595  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
596 
597  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NXdlPack, NPerXdl, false>(
598  BBlockDesc_BK0_N_BK1{});
599  }
600 
601  template <typename ELayout>
602  __host__ __device__ static auto MakeCGridDescriptor_M_N(
603  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
604  {
605  const auto c_grid_desc_mraw_nraw = [&]() {
607  {
608  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
609  }
611  {
612  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
613  }
614  }();
615 
616  // pad M and N
617  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
619  make_right_pad_transform(N, NPad - N)),
622  }
623 
624  template <typename DLayout>
625  __host__ __device__ static auto
627  {
628  const auto c_grid_desc_mraw_nraw = [&]() {
630  {
631  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
632  }
634  {
635  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
636  }
637  }();
638 
639  // pad M and N
640  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
642  make_right_pad_transform(N, NPad - N)),
645  }
646 
647  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
648  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
649  {
650  return generate_tuple(
651  [&](auto i) {
652  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
653  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
654  },
656  }
657 
658  template <typename DsGridDesc>
660  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
661  {
662  return generate_tuple(
663  [&](auto i) {
665  ds_grid_desc_m_n[i], MBlock, NBlock);
666  },
668  }
669 
670  struct Problem
671  {
672  __host__ Problem(index_t NumTokens_,
673  index_t TopK_,
674  index_t M_,
675  index_t N_,
676  index_t K_,
677  index_t StrideA_,
678  index_t StrideScaleA_,
679  index_t StrideB_,
680  index_t StrideScaleB_,
681  std::array<index_t, NumDTensor> StrideDs_,
682  index_t StrideC_,
683  index_t KBatch_)
684  : NumTokens{NumTokens_},
685  TopK{TopK_},
686  M{M_},
687  N{N_},
688  K{K_},
689  StrideA{StrideA_},
690  StrideScaleA{StrideScaleA_},
691  StrideB{StrideB_},
692  StrideScaleB{StrideScaleB_},
693  StrideDs{StrideDs_},
694  StrideC{StrideC_},
695  KBatch{KBatch_},
698  KRead{CalculateKRead(K_, KBatch_)},
699  KPadded{CalculateKPadded(K_, KBatch_)},
700  AK0{CalculateAK0Padded(K_, KBatch_)},
701  BK0{CalculateBK0Padded(K_, KBatch_)},
702  MBlock{CalculateMBlock(M_)},
703  NBlock{CalculateNBlock(N_)},
706  {
707  }
708 
709  __host__ void Print() const
710  {
711  std::cout << "problem {"
712  << "NumTokens:" << NumTokens << ", "
713  << "TopK:" << TopK << ", "
714  << "M:" << M << ", "
715  << "N:" << N << ", "
716  << "K:" << K << ", "
717  << "SA:" << StrideA << ", "
718  << "SScaleA:" << StrideScaleA << ", "
719  << "SB:" << StrideB << ", "
720  << "SScaleB:" << StrideScaleB << ", "
721  << "SC:" << StrideC << ", "
722  << "MP:" << MPadded << ", "
723  << "NP:" << NPadded << ", "
724  << "KRead:" << KRead << ", "
725  << "KP:" << KPadded << ", "
726  << "AK0:" << AK0 << ", "
727  << "BK0:" << BK0 << ", "
728  << "MBlock: " << MBlock << ", "
729  << "NBlock: " << NBlock << "}" << std::endl;
730  }
731 
741  std::array<index_t, NumDTensor> StrideDs;
752  // FOR PRESHUFFLE ONLY
755  };
756 
757  // Argument
759  {
760  __host__ Argument(const index_t* p_sorted_token_ids_,
761  const index_t* p_sorted_expert_ids_,
762  const index_t* p_max_token_id_,
763  const ADataType* p_a_grid_,
764  const AScaleDataType* p_a_scale_grid_,
765  const BDataType* p_b_grid_,
766  const BScaleDataType* p_b_scale_grid_,
767  std::array<const void*, NumDTensor> p_ds_grid_,
768  CDataType* p_c_grid_,
769  index_t NumTokens_,
770  index_t TopK_,
771  index_t M_,
772  index_t N_,
773  index_t K_,
774  index_t StrideA_,
775  index_t StrideScaleA_,
776  index_t StrideB_,
777  index_t StrideScaleB_,
778  std::array<index_t, NumDTensor> StrideDs_,
779  index_t StrideC_,
780  index_t k_batch_,
781  AElementwiseOperation a_element_op_,
782  BElementwiseOperation b_element_op_,
783  CElementwiseOperation c_element_op_)
784  : Problem{NumTokens_,
785  TopK_,
786  M_,
787  N_,
788  K_ / APackedSize,
789  StrideA_ / APackedSize,
790  StrideScaleA_,
791  StrideB_ / BPackedSize,
792  StrideScaleB_,
793  StrideDs_,
794  StrideC_,
795  k_batch_},
796  p_sorted_token_ids{p_sorted_token_ids_},
797  p_sorted_expert_ids{p_sorted_expert_ids_},
798  p_max_token_id{p_max_token_id_},
799  p_a_grid{p_a_grid_},
800  p_a_scale_grid{p_a_scale_grid_},
801  p_b_grid{p_b_grid_},
802  p_b_scale_grid{p_b_scale_grid_},
803  p_ds_grid{},
804  p_c_grid{p_c_grid_},
805  a_element_op{a_element_op_},
806  b_element_op{b_element_op_},
807  c_element_op{c_element_op_}
808  {
809 
810  // populate pointer, desc for Ds
811  static_for<0, NumDTensor, 1>{}([&](auto i) {
812  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
813 
814  // D pointer
815  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
816  });
817  }
818 
822  const ADataType* p_a_grid;
823  const AScaleDataType* p_a_scale_grid;
824  const BDataType* p_b_grid;
825  const BScaleDataType* p_b_scale_grid;
827  CDataType* p_c_grid;
828 
829  const AElementwiseOperation a_element_op;
830  const BElementwiseOperation b_element_op;
831  const CElementwiseOperation c_element_op;
832  };
833 
835  {
836  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
837  {
838  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
839  {
840  a_k_split_offset = k_id * karg.KRead;
841  }
842  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
843  {
844  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
845  }
846 
847  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
848  {
849  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
850  }
851  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
852  {
853  // KPack * NLane * KLane * K0 * N0
854  b_k_split_offset = k_id * karg.KRead * NPerXdl;
855  }
856 
857  // Calculate A scale offset
858  a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize) * MXdlPack *
859  MPerXdl / scale_pack_size_a;
860 
861  // Calculate B scale offset
862  b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize) * NXdlPack *
863  NPerXdl / scale_pack_size_b;
864 
865  if(k_id < karg.KBatch - 1)
866  {
867  karg.K = karg.KRead;
868  }
869  else
870  {
871  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
872  }
873  }
874 
879  };
880 
881  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
882  {
883  // A matrix in LDS memory, dst of blockwise copy
884  if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
885  {
886  // contiguous in LDS
890  }
891  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
892  // in some cases.
894  {
895  constexpr auto a_lds_block_desc =
898 
899  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
900  a_lds_block_desc,
906 
907  return a_lds_block_desc_permuted;
908  }
909  else // ColumnMajor A
910  {
911  // kfold and mpair dimension is not always required.
912  // more dimension in merge_transform increase the difficulty of generating immarg offset
913  // for compiler.
914  constexpr auto WaveSize = 64;
915  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
916  constexpr auto M1 = MPerBlock / M0;
917 
918  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
919  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
920  constexpr auto KThreadRead = WaveSize / MPerXdl;
921  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
922 
923  constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
924  ? 1
925  : 128 / (AK1Number * M0 * sizeof(ADataType));
926  constexpr auto KThreadReadPerm =
927  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
928  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
929  : KThreadRead;
930 
931  // 1<=mpair<=n0
932  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
933  ? 1
934  : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
935  ? M0
936  : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
937 
938  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
942  Number<kfold * M0 / mpair>{},
943  Number<mpair>{},
944  AK1Number));
945 
946  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
947  a_lds_block_desc,
948  make_tuple(
952  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
955  make_tuple(
957  make_tuple(
959 
960  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
961  a_lds_block_desc_permuted,
962  make_tuple(
970  Sequence<1>{},
971  Sequence<2>{},
972  Sequence<3>{},
973  Sequence<4>{},
974  Sequence<5>{}),
976  Sequence<2>{},
977  Sequence<0, 3>{},
978  Sequence<4, 5>{},
979  Sequence<6>{},
980  Sequence<7>{}));
981 
982  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
983  a_lds_block_desc_unmerged,
986  Number<KThreadWrite / kfold / KThreadReadPerm>{},
987  Number<kfold>{},
994 
995  return a_lds_block_desc_ak0_m_ak1;
996  }
997  }
998 
999  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
1000  {
1001  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1003  I1,
1004  Number<NXdlPack>{},
1005  Number<KRepeat>{},
1006  Number<BK1Value>{}));
1007  }
1008 
1010  {
1011  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1012 
1013  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1015  make_tuple(I1,
1017  I1,
1019 
1020  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1021  }
1022 
1025  BlkGemmPipelineVer,
1026  BlkGemmPipeSched,
1027  BlockSize,
1028  ScaleBlockSize,
1029  ADataType,
1030  AScaleDataType,
1031  BDataType,
1032  BScaleDataType,
1033  ComputeTypeA,
1034  AccDataType,
1041  ABlockTransferSrcScalarPerVector,
1042  BBlockTransferSrcScalarPerVector,
1043  MPerBlock,
1044  NPerBlock,
1045  KPerBlock,
1046  MPerXdl,
1047  NPerXdl,
1048  MXdlPerWave,
1049  NXdlPerWave,
1050  KPack,
1051  IsInputGemm>())>;
1052 
1053  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1054  {
1055  // LDS allocation for A and B: be careful of alignment
1056  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1057  // lds max alignment
1058  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1059 
1060  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1061  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1062 
1063  // LDS allocation for C shuffle in LDS
1064  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1066 
1067  constexpr auto c_block_size =
1068  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1069 
1070  return math::max(a_block_space_size_aligned * sizeof(ADataType),
1071  c_block_size * sizeof(CShuffleDataType));
1072  }
1073 
1074  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1075  __host__ static constexpr bool CheckValidity(const Argument& karg)
1076  {
1077  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1078  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1079  "Invalid tuning param!");
1080 
1081  static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
1082  "KPerBlock should be multiple of ScaleBlockSize");
1083 
1084  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1089  {
1090  if(!(karg.M % MPerBlock == 0))
1091  {
1092  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1093  {
1094  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1095  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1096  << std::endl;
1097  }
1098  return false;
1099  }
1100  }
1101 
1102  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1107  {
1108  if(!(karg.N % NPerBlock == 0))
1109  {
1110  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1111  {
1112  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1113  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1114  << std::endl;
1115  }
1116  return false;
1117  }
1118  }
1119 
1120  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1124  {
1125  auto K_t = karg.KBatch * KPerBlock;
1126  if(!(karg.K % K_t == 0))
1127  {
1128  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1129  {
1130  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1131  << karg.K << " " << __FILE__ << ":" << __LINE__
1132  << ", in function: " << __func__ << std::endl;
1133  }
1134  return false;
1135  }
1136  }
1137  else
1138  {
1139  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1140  auto K_t = karg.KBatch * KReadVec;
1141  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1142  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1143  {
1144  return false;
1145  }
1146  }
1147 
1149  {
1150  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1151  {
1152  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1153  {
1154  std::cout << "Arg K (" << karg.K
1155  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1156  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1157  << __LINE__ << ", in function: " << __func__ << std::endl;
1158  }
1159  return false;
1160  }
1161  }
1162  else
1163  {
1164  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1165  {
1166  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1167  {
1168  std::cout << "Arg M (" << karg.M
1169  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1170  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1171  << __LINE__ << ", in function: " << __func__ << std::endl;
1172  }
1173  return false;
1174  }
1175  }
1176 
1178  {
1179  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1180  {
1181  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1182  {
1183  std::cout << "Arg N (" << karg.N
1184  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1185  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1186  << __LINE__ << ", in function: " << __func__ << std::endl;
1187  }
1188  return false;
1189  }
1190  }
1191  else
1192  {
1193  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1194  {
1195  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1196  {
1197  std::cout << "Arg K (" << karg.K
1198  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1199  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1200  << __LINE__ << ", in function: " << __func__ << std::endl;
1201  }
1202  return false;
1203  }
1204  }
1205 
1207  {
1209  {
1210  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1211  {
1212  std::cout << "Arg N (" << karg.N
1213  << ") value is not a multiple of "
1214  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1216  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1217  << std::endl;
1218  }
1219  return false;
1220  }
1221  }
1222  else
1223  {
1225  {
1226  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1227  {
1228  std::cout << "Arg M (" << karg.M
1229  << ") value is not a multiple of "
1230  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1232  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1233  << std::endl;
1234 
1235  return false;
1236  }
1237  }
1238  }
1239 
1240  // check gridwise gemm pipeline
1241 #if 0
1242  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1243 
1244  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1245  {
1246  return false;
1247  }
1248 #endif
1249  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1250  return true;
1251  }
1252 
1253  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1254  {
1255  const index_t num_loop = K / KPerBlock;
1256 
1257  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1258  }
1259 
1260  __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1261  {
1262  const index_t num_loop = K / KPerBlock;
1263 
1264  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1265  }
1266 
1267  template <typename CGridDesc>
1268  __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1269  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1270  {
1271  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1272  c_grid_desc_m_n,
1277 
1278  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1279  }
1280 
1281  // return block_id to C matrix tile idx (m0, n0) mapping
1282  // if arch = gfx942
1283  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1284  // NPerBlock>;
1285 
1286 #if 0
1287  template <bool HasMainKBlockLoop,
1288  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1289  TailNumber TailNum = TailNumber::Odd>
1290  __device__ static void Run(const index_t* p_sorted_token_ids,
1291  const index_t* p_sorted_expert_ids,
1292  const index_t* p_max_token_id,
1293  const ADataType* p_a_grid,
1294  const AScaleDataType* p_a_scale_grid,
1295  const BDataType* p_b_grid,
1296  const BScaleDataType* p_b_scale_grid,
1297  DsGridPointer& p_ds_grid,
1298  CDataType* p_c_grid,
1299  void* p_shared,
1300  const Problem& problem,
1301  AElementwiseOperation a_element_op,
1302  BElementwiseOperation b_element_op,
1303  CElementwiseOperation c_element_op)
1304  {
1305  ignore = b_element_op;
1306  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1307  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1308  problem.MPadded,
1309  problem.K,
1310  problem.KPadded,
1311  problem.StrideA,
1312  problem.AK0);
1313  const auto b_grid_desc_bpreshuffled =
1314  MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
1315  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1316  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1317  problem.MPadded,
1318  problem.N,
1319  problem.NPadded,
1320  problem.StrideC);
1321 
1322  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
1323  make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerBlock),
1324  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
1325  (KXdlPack * 64 / MPerXdl),
1326  64 * KXdlPack * MXdlPack / scale_pack_size_a));
1327 
1328  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
1329  make_tuple(problem.N / (NXdlPack * NPerXdl),
1330  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
1331  (KXdlPack * 64 / NPerXdl),
1332  64 * KXdlPack * NXdlPack / scale_pack_size_b));
1333 
1334  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1336  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1337  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1338  // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
1339  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1340  if(expert_block_id * MPerBlock >= max_token_id)
1341  return;
1342  const index_t expert_id =
1343  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1344 
1345  const auto block_mn = [&]() -> std::pair<int, int> {
1346  if constexpr(NSwizzle)
1347  {
1348  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1349  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1350  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1351  const index_t expert_swizzle =
1352  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1353  const index_t bid_new = blockIdx.x - prefix_block;
1354  const index_t nid = __builtin_amdgcn_readfirstlane(
1355  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1356  const index_t mid =
1357  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1358  return {nid, mid};
1359  }
1360  else
1361  {
1362  return {blockIdx.x, blockIdx.y};
1363  }
1364  }();
1365 
1366  const index_t block_n_id = block_mn.first;
1367  const index_t block_m_id = block_mn.second;
1368  const index_t token0 =
1369  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1370 
1371  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1372  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1373  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1374  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1375  constexpr auto AKThreads = AK0Threads * AK1Threads;
1376  constexpr auto AMRepeats = MPerBlock / AMThreads;
1377  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1378 
1379  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1380  return;
1381  StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
1382  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1383  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1384  index_t token_offset = fused_token & 0xffffff;
1385  if constexpr(!IsInputGemm)
1386  {
1387  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1388  }
1389  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K / APackedSize;
1390  });
1391  const index_t expert_stride =
1392  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1393  const index_t expert_scale_stride =
1394  __builtin_amdgcn_readfirstlane(problem.N * (IsInputGemm ? 2 : 1) *
1395  math::integer_divide_ceil(problem.K, ScaleBlockSize));
1396 
1397  // N0, K0, Blocksize*KPack
1398  const index_t n_block_data_idx_on_grid =
1399  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1400 
1401  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1402  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1403  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1404  p_b_grid + expert_id * expert_stride / BPackedSize,
1405  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1406 
1407  // A, B scale buffer
1408  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1409  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1410  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1411  p_b_scale_grid + expert_id * expert_scale_stride,
1412  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1413 
1414  // A matrix in LDS memory, dst of blockwise copy
1415  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1416 
1417  // B matrix in LDS memory, dst of blockwise copy
1418  // dummy
1419  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1420  // A matrix blockwise copy
1421  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1423  AElementwiseOperation,
1426  Sequence<AK0Number, MPerBlock, AK1Number>,
1427  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1428  ABlockTransferThreadClusterArrangeOrder,
1429  ADataType,
1430  LDSTypeA,
1431  decltype(a_grid_desc_ak0_m_ak1),
1432  decltype(a_block_desc_ak0_m_ak1),
1433  ABlockTransferSrcAccessOrder,
1434  Sequence<0, 1, 2>,
1435  ABlockTransferSrcVectorDim,
1436  2,
1437  ABlockTransferSrcScalarPerVector,
1438  ABlockTransferDstScalarPerVector_AK1,
1439  1,
1440  1,
1441  AThreadTransferSrcResetCoordinateAfterRun,
1442  true,
1443  IndexType,
1444  1,
1445  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1446  make_multi_index(0, 0, 0),
1447  a_element_op,
1448  a_block_desc_ak0_m_ak1,
1449  make_multi_index(0, 0, 0),
1451  gather_offsets);
1452 
1453  // Thread-wise copy
1454  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1455  auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1456  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1457 
1458  auto b_blockwise_copy =
1459  ThreadwiseTensorSliceTransfer_v2<BDataType,
1460  BDataType,
1461  decltype(b_grid_desc_bpreshuffled),
1462  decltype(b_block_desc_bk0_n_bk1),
1463  Sequence<Number<NXdlPerWave / NXdlPack>{},
1464  I1,
1465  Number<NXdlPack>{},
1466  Number<KRepeat>{},
1467  Number<BK1Value>{}>,
1468  Sequence<1, 2, 0, 3>,
1469  4,
1470  BBlockTransferSrcScalarPerVector,
1471  BThreadTransferSrcResetCoordinateAfterRun,
1472  true>(
1473  b_grid_desc_bpreshuffled,
1474  make_multi_index(n_block_data_idx_on_grid,
1476  0,
1477  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1478 
1479  // LDS allocation for A and B: be careful of alignment
1480  // Cast after lds
1481  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1482  static_cast<LDSTypeA*>(p_shared),
1483  a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize);
1484 
1485  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1486  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
1487 
1488  // Blockwise GEMM pipeline
1489  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1490  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1491  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1492  decltype(c_thread_buf) c_thread_buf_up;
1493 
1494  StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
1495  float,
1496  c_thread_buf.num_of_v_,
1497  c_thread_buf.s_per_v,
1498  true>
1499  c_thread_buf_fp32;
1500 
1501  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1502  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1503  KPerBlock);
1504 
1505  // a and b scale processing
1506  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1507  const auto waveId_m = wave_idx[I0];
1508  const auto waveId_n = wave_idx[I1];
1509 
1510  static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
1511 
1512  auto thread_offset_shuffled =
1513  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
1514 
1515  auto a_thread_offset_m = waveId_m;
1516 
1517  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1518  AScaleDataType,
1519  AScaleDataType,
1520  decltype(a_scale_grid_desc_am_ak),
1521  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1522  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
1523  Sequence<0, 1, 2>, // DimAccessOrder
1524  2, // SrcVectorDim
1525  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
1526  1, // SrcScalarStrideInVector
1527  true>(a_scale_grid_desc_am_ak,
1528  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
1529  0,
1530  thread_offset_shuffled / scale_pack_size_a));
1531 
1532  // B scale load
1533  auto b_thread_offset_n = waveId_n;
1534 
1535  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1536  BScaleDataType,
1537  BScaleDataType,
1538  decltype(b_scale_grid_desc_bn_ak),
1539  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1540  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1541  Sequence<0, 1, 2>, // DimAccessOrder
1542  2, // SrcVectorDim
1543  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1544  1, // SrcScalarStrideInVector
1545  true>(b_scale_grid_desc_bn_ak,
1546  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1547  0,
1548  thread_offset_shuffled / scale_pack_size_b));
1549 
1550  if constexpr(IsInputGemm)
1551  {
1552  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
1553  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1554  p_b_grid_up + expert_id * expert_stride / BPackedSize,
1555  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1556  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
1557  BDataType,
1558  BDataType,
1559  decltype(b_grid_desc_bpreshuffled),
1560  decltype(b_block_desc_bk0_n_bk1),
1561  Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
1562  Sequence<1, 2, 0, 3>,
1563  3,
1564  BBlockTransferSrcScalarPerVector,
1565  BThreadTransferSrcResetCoordinateAfterRun,
1566  true>(b_grid_desc_bpreshuffled,
1567  make_multi_index(n_block_data_idx_on_grid,
1569  0,
1570  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1571  const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
1572  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1573  p_b_scale_grid_up + expert_id * expert_scale_stride,
1574  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1575  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
1576  BScaleDataType,
1577  BScaleDataType,
1578  decltype(b_scale_grid_desc_bn_ak),
1579  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1580  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1581  Sequence<0, 1, 2>, // DimAccessOrder
1582  2, // SrcVectorDim
1583  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1584  1, // SrcScalarStrideInVector
1585  true>(
1586  b_scale_grid_desc_bn_ak,
1587  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1588  0,
1589  thread_offset_shuffled / scale_pack_size_b));
1590 
1591  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1592  a_grid_desc_ak0_m_ak1,
1593  a_block_desc_ak0_m_ak1,
1594  a_blockwise_copy,
1595  a_grid_buf,
1596  a_block_buf,
1597  a_block_slice_copy_step,
1598  b_grid_desc_bpreshuffled,
1599  b_block_desc_bk0_n_bk1,
1600  b_blockwise_copy,
1601  b_blockwise_copy_up,
1602  b_grid_buf,
1603  b_grid_buf_up,
1604  b_block_buf,
1605  b_block_slice_copy_step,
1606  c_thread_buf,
1607  c_thread_buf_up,
1608  a_scale_grid_desc_am_ak,
1609  a_scale_thread_copy,
1610  a_scale_grid_buf,
1611  b_scale_grid_desc_bn_ak,
1612  b_scale_thread_copy,
1613  b_scale_thread_copy_up,
1614  b_scale_grid_buf,
1615  b_scale_grid_buf_up,
1616  num_k_block_main_loop);
1617  }
1618  else
1619  {
1620  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1621  a_grid_desc_ak0_m_ak1,
1622  a_block_desc_ak0_m_ak1,
1623  a_blockwise_copy,
1624  a_grid_buf,
1625  a_block_buf,
1626  a_block_slice_copy_step,
1627  b_grid_desc_bpreshuffled,
1628  b_block_desc_bk0_n_bk1,
1629  b_blockwise_copy,
1630  b_grid_buf,
1631  b_block_buf,
1632  b_block_slice_copy_step,
1633  c_thread_buf,
1634  a_scale_grid_desc_am_ak,
1635  a_scale_thread_copy,
1636  a_scale_grid_buf,
1637  b_scale_grid_desc_bn_ak,
1638  b_scale_thread_copy,
1639  b_scale_grid_buf,
1640  num_k_block_main_loop);
1641  }
1642 
1643  // shuffle C and write out
1644  {
1645  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1646  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1647  "wrong!");
1648 
1649  // TODO: hacky, fix it!
1650  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1651  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1652 
1653  // TODO: hacky, fix it!
1654  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1655  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1656  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1657 
1658  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1659  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1660  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1661  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1662  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1663  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1664  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1665  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1666 
1667  // mul scales
1668  static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
1669  static_assert(M4 == 4);
1670  const index_t m1 = get_warp_local_1d_id() / NWave;
1671  const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
1672 
1673  vector_type<float, 4> topk_weights; // for gemm2 only
1674  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
1675  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
1676  static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
1677  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
1678  m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
1679  if constexpr(MulRoutedWeight)
1680  {
1681  topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
1682  p_ds_grid[I2] + m_pos);
1683  }
1684  static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
1685  constexpr index_t c_offset =
1686  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1687  make_tuple(m0, n0, m2 * M4 + m4));
1688  constexpr auto cidx = Number<c_offset>{};
1689 
1690  if constexpr(IsInputGemm) // gu fusion
1691  {
1692  if constexpr(ActivationOperation == Activation::silu_and_mul)
1693  {
1694  float gate = c_thread_buf[cidx];
1695  float up = c_thread_buf_up[cidx];
1696  if constexpr(MulRoutedWeight)
1697  {
1698  gate = gate * topk_weights.AsType<float>()[m4];
1699  up = up * topk_weights.AsType<float>()[m4];
1700  }
1701  tensor_operation::element_wise::Silu{}(gate, gate);
1702  c_thread_buf_fp32(cidx) = gate * up;
1703  }
1704  else if(ActivationOperation == Activation::gelu_and_mul)
1705  {
1706  float gate = c_thread_buf[cidx];
1707  float up = c_thread_buf_up[cidx];
1708  if constexpr(MulRoutedWeight)
1709  {
1710  gate = gate * topk_weights.AsType<float>()[m4];
1711  up = up * topk_weights.AsType<float>()[m4];
1712  }
1713  tensor_operation::element_wise::Gelu{}(gate, gate);
1714  c_thread_buf_fp32(cidx) = gate * up;
1715  }
1716  }
1717  else
1718  {
1719  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1720  if constexpr(MulRoutedWeight)
1721  {
1722  c_thread_buf_fp32(cidx) =
1723  topk_weights.AsType<float>()[m4] * c_thread_buf_fp32[cidx];
1724  }
1725  }
1726  });
1727  });
1728  });
1729  });
1730 
1731  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1733 
1734  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1735  static_cast<CShuffleDataType*>(p_shared),
1736  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1737 
1738  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1739  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1740  make_tuple(
1743  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1744  M1, // M1 = MWave
1745  M2, // M2 * M3 * M4 = MPerXdl
1746  M3,
1747  M4)),
1750  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1751  N1, // N1 = NWave
1752  N2))), // N2 = NPerXdl
1753  make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
1754  make_tuple(
1755  Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
1756 
1757  // calculate origin of thread output tensor on global memory
1758  // blockwise GEMM c matrix starting index
1759  const auto c_thread_mtx_on_block =
1760  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1761 
1762  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1763  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1764 
1765  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1767  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1768  make_tuple(Sequence<0, 1, 2, 3, 4>{}),
1769  make_tuple(Sequence<0>{}));
1770 
1771  const auto m_thread_data_on_block_idx =
1772  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1773  make_multi_index(m_thread_data_on_block));
1774 
1775  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1778  make_tuple(Sequence<0, 1, 2>{}),
1779  make_tuple(Sequence<0>{}));
1780 
1781  const auto n_thread_data_on_block_idx =
1782  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1783  make_multi_index(n_thread_data_on_block));
1784 
1785  // shuffle: threadwise copy C from VGPR to LDS
1786  auto c_thread_copy_vgpr_to_lds =
1787  ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
1788  CShuffleDataType,
1789  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1790  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1792  Sequence<CShuffleMXdlPerWavePerShuffle,
1793  CShuffleNXdlPerWavePerShuffle,
1794  I1,
1795  I1,
1796  M2,
1797  I1,
1798  M4,
1799  I1>,
1800  Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
1801  7,
1802  1,
1804  1,
1805  true>{
1806  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1807  make_multi_index(0,
1808  0,
1809  m_thread_data_on_block_idx[I1],
1810  n_thread_data_on_block_idx[I1],
1811  m_thread_data_on_block_idx[I2],
1812  m_thread_data_on_block_idx[I3],
1813  m_thread_data_on_block_idx[I4],
1814  n_thread_data_on_block_idx[I2]),
1816 
1817  using EDataType = CDataType;
1818 
1819  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1820  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1821 
1822  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1824  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1825 
1826  const auto ds_grid_buf = generate_tuple(
1827  [&](auto i) {
1828  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1829  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1830  },
1831  Number<NumDTensor>{});
1832 
1833  // tuple of reference to C/Ds tensor descriptors
1834  const auto c_ds_desc_refs = concat_tuple_of_reference(
1835  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1836  generate_tie([&](auto i) -> const auto& // return type should be reference
1837  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1838  Number<NumDTensor>{}));
1839 
1840  // tuple of reference to C/Ds tensor descriptors
1841  const auto c_ds_buf_refs = concat_tuple_of_reference(
1842  tie(c_shuffle_block_buf),
1843  generate_tie([&](auto i) -> const auto& // return type should be reference
1844  { return ds_grid_buf[i]; },
1845  Number<NumDTensor>{}));
1846 
1847  // tuple of starting index of C/Ds blockwise copy
1848  const auto idx_c_ds_block_begin =
1851  [&](auto) {
1852  return make_multi_index(block_m_id, 0, block_n_id, 0);
1853  // return make_multi_index(block_work_idx[I0], 0,
1854  // block_work_idx[I1], 0);
1855  },
1856  Number<NumDTensor>{}));
1857 
1858  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1859  c_grid_desc_mblock_mperblock_nblock_nperblock;
1860 
1861  using CDEBlockTransferCluster =
1862  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1863  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1864  constexpr index_t scatter_weight_idx = 1; // hack fix felix
1865  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1867  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1868  Tuple<EDataType>,
1869  decltype(c_ds_desc_refs),
1870  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1871  CElementwiseOperation,
1872  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1873  // support arbitray type
1874  Sequence<1,
1875  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1876  1,
1877  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1878  CDEBlockTransferCluster,
1879  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1880  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1881  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1882  3, // index_t SrcVectorDim,
1883  3, // index_t DstVectorDim,
1884  CDEShuffleBlockTransferScalarPerVectors,
1887  Sequence<true>,
1889  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1890  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1891  IndexType,
1892  1, // ScatterDim
1893  true, // OutputScatter: false, only use scatter weights
1894  scatter_weight_idx // ScatterWeightIdx: ascale
1895  >{c_ds_desc_refs,
1896  idx_c_ds_block_begin,
1897  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1898  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
1899  c_element_op};
1900 
1901  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1902  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1903  constexpr auto sfc_c_vgpr =
1904  SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
1905  Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
1906  Sequence<CShuffleMXdlPerWavePerShuffle,
1907  CShuffleNXdlPerWavePerShuffle,
1908  1,
1909  1,
1910  M2,
1911  1,
1912  M4,
1913  1>>{};
1914 
1915  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1916 
1917  // space filling curve for shuffled blockwise C/D/E
1918  constexpr auto sfc_cde_block =
1919  SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
1920  Sequence<0, 2, 1, 3>,
1921  Sequence<1,
1922  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1923  1,
1924  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1925 
1926  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1927  constexpr auto EMThreads =
1928  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
1929  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1930  constexpr auto ENThreads =
1931  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
1932  static_for<0, num_access, 1>{}([&](auto access_id) {
1933  // make sure it's safe to write to LDS
1934  StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
1935 
1936  auto dstidx = sfc_cde_block.GetIndex(access_id);
1937  const index_t c_token_pos =
1938  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
1939  static_for<0, EMRepeats, 1>{}([&](auto m0) {
1940  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1941  IndexType token_offset = fused_token & 0xffffff;
1942  if constexpr(IsInputGemm)
1943  {
1944  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1945  }
1946  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
1947  });
1948 
1949  block_sync_lds();
1950 
1951  // each thread write its data from VGPR to LDS
1952  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1953  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1954  c_thread_buf_fp32,
1955  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1956  c_shuffle_block_buf);
1957 
1958  // make sure it's safe to read from LDS
1959  block_sync_lds();
1960 
1961  // each block copy its data from LDS to global
1962  cde_block_copy_lds_and_global.Run(
1963  c_ds_desc_refs,
1964  c_ds_buf_refs,
1965  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1966  tie(c_grid_buf),
1967  scatter_offsets);
1968 
1969  if constexpr(access_id < num_access - 1)
1970  {
1971  constexpr auto cde_lds_and_global_step =
1972  sfc_cde_block.GetForwardStep(access_id);
1973 
1974  // move on Ds
1975  static_for<0, NumDTensor, 1>{}([&](auto i) {
1976  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1977  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
1978  });
1979 
1980  // move on E
1981  cde_block_copy_lds_and_global.MoveDstSliceWindow(
1982  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1983  I0,
1984  cde_lds_and_global_step);
1985  }
1986  });
1987  }
1988  }
1989 #endif
1990 
1991  template <bool HasMainKBlockLoop,
1992  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1993  TailNumber TailNum = TailNumber::Odd>
1994  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
1995  const index_t* p_sorted_expert_ids,
1996  const index_t* p_max_token_id,
1997  const ADataType* p_a_grid,
1998  const AScaleDataType* p_a_scale_grid,
1999  const BDataType* p_b_grid,
2000  const BScaleDataType* p_b_scale_grid,
2001  DsGridPointer& p_ds_grid,
2002  CDataType* p_c_grid,
2003  void* p_shared_0,
2004  void* p_shared_1,
2005  const Problem& problem,
2006  AElementwiseOperation a_element_op,
2007  BElementwiseOperation b_element_op,
2008  CElementwiseOperation c_element_op)
2009  {
2010  ignore = a_element_op;
2011  ignore = b_element_op;
2012  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2013  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
2014  problem.MPadded,
2015  problem.K,
2016  problem.KPadded,
2017  problem.StrideA,
2018  problem.AK0);
2019  const auto b_grid_desc_bpreshuffled =
2021  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
2022  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
2023  problem.MPadded,
2024  problem.N,
2025  problem.NPadded,
2026  problem.StrideC);
2027 
2028  // We pad the M unconditionaly for Scale
2029  const auto Padded_Scale_M =
2030  math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
2031  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
2032  make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
2033  math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
2034  (KXdlPack * 64 / MPerXdl),
2036  make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
2037  (ScaleBlockSize / APackedSize)) *
2038  MPerXdl * MXdlPack / scale_pack_size_a,
2040  1));
2041 
2042  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
2043  make_tuple(problem.N / (NXdlPack * NPerXdl),
2044  math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
2045  (KXdlPack * 64 / NPerXdl),
2047  make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
2048  (ScaleBlockSize / BPackedSize)) *
2049  NPerXdl * NXdlPack / scale_pack_size_b,
2051  1));
2052 
2053  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2055  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2056 
2057  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2058  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
2059  if(expert_block_id * MPerBlock >= max_token_id)
2060  return;
2061  const index_t expert_id =
2062  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2063  const auto block_mn = [&]() -> std::pair<int, int> {
2064  if constexpr(NSwizzle)
2065  {
2066  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2067  const index_t prefix_block = ecnt_prefix * problem.NBlock;
2068  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2069  const index_t expert_swizzle =
2070  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
2071  const index_t bid_new = blockIdx.x - prefix_block;
2072  const index_t nid = __builtin_amdgcn_readfirstlane(
2073  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2074  const index_t mid =
2075  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2076  return {nid, mid};
2077  }
2078  else
2079  {
2080  return {blockIdx.x, blockIdx.y};
2081  }
2082  }();
2083 
2084  const index_t block_n_id = block_mn.first;
2085  const index_t block_m_id = block_mn.second;
2086  const index_t token0 =
2087  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2088 
2089  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2090  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2091  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
2092  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
2093  constexpr auto AKThreads = AK0Threads * AK1Threads;
2094  constexpr auto AMRepeats = MPerBlock / AMThreads;
2095  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads;
2096 
2097  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
2098  return;
2100  static_for<0, AMRepeats, 1>{}([&](auto m0) {
2101  const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
2102  index_t token_offset = fused_token & 0xffffff;
2103  if constexpr(!IsInputGemm)
2104  {
2105  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2106  }
2107  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2108  });
2109 
2110  const index_t expert_stride =
2111  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2112  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2113  problem.N * (IsInputGemm ? 2 : 1) *
2114  math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
2115 
2116  // N0, K0, Blocksize*KPack
2117  const index_t n_block_data_idx_on_grid =
2118  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / NXdlPack);
2119 
2120  // Gride buffer creation
2121  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2122  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2123  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2124  p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
2125 
2126  // A, B scale buffer
2127  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2128  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2129  const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2130  p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
2131  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2132 
2133  // A matrix in LDS memory, dst of blockwise copy
2134  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2135 
2136  // B matrix in LDS memory, dst of blockwise copy
2137  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2138 
2139  // A matrix blockwise direct to LDS copy
2140  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
2143  ABlockTransferThreadClusterLengths_AK0_M_AK1,
2144  ABlockTransferThreadClusterArrangeOrder,
2145  ADataType,
2146  ADataType,
2147  decltype(a_grid_desc_ak0_m_ak1),
2148  decltype(a_block_desc_ak0_m_ak1),
2149  ABlockTransferSrcAccessOrder,
2150  ABlockTransferSrcVectorDim,
2151  2,
2152  ABlockTransferSrcScalarPerVector,
2153  IndexType,
2154  1>(a_grid_desc_ak0_m_ak1,
2155  make_multi_index(0, 0, 0),
2156  a_block_desc_ak0_m_ak1,
2157  make_multi_index(0, 0, 0),
2158  gather_offsets);
2159 
2160  // Thread-wise copy
2161  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2162  auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2163  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2164  auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2165  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2166  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2167 
2168  auto b_blockwise_copy =
2170  BDataType,
2171  decltype(b_grid_desc_bpreshuffled),
2172  decltype(b_block_desc_bk0_n_bk1),
2173  Sequence<Number<NXdlPerWave / NXdlPack>{},
2174  I1,
2175  Number<NXdlPack>{},
2176  Number<KRepeat>{},
2177  Number<BK1Value>{}>,
2179  4,
2180  BBlockTransferSrcScalarPerVector,
2181  BThreadTransferSrcResetCoordinateAfterRun,
2182  true>(
2183  b_grid_desc_bpreshuffled,
2184  make_multi_index(n_block_data_idx_on_grid,
2186  0,
2187  0,
2188  KPack * (get_thread_local_1d_id() % WarpSize)));
2189 
2190  // LDS allocation for A and B: be careful of alignment
2191  // Cast after lds
2192  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2193  static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2194  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2195  static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2196  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2197 
2198  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2199  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0);
2200 
2201  // Blockwise GEMM pipeline
2202  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2203  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2204  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2205  decltype(c_thread_buf) c_thread_buf_up;
2206 
2208  float,
2209  c_thread_buf.num_of_v_,
2210  c_thread_buf.s_per_v,
2211  true>
2212  c_thread_buf_fp32;
2213 
2214  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2215  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2216  KPerBlock);
2217 
2218  // a and b scale processing
2219  const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2220  const auto waveId_m = wave_idx[I0];
2221  const auto waveId_n = wave_idx[I1];
2222 
2223  auto thread_offset_shuffled =
2224  get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
2225 
2226  auto a_thread_offset_m = waveId_m;
2227 
2228  // get each thread's offset int the scale tensor
2229  const index_t token_scale_pos = block_m_id * MPerBlock;
2230  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2231  return;
2232 
2233  auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2234  AScaleDataType,
2235  AScaleDataType,
2236  decltype(a_scale_grid_desc_am_ak),
2237  decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2238  Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
2239  Sequence<0, 1, 2>, // DimAccessOrder
2240  2, // SrcVectorDim
2241  KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
2242  1, // SrcScalarStrideInVector
2243  true>(a_scale_grid_desc_am_ak,
2244  make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
2245  0,
2246  thread_offset_shuffled / scale_pack_size_a));
2247 
2248  // B scale load
2249  auto b_thread_offset_n = waveId_n;
2250 
2251  auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2252  BScaleDataType,
2253  BScaleDataType,
2254  decltype(b_scale_grid_desc_bn_ak),
2255  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2256  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2257  Sequence<0, 1, 2>, // DimAccessOrder
2258  2, // SrcVectorDim
2259  KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
2260  1, // SrcScalarStrideInVector
2261  true>(b_scale_grid_desc_bn_ak,
2262  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2263  0,
2264  thread_offset_shuffled / scale_pack_size_b));
2265 
2266  if constexpr(IsInputGemm)
2267  {
2268  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
2269  const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2270  p_b_grid_up + expert_id * expert_stride,
2271  b_grid_desc_bpreshuffled.GetElementSpaceSize());
2272  auto b_blockwise_copy_up =
2274  BDataType,
2275  decltype(b_grid_desc_bpreshuffled),
2276  decltype(b_block_desc_bk0_n_bk1),
2277  Sequence<Number<NXdlPerWave / NXdlPack>{},
2278  I1,
2279  Number<NXdlPack>{},
2280  Number<KRepeat>{},
2281  Number<BK1Value>{}>,
2283  4,
2284  BBlockTransferSrcScalarPerVector,
2285  BThreadTransferSrcResetCoordinateAfterRun,
2286  true>(
2287  b_grid_desc_bpreshuffled,
2288  make_multi_index(n_block_data_idx_on_grid,
2290  0,
2291  0,
2292  KPack * (get_thread_local_1d_id() % WarpSize)));
2293  const BScaleDataType* p_b_scale_grid_up =
2294  p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
2295  const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2296  p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
2297  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2298 
2299  auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
2300  BScaleDataType,
2301  BScaleDataType,
2302  decltype(b_scale_grid_desc_bn_ak),
2303  decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2304  Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2305  Sequence<0, 1, 2>, // DimAccessOrder
2306  2, // SrcVectorDim
2307  KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
2308  1, // SrcScalarStrideInVector
2309  true>(
2310  b_scale_grid_desc_bn_ak,
2311  make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2312  0,
2313  thread_offset_shuffled / scale_pack_size_b));
2314 
2315  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2316  // A
2317  a_grid_desc_ak0_m_ak1,
2318  a_block_desc_ak0_m_ak1,
2319  a_blockwise_copy,
2320  a_grid_buf,
2321  a_block_bufs,
2322  a_block_slice_copy_step,
2323  // Gate and Up
2324  b_grid_desc_bpreshuffled,
2325  b_block_desc_bk0_n_bk1,
2326  b_blockwise_copy,
2327  b_blockwise_copy_up,
2328  b_grid_buf,
2329  b_grid_buf_up,
2330  b_block_bufs,
2331  b_block_slice_copy_step,
2332  // C
2333  c_thread_buf,
2334  c_thread_buf_up,
2335  // A scale
2336  a_scale_grid_desc_am_ak,
2337  a_scale_thread_copy,
2338  a_scale_grid_buf,
2339  // B scale
2340  b_scale_grid_desc_bn_ak,
2341  b_scale_thread_copy,
2342  b_scale_thread_copy_up,
2343  b_scale_grid_buf,
2344  b_scale_grid_buf_up,
2345  num_k_block_main_loop);
2346  }
2347  else
2348  {
2349  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2350  a_grid_desc_ak0_m_ak1, // A
2351  a_block_desc_ak0_m_ak1,
2352  a_blockwise_copy,
2353  a_grid_buf,
2354  a_block_bufs,
2355  a_block_slice_copy_step,
2356  b_grid_desc_bpreshuffled, // B
2357  b_block_desc_bk0_n_bk1,
2358  b_blockwise_copy,
2359  b_grid_buf,
2360  b_block_bufs,
2361  b_block_slice_copy_step,
2362  c_thread_buf, // C
2363  a_scale_grid_desc_am_ak, // A scale
2364  a_scale_thread_copy,
2365  a_scale_grid_buf,
2366  b_scale_grid_desc_bn_ak, // B scale
2367  b_scale_thread_copy,
2368  b_scale_grid_buf,
2369  num_k_block_main_loop);
2370  }
2371 
2372  // shuffle C and write out
2373  {
2374  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2375  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2376  "wrong!");
2377  static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
2378  CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
2379  "wrong!");
2380 
2381  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2382 
2383  // TODO: hacky, fix it!
2384  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2385  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2386 
2387  // TODO: hacky, fix it!
2388  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2389  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2390  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2391 
2392  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2393  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2394  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2395  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2396  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2397  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2398  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2399  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2400  constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
2401  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
2402 
2403  // mul scales
2404 
2405  static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
2406  static_assert(M5 == 4);
2407  const index_t m1 = get_warp_local_1d_id() / NWave;
2408  const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
2409 
2410  vector_type<float, 4> topk_weights; // for gemm2 only
2411  static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
2412  static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
2413  static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
2414  static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
2415  static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
2416  const index_t m_pos = block_m_id * MPerBlock +
2417  m0 * M2 * M1 * M3 * M4 * M5 +
2418  m1 * M2 * M3 * M4 * M5 +
2419  imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
2420  if constexpr(MulRoutedWeight)
2421  {
2422  topk_weights =
2423  *c_style_pointer_cast<const vector_type<float, M5>*>(
2424  p_ds_grid[I2] + m_pos);
2425  }
2426  static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
2427  constexpr index_t c_offset =
2428  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2429  make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
2430  constexpr auto cidx = Number<c_offset>{};
2431 
2432  if constexpr(IsInputGemm) // gu fusion
2433  {
2434  if constexpr(ActivationOperation ==
2436  {
2437  float gate = c_thread_buf[cidx];
2438  float up = c_thread_buf_up[cidx];
2439  if constexpr(MulRoutedWeight)
2440  {
2441  gate = gate * topk_weights.AsType<float>()[m5];
2442  up = up * topk_weights.AsType<float>()[m5];
2443  }
2445  c_thread_buf_fp32(cidx) = gate * up;
2446  }
2447  else if(ActivationOperation == Activation::gelu_and_mul)
2448  {
2449  float gate = c_thread_buf[cidx];
2450  float up = c_thread_buf_up[cidx];
2451  if constexpr(MulRoutedWeight)
2452  {
2453  gate = gate * topk_weights.AsType<float>()[m5];
2454  up = up * topk_weights.AsType<float>()[m5];
2455  }
2457  c_thread_buf_fp32(cidx) = gate * up;
2458  }
2459  }
2460  else
2461  {
2462  c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2463  if constexpr(MulRoutedWeight)
2464  {
2465  c_thread_buf_fp32(cidx) =
2466  topk_weights.AsType<float>()[m5] *
2467  c_thread_buf_fp32[cidx];
2468  }
2469  }
2470  });
2471  });
2472  });
2473  });
2474  });
2475  });
2476 
2477  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2479 
2480  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2481  static_cast<CShuffleDataType*>(p_shared_0),
2482  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2483 
2484  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2485  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2486  make_tuple(
2489  Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
2490  // shuffle
2491  M1, // M1 = MWave
2492  M2, // M2 * M3 * M4 = MPerXdl
2493  M3,
2494  M4,
2495  M5)),
2499  // per shuffle
2500  N1, // N1 = NWave
2501  N2, // N2 = NXdlPack
2502  N3))), // N3 = NPerXdl
2506  Sequence<>{},
2508 
2509  // calculate origin of thread output tensor on global memory
2510  // blockwise GEMM c matrix starting index
2511  const auto c_thread_mtx_on_block =
2512  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2513 
2514  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2515  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2516 
2517  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2519  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
2521  make_tuple(Sequence<0>{}));
2522 
2523  const auto m_thread_data_on_block_idx =
2524  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2525  make_multi_index(m_thread_data_on_block));
2526 
2527  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2529  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
2531  make_tuple(Sequence<0>{}));
2532 
2533  const auto n_thread_data_on_block_idx =
2534  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2535  make_multi_index(n_thread_data_on_block));
2536 
2537  // shuffle: threadwise copy C from VGPR to LDS
2538  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
2539  AccDataType,
2540  CShuffleDataType,
2541  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2542  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2544  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2545  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2546  I1,
2547  I1,
2548  M2,
2549  N2,
2550  M3,
2551  I1,
2552  M5,
2553  I1>,
2555  9,
2556  1,
2558  1,
2559  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2560  make_multi_index(0,
2561  0,
2562  m_thread_data_on_block_idx[I1],
2563  n_thread_data_on_block_idx[I1],
2564  m_thread_data_on_block_idx[I2],
2565  n_thread_data_on_block_idx[I2],
2566  m_thread_data_on_block_idx[I3],
2567  m_thread_data_on_block_idx[I4],
2568  m_thread_data_on_block_idx[I5],
2569  n_thread_data_on_block_idx[I3]),
2571 
2572  using EDataType = CDataType;
2573 
2574  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2575  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2576 
2577  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2579  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2580 
2581  const auto ds_grid_buf = generate_tuple(
2582  [&](auto i) {
2583  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2584  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2585  },
2586  Number<NumDTensor>{});
2587 
2588  // tuple of reference to C/Ds tensor descriptors
2589  const auto c_ds_desc_refs = concat_tuple_of_reference(
2590  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2591  generate_tie(
2592  [&](auto i) -> const auto& // return type should be reference
2593  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2594  Number<NumDTensor>{}));
2595 
2596  // tuple of reference to C/Ds tensor descriptors
2597  const auto c_ds_buf_refs = concat_tuple_of_reference(
2598  tie(c_shuffle_block_buf),
2599  generate_tie(
2600  [&](auto i) -> const auto& // return type should be reference
2601  { return ds_grid_buf[i]; },
2602  Number<NumDTensor>{}));
2603 
2604  // tuple of starting index of C/Ds blockwise copy
2605  const auto idx_c_ds_block_begin =
2608  [&](auto) {
2609  return make_multi_index(block_m_id, 0, block_n_id, 0);
2610  // return make_multi_index(block_work_idx[I0], 0,
2611  // block_work_idx[I1], 0);
2612  },
2613  Number<NumDTensor>{}));
2614 
2615  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2616  c_grid_desc_mblock_mperblock_nblock_nperblock;
2617 
2618  using CDEBlockTransferCluster =
2619  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2620  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2621  constexpr index_t scatter_weight_idx = 3; // hack fix felix
2622  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2624  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2626  decltype(c_ds_desc_refs),
2627  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2628  CElementwiseOperation,
2629  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
2630  // Sequence support
2631  // arbitray type
2632  Sequence<1,
2633  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2634  1,
2635  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2636  CDEBlockTransferCluster,
2637  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2638  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2639  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2640  3, // index_t SrcVectorDim,
2641  3, // index_t DstVectorDim,
2642  CDEShuffleBlockTransferScalarPerVectors,
2647  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2648  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2649  IndexType,
2650  1, // ScatterDim
2651  true, // OutputScatter: false, only use scatter weights
2652  scatter_weight_idx // ScatterWeightIdx: ascale
2653  >{c_ds_desc_refs,
2654  idx_c_ds_block_begin,
2655  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2656  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2657  c_element_op};
2658 
2659  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2660  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2661 
2662  constexpr auto sfc_c_vgpr =
2663  SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2664  NXdlPerWave / NXdlPack,
2665  1,
2666  1,
2667  MXdlPack,
2668  NXdlPack,
2669  M2,
2670  1,
2671  M4,
2672  1>,
2674  Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2675  CShuffleNXdlPerWavePerShuffle / NXdlPack,
2676  1,
2677  1,
2678  MXdlPack,
2679  NXdlPack,
2680  M2,
2681  1,
2682  M4,
2683  1>>{};
2684 
2685  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2686 
2687  // space filling curve for shuffled blockwise C/D/E
2688  constexpr auto sfc_cde_block =
2691  Sequence<1,
2692  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2693  1,
2694  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2695 
2696  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2697  constexpr auto EMThreads =
2698  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2699  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2700  constexpr auto ENThreads =
2701  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2702  static_for<0, num_access, 1>{}([&](auto access_id) {
2703  // make sure it's safe to write to LDS
2705 
2706  auto dstidx = sfc_cde_block.GetIndex(access_id);
2707  const index_t c_token_pos =
2708  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2709  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2710  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2711  IndexType token_offset = fused_token & 0xffffff;
2712  if constexpr(IsInputGemm)
2713  {
2714  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2715  }
2716  scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2717  });
2718 
2719  block_sync_lds();
2720 
2721  // each thread write its data from VGPR to LDS
2722  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2723  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2724  c_thread_buf_fp32,
2725  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2726  c_shuffle_block_buf);
2727 
2728  // make sure it's safe to read from LDS
2729  block_sync_lds();
2730 
2731  // each block copy its data from LDS to global
2732  cde_block_copy_lds_and_global.Run(
2733  c_ds_desc_refs,
2734  c_ds_buf_refs,
2735  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2736  tie(c_grid_buf),
2737  scatter_offsets);
2738 
2739  if constexpr(access_id < num_access - 1)
2740  {
2741  constexpr auto cde_lds_and_global_step =
2742  sfc_cde_block.GetForwardStep(access_id);
2743 
2744  // move on Ds
2745  static_for<0, NumDTensor, 1>{}([&](auto i) {
2746  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2747  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2748  });
2749 
2750  // move on E
2751  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2752  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2753  I0,
2754  cde_lds_and_global_step);
2755  }
2756  });
2757  }
2758  }
2759 };
2760 
2761 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:269
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:23
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:278
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp:36
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:87
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:300
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:759
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:820
const index_t * p_sorted_token_ids
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:819
const BScaleDataType * p_b_scale_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:825
const BElementwiseOperation b_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:830
const BDataType * p_b_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:824
const index_t * p_max_token_id
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:821
DsGridPointer p_ds_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:826
const AElementwiseOperation a_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:829
CDataType * p_c_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:827
const AScaleDataType * p_a_scale_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:823
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:760
const CElementwiseOperation c_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:831
const ADataType * p_a_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:822
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:671
index_t BN0Shuffled
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:753
index_t AK0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:748
index_t NPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:745
index_t BK0Shuffled
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:754
index_t KBatch
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:743
index_t BK0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:749
index_t TopK
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:733
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:672
index_t KRead
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:746
index_t K
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:736
index_t MPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:744
index_t StrideC
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:742
index_t StrideScaleB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:740
index_t NumTokens
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:732
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:741
index_t MBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:750
index_t StrideScaleA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:738
index_t N
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:735
__host__ void Print() const
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:709
index_t StrideA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:737
index_t StrideB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:739
index_t M
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:734
index_t KPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:747
index_t NBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:751
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:835
index_t a_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:877
index_t b_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:878
index_t a_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:875
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:836
index_t b_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:876
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:171
static constexpr auto I6
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:181
static __host__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:273
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:299
static constexpr auto AK1Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:191
static constexpr auto AK0Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:189
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1053
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:249
static constexpr auto I1
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:176
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:282
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:253
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:647
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1994
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1268
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:659
static constexpr index_t NLane
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:222
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:999
static constexpr index_t SortedTileSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:228
static constexpr index_t scale_pack_size_b
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:232
static constexpr auto BK1Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:192
remove_cvref_t< decltype(BlockGemmMXBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1051
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1260
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:268
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:238
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:369
static constexpr auto BK0Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:190
static constexpr auto I9
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:184
static constexpr auto I4
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:179
static constexpr auto I3
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:178
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:305
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:287
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:263
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:186
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:312
static constexpr index_t KRepeat
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:225
static constexpr auto NXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:201
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1009
static constexpr auto KXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:202
BDataType LDSTypeB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:173
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:475
static constexpr auto I0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:175
static constexpr auto I5
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:180
static constexpr index_t scale_pack_size_a
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:231
static constexpr auto lcm_AK1_BK1
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:194
static constexpr index_t KLane
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:223
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:881
static constexpr auto MXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:200
static constexpr auto I8
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:183
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:593
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:602
static constexpr bool is_single_rate_mfma
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:195
static constexpr auto I2
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:177
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:583
static constexpr index_t APackedSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:210
static constexpr index_t NumDTensor
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:198
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:317
static constexpr index_t BPackedSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:211
static constexpr index_t KPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:219
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1253
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1075
ADataType LDSTypeA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:172
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:293
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:626
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:328
static __host__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:277
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:251
static constexpr auto I7
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:182
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:468
static constexpr index_t NWave
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:224
static constexpr auto is_scale_mfma
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:196
Definition: xdlops_gemm.hpp:942
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1343
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Definition: thread_group_tensor_slice_transfer_gather_direct_load.hpp:57
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:25
Definition: data_type.hpp:41
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:981
Definition: unary_element_wise_operation.hpp:308
Definition: unary_element_wise_operation.hpp:1023
Definition: dtype_vector.hpp:10
#define CK_ENV(name)
Definition: env.hpp:128