/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp Source File
smoothquant_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
8 
9 namespace ck_tile {
10 
11 // host side args
13 {
14  const void* p_x; // [m ,n], input, fp16/bf16
15  const void* p_smscale; // [1, n], input, columnwise scale, fp32
16 
17  void* p_yscale; // [m, 1], output, rowwise quant scale (amax / 127) of (p_x * p_smscale)
18  void* p_qy; // [m, n], output, p_x * p_smscale / p_yscale
19 
22  index_t x_stride; // input row_stride
23  index_t y_stride; // output row_stride
24 };
25 
26 // TODO: Extract some type to wrapper class
27 template <typename Pipeline_>
29 {
31  using Problem = typename Pipeline::Problem;
32 
38 
39  static constexpr index_t Block_M = Problem::BlockShape::Block_M;
40  static constexpr index_t Block_N = Problem::BlockShape::Block_N;
41  static constexpr bool kPadM = false; // always no need to pad along M
42  static constexpr bool kPadN = Problem::kPadN;
43  static constexpr bool kTwoPass = Problem::kTwoPass;
44 
45  static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
46  static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
47  static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
48 
49  static constexpr auto I0 = number<0>{};
50  static constexpr auto I1 = number<1>{};
51 
52  struct Kargs
53  {
54  const void* p_x;
55  const void* p_smscale;
56 
57  void* p_yscale;
58  void* p_qy;
59 
62  index_t x_stride; // input row_stride
63  index_t y_stride; // out row_stride
64  };
66 
67  CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
68  {
69  return Kargs{hargs.p_x,
70  hargs.p_smscale,
71  hargs.p_yscale,
72  hargs.p_qy,
73  hargs.m,
74  hargs.n,
75  hargs.x_stride,
76  hargs.y_stride};
77  }
78 
79  CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
80  {
81  return dim3(integer_divide_ceil(hargs.m, Block_M));
82  }
83 
84  CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
85 
86  // clang-format off
87  template <typename T> struct t2s;
88  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
89  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
90  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
91  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
92  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
93  // clang-format on
94 
95  // in byte
96  CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
97 
98  CK_TILE_HOST static std::string GetName()
99  {
100  // clang-format off
101  using S_ = typename Problem::BlockShape;
102  auto surfix = [&] () {
103  std::string n;
104  if (kPadN) n += "_pn";
105  if (kTwoPass) n += "_2p";
106  return n; }();
107 
108  #define _SS_ std::string
109  #define _TS_ std::to_string
110  return _SS_("smoothquant_fwd_") + _SS_(t2s<XDataType>::name) + "_" +
111  _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
112  _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
113  _SS_(Pipeline::name) + surfix;
114  #undef _SS_
115  #undef _TS_
116  // clang-format on
117  }
118 
119  CK_TILE_DEVICE void operator()(Kargs kargs) const
120  {
121  const auto iM = get_block_id() * Block_M;
122 
123  const auto x_window = [&]() {
124  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
125  static_cast<const XDataType*>(kargs.p_x),
126  make_tuple(kargs.m, kargs.n),
127  make_tuple(kargs.x_stride, 1),
129  number<1>{});
130 
131  const auto tmp2_ = pad_tensor_view(
133  return make_tile_window(
134  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
135  }();
136 
137  const auto smscale_window = [&]() {
138  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
139  static_cast<const SmoothScaleDataType*>(kargs.p_smscale),
140  make_tuple(kargs.n),
141  make_tuple(1),
143  number<1>{});
144 
145  const auto tmp2_ =
147 
148  return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
149  }();
150 
151  auto yscale_window = [&]() {
152  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
153  static_cast<YScaleDataType*>(kargs.p_yscale),
154  make_tuple(kargs.m),
155  make_tuple(1),
156  number<1>{});
157 
158  const auto tmp2_ =
160 
161  return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {iM});
162  }();
163 
164  auto qy_window = [&]() {
165  auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
166  static_cast<QYDataType*>(kargs.p_qy),
167  make_tuple(kargs.m, kargs.n),
168  make_tuple(kargs.y_stride, 1),
170  number<1>{});
171 
172  auto tmp2_ = pad_tensor_view(
174  return make_tile_window(
175  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
176  }();
177 
178  __shared__ char smem[GetSmemSize()];
179 
180  Pipeline{}(x_window, smscale_window, yscale_window, qy_window, kargs.n, smem);
181  }
182 };
183 
184 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:480
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
CK_TILE_DEVICE index_t get_block_id()
Definition: arch.hpp:78
#define _TS_
#define _SS_
Definition: smoothquant_kernel.hpp:53
index_t m
Definition: smoothquant_kernel.hpp:60
index_t n
Definition: smoothquant_kernel.hpp:61
index_t x_stride
Definition: smoothquant_kernel.hpp:62
const void * p_x
Definition: smoothquant_kernel.hpp:54
void * p_qy
Definition: smoothquant_kernel.hpp:58
const void * p_smscale
Definition: smoothquant_kernel.hpp:55
void * p_yscale
Definition: smoothquant_kernel.hpp:57
index_t y_stride
Definition: smoothquant_kernel.hpp:63
Definition: smoothquant_kernel.hpp:87
Definition: smoothquant_kernel.hpp:13
index_t y_stride
Definition: smoothquant_kernel.hpp:23
const void * p_smscale
Definition: smoothquant_kernel.hpp:15
void * p_qy
Definition: smoothquant_kernel.hpp:18
index_t x_stride
Definition: smoothquant_kernel.hpp:22
void * p_yscale
Definition: smoothquant_kernel.hpp:17
index_t m
Definition: smoothquant_kernel.hpp:20
index_t n
Definition: smoothquant_kernel.hpp:21
const void * p_x
Definition: smoothquant_kernel.hpp:14
Definition: smoothquant_kernel.hpp:29
static constexpr index_t Block_M
Definition: smoothquant_kernel.hpp:39
static constexpr CK_TILE_HOST Kargs MakeKargs(const Hargs &hargs)
Definition: smoothquant_kernel.hpp:67
static constexpr CK_TILE_HOST auto BlockSize()
Definition: smoothquant_kernel.hpp:84
static constexpr CK_TILE_HOST auto GridSize(const Hargs &hargs)
Definition: smoothquant_kernel.hpp:79
remove_cvref_t< Pipeline_ > Pipeline
Definition: smoothquant_kernel.hpp:30
static constexpr auto I1
Definition: smoothquant_kernel.hpp:50
static constexpr index_t Repeat_N
Definition: smoothquant_kernel.hpp:47
remove_cvref_t< typename Problem::SmoothScaleDataType > SmoothScaleDataType
Definition: smoothquant_kernel.hpp:34
static constexpr bool kPadM
Definition: smoothquant_kernel.hpp:41
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: smoothquant_kernel.hpp:96
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: smoothquant_kernel.hpp:119
static CK_TILE_HOST std::string GetName()
Definition: smoothquant_kernel.hpp:98
static constexpr index_t ThreadPerWarp_N
Definition: smoothquant_kernel.hpp:45
static constexpr bool kTwoPass
Definition: smoothquant_kernel.hpp:43
static constexpr auto I0
Definition: smoothquant_kernel.hpp:49
remove_cvref_t< typename Problem::XDataType > XDataType
Definition: smoothquant_kernel.hpp:33
static constexpr index_t Vector_N
Definition: smoothquant_kernel.hpp:46
remove_cvref_t< typename Problem::QYDataType > QYDataType
Definition: smoothquant_kernel.hpp:37
static constexpr bool kPadN
Definition: smoothquant_kernel.hpp:42
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition: smoothquant_kernel.hpp:35
static constexpr index_t Block_N
Definition: smoothquant_kernel.hpp:40
remove_cvref_t< typename Problem::YScaleDataType > YScaleDataType
Definition: smoothquant_kernel.hpp:36
typename Pipeline::Problem Problem
Definition: smoothquant_kernel.hpp:31
Definition: integral_constant.hpp:13
Definition: sequence.hpp:52