/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp Source File
universal_gemm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
16 
17 namespace ck_tile {
18 
30 template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
32 {
33  CK_TILE_HOST UniversalGemmHostArgs(const std::array<const void*, NumATensor>& as_ptr_,
34  const std::array<const void*, NumBTensor>& bs_ptr_,
35  const std::array<const void*, NumDTensor>& ds_ptr_,
36  void* e_ptr_,
37  index_t k_batch_,
38  index_t M_,
39  index_t N_,
40  index_t K_,
41  const std::array<index_t, NumATensor>& stride_As_,
42  const std::array<index_t, NumBTensor>& stride_Bs_,
43  const std::array<index_t, NumDTensor>& stride_Ds_,
44  index_t stride_E_)
45  : as_ptr(as_ptr_),
46  bs_ptr(bs_ptr_),
47  ds_ptr(ds_ptr_),
48  e_ptr(e_ptr_),
49  M(M_),
50  N(N_),
51  K(K_),
52  stride_As(stride_As_),
53  stride_Bs(stride_Bs_),
54  stride_Ds(stride_Ds_),
55  stride_E(stride_E_),
56  k_batch(k_batch_)
57  {
58  }
59 
60  const std::array<const void*, NumATensor> as_ptr;
61  const std::array<const void*, NumBTensor> bs_ptr;
62  const std::array<const void*, NumDTensor> ds_ptr;
63  union
64  {
65  void* e_ptr;
66  void* c_ptr;
67  };
71  const std::array<index_t, NumATensor> stride_As;
72  const std::array<index_t, NumBTensor> stride_Bs;
73  const std::array<index_t, NumDTensor> stride_Ds;
74  union
75  {
78  };
79 
81 };
82 
84 template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
86 {
88  const std::array<const void*, NumATensor> as_ptr;
90  const std::array<const void*, NumBTensor> bs_ptr;
92  const std::array<const void*, NumDTensor> ds_ptr;
94  void* e_ptr;
103  std::array<index_t, NumATensor> stride_As;
106  std::array<index_t, NumBTensor> stride_Bs;
109  std::array<index_t, NumDTensor> stride_Ds;
114 };
115 
152 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
154 {
158 
159  static constexpr bool ADataTypeIsTuple =
161  static constexpr bool BDataTypeIsTuple =
163  static constexpr bool DDataTypeIsTuple =
165  static constexpr bool ALayoutIsTuple =
167  static constexpr bool BLayoutIsTuple =
169  static constexpr bool DLayoutIsTuple =
171 
178 
182 
186 
190 
191  using DsDataType =
195 
198 
199  static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
200 
201  // Get the persistent kernel if the pipeline has it available
203  {
204  template <typename T>
205  using has_persistent_type = decltype(T::UsePersistentKernel);
206 
207  static constexpr bool value = []() {
209  return GemmPipeline::UsePersistentKernel;
210  else
211  return false;
212  }();
213  };
215 
216  static constexpr auto I0 = number<0>();
217  static constexpr auto I1 = number<1>();
218  static constexpr auto I2 = number<2>();
219  static constexpr auto I3 = number<3>{};
220 
221  static constexpr index_t NumATensor = AsDataType::size();
222  static constexpr index_t NumBTensor = BsDataType::size();
223  static constexpr index_t NumDTensor = DsDataType::size();
224 
227 
228  static_assert(AsLayout::size() == AsDataType::size(),
229  "The size of AsLayout and AsDataType should be the same");
230 
231  static_assert(BsLayout::size() == BsDataType::size(),
232  "The size of BsLayout and BsDataType should be the same");
233 
234  static_assert(DsLayout::size() == DsDataType::size(),
235  "The size of DsLayout and DsDataType should be the same");
236 
237  using KernelArgs =
238  UniversalGemmKernelArgs<AsLayout::size(), BsLayout::size(), DsLayout::size()>;
239 
240  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
241  {
242  // clang-format off
243  return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>(), GemmPipeline::GetName());
244  // clang-format on
245  }
246 
247  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
248  {
249  return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
250  }
251 
258  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
259  {
261  const auto kernel = kentry<KernelBlockSize, 1, Kernel, KernelArgs>;
262  int occupancy;
264  hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0));
265  const int grid_size = get_available_compute_units(s) * occupancy;
266  return dim3(grid_size, 1, 1);
267  }
268 
269  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
270 
271  CK_TILE_HOST static constexpr KernelArgs
273  {
274  return KernelArgs{hostArgs.as_ptr,
275  hostArgs.bs_ptr,
276  hostArgs.ds_ptr,
277  hostArgs.e_ptr,
278  hostArgs.M,
279  hostArgs.N,
280  hostArgs.K,
281  hostArgs.stride_As,
282  hostArgs.stride_Bs,
283  hostArgs.stride_Ds,
284  hostArgs.stride_E,
285  hostArgs.k_batch};
286  }
287 
289  {
290  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
291  }
292 
294  {
295  __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
296  {
297  constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
298  const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1);
299  const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1);
300 
301  static_for<0, NumATensor, 1>{}([&](auto index) {
302  using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
303  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, AiLayout>)
304  {
305  as_k_split_offset[index] = __builtin_amdgcn_readfirstlane(k_id * KRead);
306  }
307  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, AiLayout>)
308  {
309  as_k_split_offset[index] =
310  __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_As[index]);
311  }
312  });
313 
314  static_for<0, NumBTensor, 1>{}([&](auto index) {
315  using BiLayout = remove_cvref_t<std::tuple_element_t<index.value, BsLayout>>;
316  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BiLayout>)
317  {
318  bs_k_split_offset[index] =
319  __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_Bs[index]);
320  }
321  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BiLayout>)
322  {
323  bs_k_split_offset[index] = __builtin_amdgcn_readfirstlane(k_id * KRead);
324  }
325  });
326 
327  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
328  {
329  splitted_k = __builtin_amdgcn_readfirstlane(KRead);
330  }
331  else
332  {
333  splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1));
334  }
335  }
336 
337  std::array<index_t, NumATensor> as_k_split_offset;
338  std::array<index_t, NumBTensor> bs_k_split_offset;
340  };
341 
342  CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
343  {
344  if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
346  {
347  if(kargs.k_batch != 1)
348  {
349  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
350  {
351  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
352  }
353  return false;
354  }
355  }
356 
357  bool AsTesnorIsValid = {true};
358  static_for<0, NumATensor, 1>{}([&](auto index) {
359  using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
360  if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
361  {
362  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
363  GemmPipeline::kPadK == false)
364  {
365  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
366  {
368  "Can't support K that is not a multiple of k_batch * KPerBlock "
369  "without padding!");
370  }
371  AsTesnorIsValid = false;
372  }
373  if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
374  {
375  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
376  {
377  CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
378  }
379  AsTesnorIsValid = false;
380  }
381  }
382  else
383  {
384  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
385  {
386  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
387  {
389  "Can't support M that is not a multiple of MPerBlock without padding!");
390  }
391  AsTesnorIsValid = false;
392  }
393  if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
394  {
395  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
396  {
397  CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
398  }
399  AsTesnorIsValid = false;
400  }
401  }
402  });
403 
404  bool BsTesnorIsValid = {true};
405  static_for<0, NumBTensor, 1>{}([&](auto index) {
406  using BiLayout = remove_cvref_t<std::tuple_element_t<index.value, BsLayout>>;
407  if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
408  {
409  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
410  {
411  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
412  {
414  "Can't support N that is not a multiple of NPerBlock without padding!");
415  }
416  BsTesnorIsValid = false;
417  }
418  if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
419  {
420  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
421  {
422  CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
423  }
424  BsTesnorIsValid = false;
425  }
426  }
427  else
428  {
429  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
430  GemmPipeline::kPadK == false)
431  {
432  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
433  {
435  "Can't support K that is not a multiple of k_batch * KPerBlock "
436  "without padding!");
437  }
438  BsTesnorIsValid = false;
439  }
440  if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
441  {
442  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
443  {
444  CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
445  }
446  BsTesnorIsValid = false;
447  }
448  }
449  });
450 
451  bool DTesnorIsValid = {true};
452  static_for<0, NumDTensor, 1>{}([&](auto index) {
453  using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
454  if(std::is_same_v<DiLayout, ELayout> == false)
455  {
456  DTesnorIsValid = false;
457  }
458  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
459  {
460  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
461  {
462  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
463  {
464  CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
465  "NPerBlock without padding!");
466  }
467  DTesnorIsValid = false;
468  }
469  if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
470  {
471  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
472  {
473  CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
474  }
475  DTesnorIsValid = false;
476  }
477  }
478  else
479  {
480  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
481  {
482  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
483  {
484  CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
485  "MPerBlock without padding!");
486  }
487  DTesnorIsValid = false;
488  }
489  if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
490  {
491  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
492  {
493  CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
494  }
495  DTesnorIsValid = false;
496  }
497  }
498  });
499 
500  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
501  {
502  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
503  {
504  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
505  {
507  "Can't support N that is not a multiple of NPerBlock without padding!");
508  }
509  return false;
510  }
511  if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
512  {
513  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
514  {
515  CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
516  }
517  return false;
518  }
519  }
520  else
521  {
522  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
523  {
524  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
525  {
527  "Can't support M that is not a multiple of MPerBlock without padding!");
528  }
529  return false;
530  }
531  if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
532  {
533  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
534  {
535  CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
536  }
537  return false;
538  }
539  }
540  return AsTesnorIsValid && BsTesnorIsValid && DTesnorIsValid;
541  }
542 
543  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
544  CK_TILE_DEVICE static auto
545  MakeGemmTensorViews(const std::array<const ADataType*, NumATensor>& as_ptr,
546  const std::array<const BDataType*, NumBTensor>& bs_ptr,
547  const std::array<const void*, NumDTensor>& ds_ptr,
548  EDataType* e_ptr,
549  const KernelArgs& kargs,
550  const SplitKBatchOffset& splitk_batch_offset)
551  {
552  static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
553 
554  const auto& as_tensor_view = generate_tuple(
555  [&](auto i) {
556  using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
557  using AiDataType = remove_cvref_t<std::tuple_element_t<i.value, AsDataType>>;
558  if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
559  {
560  return make_naive_tensor_view<address_space_enum::global>(
561  static_cast<const AiDataType*>(as_ptr[i]),
562  make_tuple(kargs.M, splitk_batch_offset.splitted_k),
563  make_tuple(kargs.stride_As[i], 1),
564  number<GemmPipeline::GetVectorSizeA()>{},
565  number<1>{});
566  }
567  else
568  {
569  return make_naive_tensor_view<address_space_enum::global>(
570  static_cast<const AiDataType*>(as_ptr[i]),
571  make_tuple(splitk_batch_offset.splitted_k, kargs.M),
572  make_tuple(kargs.stride_As[i], 1),
573  number<GemmPipeline::GetVectorSizeA()>{},
574  number<1>{});
575  }
576  },
578 
579  const auto& bs_tensor_view = generate_tuple(
580  [&](auto i) {
581  using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
582  using BiDataType = remove_cvref_t<std::tuple_element_t<i.value, BsDataType>>;
583  if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
584  {
585  if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
586  {
587  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
588  const index_t K0 = splitk_batch_offset.splitted_k / K1;
589  constexpr index_t VectorSizeB =
590  std::min(K1, GemmPipeline::GetVectorSizeB());
591  const auto b_k0_n_k1_desc =
593  make_tuple(kargs.N * K1, K1, I1),
595  number<1>{});
596  const auto b_n_k_desc = transform_tensor_descriptor(
597  b_k0_n_k1_desc,
602  return make_tensor_view<address_space_enum::global>(
603  static_cast<const BiDataType*>(bs_ptr[i]), b_n_k_desc);
604  }
605  else
606  {
607  return make_naive_tensor_view<address_space_enum::global>(
608  bs_ptr[i],
609  make_tuple(splitk_batch_offset.splitted_k, kargs.N),
610  make_tuple(kargs.stride_Bs[i], 1),
611  number<GemmPipeline::GetVectorSizeB()>{},
612  number<1>{});
613  }
614  }
615  else
616  {
617  if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
618  {
619  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
620  const index_t K0 = splitk_batch_offset.splitted_k / K1;
621  constexpr index_t VectorSizeB =
622  std::min(K1, GemmPipeline::GetVectorSizeB());
623  const auto b_k0_n_k1_desc =
625  make_tuple(kargs.N * K1, K1, I1),
627  number<1>{});
628  const auto b_n_k_desc = transform_tensor_descriptor(
629  b_k0_n_k1_desc,
634  return make_tensor_view<address_space_enum::global>(
635  static_cast<const BiDataType*>(bs_ptr[i]), b_n_k_desc);
636  }
637  else
638  {
639  if constexpr(GemmPipeline::Preshuffle)
640  {
641  index_t kFlatK =
642  GemmPipeline::BlockGemmShape::flatKPerWarp *
643  (splitk_batch_offset.splitted_k /
644  TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}));
645  index_t kFlatN = kargs.N * kargs.K / kFlatK;
646 
647  return make_naive_tensor_view<address_space_enum::global>(
648  bs_ptr[i],
649  make_tuple(kFlatN, kFlatK),
650  make_tuple(kFlatK, 1),
651  number<GemmPipeline::GetVectorSizeB()>{},
652  number<1>{});
653  }
654  else
655  {
656  return make_naive_tensor_view<address_space_enum::global>(
657  bs_ptr[i],
658  make_tuple(kargs.N, splitk_batch_offset.splitted_k),
659  make_tuple(kargs.stride_Bs[i], 1),
660  number<GemmPipeline::GetVectorSizeB()>{},
661  number<1>{});
662  }
663  }
664  }
665  },
667 
668  const auto& ds_tensor_view = generate_tuple(
669  [&](auto i) {
670  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
671  using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
672  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
673  {
674  return make_naive_tensor_view<address_space_enum::global>(
675  static_cast<const DDataType_*>(ds_ptr[i]),
676  make_tuple(kargs.M, kargs.N),
677  make_tuple(kargs.stride_Ds[i], 1),
678  number<EpiloguePipeline::GetVectorSizeD(i)>{},
679  number<1>{});
680  }
681  else
682  {
683  return make_naive_tensor_view<address_space_enum::global>(
684  static_cast<const DDataType_*>(ds_ptr[i]),
685  make_tuple(kargs.N, kargs.M),
686  make_tuple(kargs.stride_Ds[i], 1),
687  number<EpiloguePipeline::GetVectorSizeD(i)>{},
688  number<1>{});
689  }
690  },
692 
693  // TODO: enable vector write for C in ColMajor
694  const auto& e_tensor_view = [&]() {
695  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
696  {
697  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
698  e_ptr,
699  make_tuple(kargs.M, kargs.N), // arguments not matching with flatmm.
700  make_tuple(kargs.stride_E, 1),
701  number<EpiloguePipeline::GetVectorSizeC()>{},
702  number<1>{});
703  }
704  else
705  {
706  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
707  e_ptr,
708  make_tuple(kargs.M, kargs.N),
709  make_tuple(1, kargs.stride_E),
710  number<1>{},
711  number<1>{});
712  }
713  }();
714 
715  return make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view);
716  }
717 
718  template <typename TensorView>
719  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
720  {
721  const auto& as_pad_view = generate_tuple(
722  [&](auto i) {
723  const auto& a_tensor_view = views.at(I0);
724  using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
725  if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
726  {
727  return pad_tensor_view(a_tensor_view[i],
731  }
732  else
733  {
734  return pad_tensor_view(a_tensor_view[i],
738  }
739  },
741 
742  const auto& b_flat_pad_view = views.at(I1);
743 
744  const auto& bs_pad_view = generate_tuple(
745  [&](auto i) {
746  const auto& b_tensor_view = views.at(I1);
747  using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
748  if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
749  {
750  return pad_tensor_view(b_tensor_view[i],
754  }
755  else
756  {
757  return pad_tensor_view(b_tensor_view[i],
761  }
762  },
764 
765  const auto& ds_pad_view = generate_tuple(
766  [&](auto i) {
767  const auto& d_tensor_view = views.at(I2);
768  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
769  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
770  {
771  return pad_tensor_view(d_tensor_view[i],
775  }
776  else
777  {
778  return pad_tensor_view(d_tensor_view[i],
782  }
783  },
785 
786  // TODO vector write in for C in ColMajor
787  const auto& e_pad_view = [&]() {
788  const auto& e_tensor_view = views.at(I3);
789  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
790  {
791  return pad_tensor_view(e_tensor_view,
795  }
796  else
797  {
798  return pad_tensor_view(e_tensor_view,
802  }
803  }();
804 
805  if constexpr(GemmPipeline::Preshuffle)
806  {
807  // For flatmm, we need to use the flat B tensor view
808  return make_tuple(as_pad_view, b_flat_pad_view, ds_pad_view, e_pad_view);
809  }
810  else
811  {
812  return make_tuple(as_pad_view, bs_pad_view, ds_pad_view, e_pad_view);
813  }
814  }
815 
816  template <typename PadView>
817  CK_TILE_DEVICE static auto
818  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
819  {
820  const auto& as_pad_view = views.at(I0);
821  const auto& bs_pad_view = views.at(I1);
822  const auto& ds_pad_view = views.at(I2);
823  const auto& e_pad_view = views.at(I3);
824 
825  const auto& as_block_window = generate_tuple(
826  [&](auto i) {
827  using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
828  if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
829  {
830  return make_tile_window(as_pad_view[i],
833  {i_m, 0});
834  }
835  else
836  {
837  return make_tile_window(as_pad_view[i],
840  {0, i_m});
841  }
842  },
844 
845  const auto& bs_block_window = generate_tuple(
846  [&](auto i) {
847  using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
848  if constexpr(GemmPipeline::Preshuffle)
849  {
850  return make_tile_window(
851  bs_pad_view[i],
854  {static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)),
855  0});
856  }
857  else
858  {
859  if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
860  {
861  return make_tile_window(bs_pad_view[i],
864  {i_n, 0});
865  }
866  else
867  {
868  return make_tile_window(bs_pad_view[i],
871  {0, i_n});
872  }
873  }
874  },
876 
877  const auto ds_block_window = generate_tuple(
878  [&](auto i) {
879  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
880  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
881  {
882  return make_tile_window(ds_pad_view[i],
885  {i_m, i_n});
886  }
887  else
888  {
889  return make_tile_window(ds_pad_view[i],
892  {i_n, i_m});
893  }
894  },
896 
897  auto e_block_window = make_tile_window(
898  e_pad_view,
900  {i_m, i_n});
901 
902  return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window);
903  }
904 
919  template <bool UseDefaultScheduler = true>
920  CK_TILE_DEVICE static void RunGemm(const std::array<const ADataType*, NumATensor>& as_ptr,
921  const std::array<const BDataType*, NumBTensor>& bs_ptr,
922  const std::array<const void*, NumDTensor>& ds_ptr,
923  EDataType* e_ptr,
924  void* smem_ptr_0,
925  const KernelArgs& kargs,
926  const SplitKBatchOffset& splitk_batch_offset,
927  const index_t block_idx_m,
928  const index_t block_idx_n)
929  {
930  // Create Gemm tensor views, pad views and tile windows
931  const auto& gemm_tensor_views_tuple =
932  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
933  as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
934 
935  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
936  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
937 
938  const index_t num_loop = __builtin_amdgcn_readfirstlane(
939  TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
940 
941  // Run GEMM cooperatively by whole workgroup.
942  const auto& as_block_window = gemm_tile_windows.at(I0);
943  const auto& bs_block_window = gemm_tile_windows.at(I1);
944  const auto& ds_block_window = gemm_tile_windows.at(I2);
945 
946  const auto& c_block_tile = GemmPipeline{}.template operator()(
947  as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0);
948 
949  if(UseDefaultScheduler || (get_warp_id() == 0))
950  {
951  // Run Epilogue Pipeline
952  auto& c_block_window = gemm_tile_windows.at(I3);
953 
954  EpiloguePipeline{}.template
955  operator()<decltype(c_block_window), decltype(c_block_tile), decltype(ds_block_window)>(
956  c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
957  }
958  }
959 
977  CK_TILE_DEVICE static void RunGemm2LDS(const std::array<const ADataType*, NumATensor>& as_ptr,
978  const std::array<const BDataType*, NumBTensor>& bs_ptr,
979  const std::array<const void*, NumDTensor>& ds_ptr,
980  EDataType* e_ptr,
981  void* __restrict__ smem_ptr_0,
982  void* __restrict__ smem_ptr_1,
983  const KernelArgs& kargs,
984  const SplitKBatchOffset& splitk_batch_offset,
985  const index_t block_idx_m,
986  const index_t block_idx_n)
987  {
988  // Create Gemm tensor views, pad views and tile windows
989  const auto& gemm_tensor_views_tuple =
990  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
991  as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
992 
993  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
994  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
995 
996  const index_t num_loop = __builtin_amdgcn_readfirstlane(
997  TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
998 
999  // Run GEMM cooperatively by whole workgroup.
1000  const auto& as_block_window = gemm_tile_windows.at(I0);
1001  const auto& bs_block_window = gemm_tile_windows.at(I1);
1002  const auto& ds_block_window = gemm_tile_windows.at(I2);
1003 
1004  const auto& c_block_tile = GemmPipeline{}.template operator()(
1005  as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0, smem_ptr_1);
1006 
1007  // Run Epilogue Pipeline
1008  auto& c_block_window = gemm_tile_windows.at(I3);
1009 
1010  EpiloguePipeline{}.template
1011  operator()<decltype(c_block_window), decltype(c_block_tile), decltype(ds_block_window)>(
1012  c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
1013  }
1014 
1015  // Non-persistent kernel entry point
1016  template <bool U = !PersistentKernel, typename = std::enable_if_t<U>>
1018  {
1019  const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
1020  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
1021  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
1022  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
1023 
1024  const SplitKBatchOffset splitk_batch_offset(kargs);
1025 
1026  // options
1027  std::array<const ADataType*, NumATensor> as_ptr;
1028  static_for<0, NumATensor, 1>{}([&](auto i) {
1029  as_ptr[i] = static_cast<const ADataType*>(kargs.as_ptr[i]) +
1030  splitk_batch_offset.as_k_split_offset[i];
1031  });
1032 
1033  std::array<const BDataType*, NumBTensor> bs_ptr;
1034  static_for<0, NumBTensor, 1>{}([&](auto i) {
1035  bs_ptr[i] = static_cast<const BDataType*>(kargs.bs_ptr[i]) +
1036  splitk_batch_offset.bs_k_split_offset[i];
1037  });
1038 
1039  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
1040 
1041  // allocate LDS
1042  __shared__ char smem_ptr_0[GetSmemSize()];
1043 
1044  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
1045  {
1046  __shared__ char smem_ptr_1[GetSmemSize()];
1047  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
1048  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1050  {
1051  RunGemm2LDS(as_ptr,
1052  bs_ptr,
1053  kargs.ds_ptr,
1054  e_ptr,
1055  smem_ptr_0,
1056  smem_ptr_1,
1057  kargs,
1058  splitk_batch_offset,
1059  i_m,
1060  i_n);
1061  }
1062  }
1063  else
1064  {
1065  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
1066  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1068  {
1069  constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
1070  RunGemm<scheduler_type>(as_ptr,
1071  bs_ptr,
1072  kargs.ds_ptr,
1073  e_ptr,
1074  smem_ptr_0,
1075  kargs,
1076  splitk_batch_offset,
1077  i_m,
1078  i_n);
1079  }
1080  }
1081  }
1082 
1083  // Persistent kernel entry point
1084  template <bool U = PersistentKernel, typename = std::enable_if_t<U>, typename = void>
1086  {
1087  const auto grid_size = __builtin_amdgcn_readfirstlane(get_grid_size());
1088  const auto num_tiles =
1089  __builtin_amdgcn_readfirstlane(TilePartitioner::GridSize(kargs.M, kargs.N));
1090  const auto num_work = __builtin_amdgcn_readfirstlane(num_tiles * kargs.k_batch);
1091  auto block_id = __builtin_amdgcn_readfirstlane(get_block_id());
1092 
1093  while(block_id < num_work)
1094  {
1095  // Get the tile index for this block
1096  const auto tile_idx = __builtin_amdgcn_readfirstlane(block_id % num_tiles);
1097  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx);
1098  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
1099  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
1100 
1101  // Get the SplitK offset for this block
1102  const auto k_batch = __builtin_amdgcn_readfirstlane(block_id / num_tiles);
1103  const SplitKBatchOffset splitk_batch_offset(kargs, k_batch);
1104 
1105  std::array<const ADataType*, NumATensor> as_ptr;
1106  static_for<0, NumATensor, 1>{}([&](auto i) {
1107  as_ptr[i] = static_cast<const ADataType*>(kargs.as_ptr[i]) +
1108  splitk_batch_offset.as_k_split_offset[i];
1109  });
1110 
1111  std::array<const BDataType*, NumBTensor> bs_ptr;
1112  static_for<0, NumBTensor, 1>{}([&](auto i) {
1113  bs_ptr[i] = static_cast<const BDataType*>(kargs.bs_ptr[i]) +
1114  splitk_batch_offset.bs_k_split_offset[i];
1115  });
1116 
1117  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
1118 
1119  // allocate LDS
1120  __shared__ char smem_ptr_0[GetSmemSize()];
1121  // Run the GEMM
1122  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
1123  {
1124  __shared__ char smem_ptr_1[GetSmemSize()];
1125  if constexpr(!(EpiloguePipeline::MemoryOperation ==
1127  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1129  {
1130  RunGemm2LDS(as_ptr,
1131  bs_ptr,
1132  kargs.ds_ptr,
1133  e_ptr,
1134  smem_ptr_0,
1135  smem_ptr_1,
1136  kargs,
1137  splitk_batch_offset,
1138  i_m,
1139  i_n);
1140  }
1141  }
1142  else
1143  {
1144  if constexpr(!(EpiloguePipeline::MemoryOperation ==
1146  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1148  {
1149  RunGemm(as_ptr,
1150  bs_ptr,
1151  kargs.ds_ptr,
1152  e_ptr,
1153  smem_ptr_0,
1154  kargs,
1155  splitk_batch_offset,
1156  i_m,
1157  i_n);
1158  }
1159  }
1160  // Advance to the next work item
1161  block_id += grid_size;
1162  if(block_id >= num_work)
1163  {
1164  break;
1165  }
1166  }
1167  }
1168 };
1169 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:255
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1672
CK_TILE_HOST void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:529
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1615
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:184
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:412
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:343
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
__device__ index_t get_grid_size()
Definition: get_id.hpp:27
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__device__ X atomic_add(X *p_dst, const X &x)
The Universal GEMM kernel host arguments.
Definition: universal_gemm_kernel.hpp:32
void * c_ptr
Definition: universal_gemm_kernel.hpp:66
const std::array< index_t, NumDTensor > stride_Ds
Definition: universal_gemm_kernel.hpp:73
const std::array< index_t, NumBTensor > stride_Bs
Definition: universal_gemm_kernel.hpp:72
CK_TILE_HOST UniversalGemmHostArgs(const std::array< const void *, NumATensor > &as_ptr_, const std::array< const void *, NumBTensor > &bs_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, const std::array< index_t, NumATensor > &stride_As_, const std::array< index_t, NumBTensor > &stride_Bs_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition: universal_gemm_kernel.hpp:33
index_t K
Definition: universal_gemm_kernel.hpp:70
void * e_ptr
Definition: universal_gemm_kernel.hpp:65
index_t M
Definition: universal_gemm_kernel.hpp:68
const std::array< const void *, NumDTensor > ds_ptr
Definition: universal_gemm_kernel.hpp:62
const std::array< const void *, NumATensor > as_ptr
Definition: universal_gemm_kernel.hpp:60
const std::array< index_t, NumATensor > stride_As
Definition: universal_gemm_kernel.hpp:71
index_t N
Definition: universal_gemm_kernel.hpp:69
index_t stride_E
Definition: universal_gemm_kernel.hpp:76
const std::array< const void *, NumBTensor > bs_ptr
Definition: universal_gemm_kernel.hpp:61
index_t stride_C
Definition: universal_gemm_kernel.hpp:77
index_t k_batch
Definition: universal_gemm_kernel.hpp:80
Definition: universal_gemm_kernel.hpp:294
std::array< index_t, NumATensor > as_k_split_offset
Definition: universal_gemm_kernel.hpp:337
index_t splitted_k
Definition: universal_gemm_kernel.hpp:339
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: universal_gemm_kernel.hpp:295
std::array< index_t, NumBTensor > bs_k_split_offset
Definition: universal_gemm_kernel.hpp:338
Definition: universal_gemm_kernel.hpp:203
static constexpr bool value
Definition: universal_gemm_kernel.hpp:207
decltype(T::UsePersistentKernel) has_persistent_type
Definition: universal_gemm_kernel.hpp:205
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:86
void * e_ptr
The E output tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:94
std::array< index_t, NumBTensor > stride_Bs
The distance between consecutive elements of non-contiguous dimension (in memory) of Bs tensor.
Definition: universal_gemm_kernel.hpp:106
const std::array< const void *, NumDTensor > ds_ptr
The Ds input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:92
std::array< index_t, NumATensor > stride_As
The distance between consecutive elements of non-contiguous dimension (in memory) of As tensor.
Definition: universal_gemm_kernel.hpp:103
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:88
index_t k_batch
Definition: universal_gemm_kernel.hpp:113
index_t N
GEMM's N dimension size.
Definition: universal_gemm_kernel.hpp:98
index_t stride_E
The distance between consecutive elements of non-contiguous dimension (in memory) of E tensor.
Definition: universal_gemm_kernel.hpp:112
index_t K
GEMM's K dimension size.
Definition: universal_gemm_kernel.hpp:100
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:90
std::array< index_t, NumDTensor > stride_Ds
The distance between consecutive elements of non-contiguous dimension (in memory) of Ds tensor.
Definition: universal_gemm_kernel.hpp:109
index_t M
GEMM's M dimension size.
Definition: universal_gemm_kernel.hpp:96
The Universal GEMM kernel template.
Definition: universal_gemm_kernel.hpp:154
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: universal_gemm_kernel.hpp:1017
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< typename GemmPipeline::BDataType >, remove_cvref_t< tuple< typename GemmPipeline::BDataType > >> BsDataType
Definition: universal_gemm_kernel.hpp:189
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: universal_gemm_kernel.hpp:156
static CK_TILE_HOST const std::string GetName()
Definition: universal_gemm_kernel.hpp:240
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: universal_gemm_kernel.hpp:155
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: universal_gemm_kernel.hpp:1085
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:920
static constexpr bool BDataTypeIsTuple
Definition: universal_gemm_kernel.hpp:161
static constexpr auto I2
Definition: universal_gemm_kernel.hpp:218
static constexpr bool BLayoutIsTuple
Definition: universal_gemm_kernel.hpp:167
static CK_TILE_DEVICE auto MakeGemmTensorViews(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: universal_gemm_kernel.hpp:545
std::conditional_t< BLayoutIsTuple, remove_cvref_t< typename GemmPipeline::BLayout >, remove_cvref_t< tuple< typename GemmPipeline::BLayout > >> BsLayout
Definition: universal_gemm_kernel.hpp:177
static constexpr index_t NumATensor
Definition: universal_gemm_kernel.hpp:221
static constexpr bool ALayoutIsTuple
Definition: universal_gemm_kernel.hpp:165
remove_cvref_t< std::tuple_element_t< I0, AsDataType > > ADataType
Definition: universal_gemm_kernel.hpp:225
static CK_TILE_DEVICE void RunGemm2LDS(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:977
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: universal_gemm_kernel.hpp:818
static constexpr auto I3
Definition: universal_gemm_kernel.hpp:219
std::conditional_t< DDataTypeIsTuple, remove_cvref_t< typename EpiloguePipeline::DsDataType >, remove_cvref_t< tuple< typename EpiloguePipeline::DsDataType > >> DsDataType
Definition: universal_gemm_kernel.hpp:194
static constexpr bool ADataTypeIsTuple
Definition: universal_gemm_kernel.hpp:159
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: universal_gemm_kernel.hpp:719
remove_cvref_t< typename GemmPipeline::CLayout > ELayout
Definition: universal_gemm_kernel.hpp:196
static constexpr index_t NumDTensor
Definition: universal_gemm_kernel.hpp:223
UniversalGemmKernelArgs< AsLayout::size(), BsLayout::size(), DsLayout::size()> KernelArgs
Definition: universal_gemm_kernel.hpp:238
static constexpr bool DDataTypeIsTuple
Definition: universal_gemm_kernel.hpp:163
static constexpr bool PersistentKernel
Definition: universal_gemm_kernel.hpp:214
static constexpr auto I1
Definition: universal_gemm_kernel.hpp:217
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: universal_gemm_kernel.hpp:247
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< typename GemmPipeline::ADataType >, remove_cvref_t< tuple< typename GemmPipeline::ADataType > >> AsDataType
Definition: universal_gemm_kernel.hpp:185
remove_cvref_t< std::tuple_element_t< I0, BsDataType > > BDataType
Definition: universal_gemm_kernel.hpp:226
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: universal_gemm_kernel.hpp:258
static constexpr index_t KernelBlockSize
Definition: universal_gemm_kernel.hpp:199
static constexpr index_t NumBTensor
Definition: universal_gemm_kernel.hpp:222
static constexpr auto I0
Definition: universal_gemm_kernel.hpp:216
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:342
std::conditional_t< ALayoutIsTuple, remove_cvref_t< typename GemmPipeline::ALayout >, remove_cvref_t< tuple< typename GemmPipeline::ALayout > >> AsLayout
Definition: universal_gemm_kernel.hpp:174
std::conditional_t< DLayoutIsTuple, remove_cvref_t< typename EpiloguePipeline::DsLayout >, remove_cvref_t< tuple< typename EpiloguePipeline::DsLayout > >> DsLayout
Definition: universal_gemm_kernel.hpp:181
static constexpr bool DLayoutIsTuple
Definition: universal_gemm_kernel.hpp:169
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: universal_gemm_kernel.hpp:157
static constexpr CK_TILE_HOST auto BlockSize()
Definition: universal_gemm_kernel.hpp:269
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: universal_gemm_kernel.hpp:288
static constexpr CK_TILE_HOST KernelArgs MakeKernelArgs(const UniversalGemmHostArgs< NumATensor, NumBTensor, NumDTensor > &hostArgs)
Definition: universal_gemm_kernel.hpp:272
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: universal_gemm_kernel.hpp:197
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: sequence.hpp:52
Definition: functional.hpp:43
Definition: stream_config.hpp:30
#define CK_TILE_ENV(name)
Definition: env.hpp:145