98 template <ck_tile::index_t NumDTensor = 0>
99 struct BatchedContractionHostArgs
117 BatchedContractionHostArgs(
120 const std::array<const void*, NumDTensor>& ds_ptr_,
123 const std::vector<ck_tile::index_t>& A_dims_,
124 const std::vector<ck_tile::index_t>& B_dims_,
125 const std::array<std::vector<ck_tile::index_t>, NumDTensor>&
127 const std::vector<ck_tile::index_t>& E_dims_,
129 const std::vector<ck_tile::index_t>& A_strides_,
130 const std::vector<ck_tile::index_t>& B_strides_,
131 const std::array<std::vector<ck_tile::index_t>, NumDTensor>&
133 const std::vector<ck_tile::index_t>&
145 A_strides(A_strides_),
146 B_strides(B_strides_),
147 Ds_strides(Ds_strides_),
148 E_strides(E_strides_)
154 std::array<const void*, NumDTensor> ds_ptr;
157 const std::vector<ck_tile::index_t>
159 const std::vector<ck_tile::index_t>
161 const std::array<std::vector<ck_tile::index_t>, NumDTensor>
163 const std::vector<ck_tile::index_t>
165 const std::vector<ck_tile::index_t>
167 const std::vector<ck_tile::index_t>
169 const std::array<std::vector<ck_tile::index_t>, NumDTensor>
171 const std::vector<ck_tile::index_t>
188 struct BatchedContractionKernelArgs
192 std::array<const void*, NumDTensor> ds_ptr;
206 std::array<ck_tile::index_t, NumDTensor> batch_stride_Ds;
215 std::array<ck_tile::index_t, NumDTensor>
233 template <
typename Problem_,
234 typename TilePartitioner_,
235 typename GemmPipeline_,
236 typename EpiloguePipeline_>
237 struct BatchedContractionKernel
263 using TilePartitioner =
267 using EpiloguePipeline =
271 using UniversalGemmKernel =
278 BatchedContractionKernelArgs<NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor>;
284 CK_TILE_HOST static constexpr
auto GetKernelName() {
return "batched_contraction_kernel"; }
290 CK_TILE_HOST static constexpr
bool IsSupportedArguments(
const KernelArgs& kargs)
323 CK_TILE_HOST static constexpr
auto GridSize(
const KernelArgs& kargs)
326 TilePartitioner::GridSize(kargs.M_total, kargs.N_total), kargs.G_total, kargs.k_batch);
330 MakeKernelArgs(
const BatchedContractionHostArgs<NumDTensor>& host_args)
332 const auto expected_A_dims = NumDimG + NumDimM + NumDimK;
333 const auto expected_B_dims = NumDimG + NumDimN + NumDimK;
334 const auto expected_E_dims = NumDimG + NumDimM + NumDimN;
336 if(host_args.A_dims.size() != expected_A_dims ||
337 host_args.A_strides.size() != expected_A_dims)
339 throw std::invalid_argument(
"A dimension size mismatch");
341 if(host_args.B_dims.size() != expected_B_dims ||
342 host_args.B_strides.size() != expected_B_dims)
344 throw std::invalid_argument(
"B dimension size mismatch");
346 if(host_args.E_dims.size() != expected_E_dims ||
347 host_args.E_strides.size() != expected_E_dims)
349 throw std::invalid_argument(
"E dimension size mismatch");
354 if(host_args.Ds_dims[d].size() != expected_E_dims ||
355 host_args.Ds_strides[d].size() != expected_E_dims)
357 throw std::invalid_argument(
"D dimension size mismatch");
362 kargs.a_ptr = host_args.a_ptr;
363 kargs.b_ptr = host_args.b_ptr;
364 kargs.ds_ptr = host_args.ds_ptr;
365 kargs.e_ptr = host_args.e_ptr;
366 kargs.k_batch = host_args.k_batch;
372 if(host_args.A_dims[i] != host_args.B_dims[i] ||
373 host_args.A_dims[i] != host_args.E_dims[i])
375 throw std::invalid_argument(
376 "All tensors must have identical G dimensions for valid contraction");
380 kargs.G_dims[i] = host_args.A_dims[i];
384 kargs.batch_stride_A = host_args.A_strides[NumDimG - 1];
385 kargs.batch_stride_B = host_args.B_strides[NumDimG - 1];
386 kargs.batch_stride_E = host_args.E_strides[NumDimG - 1];
390 kargs.M_dims[i] = host_args.A_dims[NumDimG + i];
391 if(kargs.M_dims[i] != host_args.E_dims[NumDimG + i])
393 throw std::invalid_argument(
"M dimension mismatch between A and E tensors");
398 kargs.N_dims[i] = host_args.B_dims[NumDimG + i];
399 if(kargs.N_dims[i] != host_args.E_dims[NumDimG + NumDimM + i])
401 throw std::invalid_argument(
"N dimension mismatch between B and E tensors");
406 kargs.K_dims[i] = host_args.A_dims[NumDimG + NumDimM + i];
407 if(kargs.K_dims[i] != host_args.B_dims[NumDimG + NumDimN + i])
409 throw std::invalid_argument(
"K dimension mismatch between A and B tensors");
417 kargs.G_total *= kargs.G_dims[i];
423 kargs.M_total *= kargs.M_dims[i];
429 kargs.N_total *= kargs.N_dims[i];
435 kargs.K_total *= kargs.K_dims[i];
438 kargs.stride_A = kargs.K_total;
439 kargs.stride_B = kargs.K_total;
440 kargs.stride_E = kargs.N_total;
447 if(host_args.Ds_dims[d][i] != host_args.A_dims[i])
449 throw std::invalid_argument(
450 "D tensor G dimensions must match A/B/E tensor G dimensions");
454 kargs.batch_stride_Ds[d] = host_args.Ds_strides[d][NumDimG - 1];
455 kargs.stride_Ds[d] = kargs.N_total;
464 const auto [iM, iN] =
465 TilePartitioner{kargs.M_total, kargs.N_total}.GetOutputTileIndex(blockIdx.x);
467 __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
469 __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
471 const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y);
472 const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
475 const auto batch_offset_A = i_batch_flat * kargs.batch_stride_A;
476 const auto batch_offset_B = i_batch_flat * kargs.batch_stride_B;
477 const auto batch_offset_E = i_batch_flat * kargs.batch_stride_E;
479 const ADataType* a_ptr =
static_cast<const ADataType*
>(kargs.a_ptr) + batch_offset_A;
480 const BDataType* b_ptr =
static_cast<const BDataType*
>(kargs.b_ptr) + batch_offset_B;
481 EDataType* e_ptr =
static_cast<EDataType*
>(kargs.e_ptr) + batch_offset_E;
483 std::array<const void*, NumDTensor> ds_batch_ptr;
484 static_for<0, NumDTensor, 1>{}([&](
auto i) {
485 using DDataType =
typename std::tuple_element<i.value, DsDataType>::type;
486 const auto batch_offset_D = i_batch_flat * kargs.batch_stride_Ds[i];
487 ds_batch_ptr[i] =
static_cast<const DDataType*
>(kargs.ds_ptr[i]) + batch_offset_D;
503 const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs,
506 const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0];
507 const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
508 __shared__
char smem_ptr[GetSmemSize()];
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
The Universal GEMM kernel template.
Definition: universal_gemm_kernel.hpp:154
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:955
UniversalGemmKernelArgs< AsLayout::size(), BsLayout::size(), DsLayout::size()> KernelArgs
Definition: universal_gemm_kernel.hpp:258
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:373
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: universal_gemm_kernel.hpp:319
static constexpr index_t kBlockSize
Definition: universal_gemm_kernel.hpp:202