/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp Source File
moe_smoothquant_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
8 
9 namespace ck_tile {
10 
11 // host side args
13 {
14  const void* p_x; // [tokens ,hidden_size], input, fp16/bf16
15  const void* p_smscale; // [experts, hidden_size], input, columnwise scale, fp32
16  const void* p_topk_ids; // [tokens, topk]
17 
18  void* p_yscale; // [topk * tokens, 1], output, rowwise quant scale
19  void* p_qy; // [topk * tokens, hidden_size], output
20 
25  index_t x_stride; // input x row stride
26  index_t y_stride; // output y stride(stride for topk)
27 };
28 
29 // TODO: Extract some type to wrapper class
30 template <typename Pipeline_>
32 {
34  using Problem = typename Pipeline::Problem;
35 
41 
42  static constexpr index_t Block_M = Problem::BlockShape::Block_M;
43  static constexpr index_t Block_N = Problem::BlockShape::Block_N;
44  static constexpr bool kPadM = false; // always no need to pad along M
45  static constexpr bool kPadN = Problem::kPadN;
46  static constexpr bool kTwoPass = Problem::kTwoPass;
47 
48  static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
49  static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
50  static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
51 
52  static constexpr auto I0 = number<0>{};
53  static constexpr auto I1 = number<1>{};
54 
55  static_assert(Problem::BlockShape::Repeat_M == 1);
56 
57  struct Kargs
58  {
59  const void* p_x; // [tokens ,hidden_size], input, fp16/bf16
60  const void* p_smscale; // [experts, hidden_size], input, columnwise scale, fp32
61  const void* p_topk_ids; // [tokens, topk]
62 
63  void* p_yscale; // [topk, tokens, 1], output, rowwise quant scale
64  void* p_qy; // [topk, tokens, hidden_size], output
65 
70  index_t x_stride; // input x row stride
71  index_t y_stride; // output y stride(stride for topk)
72  };
74 
75  CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
76  {
77  return Kargs{hargs.p_x,
78  hargs.p_smscale,
79  hargs.p_topk_ids,
80  hargs.p_yscale,
81  hargs.p_qy,
82  hargs.tokens,
83  hargs.hidden_size,
84  hargs.experts,
85  hargs.topk,
86  hargs.x_stride,
87  hargs.y_stride};
88  }
89 
90  CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
91  {
92  return dim3(hargs.topk, integer_divide_ceil(hargs.tokens, Block_M), 1);
93  }
94 
95  CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
96 
97  // clang-format off
98  template <typename T> struct t2s;
99  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
100  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
101  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
102  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
103  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
104  template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "i8"; };
105  // clang-format on
106 
107  // in byte
108  CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
109 
110  CK_TILE_HOST static std::string GetName()
111  {
112  // clang-format off
113  using S_ = typename Problem::BlockShape;
114  auto surfix = [&] () {
115  std::string n;
116  if (kPadN) n += "_pn";
117  if (kTwoPass) n += "_2p";
118  return n; }();
119 
120  #define _SS_ std::string
121  #define _TS_ std::to_string
122  return _SS_("moe_smoothquant_") + _SS_(t2s<XDataType>::name) + "_" + _SS_(t2s<QYDataType>::name) + "_" +
123  _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
124  _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
125  _SS_(Pipeline::name) + surfix;
126  #undef _SS_
127  #undef _TS_
128  // clang-format on
129  }
130 
131  CK_TILE_DEVICE void operator()(Kargs kargs) const
132  {
133  const index_t i_topk = blockIdx.x;
134  const index_t i_token = blockIdx.y * Block_M;
135  const index_t i_token_in_thrd =
136  __builtin_amdgcn_readfirstlane(threadIdx.x / Problem::BlockShape::ThreadPerBlock_N);
137 
138  const index_t i_expert = reinterpret_cast<const index_t*>(
139  kargs.p_topk_ids)[(i_token + i_token_in_thrd) * kargs.topk + i_topk];
140 
141  // [tokens ,hidden_size]
142  const auto x_window = [&]() {
143  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
144  static_cast<const XDataType*>(kargs.p_x),
145  make_tuple(kargs.tokens, kargs.hidden_size),
146  make_tuple(kargs.x_stride, 1),
148  number<1>{});
149 
150  const auto tmp2_ = pad_tensor_view(
152  return make_tile_window(
153  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {i_token, 0});
154  }();
155 
156  // [experts, hidden_size],
157  const auto smscale_window = [&]() {
158  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
159  static_cast<const SmoothScaleDataType*>(kargs.p_smscale) +
160  i_expert * kargs.hidden_size,
161  make_tuple(kargs.hidden_size),
162  make_tuple(1),
164  number<1>{});
165 
166  const auto tmp2_ =
168 
169  return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
170  }();
171 
172  // [topk, tokens]
173  auto yscale_window = [&]() {
174  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
175  static_cast<YScaleDataType*>(kargs.p_yscale) + i_topk * kargs.tokens,
176  make_tuple(kargs.tokens),
177  make_tuple(1),
178  number<1>{});
179 
180  const auto tmp2_ =
182 
183  return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {i_token});
184  }();
185 
186  // [topk, tokens, hidden_size]
187  auto qy_window = [&]() {
188  auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
189  static_cast<QYDataType*>(kargs.p_qy) + i_topk * kargs.tokens * kargs.y_stride,
190  make_tuple(kargs.tokens, kargs.hidden_size),
191  make_tuple(kargs.y_stride, 1),
193  number<1>{});
194 
195  auto tmp2_ = pad_tensor_view(
197  return make_tile_window(
198  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {i_token, 0});
199  }();
200 
201  __shared__ char smem[GetSmemSize()];
202 
203  Pipeline{}(x_window, smscale_window, yscale_window, qy_window, kargs.hidden_size, smem);
204  }
205 };
206 
207 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:480
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
Definition: moe_smoothquant_kernel.hpp:58
index_t y_stride
Definition: moe_smoothquant_kernel.hpp:71
const void * p_x
Definition: moe_smoothquant_kernel.hpp:59
index_t tokens
Definition: moe_smoothquant_kernel.hpp:66
index_t topk
Definition: moe_smoothquant_kernel.hpp:69
index_t x_stride
Definition: moe_smoothquant_kernel.hpp:70
void * p_yscale
Definition: moe_smoothquant_kernel.hpp:63
const void * p_smscale
Definition: moe_smoothquant_kernel.hpp:60
void * p_qy
Definition: moe_smoothquant_kernel.hpp:64
index_t experts
Definition: moe_smoothquant_kernel.hpp:68
const void * p_topk_ids
Definition: moe_smoothquant_kernel.hpp:61
index_t hidden_size
Definition: moe_smoothquant_kernel.hpp:67
Definition: moe_smoothquant_kernel.hpp:98
Definition: moe_smoothquant_kernel.hpp:13
index_t x_stride
Definition: moe_smoothquant_kernel.hpp:25
index_t topk
Definition: moe_smoothquant_kernel.hpp:24
index_t hidden_size
Definition: moe_smoothquant_kernel.hpp:22
void * p_yscale
Definition: moe_smoothquant_kernel.hpp:18
index_t experts
Definition: moe_smoothquant_kernel.hpp:23
index_t y_stride
Definition: moe_smoothquant_kernel.hpp:26
index_t tokens
Definition: moe_smoothquant_kernel.hpp:21
const void * p_topk_ids
Definition: moe_smoothquant_kernel.hpp:16
const void * p_smscale
Definition: moe_smoothquant_kernel.hpp:15
void * p_qy
Definition: moe_smoothquant_kernel.hpp:19
const void * p_x
Definition: moe_smoothquant_kernel.hpp:14
Definition: moe_smoothquant_kernel.hpp:32
static constexpr bool kTwoPass
Definition: moe_smoothquant_kernel.hpp:46
remove_cvref_t< typename Problem::SmoothScaleDataType > SmoothScaleDataType
Definition: moe_smoothquant_kernel.hpp:37
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: moe_smoothquant_kernel.hpp:131
remove_cvref_t< typename Problem::YScaleDataType > YScaleDataType
Definition: moe_smoothquant_kernel.hpp:39
static constexpr bool kPadM
Definition: moe_smoothquant_kernel.hpp:44
static constexpr CK_TILE_HOST Kargs MakeKargs(const Hargs &hargs)
Definition: moe_smoothquant_kernel.hpp:75
static constexpr auto I0
Definition: moe_smoothquant_kernel.hpp:52
static constexpr bool kPadN
Definition: moe_smoothquant_kernel.hpp:45
remove_cvref_t< typename Problem::QYDataType > QYDataType
Definition: moe_smoothquant_kernel.hpp:40
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition: moe_smoothquant_kernel.hpp:38
remove_cvref_t< Pipeline_ > Pipeline
Definition: moe_smoothquant_kernel.hpp:33
remove_cvref_t< typename Problem::XDataType > XDataType
Definition: moe_smoothquant_kernel.hpp:36
static constexpr index_t Vector_N
Definition: moe_smoothquant_kernel.hpp:49
static constexpr CK_TILE_HOST auto BlockSize()
Definition: moe_smoothquant_kernel.hpp:95
static constexpr index_t Block_N
Definition: moe_smoothquant_kernel.hpp:43
static CK_TILE_HOST std::string GetName()
Definition: moe_smoothquant_kernel.hpp:110
typename Pipeline::Problem Problem
Definition: moe_smoothquant_kernel.hpp:34
static constexpr CK_TILE_HOST auto GridSize(const Hargs &hargs)
Definition: moe_smoothquant_kernel.hpp:90
static constexpr index_t Repeat_N
Definition: moe_smoothquant_kernel.hpp:50
static constexpr index_t ThreadPerWarp_N
Definition: moe_smoothquant_kernel.hpp:48
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: moe_smoothquant_kernel.hpp:108
static constexpr auto I1
Definition: moe_smoothquant_kernel.hpp:53
static constexpr index_t Block_M
Definition: moe_smoothquant_kernel.hpp:42
Definition: integral_constant.hpp:13
Definition: sequence.hpp:52