/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp Source File
mx_flatmm_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 
13 
14 namespace ck_tile {
15 
16 template <typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
17 struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_>
18 {
20 
31  static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
32  static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
33 
36  // Below type is actually accumulation data type - the output of block GEMM.
38 
39  static constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(number<0>{});
40  static constexpr int NThreadPerXdl = BlockGemmShape::WarpTile::at(number<1>{});
41  static constexpr int KThreadPerXdl = 64 / MThreadPerXdl;
42 
45 
46  static constexpr int MXdlPack = FlatmmPipeline::MXdlPack;
47  static constexpr int NXdlPack = FlatmmPipeline::NXdlPack;
48  static constexpr int KXdlPack = FlatmmPipeline::KXdlPack;
49 
50  static constexpr index_t NumDTensor = DsDataType::size();
51 
52  static constexpr auto I0 = number<0>();
53  static constexpr auto I1 = number<1>();
54  static constexpr auto I2 = number<2>();
55  static constexpr auto I3 = number<3>();
56  static constexpr auto I4 = number<4>();
57  static constexpr auto I5 = number<5>();
58 
59  static_assert(DsLayout::size() == DsDataType::size(),
60  "The size of DsLayout and DsDataType should be the same");
61  // using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
62 
63  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
64  {
65  // clang-format off
66  return concat('_', "mx_flatmm_gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
67  // clang-format on
68  }
69 
70  template <class ScaleM, class ScaleN>
71  CK_TILE_HOST static constexpr auto
72  GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
73  {
74  if constexpr(UsePersistentKernel)
75  {
76  hipDeviceProp_t prop;
77  int deviceId = 0; // default device
78 
79  constexpr int block_size = MXFlatmmKernel::BlockSize().x;
80  int dync_smem_size = 0;
81  int maxActiveBlocksPerCU = 0;
82 
83  if(hipGetDeviceProperties(&prop, deviceId) != hipSuccess)
84  throw std::runtime_error(std::string("hipGetDeviceProperties failed: ") +
85  hipGetErrorName(hipGetLastError()));
86 
87  if(hipOccupancyMaxActiveBlocksPerMultiprocessor(
88  &maxActiveBlocksPerCU,
89  reinterpret_cast<void*>(
90  kentry<1, MXFlatmmKernel, remove_cvref_t<decltype(kargs)>>),
91  block_size,
92  dync_smem_size) != hipSuccess)
93  throw std::runtime_error(
94  std::string("hipOccupancyMaxActiveBlocksPerMultiprocessor failed: ") +
95  hipGetErrorName(hipGetLastError()));
96 
97  const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
98  const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
99 
100  // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
101  // << ", persistent_block_size: " << persistent_block_size
102  // << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
103 
104  if(kargs.k_batch != 1)
105  throw std::runtime_error("Wrong! k_batch != 1 not supported in persistent kernel");
106  return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
107  }
108  else
109  {
110  return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
111  }
112  }
113 
114  using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
115 
116  template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
117  CK_TILE_DEVICE static auto
119  const BDataType* b_flat_ptr,
120  const std::array<const void*, NumDTensor>& ds_ptr,
121  EDataType* e_ptr,
122  const KernelArgs& kargs,
123  const SplitKBatchOffset& splitk_batch_offset)
124  {
125  const auto& a_tensor_view = [&]() {
126  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
127  {
128  return make_naive_tensor_view<address_space_enum::global>(
129  a_ptr,
130  make_tuple(kargs.M, splitk_batch_offset.splitted_k),
131  make_tuple(kargs.stride_A, 1),
132  number<FlatmmPipeline::GetVectorSizeA()>{},
133  number<1>{});
134  }
135  else
136  {
137  return make_naive_tensor_view<address_space_enum::global>(
138  a_ptr,
139  make_tuple(splitk_batch_offset.splitted_k, kargs.M),
140  make_tuple(kargs.stride_A, 1),
141  number<FlatmmPipeline::GetVectorSizeA()>{},
142  number<1>{});
143  }
144  }();
145 
146  constexpr index_t kKPerBlock = FlatmmPipeline::kKPerBlock;
147  constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1);
148  constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile;
149  const index_t kFlatKBlocks = kargs.K / kKPerBlock;
150  const index_t kFlatN = kargs.N / kNWarpTile;
151  const auto& b_flat_tensor_view = [&]() {
152  static_assert(flatKPerBlock % FlatmmPipeline::GetVectorSizeB() == 0,
153  "wrong! vector size for B tensor");
154  auto&& naive_desc = make_naive_tensor_descriptor_packed(
155  make_tuple(kFlatN, kFlatKBlocks, number<flatKPerBlock>{}));
156  auto&& desc = transform_tensor_descriptor(
157  naive_desc,
160  make_tuple(kFlatKBlocks, number<flatKPerBlock>{}))),
163  return make_tensor_view<address_space_enum::global>(b_flat_ptr, desc);
164  }();
165 
166  const auto& ds_tensor_view = generate_tuple(
167  [&](auto i) {
168  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
169  using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
170  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
171  {
172  return make_naive_tensor_view<address_space_enum::global>(
173  static_cast<const DDataType_*>(ds_ptr[i]),
174  make_tuple(kargs.M, kargs.N),
175  make_tuple(kargs.stride_Ds[i], 1),
176  number<EpiloguePipeline::GetVectorSizeD(i)>{},
177  number<1>{});
178  }
179  else
180  {
181  return make_naive_tensor_view<address_space_enum::global>(
182  static_cast<const DDataType_*>(ds_ptr[i]),
183  make_tuple(kargs.N, kargs.M),
184  make_tuple(kargs.stride_Ds[i], 1),
185  number<EpiloguePipeline::GetVectorSizeD(i)>{},
186  number<1>{});
187  }
188  },
190 
191  // TODO: enable vector write for C in ColMajor
192  const auto& e_tensor_view = [&]() {
193  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
194  {
195  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
196  e_ptr,
197  make_tuple(kargs.M, kargs.N),
198  make_tuple(kargs.stride_E, 1),
199  number<EpiloguePipeline::GetVectorSizeC()>{},
200  number<1>{});
201  }
202  else
203  {
204  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
205  e_ptr,
206  make_tuple(kargs.N, kargs.M),
207  make_tuple(kargs.stride_E, 1),
208  number<1>{},
209  number<1>{});
210  }
211  }();
212 
213  auto scale_a = kargs.scale_m_ptr;
214  auto scale_b = kargs.scale_n_ptr;
215 
216  static constexpr int BlockScaleSize = 32; // decltype(scale_n)::GranularityK;
217  const auto&& scale_packs_m = integer_divide_ceil(kargs.M, (MXdlPack * MThreadPerXdl));
218  const auto&& scale_packs_n = integer_divide_ceil(kargs.N, (NXdlPack * NThreadPerXdl));
219  const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl);
220 
221  // A scale tensor view
222  const auto& scale_a_tensor_view = [&]() {
223  // Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load
224  const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed(
225  make_tuple(scale_packs_m, scale_packs_k, KThreadPerXdl, MThreadPerXdl));
226  const auto scale_a_desc = transform_tensor_descriptor(
227  scale_a_naive_desc,
232 
233  return make_tensor_view<address_space_enum::global>(
234  reinterpret_cast<const int32_t*>(scale_a.ptr), scale_a_desc);
235  }();
236 
237  // B scale tensor view
238  const auto& scale_b_tensor_view = [&]() {
239  const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed(
240  make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl));
241  const auto scale_b_desc = transform_tensor_descriptor(
242  scale_b_navie_desc,
247 
248  return make_tensor_view<address_space_enum::global>(
249  reinterpret_cast<const int32_t*>(scale_b.ptr), scale_b_desc);
250  }();
251 
252  return make_tuple(a_tensor_view,
253  b_flat_tensor_view,
254  ds_tensor_view,
255  e_tensor_view,
256  scale_a_tensor_view,
257  scale_b_tensor_view);
258  }
259 
260  template <typename TensorView>
261  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
262  {
263  const auto& a_pad_view = [&]() {
264  const auto& a_tensor_view = views.at(I0);
265  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
266  {
267  return pad_tensor_view(a_tensor_view,
271  }
272  else
273  {
274  return pad_tensor_view(a_tensor_view,
278  }
279  }();
280 
281  const auto& b_flat_tensor_view = views.at(I1);
282 
283  const auto& ds_pad_view = generate_tuple(
284  [&](auto i) {
285  const auto& d_tensor_view = views.at(I2);
286  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
287  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
288  {
289  return pad_tensor_view(d_tensor_view[i],
293  }
294  else
295  {
296  return pad_tensor_view(d_tensor_view[i],
300  }
301  },
303 
304  // TODO vector write in for C in ColMajor
305  const auto& e_pad_view = [&]() {
306  const auto& e_tensor_view = views.at(I3);
307  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
308  {
309  return pad_tensor_view(e_tensor_view,
313  }
314  else
315  {
316  return pad_tensor_view(e_tensor_view,
320  }
321  }();
322 
323  return make_tuple(
324  a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view, views.at(I4), views.at(I5));
325  }
326 
327  template <typename PadView>
328  CK_TILE_DEVICE static auto
329  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
330  {
331  const auto& a_pad_view = views.at(I0);
332  const auto& b_flat_pad_view = views.at(I1);
333  const auto& ds_pad_view = views.at(I2);
334  const auto& e_pad_view = views.at(I3);
335 
336  const auto& a_block_window = [&]() {
337  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
338  {
339  return make_tile_window(a_pad_view,
342  {i_m, 0});
343  }
344  else
345  {
346  return make_tile_window(a_pad_view,
349  {0, i_m});
350  }
351  }();
352 
353  const auto& b_flat_block_window =
354  make_tile_window(b_flat_pad_view,
357  {static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
358 
359  const auto ds_block_window = generate_tuple(
360  [&](auto i) {
361  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
362  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
363  {
364  return make_tile_window(ds_pad_view[i],
367  {i_m, i_n});
368  }
369  else
370  {
371  return make_tile_window(ds_pad_view[i],
374  {i_n, i_m});
375  }
376  },
378 
379  auto e_block_window = make_tile_window(
380  e_pad_view,
382  {i_m, i_n});
383 
384  static constexpr int BlockScaleSize = 32;
385 
386  auto scale_a_block_window = make_tile_window(
387  views.at(I4),
389  number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPack)>{}),
390  {i_m / MXdlPack, 0});
391 
392  auto scale_b_block_window = make_tile_window(
393  views.at(I5),
395  number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPack)>{}),
396  {i_n / NXdlPack, 0});
397 
398  return make_tuple(a_block_window,
399  b_flat_block_window,
400  ds_block_window,
401  e_block_window,
402  scale_a_block_window,
403  scale_b_block_window);
404  }
405 
406  template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
407  CK_TILE_DEVICE static void
408  RunFlatmm(const ADataType* a_ptr,
409  const BDataType* b_flat_ptr,
410  const std::array<const void*, NumDTensor>& ds_ptr,
411  EDataType* e_ptr,
412  void* smem_ptr_ping,
413  void* smem_ptr_pong,
414  const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
415  const SplitKBatchOffset& splitk_batch_offset,
416  const index_t block_idx_m,
417  const index_t block_idx_n)
418  {
419  // Create Gemm tensor views, pad views and tile windows
420  const auto& gemm_tensor_views_tuple =
421  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
422  a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
423  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
424  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
425 
426  const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
427 
428  // Run GEMM cooperatively by whole workgroup.
429  const auto& a_block_window = gemm_tile_windows.at(I0);
430  const auto& b_flat_block_window = gemm_tile_windows.at(I1);
431  const auto& d_block_window = gemm_tile_windows.at(I2);
432  const auto& scale_a_block_window = gemm_tile_windows.at(I4);
433  const auto& scale_b_block_window = gemm_tile_windows.at(I5);
434 
435  static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
436  || ScaleM::GranularityMN == -1 // or ScaleA is disable
437  || ScaleN::GranularityMN == -1, // or ScaleB is disable
438  "ScaleM and ScaleN should have the same GranularityK");
439  constexpr bool DoEpiScale =
440  (ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token
441  (ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel
442 
443  auto a_block_window_with_distr =
444  ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
445  a_block_window.get_window_lengths(),
446  a_block_window.get_window_origin(),
447  FlatmmPipeline::GetADramTileDistribution());
448  const auto& c_block_tile = FlatmmPipeline{}(a_block_window_with_distr,
449  b_flat_block_window,
450  scale_a_block_window,
451  scale_b_block_window,
452  num_loop,
453  smem_ptr_ping,
454  smem_ptr_pong);
455 
456  // Run Epilogue Pipeline
457  if constexpr(DoEpiScale)
458  {
459  auto& c_block_window = gemm_tile_windows.at(I3);
460  EpiloguePipeline{}(c_block_window,
461  c_block_tile,
462  d_block_window,
463  smem_ptr_ping,
464  kargs.scale_m_ptr + block_idx_m,
465  kargs.scale_n_ptr + block_idx_n);
466  }
467  else if(UseDefaultScheduler || (get_warp_id() == 0))
468  {
469  // Run Epilogue Pipeline
470  auto& c_block_window = gemm_tile_windows.at(I3);
471  EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
472  }
473  }
474 
475  template <class ScaleM, class ScaleN>
476  CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
477  int partition_idx = blockIdx.x) const
478  {
479  int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
480 
481  do
482  {
483  const auto [iM, iN] =
484  TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
485  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
486  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
487 
488  const SplitKBatchOffset splitk_batch_offset(kargs);
489  // options
490  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) +
491  splitk_batch_offset.a_k_split_offset / APackedSize;
492  const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
493  splitk_batch_offset.b_k_split_offset / BPackedSize;
494  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
495 
496  // allocate LDS
497  __shared__ char smem_ptr_ping[Underlying::GetSmemPingSize()];
498  __shared__ char smem_ptr_pong[Underlying::GetSmemPongSize()];
499 
500  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
501  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
503  {
504  constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
505  RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
506  b_flat_ptr,
507  kargs.ds_ptr,
508  e_ptr,
509  smem_ptr_ping,
510  smem_ptr_pong,
511  kargs,
512  splitk_batch_offset,
513  i_m,
514  i_n);
515  }
516  else
517  {
518  static_assert(false,
519  "Unimplemented: atomic_add with odd vector size for fp16/bf16");
520  }
521  partition_idx += gridDim.x;
522  } while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
523  }
524 };
525 
526 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
__global__ void kentry(Args... args)
Definition: kernel_launch.hpp:22
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1609
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor_packed(const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition: tensor_descriptor.hpp:371
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__device__ X atomic_add(X *p_dst, const X &x)
Definition: flatmm_kernel.hpp:229
Definition: flatmm_kernel.hpp:249
static constexpr CK_TILE_HOST auto BlockSize()
Definition: flatmm_kernel.hpp:330
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPingSize()
Definition: flatmm_kernel.hpp:352
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPongSize()
Definition: flatmm_kernel.hpp:356
Definition: mx_flatmm_kernel.hpp:18
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition: mx_flatmm_kernel.hpp:28
static CK_TILE_HOST const std::string GetName()
Definition: mx_flatmm_kernel.hpp:63
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: mx_flatmm_kernel.hpp:25
static constexpr index_t NumDTensor
Definition: mx_flatmm_kernel.hpp:50
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: mx_flatmm_kernel.hpp:29
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition: mx_flatmm_kernel.hpp:27
static constexpr int NThreadPerXdl
Definition: mx_flatmm_kernel.hpp:40
static constexpr int NXdlPack
Definition: mx_flatmm_kernel.hpp:47
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: mx_flatmm_kernel.hpp:118
static constexpr CK_TILE_HOST auto GridSize(const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs)
Definition: mx_flatmm_kernel.hpp:72
static constexpr auto I2
Definition: mx_flatmm_kernel.hpp:54
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: mx_flatmm_kernel.hpp:37
static constexpr bool UsePersistentKernel
Definition: mx_flatmm_kernel.hpp:32
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: mx_flatmm_kernel.hpp:329
static constexpr auto I4
Definition: mx_flatmm_kernel.hpp:56
static constexpr auto I1
Definition: mx_flatmm_kernel.hpp:53
static constexpr int MXdlPack
Definition: mx_flatmm_kernel.hpp:46
static constexpr auto I0
Definition: mx_flatmm_kernel.hpp:52
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition: mx_flatmm_kernel.hpp:26
remove_cvref_t< typename MXFlatmmPipeline_::BlockGemmShape > BlockGemmShape
Definition: mx_flatmm_kernel.hpp:24
static constexpr int MThreadPerXdl
Definition: mx_flatmm_kernel.hpp:39
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition: mx_flatmm_kernel.hpp:34
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: mx_flatmm_kernel.hpp:21
static constexpr auto I3
Definition: mx_flatmm_kernel.hpp:55
typename Underlying::SplitKBatchOffset SplitKBatchOffset
Definition: mx_flatmm_kernel.hpp:114
CK_TILE_DEVICE void operator()(FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> kargs, int partition_idx=blockIdx.x) const
Definition: mx_flatmm_kernel.hpp:476
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: mx_flatmm_kernel.hpp:30
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition: mx_flatmm_kernel.hpp:35
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: mx_flatmm_kernel.hpp:261
static constexpr index_t KernelBlockSize
Definition: mx_flatmm_kernel.hpp:31
remove_cvref_t< MXFlatmmPipeline_ > FlatmmPipeline
Definition: mx_flatmm_kernel.hpp:22
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_ping, void *smem_ptr_pong, const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition: mx_flatmm_kernel.hpp:408
static constexpr int KXdlPack
Definition: mx_flatmm_kernel.hpp:48
static constexpr auto I5
Definition: mx_flatmm_kernel.hpp:57
static constexpr int KThreadPerXdl
Definition: mx_flatmm_kernel.hpp:41
static constexpr int BPackedSize
Definition: mx_flatmm_kernel.hpp:44
static constexpr int APackedSize
Definition: mx_flatmm_kernel.hpp:43
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: numeric.hpp:81
Definition: sequence.hpp:49