/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp Source File
gemm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
16 
17 namespace ck_tile {
18 
28 template <index_t NumDTensor = 0>
30 {
32  CK_TILE_HOST GemmHostArgs(const void* a_ptr_,
33  const void* b_ptr_,
34  const std::array<const void*, NumDTensor>& ds_ptr_,
35  void* e_ptr_,
36  index_t k_batch_,
37  index_t M_,
38  index_t N_,
39  index_t K_,
40  index_t stride_A_,
41  index_t stride_B_,
42  const std::array<index_t, NumDTensor>& stride_Ds_,
43  index_t stride_E_)
44  : a_ptr(a_ptr_),
45  b_ptr(b_ptr_),
46  ds_ptr(ds_ptr_),
47  e_ptr(e_ptr_),
48  M(M_),
49  N(N_),
50  K(K_),
51  stride_A(stride_A_),
52  stride_B(stride_B_),
53  stride_Ds(stride_Ds_),
54  stride_E(stride_E_),
55  k_batch(k_batch_)
56  {
57  }
58 
59  const void* a_ptr;
60  const void* b_ptr;
61  const std::array<const void*, NumDTensor> ds_ptr;
62  union
63  {
64  void* e_ptr;
65  void* c_ptr;
66  };
72  const std::array<index_t, NumDTensor> stride_Ds;
73  union
74  {
77  };
78 
80 };
81 
83 template <index_t NumDTensor = 0>
85 {
87  const void* a_ptr;
89  const void* b_ptr;
91  const std::array<const void*, NumDTensor> ds_ptr;
93  void* e_ptr;
108  std::array<index_t, NumDTensor> stride_Ds;
113 };
114 
151 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
153 {
159  // TODO: GemmPipeline::CLayout -> GemmPipeline::ELayout will be changed for multi-ABD
163  static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
164 
165  // Get the persistent kernel if the pipeline has it available
167  {
168  template <typename T>
169  using has_persistent_type = decltype(T::UsePersistentKernel);
170 
171  static constexpr bool value = []() {
173  return GemmPipeline::UsePersistentKernel;
174  else
175  return false;
176  }();
177  };
179 
182  // Below type is actually accumulation data type - the output of block GEMM.
184 
185  static constexpr index_t NumDTensor = DsDataType::size();
186 
187  static constexpr auto I0 = number<0>();
188  static constexpr auto I1 = number<1>();
189  static constexpr auto I2 = number<2>();
190  static constexpr auto I3 = number<3>{};
191 
192  static_assert(DsLayout::size() == DsDataType::size(),
193  "The size of DsLayout and DsDataType should be the same");
194  using KernelArgs = GemmKernelArgs<DsLayout::size()>;
195 
196  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
197  {
198  // clang-format off
199  return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
200  // clang-format on
201  }
202 
203  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
204  {
205  return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
206  }
207 
214  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
215  {
217  const auto kernel = kentry<KernelBlockSize, 1, Kernel, KernelArgs>;
218  int occupancy;
220  hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0));
221  const int grid_size = get_available_compute_units(s) * occupancy;
222  return dim3(grid_size, 1, 1);
223  }
224 
225  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
226 
227  CK_TILE_HOST static constexpr KernelArgs
229  {
230 
231  return KernelArgs{hostArgs.a_ptr,
232  hostArgs.b_ptr,
233  hostArgs.ds_ptr,
234  hostArgs.e_ptr,
235  hostArgs.M,
236  hostArgs.N,
237  hostArgs.K,
238  hostArgs.stride_A,
239  hostArgs.stride_B,
240  hostArgs.stride_Ds,
241  hostArgs.stride_E,
242  hostArgs.k_batch};
243  }
244 
246  {
247  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
248  }
249 
251  {
252  __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
253  {
254  constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
255  const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1);
256  const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1);
257 
258  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
259  {
260  a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
261  }
262  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
263  {
264  a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_A);
265  }
266 
267  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
268  {
269  b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_B);
270  }
271  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
272  {
273  b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
274  }
275 
276  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
277  {
278  splitted_k = __builtin_amdgcn_readfirstlane(KRead);
279  }
280  else
281  {
282  splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1));
283  }
284  }
285 
289  };
290 
291  CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
292  {
293  if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
295  {
296  if(kargs.k_batch != 1)
297  {
298  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
299  {
300  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
301  }
302  return false;
303  }
304  }
305 
306  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
307  {
308  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
309  GemmPipeline::kPadK == false)
310  {
311  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
312  {
313  CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
314  "without padding!");
315  }
316  return false;
317  }
318  if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
319  {
320  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
321  {
322  CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
323  }
324  return false;
325  }
326  }
327  else
328  {
329  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
330  {
331  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
332  {
334  "Can't support M that is not a multiple of MPerBlock without padding!");
335  }
336  return false;
337  }
338  if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
339  {
340  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
341  {
342  CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
343  }
344  return false;
345  }
346  }
347 
348  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
349  {
350  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
351  {
352  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
353  {
355  "Can't support N that is not a multiple of NPerBlock without padding!");
356  }
357  return false;
358  }
359  if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
360  {
361  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
362  {
363  CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
364  }
365  return false;
366  }
367  }
368  else
369  {
370  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
371  GemmPipeline::kPadK == false)
372  {
373  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
374  {
375  CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
376  "without padding!");
377  }
378  return false;
379  }
380  if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
381  {
382  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
383  {
384  CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
385  }
386  return false;
387  }
388  }
389 
390  bool DTesnorIsValid = {true};
391  static_for<0, NumDTensor, 1>{}([&](auto index) {
392  using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
393  if(std::is_same_v<DiLayout, ELayout> == false)
394  {
395  DTesnorIsValid = false;
396  }
397  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
398  {
399  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
400  {
401  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
402  {
403  CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
404  "NPerBlock without padding!");
405  }
406  DTesnorIsValid = false;
407  }
408  if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
409  {
410  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
411  {
412  CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
413  }
414  DTesnorIsValid = false;
415  }
416  }
417  else
418  {
419  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
420  {
421  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
422  {
423  CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
424  "MPerBlock without padding!");
425  }
426  DTesnorIsValid = false;
427  }
428  if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
429  {
430  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
431  {
432  CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
433  }
434  DTesnorIsValid = false;
435  }
436  }
437  });
438 
439  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
440  {
441  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
442  {
443  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
444  {
446  "Can't support N that is not a multiple of NPerBlock without padding!");
447  }
448  return false;
449  }
450  if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
451  {
452  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
453  {
454  CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
455  }
456  return false;
457  }
458  }
459  else
460  {
461  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
462  {
463  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
464  {
466  "Can't support M that is not a multiple of MPerBlock without padding!");
467  }
468  return false;
469  }
470  if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
471  {
472  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
473  {
474  CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
475  }
476  return false;
477  }
478  }
479  return DTesnorIsValid;
480  }
481 
482  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
483  CK_TILE_DEVICE static auto
485  const BDataType* b_ptr,
486  const std::array<const void*, NumDTensor>& ds_ptr,
487  EDataType* e_ptr,
488  const KernelArgs& kargs,
489  const SplitKBatchOffset& splitk_batch_offset)
490  {
491  static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
492  const auto& a_tensor_view = [&]() {
493  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
494  {
495  return make_naive_tensor_view<address_space_enum::global>(
496  a_ptr,
497  make_tuple(kargs.M, splitk_batch_offset.splitted_k),
498  make_tuple(kargs.stride_A, 1),
499  number<GemmPipeline::GetVectorSizeA()>{},
500  number<1>{});
501  }
502  else
503  {
504  return make_naive_tensor_view<address_space_enum::global>(
505  a_ptr,
506  make_tuple(splitk_batch_offset.splitted_k, kargs.M),
507  make_tuple(kargs.stride_A, 1),
508  number<GemmPipeline::GetVectorSizeA()>{},
509  number<1>{});
510  }
511  }();
512 
513  const auto& b_tensor_view = [&]() {
514  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
515  {
516  if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
517  {
518  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
519  const index_t K0 = splitk_batch_offset.splitted_k / K1;
520  constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
521  const auto b_k0_n_k1_desc =
523  make_tuple(kargs.N * K1, K1, I1),
525  number<1>{});
526  const auto b_n_k_desc = transform_tensor_descriptor(
527  b_k0_n_k1_desc,
532  return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
533  }
534  else
535  {
536  return make_naive_tensor_view<address_space_enum::global>(
537  b_ptr,
538  make_tuple(splitk_batch_offset.splitted_k, kargs.N),
539  make_tuple(kargs.stride_B, 1),
540  number<GemmPipeline::GetVectorSizeB()>{},
541  number<1>{});
542  }
543  }
544  else
545  {
546  if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
547  {
548  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
549  const index_t K0 = splitk_batch_offset.splitted_k / K1;
550  constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
551  const auto b_k0_n_k1_desc =
553  make_tuple(kargs.N * K1, K1, I1),
555  number<1>{});
556  const auto b_n_k_desc = transform_tensor_descriptor(
557  b_k0_n_k1_desc,
562  return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
563  }
564  else
565  {
566  return make_naive_tensor_view<address_space_enum::global>(
567  b_ptr,
568  make_tuple(kargs.N, splitk_batch_offset.splitted_k),
569  make_tuple(kargs.stride_B, 1),
570  number<GemmPipeline::GetVectorSizeB()>{},
571  number<1>{});
572  }
573  }
574  }();
575 
576  const auto& ds_tensor_view = generate_tuple(
577  [&](auto i) {
578  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
579  using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
580  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
581  {
582  return make_naive_tensor_view<address_space_enum::global>(
583  static_cast<const DDataType_*>(ds_ptr[i]),
584  make_tuple(kargs.M, kargs.N),
585  make_tuple(kargs.stride_Ds[i], 1),
586  number<EpiloguePipeline::GetVectorSizeD(i)>{},
587  number<1>{});
588  }
589  else
590  {
591  return make_naive_tensor_view<address_space_enum::global>(
592  static_cast<const DDataType_*>(ds_ptr[i]),
593  make_tuple(kargs.N, kargs.M),
594  make_tuple(kargs.stride_Ds[i], 1),
595  number<EpiloguePipeline::GetVectorSizeD(i)>{},
596  number<1>{});
597  }
598  },
600 
601  // TODO: enable vector write for C in ColMajor
602  const auto& e_tensor_view = [&]() {
603  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
604  {
605  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
606  e_ptr,
607  make_tuple(kargs.M, kargs.N),
608  make_tuple(kargs.stride_E, 1),
609  number<EpiloguePipeline::GetVectorSizeC()>{},
610  number<1>{});
611  }
612  else
613  {
614  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
615  e_ptr,
616  make_tuple(kargs.M, kargs.N),
617  make_tuple(1, kargs.stride_E),
618  number<1>{},
619  number<1>{});
620  }
621  }();
622 
623  return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, e_tensor_view);
624  }
625 
626  template <typename TensorView>
627  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
628  {
629  const auto& a_pad_view = [&]() {
630  const auto& a_tensor_view = views.at(I0);
631  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
632  {
633  return pad_tensor_view(a_tensor_view,
637  }
638  else
639  {
640  return pad_tensor_view(a_tensor_view,
644  }
645  }();
646 
647  const auto& b_pad_view = [&]() {
648  const auto& b_tensor_view = views.at(I1);
649  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
650  {
651  return pad_tensor_view(b_tensor_view,
655  }
656  else
657  {
658  return pad_tensor_view(b_tensor_view,
662  }
663  }();
664 
665  const auto& ds_pad_view = generate_tuple(
666  [&](auto i) {
667  const auto& d_tensor_view = views.at(I2);
668  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
669  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
670  {
671  return pad_tensor_view(d_tensor_view[i],
675  }
676  else
677  {
678  return pad_tensor_view(d_tensor_view[i],
682  }
683  },
685 
686  // TODO vector write in for C in ColMajor
687  const auto& e_pad_view = [&]() {
688  const auto& e_tensor_view = views.at(I3);
689  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
690  {
691  return pad_tensor_view(e_tensor_view,
695  }
696  else
697  {
698  return pad_tensor_view(e_tensor_view,
702  }
703  }();
704 
705  return make_tuple(a_pad_view, b_pad_view, ds_pad_view, e_pad_view);
706  }
707 
708  template <typename PadView>
709  CK_TILE_DEVICE static auto
710  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
711  {
712  const auto& a_pad_view = views.at(I0);
713  const auto& b_pad_view = views.at(I1);
714  const auto& ds_pad_view = views.at(I2);
715  const auto& e_pad_view = views.at(I3);
716 
717  const auto& a_block_window = [&]() {
718  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
719  {
720  return make_tile_window(a_pad_view,
723  {i_m, 0});
724  }
725  else
726  {
727  return make_tile_window(a_pad_view,
730  {0, i_m});
731  }
732  }();
733 
734  const auto& b_block_window = [&]() {
735  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
736  {
737  return make_tile_window(b_pad_view,
740  {i_n, 0});
741  }
742  else
743  {
744  return make_tile_window(b_pad_view,
747  {0, i_n});
748  }
749  }();
750 
751  const auto ds_block_window = generate_tuple(
752  [&](auto i) {
753  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
754  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
755  {
756  return make_tile_window(ds_pad_view[i],
759  {i_m, i_n});
760  }
761  else
762  {
763  return make_tile_window(ds_pad_view[i],
766  {i_n, i_m});
767  }
768  },
770 
771  auto e_block_window = make_tile_window(
772  e_pad_view,
774  {i_m, i_n});
775 
776  return make_tuple(a_block_window, b_block_window, ds_block_window, e_block_window);
777  }
778 
793  template <bool UseDefaultScheduler = true>
794  CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
795  const BDataType* b_ptr,
796  const std::array<const void*, NumDTensor>& ds_ptr,
797  EDataType* e_ptr,
798  void* smem_ptr_0,
799  const KernelArgs& kargs,
800  const SplitKBatchOffset& splitk_batch_offset,
801  const index_t block_idx_m,
802  const index_t block_idx_n)
803  {
804  // Create Gemm tensor views, pad views and tile windows
805  const auto& gemm_tensor_views_tuple =
806  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
807  a_ptr, b_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
808 
809  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
810  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
811 
812  const index_t num_loop = __builtin_amdgcn_readfirstlane(
813  TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
814 
815  // Run GEMM cooperatively by whole workgroup.
816  const auto& a_block_window = gemm_tile_windows.at(I0);
817  const auto& b_block_window = gemm_tile_windows.at(I1);
818  const auto& d_block_window = gemm_tile_windows.at(I2);
819 
820  const auto& c_block_tile = GemmPipeline{}.template operator()(
821  a_block_window, b_block_window, num_loop, smem_ptr_0);
822 
823  if(UseDefaultScheduler || (get_warp_id() == 0))
824  {
825  // Run Epilogue Pipeline
826  auto& c_block_window = gemm_tile_windows.at(I3);
827 
828  EpiloguePipeline{}.template
829  operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
830  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
831  }
832  }
833 
851  CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr,
852  const BDataType* b_ptr,
853  const std::array<const void*, NumDTensor>& ds_ptr,
854  EDataType* e_ptr,
855  void* __restrict__ smem_ptr_0,
856  void* __restrict__ smem_ptr_1,
857  const KernelArgs& kargs,
858  const SplitKBatchOffset& splitk_batch_offset,
859  const index_t block_idx_m,
860  const index_t block_idx_n)
861  {
862  // Create Gemm tensor views, pad views and tile windows
863  const auto& gemm_tensor_views_tuple =
864  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
865  a_ptr, b_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
866 
867  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
868  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
869 
870  const index_t num_loop = __builtin_amdgcn_readfirstlane(
871  TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
872 
873  // Run GEMM cooperatively by whole workgroup.
874  const auto& a_block_window = gemm_tile_windows.at(I0);
875  const auto& b_block_window = gemm_tile_windows.at(I1);
876  const auto& d_block_window = gemm_tile_windows.at(I2);
877 
878  const auto& c_block_tile = GemmPipeline{}.template operator()(
879  a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
880 
881  // Run Epilogue Pipeline
882  auto& c_block_window = gemm_tile_windows.at(I3);
883 
884  EpiloguePipeline{}.template
885  operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
886  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
887  }
888 
889  // Non-persistent kernel entry point
890  template <bool U = !PersistentKernel, typename = std::enable_if_t<U>>
892  {
893  const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
894  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
895  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
896  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
897 
898  const SplitKBatchOffset splitk_batch_offset(kargs);
899 
900  // options
901  const ADataType* a_ptr =
902  static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
903  const BDataType* b_ptr =
904  static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
905 
906  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
907 
908  // allocate LDS
909  __shared__ char smem_ptr_0[GetSmemSize()];
910 
911  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
912  {
913  __shared__ char smem_ptr_1[GetSmemSize()];
914  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
915  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
917  {
918  RunGemm2LDS(a_ptr,
919  b_ptr,
920  kargs.ds_ptr,
921  e_ptr,
922  smem_ptr_0,
923  smem_ptr_1,
924  kargs,
925  splitk_batch_offset,
926  i_m,
927  i_n);
928  }
929  }
930  else
931  {
932  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
933  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
935  {
936  constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
937  RunGemm<scheduler_type>(a_ptr,
938  b_ptr,
939  kargs.ds_ptr,
940  e_ptr,
941  smem_ptr_0,
942  kargs,
943  splitk_batch_offset,
944  i_m,
945  i_n);
946  }
947  }
948  }
949 
950  // Persistent kernel entry point
951  template <bool U = PersistentKernel, typename = std::enable_if_t<U>, typename = void>
953  {
954  const auto grid_size = __builtin_amdgcn_readfirstlane(get_grid_size());
955  const auto num_tiles =
956  __builtin_amdgcn_readfirstlane(TilePartitioner::GridSize(kargs.M, kargs.N));
957  const auto num_work = __builtin_amdgcn_readfirstlane(num_tiles * kargs.k_batch);
958  auto block_id = __builtin_amdgcn_readfirstlane(get_block_id());
959 
960  while(block_id < num_work)
961  {
962  // Get the tile index for this block
963  const auto tile_idx = __builtin_amdgcn_readfirstlane(block_id % num_tiles);
964  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx);
965  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
966  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
967 
968  // Get the SplitK offset for this block
969  const auto k_batch = __builtin_amdgcn_readfirstlane(block_id / num_tiles);
970  const SplitKBatchOffset splitk_batch_offset(kargs, k_batch);
971  const ADataType* a_ptr =
972  static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
973  const BDataType* b_ptr =
974  static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
975  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
976 
977  // allocate LDS
978  __shared__ char smem_ptr_0[GetSmemSize()];
979  // Run the GEMM
980  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
981  {
982  __shared__ char smem_ptr_1[GetSmemSize()];
983  if constexpr(!(EpiloguePipeline::MemoryOperation ==
985  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
987  {
988  RunGemm2LDS(a_ptr,
989  b_ptr,
990  kargs.ds_ptr,
991  e_ptr,
992  smem_ptr_0,
993  smem_ptr_1,
994  kargs,
995  splitk_batch_offset,
996  i_m,
997  i_n);
998  }
999  }
1000  else
1001  {
1002  if constexpr(!(EpiloguePipeline::MemoryOperation ==
1004  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1006  {
1007  RunGemm(a_ptr,
1008  b_ptr,
1009  kargs.ds_ptr,
1010  e_ptr,
1011  smem_ptr_0,
1012  kargs,
1013  splitk_batch_offset,
1014  i_m,
1015  i_n);
1016  }
1017  }
1018  // Advance to the next work item
1019  block_id += grid_size;
1020  if(block_id >= num_work)
1021  {
1022  break;
1023  }
1024  }
1025  }
1026 };
1027 
1028 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:255
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1672
CK_TILE_HOST void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:529
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1615
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:41
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:184
CK_TILE_DEVICE index_t get_warp_id()
Definition: arch.hpp:74
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:406
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
CK_TILE_DEVICE index_t get_block_id()
Definition: arch.hpp:81
CK_TILE_DEVICE index_t get_grid_size()
Definition: arch.hpp:60
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
The GEMM kernel host arguments.
Definition: gemm_kernel.hpp:30
index_t M
Definition: gemm_kernel.hpp:67
index_t K
Definition: gemm_kernel.hpp:69
index_t stride_E
Definition: gemm_kernel.hpp:75
const void * b_ptr
Definition: gemm_kernel.hpp:60
const std::array< index_t, NumDTensor > stride_Ds
Definition: gemm_kernel.hpp:72
index_t k_batch
Definition: gemm_kernel.hpp:79
index_t stride_A
Definition: gemm_kernel.hpp:70
const void * a_ptr
Definition: gemm_kernel.hpp:59
index_t N
Definition: gemm_kernel.hpp:68
index_t stride_B
Definition: gemm_kernel.hpp:71
index_t stride_C
Definition: gemm_kernel.hpp:76
CK_TILE_HOST GemmHostArgs(const void *a_ptr_, const void *b_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition: gemm_kernel.hpp:32
void * e_ptr
Definition: gemm_kernel.hpp:64
CK_TILE_HOST GemmHostArgs()=default
void * c_ptr
Definition: gemm_kernel.hpp:65
const std::array< const void *, NumDTensor > ds_ptr
Definition: gemm_kernel.hpp:61
Definition: gemm_kernel.hpp:251
index_t b_k_split_offset
Definition: gemm_kernel.hpp:287
index_t a_k_split_offset
Definition: gemm_kernel.hpp:286
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: gemm_kernel.hpp:252
index_t splitted_k
Definition: gemm_kernel.hpp:288
Definition: gemm_kernel.hpp:167
static constexpr bool value
Definition: gemm_kernel.hpp:171
decltype(T::UsePersistentKernel) has_persistent_type
Definition: gemm_kernel.hpp:169
The GEMM kernel device arguments.
Definition: gemm_kernel.hpp:85
const void * a_ptr
The A input tensor's pointer to device memory.
Definition: gemm_kernel.hpp:87
index_t stride_A
The distance between consecutive elements of non-contiguous dimension (in memory) of A tensor.
Definition: gemm_kernel.hpp:102
const void * b_ptr
The B input tensor's pointer to device memory.
Definition: gemm_kernel.hpp:89
index_t N
GEMM's N dimension size.
Definition: gemm_kernel.hpp:97
void * e_ptr
The E output tensor's pointer to device memory.
Definition: gemm_kernel.hpp:93
index_t k_batch
Definition: gemm_kernel.hpp:112
const std::array< const void *, NumDTensor > ds_ptr
The Ds input tensor's pointer to device memory.
Definition: gemm_kernel.hpp:91
index_t K
GEMM's K dimension size.
Definition: gemm_kernel.hpp:99
index_t stride_B
The distance between consecutive elements of non-contiguous dimension (in memory) of B tensor.
Definition: gemm_kernel.hpp:105
index_t M
GEMM's M dimension size.
Definition: gemm_kernel.hpp:95
std::array< index_t, NumDTensor > stride_Ds
The distance between consecutive elements of non-contiguous dimension (in memory) of Ds tensor.
Definition: gemm_kernel.hpp:108
index_t stride_E
The distance between consecutive elements of non-contiguous dimension (in memory) of E tensor.
Definition: gemm_kernel.hpp:111
The GEMM kernel template.
Definition: gemm_kernel.hpp:153
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: gemm_kernel.hpp:291
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: gemm_kernel.hpp:183
static constexpr CK_TILE_HOST KernelArgs MakeKernelArgs(const GemmHostArgs< NumDTensor > &hostArgs)
Definition: gemm_kernel.hpp:228
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: gemm_kernel.hpp:891
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Definition: gemm_kernel.hpp:180
static CK_TILE_HOST const std::string GetName()
Definition: gemm_kernel.hpp:196
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: gemm_kernel.hpp:161
static constexpr auto I3
Definition: gemm_kernel.hpp:190
remove_cvref_t< typename GemmPipeline::CLayout > ELayout
Definition: gemm_kernel.hpp:160
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: gemm_kernel.hpp:203
static CK_TILE_DEVICE void RunGemm(const ADataType *a_ptr, const BDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_kernel.hpp:794
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: gemm_kernel.hpp:157
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: gemm_kernel.hpp:484
static constexpr index_t KernelBlockSize
Definition: gemm_kernel.hpp:163
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: gemm_kernel.hpp:181
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: gemm_kernel.hpp:710
static constexpr auto I0
Definition: gemm_kernel.hpp:187
static CK_TILE_DEVICE void RunGemm2LDS(const ADataType *a_ptr, const BDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_kernel.hpp:851
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: gemm_kernel.hpp:214
static constexpr auto I1
Definition: gemm_kernel.hpp:188
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: gemm_kernel.hpp:245
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_kernel.hpp:155
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: gemm_kernel.hpp:162
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: gemm_kernel.hpp:627
GemmKernelArgs< DsLayout::size()> KernelArgs
Definition: gemm_kernel.hpp:194
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: gemm_kernel.hpp:158
static constexpr index_t NumDTensor
Definition: gemm_kernel.hpp:185
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_kernel.hpp:154
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: gemm_kernel.hpp:952
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_kernel.hpp:156
static constexpr bool PersistentKernel
Definition: gemm_kernel.hpp:178
static constexpr auto I2
Definition: gemm_kernel.hpp:189
static constexpr CK_TILE_HOST auto BlockSize()
Definition: gemm_kernel.hpp:225
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: sequence.hpp:52
Definition: functional.hpp:43
Definition: stream_config.hpp:26
#define CK_TILE_ENV(name)
Definition: env.hpp:145