14 template <
typename ADataType,
22 typename ActivationOp = identity>
29 const AccDataType* expert_weight_ptr,
44 float* expert_bias_ptr)
46 constexpr
auto is_split_k = MoeGemmKind == 3;
47 int idx = blockIdx.x * blockDim.x + threadIdx.x;
48 int problem_N = MoeGemmKind == 1 ? N / 2 : N;
49 int row = idx / problem_N;
50 int col = idx % problem_N;
56 if(row < p_max_token_id_[0])
58 expert_id = p_sorted_expert_ids_[row / TokensPerBlock];
59 gather_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
60 scatter_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
61 if(gather_token_id >= Num_tokens)
67 gather_token_id = gather_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
71 scatter_token_id = scatter_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
81 AccDataType acc = 0.0;
82 AccDataType acc_up = 0.0;
84 AccDataType acc_temp = 0.0;
85 AccDataType acc_up_temp = 0.0;
91 index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
92 index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
93 index_t scale_B_expert_stride = scale_B_stride * K / scale_granularity_k;
95 for(
int k = 0; k < K; ++k)
97 if(k % scale_granularity_k == 0)
100 acc += acc_temp * scale_A * scale_B;
101 acc_up += acc_up_temp * scale_A * scale_B_up;
106 scale_A = scale_A_ptr[(gather_token_id / scale_granularity_m) +
107 (k / scale_granularity_k) * scale_A_stride];
109 scale_B_ptr[expert_id * scale_B_expert_stride + col / scale_granularity_n +
110 (k / scale_granularity_k) * scale_B_stride];
111 if constexpr(MoeGemmKind == 1)
112 scale_B_up = scale_B_ptr[expert_id * scale_B_expert_stride +
113 (col + problem_N) / scale_granularity_n +
114 (k / scale_granularity_k) * scale_B_stride];
120 int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
121 ? gather_token_id * strideA + k
122 : k * strideA + gather_token_id;
125 long(expert_id) * N * K +
126 ((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>) ? col * strideB + k
127 : k * strideB + col);
129 if constexpr(MoeGemmKind == 1)
130 b_index_up = long(expert_id) * N * K +
131 ((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
132 ? (col + problem_N) * strideB + k
133 : k * strideB + col + problem_N);
138 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
146 else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
156 v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
158 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
165 if constexpr(MoeGemmKind == 1)
170 v_b_up = fp32_val_up.hi;
172 v_b_up = fp32_val_up.lo;
175 else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
182 if constexpr(MoeGemmKind == 1)
187 v_b_up = fp32_val_up.hi;
189 v_b_up = fp32_val_up.lo;
194 v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
195 if constexpr(MoeGemmKind == 1)
196 v_b_up = ck_tile::type_convert<AccDataType>(B[b_index_up]);
198 acc_temp += v_a * v_b;
199 if constexpr(MoeGemmKind == 1)
200 acc_up_temp += v_a * v_b_up;
203 acc += acc_temp * scale_A * scale_B;
204 acc_up += acc_up_temp * scale_A * scale_B_up;
206 float bias = 0.f, bias_up = 0.f;
207 if(expert_bias_ptr !=
nullptr && !is_split_k)
209 bias = expert_bias_ptr[expert_id * N + col];
210 if constexpr(MoeGemmKind == 1)
211 bias_up = expert_bias_ptr[expert_id * N + col + problem_N];
214 int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
215 ? scatter_token_id * strideC + col
216 : col * strideC + scatter_token_id;
217 if constexpr(MoeGemmKind < 2)
219 C[c_index] = ck_tile::type_convert<CDataType>(
220 ActivationOp{}(acc + bias, MoeGemmKind == 1 ? acc_up + bias_up : 1));
226 is_split_k ? ck_tile::type_convert<AccDataType>(1.0f) : expert_weight_ptr[row];
227 CDataType res = ck_tile::type_convert<CDataType>((acc + bias) * weight);
233 add_v.template get_as<CDataType>()[1] = res;
238 add_v.template get_as<CDataType>()[0] = res;
241 atomic_add_g<CDataType, 2>(
reinterpret_cast<CDataType*
>(C + (c_index & 0xffff'fffe)),
247 template <
typename ADataType,
249 typename AccDataType,
255 typename ActivationOp = identity>
257 const index_t* p_sorted_expert_ids_,
258 const index_t* p_max_token_id_,
259 const ADataType* a_ptr,
260 const BDataType* b_ptr,
262 const AccDataType* expert_weight_ptr,
277 float* exp_bias =
nullptr)
279 int problem_N = MoeGemmKind == 1 ? N / 2 : N;
280 int totalElements = M * problem_N;
281 int numThreadsPerBlock = 256;
282 int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
292 ActivationOp><<<numBlocks, numThreadsPerBlock>>>(p_sorted_token_ids_,
293 p_sorted_expert_ids_,
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:105
float fp32x2_t
Definition: bfloat16.hpp:434
int32_t index_t
Definition: integer.hpp:9
__global__ void moe_gemm_kernel(const ck_tile::index_t *p_sorted_token_ids_, const ck_tile::index_t *p_sorted_expert_ids_, const ck_tile::index_t *p_max_token_id_, const ADataType *A, const BDataType *B, CDataType *C, const AccDataType *expert_weight_ptr, ck_tile::index_t Num_tokens, ck_tile::index_t TokensPerBlock, ck_tile::index_t TopK, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC, index_t scale_granularity_m, index_t scale_granularity_n, index_t scale_granularity_k, float *scale_A_ptr, float *scale_B_ptr, float *expert_bias_ptr)
Definition: reference_moe_gemm.hpp:23
void reference_moe_gemm_gpu(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, const AccDataType *expert_weight_ptr, index_t Num_tokens, index_t TokensPerBlock, index_t TopK, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t scale_granularity_m, index_t scale_granularity_n, index_t scale_granularity_k, float *scale_A_ptr, float *scale_B_ptr, float *exp_bias=nullptr)
Definition: reference_moe_gemm.hpp:256
constexpr CK_TILE_HOST_DEVICE fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t &x, float scale)
Definition: pk_fp4.hpp:350
Definition: numeric.hpp:81