/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp Source File
flatmm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
12 
13 namespace ck_tile {
14 
15 template <index_t NumDTensor = 0>
17 {
19  CK_TILE_HOST FlatmmHostArgs(const void* a_ptr_,
20  const void* b_ptr_,
21  const std::array<const void*, NumDTensor>& ds_ptr_,
22  void* e_ptr_,
23  index_t k_batch_,
24  index_t M_,
25  index_t N_,
26  index_t K_,
27  index_t stride_A_,
28  index_t stride_B_,
29  const std::array<index_t, NumDTensor>& stride_Ds_,
30  index_t stride_E_)
31  : a_ptr(a_ptr_),
32  b_ptr(b_ptr_),
33  ds_ptr(ds_ptr_),
34  e_ptr(e_ptr_),
35  M(M_),
36  N(N_),
37  K(K_),
38  stride_A(stride_A_),
39  stride_B(stride_B_),
40  stride_Ds(stride_Ds_),
41  stride_E(stride_E_),
42  k_batch(k_batch_)
43  {
44  }
45 
46  const void* a_ptr;
47  const void* b_ptr;
48  const std::array<const void*, NumDTensor> ds_ptr;
49  union
50  {
51  void* e_ptr;
52  void* c_ptr;
53  };
59  const std::array<index_t, NumDTensor> stride_Ds;
60  union
61  {
64  };
65 
67 };
68 
69 template <index_t NumDTensor = 0>
71 {
72  const void* a_ptr;
73  // const void* b_shuffle_ptr;
74  const void* b_ptr;
75  const std::array<const void*, NumDTensor> ds_ptr;
76  void* e_ptr;
82  std::array<index_t, NumDTensor> stride_Ds;
85 };
86 
87 template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
89 {
100  static constexpr index_t kBlockSize = FlatmmPipeline::BlockSize;
101 
104  // Below type is actually accumulation data type - the output of block GEMM.
106 
107  static constexpr index_t NumDTensor = DsDataType::size();
108 
109  static constexpr auto I0 = number<0>();
110  static constexpr auto I1 = number<1>();
111  static constexpr auto I2 = number<2>();
112  static constexpr auto I3 = number<3>();
113 
114  static_assert(DsLayout::size() == DsDataType::size(),
115  "The size of DsLayout and DsDataType should be the same");
116  using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
117 
118  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
119  {
120  // clang-format off
121  return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
122  // clang-format on
123  }
124 
125  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
126  {
127  return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
128  }
129 
130  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
131 
132  CK_TILE_HOST static constexpr KernelArgs
134  {
135  return KernelArgs{hostArgs.a_ptr,
136  hostArgs.b_ptr,
137  hostArgs.ds_ptr,
138  hostArgs.e_ptr,
139  hostArgs.M,
140  hostArgs.N,
141  hostArgs.K,
142  hostArgs.stride_A,
143  hostArgs.stride_B,
144  hostArgs.stride_Ds,
145  hostArgs.stride_E,
146  hostArgs.k_batch};
147  }
148 
150  {
151  return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
152  }
153 
155  {
156  __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
157  {
158  constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
159  const index_t K_t = kargs.k_batch * K1;
160  const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
161 
162  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
163  {
164  a_k_split_offset = k_id * KRead;
165  }
166  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
167  {
168  a_k_split_offset = k_id * KRead * kargs.stride_A;
169  }
170 
171  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
172  {
173  b_k_split_offset = k_id * KRead * kargs.stride_B;
174  }
175  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
176  {
177  b_k_split_offset = k_id * KRead;
178  }
179 
180  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
181  {
182  splitted_k = KRead;
183  }
184  else
185  {
186  splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
187  }
188  }
189 
193  };
194 
195  CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
196  {
197  if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
199  {
200  if(kargs.k_batch != 1)
201  {
202  std::cerr << "Conditions not met for Kbatch >1 !" << std::endl;
203  return false;
204  }
205  }
206 
207  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
208  {
209  if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false)
210  {
211  std::cerr << "Can't support K that is not a multiple of KPerBlock"
212  " without padding!"
213  << std::endl;
214  return false;
215  }
216  if(kargs.K % FlatmmPipeline::GetVectorSizeA() != 0)
217  {
218  std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl;
219  return false;
220  }
221  }
222  else
223  {
224  if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
225  {
226  std::cerr << "Can't support M that is not a multiple of MPerBlock"
227  " without padding!"
228  << std::endl;
229  return false;
230  }
231  if(kargs.M % FlatmmPipeline::GetVectorSizeA() != 0)
232  {
233  std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl;
234  return false;
235  }
236  }
237 
238  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
239  {
240  if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
241  {
242  std::cerr << "Can't support N that is not a multiple of NPerBlock"
243  " without padding!"
244  << std::endl;
245  return false;
246  }
247  if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0)
248  {
249  std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl;
250  return false;
251  }
252  }
253  else
254  {
255  if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false)
256  {
257  std::cerr << "Can't support K that is not a multiple of KPerBlock"
258  " without padding!"
259  << std::endl;
260  return false;
261  }
262  if(kargs.K % FlatmmPipeline::GetVectorSizeB() != 0)
263  {
264  std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl;
265  return false;
266  }
267  }
268 
269  bool DTesnorIsValid = {true};
270  static_for<0, NumDTensor, 1>{}([&](auto index) {
271  using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
272  if(std::is_same_v<DiLayout, ELayout> == false)
273  {
274  DTesnorIsValid = false;
275  }
276  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
277  {
278  if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
279  {
280  CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
281  "NPerBlock without padding!");
282  DTesnorIsValid = false;
283  }
284  if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
285  {
286  CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
287  DTesnorIsValid = false;
288  }
289  }
290  else
291  {
292  if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
293  {
294  CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
295  "MPerBlock without padding!");
296 
297  DTesnorIsValid = false;
298  }
299  if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
300  {
301  CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
302  DTesnorIsValid = false;
303  }
304  }
305  });
306 
307  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
308  {
309  if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
310  {
311  std::cerr << "Can't support N that is not a multiple of NPerBlock"
312  " without padding!"
313  << std::endl;
314  return false;
315  }
316  if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
317  {
318  std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
319  return false;
320  }
321  }
322  else
323  {
324  if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
325  {
326  std::cerr << "Can't support M that is not a multiple of MPerBlock"
327  " without padding!"
328  << std::endl;
329  return false;
330  }
331  if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
332  {
333  std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
334  return false;
335  }
336  }
337  return DTesnorIsValid;
338  }
339 
340  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
341  CK_TILE_DEVICE static auto
343  const BDataType* b_flat_ptr,
344  const std::array<const void*, NumDTensor>& ds_ptr,
345  EDataType* e_ptr,
346  const KernelArgs& kargs,
347  const SplitKBatchOffset& splitk_batch_offset)
348  {
349  const auto& a_tensor_view = [&]() {
350  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
351  {
352  return make_naive_tensor_view<address_space_enum::global>(
353  a_ptr,
354  make_tuple(kargs.M, splitk_batch_offset.splitted_k),
355  make_tuple(kargs.stride_A, 1),
356  number<FlatmmPipeline::GetVectorSizeA()>{},
357  number<1>{});
358  }
359  else
360  {
361  return make_naive_tensor_view<address_space_enum::global>(
362  a_ptr,
363  make_tuple(splitk_batch_offset.splitted_k, kargs.M),
364  make_tuple(kargs.stride_A, 1),
365  number<FlatmmPipeline::GetVectorSizeA()>{},
366  number<1>{});
367  }
368  }();
369 
370  index_t kFlatK = FlatmmPipeline::flatKPerWarp * (splitk_batch_offset.splitted_k /
371  BlockGemmShape::WarpTile::at(number<2>{}));
372  index_t kFlatN = kargs.N * kargs.K / kFlatK;
373  const auto& b_flat_tensor_view = [&]() {
374  return make_naive_tensor_view<address_space_enum::global>(
375  b_flat_ptr,
376  make_tuple(kFlatN, kFlatK),
377  make_tuple(kFlatK, 1),
378  number<FlatmmPipeline::GetVectorSizeB()>{},
379  number<1>{});
380  }();
381 
382  const auto& ds_tensor_view = generate_tuple(
383  [&](auto i) {
384  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
385  using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
386  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
387  {
388  return make_naive_tensor_view<address_space_enum::global>(
389  static_cast<const DDataType_*>(ds_ptr[i]),
390  make_tuple(kargs.M, kargs.N),
391  make_tuple(kargs.stride_Ds[i], 1),
392  number<EpiloguePipeline::GetVectorSizeD(i)>{},
393  number<1>{});
394  }
395  else
396  {
397  return make_naive_tensor_view<address_space_enum::global>(
398  static_cast<const DDataType_*>(ds_ptr[i]),
399  make_tuple(kargs.N, kargs.M),
400  make_tuple(kargs.stride_Ds[i], 1),
401  number<EpiloguePipeline::GetVectorSizeD(i)>{},
402  number<1>{});
403  }
404  },
406 
407  // TODO: enable vector write for C in ColMajor
408  const auto& e_tensor_view = [&]() {
409  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
410  {
411  return make_naive_tensor_view<address_space_enum::global>(
412  e_ptr,
413  make_tuple(kargs.M, kargs.N),
414  make_tuple(kargs.stride_E, 1),
415  number<EpiloguePipeline::GetVectorSizeC()>{},
416  number<1>{});
417  }
418  else
419  {
420  return make_naive_tensor_view<address_space_enum::global>(
421  e_ptr,
422  make_tuple(kargs.N, kargs.M),
423  make_tuple(kargs.stride_E, 1),
424  number<1>{},
425  number<1>{});
426  }
427  }();
428 
429  return make_tuple(a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view);
430  }
431 
432  template <typename TensorView>
433  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
434  {
435  const auto& a_pad_view = [&]() {
436  const auto& a_tensor_view = views.at(I0);
437  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
438  {
439  return pad_tensor_view(a_tensor_view,
443  }
444  else
445  {
446  return pad_tensor_view(a_tensor_view,
450  }
451  }();
452 
453  const auto& b_flat_tensor_view = views.at(I1);
454 
455  const auto& ds_pad_view = generate_tuple(
456  [&](auto i) {
457  const auto& d_tensor_view = views.at(I2);
458  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
459  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
460  {
461  return pad_tensor_view(d_tensor_view[i],
465  }
466  else
467  {
468  return pad_tensor_view(d_tensor_view[i],
472  }
473  },
475 
476  // TODO vector write in for C in ColMajor
477  const auto& e_pad_view = [&]() {
478  const auto& e_tensor_view = views.at(I3);
479  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
480  {
481  return pad_tensor_view(e_tensor_view,
485  }
486  else
487  {
488  return pad_tensor_view(e_tensor_view,
492  }
493  }();
494 
495  return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view);
496  }
497 
498  template <typename PadView>
499  CK_TILE_DEVICE static auto
500  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
501  {
502  const auto& a_pad_view = views.at(I0);
503  const auto& b_flat_pad_view = views.at(I1);
504  const auto& ds_pad_view = views.at(I2);
505  const auto& e_pad_view = views.at(I3);
506 
507  const auto& a_block_window = [&]() {
508  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
509  {
510  return make_tile_window(a_pad_view,
513  {i_m, 0});
514  }
515  else
516  {
517  return make_tile_window(a_pad_view,
520  {0, i_m});
521  }
522  }();
523 
524  const auto& b_flat_block_window =
525  make_tile_window(b_flat_pad_view,
528  {static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
529 
530  const auto ds_block_window = generate_tuple(
531  [&](auto i) {
532  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
533  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
534  {
535  return make_tile_window(ds_pad_view[i],
538  {i_m, i_n});
539  }
540  else
541  {
542  return make_tile_window(ds_pad_view[i],
545  {i_n, i_m});
546  }
547  },
549 
550  auto e_block_window = make_tile_window(
551  e_pad_view,
553  {i_m, i_n});
554 
555  return make_tuple(a_block_window, b_flat_block_window, ds_block_window, e_block_window);
556  }
557 
558  template <bool UseDefaultScheduler = true>
559  CK_TILE_DEVICE static void RunFlatmm(const ADataType* a_ptr,
560  const BDataType* b_flat_ptr,
561  const std::array<const void*, NumDTensor>& ds_ptr,
562  EDataType* e_ptr,
563  void* smem_ptr,
564  const KernelArgs& kargs,
565  const SplitKBatchOffset& splitk_batch_offset,
566  const index_t block_idx_m,
567  const index_t block_idx_n)
568  {
569  // Create Gemm tensor views, pad views and tile windows
570  const auto& gemm_tensor_views_tuple =
571  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
572  a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
573  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
574  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
575 
576  const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
577 
578  // Run GEMM cooperatively by whole workgroup.
579  const auto& a_block_window = gemm_tile_windows.at(I0);
580  const auto& b_flat_block_window = gemm_tile_windows.at(I1);
581  const auto& d_block_window = gemm_tile_windows.at(I2);
582  const auto& c_block_tile = FlatmmPipeline{}.template operator()(
583  a_block_window, b_flat_block_window, num_loop, smem_ptr);
584  if(UseDefaultScheduler || (get_warp_id() == 0))
585  {
586  // Run Epilogue Pipeline
587  auto& c_block_window = gemm_tile_windows.at(I3);
588 
589  EpiloguePipeline{}.template
590  operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
591  c_block_window, c_block_tile, d_block_window, smem_ptr);
592  }
593  }
594 
596  {
597  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
598  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
599  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
600 
601  const SplitKBatchOffset splitk_batch_offset(kargs);
602  // options
603  const ADataType* a_ptr =
604  static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
605  const BDataType* b_flat_ptr =
606  static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
607  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
608 
609  // allocate LDS
610  __shared__ char smem_ptr[GetSmemSize()];
611 
612  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
613  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
615  {
616  constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
617  RunFlatmm<scheduler_type>(a_ptr,
618  b_flat_ptr,
619  kargs.ds_ptr,
620  e_ptr,
621  smem_ptr,
622  kargs,
623  splitk_batch_offset,
624  i_m,
625  i_n);
626  }
627  }
628 };
629 
630 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__device__ X atomic_add(X *p_dst, const X &x)
unsigned int uint32_t
Definition: stdint.h:126
Definition: flatmm_kernel.hpp:17
index_t stride_C
Definition: flatmm_kernel.hpp:63
index_t stride_A
Definition: flatmm_kernel.hpp:57
CK_TILE_HOST FlatmmHostArgs(const void *a_ptr_, const void *b_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition: flatmm_kernel.hpp:19
index_t K
Definition: flatmm_kernel.hpp:56
index_t stride_E
Definition: flatmm_kernel.hpp:62
const void * b_ptr
Definition: flatmm_kernel.hpp:47
void * c_ptr
Definition: flatmm_kernel.hpp:52
CK_TILE_HOST FlatmmHostArgs()=default
void * e_ptr
Definition: flatmm_kernel.hpp:51
const std::array< index_t, NumDTensor > stride_Ds
Definition: flatmm_kernel.hpp:59
const void * a_ptr
Definition: flatmm_kernel.hpp:46
index_t N
Definition: flatmm_kernel.hpp:55
index_t stride_B
Definition: flatmm_kernel.hpp:58
index_t k_batch
Definition: flatmm_kernel.hpp:66
index_t M
Definition: flatmm_kernel.hpp:54
const std::array< const void *, NumDTensor > ds_ptr
Definition: flatmm_kernel.hpp:48
Definition: flatmm_kernel.hpp:155
index_t b_k_split_offset
Definition: flatmm_kernel.hpp:191
index_t a_k_split_offset
Definition: flatmm_kernel.hpp:190
index_t splitted_k
Definition: flatmm_kernel.hpp:192
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: flatmm_kernel.hpp:156
Definition: flatmm_kernel.hpp:71
index_t N
Definition: flatmm_kernel.hpp:78
index_t K
Definition: flatmm_kernel.hpp:79
void * e_ptr
Definition: flatmm_kernel.hpp:76
index_t k_batch
Definition: flatmm_kernel.hpp:84
const std::array< const void *, NumDTensor > ds_ptr
Definition: flatmm_kernel.hpp:75
index_t M
Definition: flatmm_kernel.hpp:77
const void * a_ptr
Definition: flatmm_kernel.hpp:72
index_t stride_A
Definition: flatmm_kernel.hpp:80
index_t stride_E
Definition: flatmm_kernel.hpp:83
index_t stride_B
Definition: flatmm_kernel.hpp:81
const void * b_ptr
Definition: flatmm_kernel.hpp:74
std::array< index_t, NumDTensor > stride_Ds
Definition: flatmm_kernel.hpp:82
Definition: flatmm_kernel.hpp:89
FlatmmKernelArgs< DsLayout::size()> KernelArgs
Definition: flatmm_kernel.hpp:116
static constexpr CK_TILE_HOST auto BlockSize()
Definition: flatmm_kernel.hpp:130
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition: flatmm_kernel.hpp:93
static constexpr auto I0
Definition: flatmm_kernel.hpp:109
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: flatmm_kernel.hpp:90
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: flatmm_kernel.hpp:98
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: flatmm_kernel.hpp:99
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: flatmm_kernel.hpp:105
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: flatmm_kernel.hpp:342
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: flatmm_kernel.hpp:94
static constexpr auto I2
Definition: flatmm_kernel.hpp:111
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition: flatmm_kernel.hpp:559
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: flatmm_kernel.hpp:433
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: flatmm_kernel.hpp:195
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition: flatmm_kernel.hpp:91
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition: flatmm_kernel.hpp:97
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition: flatmm_kernel.hpp:102
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition: flatmm_kernel.hpp:96
static constexpr index_t NumDTensor
Definition: flatmm_kernel.hpp:107
static CK_TILE_HOST const std::string GetName()
Definition: flatmm_kernel.hpp:118
static constexpr CK_TILE_HOST KernelArgs MakeKernelArgs(const FlatmmHostArgs< NumDTensor > &hostArgs)
Definition: flatmm_kernel.hpp:133
static constexpr index_t kBlockSize
Definition: flatmm_kernel.hpp:100
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition: flatmm_kernel.hpp:95
static constexpr auto I3
Definition: flatmm_kernel.hpp:112
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: flatmm_kernel.hpp:149
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: flatmm_kernel.hpp:125
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: flatmm_kernel.hpp:500
static constexpr auto I1
Definition: flatmm_kernel.hpp:110
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: flatmm_kernel.hpp:595
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition: flatmm_kernel.hpp:103
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
Definition: functional.hpp:43