/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp Source File
streamk_gemm_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
7 #include "ck_tile/ops/common.hpp"
9 
10 namespace ck_tile {
11 
22 {
23  CK_TILE_HOST explicit StreamKHostArgs(const void* a_ptr_,
24  const void* b_ptr_,
25  void* c_ptr_,
26  index_t M_,
27  index_t N_,
28  index_t K_,
29  index_t stride_A_,
30  index_t stride_B_,
31  index_t stride_C_)
32  : UniversalGemmHostArgs<>({a_ptr_},
33  {b_ptr_},
34  {/*ds_ptr*/},
35  c_ptr_,
36  /*k_batch_ =*/1,
37  M_,
38  N_,
39  K_,
40  {stride_A_},
41  {stride_B_},
42  {/*stride_Ds_*/},
43  stride_C_)
44  {
45  }
46 };
47 
61 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
63 {
70 
73 
74  using TilePartitioner = TilePartitioner_;
75  using GemmPipeline = GemmPipeline_;
76  using EpiloguePipeline = EpiloguePipeline_;
77 
78  static_assert(
79  TilePartitioner::PERSISTENT == PersistentDP,
80  "Persistent flag from TilePartitioner must match Persistent flag from UniversalGemm.");
81 
85  using ALayout = typename GemmPipeline::ALayout;
86  using BLayout = typename GemmPipeline::BLayout;
87  using CLayout = typename GemmPipeline::CLayout;
88 
92  using ADataType = typename GemmPipeline::ADataType;
93  using BDataType = typename GemmPipeline::BDataType;
94  using CDataType = typename EpiloguePipeline::ODataType;
95  using AccDataType = typename EpiloguePipeline::AccDataType;
96 
97  template <typename T>
102  static_assert(!is_tuple_v<ALayout> && !is_tuple_v<ADataType>,
103  "ALayout and ADataType must be scalars.");
104 
108  static_assert(!is_tuple_v<BLayout> && !is_tuple_v<BDataType>,
109  "BLayout and BDataType must be scalars.");
110 
114  static_assert(!is_tuple_v<CLayout> && !is_tuple_v<CDataType>,
115  "CLayout and CDataType must be scalars.");
116 
118  {
119  StreamKKernelArgs(const StreamKHostArgs& host_args, index_t grid)
120  : UniversalGemmKernelArgs{host_args.as_ptr,
121  host_args.bs_ptr,
122  host_args.ds_ptr,
123  host_args.e_ptr,
124  host_args.M,
125  host_args.N,
126  host_args.K,
127  host_args.stride_As,
128  host_args.stride_Bs,
129  host_args.stride_Ds,
130  host_args.stride_E,
131  host_args.k_batch},
132  // The workspace pointer is set to nullptr because we must first
133  // instantiate the TilePartitioner to get the necessary size
134  workspace_ptr{nullptr},
135  tile_partitioner{TilePartitioner{host_args.M, host_args.N, host_args.K, grid}}
136 
137  {
138  }
149  };
150 
153 
154  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
155  {
156  // clang-format off
157  using P_ = GemmPipeline;
158  using WarpTile = typename P_::BlockGemmShape::WarpTile;
159 
160  return concat('_', "streamk", gemm_prec_str<ADataType, BDataType>(),
161  concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
162  concat('x', WarpTile::at(number<0>{}), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})),
163  concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
164  concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
165  // clang-format on
166  }
167 
172  CK_TILE_HOST static auto GridSize(const TilePartitioner& tile_partitioner) -> dim3
173  {
174  return tile_partitioner.grid_size();
175  }
176 
183  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
184  {
186  }
187 
188  CK_TILE_HOST static constexpr auto BlockSize() -> dim3
189  {
191  }
192 
203  int num_cu = NumCU(),
204  int occupancy = Occupancy())
205  {
206  const index_t grid = num_cu * occupancy;
207 
208  return StreamKKernelArgs{host_args, grid};
209  }
210 
211  template <bool UseDefaultScheduler = true>
212  CK_TILE_DEVICE static void
213  RunGemm(const std::array<const ADataType*, UniversalGemmKernel::NumATensor>& as_ptr,
214  const std::array<const BDataType*, UniversalGemmKernel::NumBTensor>& bs_ptr,
215  const std::array<const void*, UniversalGemmKernel::NumDTensor>& ds_ptr,
216  CDataType* c_ptr,
217  void* smem_ptr_0,
218  const typename UniversalGemmKernel::KernelArgs& kargs,
219  const index_t num_loop,
220  const index_t block_idx_m,
221  const index_t block_idx_n,
222  const index_t k_size)
223  {
224  // Create Gemm tensor views, pad views and tile windows
225  const auto& gemm_tensor_views_tuple =
226  UniversalGemmKernel::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
227  as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size);
228 
229  const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
230  auto gemm_tile_windows =
231  UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
232 
233  // Run GEMM cooperatively by whole workgroup.
234  const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
235  const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
236  const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
237 
238  // Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
239  // has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
240  // case, we call the GemmPipeline's operator() function that takes both has_hot_loop and
241  // tail_num.
242  const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
243  const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
244 
245  const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
246  bs_block_window[UniversalGemmKernel::I0],
247  num_loop,
248  has_hot_loop,
249  tail_num,
250  smem_ptr_0);
251 
252  if(UseDefaultScheduler || (get_warp_id() == 0))
253  {
254  // Run Epilogue Pipeline
255  auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
256 
257  EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
258  }
259  }
260 
262  {
264  }
265 
271  {
272  return kargs.tile_partitioner.get_workspace_size(sizeof(AccDataType));
273  }
278  CK_TILE_HOST static void SetWorkSpacePointer(StreamKKernelArgs& kargs, void* workspace_ptr)
279  {
280  kargs.workspace_ptr = workspace_ptr;
281  }
282 
296  index_t tile_idx,
297  index_t num_loop,
298  index_t i_k_a,
299  index_t i_k_b,
300  index_t k_size,
301  void* smem_ptr_0) const
302  {
303  const auto c_macro_tile_idx = kargs.tile_partitioner.get_output_tile_index(tile_idx);
304  index_t i_m = c_macro_tile_idx[UniversalGemmKernel::I0] * TilePartitioner::MPerBlock;
305  index_t i_n = c_macro_tile_idx[UniversalGemmKernel::I1] * TilePartitioner::NPerBlock;
306 
307  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a;
308  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
309  CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
310 
311  // Run the GEMM pipeline and Epilogue.
312  RunGemm(
313  {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, k_size);
314  }
315 
325  index_t cta_idx) const
326  {
327  auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
328  workgroup_barrier sk_flags(sk_flags_ptr);
329  sk_flags.wait_set(0, 1, cta_idx);
330  }
331 
340  {
341  auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
342  workgroup_barrier sk_flags(sk_flags_ptr);
343  sk_flags.wait_eq(1, cta_idx);
344  }
345 
353  template <typename OAccTile>
354  CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile,
355  const OAccTile& in_block_tile) const
356  {
357  using BlockType = remove_cvref_t<decltype(in_out_block_tile)>;
358  constexpr auto o_spans = BlockType::get_distributed_spans();
359  sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
360  sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
361  constexpr auto idx = make_tuple(idx0, idx1);
362  in_out_block_tile(idx) = in_out_block_tile[idx] + in_block_tile[idx];
363  });
364  });
365  }
366 
376  template <typename DataType, typename OAccTileDist>
378  index_t cta_idx,
379  const OAccTileDist& c_block_tile_dist) const
380  {
381  const auto c_block_tile_buffer_size =
382  TilePartitioner::MPerBlock * TilePartitioner::NPerBlock * sizeof(DataType);
383  void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) +
384  kargs.tile_partitioner.get_flags_buffer_size() +
385  cta_idx * c_block_tile_buffer_size;
386 
387  const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
388  static_cast<DataType*>(partial_buffer_ptr),
390  make_tuple(TilePartitioner::NPerBlock, 1),
391  number<GemmPipeline::GetVectorSizeC()>{},
392  number<1>{});
393 
394  auto partial_tile_window = make_tile_window(
395  partial_tensor_view,
397  {0, 0},
398  c_block_tile_dist);
399 
400  return load_tile(partial_tile_window);
401  }
402 
411  template <typename OAccTile>
413  index_t cta_idx,
414  const OAccTile& c_block_tile) const
415  {
416  const auto c_block_tile_buffer_size = TilePartitioner::MPerBlock *
417  TilePartitioner::NPerBlock *
418  sizeof(typename OAccTile::DataType);
419  void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) +
420  kargs.tile_partitioner.get_flags_buffer_size() +
421  cta_idx * c_block_tile_buffer_size;
422 
423  const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
424  static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
426  make_tuple(TilePartitioner::NPerBlock, 1),
427  number<GemmPipeline::GetVectorSizeC()>{},
428  number<1>{});
429 
430  auto partial_tile_window = make_tile_window(
431  partial_tensor_view,
433  {0, 0});
434 
435  store_tile(partial_tile_window, c_block_tile);
436  }
437 
448  void StreamKGemm(StreamKKernelArgs& kargs, index_t cta_idx, void* smem_ptr_0) const
449  {
450  index_t iter_start, iter_end;
451  kargs.tile_partitioner.get_iter_boundaries(iter_start, iter_end, cta_idx);
452 
453  while(iter_start < iter_end)
454  {
455  // Get the 1D tile index in the C tensor that this workgroup will work in for this
456  // iteration of the loop.
457  index_t tile_idx =
458  amd_wave_read_first_lane(kargs.tile_partitioner.get_tile_index(iter_start));
459 
460  // Get the start and end boundaries for the current tile.
461  index_t tile_iter_start, tile_iter_end;
462  kargs.tile_partitioner.get_tile_boundaries(tile_iter_start, tile_iter_end, tile_idx);
463 
464  // Get the start and end iteration within the current tile for the workgroup.
465  index_t local_iter_start = amd_wave_read_first_lane(
466  kargs.tile_partitioner.get_local_iter(iter_start, tile_iter_start));
467  index_t local_iter_end =
468  amd_wave_read_first_lane(kargs.tile_partitioner.get_local_iter_end(
469  tile_iter_start, iter_end, tile_iter_end));
470 
471  // Get the iteration length.
472  index_t num_loop_sk = local_iter_end - local_iter_start;
473 
474  // Determine the total size along the K dimension the workgroup is using in this
475  // iteration (used to construct tensor views).
476  index_t k_size = num_loop_sk * TilePartitioner::KPerBlock;
477 
478  // Get the K offsets for the A and B tensors
479  auto [i_k_a, i_k_b] = GetKOffsets<ALayout, BLayout>(
480  local_iter_start, kargs.stride_As[0], kargs.stride_Bs[0]);
481 
482  if constexpr(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Atomic)
483  {
484  BaseGemm(kargs, tile_idx, num_loop_sk, i_k_a, i_k_b, k_size, smem_ptr_0);
485  }
486  else
487  {
488  const auto c_macro_tile_idx =
489  kargs.tile_partitioner.get_output_tile_index(tile_idx);
490  index_t i_m =
491  c_macro_tile_idx[UniversalGemmKernel::I0] * TilePartitioner::MPerBlock;
492  index_t i_n =
493  c_macro_tile_idx[UniversalGemmKernel::I1] * TilePartitioner::NPerBlock;
494 
495  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a;
496  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
497  CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
498 
499  // Create Gemm tensor views, pad views and tile windows
500  const auto& gemm_tensor_views_tuple =
501  UniversalGemmKernel::template MakeGemmTensorViews<
502  EpiloguePipeline::MemoryOperation>(
503  {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, k_size);
504 
505  const auto& gemm_pad_views =
506  UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
507  auto gemm_tile_windows =
508  UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, i_m, i_n);
509 
510  // Run GEMM cooperatively by whole workgroup.
511  const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
512  const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
513  const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
514 
515  // Since num_loop can vary per WG and per iteration of the Stream-K while loop,
516  // we compute has_hot_loop and tail_num here. This is a similar pattern used by
517  // grouped GEMM. In this case, we call the GemmPipeline's operator() function
518  // that takes both has_hot_loop and tail_num.
519  const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop_sk);
520  const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop_sk);
521 
522  const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
523  bs_block_window[UniversalGemmKernel::I0],
524  num_loop_sk,
525  has_hot_loop,
526  tail_num,
527  smem_ptr_0);
528 
529  auto tile_started = iter_start == tile_iter_start;
530  auto tile_ended = iter_end >= tile_iter_end;
531  if(!tile_started)
532  {
533  StorePartial(kargs, cta_idx, c_block_tile);
534  // Ensure device-wide visibility of partial results stored in global memory
535  // before signaling completion. __threadfence() guarantees that all global
536  // memory writes by this thread are visible to other threads on the device.
537  __threadfence(); // send signal when the store is done
538  SignalStorePartialDone(kargs, cta_idx);
539  }
540  else
541  {
542  auto accum_block_tile = c_block_tile;
543  if(!tile_ended)
544  {
545  const index_t iter_per_tile = kargs.tile_partitioner.get_iters_per_tile();
546  const index_t iter_per_cta = kargs.tile_partitioner.get_iters_per_sk_cta();
547  const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
548  int accum_iters = local_iter_end - local_iter_start;
549  int next_cta = cta_idx + 1;
550 
551  while(accum_iters < iter_per_tile)
552  {
553  WaitStorePartialDone(kargs, next_cta);
554 
555  using BlockType = remove_cvref_t<decltype(c_block_tile)>;
556  AddBlockTile(
557  accum_block_tile,
558  LoadPartial<typename BlockType::DataType>(
559  kargs, next_cta, c_block_tile.get_tile_distribution()));
560 
561  accum_iters += iter_per_cta + (next_cta < extra_iters);
562  ++next_cta;
563  }
564  }
565 
566  auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
568  c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
569  }
570  }
571 
572  // Prepare for next Stream-K loop iteration.
573  iter_start = tile_iter_end;
574  block_sync_lds();
575  }
576  }
577 
587  template <bool U = PersistentDP>
588  CK_TILE_DEVICE typename std::enable_if_t<!U> operator()(StreamKKernelArgs kargs) const
589  {
590  // Allocate LDS
591  __shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
592 
593  index_t block_idx = ck_tile::get_block_1d_id();
594  index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
595  index_t dp_ctas = kargs.tile_partitioner.get_dp_ctas();
596  bool is_dp_ctas = block_idx < kargs.tile_partitioner.get_dp_ctas();
597 
598  // Check if at the data parallel section
599  if(is_dp_ctas)
600  {
601  BaseGemm(kargs, block_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
602  }
603  else
604  {
605  // Stream-K
606  StreamKGemm(kargs, block_idx - dp_ctas, smem_ptr_0);
607  }
608  }
609 
620  template <bool U = PersistentDP>
621  CK_TILE_DEVICE typename std::enable_if_t<U> operator()(StreamKKernelArgs kargs) const
622  {
623  // Allocate LDS
624  __shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
625 
626  index_t block_idx = ck_tile::get_block_1d_id();
627  index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
628 
629  // Data-parallel section
630  for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles();
631  tile_idx += kargs.tile_partitioner.get_grid())
632  {
633  BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
634  }
635 
636  // Stream-K section
637  StreamKGemm(kargs, block_idx, smem_ptr_0);
638  }
639 
640  private:
649  template <typename ALayout, typename BLayout>
651  GetKOffsets(index_t iter_offset, index_t stride_a, index_t stride_b)
652  {
653  index_t stride_offset_a;
654  index_t stride_offset_b;
655  if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
656  {
657  stride_offset_a = stride_a;
658  }
659  else
660  {
661  stride_offset_a = 1;
662  }
663 
664  if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
665  {
666  stride_offset_b = stride_b;
667  }
668  else
669  {
670  stride_offset_b = 1;
671  }
672 
673  index_t base_offset = iter_offset * TilePartitioner::KPerBlock;
674 
675  return make_tuple(base_offset * stride_offset_a, base_offset * stride_offset_b);
676  }
677 
678  CK_TILE_HOST static int NumCU()
679  {
680  hipDeviceProp_t dev_prop;
681  hipDevice_t dev;
682  hip_check_error(hipGetDevice(&dev));
683  hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
684  int num_cu = dev_prop.multiProcessorCount;
685 
686  return num_cu;
687  }
688 
695  CK_TILE_HOST static int Occupancy()
696  {
697  int occupancy;
698 
699  // Since occupancy of 1 is valid for stream k, we set min_num_block_per_cu to 1
700  constexpr int min_block_per_cu = 1;
701  const auto kernel = kentry<min_block_per_cu, Kernel, KernelArgs>;
702 
704  hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
705 
706  return max(occupancy, 1);
707  }
708 };
709 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
Definition: cluster_descriptor.hpp:13
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:21
CK_TILE_HOST void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:13
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:24
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:36
@ Atomic
Definition: block_to_ctile_map.hpp:1012
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
unsigned int uint32_t
Definition: stdint.h:126
The Stream K GEMM kernel host arguments.
Definition: streamk_gemm_kernel.hpp:22
CK_TILE_HOST StreamKHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
Definition: streamk_gemm_kernel.hpp:23
ALayout and ADataType are expected to be scalars, not a tuple.
Definition: streamk_gemm_kernel.hpp:118
void * workspace_ptr
A pointer to a buffer in device memory for accumulating partial via reduction strategy.
Definition: streamk_gemm_kernel.hpp:143
TilePartitioner tile_partitioner
An instance of the TilePartioner class for assisting with mapping workgroups to the C tensor.
Definition: streamk_gemm_kernel.hpp:148
StreamKKernelArgs(const StreamKHostArgs &host_args, index_t grid)
Definition: streamk_gemm_kernel.hpp:119
The Stream K GEMM kernel class.
Definition: streamk_gemm_kernel.hpp:63
CK_TILE_DEVICE void StreamKGemm(StreamKKernelArgs &kargs, index_t cta_idx, void *smem_ptr_0) const
Runs the main Stream - K algorithm.
Definition: streamk_gemm_kernel.hpp:448
GemmPipeline_ GemmPipeline
Definition: streamk_gemm_kernel.hpp:75
UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > UniversalGemmKernel
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Definition: streamk_gemm_kernel.hpp:69
typename GemmPipeline::BLayout BLayout
Definition: streamk_gemm_kernel.hpp:86
typename GemmPipeline::ALayout ALayout
Specify the layout configurations for A, B, and C.
Definition: streamk_gemm_kernel.hpp:85
typename EpiloguePipeline::ODataType CDataType
Definition: streamk_gemm_kernel.hpp:94
static CK_TILE_HOST auto GridSize(const TilePartitioner &tile_partitioner) -> dim3
Compute the grid size for the Stream K kernel using the tile_partitioner.
Definition: streamk_gemm_kernel.hpp:172
CK_TILE_DEVICE std::enable_if_t< U > operator()(StreamKKernelArgs kargs) const
Entry point for the Stream-K Kernel with persistent DP.
Definition: streamk_gemm_kernel.hpp:621
typename GemmPipeline::CLayout CLayout
Definition: streamk_gemm_kernel.hpp:87
static constexpr bool is_tuple_v
Definition: streamk_gemm_kernel.hpp:98
static CK_TILE_HOST StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs &host_args, int num_cu=NumCU(), int occupancy=Occupancy())
Constructs kernel arguments for the Stream-K kernel.
Definition: streamk_gemm_kernel.hpp:202
typename GemmPipeline::BDataType BDataType
Definition: streamk_gemm_kernel.hpp:93
TilePartitioner_ TilePartitioner
Definition: streamk_gemm_kernel.hpp:74
static constexpr bool PersistentDP
Definition: streamk_gemm_kernel.hpp:72
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, UniversalGemmKernel::NumATensor > &as_ptr, const std::array< const BDataType *, UniversalGemmKernel::NumBTensor > &bs_ptr, const std::array< const void *, UniversalGemmKernel::NumDTensor > &ds_ptr, CDataType *c_ptr, void *smem_ptr_0, const typename UniversalGemmKernel::KernelArgs &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t k_size)
Definition: streamk_gemm_kernel.hpp:213
typename GemmPipeline::ADataType ADataType
Specify the data type configurations for A, B, and C.
Definition: streamk_gemm_kernel.hpp:92
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: streamk_gemm_kernel.hpp:183
static CK_TILE_HOST void SetWorkSpacePointer(StreamKKernelArgs &kargs, void *workspace_ptr)
Sets the kargs' current workspace_ptr to the given workspace_ptr.
Definition: streamk_gemm_kernel.hpp:278
CK_TILE_DEVICE void AddBlockTile(OAccTile &in_out_block_tile, const OAccTile &in_block_tile) const
Adds the values of a block tile to an output block tile.
Definition: streamk_gemm_kernel.hpp:354
CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs &kargs, index_t cta_idx) const
Signals that the current thread block(CTA) has completed storing its partial results.
Definition: streamk_gemm_kernel.hpp:324
static constexpr index_t kBlockSize
Definition: streamk_gemm_kernel.hpp:71
typename EpiloguePipeline::AccDataType AccDataType
Definition: streamk_gemm_kernel.hpp:95
EpiloguePipeline_ EpiloguePipeline
Definition: streamk_gemm_kernel.hpp:76
static CK_TILE_HOST const std::string GetName()
Definition: streamk_gemm_kernel.hpp:154
static constexpr CK_TILE_HOST auto BlockSize() -> dim3
Definition: streamk_gemm_kernel.hpp:188
static CK_TILE_HOST bool IsSupportedArgument(const StreamKKernelArgs &kargs)
Definition: streamk_gemm_kernel.hpp:261
CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs &kargs, index_t cta_idx, const OAccTile &c_block_tile) const
Stores a partial block tile to the workspace buffer.
Definition: streamk_gemm_kernel.hpp:412
CK_TILE_DEVICE void BaseGemm(StreamKKernelArgs &kargs, index_t tile_idx, index_t num_loop, index_t i_k_a, index_t i_k_b, index_t k_size, void *smem_ptr_0) const
Computes offsets into A, B, and C tensors then runs the GEMM pipeline and epilogue.
Definition: streamk_gemm_kernel.hpp:295
CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs &kargs, index_t cta_idx) const
Waits for the thread block (cta_idx) to complete storing its partial results.
Definition: streamk_gemm_kernel.hpp:339
static CK_TILE_HOST uint32_t GetWorkSpaceSize(const StreamKKernelArgs &kargs)
Computes the buffer size needed to store accumulation results for Stream K.
Definition: streamk_gemm_kernel.hpp:270
CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs &kargs, index_t cta_idx, const OAccTileDist &c_block_tile_dist) const
Loads a partial block tile from the workspace buffer.
Definition: streamk_gemm_kernel.hpp:377
CK_TILE_DEVICE std::enable_if_t<!U > operator()(StreamKKernelArgs kargs) const
Entry point for the Stream-K Kernel with non-persistent DP.
Definition: streamk_gemm_kernel.hpp:588
The Universal GEMM kernel host arguments.
Definition: universal_gemm_kernel.hpp:32
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:86
void * e_ptr
The E output tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:94
std::array< index_t, NumBTensor > stride_Bs
The distance between consecutive elements of non-contiguous dimension (in memory) of Bs tensor.
Definition: universal_gemm_kernel.hpp:106
const std::array< const void *, NumDTensor > ds_ptr
The Ds input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:92
std::array< index_t, NumATensor > stride_As
The distance between consecutive elements of non-contiguous dimension (in memory) of As tensor.
Definition: universal_gemm_kernel.hpp:103
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:88
index_t k_batch
Definition: universal_gemm_kernel.hpp:113
index_t N
GEMM's N dimension size.
Definition: universal_gemm_kernel.hpp:98
index_t stride_E
The distance between consecutive elements of non-contiguous dimension (in memory) of E tensor.
Definition: universal_gemm_kernel.hpp:112
index_t K
GEMM's K dimension size.
Definition: universal_gemm_kernel.hpp:100
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:90
std::array< index_t, NumDTensor > stride_Ds
The distance between consecutive elements of non-contiguous dimension (in memory) of Ds tensor.
Definition: universal_gemm_kernel.hpp:109
index_t M
GEMM's M dimension size.
Definition: universal_gemm_kernel.hpp:96
static constexpr auto I2
Definition: universal_gemm_kernel.hpp:238
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: universal_gemm_kernel.hpp:852
static constexpr auto I3
Definition: universal_gemm_kernel.hpp:239
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: universal_gemm_kernel.hpp:753
static constexpr bool PersistentKernel
Definition: universal_gemm_kernel.hpp:217
static constexpr auto I1
Definition: universal_gemm_kernel.hpp:237
static CK_TILE_HOST auto BlockSize()
Definition: universal_gemm_kernel.hpp:290
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: universal_gemm_kernel.hpp:278
static constexpr auto I0
Definition: universal_gemm_kernel.hpp:236
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:373
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: universal_gemm_kernel.hpp:319
static constexpr index_t kBlockSize
Definition: universal_gemm_kernel.hpp:202
Definition: integral_constant.hpp:13
Definition: stream_config.hpp:30
Definition: tuple.hpp:192
Definition: workgroup_barrier.hpp:12
CK_TILE_DEVICE void wait_eq(uint32_t value, uint32_t offset=0)
Definition: workgroup_barrier.hpp:20
CK_TILE_DEVICE void wait_set(uint32_t compare, uint32_t value, uint32_t offset=0)
Definition: workgroup_barrier.hpp:38