/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp Source File
batched_contraction_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
10 
79 namespace ck_tile {
80 
95 template <ck_tile::index_t NumDTensor = 0>
96 struct BatchedContractionHostArgs
97 {
114  BatchedContractionHostArgs(
115  const void* a_ptr_,
116  const void* b_ptr_,
117  const std::array<const void*, NumDTensor>& ds_ptr_,
118  void* e_ptr_,
119  ck_tile::index_t k_batch_,
120  const std::vector<ck_tile::index_t>& A_dims_, // [G0, G1, ..., M0, M1, ... , K0, K1, ...]
121  const std::vector<ck_tile::index_t>& B_dims_, // [G0, G1, ..., N0, N1, ... , K0, K1, ...]
122  const std::array<std::vector<ck_tile::index_t>, NumDTensor>&
123  Ds_dims_, // [G0, G1, ..., M0, M1, ... , N0, N1, ...][NumDTensor]
124  const std::vector<ck_tile::index_t>& E_dims_, // [G0, G1, ..., M0, M1, ... , N0, N1, ...]
125 
126  const std::vector<ck_tile::index_t>& A_strides_, // [G0, G1, ..., M0, M1, ...,K0, K1, ...]
127  const std::vector<ck_tile::index_t>& B_strides_, // [G0, G1, ..., N0, N1, ...,K0, K1, ...]
128  const std::array<std::vector<ck_tile::index_t>, NumDTensor>&
129  Ds_strides_, // [G0, G1, ..., M0, M1, ...,N0, N1, ...]
130  const std::vector<ck_tile::index_t>&
131  E_strides_) // [G0, G1, ..., M0, M1, ...,N0, N1, ...][NumDTensor]
132 
133  : a_ptr(a_ptr_),
134  b_ptr(b_ptr_),
135  ds_ptr(ds_ptr_),
136  e_ptr(e_ptr_),
137  k_batch(k_batch_),
138  A_dims(A_dims_),
139  B_dims(B_dims_),
140  Ds_dims(Ds_dims_),
141  E_dims(E_dims_),
142  A_strides(A_strides_),
143  B_strides(B_strides_),
144  Ds_strides(Ds_strides_),
145  E_strides(E_strides_)
146  {
147  }
148 
149  const void* a_ptr;
150  const void* b_ptr;
151  std::array<const void*, NumDTensor> ds_ptr;
152  void* e_ptr;
153  ck_tile::index_t k_batch;
154  const std::vector<ck_tile::index_t>
155  A_dims;
156  const std::vector<ck_tile::index_t>
157  B_dims;
158  const std::array<std::vector<ck_tile::index_t>, NumDTensor>
159  Ds_dims;
160  const std::vector<ck_tile::index_t>
161  E_dims;
162  const std::vector<ck_tile::index_t>
163  A_strides;
164  const std::vector<ck_tile::index_t>
165  B_strides;
166  const std::array<std::vector<ck_tile::index_t>, NumDTensor>
167  Ds_strides;
168  const std::vector<ck_tile::index_t>
169  E_strides;
170 };
171 
179 
180 template <ck_tile::index_t NumDimG,
181  ck_tile::index_t NumDimM,
182  ck_tile::index_t NumDimN,
183  ck_tile::index_t NumDimK,
184  ck_tile::index_t NumDTensor = 0,
185  ck_tile::index_t VectorSizeA = 1,
186  ck_tile::index_t VectorSizeB = 1,
187  ck_tile::index_t VectorSizeE = 1>
188 struct BatchedContractionKernelArgs
189 {
190  const void* a_ptr;
191  const void* b_ptr;
192  std::array<const void*, NumDTensor> ds_ptr;
193  void* e_ptr;
194  ck_tile::index_t k_batch;
195 
196  ck_tile::index_t M_dims[NumDimM];
197  ck_tile::index_t N_dims[NumDimN];
198  ck_tile::index_t K_dims[NumDimK];
200  G_dims[NumDimG];
201 
202  // Batch strides for efficient offset calculation
203  ck_tile::index_t batch_stride_A;
204  ck_tile::index_t batch_stride_B;
205  ck_tile::index_t batch_stride_E;
206  std::array<ck_tile::index_t, NumDTensor> batch_stride_Ds;
207 
208  ck_tile::index_t G_total;
209  ck_tile::index_t M_total;
210  ck_tile::index_t N_total;
211  ck_tile::index_t K_total;
212 
214  stride_A;
216  stride_B;
217  std::array<ck_tile::index_t, NumDTensor>
218  stride_Ds;
220  stride_E;
221 
222  // Tensor descriptors (encode full multi-dimensional stride information with vectorization)
223  using AGridDesc_M_K_ =
224  decltype(TensorDescriptorUtils<NumDimG,
225  NumDimM,
226  NumDimN,
227  NumDimK,
228  VectorSizeA,
229  VectorSizeB,
230  VectorSizeE>::Make_A_GridDescriptor_M_K({}, {}));
231  using BGridDesc_N_K_ =
232  decltype(TensorDescriptorUtils<NumDimG,
233  NumDimM,
234  NumDimN,
235  NumDimK,
236  VectorSizeA,
237  VectorSizeB,
238  VectorSizeE>::Make_B_GridDescriptor_N_K({}, {}));
239  using EGridDesc_M_N_ =
240  decltype(TensorDescriptorUtils<NumDimG,
241  NumDimM,
242  NumDimN,
243  NumDimK,
244  VectorSizeA,
245  VectorSizeB,
246  VectorSizeE>::Make_E_GridDescriptor_M_N({}, {}));
247 
248  AGridDesc_M_K_ a_grid_desc_m_k;
249  BGridDesc_N_K_ b_grid_desc_n_k;
250  EGridDesc_M_N_ e_grid_desc_m_n;
251  std::array<EGridDesc_M_N_, NumDTensor>
252  ds_grid_desc_m_n;
253 };
254 
267 
268 template <typename Problem_,
269  typename TilePartitioner_,
270  typename GemmPipeline_,
271  typename EpiloguePipeline_>
272 struct BatchedContractionKernel
273 {
274  // Type aliases for cleaner code and better readability
275  using Problem = ck_tile::remove_cvref_t<Problem_>;
276  using ADataType =
278  using BDataType =
280  using DsDataType =
283  using EDataType =
285 
286  // Compile-time dimension constants extracted from problem specification
287  static constexpr ck_tile::index_t NumDimG = Problem::NumDimG;
288  static constexpr ck_tile::index_t NumDimM =
289  Problem::NumDimM;
290  static constexpr ck_tile::index_t NumDimN =
291  Problem::NumDimN;
292  static constexpr ck_tile::index_t NumDimK =
293  Problem::NumDimK;
294  static constexpr ck_tile::index_t NumDTensor =
295  Problem::NumDTensor;
296 
297  // Pipeline and partitioning strategy types
298  using TilePartitioner =
302  using EpiloguePipeline =
304 
305  // Underlying GEMM kernel that performs the actual computation
306  using UniversalGemmKernel =
308 
309  static constexpr ck_tile::index_t kBlockSize =
311 
312  // Tensor descriptor utilities with vectorization support
313  using DescriptorUtils = TensorDescriptorUtils<NumDimG,
314  NumDimM,
315  NumDimN,
316  NumDimK,
317  GemmPipeline::GetVectorSizeA(),
318  GemmPipeline::GetVectorSizeB(),
319  EpiloguePipeline::GetVectorSizeC()>;
320 
321  // Kernel arguments with vectorization support
322  using KernelArgs = BatchedContractionKernelArgs<NumDimG,
323  NumDimM,
324  NumDimN,
325  NumDimK,
326  NumDTensor,
327  GemmPipeline::GetVectorSizeA(),
328  GemmPipeline::GetVectorSizeB(),
329  EpiloguePipeline::GetVectorSizeC()>;
330 
333  CK_TILE_HOST static constexpr auto GetKernelName() { return "batched_contraction_kernel"; }
334 
339  CK_TILE_HOST static constexpr bool IsSupportedArguments(const KernelArgs& kargs)
340  {
341  typename UniversalGemmKernel::KernelArgs gemm_kargs{{kargs.a_ptr},
342  {kargs.b_ptr},
343  kargs.ds_ptr,
344  kargs.e_ptr,
345  kargs.M_total,
346  kargs.N_total,
347  kargs.K_total,
348  {kargs.stride_A},
349  {kargs.stride_B},
350  kargs.stride_Ds,
351  kargs.stride_E,
352  kargs.k_batch};
353 
354  return UniversalGemmKernel::IsSupportedArgument(gemm_kargs) && kargs.G_total > 0;
355  }
356 
360  CK_TILE_HOST static constexpr ck_tile::index_t GetSmemSize()
361  {
363  }
364 
367  CK_TILE_HOST static constexpr auto GetBlockSize()
368  {
369  return dim3(UniversalGemmKernel::kBlockSize);
370  }
371 
372  CK_TILE_HOST static constexpr auto GridSize(const KernelArgs& kargs)
373  {
374  return dim3(
375  TilePartitioner::GridSize(kargs.M_total, kargs.N_total), kargs.G_total, kargs.k_batch);
376  }
377 
395  CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
396  const BDataType* b_ptr,
397  const std::array<const void*, NumDTensor>& ds_ptr,
398  EDataType* e_ptr,
399  void* smem_ptr,
400  const KernelArgs& kargs,
401  const index_t k_size,
402  const index_t i_m,
403  const index_t i_n)
404  {
405  // Create tensor views from descriptors (supports arbitrary stride patterns)
406  auto a_tensor_view =
407  make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_m_k);
408  auto b_tensor_view =
409  make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_n_k);
410  auto e_tensor_view =
411  make_tensor_view<address_space_enum::global>(e_ptr, kargs.e_grid_desc_m_n);
412 
413  // Pad views for boundary handling and optimization (like UniversalGemmKernel)
414  auto a_pad_view = pad_tensor_view(
415  a_tensor_view,
416  make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
417  sequence<false, GemmPipeline::kPadK>{});
418 
419  auto b_pad_view = pad_tensor_view(
420  b_tensor_view,
421  make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
422  sequence<false, GemmPipeline::kPadK>{});
423 
424  auto e_pad_view = pad_tensor_view(
425  e_tensor_view,
426  make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
427  sequence<false, GemmPipeline::kPadN>{});
428 
429  // Create tile windows from PADDED views
430  auto a_block_window = make_tile_window(
431  a_pad_view,
432  make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
433  {i_m, 0});
434 
435  auto b_block_window = make_tile_window(
436  b_pad_view,
437  make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
438  {i_n, 0});
439 
440  auto e_block_window = make_tile_window(
441  e_pad_view,
442  make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
443  {i_m, i_n});
444 
445  // Calculate number of K loops
446  const index_t num_loop =
447  __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(k_size));
448 
449  // Run GEMM Pipeline (same as UniversalGemmKernel, but with descriptor-based windows)
450  using AElementWise = remove_cvref_t<typename GemmPipeline::AElementWise>;
451  using BElementWise = remove_cvref_t<typename GemmPipeline::BElementWise>;
452 
453  const auto& c_block_tile = GemmPipeline{}(
454  a_block_window, AElementWise{}, b_block_window, BElementWise{}, num_loop, smem_ptr);
455 
456  // Create D windows from descriptors (for each D tensor)
457  auto ds_block_windows = generate_tuple(
458  [&](auto i) {
459  using DDataType = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
460  const DDataType* d_ptr = static_cast<const DDataType*>(ds_ptr[i]);
461 
462  auto d_tensor_view =
463  make_tensor_view<address_space_enum::global>(d_ptr, kargs.ds_grid_desc_m_n[i]);
464 
465  return make_tile_window(d_tensor_view,
466  make_tuple(number<TilePartitioner::MPerBlock>{},
467  number<TilePartitioner::NPerBlock>{}),
468  {i_m, i_n});
469  },
470  number<NumDTensor>{});
471 
472  // Run Epilogue Pipeline with descriptor-based D windows
473  EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_windows, smem_ptr);
474  }
475 
476  CK_TILE_HOST static constexpr KernelArgs
477  MakeKernelArgs(const BatchedContractionHostArgs<NumDTensor>& host_args)
478  {
479  const auto expected_A_dims = NumDimG + NumDimM + NumDimK;
480  const auto expected_B_dims = NumDimG + NumDimN + NumDimK;
481  const auto expected_E_dims = NumDimG + NumDimM + NumDimN;
482 
483  if(host_args.A_dims.size() != expected_A_dims ||
484  host_args.A_strides.size() != expected_A_dims)
485  {
486  throw std::invalid_argument("A dimension size mismatch");
487  }
488  if(host_args.B_dims.size() != expected_B_dims ||
489  host_args.B_strides.size() != expected_B_dims)
490  {
491  throw std::invalid_argument("B dimension size mismatch");
492  }
493  if(host_args.E_dims.size() != expected_E_dims ||
494  host_args.E_strides.size() != expected_E_dims)
495  {
496  throw std::invalid_argument("E dimension size mismatch");
497  }
498 
499  for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
500  {
501  if(host_args.Ds_dims[d].size() != expected_E_dims ||
502  host_args.Ds_strides[d].size() != expected_E_dims)
503  {
504  throw std::invalid_argument("D dimension size mismatch");
505  }
506  }
507 
508  KernelArgs kargs;
509  kargs.a_ptr = host_args.a_ptr;
510  kargs.b_ptr = host_args.b_ptr;
511  kargs.ds_ptr = host_args.ds_ptr;
512  kargs.e_ptr = host_args.e_ptr;
513  kargs.k_batch = host_args.k_batch;
514 
515  // Validate and set G dimensions (must be identical across all tensors)
516  for(ck_tile::index_t i = 0; i < NumDimG; ++i)
517  {
518  // All tensors must have same G dimensions for valid contraction
519  if(host_args.A_dims[i] != host_args.B_dims[i] ||
520  host_args.A_dims[i] != host_args.E_dims[i])
521  {
522  throw std::invalid_argument(
523  "All tensors must have identical G dimensions for valid contraction");
524  }
525 
526  // Store G dimensions (same for all tensors)
527  kargs.G_dims[i] = host_args.A_dims[i];
528  }
529 
530  // Set batch strides from the stride of last G dimension
531  kargs.batch_stride_A = host_args.A_strides[NumDimG - 1];
532  kargs.batch_stride_B = host_args.B_strides[NumDimG - 1];
533  kargs.batch_stride_E = host_args.E_strides[NumDimG - 1];
534 
535  for(ck_tile::index_t i = 0; i < NumDimM; ++i)
536  {
537  kargs.M_dims[i] = host_args.A_dims[NumDimG + i];
538  if(kargs.M_dims[i] != host_args.E_dims[NumDimG + i])
539  {
540  throw std::invalid_argument("M dimension mismatch between A and E tensors");
541  }
542  }
543  for(ck_tile::index_t i = 0; i < NumDimN; ++i)
544  {
545  kargs.N_dims[i] = host_args.B_dims[NumDimG + i];
546  if(kargs.N_dims[i] != host_args.E_dims[NumDimG + NumDimM + i])
547  {
548  throw std::invalid_argument("N dimension mismatch between B and E tensors");
549  }
550  }
551  for(ck_tile::index_t i = 0; i < NumDimK; ++i)
552  {
553  kargs.K_dims[i] = host_args.A_dims[NumDimG + NumDimM + i];
554  if(kargs.K_dims[i] != host_args.B_dims[NumDimG + NumDimN + i])
555  {
556  throw std::invalid_argument("K dimension mismatch between A and B tensors");
557  }
558  }
559 
560  // Calculate total dimensions from individual dimension arrays
561  kargs.G_total = 1;
562  for(ck_tile::index_t i = 0; i < NumDimG; ++i)
563  {
564  kargs.G_total *= kargs.G_dims[i];
565  }
566 
567  kargs.M_total = 1;
568  for(ck_tile::index_t i = 0; i < NumDimM; ++i)
569  {
570  kargs.M_total *= kargs.M_dims[i];
571  }
572 
573  kargs.N_total = 1;
574  for(ck_tile::index_t i = 0; i < NumDimN; ++i)
575  {
576  kargs.N_total *= kargs.N_dims[i];
577  }
578 
579  kargs.K_total = 1;
580  for(ck_tile::index_t i = 0; i < NumDimK; ++i)
581  {
582  kargs.K_total *= kargs.K_dims[i];
583  }
584 
585  // Create tensor descriptors on host using actual dims and strides
586  kargs.a_grid_desc_m_k =
587  DescriptorUtils::Make_A_GridDescriptor_M_K(host_args.A_dims, host_args.A_strides);
588  kargs.b_grid_desc_n_k =
589  DescriptorUtils::Make_B_GridDescriptor_N_K(host_args.B_dims, host_args.B_strides);
590  kargs.e_grid_desc_m_n =
591  DescriptorUtils::Make_E_GridDescriptor_M_N(host_args.E_dims, host_args.E_strides);
592 
593  // Create D descriptors with their own strides (same shape as E, independent strides)
594  for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
595  {
596  kargs.ds_grid_desc_m_n[d] = DescriptorUtils::Make_E_GridDescriptor_M_N(
597  host_args.Ds_dims[d], host_args.Ds_strides[d]);
598  }
599 
600  // Keep simple strides for backward compatibility
601  kargs.stride_A = kargs.K_total;
602  kargs.stride_B = kargs.K_total;
603  kargs.stride_E = kargs.N_total;
604 
605  // Validate D tensors have same G dimensions and set their batch strides
606  for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
607  {
608  for(ck_tile::index_t i = 0; i < NumDimG; ++i)
609  {
610  if(host_args.Ds_dims[d][i] != host_args.A_dims[i])
611  {
612  throw std::invalid_argument(
613  "D tensor G dimensions must match A/B/E tensor G dimensions");
614  }
615  }
616  // Set batch stride for D tensor
617  kargs.batch_stride_Ds[d] = host_args.Ds_strides[d][NumDimG - 1];
618  kargs.stride_Ds[d] = kargs.N_total; // D tensors same shape as E
619  }
620 
621  return kargs;
622  }
623 
624  CK_TILE_DEVICE void operator()(const KernelArgs& kargs) const
625  {
626 
627  const auto [iM, iN] =
628  TilePartitioner{kargs.M_total, kargs.N_total}.GetOutputTileIndex(blockIdx.x);
629  const ck_tile::index_t i_m =
630  __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
631  const ck_tile::index_t i_n =
632  __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
633 
634  const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y);
635  [[maybe_unused]] const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
636 
637  // Calculate batch offsets for each tensor
638  const auto batch_offset_A = i_batch_flat * kargs.batch_stride_A;
639  const auto batch_offset_B = i_batch_flat * kargs.batch_stride_B;
640  const auto batch_offset_E = i_batch_flat * kargs.batch_stride_E;
641 
642  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) + batch_offset_A;
643  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B;
644  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr) + batch_offset_E;
645 
646  std::array<const void*, NumDTensor> ds_batch_ptr;
647  static_for<0, NumDTensor, 1>{}([&](auto i) {
648  using DDataType = typename std::tuple_element<i.value, DsDataType>::type;
649  const auto batch_offset_D = i_batch_flat * kargs.batch_stride_Ds[i];
650  ds_batch_ptr[i] = static_cast<const DDataType*>(kargs.ds_ptr[i]) + batch_offset_D;
651  });
652 
653  // Allocate shared memory
654  __shared__ char smem_ptr[GetSmemSize()];
655 
656  // Use UniversalGemmKernel's SplitKBatchOffset for split-K calculation
657  typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr},
658  {b_ptr},
659  ds_batch_ptr,
660  e_ptr,
661  kargs.M_total,
662  kargs.N_total,
663  kargs.K_total,
664  {kargs.stride_A},
665  {kargs.stride_B},
666  kargs.stride_Ds,
667  kargs.stride_E,
668  kargs.k_batch};
669 
670  const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs,
671  i_splitk);
672 
673  // Apply K-split offsets and run descriptor-based RunGemm
674  const ADataType* a_ptr_split = a_ptr + splitk_batch_offset.as_k_split_offset[0];
675  const BDataType* b_ptr_split = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
676 
677  RunGemm(a_ptr_split,
678  b_ptr_split,
679  ds_batch_ptr,
680  e_ptr,
681  smem_ptr,
682  kargs,
683  splitk_batch_offset.splitted_k,
684  i_m,
685  i_n);
686  }
687 };
688 
689 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GemmPipeline
Definition: gemm_pipelines.hpp:9
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
The Universal GEMM kernel template.
Definition: universal_gemm_kernel.hpp:154
UniversalGemmKernelArgs< AsLayout::size(), BsLayout::size(), DsLayout::size()> KernelArgs
Definition: universal_gemm_kernel.hpp:258
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:373
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: universal_gemm_kernel.hpp:319
static constexpr index_t kBlockSize
Definition: universal_gemm_kernel.hpp:202
Utility functions for creating tensor descriptors in batched contraction operations.