/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp Source File
streamk_gemm_kernel.hpp
Go to the documentation of this file.
1 // Copyright © Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
7 #include "ck_tile/ops/common.hpp"
9 
10 namespace ck_tile {
11 namespace reboot {
12 
21 {
22  CK_TILE_HOST explicit StreamKHostArgs(const void* a_ptr_,
23  const void* b_ptr_,
24  void* c_ptr_,
25  index_t M_,
26  index_t N_,
27  index_t K_,
28  index_t stride_A_,
29  index_t stride_B_,
30  index_t stride_C_,
31  StreamKReductionStrategy reduction_strategy_)
32  : UniversalGemmHostArgs<>({a_ptr_},
33  {b_ptr_},
34  {/*ds_ptr*/},
35  c_ptr_,
36  /*k_batch_ =*/1,
37  M_,
38  N_,
39  K_,
40  {stride_A_},
41  {stride_B_},
42  {/*stride_Ds_*/},
43  stride_C_),
44  reduction_strategy{reduction_strategy_}
45  {
46  }
47 
49 };
50 
55 // The main kernel functions are the operator() functions. There is one for Persistent
56 // and one for Non-Persistent data parallel sections of the Stream-K algorithm.
57 //
58 // Both the Non-Persistent and Persistent kernels make use of `BaseGemm()` and
59 // `StreamKGemm()`. `BaseGemm()` computes offsets into the A,B,C tensors, then calls
60 // `RunGemm()` which runs the GEMM pipeline and epilogue. `StreamKGemm()` performs the
61 // main Stream-K algorithm. Each iteration of the Stream-K loop calls `BaseGemm()`.
62 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
64 {
69 
72 
73  using TilePartitioner = TilePartitioner_;
74  using GemmPipeline = GemmPipeline_;
75  using EpiloguePipeline = EpiloguePipeline_;
76 
77  static_assert(
78  TilePartitioner::PERSISTENT == PersistentDP,
79  "Persistent flag from TilePartitioner must match Persistent flag from UniversalGemm.");
80 
82  using ALayout = typename GemmPipeline::ALayout;
83  using BLayout = typename GemmPipeline::BLayout;
84  using CLayout = typename GemmPipeline::CLayout;
85 
87  using ADataType = typename GemmPipeline::ADataType;
88  using BDataType = typename GemmPipeline::BDataType;
89  using CDataType = typename EpiloguePipeline::ODataType;
90  using AccDataType = typename EpiloguePipeline::AccDataType;
91 
92  template <typename T>
94 
96  static_assert(!is_tuple_v<ALayout> && !is_tuple_v<ADataType>,
97  "ALayout and ADataType must be scalars.");
98 
100  static_assert(!is_tuple_v<BLayout> && !is_tuple_v<BDataType>,
101  "BLayout and BDataType must be scalars.");
102 
104  static_assert(!is_tuple_v<CLayout> && !is_tuple_v<CDataType>,
105  "CLayout and CDataType must be scalars.");
106 
108  {
109  StreamKKernelArgs(const StreamKHostArgs& host_args, index_t grid)
110  : UniversalGemmKernelArgs{host_args.as_ptr,
111  host_args.bs_ptr,
112  host_args.ds_ptr,
113  host_args.e_ptr,
114  host_args.M,
115  host_args.N,
116  host_args.K,
117  host_args.stride_As,
118  host_args.stride_Bs,
119  host_args.stride_Ds,
120  host_args.stride_E,
121  host_args.k_batch},
123  // The workspace pointer is set to nullptr because we must first
124  // instantiate the TilePartitioner to get the necessary size
125  workspace_ptr{nullptr},
126  tile_partitioner{TilePartitioner{host_args.M, host_args.N, host_args.K, grid}}
127 
128  {
129  }
130 
139  };
140 
143 
144  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
145  {
146  // clang-format off
147  using P_ = GemmPipeline;
148  using WarpTile = typename P_::BlockGemmShape::WarpTile;
149 
150  return concat('_', "streamk", gemm_prec_str<ADataType, BDataType>(),
151  concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
152  concat('x', WarpTile::at(number<0>{}), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})),
153  concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
154  concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
155  // clang-format on
156  }
157 
160  CK_TILE_HOST static auto GridSize(const TilePartitioner& tile_partitioner) -> dim3
161  {
162  return tile_partitioner.grid_size();
163  }
164 
169  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
170  {
172  }
173 
174  CK_TILE_HOST static constexpr auto BlockSize() -> dim3
175  {
177  }
178 
187  int num_cu = NumCU(),
188  int occupancy = Occupancy())
189  {
190  const index_t grid = num_cu * occupancy;
191 
192  return StreamKKernelArgs{host_args, grid};
193  }
194 
195  template <bool UseDefaultScheduler = true>
196  CK_TILE_DEVICE static void
197  RunGemm(const std::array<const ADataType*, UniversalGemmKernel::NumATensor>& as_ptr,
198  const std::array<const BDataType*, UniversalGemmKernel::NumBTensor>& bs_ptr,
199  const std::array<const void*, UniversalGemmKernel::NumDTensor>& ds_ptr,
200  CDataType* c_ptr,
201  void* smem_ptr_0,
202  const typename UniversalGemmKernel::KernelArgs& kargs,
203  const index_t num_loop,
204  const index_t block_idx_m,
205  const index_t block_idx_n,
206  const index_t k_size)
207  {
208  // Create Gemm tensor views, pad views and tile windows
209  const auto& gemm_tensor_views_tuple =
210  UniversalGemmKernel::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
211  as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size);
212 
213  const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
214  auto gemm_tile_windows =
215  UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
216 
217  // Run GEMM cooperatively by whole workgroup.
218  const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
219  const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
220  const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
221 
222  // Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
223  // has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
224  // case, we call the GemmPipeline's operator() function that takes both has_hot_loop and
225  // tail_num.
226  const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
227  const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
228 
229  const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
230  bs_block_window[UniversalGemmKernel::I0],
231  num_loop,
232  has_hot_loop,
233  tail_num,
234  smem_ptr_0);
235 
236  if(UseDefaultScheduler || (get_warp_id() == 0))
237  {
238  // Run Epilogue Pipeline
239  auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
240 
241  EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
242  }
243  }
244 
246  {
248  }
249 
253  {
254  return kargs.tile_partitioner.get_workspace_size(sizeof(AccDataType));
255  }
256 
259  CK_TILE_HOST static void SetWorkSpacePointer(StreamKKernelArgs& kargs, void* workspace_ptr)
260  {
261  kargs.workspace_ptr = workspace_ptr;
262  }
263 
275  index_t tile_idx,
276  index_t num_loop,
277  index_t i_k_a,
278  index_t i_k_b,
279  index_t k_size,
280  void* smem_ptr_0) const
281  {
282  const auto c_macro_tile_idx = kargs.tile_partitioner.get_output_tile_index(tile_idx);
283  index_t i_m = c_macro_tile_idx[UniversalGemmKernel::I0] * TilePartitioner::MPerBlock;
284  index_t i_n = c_macro_tile_idx[UniversalGemmKernel::I1] * TilePartitioner::NPerBlock;
285 
286  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a;
287  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
288  CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
289 
290  // Run the GEMM pipeline and Epilogue.
291  RunGemm(
292  {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, k_size);
293  }
294 
302  index_t cta_idx) const
303  {
304  auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
305  workgroup_barrier sk_flags(sk_flags_ptr);
306  sk_flags.wait_set(0, 1, cta_idx);
307  }
308 
315  {
316  auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
317  workgroup_barrier sk_flags(sk_flags_ptr);
318  sk_flags.wait_eq(1, cta_idx);
319  }
320 
326  template <typename OAccTile>
327  CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile,
328  const OAccTile& in_block_tile) const
329  {
330  using BlockType = remove_cvref_t<decltype(in_out_block_tile)>;
331  constexpr auto o_spans = BlockType::get_distributed_spans();
332  sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
333  sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
334  constexpr auto idx = make_tuple(idx0, idx1);
335  in_out_block_tile(idx) = in_out_block_tile[idx] + in_block_tile[idx];
336  });
337  });
338  }
339 
347  template <typename DataType, typename OAccTileDist>
349  index_t cta_idx,
350  const OAccTileDist& c_block_tile_dist) const
351  {
352  const auto c_block_tile_buffer_size =
353  TilePartitioner::MPerBlock * TilePartitioner::NPerBlock * sizeof(DataType);
354  void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) +
355  kargs.tile_partitioner.get_flags_buffer_size() +
356  cta_idx * c_block_tile_buffer_size;
357 
358  const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
359  static_cast<DataType*>(partial_buffer_ptr),
361  make_tuple(TilePartitioner::NPerBlock, 1),
362  number<GemmPipeline::GetVectorSizeC()>{},
363  number<1>{});
364 
365  auto partial_tile_window = make_tile_window(
366  partial_tensor_view,
368  {0, 0},
369  c_block_tile_dist);
370 
371  return load_tile(partial_tile_window);
372  }
373 
380  template <typename OAccTile>
382  index_t cta_idx,
383  const OAccTile& c_block_tile) const
384  {
385  const auto c_block_tile_buffer_size = TilePartitioner::MPerBlock *
386  TilePartitioner::NPerBlock *
387  sizeof(typename OAccTile::DataType);
388  void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) +
389  kargs.tile_partitioner.get_flags_buffer_size() +
390  cta_idx * c_block_tile_buffer_size;
391 
392  const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
393  static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
395  make_tuple(TilePartitioner::NPerBlock, 1),
396  number<GemmPipeline::GetVectorSizeC()>{},
397  number<1>{});
398 
399  auto partial_tile_window = make_tile_window(
400  partial_tensor_view,
402  {0, 0});
403 
404  store_tile(partial_tile_window, c_block_tile);
405  }
406 
414  CK_TILE_DEVICE void
415  StreamKGemm(StreamKKernelArgs& kargs, index_t cta_idx, void* smem_ptr_0) const
416  {
417  index_t iter_start, iter_end;
418  kargs.tile_partitioner.get_iter_boundaries(iter_start, iter_end, cta_idx);
419 
420  while(iter_start < iter_end)
421  {
422  // Get the 1D tile index in the C tensor that this workgroup will work in for this
423  // iteration of the loop.
424  index_t tile_idx =
425  amd_wave_read_first_lane(kargs.tile_partitioner.get_tile_index(iter_start));
426 
427  // Get the start and end boundaries for the current tile.
428  index_t tile_iter_start, tile_iter_end;
429  kargs.tile_partitioner.get_tile_boundaries(tile_iter_start, tile_iter_end, tile_idx);
430 
431  // Get the start and end iteration within the current tile for the workgroup.
432  index_t local_iter_start = amd_wave_read_first_lane(
433  kargs.tile_partitioner.get_local_iter(iter_start, tile_iter_start));
434  index_t local_iter_end =
435  amd_wave_read_first_lane(kargs.tile_partitioner.get_local_iter_end(
436  tile_iter_start, iter_end, tile_iter_end));
437 
438  // Get the iteration length.
439  index_t num_loop_sk = local_iter_end - local_iter_start;
440 
441  // Determine the total size along the K dimension the workgroup is using in this
442  // iteration (used to construct tensor views).
443  index_t k_size = num_loop_sk * TilePartitioner::KPerBlock;
444 
445  // Get the K offsets for the A and B tensors
446  auto [i_k_a, i_k_b] = GetKOffsets<ALayout, BLayout>(
447  local_iter_start, kargs.stride_As[0], kargs.stride_Bs[0]);
448 
449  if constexpr(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Atomic)
450  {
451  BaseGemm(kargs, tile_idx, num_loop_sk, i_k_a, i_k_b, k_size, smem_ptr_0);
452  }
453  else
454  {
455  const auto c_macro_tile_idx =
456  kargs.tile_partitioner.get_output_tile_index(tile_idx);
457  index_t i_m =
458  c_macro_tile_idx[UniversalGemmKernel::I0] * TilePartitioner::MPerBlock;
459  index_t i_n =
460  c_macro_tile_idx[UniversalGemmKernel::I1] * TilePartitioner::NPerBlock;
461 
462  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a;
463  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
464  CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
465 
466  // Create Gemm tensor views, pad views and tile windows
467  const auto& gemm_tensor_views_tuple =
468  UniversalGemmKernel::template MakeGemmTensorViews<
469  EpiloguePipeline::MemoryOperation>(
470  {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, k_size);
471 
472  const auto& gemm_pad_views =
473  UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
474  auto gemm_tile_windows =
475  UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, i_m, i_n);
476 
477  // Run GEMM cooperatively by whole workgroup.
478  const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
479  const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
480  const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
481 
482  // Since num_loop can vary per WG and per iteration of the Stream-K while loop,
483  // we compute has_hot_loop and tail_num here. This is a similar pattern used by
484  // grouped GEMM. In this case, we call the GemmPipeline's operator() function
485  // that takes both has_hot_loop and tail_num.
486  const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop_sk);
487  const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop_sk);
488 
489  const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
490  bs_block_window[UniversalGemmKernel::I0],
491  num_loop_sk,
492  has_hot_loop,
493  tail_num,
494  smem_ptr_0);
495 
496  auto tile_started = iter_start == tile_iter_start;
497  auto tile_ended = iter_end >= tile_iter_end;
498  if(!tile_started)
499  {
500  StorePartial(kargs, cta_idx, c_block_tile);
501  // Ensure device-wide visibility of partial results stored in global memory
502  // before signaling completion. __threadfence() guarantees that all global
503  // memory writes by this thread are visible to other threads on the device.
504  __threadfence(); // send signal when the store is done
505  SignalStorePartialDone(kargs, cta_idx);
506  }
507  else
508  {
509  auto accum_block_tile = c_block_tile;
510  if(!tile_ended)
511  {
512  const index_t iter_per_tile = kargs.tile_partitioner.get_iters_per_tile();
513  const index_t iter_per_cta = kargs.tile_partitioner.get_iters_per_sk_cta();
514  const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
515  int accum_iters = local_iter_end - local_iter_start;
516  int next_cta = cta_idx + 1;
517 
518  while(accum_iters < iter_per_tile)
519  {
520  WaitStorePartialDone(kargs, next_cta);
521 
522  using BlockType = remove_cvref_t<decltype(c_block_tile)>;
523  AddBlockTile(
524  accum_block_tile,
525  LoadPartial<typename BlockType::DataType>(
526  kargs, next_cta, c_block_tile.get_tile_distribution()));
527 
528  accum_iters += iter_per_cta + (next_cta < extra_iters);
529  ++next_cta;
530  }
531  }
532 
533  auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
535  c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
536  }
537  }
538 
539  // Prepare for next Stream-K loop iteration.
540  iter_start = tile_iter_end;
541  block_sync_lds();
542  }
543  }
544 
552  template <bool U = PersistentDP>
553  CK_TILE_DEVICE typename std::enable_if_t<!U> operator()(StreamKKernelArgs kargs) const
554  {
555  // Allocate LDS
556  __shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
557 
558  index_t block_idx = ck_tile::get_block_1d_id();
559  index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
560  index_t dp_ctas = kargs.tile_partitioner.get_dp_ctas();
561  bool is_dp_ctas = block_idx < kargs.tile_partitioner.get_dp_ctas();
562 
563  // Check if at the data parallel section
564  if(is_dp_ctas)
565  {
566  BaseGemm(kargs, block_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
567  }
568  else
569  {
570  // Stream-K
571  StreamKGemm(kargs, block_idx - dp_ctas, smem_ptr_0);
572  }
573  }
574 
583  template <bool U = PersistentDP>
584  CK_TILE_DEVICE typename std::enable_if_t<U> operator()(StreamKKernelArgs kargs) const
585  {
586  // Allocate LDS
587  __shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
588 
589  index_t block_idx = ck_tile::get_block_1d_id();
590  index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
591 
592  // Data-parallel section
593  for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles();
594  tile_idx += kargs.tile_partitioner.get_grid())
595  {
596  BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
597  }
598 
599  // Stream-K section
600  StreamKGemm(kargs, block_idx, smem_ptr_0);
601  }
602 
603  private:
610  template <typename ALayout, typename BLayout>
612  GetKOffsets(index_t iter_offset, index_t stride_a, index_t stride_b)
613  {
614  index_t stride_offset_a;
615  index_t stride_offset_b;
616  if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
617  {
618  stride_offset_a = stride_a;
619  }
620  else
621  {
622  stride_offset_a = 1;
623  }
624 
625  if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
626  {
627  stride_offset_b = stride_b;
628  }
629  else
630  {
631  stride_offset_b = 1;
632  }
633 
634  index_t base_offset = iter_offset * TilePartitioner::KPerBlock;
635 
636  return make_tuple(base_offset * stride_offset_a, base_offset * stride_offset_b);
637  }
638 
639  CK_TILE_HOST static int NumCU()
640  {
641  hipDeviceProp_t dev_prop;
642  hipDevice_t dev;
643  hip_check_error(hipGetDevice(&dev));
644  hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
645  int num_cu = dev_prop.multiProcessorCount;
646 
647  return num_cu;
648  }
649 
654  CK_TILE_HOST static int Occupancy()
655  {
656  int occupancy;
657 
658  // Since occupancy of 1 is valid for stream k, we set min_num_block_per_cu to 1
659  constexpr int min_block_per_cu = 1;
660  const auto kernel = kentry<min_block_per_cu, Kernel, KernelArgs>;
661 
663  hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
664 
665  return occupancy;
666  }
667 };
668 } // namespace reboot
669 
678 {
679  CK_TILE_HOST explicit StreamKHostArgs(const void* a_ptr_,
680  const void* b_ptr_,
681  void* c_ptr_,
682  index_t M_,
683  index_t N_,
684  index_t K_,
685  index_t stride_A_,
686  index_t stride_B_,
687  index_t stride_C_,
688  StreamKReductionStrategy reduction_strategy_,
689  uint32_t num_sk_blocks_ = 0xffffffff)
690  : UniversalGemmHostArgs<>({a_ptr_},
691  {b_ptr_},
692  {/*ds_ptr*/},
693  c_ptr_,
694  /*k_batch_ =*/1,
695  M_,
696  N_,
697  K_,
698  {stride_A_},
699  {stride_B_},
700  {/*stride_Ds_*/},
701  stride_C_),
702  reduction_strategy{reduction_strategy_},
703  num_sk_blocks{num_sk_blocks_}
704  {
705  }
706 
709 };
710 
711 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
713 {
718 
720 
724 
729 
734 
736  static_assert(!is_detected<is_tuple, ALayout>::value &&
738  "ALayout and ADataType must be scalars.");
739 
741  static_assert(!is_detected<is_tuple, BLayout>::value &&
743  "BLayout and BDataType must be scalars.");
744 
746  static_assert(!is_detected<is_tuple, CLayout>::value &&
748  "CLayout and CDataType must be scalars.");
749 
751  {
762  };
763 
766 
767  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
768  {
769  // clang-format off
770  using P_ = GemmPipeline;
771  using WarpTile = typename P_::BlockGemmShape::WarpTile;
772 
773  return concat('_', "streamk", gemm_prec_str<ADataType, BDataType>(),
774  concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
775  concat('x', WarpTile::at(number<0>{}), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})),
776  concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
777  concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
778  // clang-format on
779  }
780 
783  CK_TILE_HOST static auto GridSize(const TilePartitioner& tile_partitioner) -> dim3
784  {
785  return tile_partitioner.GridSize();
786  }
787 
792  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
793  {
795  }
796 
797  CK_TILE_HOST static constexpr auto BlockSize() -> dim3
798  {
800  }
801 
810  int num_cu = NumCU(),
811  int occupancy = Occupancy())
812  {
813  return StreamKKernelArgs{{host_args.as_ptr,
814  host_args.bs_ptr,
815  host_args.ds_ptr,
816  host_args.e_ptr,
817  host_args.M,
818  host_args.N,
819  host_args.K,
820  host_args.stride_As,
821  host_args.stride_Bs,
822  host_args.stride_Ds,
823  host_args.stride_E,
824  host_args.k_batch},
825  host_args.reduction_strategy,
826  host_args.num_sk_blocks,
827  // The workspace pointer is set to nullptr because we must first
828  // instantiate the TilePartitioner to get the necessary size
829  /*workspace_ptr =*/nullptr,
830  TilePartitioner{static_cast<uint32_t>(host_args.M),
831  static_cast<uint32_t>(host_args.N),
832  static_cast<uint32_t>(host_args.K),
833  static_cast<uint32_t>(num_cu),
834  static_cast<uint32_t>(occupancy),
835  host_args.num_sk_blocks}};
836  }
837 
838  template <bool UseDefaultScheduler = true>
839  CK_TILE_DEVICE static void
840  RunGemm(const std::array<const ADataType*, UniversalGemmKernel::NumATensor>& as_ptr,
841  const std::array<const BDataType*, UniversalGemmKernel::NumBTensor>& bs_ptr,
842  const std::array<const void*, UniversalGemmKernel::NumDTensor>& ds_ptr,
843  CDataType* c_ptr,
844  void* smem_ptr_0,
845  const typename UniversalGemmKernel::KernelArgs& kargs,
846  const index_t num_loop,
847  const index_t block_idx_m,
848  const index_t block_idx_n,
849  const index_t k_size)
850  {
851  // Create Gemm tensor views, pad views and tile windows
852  const auto& gemm_tensor_views_tuple =
853  UniversalGemmKernel::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
854  as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size);
855 
856  const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
857  auto gemm_tile_windows =
858  UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
859 
860  // Run GEMM cooperatively by whole workgroup.
861  const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
862  const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
863  const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
864 
865  // Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
866  // has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
867  // case, we call the GemmPipeline's operator() function that takes both has_hot_loop and
868  // tail_num.
869  const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
870  const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
871 
872  const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
873  bs_block_window[UniversalGemmKernel::I0],
874  num_loop,
875  has_hot_loop,
876  tail_num,
877  smem_ptr_0);
878 
879  if(UseDefaultScheduler || (get_warp_id() == 0))
880  {
881  // Run Epilogue Pipeline
882  auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
883 
884  EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
885  }
886  }
887 
889  {
891  {
892  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
893  {
894  CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy.");
895  }
896  return false;
897  }
899  }
900 
904  {
905  // For reduction, we need to determine the amount of device space for acculumation
906  // results and semaphores.
908  {
909  return kargs.tile_partitioner.GetWorkSpaceSize(sizeof(CDataType));
910  }
911 
912  // Otherwise, no additional space is needed since blocks atomically store their results.
913  return 0;
914  }
915 
918  CK_TILE_HOST static void SetWorkSpacePointer(StreamKKernelArgs& kargs, void* workspace_ptr)
919  {
920  kargs.workspace_ptr = workspace_ptr;
921  }
922 
925  {
926  // Allocate LDS
927  __shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
928 
929  uint32_t block_idx = ck_tile::get_block_1d_id();
930 
931  bool is_padding_block =
932  amd_wave_read_first_lane(block_idx >= kargs.tile_partitioner.sk_num_blocks &&
933  block_idx < kargs.tile_partitioner.dp_start_block_idx);
934 
935  // Padding blocks make it such that the DP blocks are aligned with the number of CUs; they
936  // should not partake in the GEMM
937  if(is_padding_block)
938  return;
939 
940  // Determine the K offset of the first and final macro tile in the A and B tensors along the
941  // K dimension.
942  uint32_t iter_start, iter_end;
943  kargs.tile_partitioner.GetBlockItr(block_idx, iter_start, iter_end);
944 
945  // Main Stream-K loop
946  while(true)
947  {
948  // Determine the number of macro tiles in A and B this WG is resposible for in the
949  // current C macro tile.
950  uint32_t current_iter_length = amd_wave_read_first_lane(
951  kargs.tile_partitioner.GetCurrentIterLength(iter_start, iter_end));
952 
953  // Determine the 1D tile_idx and the iter_offset for this WG.
954  // The tile_idx is the 1D macro tile index in the C tensor.
955  // The iter_offset is the starting macro tile index in the K dimension for the WG in the
956  // current iteration of the while loop.
957  uint32_t tile_idx, iter_offset;
958  kargs.tile_partitioner.GetTileIdxWithOffset(iter_start, tile_idx, iter_offset);
959 
960  // Get the 2D tile index in the C tensor for this WG using the 1D index (i.e. tile_idx)
961  auto spatial_idx = kargs.tile_partitioner.GetOutputTileIndex(tile_idx);
962 
963  // Get the offsets in A, B, C tensors.
964  index_t i_m = static_cast<index_t>(spatial_idx[UniversalGemmKernel::I0] *
965  TilePartitioner::MPerBlock);
966  index_t i_n = static_cast<index_t>(spatial_idx[UniversalGemmKernel::I1] *
967  TilePartitioner::NPerBlock);
968  auto [i_k_a, i_k_b] = GetKOffsets<ALayout, BLayout>(
969  static_cast<index_t>(iter_offset), kargs.stride_As[0], kargs.stride_Bs[0]);
970 
971  // Determine the total size along the K dimension the WG is using in this iteration
972  // (used to construct tensor views).
973  index_t k_size = static_cast<index_t>(current_iter_length * TilePartitioner::KPerBlock);
974 
975  // Update pointer offsets for A, B, and C.
976  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a;
977  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
978  CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
979 
980  // Run the GEMM pipeline and Epilogue.
981  RunGemm({a_ptr},
982  {b_ptr},
983  {/*ds_ptr*/},
984  c_ptr,
985  smem_ptr_0,
986  kargs,
987  current_iter_length,
988  i_m,
989  i_n,
990  k_size);
991 
992  // Prepare for next Stream-K loop iteration.
993  iter_start += current_iter_length;
994  if(iter_end <= iter_start)
995  break;
996  block_sync_lds();
997  }
998  }
999 
1000  private:
1007  template <typename ALayout, typename BLayout>
1009  GetKOffsets(index_t iter_offset, index_t stride_a, index_t stride_b)
1010  {
1011  index_t stride_offset_a;
1012  index_t stride_offset_b;
1013  if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
1014  {
1015  stride_offset_a = stride_a;
1016  }
1017  else
1018  {
1019  stride_offset_a = 1;
1020  }
1021 
1022  if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1023  {
1024  stride_offset_b = stride_b;
1025  }
1026  else
1027  {
1028  stride_offset_b = 1;
1029  }
1030 
1031  index_t base_offset = iter_offset * TilePartitioner::KPerBlock;
1032 
1033  return make_tuple(base_offset * stride_offset_a, base_offset * stride_offset_b);
1034  }
1035 
1036  CK_TILE_HOST static int NumCU()
1037  {
1038  hipDeviceProp_t dev_prop;
1039  hipDevice_t dev;
1040  hip_check_error(hipGetDevice(&dev));
1041  hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
1042  int num_cu = dev_prop.multiProcessorCount;
1043 
1044  return num_cu;
1045  }
1046 
1051  CK_TILE_HOST static int Occupancy()
1052  {
1053  int occupancy;
1054 
1055  // Since occupancy of 1 is valid for stream k, we set min_num_block_per_cu to 1
1056  constexpr int min_block_per_cu = 1;
1057  const auto kernel = kentry<min_block_per_cu, Kernel, KernelArgs>;
1058 
1060  hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
1061 
1062  return occupancy;
1063  }
1064 };
1065 
1066 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:21
StreamKReductionStrategy
Definition: streamk_common.hpp:10
CK_TILE_HOST void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:13
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
@ Atomic
Definition: block_to_ctile_map.hpp:1012
@ Reduction
Definition: block_to_ctile_map.hpp:1013
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
unsigned int uint32_t
Definition: stdint.h:126
The Stream K GEMM kernel host arguments.
Definition: streamk_gemm_kernel.hpp:678
uint32_t num_sk_blocks
Definition: streamk_gemm_kernel.hpp:708
ck_tile::StreamKReductionStrategy reduction_strategy
Definition: streamk_gemm_kernel.hpp:707
CK_TILE_HOST StreamKHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_, StreamKReductionStrategy reduction_strategy_, uint32_t num_sk_blocks_=0xffffffff)
Definition: streamk_gemm_kernel.hpp:679
ALayout and ADataType are expected to be scalars, not a tuple.
Definition: streamk_gemm_kernel.hpp:751
StreamKReductionStrategy reduction_strategy
The strategy used by work groups to compute final results in C tensor.
Definition: streamk_gemm_kernel.hpp:753
uint32_t num_sk_blocks
The number of stream k blocks.
Definition: streamk_gemm_kernel.hpp:755
void * workspace_ptr
A pointer to a buffer in device memory for accumulating partial via reduction strategy.
Definition: streamk_gemm_kernel.hpp:758
TilePartitioner tile_partitioner
An instance of the TilePartioner class for assisting with mapping workgroups to the C tensor.
Definition: streamk_gemm_kernel.hpp:761
Definition: streamk_gemm_kernel.hpp:713
UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > UniversalGemmKernel
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Definition: streamk_gemm_kernel.hpp:717
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Specify the layout configurations for A, B, and C.
Definition: streamk_gemm_kernel.hpp:726
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Specify the data type configurations for A, B, and C.
Definition: streamk_gemm_kernel.hpp:731
static CK_TILE_HOST auto GridSize(const TilePartitioner &tile_partitioner) -> dim3
Compute the grid size for the Stream K kernel using the tile_partitioner.
Definition: streamk_gemm_kernel.hpp:783
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: streamk_gemm_kernel.hpp:727
static CK_TILE_HOST StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs &host_args, int num_cu=NumCU(), int occupancy=Occupancy())
Constructs kernel arguments for the Stream-K kernel.
Definition: streamk_gemm_kernel.hpp:809
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: streamk_gemm_kernel.hpp:733
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: streamk_gemm_kernel.hpp:721
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: streamk_gemm_kernel.hpp:723
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, UniversalGemmKernel::NumATensor > &as_ptr, const std::array< const BDataType *, UniversalGemmKernel::NumBTensor > &bs_ptr, const std::array< const void *, UniversalGemmKernel::NumDTensor > &ds_ptr, CDataType *c_ptr, void *smem_ptr_0, const typename UniversalGemmKernel::KernelArgs &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t k_size)
Definition: streamk_gemm_kernel.hpp:840
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: streamk_gemm_kernel.hpp:792
static CK_TILE_HOST void SetWorkSpacePointer(StreamKKernelArgs &kargs, void *workspace_ptr)
Sets the kargs' current workspace_ptr to the given workspace_ptr.
Definition: streamk_gemm_kernel.hpp:918
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: streamk_gemm_kernel.hpp:732
static constexpr index_t kBlockSize
Definition: streamk_gemm_kernel.hpp:719
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: streamk_gemm_kernel.hpp:722
static CK_TILE_HOST const std::string GetName()
Definition: streamk_gemm_kernel.hpp:767
static constexpr CK_TILE_HOST auto BlockSize() -> dim3
Definition: streamk_gemm_kernel.hpp:797
static CK_TILE_HOST bool IsSupportedArgument(const StreamKKernelArgs &kargs)
Definition: streamk_gemm_kernel.hpp:888
CK_TILE_DEVICE void operator()(StreamKKernelArgs kargs) const
Entry point for the Stream-K Kernel, performing the main Stream-K loop.
Definition: streamk_gemm_kernel.hpp:924
static CK_TILE_HOST uint32_t GetWorkSpaceSize(const StreamKKernelArgs &kargs)
Computes the buffer size needed to store accumulation results for Stream K.
Definition: streamk_gemm_kernel.hpp:903
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: streamk_gemm_kernel.hpp:728
The Universal GEMM kernel host arguments.
Definition: universal_gemm_kernel.hpp:32
const std::array< index_t, NumDTensor > stride_Ds
Definition: universal_gemm_kernel.hpp:73
const std::array< index_t, NumBTensor > stride_Bs
Definition: universal_gemm_kernel.hpp:72
index_t K
Definition: universal_gemm_kernel.hpp:70
void * e_ptr
Definition: universal_gemm_kernel.hpp:65
index_t M
Definition: universal_gemm_kernel.hpp:68
const std::array< const void *, NumDTensor > ds_ptr
Definition: universal_gemm_kernel.hpp:62
const std::array< const void *, NumATensor > as_ptr
Definition: universal_gemm_kernel.hpp:60
const std::array< index_t, NumATensor > stride_As
Definition: universal_gemm_kernel.hpp:71
index_t N
Definition: universal_gemm_kernel.hpp:69
index_t stride_E
Definition: universal_gemm_kernel.hpp:76
const std::array< const void *, NumBTensor > bs_ptr
Definition: universal_gemm_kernel.hpp:61
index_t k_batch
Definition: universal_gemm_kernel.hpp:80
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:86
void * e_ptr
The E output tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:94
std::array< index_t, NumBTensor > stride_Bs
The distance between consecutive elements of non-contiguous dimension (in memory) of Bs tensor.
Definition: universal_gemm_kernel.hpp:106
const std::array< const void *, NumDTensor > ds_ptr
The Ds input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:92
std::array< index_t, NumATensor > stride_As
The distance between consecutive elements of non-contiguous dimension (in memory) of As tensor.
Definition: universal_gemm_kernel.hpp:103
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:88
index_t k_batch
Definition: universal_gemm_kernel.hpp:113
index_t N
GEMM's N dimension size.
Definition: universal_gemm_kernel.hpp:98
index_t stride_E
The distance between consecutive elements of non-contiguous dimension (in memory) of E tensor.
Definition: universal_gemm_kernel.hpp:112
index_t K
GEMM's K dimension size.
Definition: universal_gemm_kernel.hpp:100
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:90
std::array< index_t, NumDTensor > stride_Ds
The distance between consecutive elements of non-contiguous dimension (in memory) of Ds tensor.
Definition: universal_gemm_kernel.hpp:109
index_t M
GEMM's M dimension size.
Definition: universal_gemm_kernel.hpp:96
static constexpr auto I2
Definition: universal_gemm_kernel.hpp:238
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: universal_gemm_kernel.hpp:853
static constexpr auto I3
Definition: universal_gemm_kernel.hpp:239
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: universal_gemm_kernel.hpp:754
static constexpr bool PersistentKernel
Definition: universal_gemm_kernel.hpp:217
static constexpr auto I1
Definition: universal_gemm_kernel.hpp:237
static CK_TILE_HOST auto BlockSize()
Definition: universal_gemm_kernel.hpp:290
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: universal_gemm_kernel.hpp:278
static constexpr auto I0
Definition: universal_gemm_kernel.hpp:236
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:373
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: universal_gemm_kernel.hpp:319
static constexpr index_t kBlockSize
Definition: universal_gemm_kernel.hpp:202
Definition: integral_constant.hpp:13
The Stream K GEMM kernel host arguments.
Definition: streamk_gemm_kernel.hpp:21
ck_tile::StreamKReductionStrategy reduction_strategy
Definition: streamk_gemm_kernel.hpp:48
CK_TILE_HOST StreamKHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_, StreamKReductionStrategy reduction_strategy_)
Definition: streamk_gemm_kernel.hpp:22
ALayout and ADataType are expected to be scalars, not a tuple.
Definition: streamk_gemm_kernel.hpp:108
TilePartitioner tile_partitioner
An instance of the TilePartioner class for assisting with mapping workgroups to the C tensor.
Definition: streamk_gemm_kernel.hpp:138
StreamKReductionStrategy reduction_strategy
The strategy used by work groups to compute final results in C tensor.
Definition: streamk_gemm_kernel.hpp:132
void * workspace_ptr
A pointer to a buffer in device memory for accumulating partial via reduction strategy.
Definition: streamk_gemm_kernel.hpp:135
StreamKKernelArgs(const StreamKHostArgs &host_args, index_t grid)
Definition: streamk_gemm_kernel.hpp:109
The Stream K GEMM kernel class.
Definition: streamk_gemm_kernel.hpp:64
typename GemmPipeline::ALayout ALayout
Specify the layout configurations for A, B, and C.
Definition: streamk_gemm_kernel.hpp:82
static constexpr CK_TILE_HOST auto BlockSize() -> dim3
Definition: streamk_gemm_kernel.hpp:174
CK_TILE_DEVICE std::enable_if_t< U > operator()(StreamKKernelArgs kargs) const
Entry point for the Stream-K Kernel with persistent DP.
Definition: streamk_gemm_kernel.hpp:584
CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs &kargs, index_t cta_idx) const
Waits for the thread block (cta_idx) to complete storing its partial results.
Definition: streamk_gemm_kernel.hpp:314
static CK_TILE_HOST auto GridSize(const TilePartitioner &tile_partitioner) -> dim3
Compute the grid size for the Stream K kernel using the tile_partitioner.
Definition: streamk_gemm_kernel.hpp:160
static constexpr bool is_tuple_v
Definition: streamk_gemm_kernel.hpp:93
CK_TILE_DEVICE void AddBlockTile(OAccTile &in_out_block_tile, const OAccTile &in_block_tile) const
Adds the values of a block tile to an output block tile.
Definition: streamk_gemm_kernel.hpp:327
static constexpr bool PersistentDP
Definition: streamk_gemm_kernel.hpp:71
EpiloguePipeline_ EpiloguePipeline
Definition: streamk_gemm_kernel.hpp:75
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: streamk_gemm_kernel.hpp:169
typename GemmPipeline::BDataType BDataType
Definition: streamk_gemm_kernel.hpp:88
CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs &kargs, index_t cta_idx, const OAccTileDist &c_block_tile_dist) const
Loads a partial block tile from the workspace buffer.
Definition: streamk_gemm_kernel.hpp:348
CK_TILE_DEVICE void BaseGemm(StreamKKernelArgs &kargs, index_t tile_idx, index_t num_loop, index_t i_k_a, index_t i_k_b, index_t k_size, void *smem_ptr_0) const
Computes offsets into A, B, and C tensors then runs the GEMM pipeline and epilogue.
Definition: streamk_gemm_kernel.hpp:274
static CK_TILE_HOST const std::string GetName()
Definition: streamk_gemm_kernel.hpp:144
typename EpiloguePipeline::AccDataType AccDataType
Definition: streamk_gemm_kernel.hpp:90
static CK_TILE_HOST bool IsSupportedArgument(const StreamKKernelArgs &kargs)
Definition: streamk_gemm_kernel.hpp:245
CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs &kargs, index_t cta_idx, const OAccTile &c_block_tile) const
Stores a partial block tile to the workspace buffer.
Definition: streamk_gemm_kernel.hpp:381
static constexpr index_t kBlockSize
Definition: streamk_gemm_kernel.hpp:70
CK_TILE_DEVICE std::enable_if_t<!U > operator()(StreamKKernelArgs kargs) const
Entry point for the Stream-K Kernel with non-persistent DP.
Definition: streamk_gemm_kernel.hpp:553
static CK_TILE_HOST uint32_t GetWorkSpaceSize(const StreamKKernelArgs &kargs)
Computes the buffer size needed to store accumulation results for Stream K.
Definition: streamk_gemm_kernel.hpp:252
static CK_TILE_HOST StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs &host_args, int num_cu=NumCU(), int occupancy=Occupancy())
Constructs kernel arguments for the Stream-K kernel.
Definition: streamk_gemm_kernel.hpp:186
CK_TILE_DEVICE void StreamKGemm(StreamKKernelArgs &kargs, index_t cta_idx, void *smem_ptr_0) const
Runs the main Stream-K algorithm.
Definition: streamk_gemm_kernel.hpp:415
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, UniversalGemmKernel::NumATensor > &as_ptr, const std::array< const BDataType *, UniversalGemmKernel::NumBTensor > &bs_ptr, const std::array< const void *, UniversalGemmKernel::NumDTensor > &ds_ptr, CDataType *c_ptr, void *smem_ptr_0, const typename UniversalGemmKernel::KernelArgs &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t k_size)
Definition: streamk_gemm_kernel.hpp:197
static CK_TILE_HOST void SetWorkSpacePointer(StreamKKernelArgs &kargs, void *workspace_ptr)
Sets the kargs' current workspace_ptr to the given workspace_ptr.
Definition: streamk_gemm_kernel.hpp:259
typename GemmPipeline::ADataType ADataType
Specify the data type configurations for A, B, and C.
Definition: streamk_gemm_kernel.hpp:87
typename GemmPipeline::BLayout BLayout
Definition: streamk_gemm_kernel.hpp:83
CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs &kargs, index_t cta_idx) const
Signals that the current thread block (CTA) has completed storing its partial results.
Definition: streamk_gemm_kernel.hpp:301
typename EpiloguePipeline::ODataType CDataType
Definition: streamk_gemm_kernel.hpp:89
GemmPipeline_ GemmPipeline
Definition: streamk_gemm_kernel.hpp:74
TilePartitioner_ TilePartitioner
Definition: streamk_gemm_kernel.hpp:73
typename GemmPipeline::CLayout CLayout
Definition: streamk_gemm_kernel.hpp:84
UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > UniversalGemmKernel
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Definition: streamk_gemm_kernel.hpp:68
Definition: stream_config.hpp:30
Definition: tuple.hpp:192
Definition: workgroup_barrier.hpp:12
CK_TILE_DEVICE void wait_eq(uint32_t value, uint32_t offset=0)
Definition: workgroup_barrier.hpp:20
CK_TILE_DEVICE void wait_set(uint32_t compare, uint32_t value, uint32_t offset=0)
Definition: workgroup_barrier.hpp:38
#define CK_TILE_ENV(name)
Definition: env.hpp:145