/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp Source File
streamk_gemm_kernel.hpp
Go to the documentation of this file.
1 // Copyright © Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
7 #include "ck_tile/ops/common.hpp"
9 
10 namespace ck_tile {
11 
20 {
21  CK_TILE_HOST explicit StreamKHostArgs(const void* a_ptr_,
22  const void* b_ptr_,
23  void* c_ptr_,
24  index_t M_,
25  index_t N_,
26  index_t K_,
27  index_t stride_A_,
28  index_t stride_B_,
29  index_t stride_C_,
30  StreamKReductionStrategy reduction_strategy_,
31  uint32_t num_sk_blocks_ = 0xffffffff)
32  : UniversalGemmHostArgs<>({a_ptr_},
33  {b_ptr_},
34  {/*ds_ptr*/},
35  c_ptr_,
36  /*k_batch_ =*/1,
37  M_,
38  N_,
39  K_,
40  {stride_A_},
41  {stride_B_},
42  {/*stride_Ds_*/},
43  stride_C_),
44  reduction_strategy{reduction_strategy_},
45  num_sk_blocks{num_sk_blocks_}
46  {
47  }
48 
51 };
52 
53 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
55 {
60 
62 
66 
71 
76 
80  "ALayout and ADataType must be scalars.");
81 
85  "BLayout and BDataType must be scalars.");
86 
90  "CLayout and CDataType must be scalars.");
91 
93  {
104  };
105 
108 
109  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
110  {
111  // clang-format off
112  using P_ = GemmPipeline;
113  using WarpTile = typename P_::BlockGemmShape::WarpTile;
114 
115  return concat('_', "streamk", gemm_prec_str<ADataType, BDataType>(),
116  concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
117  concat('x', WarpTile::at(number<0>{}), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})),
118  concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
119  concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
120  // clang-format on
121  }
122 
125  CK_TILE_HOST static auto GridSize(const TilePartitioner& tile_partitioner) -> dim3
126  {
127  return tile_partitioner.GridSize();
128  }
129 
134  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
135  {
137  }
138 
139  CK_TILE_HOST static constexpr auto BlockSize() -> dim3
140  {
142  }
143 
145  {
146  uint32_t occupancy = static_cast<uint32_t>(Occupancy());
147  uint32_t num_cu = static_cast<uint32_t>(NumCU());
148 
149  return StreamKKernelArgs{{host_args.as_ptr,
150  host_args.bs_ptr,
151  host_args.ds_ptr,
152  host_args.e_ptr,
153  host_args.M,
154  host_args.N,
155  host_args.K,
156  host_args.stride_As,
157  host_args.stride_Bs,
158  host_args.stride_Ds,
159  host_args.stride_E,
160  host_args.k_batch},
161  host_args.reduction_strategy,
162  host_args.num_sk_blocks,
163  // The workspace pointer is set to nullptr because we must first
164  // instantiate the TilePartitioner to get the necessary size
165  /*workspace_ptr =*/nullptr,
166  TilePartitioner{static_cast<uint32_t>(host_args.M),
167  static_cast<uint32_t>(host_args.N),
168  static_cast<uint32_t>(host_args.K),
169  num_cu,
170  occupancy,
171  host_args.num_sk_blocks}};
172  }
173 
174  CK_TILE_HOST static bool
176  {
178  }
179 
183  {
184  // For reduction, we need to determine the amount of device space for acculumation
185  // results and semaphores.
187  {
188  return kargs.tile_partitioner.GetWorkSpaceSize(sizeof(CDataType));
189  }
190 
191  // Otherwise, no additional space is needed since blocks atomically store their results.
192  return 0;
193  }
194 
197  CK_TILE_HOST static void SetWorkSpacePointer(StreamKKernelArgs& kargs, void* workspace_ptr)
198  {
199  kargs.workspace_ptr = workspace_ptr;
200  }
201 
202  // Temporary placeholder to support the Occupancy() static function.
203  // Since the Occupancy function uses kentry, this class must have an operator() function
204  CK_TILE_DEVICE void operator()(StreamKKernelArgs /*kargs*/) const {}
205 
206  private:
207  CK_TILE_HOST static int NumCU()
208  {
209  hipDeviceProp_t dev_prop;
210  hipDevice_t dev;
211  hip_check_error(hipGetDevice(&dev));
212  hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
213  int num_cu = dev_prop.multiProcessorCount;
214 
215  return num_cu;
216  }
217 
222  CK_TILE_HOST static int Occupancy()
223  {
224  int occupancy;
225 
226  // Since occupancy of 1 is valid for stream k, we set min_num_block_per_cu to 1
227  constexpr int min_block_per_cu = 1;
228  const auto kernel = kentry<min_block_per_cu, Kernel, KernelArgs>;
229 
231  hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
232 
233  return occupancy;
234  }
235 };
236 
237 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
StreamKReductionStrategy
Definition: streamk_common.hpp:10
CK_TILE_HOST void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:13
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
@ Reduction
Definition: block_to_ctile_map.hpp:1012
unsigned int uint32_t
Definition: stdint.h:126
The Stream K GEMM kernel host arguments.
Definition: streamk_gemm_kernel.hpp:20
uint32_t num_sk_blocks
Definition: streamk_gemm_kernel.hpp:50
ck_tile::StreamKReductionStrategy reduction_strategy
Definition: streamk_gemm_kernel.hpp:49
CK_TILE_HOST StreamKHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_, StreamKReductionStrategy reduction_strategy_, uint32_t num_sk_blocks_=0xffffffff)
Definition: streamk_gemm_kernel.hpp:21
ALayout and ADataType are expected to be scalars, not a tuple.
Definition: streamk_gemm_kernel.hpp:93
StreamKReductionStrategy reduction_strategy
The strategy used by work groups to compute final results in C tensor.
Definition: streamk_gemm_kernel.hpp:95
uint32_t num_sk_blocks
The number of stream k blocks.
Definition: streamk_gemm_kernel.hpp:97
void * workspace_ptr
A pointer to a buffer in device memory for accumulating partial via reduction strategy.
Definition: streamk_gemm_kernel.hpp:100
TilePartitioner tile_partitioner
An instance of the TilePartioner class for assisting with mapping workgroups to the C tensor.
Definition: streamk_gemm_kernel.hpp:103
Definition: streamk_gemm_kernel.hpp:55
UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > UniversalGemmKernel
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Definition: streamk_gemm_kernel.hpp:59
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Specify the layout configurations for A, B, and C.
Definition: streamk_gemm_kernel.hpp:68
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Specify the data type configurations for A, B, and C.
Definition: streamk_gemm_kernel.hpp:73
static CK_TILE_HOST auto GridSize(const TilePartitioner &tile_partitioner) -> dim3
Compute the grid size for the Stream K kernel using the tile_partitioner.
Definition: streamk_gemm_kernel.hpp:125
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: streamk_gemm_kernel.hpp:69
static CK_TILE_HOST bool IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs &kargs)
Definition: streamk_gemm_kernel.hpp:175
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: streamk_gemm_kernel.hpp:75
CK_TILE_DEVICE void operator()(StreamKKernelArgs) const
Definition: streamk_gemm_kernel.hpp:204
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: streamk_gemm_kernel.hpp:63
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: streamk_gemm_kernel.hpp:65
static CK_TILE_HOST StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs &host_args)
Definition: streamk_gemm_kernel.hpp:144
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: streamk_gemm_kernel.hpp:134
static CK_TILE_HOST void SetWorkSpacePointer(StreamKKernelArgs &kargs, void *workspace_ptr)
Sets the kargs' current workspace_ptr to the given workspace_ptr.
Definition: streamk_gemm_kernel.hpp:197
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: streamk_gemm_kernel.hpp:74
static constexpr index_t kBlockSize
Definition: streamk_gemm_kernel.hpp:61
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: streamk_gemm_kernel.hpp:64
static CK_TILE_HOST const std::string GetName()
Definition: streamk_gemm_kernel.hpp:109
static constexpr CK_TILE_HOST auto BlockSize() -> dim3
Definition: streamk_gemm_kernel.hpp:139
static CK_TILE_HOST uint32_t GetWorkSpaceSize(const StreamKKernelArgs &kargs)
Computes the buffer size needed to store accumulation results for Stream K.
Definition: streamk_gemm_kernel.hpp:182
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: streamk_gemm_kernel.hpp:70
The Universal GEMM kernel host arguments.
Definition: universal_gemm_kernel.hpp:32
const std::array< index_t, NumDTensor > stride_Ds
Definition: universal_gemm_kernel.hpp:73
const std::array< index_t, NumBTensor > stride_Bs
Definition: universal_gemm_kernel.hpp:72
index_t K
Definition: universal_gemm_kernel.hpp:70
void * e_ptr
Definition: universal_gemm_kernel.hpp:65
index_t M
Definition: universal_gemm_kernel.hpp:68
const std::array< const void *, NumDTensor > ds_ptr
Definition: universal_gemm_kernel.hpp:62
const std::array< const void *, NumATensor > as_ptr
Definition: universal_gemm_kernel.hpp:60
const std::array< index_t, NumATensor > stride_As
Definition: universal_gemm_kernel.hpp:71
index_t N
Definition: universal_gemm_kernel.hpp:69
index_t stride_E
Definition: universal_gemm_kernel.hpp:76
const std::array< const void *, NumBTensor > bs_ptr
Definition: universal_gemm_kernel.hpp:61
index_t k_batch
Definition: universal_gemm_kernel.hpp:80
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:86
static CK_TILE_HOST auto BlockSize()
Definition: universal_gemm_kernel.hpp:287
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: universal_gemm_kernel.hpp:275
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:370
static constexpr index_t kBlockSize
Definition: universal_gemm_kernel.hpp:199
Definition: integral_constant.hpp:13
Definition: stream_config.hpp:30