/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp Source File
grouped_gemm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
10 #include "ck_tile/host.hpp"
11 
12 namespace ck_tile {
13 
15 {
16  CK_TILE_HOST GroupedGemmHostArgs() noexcept = default;
17  CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_,
18  const void* b_ptr_,
19  void* c_ptr_,
20  ck_tile::index_t M_,
21  ck_tile::index_t N_,
22  ck_tile::index_t K_,
23  ck_tile::index_t stride_A_,
24  ck_tile::index_t stride_B_,
25  ck_tile::index_t stride_C_)
26  : GemmHostArgs(a_ptr_, b_ptr_, c_ptr_, KBatch, M_, N_, K_, stride_A_, stride_B_, stride_C_)
27  {
28  }
29 
30  private:
31  static constexpr index_t KBatch = 1;
32 };
33 
34 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
35 struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>
36 {
43 
47 
51 
52  static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
53 
55  {
59 
60  GemmTransKernelArg() = default;
62  : group_karg{karg}, block_start{bl_start}, block_end{bl_end}
63  {
64  }
65  };
66 
67  __host__ static auto GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
68  -> std::size_t
69  {
70  return gemm_descs.size() * sizeof(GemmTransKernelArg);
71  }
72 
73  __host__ static constexpr auto BlockSize() -> dim3 { return dim3(KernelBlockSize); }
74 
75  __host__ static constexpr auto GridSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
76  {
77  index_t grid_size = 0;
78  for(const auto& it_desc : gemm_descs)
79  {
80  const auto local_grid_size = TilePartitioner::GridSize(it_desc.M, it_desc.N);
81  grid_size += local_grid_size * it_desc.k_batch;
82  }
83  return dim3(grid_size, 1, 1);
84  }
85 
86  CK_TILE_HOST static auto MakeKargs(const std::vector<GroupedGemmHostArgs>& gemm_descs)
87  -> std::vector<GemmTransKernelArg>
88  {
89  std::vector<GemmTransKernelArg> gemm_kernel_args_;
90  index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size());
91  index_t grid_size = 0;
92  gemm_kernel_args_.reserve(group_count);
93 
94  for(std::size_t i = 0; i < gemm_descs.size(); ++i)
95  {
96  const index_t M = gemm_descs[i].M;
97  const index_t N = gemm_descs[i].N;
98  const index_t K = gemm_descs[i].K;
99 
100  if(M == 0 || N == 0 || K == 0)
101  {
102  continue;
103  }
104 
105  const index_t stride_a = gemm_descs[i].stride_A;
106  const index_t stride_b = gemm_descs[i].stride_B;
107  const index_t stride_c = gemm_descs[i].stride_C;
108 
109  const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch;
110 
111  const index_t block_start = grid_size;
112  const index_t block_end = grid_size + grid_size_grp;
113 
114  grid_size += grid_size_grp;
115 
116  auto karg = GemmKernelArgs{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
117  type_convert<const BDataType*>(gemm_descs[i].b_ptr),
118  type_convert<CDataType*>(gemm_descs[i].c_ptr),
119  M,
120  N,
121  K,
122  stride_a,
123  stride_b,
124  stride_c,
125  gemm_descs[i].k_batch};
126 
127  gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
128  }
129 
130  return gemm_kernel_args_;
131  }
132 
133  CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t
134  {
135  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
136  }
137 
138  CK_TILE_DEVICE void Run(const GemmTransKernelArg& kargs) const
139  {
140  const auto [iM, iN] = OffsetTile1DPartitioner::GetOffsetedTileIndex(
141  kargs.block_start, kargs.group_karg.M, kargs.group_karg.N);
142 
143  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
144  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
145 
146  const typename Base::SplitKBatchOffset splitk_batch_offset(kargs.group_karg, blockIdx.z);
147 
148  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.group_karg.a_ptr);
149  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.group_karg.b_ptr);
150  CDataType* c_ptr = static_cast<CDataType*>(kargs.group_karg.c_ptr);
151 
152  // allocate LDS
153  __shared__ char smem_ptr[GetSmemSize()];
154 
155  this->RunGemm(
156  a_ptr, b_ptr, c_ptr, smem_ptr, kargs.group_karg, splitk_batch_offset, i_m, i_n);
157  }
158 
159  CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
160  index_t group_count) const
161  {
162  const index_t block_id = ck_tile::get_block_1d_id();
163  const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
164  cast_pointer_to_generic_address_space(gemm_descs_const));
165 
166  index_t left = 0;
167  index_t right = group_count;
168  index_t group_id = index_t((left + right) >> 1);
169 
170  while((!(block_id >= gemm_desc_ptr[group_id].block_start &&
171  block_id < gemm_desc_ptr[group_id].block_end)) &&
172  left <= right)
173  {
174  if(block_id < gemm_desc_ptr[group_id].block_start)
175  {
176  right = group_id;
177  }
178  else
179  {
180  left = group_id;
181  }
182  group_id = index_t((left + right) >> 1);
183  }
184 
185  Run(gemm_desc_ptr[group_id]);
186  }
187 };
188 
189 } // namespace ck_tile
#define CK_CONSTANT_ADDRESS_SPACE
Definition: ck.hpp:26
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE index_t get_block_1d_id()
Definition: arch.hpp:66
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition: arch.hpp:136
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
Definition: gemm_kernel.hpp:32
Definition: gemm_kernel.hpp:86
Definition: gemm_kernel.hpp:119
Definition: gemm_kernel.hpp:60
static CK_TILE_DEVICE void RunGemm(const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, void *smem_ptr, const GemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_kernel.hpp:465
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Definition: gemm_kernel.hpp:69
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: gemm_kernel.hpp:64
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: gemm_kernel.hpp:72
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: gemm_kernel.hpp:70
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_kernel.hpp:62
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: gemm_kernel.hpp:66
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: gemm_kernel.hpp:65
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_kernel.hpp:61
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_kernel.hpp:63
Definition: grouped_gemm_kernel.hpp:15
CK_TILE_HOST GroupedGemmHostArgs() noexcept=default
Definition: grouped_gemm_kernel.hpp:55
GemmTransKernelArg(GemmKernelArgs &&karg, index_t bl_start, index_t bl_end)
Definition: grouped_gemm_kernel.hpp:61
GemmKernelArgs group_karg
Definition: grouped_gemm_kernel.hpp:56
ck_tile::index_t block_end
Definition: grouped_gemm_kernel.hpp:58
ck_tile::index_t block_start
Definition: grouped_gemm_kernel.hpp:57
Definition: grouped_gemm_kernel.hpp:36
static constexpr index_t KernelBlockSize
Definition: grouped_gemm_kernel.hpp:52
static __host__ auto GetWorkSpaceSize(const std::vector< GroupedGemmHostArgs > &gemm_descs) -> std::size_t
Definition: grouped_gemm_kernel.hpp:67
static CK_TILE_HOST auto MakeKargs(const std::vector< GroupedGemmHostArgs > &gemm_descs) -> std::vector< GemmTransKernelArg >
Definition: grouped_gemm_kernel.hpp:86
CK_TILE_DEVICE void Run(const GemmTransKernelArg &kargs) const
Definition: grouped_gemm_kernel.hpp:138
static constexpr CK_TILE_HOST_DEVICE auto GetSmemSize() -> index_t
Definition: grouped_gemm_kernel.hpp:133
typename Base::GemmKernelArgs GemmKernelArgs
Definition: grouped_gemm_kernel.hpp:50
static constexpr __host__ auto GridSize(const std::vector< GroupedGemmHostArgs > &gemm_descs)
Definition: grouped_gemm_kernel.hpp:75
static constexpr __host__ auto BlockSize() -> dim3
Definition: grouped_gemm_kernel.hpp:73
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, index_t group_count) const
Definition: grouped_gemm_kernel.hpp:159
Struct used to calculate offseted tile indexes.
Definition: gemm_tile_partitioner.hpp:183
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition: gemm_tile_partitioner.hpp:192