/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp Source File
mixed_prec_flatmm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 
13 
14 namespace ck_tile {
15 
16 template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
17 struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>
18 {
20 
31  static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
32  static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
33 
36  // Below type is actually accumulation data type - the output of block GEMM.
38 
40  static constexpr int N_Pack = 2;
41 
42  static constexpr index_t NumDTensor = DsDataType::size();
43 
44  static constexpr auto I0 = number<0>();
45  static constexpr auto I1 = number<1>();
46  static constexpr auto I2 = number<2>();
47  static constexpr auto I3 = number<3>();
48  static constexpr auto I4 = number<4>();
49 
50  static_assert(DsLayout::size() == DsDataType::size(),
51  "The size of DsLayout and DsDataType should be the same");
52  // using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
53 
54  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
55  {
56  // clang-format off
57  return concat('_', "mixed_prec_gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
58  // clang-format on
59  }
60 
61  template <class ScaleM, class ScaleN>
62  CK_TILE_HOST static constexpr auto
63  GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
64  {
65  if constexpr(UsePersistentKernel)
66  {
67  hipDeviceProp_t prop;
68  int deviceId = 0; // default device
69 
70  constexpr int block_size = F16xMXF4FlatmmKernel::BlockSize().x;
71  int dync_smem_size = 0;
72  int maxActiveBlocksPerCU = 0;
73 
74  [[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
75 
76  e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
77  &maxActiveBlocksPerCU,
78  reinterpret_cast<void*>(
79  kentry<1,
81  FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
82  block_size,
83  dync_smem_size);
84 
85  const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
86  const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
87 
88  // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
89  // << ", persistent_block_size: " << persistent_block_size
90  // << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
91 
92  assert(kargs.k_batch == 1);
93  return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
94  }
95  else
96  {
97  return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
98  }
99  }
100 
102 
103  template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
104  CK_TILE_DEVICE static auto
106  const BDataType* b_flat_ptr,
107  const std::array<const void*, NumDTensor>& ds_ptr,
108  EDataType* e_ptr,
109  const KernelArgs& kargs,
110  const SplitKBatchOffset& splitk_batch_offset)
111  {
112  const auto& a_tensor_view = [&]() {
113  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
114  {
115  return make_naive_tensor_view<address_space_enum::global>(
116  a_ptr,
117  make_tuple(kargs.M, splitk_batch_offset.splitted_k),
118  make_tuple(kargs.stride_A, 1),
119  number<FlatmmPipeline::GetVectorSizeA()>{},
120  number<1>{});
121  }
122  else
123  {
124  return make_naive_tensor_view<address_space_enum::global>(
125  a_ptr,
126  make_tuple(splitk_batch_offset.splitted_k, kargs.M),
127  make_tuple(kargs.stride_A, 1),
128  number<FlatmmPipeline::GetVectorSizeA()>{},
129  number<1>{});
130  }
131  }();
132 
133  index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1);
134  index_t kFlatN = kargs.N * kargs.K / kFlatK;
135 
136  const auto& b_flat_tensor_view = [&]() {
137  return make_naive_tensor_view<address_space_enum::global>(
138  b_flat_ptr,
139  make_tuple(kFlatN, kFlatK),
140  make_tuple(kFlatK, 1),
141  number<FlatmmPipeline::GetVectorSizeB()>{},
142  number<1>{});
143  }();
144 
145  const auto& ds_tensor_view = generate_tuple(
146  [&](auto i) {
147  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
148  using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
149  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
150  {
151  return make_naive_tensor_view<address_space_enum::global>(
152  static_cast<const DDataType_*>(ds_ptr[i]),
153  make_tuple(kargs.M, kargs.N),
154  make_tuple(kargs.stride_Ds[i], 1),
155  number<EpiloguePipeline::GetVectorSizeD(i)>{},
156  number<1>{});
157  }
158  else
159  {
160  return make_naive_tensor_view<address_space_enum::global>(
161  static_cast<const DDataType_*>(ds_ptr[i]),
162  make_tuple(kargs.N, kargs.M),
163  make_tuple(kargs.stride_Ds[i], 1),
164  number<EpiloguePipeline::GetVectorSizeD(i)>{},
165  number<1>{});
166  }
167  },
169 
170  // TODO: enable vector write for C in ColMajor
171  const auto& e_tensor_view = [&]() {
172  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
173  {
174  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
175  e_ptr,
176  make_tuple(kargs.M, kargs.N),
177  make_tuple(kargs.stride_E, 1),
178  number<EpiloguePipeline::GetVectorSizeC()>{},
179  number<1>{});
180  }
181  else
182  {
183  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
184  e_ptr,
185  make_tuple(kargs.N, kargs.M),
186  make_tuple(kargs.stride_E, 1),
187  number<1>{},
188  number<1>{});
189  }
190  }();
191 
192  auto scale_n = kargs.scale_n_ptr;
193 
194  index_t FlatScaleK =
195  (kargs.K / decltype(scale_n)::GranularityK) * N_Pack * BlockGemmShape::WarpTile::at(I1);
196  index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
197 
198  const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
199  reinterpret_cast<const e8m0_t*>(scale_n.ptr),
200  make_tuple(FlatScaleN, FlatScaleK),
201  make_tuple(FlatScaleK, 1),
202  number<8>{},
203  number<1>{});
204 
205  return make_tuple(
206  a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view, scale_b_flat_view);
207  }
208 
209  template <typename TensorView>
210  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
211  {
212  const auto& a_pad_view = [&]() {
213  const auto& a_tensor_view = views.at(I0);
214  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
215  {
216  return pad_tensor_view(a_tensor_view,
220  }
221  else
222  {
223  return pad_tensor_view(a_tensor_view,
227  }
228  }();
229 
230  const auto& b_flat_tensor_view = views.at(I1);
231 
232  const auto& ds_pad_view = generate_tuple(
233  [&](auto i) {
234  const auto& d_tensor_view = views.at(I2);
235  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
236  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
237  {
238  return pad_tensor_view(d_tensor_view[i],
242  }
243  else
244  {
245  return pad_tensor_view(d_tensor_view[i],
249  }
250  },
252 
253  // TODO vector write in for C in ColMajor
254  const auto& e_pad_view = [&]() {
255  const auto& e_tensor_view = views.at(I3);
256  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
257  {
258  return pad_tensor_view(e_tensor_view,
262  }
263  else
264  {
265  return pad_tensor_view(e_tensor_view,
269  }
270  }();
271 
272  return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view, views.at(I4));
273  }
274 
275  template <typename PadView>
276  CK_TILE_DEVICE static auto
277  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
278  {
279  const auto& a_pad_view = views.at(I0);
280  const auto& b_flat_pad_view = views.at(I1);
281  const auto& ds_pad_view = views.at(I2);
282  const auto& e_pad_view = views.at(I3);
283 
284  const auto& a_block_window = [&]() {
285  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
286  {
287  return make_tile_window(a_pad_view,
290  {i_m, 0});
291  }
292  else
293  {
294  return make_tile_window(a_pad_view,
297  {0, i_m});
298  }
299  }();
300 
301  const auto& b_flat_block_window =
302  make_tile_window(b_flat_pad_view,
305  {static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
306 
307  const auto ds_block_window = generate_tuple(
308  [&](auto i) {
309  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
310  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
311  {
312  return make_tile_window(ds_pad_view[i],
315  {i_m, i_n});
316  }
317  else
318  {
319  return make_tile_window(ds_pad_view[i],
322  {i_n, i_m});
323  }
324  },
326 
327  auto e_block_window = make_tile_window(
328  e_pad_view,
330  {i_m, i_n});
331 
332  auto scale_block_window =
333  make_tile_window(views.at(I4),
335  number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / 32>{}),
336  {i_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
337 
338  return make_tuple(a_block_window,
339  b_flat_block_window,
340  ds_block_window,
341  e_block_window,
342  scale_block_window);
343  }
344 
345  template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
346  CK_TILE_DEVICE static void
347  RunFlatmm(const ADataType* a_ptr,
348  const BDataType* b_flat_ptr,
349  const std::array<const void*, NumDTensor>& ds_ptr,
350  EDataType* e_ptr,
351  void* smem_ptr_ping,
352  void* smem_ptr_pong,
353  const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
354  const SplitKBatchOffset& splitk_batch_offset,
355  const index_t block_idx_m,
356  const index_t block_idx_n)
357  {
358  // Create Gemm tensor views, pad views and tile windows
359  const auto& gemm_tensor_views_tuple =
360  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
361  a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
362  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
363  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
364 
365  const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
366 
367  // Run GEMM cooperatively by whole workgroup.
368  const auto& a_block_window = gemm_tile_windows.at(I0);
369  const auto& b_flat_block_window = gemm_tile_windows.at(I1);
370  const auto& d_block_window = gemm_tile_windows.at(I2);
371  const auto& scale_block_window = gemm_tile_windows.at(I4);
372 
373  static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
374  || ScaleM::GranularityMN == -1 // or ScaleA is disable
375  || ScaleN::GranularityMN == -1, // or ScaleB is disable
376  "ScaleM and ScaleN should have the same GranularityK");
377  constexpr bool DoEpiScale =
378  (ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token
379  (ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel
380 
381  auto a_block_window_with_distr =
382  ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
383  a_block_window.get_window_lengths(),
384  a_block_window.get_window_origin(),
385  FlatmmPipeline::GetADramTileDistribution());
386  const auto& c_block_tile = FlatmmPipeline{}(a_block_window_with_distr,
387  b_flat_block_window,
388  scale_block_window,
389  num_loop,
390  smem_ptr_ping,
391  smem_ptr_pong);
392 
393  // Run Epilogue Pipeline
394  if constexpr(DoEpiScale)
395  {
396  auto& c_block_window = gemm_tile_windows.at(I3);
397  EpiloguePipeline{}(c_block_window,
398  c_block_tile,
399  d_block_window,
400  smem_ptr_ping,
401  kargs.scale_m_ptr + block_idx_m,
402  kargs.scale_n_ptr + block_idx_n);
403  }
404  else if(UseDefaultScheduler || (get_warp_id() == 0))
405  {
406  // Run Epilogue Pipeline
407  auto& c_block_window = gemm_tile_windows.at(I3);
408  EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
409  }
410  }
411 
412  template <class ScaleM, class ScaleN>
413  CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
414  int partition_idx = blockIdx.x) const
415  {
416  int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
417 
418  do
419  {
420  const auto [iM, iN] =
421  TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
422  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
423  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
424 
425  const SplitKBatchOffset splitk_batch_offset(kargs);
426  // options
427  const ADataType* a_ptr =
428  static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
429  const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
430  splitk_batch_offset.b_k_split_offset / QuantPackedSize;
431  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
432 
433  // allocate LDS
434  __shared__ char smem_ptr_ping[Underlying::GetSmemPingSize()];
435  __shared__ char smem_ptr_pong[Underlying::GetSmemPongSize()];
436 
437  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
438  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
440  {
441  constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
442  RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
443  b_flat_ptr,
444  kargs.ds_ptr,
445  e_ptr,
446  smem_ptr_ping,
447  smem_ptr_pong,
448  kargs,
449  splitk_batch_offset,
450  i_m,
451  i_n);
452  }
453  partition_idx += gridDim.x;
454  } while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
455  }
456 };
457 
458 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
__global__ void kentry(Args... args)
Definition: kernel_launch.hpp:22
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__device__ X atomic_add(X *p_dst, const X &x)
Definition: mixed_prec_flatmm_kernel.hpp:18
static constexpr int N_Pack
Definition: mixed_prec_flatmm_kernel.hpp:40
static constexpr auto I4
Definition: mixed_prec_flatmm_kernel.hpp:48
static constexpr index_t KernelBlockSize
Definition: mixed_prec_flatmm_kernel.hpp:31
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: mixed_prec_flatmm_kernel.hpp:277
static CK_TILE_HOST const std::string GetName()
Definition: mixed_prec_flatmm_kernel.hpp:54
static constexpr auto I0
Definition: mixed_prec_flatmm_kernel.hpp:44
static constexpr auto I1
Definition: mixed_prec_flatmm_kernel.hpp:45
static constexpr auto I2
Definition: mixed_prec_flatmm_kernel.hpp:46
CK_TILE_DEVICE void operator()(FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> kargs, int partition_idx=blockIdx.x) const
Definition: mixed_prec_flatmm_kernel.hpp:413
static constexpr int QuantPackedSize
Definition: mixed_prec_flatmm_kernel.hpp:39
static constexpr bool UsePersistentKernel
Definition: mixed_prec_flatmm_kernel.hpp:32
static constexpr CK_TILE_HOST auto GridSize(const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs)
Definition: mixed_prec_flatmm_kernel.hpp:63
typename Underlying::SplitKBatchOffset SplitKBatchOffset
Definition: mixed_prec_flatmm_kernel.hpp:101
static constexpr auto I3
Definition: mixed_prec_flatmm_kernel.hpp:47
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_ping, void *smem_ptr_pong, const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition: mixed_prec_flatmm_kernel.hpp:347
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: mixed_prec_flatmm_kernel.hpp:105
static constexpr index_t NumDTensor
Definition: mixed_prec_flatmm_kernel.hpp:42
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: mixed_prec_flatmm_kernel.hpp:210
Definition: flatmm_kernel.hpp:362
Definition: flatmm_kernel.hpp:229
Definition: flatmm_kernel.hpp:249
static constexpr CK_TILE_HOST auto BlockSize()
Definition: flatmm_kernel.hpp:330
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition: flatmm_kernel.hpp:253
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: flatmm_kernel.hpp:250
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: flatmm_kernel.hpp:258
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: flatmm_kernel.hpp:259
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: flatmm_kernel.hpp:266
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: flatmm_kernel.hpp:254
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPingSize()
Definition: flatmm_kernel.hpp:352
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition: flatmm_kernel.hpp:251
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition: flatmm_kernel.hpp:257
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition: flatmm_kernel.hpp:263
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition: flatmm_kernel.hpp:256
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition: flatmm_kernel.hpp:255
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition: flatmm_kernel.hpp:264
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPongSize()
Definition: flatmm_kernel.hpp:356
Definition: integral_constant.hpp:13
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:27
Definition: type_traits.hpp:115
Definition: numeric.hpp:81
Definition: sequence.hpp:49