/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp Source File
batched_gemm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck_tile/ops/common.hpp"
9 
10 namespace ck_tile {
11 
12 struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs</*NumDTensor = 0*/>
13 {
15  CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_,
16  const void* b_ptr_,
17  void* c_ptr_,
18  ck_tile::index_t k_batch_,
22  ck_tile::index_t stride_A_,
23  ck_tile::index_t stride_B_,
24  ck_tile::index_t stride_C_,
25  ck_tile::index_t batch_stride_A_,
26  ck_tile::index_t batch_stride_B_,
27  ck_tile::index_t batch_stride_C_,
28  ck_tile::index_t batch_count_)
29  : GemmHostArgs(a_ptr_,
30  b_ptr_,
31  {},
32  c_ptr_,
33  k_batch_,
34  M_,
35  N_,
36  K_,
37  stride_A_,
38  stride_B_,
39  {},
40  stride_C_),
41  batch_stride_A(batch_stride_A_),
42  batch_stride_B(batch_stride_B_),
43  batch_stride_E(batch_stride_C_),
44  batch_count(batch_count_)
45  {
46  }
47 
52 };
53 
54 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
55 struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>
56 {
58 
60 
61  using ADataType = typename Base::ADataType;
62  using BDataType = typename Base::BDataType;
63  using CDataType = typename Base::EDataType;
64 
66  using GemmPipeline = typename Base::GemmPipeline;
68  using ALayout = typename Base::ALayout;
69  using BLayout = typename Base::BLayout;
70  using CLayout = typename Base::ELayout;
71 
72  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
73  {
74  // clang-format off
75  using P_ = GemmPipeline;
76 
77  return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>,
78  concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
79  concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
80  concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
81  // clang-format on
82  }
83 
85  {
90  };
91 
93 
94  __host__ static constexpr auto
95  GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count)
96  {
97  return dim3(TilePartitioner::GridSize(M, N), batch_count, KBatch);
98  }
99 
100  __host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); }
101 
102  CK_TILE_HOST static constexpr BatchedGemmKernelArgs
104  {
105  return BatchedGemmKernelArgs{{hostArgs.a_ptr,
106  hostArgs.b_ptr,
107  {},
108  hostArgs.e_ptr,
109  hostArgs.M,
110  hostArgs.N,
111  hostArgs.K,
112  hostArgs.stride_A,
113  hostArgs.stride_B,
114  {},
115  hostArgs.stride_E,
116  hostArgs.k_batch},
117  hostArgs.batch_stride_A,
118  hostArgs.batch_stride_B,
119  hostArgs.batch_stride_E,
120  hostArgs.batch_count};
121  }
122 
124  {
125  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
126  }
127 
129  {
130  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
131  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
132  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
133 
134  const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.y);
135  const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
136 
137  const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_splitk);
138 
139  // options
140  const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A);
141  const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A);
142  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) + batch_offset_A +
143  splitk_batch_offset.a_k_split_offset;
144 
145  const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B);
146  const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B);
147  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B +
148  splitk_batch_offset.b_k_split_offset;
149 
150  const auto batch_stride_E = __builtin_amdgcn_readfirstlane(kargs.batch_stride_E);
151  const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_E);
152  CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr) + batch_offset_C;
153 
154  // allocate LDS
155  __shared__ char smem_ptr[GetSmemSize()];
156 
157  this->RunGemm(a_ptr, b_ptr, {}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
158  }
159 };
160 
161 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:41
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
Definition: batched_gemm_kernel.hpp:13
ck_tile::index_t batch_stride_B
Definition: batched_gemm_kernel.hpp:49
ck_tile::index_t batch_stride_A
Definition: batched_gemm_kernel.hpp:48
ck_tile::index_t batch_stride_E
Definition: batched_gemm_kernel.hpp:50
CK_TILE_HOST BatchedGemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, ck_tile::index_t k_batch_, ck_tile::index_t M_, ck_tile::index_t N_, ck_tile::index_t K_, ck_tile::index_t stride_A_, ck_tile::index_t stride_B_, ck_tile::index_t stride_C_, ck_tile::index_t batch_stride_A_, ck_tile::index_t batch_stride_B_, ck_tile::index_t batch_stride_C_, ck_tile::index_t batch_count_)
Definition: batched_gemm_kernel.hpp:15
ck_tile::index_t batch_count
Definition: batched_gemm_kernel.hpp:51
CK_TILE_HOST BatchedGemmHostArgs()=default
Definition: batched_gemm_kernel.hpp:85
index_t batch_stride_E
Definition: batched_gemm_kernel.hpp:88
index_t batch_count
Definition: batched_gemm_kernel.hpp:89
index_t batch_stride_A
Definition: batched_gemm_kernel.hpp:86
index_t batch_stride_B
Definition: batched_gemm_kernel.hpp:87
Definition: batched_gemm_kernel.hpp:56
typename Base::BDataType BDataType
Definition: batched_gemm_kernel.hpp:62
typename Base::BLayout BLayout
Definition: batched_gemm_kernel.hpp:69
typename Base::ADataType ADataType
Definition: batched_gemm_kernel.hpp:61
static CK_TILE_HOST const std::string GetName()
Definition: batched_gemm_kernel.hpp:72
typename Base::ELayout CLayout
Definition: batched_gemm_kernel.hpp:70
static constexpr CK_TILE_HOST BatchedGemmKernelArgs MakeKernelArgs(const BatchedGemmHostArgs &hostArgs)
Definition: batched_gemm_kernel.hpp:103
typename Base::TilePartitioner TilePartitioner
Definition: batched_gemm_kernel.hpp:65
typename Base::EDataType CDataType
Definition: batched_gemm_kernel.hpp:63
typename Base::ALayout ALayout
Definition: batched_gemm_kernel.hpp:68
static constexpr __host__ auto GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count)
Definition: batched_gemm_kernel.hpp:95
CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
Definition: batched_gemm_kernel.hpp:128
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: batched_gemm_kernel.hpp:123
static constexpr __host__ auto BlockSize()
Definition: batched_gemm_kernel.hpp:100
typename ck_tile::GemmKernelArgs<> GemmKernelArgs
Definition: batched_gemm_kernel.hpp:59
typename Base::GemmPipeline GemmPipeline
Definition: batched_gemm_kernel.hpp:66
typename Base::EpiloguePipeline EpiloguePipeline
Definition: batched_gemm_kernel.hpp:67
The GEMM kernel host arguments.
Definition: gemm_kernel.hpp:30
index_t M
Definition: gemm_kernel.hpp:67
index_t K
Definition: gemm_kernel.hpp:69
index_t stride_E
Definition: gemm_kernel.hpp:75
const void * b_ptr
Definition: gemm_kernel.hpp:60
index_t k_batch
Definition: gemm_kernel.hpp:79
index_t stride_A
Definition: gemm_kernel.hpp:70
const void * a_ptr
Definition: gemm_kernel.hpp:59
index_t N
Definition: gemm_kernel.hpp:68
index_t stride_B
Definition: gemm_kernel.hpp:71
void * e_ptr
Definition: gemm_kernel.hpp:64
Definition: gemm_kernel.hpp:251
index_t b_k_split_offset
Definition: gemm_kernel.hpp:287
index_t a_k_split_offset
Definition: gemm_kernel.hpp:286
The GEMM kernel device arguments.
Definition: gemm_kernel.hpp:85
const void * a_ptr
The A input tensor's pointer to device memory.
Definition: gemm_kernel.hpp:87
const void * b_ptr
The B input tensor's pointer to device memory.
Definition: gemm_kernel.hpp:89
index_t N
GEMM's N dimension size.
Definition: gemm_kernel.hpp:97
void * e_ptr
The E output tensor's pointer to device memory.
Definition: gemm_kernel.hpp:93
index_t M
GEMM's M dimension size.
Definition: gemm_kernel.hpp:95
The GEMM kernel template.
Definition: gemm_kernel.hpp:153
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: gemm_kernel.hpp:183
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Definition: gemm_kernel.hpp:180
remove_cvref_t< typename GemmPipeline::CLayout > ELayout
Definition: gemm_kernel.hpp:160
static CK_TILE_DEVICE void RunGemm(const ADataType *a_ptr, const BDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_kernel.hpp:794
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: gemm_kernel.hpp:157
static constexpr index_t KernelBlockSize
Definition: gemm_kernel.hpp:163
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: gemm_kernel.hpp:181
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_kernel.hpp:155
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: gemm_kernel.hpp:158
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_kernel.hpp:154
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_kernel.hpp:156