/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_moe_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_moe_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_moe_gemm.hpp Source File
reference_moe_gemm.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <cstdlib>
7 #include <thread>
8 
9 #include "ck_tile/core.hpp"
11 
12 namespace ck_tile {
13 
14 template <typename ADataType,
15  typename BDataType,
16  typename AccDataType,
17  typename CDataType,
18  typename LayoutA,
19  typename LayoutB,
20  typename LayoutC,
21  int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2, 3:gemm1_split_k
22  typename ActivationOp = identity>
23 __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
24  const ck_tile::index_t* p_sorted_expert_ids_,
25  const ck_tile::index_t* p_max_token_id_,
26  const ADataType* A,
27  const BDataType* B,
28  CDataType* C,
29  const AccDataType* expert_weight_ptr,
30  ck_tile::index_t Num_tokens,
31  ck_tile::index_t TokensPerBlock,
32  ck_tile::index_t TopK,
36  ck_tile::index_t strideA,
37  ck_tile::index_t strideB,
38  ck_tile::index_t strideC,
39  index_t scale_granularity_m,
40  index_t scale_granularity_n,
41  index_t scale_granularity_k,
42  float* scale_A_ptr,
43  float* scale_B_ptr,
44  float* expert_bias_ptr)
45 {
46  constexpr auto is_split_k = MoeGemmKind == 3;
47  int idx = blockIdx.x * blockDim.x + threadIdx.x;
48  int problem_N = MoeGemmKind == 1 ? N / 2 : N;
49  int row = idx / problem_N; // Compute row index
50  int col = idx % problem_N; // Compute column index
51 
52  index_t gather_token_id = 0;
53  index_t scatter_token_id = 0;
54  index_t expert_id = 0;
55 
56  if(row < p_max_token_id_[0])
57  {
58  expert_id = p_sorted_expert_ids_[row / TokensPerBlock];
59  gather_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
60  scatter_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
61  if(gather_token_id >= Num_tokens)
62  {
63  return;
64  }
65  if(MoeGemmKind == 2)
66  {
67  gather_token_id = gather_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
68  }
69  else
70  {
71  scatter_token_id = scatter_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
72  }
73  }
74  else
75  {
76  return;
77  }
78 
79  if(row < M)
80  {
81  AccDataType acc = 0.0;
82  AccDataType acc_up = 0.0;
83 
84  AccDataType acc_temp = 0.0;
85  AccDataType acc_up_temp = 0.0;
86 
87  float scale_A = 0;
88  float scale_B = 0;
89  float scale_B_up = 0;
90 
91  index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
92  index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
93  index_t scale_B_expert_stride = scale_B_stride * K / scale_granularity_k;
94 
95  for(int k = 0; k < K; ++k)
96  {
97  if(k % scale_granularity_k == 0)
98  {
99  // update acc
100  acc += acc_temp * scale_A * scale_B;
101  acc_up += acc_up_temp * scale_A * scale_B_up;
102  // reset acc temp
103  acc_temp = 0.0;
104  acc_up_temp = 0.0;
105  // update scale factors
106  scale_A = scale_A_ptr[(gather_token_id / scale_granularity_m) +
107  (k / scale_granularity_k) * scale_A_stride];
108  scale_B =
109  scale_B_ptr[expert_id * scale_B_expert_stride + col / scale_granularity_n +
110  (k / scale_granularity_k) * scale_B_stride];
111  if constexpr(MoeGemmKind == 1)
112  scale_B_up = scale_B_ptr[expert_id * scale_B_expert_stride +
113  (col + problem_N) / scale_granularity_n +
114  (k / scale_granularity_k) * scale_B_stride];
115  }
116 
119  // Adjust indexing based on matrix layout
120  int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
121  ? gather_token_id * strideA + k
122  : k * strideA + gather_token_id;
123 
124  long b_index =
125  long(expert_id) * N * K +
126  ((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>) ? col * strideB + k
127  : k * strideB + col);
128  long b_index_up;
129  if constexpr(MoeGemmKind == 1)
130  b_index_up = long(expert_id) * N * K +
131  ((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
132  ? (col + problem_N) * strideB + k
133  : k * strideB + col + problem_N);
134 
135  AccDataType v_a;
136  AccDataType v_b;
137  AccDataType v_b_up;
138  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
139  {
140  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
141  if(k % 2 == 1)
142  v_a = fp32_val.hi;
143  else
144  v_a = fp32_val.lo;
145  }
146  else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
147  {
148  const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
149  if(k % 2 == 1)
150  v_a = fp32_val.hi;
151  else
152  v_a = fp32_val.lo;
153  }
154  else
155  {
156  v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
157  }
158  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
159  {
160  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
161  if(k % 2 == 1)
162  v_b = fp32_val.hi;
163  else
164  v_b = fp32_val.lo;
165  if constexpr(MoeGemmKind == 1)
166  {
167  const fp32x2_t fp32_val_up =
168  pk_int4_t_to_fp32x2_t(B[b_index_up / packed_size_b]);
169  if(k % 2 == 1)
170  v_b_up = fp32_val_up.hi;
171  else
172  v_b_up = fp32_val_up.lo;
173  }
174  }
175  else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
176  {
177  const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
178  if(k % 2 == 1)
179  v_b = fp32_val.hi;
180  else
181  v_b = fp32_val.lo;
182  if constexpr(MoeGemmKind == 1)
183  {
184  const fp32x2_t fp32_val_up =
185  pk_fp4_to_fp32x2(B[b_index_up / packed_size_b], 1.0f);
186  if(k % 2 == 1)
187  v_b_up = fp32_val_up.hi;
188  else
189  v_b_up = fp32_val_up.lo;
190  }
191  }
192  else
193  {
194  v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
195  if constexpr(MoeGemmKind == 1)
196  v_b_up = ck_tile::type_convert<AccDataType>(B[b_index_up]);
197  }
198  acc_temp += v_a * v_b;
199  if constexpr(MoeGemmKind == 1)
200  acc_up_temp += v_a * v_b_up;
201  }
202 
203  acc += acc_temp * scale_A * scale_B;
204  acc_up += acc_up_temp * scale_A * scale_B_up;
205 
206  float bias = 0.f, bias_up = 0.f;
207  if(expert_bias_ptr != nullptr && !is_split_k)
208  {
209  bias = expert_bias_ptr[expert_id * N + col];
210  if constexpr(MoeGemmKind == 1)
211  bias_up = expert_bias_ptr[expert_id * N + col + problem_N];
212  }
213 
214  int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
215  ? scatter_token_id * strideC + col
216  : col * strideC + scatter_token_id;
217  if constexpr(MoeGemmKind < 2)
218  {
219  C[c_index] = ck_tile::type_convert<CDataType>(
220  ActivationOp{}(acc + bias, MoeGemmKind == 1 ? acc_up + bias_up : 1));
221  }
222  else
223  {
224  // moe gemm2 don't use activation.
225  auto weight =
226  is_split_k ? ck_tile::type_convert<AccDataType>(1.0f) : expert_weight_ptr[row];
227  CDataType res = ck_tile::type_convert<CDataType>((acc + bias) * weight);
228 
229  thread_buffer<CDataType, 2> add_v = 0;
230  if(c_index % 2)
231  {
232  // result is the second value of fp16 pair.
233  add_v.template get_as<CDataType>()[1] = res;
234  }
235  else
236  {
237  // result is the first value of fp16 pair.
238  add_v.template get_as<CDataType>()[0] = res;
239  }
240  // mask last bit to make sure atomicAdd pointer is aligned of DWORD.
241  atomic_add_g<CDataType, 2>(reinterpret_cast<CDataType*>(C + (c_index & 0xffff'fffe)),
242  add_v);
243  }
244  }
245 }
246 
247 template <typename ADataType,
248  typename BDataType,
249  typename AccDataType,
250  typename CDataType,
251  typename LayoutA,
252  typename LayoutB,
253  typename LayoutC,
254  int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2, 3:gemm1_split_k
255  typename ActivationOp = identity>
256 void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_,
257  const index_t* p_sorted_expert_ids_,
258  const index_t* p_max_token_id_,
259  const ADataType* a_ptr,
260  const BDataType* b_ptr,
261  CDataType* c_ptr,
262  const AccDataType* expert_weight_ptr,
263  index_t Num_tokens,
264  index_t TokensPerBlock,
265  index_t TopK,
266  index_t M,
267  index_t N,
268  index_t K,
269  index_t stride_a,
270  index_t stride_b,
271  index_t stride_c,
272  index_t scale_granularity_m,
273  index_t scale_granularity_n,
274  index_t scale_granularity_k,
275  float* scale_A_ptr,
276  float* scale_B_ptr,
277  float* exp_bias = nullptr)
278 {
279  int problem_N = MoeGemmKind == 1 ? N / 2 : N;
280  int totalElements = M * problem_N;
281  int numThreadsPerBlock = 256; // Common choice for threads per block
282  int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
283 
284  moe_gemm_kernel<ADataType,
285  BDataType,
286  AccDataType,
287  CDataType,
288  LayoutA,
289  LayoutB,
290  LayoutC,
291  MoeGemmKind,
292  ActivationOp><<<numBlocks, numThreadsPerBlock>>>(p_sorted_token_ids_,
293  p_sorted_expert_ids_,
294  p_max_token_id_,
295  a_ptr,
296  b_ptr,
297  c_ptr,
298  expert_weight_ptr,
299  Num_tokens,
300  TokensPerBlock,
301  TopK,
302  M,
303  N,
304  K,
305  stride_a,
306  stride_b,
307  stride_c,
308  scale_granularity_m,
309  scale_granularity_n,
310  scale_granularity_k,
311  scale_A_ptr,
312  scale_B_ptr,
313  exp_bias);
314 
315  return;
316 }
317 
318 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:105
float fp32x2_t
Definition: bfloat16.hpp:434
int32_t index_t
Definition: integer.hpp:9
__global__ void moe_gemm_kernel(const ck_tile::index_t *p_sorted_token_ids_, const ck_tile::index_t *p_sorted_expert_ids_, const ck_tile::index_t *p_max_token_id_, const ADataType *A, const BDataType *B, CDataType *C, const AccDataType *expert_weight_ptr, ck_tile::index_t Num_tokens, ck_tile::index_t TokensPerBlock, ck_tile::index_t TopK, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC, index_t scale_granularity_m, index_t scale_granularity_n, index_t scale_granularity_k, float *scale_A_ptr, float *scale_B_ptr, float *expert_bias_ptr)
Definition: reference_moe_gemm.hpp:23
void reference_moe_gemm_gpu(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, const AccDataType *expert_weight_ptr, index_t Num_tokens, index_t TokensPerBlock, index_t TopK, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t scale_granularity_m, index_t scale_granularity_n, index_t scale_granularity_k, float *scale_A_ptr, float *scale_B_ptr, float *exp_bias=nullptr)
Definition: reference_moe_gemm.hpp:256
constexpr CK_TILE_HOST_DEVICE fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t &x, float scale)
Definition: pk_fp4.hpp:350
Definition: numeric.hpp:81
Definition: debug.hpp:27