/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp Source File
smoothquant_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
8 
9 namespace ck_tile {
10 
11 // host side args
13 {
14  const void* p_x; // [m ,n], input, fp16/bf16
15  const void* p_smscale; // [1, n], input, columnwise scale, fp32
16 
17  void* p_yscale; // [m, 1], output, rowwise quant scale (amax / 127) of (p_x * p_smscale)
18  void* p_qy; // [m, n], output, p_x * p_smscale / p_yscale
19 
22  index_t x_stride; // input row_stride
23  index_t y_stride; // output row_stride
24 };
25 
26 // TODO: Extract some type to wrapper class
27 template <typename Pipeline_>
29 {
31  using Problem = typename Pipeline::Problem;
32 
38 
39  static constexpr index_t Block_M = Problem::BlockShape::Block_M;
40  static constexpr index_t Block_N = Problem::BlockShape::Block_N;
41  static constexpr bool kPadM = false; // always no need to pad along M
42  static constexpr bool kPadN = Problem::kPadN;
43  static constexpr bool kTwoPass = Problem::kTwoPass;
44 
45  static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
46  static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
47  static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
48  static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
49 
50  static constexpr auto I0 = number<0>{};
51  static constexpr auto I1 = number<1>{};
52 
53  struct Kargs
54  {
55  const void* p_x;
56  const void* p_smscale;
57 
58  void* p_yscale;
59  void* p_qy;
60 
63  index_t x_stride; // input row_stride
64  index_t y_stride; // out row_stride
65  };
67 
68  CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
69  {
70  return Kargs{hargs.p_x,
71  hargs.p_smscale,
72  hargs.p_yscale,
73  hargs.p_qy,
74  hargs.m,
75  hargs.n,
76  hargs.x_stride,
77  hargs.y_stride};
78  }
79 
80  CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
81  {
82  return dim3(integer_divide_ceil(hargs.m, Block_M));
83  }
84 
85  CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
86 
87  // clang-format off
88  template <typename T> struct t2s;
89  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
90  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
91  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
92  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
93  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
94  // clang-format on
95 
96  // in byte
97  CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
98 
99  CK_TILE_HOST static std::string GetName()
100  {
101  // clang-format off
102  using S_ = typename Problem::BlockShape;
103  auto surfix = [&] () {
104  std::string n;
105  if (kPadN) n += "_pn";
106  if (kTwoPass) n += "_2p";
107  return n; }();
108 
109  #define _SS_ std::string
110  #define _TS_ std::to_string
111  return _SS_("smoothquant_fwd_") + _SS_(t2s<XDataType>::name) + "_" +
112  _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
113  _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
114  _SS_(Pipeline::name) + surfix;
115  #undef _SS_
116  #undef _TS_
117  // clang-format on
118  }
119 
120  CK_TILE_DEVICE void operator()(Kargs kargs) const
121  {
122  const auto iM = get_block_id() * Block_M;
123 
124  const auto x_window = [&]() {
125  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
126  static_cast<const XDataType*>(kargs.p_x),
127  make_tuple(kargs.m, kargs.n),
128  make_tuple(kargs.x_stride, 1),
130  number<1>{});
131 
132  const auto tmp2_ = pad_tensor_view(
134  return make_tile_window(
135  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
136  }();
137 
138  const auto smscale_window = [&]() {
139  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
140  static_cast<const SmoothScaleDataType*>(kargs.p_smscale),
141  make_tuple(kargs.n),
142  make_tuple(1),
144  number<1>{});
145 
146  const auto tmp2_ =
148 
149  return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
150  }();
151 
152  auto yscale_window = [&]() {
153  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
154  static_cast<YScaleDataType*>(kargs.p_yscale),
155  make_tuple(kargs.m),
156  make_tuple(1),
157  number<1>{});
158 
159  const auto tmp2_ =
161 
162  return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {iM});
163  }();
164 
165  auto qy_window = [&]() {
166  auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
167  static_cast<QYDataType*>(kargs.p_qy),
168  make_tuple(kargs.m, kargs.n),
169  make_tuple(kargs.y_stride, 1),
171  number<1>{});
172 
173  auto tmp2_ = pad_tensor_view(
175  return make_tile_window(
176  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
177  }();
178 
179  __shared__ char smem[GetSmemSize()];
180 
181  Pipeline{}(x_window, smscale_window, yscale_window, qy_window, kargs.n, smem);
182  }
183 };
184 
185 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
#define _TS_
#define _SS_
Definition: smoothquant_kernel.hpp:54
index_t m
Definition: smoothquant_kernel.hpp:61
index_t n
Definition: smoothquant_kernel.hpp:62
index_t x_stride
Definition: smoothquant_kernel.hpp:63
const void * p_x
Definition: smoothquant_kernel.hpp:55
void * p_qy
Definition: smoothquant_kernel.hpp:59
const void * p_smscale
Definition: smoothquant_kernel.hpp:56
void * p_yscale
Definition: smoothquant_kernel.hpp:58
index_t y_stride
Definition: smoothquant_kernel.hpp:64
Definition: smoothquant_kernel.hpp:88
Definition: smoothquant_kernel.hpp:13
index_t y_stride
Definition: smoothquant_kernel.hpp:23
const void * p_smscale
Definition: smoothquant_kernel.hpp:15
void * p_qy
Definition: smoothquant_kernel.hpp:18
index_t x_stride
Definition: smoothquant_kernel.hpp:22
void * p_yscale
Definition: smoothquant_kernel.hpp:17
index_t m
Definition: smoothquant_kernel.hpp:20
index_t n
Definition: smoothquant_kernel.hpp:21
const void * p_x
Definition: smoothquant_kernel.hpp:14
Definition: smoothquant_kernel.hpp:29
static constexpr index_t Block_M
Definition: smoothquant_kernel.hpp:39
static constexpr CK_TILE_HOST Kargs MakeKargs(const Hargs &hargs)
Definition: smoothquant_kernel.hpp:68
static constexpr CK_TILE_HOST auto BlockSize()
Definition: smoothquant_kernel.hpp:85
static constexpr CK_TILE_HOST auto GridSize(const Hargs &hargs)
Definition: smoothquant_kernel.hpp:80
static constexpr index_t kBlockSize
Definition: smoothquant_kernel.hpp:48
remove_cvref_t< Pipeline_ > Pipeline
Definition: smoothquant_kernel.hpp:30
static constexpr auto I1
Definition: smoothquant_kernel.hpp:51
static constexpr index_t Repeat_N
Definition: smoothquant_kernel.hpp:47
remove_cvref_t< typename Problem::SmoothScaleDataType > SmoothScaleDataType
Definition: smoothquant_kernel.hpp:34
static constexpr bool kPadM
Definition: smoothquant_kernel.hpp:41
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: smoothquant_kernel.hpp:97
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: smoothquant_kernel.hpp:120
static CK_TILE_HOST std::string GetName()
Definition: smoothquant_kernel.hpp:99
static constexpr index_t ThreadPerWarp_N
Definition: smoothquant_kernel.hpp:45
static constexpr bool kTwoPass
Definition: smoothquant_kernel.hpp:43
static constexpr auto I0
Definition: smoothquant_kernel.hpp:50
remove_cvref_t< typename Problem::XDataType > XDataType
Definition: smoothquant_kernel.hpp:33
static constexpr index_t Vector_N
Definition: smoothquant_kernel.hpp:46
remove_cvref_t< typename Problem::QYDataType > QYDataType
Definition: smoothquant_kernel.hpp:37
static constexpr bool kPadN
Definition: smoothquant_kernel.hpp:42
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition: smoothquant_kernel.hpp:35
static constexpr index_t Block_N
Definition: smoothquant_kernel.hpp:40
remove_cvref_t< typename Problem::YScaleDataType > YScaleDataType
Definition: smoothquant_kernel.hpp:36
typename Pipeline::Problem Problem
Definition: smoothquant_kernel.hpp:31
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49