95 template <ck_tile::index_t NumDTensor = 0>
96 struct BatchedContractionHostArgs
114 BatchedContractionHostArgs(
117 const std::array<const void*, NumDTensor>& ds_ptr_,
120 const std::vector<ck_tile::index_t>& A_dims_,
121 const std::vector<ck_tile::index_t>& B_dims_,
122 const std::array<std::vector<ck_tile::index_t>, NumDTensor>&
124 const std::vector<ck_tile::index_t>& E_dims_,
126 const std::vector<ck_tile::index_t>& A_strides_,
127 const std::vector<ck_tile::index_t>& B_strides_,
128 const std::array<std::vector<ck_tile::index_t>, NumDTensor>&
130 const std::vector<ck_tile::index_t>&
142 A_strides(A_strides_),
143 B_strides(B_strides_),
144 Ds_strides(Ds_strides_),
145 E_strides(E_strides_)
151 std::array<const void*, NumDTensor> ds_ptr;
154 const std::vector<ck_tile::index_t>
156 const std::vector<ck_tile::index_t>
158 const std::array<std::vector<ck_tile::index_t>, NumDTensor>
160 const std::vector<ck_tile::index_t>
162 const std::vector<ck_tile::index_t>
164 const std::vector<ck_tile::index_t>
166 const std::array<std::vector<ck_tile::index_t>, NumDTensor>
168 const std::vector<ck_tile::index_t>
188 struct BatchedContractionKernelArgs
192 std::array<const void*, NumDTensor> ds_ptr;
206 std::array<ck_tile::index_t, NumDTensor> batch_stride_Ds;
217 std::array<ck_tile::index_t, NumDTensor>
223 using AGridDesc_M_K_ =
224 decltype(TensorDescriptorUtils<NumDimG,
230 VectorSizeE>::Make_A_GridDescriptor_M_K({}, {}));
231 using BGridDesc_N_K_ =
232 decltype(TensorDescriptorUtils<NumDimG,
238 VectorSizeE>::Make_B_GridDescriptor_N_K({}, {}));
239 using EGridDesc_M_N_ =
240 decltype(TensorDescriptorUtils<NumDimG,
246 VectorSizeE>::Make_E_GridDescriptor_M_N({}, {}));
248 AGridDesc_M_K_ a_grid_desc_m_k;
249 BGridDesc_N_K_ b_grid_desc_n_k;
250 EGridDesc_M_N_ e_grid_desc_m_n;
251 std::array<EGridDesc_M_N_, NumDTensor>
268 template <
typename Problem_,
269 typename TilePartitioner_,
270 typename GemmPipeline_,
271 typename EpiloguePipeline_>
272 struct BatchedContractionKernel
298 using TilePartitioner =
302 using EpiloguePipeline =
306 using UniversalGemmKernel =
313 using DescriptorUtils = TensorDescriptorUtils<NumDimG,
317 GemmPipeline::GetVectorSizeA(),
318 GemmPipeline::GetVectorSizeB(),
319 EpiloguePipeline::GetVectorSizeC()>;
322 using KernelArgs = BatchedContractionKernelArgs<NumDimG,
327 GemmPipeline::GetVectorSizeA(),
328 GemmPipeline::GetVectorSizeB(),
329 EpiloguePipeline::GetVectorSizeC()>;
333 CK_TILE_HOST static constexpr
auto GetKernelName() {
return "batched_contraction_kernel"; }
339 CK_TILE_HOST static constexpr
bool IsSupportedArguments(
const KernelArgs& kargs)
372 CK_TILE_HOST static constexpr
auto GridSize(
const KernelArgs& kargs)
375 TilePartitioner::GridSize(kargs.M_total, kargs.N_total), kargs.G_total, kargs.k_batch);
396 const BDataType* b_ptr,
397 const std::array<const void*, NumDTensor>& ds_ptr,
400 const KernelArgs& kargs,
407 make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_m_k);
409 make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_n_k);
411 make_tensor_view<address_space_enum::global>(e_ptr, kargs.e_grid_desc_m_n);
416 make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
417 sequence<false, GemmPipeline::kPadK>{});
421 make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
422 sequence<false, GemmPipeline::kPadK>{});
426 make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
427 sequence<false, GemmPipeline::kPadN>{});
432 make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
437 make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
442 make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
447 __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(k_size));
450 using AElementWise = remove_cvref_t<typename GemmPipeline::AElementWise>;
451 using BElementWise = remove_cvref_t<typename GemmPipeline::BElementWise>;
454 a_block_window, AElementWise{}, b_block_window, BElementWise{}, num_loop, smem_ptr);
460 const DDataType* d_ptr =
static_cast<const DDataType*
>(ds_ptr[i]);
463 make_tensor_view<address_space_enum::global>(d_ptr, kargs.ds_grid_desc_m_n[i]);
466 make_tuple(number<TilePartitioner::MPerBlock>{},
467 number<TilePartitioner::NPerBlock>{}),
470 number<NumDTensor>{});
473 EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_windows, smem_ptr);
477 MakeKernelArgs(
const BatchedContractionHostArgs<NumDTensor>& host_args)
479 const auto expected_A_dims = NumDimG + NumDimM + NumDimK;
480 const auto expected_B_dims = NumDimG + NumDimN + NumDimK;
481 const auto expected_E_dims = NumDimG + NumDimM + NumDimN;
483 if(host_args.A_dims.size() != expected_A_dims ||
484 host_args.A_strides.size() != expected_A_dims)
486 throw std::invalid_argument(
"A dimension size mismatch");
488 if(host_args.B_dims.size() != expected_B_dims ||
489 host_args.B_strides.size() != expected_B_dims)
491 throw std::invalid_argument(
"B dimension size mismatch");
493 if(host_args.E_dims.size() != expected_E_dims ||
494 host_args.E_strides.size() != expected_E_dims)
496 throw std::invalid_argument(
"E dimension size mismatch");
501 if(host_args.Ds_dims[d].size() != expected_E_dims ||
502 host_args.Ds_strides[d].size() != expected_E_dims)
504 throw std::invalid_argument(
"D dimension size mismatch");
509 kargs.a_ptr = host_args.a_ptr;
510 kargs.b_ptr = host_args.b_ptr;
511 kargs.ds_ptr = host_args.ds_ptr;
512 kargs.e_ptr = host_args.e_ptr;
513 kargs.k_batch = host_args.k_batch;
519 if(host_args.A_dims[i] != host_args.B_dims[i] ||
520 host_args.A_dims[i] != host_args.E_dims[i])
522 throw std::invalid_argument(
523 "All tensors must have identical G dimensions for valid contraction");
527 kargs.G_dims[i] = host_args.A_dims[i];
531 kargs.batch_stride_A = host_args.A_strides[NumDimG - 1];
532 kargs.batch_stride_B = host_args.B_strides[NumDimG - 1];
533 kargs.batch_stride_E = host_args.E_strides[NumDimG - 1];
537 kargs.M_dims[i] = host_args.A_dims[NumDimG + i];
538 if(kargs.M_dims[i] != host_args.E_dims[NumDimG + i])
540 throw std::invalid_argument(
"M dimension mismatch between A and E tensors");
545 kargs.N_dims[i] = host_args.B_dims[NumDimG + i];
546 if(kargs.N_dims[i] != host_args.E_dims[NumDimG + NumDimM + i])
548 throw std::invalid_argument(
"N dimension mismatch between B and E tensors");
553 kargs.K_dims[i] = host_args.A_dims[NumDimG + NumDimM + i];
554 if(kargs.K_dims[i] != host_args.B_dims[NumDimG + NumDimN + i])
556 throw std::invalid_argument(
"K dimension mismatch between A and B tensors");
564 kargs.G_total *= kargs.G_dims[i];
570 kargs.M_total *= kargs.M_dims[i];
576 kargs.N_total *= kargs.N_dims[i];
582 kargs.K_total *= kargs.K_dims[i];
586 kargs.a_grid_desc_m_k =
587 DescriptorUtils::Make_A_GridDescriptor_M_K(host_args.A_dims, host_args.A_strides);
588 kargs.b_grid_desc_n_k =
589 DescriptorUtils::Make_B_GridDescriptor_N_K(host_args.B_dims, host_args.B_strides);
590 kargs.e_grid_desc_m_n =
591 DescriptorUtils::Make_E_GridDescriptor_M_N(host_args.E_dims, host_args.E_strides);
596 kargs.ds_grid_desc_m_n[d] = DescriptorUtils::Make_E_GridDescriptor_M_N(
597 host_args.Ds_dims[d], host_args.Ds_strides[d]);
601 kargs.stride_A = kargs.K_total;
602 kargs.stride_B = kargs.K_total;
603 kargs.stride_E = kargs.N_total;
610 if(host_args.Ds_dims[d][i] != host_args.A_dims[i])
612 throw std::invalid_argument(
613 "D tensor G dimensions must match A/B/E tensor G dimensions");
617 kargs.batch_stride_Ds[d] = host_args.Ds_strides[d][NumDimG - 1];
618 kargs.stride_Ds[d] = kargs.N_total;
627 const auto [iM, iN] =
628 TilePartitioner{kargs.M_total, kargs.N_total}.GetOutputTileIndex(blockIdx.x);
630 __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
632 __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
634 const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y);
635 [[maybe_unused]]
const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
638 const auto batch_offset_A = i_batch_flat * kargs.batch_stride_A;
639 const auto batch_offset_B = i_batch_flat * kargs.batch_stride_B;
640 const auto batch_offset_E = i_batch_flat * kargs.batch_stride_E;
642 const ADataType* a_ptr =
static_cast<const ADataType*
>(kargs.a_ptr) + batch_offset_A;
643 const BDataType* b_ptr =
static_cast<const BDataType*
>(kargs.b_ptr) + batch_offset_B;
644 EDataType* e_ptr =
static_cast<EDataType*
>(kargs.e_ptr) + batch_offset_E;
646 std::array<const void*, NumDTensor> ds_batch_ptr;
647 static_for<0, NumDTensor, 1>{}([&](
auto i) {
648 using DDataType =
typename std::tuple_element<i.value, DsDataType>::type;
649 const auto batch_offset_D = i_batch_flat * kargs.batch_stride_Ds[i];
650 ds_batch_ptr[i] =
static_cast<const DDataType*
>(kargs.ds_ptr[i]) + batch_offset_D;
654 __shared__
char smem_ptr[GetSmemSize()];
670 const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs,
674 const ADataType* a_ptr_split = a_ptr + splitk_batch_offset.as_k_split_offset[0];
675 const BDataType* b_ptr_split = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
683 splitk_batch_offset.splitted_k,
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GemmPipeline
Definition: gemm_pipelines.hpp:9
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
The Universal GEMM kernel template.
Definition: universal_gemm_kernel.hpp:154
UniversalGemmKernelArgs< AsLayout::size(), BsLayout::size(), DsLayout::size()> KernelArgs
Definition: universal_gemm_kernel.hpp:258
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:373
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: universal_gemm_kernel.hpp:319
static constexpr index_t kBlockSize
Definition: universal_gemm_kernel.hpp:202
Utility functions for creating tensor descriptors in batched contraction operations.