/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp Source File
batched_contraction_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
82 namespace ck_tile {
83 
98 template <ck_tile::index_t NumDTensor = 0>
99 struct BatchedContractionHostArgs
100 {
117  BatchedContractionHostArgs(
118  const void* a_ptr_,
119  const void* b_ptr_,
120  const std::array<const void*, NumDTensor>& ds_ptr_,
121  void* e_ptr_,
122  ck_tile::index_t k_batch_,
123  const std::vector<ck_tile::index_t>& A_dims_, // [G0, G1, ..., M0, M1, ... , K0, K1, ...]
124  const std::vector<ck_tile::index_t>& B_dims_, // [G0, G1, ..., N0, N1, ... , K0, K1, ...]
125  const std::array<std::vector<ck_tile::index_t>, NumDTensor>&
126  Ds_dims_, // [G0, G1, ..., M0, M1, ... , N0, N1, ...][NumDTensor]
127  const std::vector<ck_tile::index_t>& E_dims_, // [G0, G1, ..., M0, M1, ... , N0, N1, ...]
128 
129  const std::vector<ck_tile::index_t>& A_strides_, // [G0, G1, ..., M0, M1, ...,K0, K1, ...]
130  const std::vector<ck_tile::index_t>& B_strides_, // [G0, G1, ..., N0, N1, ...,K0, K1, ...]
131  const std::array<std::vector<ck_tile::index_t>, NumDTensor>&
132  Ds_strides_, // [G0, G1, ..., M0, M1, ...,N0, N1, ...]
133  const std::vector<ck_tile::index_t>&
134  E_strides_) // [G0, G1, ..., M0, M1, ...,N0, N1, ...][NumDTensor]
135 
136  : a_ptr(a_ptr_),
137  b_ptr(b_ptr_),
138  ds_ptr(ds_ptr_),
139  e_ptr(e_ptr_),
140  k_batch(k_batch_),
141  A_dims(A_dims_),
142  B_dims(B_dims_),
143  Ds_dims(Ds_dims_),
144  E_dims(E_dims_),
145  A_strides(A_strides_),
146  B_strides(B_strides_),
147  Ds_strides(Ds_strides_),
148  E_strides(E_strides_)
149  {
150  }
151 
152  const void* a_ptr;
153  const void* b_ptr;
154  std::array<const void*, NumDTensor> ds_ptr;
155  void* e_ptr;
156  ck_tile::index_t k_batch;
157  const std::vector<ck_tile::index_t>
158  A_dims;
159  const std::vector<ck_tile::index_t>
160  B_dims;
161  const std::array<std::vector<ck_tile::index_t>, NumDTensor>
162  Ds_dims;
163  const std::vector<ck_tile::index_t>
164  E_dims;
165  const std::vector<ck_tile::index_t>
166  A_strides;
167  const std::vector<ck_tile::index_t>
168  B_strides;
169  const std::array<std::vector<ck_tile::index_t>, NumDTensor>
170  Ds_strides;
171  const std::vector<ck_tile::index_t>
172  E_strides;
173 };
174 
182 
183 template <ck_tile::index_t NumDimG,
184  ck_tile::index_t NumDimM,
185  ck_tile::index_t NumDimN,
186  ck_tile::index_t NumDimK,
187  ck_tile::index_t NumDTensor = 0>
188 struct BatchedContractionKernelArgs
189 {
190  const void* a_ptr;
191  const void* b_ptr;
192  std::array<const void*, NumDTensor> ds_ptr;
193  void* e_ptr;
194  ck_tile::index_t k_batch;
195 
196  ck_tile::index_t M_dims[NumDimM];
197  ck_tile::index_t N_dims[NumDimN];
198  ck_tile::index_t K_dims[NumDimK];
200  G_dims[NumDimG];
201 
202  // Batch strides for efficient offset calculation
203  ck_tile::index_t batch_stride_A;
204  ck_tile::index_t batch_stride_B;
205  ck_tile::index_t batch_stride_E;
206  std::array<ck_tile::index_t, NumDTensor> batch_stride_Ds;
207 
208  ck_tile::index_t G_total;
209  ck_tile::index_t M_total;
210  ck_tile::index_t N_total;
211  ck_tile::index_t K_total;
212 
213  ck_tile::index_t stride_A;
214  ck_tile::index_t stride_B;
215  std::array<ck_tile::index_t, NumDTensor>
216  stride_Ds;
217  ck_tile::index_t stride_E;
218 };
219 
232 
233 template <typename Problem_,
234  typename TilePartitioner_,
235  typename GemmPipeline_,
236  typename EpiloguePipeline_>
237 struct BatchedContractionKernel
238 {
239  // Type aliases for cleaner code and better readability
240  using Problem = ck_tile::remove_cvref_t<Problem_>;
241  using ADataType =
243  using BDataType =
245  using DsDataType =
248  using EDataType =
250 
251  // Compile-time dimension constants extracted from problem specification
252  static constexpr ck_tile::index_t NumDimG = Problem::NumDimG;
253  static constexpr ck_tile::index_t NumDimM =
254  Problem::NumDimM;
255  static constexpr ck_tile::index_t NumDimN =
256  Problem::NumDimN;
257  static constexpr ck_tile::index_t NumDimK =
258  Problem::NumDimK;
259  static constexpr ck_tile::index_t NumDTensor =
260  Problem::NumDTensor;
261 
262  // Pipeline and partitioning strategy types
263  using TilePartitioner =
266  using GemmPipeline = ck_tile::remove_cvref_t<GemmPipeline_>;
267  using EpiloguePipeline =
269 
270  // Underlying GEMM kernel that performs the actual computation
271  using UniversalGemmKernel =
273 
274  static constexpr ck_tile::index_t kBlockSize =
276 
277  using KernelArgs =
278  BatchedContractionKernelArgs<NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor>;
281 
284  CK_TILE_HOST static constexpr auto GetKernelName() { return "batched_contraction_kernel"; }
285 
290  CK_TILE_HOST static constexpr bool IsSupportedArguments(const KernelArgs& kargs)
291  {
292  typename UniversalGemmKernel::KernelArgs gemm_kargs{{kargs.a_ptr},
293  {kargs.b_ptr},
294  kargs.ds_ptr,
295  kargs.e_ptr,
296  kargs.M_total,
297  kargs.N_total,
298  kargs.K_total,
299  {kargs.stride_A},
300  {kargs.stride_B},
301  kargs.stride_Ds,
302  kargs.stride_E,
303  kargs.k_batch};
304 
305  return UniversalGemmKernel::IsSupportedArgument(gemm_kargs) && kargs.G_total > 0;
306  }
307 
311  CK_TILE_HOST static constexpr ck_tile::index_t GetSmemSize()
312  {
314  }
315 
318  CK_TILE_HOST static constexpr auto GetBlockSize()
319  {
320  return dim3(UniversalGemmKernel::kBlockSize);
321  }
322 
323  CK_TILE_HOST static constexpr auto GridSize(const KernelArgs& kargs)
324  {
325  return dim3(
326  TilePartitioner::GridSize(kargs.M_total, kargs.N_total), kargs.G_total, kargs.k_batch);
327  }
328 
329  CK_TILE_HOST static constexpr KernelArgs
330  MakeKernelArgs(const BatchedContractionHostArgs<NumDTensor>& host_args)
331  {
332  const auto expected_A_dims = NumDimG + NumDimM + NumDimK;
333  const auto expected_B_dims = NumDimG + NumDimN + NumDimK;
334  const auto expected_E_dims = NumDimG + NumDimM + NumDimN;
335 
336  if(host_args.A_dims.size() != expected_A_dims ||
337  host_args.A_strides.size() != expected_A_dims)
338  {
339  throw std::invalid_argument("A dimension size mismatch");
340  }
341  if(host_args.B_dims.size() != expected_B_dims ||
342  host_args.B_strides.size() != expected_B_dims)
343  {
344  throw std::invalid_argument("B dimension size mismatch");
345  }
346  if(host_args.E_dims.size() != expected_E_dims ||
347  host_args.E_strides.size() != expected_E_dims)
348  {
349  throw std::invalid_argument("E dimension size mismatch");
350  }
351 
352  for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
353  {
354  if(host_args.Ds_dims[d].size() != expected_E_dims ||
355  host_args.Ds_strides[d].size() != expected_E_dims)
356  {
357  throw std::invalid_argument("D dimension size mismatch");
358  }
359  }
360 
361  KernelArgs kargs;
362  kargs.a_ptr = host_args.a_ptr;
363  kargs.b_ptr = host_args.b_ptr;
364  kargs.ds_ptr = host_args.ds_ptr;
365  kargs.e_ptr = host_args.e_ptr;
366  kargs.k_batch = host_args.k_batch;
367 
368  // Validate and set G dimensions (must be identical across all tensors)
369  for(ck_tile::index_t i = 0; i < NumDimG; ++i)
370  {
371  // All tensors must have same G dimensions for valid contraction
372  if(host_args.A_dims[i] != host_args.B_dims[i] ||
373  host_args.A_dims[i] != host_args.E_dims[i])
374  {
375  throw std::invalid_argument(
376  "All tensors must have identical G dimensions for valid contraction");
377  }
378 
379  // Store G dimensions (same for all tensors)
380  kargs.G_dims[i] = host_args.A_dims[i];
381  }
382 
383  // Set batch strides from the stride of last G dimension
384  kargs.batch_stride_A = host_args.A_strides[NumDimG - 1];
385  kargs.batch_stride_B = host_args.B_strides[NumDimG - 1];
386  kargs.batch_stride_E = host_args.E_strides[NumDimG - 1];
387 
388  for(ck_tile::index_t i = 0; i < NumDimM; ++i)
389  {
390  kargs.M_dims[i] = host_args.A_dims[NumDimG + i];
391  if(kargs.M_dims[i] != host_args.E_dims[NumDimG + i])
392  {
393  throw std::invalid_argument("M dimension mismatch between A and E tensors");
394  }
395  }
396  for(ck_tile::index_t i = 0; i < NumDimN; ++i)
397  {
398  kargs.N_dims[i] = host_args.B_dims[NumDimG + i];
399  if(kargs.N_dims[i] != host_args.E_dims[NumDimG + NumDimM + i])
400  {
401  throw std::invalid_argument("N dimension mismatch between B and E tensors");
402  }
403  }
404  for(ck_tile::index_t i = 0; i < NumDimK; ++i)
405  {
406  kargs.K_dims[i] = host_args.A_dims[NumDimG + NumDimM + i];
407  if(kargs.K_dims[i] != host_args.B_dims[NumDimG + NumDimN + i])
408  {
409  throw std::invalid_argument("K dimension mismatch between A and B tensors");
410  }
411  }
412 
413  // Calculate total dimensions from individual dimension arrays
414  kargs.G_total = 1;
415  for(ck_tile::index_t i = 0; i < NumDimG; ++i)
416  {
417  kargs.G_total *= kargs.G_dims[i];
418  }
419 
420  kargs.M_total = 1;
421  for(ck_tile::index_t i = 0; i < NumDimM; ++i)
422  {
423  kargs.M_total *= kargs.M_dims[i];
424  }
425 
426  kargs.N_total = 1;
427  for(ck_tile::index_t i = 0; i < NumDimN; ++i)
428  {
429  kargs.N_total *= kargs.N_dims[i];
430  }
431 
432  kargs.K_total = 1;
433  for(ck_tile::index_t i = 0; i < NumDimK; ++i)
434  {
435  kargs.K_total *= kargs.K_dims[i];
436  }
437 
438  kargs.stride_A = kargs.K_total;
439  kargs.stride_B = kargs.K_total;
440  kargs.stride_E = kargs.N_total;
441 
442  // Validate D tensors have same G dimensions and set their batch strides
443  for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
444  {
445  for(ck_tile::index_t i = 0; i < NumDimG; ++i)
446  {
447  if(host_args.Ds_dims[d][i] != host_args.A_dims[i])
448  {
449  throw std::invalid_argument(
450  "D tensor G dimensions must match A/B/E tensor G dimensions");
451  }
452  }
453  // Set batch stride for D tensor
454  kargs.batch_stride_Ds[d] = host_args.Ds_strides[d][NumDimG - 1];
455  kargs.stride_Ds[d] = kargs.N_total; // D tensors same shape as E
456  }
457 
458  return kargs;
459  }
460 
461  CK_TILE_DEVICE void operator()(const KernelArgs& kargs) const
462  {
463 
464  const auto [iM, iN] =
465  TilePartitioner{kargs.M_total, kargs.N_total}.GetOutputTileIndex(blockIdx.x);
466  const ck_tile::index_t i_m =
467  __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
468  const ck_tile::index_t i_n =
469  __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
470 
471  const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y);
472  const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
473 
474  // Calculate batch offsets for each tensor
475  const auto batch_offset_A = i_batch_flat * kargs.batch_stride_A;
476  const auto batch_offset_B = i_batch_flat * kargs.batch_stride_B;
477  const auto batch_offset_E = i_batch_flat * kargs.batch_stride_E;
478 
479  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) + batch_offset_A;
480  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B;
481  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr) + batch_offset_E;
482 
483  std::array<const void*, NumDTensor> ds_batch_ptr;
484  static_for<0, NumDTensor, 1>{}([&](auto i) {
485  using DDataType = typename std::tuple_element<i.value, DsDataType>::type;
486  const auto batch_offset_D = i_batch_flat * kargs.batch_stride_Ds[i];
487  ds_batch_ptr[i] = static_cast<const DDataType*>(kargs.ds_ptr[i]) + batch_offset_D;
488  });
489 
490  typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr},
491  {b_ptr},
492  ds_batch_ptr,
493  e_ptr,
494  kargs.M_total,
495  kargs.N_total,
496  kargs.K_total,
497  {kargs.stride_A},
498  {kargs.stride_B},
499  kargs.stride_Ds,
500  kargs.stride_E,
501  kargs.k_batch};
502 
503  const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs,
504  i_splitk);
505 
506  const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0];
507  const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
508  __shared__ char smem_ptr[GetSmemSize()];
509 
510  UniversalGemmKernel::RunGemm({a_ptr_final},
511  {b_ptr_final},
512  ds_batch_ptr,
513  e_ptr,
514  smem_ptr,
515  gemm_kargs,
516  splitk_batch_offset,
517  i_m,
518  i_n);
519  }
520 };
521 
522 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
The Universal GEMM kernel template.
Definition: universal_gemm_kernel.hpp:154
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:955
UniversalGemmKernelArgs< AsLayout::size(), BsLayout::size(), DsLayout::size()> KernelArgs
Definition: universal_gemm_kernel.hpp:258
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:373
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: universal_gemm_kernel.hpp:319
static constexpr index_t kBlockSize
Definition: universal_gemm_kernel.hpp:202