/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp Source File
batched_gemm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck_tile/ops/common.hpp"
9 
10 namespace ck_tile {
11 
20 {
21  CK_TILE_HOST explicit BatchedGemmHostArgs(const void* a_ptr_,
22  const void* b_ptr_,
23  void* c_ptr_,
24  ck_tile::index_t k_batch_,
28  ck_tile::index_t stride_A_,
29  ck_tile::index_t stride_B_,
30  ck_tile::index_t stride_C_,
31  ck_tile::index_t batch_stride_A_,
32  ck_tile::index_t batch_stride_B_,
33  ck_tile::index_t batch_stride_C_,
34  ck_tile::index_t batch_count_)
35  : UniversalGemmHostArgs<>({a_ptr_},
36  {b_ptr_},
37  {/*ds_ptr*/},
38  c_ptr_,
39  k_batch_,
40  M_,
41  N_,
42  K_,
43  {stride_A_},
44  {stride_B_},
45  {/*stride_Ds_*/},
46  stride_C_),
47  batch_stride_A(batch_stride_A_),
48  batch_stride_B(batch_stride_B_),
49  batch_stride_E(batch_stride_C_),
50  batch_count(batch_count_)
51  {
52  }
53 
58 };
59 
60 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
62 {
68 
72 
77 
82 
84  static_assert(
86  "ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
87 
89  static_assert(
91  "BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
92 
96  "C/ELayout and C/EDataType must be scalars.");
97 
99  {
104  };
105 
107 
108  [[nodiscard]] CK_TILE_HOST static auto GetName() -> const std::string
109  {
110  // clang-format off
111  using P_ = GemmPipeline;
112  return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>(),
113  concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
114  concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
115  concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
116  // clang-format on
117  }
118 
119  CK_TILE_HOST static constexpr auto
120  GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count) -> dim3
121  {
122  return dim3(TilePartitioner::GridSize(M, N), batch_count, KBatch);
123  }
124 
125  CK_TILE_HOST static auto BlockSize() -> dim3
126  {
127  if(ck_tile::is_wave32())
128  {
129  return dim3(UniversalGemmKernel::kBlockSize / 2);
130  }
131  else
132  {
133  return dim3(UniversalGemmKernel::kBlockSize);
134  }
135  }
136 
137  CK_TILE_HOST static constexpr BatchedGemmKernelArgs
139  {
140  return BatchedGemmKernelArgs{{hostArgs.as_ptr,
141  hostArgs.bs_ptr,
142  hostArgs.ds_ptr,
143  hostArgs.e_ptr,
144  hostArgs.M,
145  hostArgs.N,
146  hostArgs.K,
147  hostArgs.stride_As,
148  hostArgs.stride_Bs,
149  hostArgs.stride_Ds,
150  hostArgs.stride_E,
151  hostArgs.k_batch},
152  hostArgs.batch_stride_A,
153  hostArgs.batch_stride_B,
154  hostArgs.batch_stride_E,
155  hostArgs.batch_count};
156  }
157 
159  {
160  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
161  }
162 
163  CK_TILE_HOST static auto
165  {
167  }
168 
170  {
171  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
172  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
173  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
174 
175  const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.y);
176  const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
177 
178  const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(kargs, i_splitk);
179 
180  // options
181  const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A);
182  const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A);
183  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + batch_offset_A +
184  splitk_batch_offset.as_k_split_offset[0];
185 
186  const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B);
187  const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B);
188  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + batch_offset_B +
189  splitk_batch_offset.bs_k_split_offset[0];
190 
191  const auto batch_stride_E = __builtin_amdgcn_readfirstlane(kargs.batch_stride_E);
192  const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_E);
193  CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr) + batch_offset_C;
194 
195  // allocate LDS
196  __shared__ char smem_ptr[GetSmemSize()];
197 
199  {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
200  }
201 };
202 
203 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
The Batched GEMM kernel host arguments.
Definition: batched_gemm_kernel.hpp:20
ck_tile::index_t batch_stride_B
Definition: batched_gemm_kernel.hpp:55
ck_tile::index_t batch_stride_A
Definition: batched_gemm_kernel.hpp:54
ck_tile::index_t batch_stride_E
Definition: batched_gemm_kernel.hpp:56
CK_TILE_HOST BatchedGemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, ck_tile::index_t k_batch_, ck_tile::index_t M_, ck_tile::index_t N_, ck_tile::index_t K_, ck_tile::index_t stride_A_, ck_tile::index_t stride_B_, ck_tile::index_t stride_C_, ck_tile::index_t batch_stride_A_, ck_tile::index_t batch_stride_B_, ck_tile::index_t batch_stride_C_, ck_tile::index_t batch_count_)
Definition: batched_gemm_kernel.hpp:21
ck_tile::index_t batch_count
Definition: batched_gemm_kernel.hpp:57
ALayout and ADataType are expected to be scalars, not a tuple.
Definition: batched_gemm_kernel.hpp:99
index_t batch_stride_E
Definition: batched_gemm_kernel.hpp:102
index_t batch_count
Definition: batched_gemm_kernel.hpp:103
index_t batch_stride_A
Definition: batched_gemm_kernel.hpp:100
index_t batch_stride_B
Definition: batched_gemm_kernel.hpp:101
Definition: batched_gemm_kernel.hpp:62
static constexpr index_t kBlockSize
Definition: batched_gemm_kernel.hpp:67
static constexpr CK_TILE_HOST BatchedGemmKernelArgs MakeKernelArgs(const BatchedGemmHostArgs &hostArgs)
Definition: batched_gemm_kernel.hpp:138
static CK_TILE_HOST auto IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs &kargs) -> bool
Definition: batched_gemm_kernel.hpp:164
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: batched_gemm_kernel.hpp:70
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: batched_gemm_kernel.hpp:75
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: batched_gemm_kernel.hpp:69
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: batched_gemm_kernel.hpp:71
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count) -> dim3
Definition: batched_gemm_kernel.hpp:120
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: batched_gemm_kernel.hpp:76
UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > UniversalGemmKernel
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Definition: batched_gemm_kernel.hpp:66
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: batched_gemm_kernel.hpp:80
CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
Definition: batched_gemm_kernel.hpp:169
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: batched_gemm_kernel.hpp:158
static CK_TILE_HOST auto BlockSize() -> dim3
Definition: batched_gemm_kernel.hpp:125
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Specify the data type configurations for A, B, E and D.
Definition: batched_gemm_kernel.hpp:79
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Specify the layout configurations for A, B, E and D.
Definition: batched_gemm_kernel.hpp:74
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: batched_gemm_kernel.hpp:81
static CK_TILE_HOST auto GetName() -> const std::string
Definition: batched_gemm_kernel.hpp:108
The Universal GEMM kernel host arguments.
Definition: universal_gemm_kernel.hpp:32
const std::array< index_t, NumDTensor > stride_Ds
Definition: universal_gemm_kernel.hpp:73
const std::array< index_t, NumBTensor > stride_Bs
Definition: universal_gemm_kernel.hpp:72
index_t K
Definition: universal_gemm_kernel.hpp:70
void * e_ptr
Definition: universal_gemm_kernel.hpp:65
index_t M
Definition: universal_gemm_kernel.hpp:68
const std::array< const void *, NumDTensor > ds_ptr
Definition: universal_gemm_kernel.hpp:62
const std::array< const void *, NumATensor > as_ptr
Definition: universal_gemm_kernel.hpp:60
const std::array< index_t, NumATensor > stride_As
Definition: universal_gemm_kernel.hpp:71
index_t N
Definition: universal_gemm_kernel.hpp:69
index_t stride_E
Definition: universal_gemm_kernel.hpp:76
const std::array< const void *, NumBTensor > bs_ptr
Definition: universal_gemm_kernel.hpp:61
index_t k_batch
Definition: universal_gemm_kernel.hpp:80
Definition: universal_gemm_kernel.hpp:322
std::array< index_t, NumATensor > as_k_split_offset
Definition: universal_gemm_kernel.hpp:365
std::array< index_t, NumBTensor > bs_k_split_offset
Definition: universal_gemm_kernel.hpp:366
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:86
void * e_ptr
The E output tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:94
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:88
index_t N
GEMM's N dimension size.
Definition: universal_gemm_kernel.hpp:98
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:90
index_t M
GEMM's M dimension size.
Definition: universal_gemm_kernel.hpp:96
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:952
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:370
static constexpr index_t kBlockSize
Definition: universal_gemm_kernel.hpp:199