/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp Source File
batched_gemm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck_tile {
9 
11 {
13  CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_,
14  const void* b_ptr_,
15  void* c_ptr_,
16  ck_tile::index_t k_batch_,
20  ck_tile::index_t stride_A_,
21  ck_tile::index_t stride_B_,
22  ck_tile::index_t stride_C_,
23  ck_tile::index_t batch_stride_A_,
24  ck_tile::index_t batch_stride_B_,
25  ck_tile::index_t batch_stride_C_,
26  ck_tile::index_t batch_count_)
27  : GemmHostArgs(
28  a_ptr_, b_ptr_, c_ptr_, k_batch_, M_, N_, K_, stride_A_, stride_B_, stride_C_),
29  batch_stride_A(batch_stride_A_),
30  batch_stride_B(batch_stride_B_),
31  batch_stride_C(batch_stride_C_),
32  batch_count(batch_count_)
33  {
34  }
35 
40 };
41 
42 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
43 struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>
44 {
46 
48 
49  using ADataType = typename Base::ADataType;
50  using BDataType = typename Base::BDataType;
51  using CDataType = typename Base::CDataType;
52 
54  using GemmPipeline = typename Base::GemmPipeline;
56  using ALayout = typename Base::ALayout;
57  using BLayout = typename Base::BLayout;
58  using CLayout = typename Base::CLayout;
59 
61  {
66  };
67 
69 
70  __host__ static constexpr auto
71  GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count)
72  {
73  return dim3(TilePartitioner::GridSize(M, N), batch_count, KBatch);
74  }
75 
76  __host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); }
77 
78  CK_TILE_HOST static constexpr BatchedGemmKernelArgs
80  {
81  return BatchedGemmKernelArgs{{hostArgs.a_ptr,
82  hostArgs.b_ptr,
83  hostArgs.c_ptr,
84  hostArgs.M,
85  hostArgs.N,
86  hostArgs.K,
87  hostArgs.stride_A,
88  hostArgs.stride_B,
89  hostArgs.stride_C,
90  hostArgs.k_batch},
91  hostArgs.batch_stride_A,
92  hostArgs.batch_stride_B,
93  hostArgs.batch_stride_C,
94  hostArgs.batch_count};
95  }
96 
98  {
99  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
100  }
101 
103  {
104  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
105  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
106  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
107 
108  const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.y);
109  const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
110 
111  const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_splitk);
112 
113  // options
114  const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A);
115  const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A);
116  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) + batch_offset_A +
117  splitk_batch_offset.a_k_split_offset;
118 
119  const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B);
120  const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B);
121  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B +
122  splitk_batch_offset.b_k_split_offset;
123 
124  const auto batch_stride_C = __builtin_amdgcn_readfirstlane(kargs.batch_stride_C);
125  const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C);
126  CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr) + batch_offset_C;
127 
128  // allocate LDS
129  __shared__ char smem_ptr[GetSmemSize()];
130 
131  if(kargs.k_batch == 1)
132  {
133  this->RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
134  }
135  else
136  {
137  this->template RunGemm<memory_operation_enum::atomic_add>(
138  a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
139  }
140  }
141 };
142 
143 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
Definition: batched_gemm_kernel.hpp:11
ck_tile::index_t batch_stride_B
Definition: batched_gemm_kernel.hpp:37
ck_tile::index_t batch_stride_C
Definition: batched_gemm_kernel.hpp:38
ck_tile::index_t batch_stride_A
Definition: batched_gemm_kernel.hpp:36
CK_TILE_HOST BatchedGemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, ck_tile::index_t k_batch_, ck_tile::index_t M_, ck_tile::index_t N_, ck_tile::index_t K_, ck_tile::index_t stride_A_, ck_tile::index_t stride_B_, ck_tile::index_t stride_C_, ck_tile::index_t batch_stride_A_, ck_tile::index_t batch_stride_B_, ck_tile::index_t batch_stride_C_, ck_tile::index_t batch_count_)
Definition: batched_gemm_kernel.hpp:13
ck_tile::index_t batch_count
Definition: batched_gemm_kernel.hpp:39
CK_TILE_HOST BatchedGemmHostArgs()=default
Definition: batched_gemm_kernel.hpp:61
index_t batch_stride_C
Definition: batched_gemm_kernel.hpp:64
index_t batch_count
Definition: batched_gemm_kernel.hpp:65
index_t batch_stride_A
Definition: batched_gemm_kernel.hpp:62
index_t batch_stride_B
Definition: batched_gemm_kernel.hpp:63
Definition: batched_gemm_kernel.hpp:44
typename Base::BDataType BDataType
Definition: batched_gemm_kernel.hpp:50
typename Base::BLayout BLayout
Definition: batched_gemm_kernel.hpp:57
typename Base::ADataType ADataType
Definition: batched_gemm_kernel.hpp:49
static constexpr CK_TILE_HOST BatchedGemmKernelArgs MakeKernelArgs(const BatchedGemmHostArgs &hostArgs)
Definition: batched_gemm_kernel.hpp:79
typename Base::TilePartitioner TilePartitioner
Definition: batched_gemm_kernel.hpp:53
typename Base::ALayout ALayout
Definition: batched_gemm_kernel.hpp:56
typename Base::CDataType CDataType
Definition: batched_gemm_kernel.hpp:51
static constexpr __host__ auto GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count)
Definition: batched_gemm_kernel.hpp:71
CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
Definition: batched_gemm_kernel.hpp:102
typename Base::CLayout CLayout
Definition: batched_gemm_kernel.hpp:58
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: batched_gemm_kernel.hpp:97
static constexpr __host__ auto BlockSize()
Definition: batched_gemm_kernel.hpp:76
typename Base::GemmKernelArgs GemmKernelArgs
Definition: batched_gemm_kernel.hpp:47
typename Base::GemmPipeline GemmPipeline
Definition: batched_gemm_kernel.hpp:54
typename Base::EpiloguePipeline EpiloguePipeline
Definition: batched_gemm_kernel.hpp:55
Definition: gemm_kernel.hpp:32
void * c_ptr
Definition: gemm_kernel.hpp:54
const void * a_ptr
Definition: gemm_kernel.hpp:52
const void * b_ptr
Definition: gemm_kernel.hpp:53
index_t k_batch
Definition: gemm_kernel.hpp:55
Definition: gemm_kernel.hpp:86
Definition: gemm_kernel.hpp:119
index_t b_k_split_offset
Definition: gemm_kernel.hpp:156
index_t a_k_split_offset
Definition: gemm_kernel.hpp:155
Definition: gemm_kernel.hpp:60
static CK_TILE_DEVICE void RunGemm(const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, void *smem_ptr, const GemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_kernel.hpp:465
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Definition: gemm_kernel.hpp:69
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: gemm_kernel.hpp:64
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: gemm_kernel.hpp:72
static constexpr index_t KernelBlockSize
Definition: gemm_kernel.hpp:67
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: gemm_kernel.hpp:70
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_kernel.hpp:62
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: gemm_kernel.hpp:66
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: gemm_kernel.hpp:65
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_kernel.hpp:61
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_kernel.hpp:63
index_t stride_C
Definition: gemm_kernel.hpp:28
index_t stride_B
Definition: gemm_kernel.hpp:27
index_t K
Definition: gemm_kernel.hpp:25
index_t stride_A
Definition: gemm_kernel.hpp:26
index_t N
Definition: gemm_kernel.hpp:24
index_t M
Definition: gemm_kernel.hpp:23