/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp Source File
device_gemm_wmma.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
20 
21 namespace ck {
22 namespace tensor_operation {
23 namespace device {
24 
25 template <typename ALayout,
26  typename BLayout,
27  typename CLayout,
28  typename ADataType,
29  typename BDataType,
30  typename CDataType,
31  typename AccDataType,
32  typename CShuffleDataType,
33  typename AElementwiseOperation,
34  typename BElementwiseOperation,
35  typename CElementwiseOperation,
36  GemmSpecialization GemmSpec,
37  ck::index_t NumPrefetch,
38  ck::index_t BlockSize,
39  ck::index_t MPerBlock,
40  ck::index_t NPerBlock,
41  ck::index_t KPerBlock,
42  ck::index_t K1,
43  ck::index_t MPerWmma,
44  ck::index_t NPerWmma,
45  ck::index_t MRepeat,
46  ck::index_t NRepeat,
47  typename ABlockTransferThreadClusterLengths_K0_M_K1,
48  typename ABlockTransferThreadClusterArrangeOrder,
49  typename ABlockTransferSrcAccessOrder,
50  ck::index_t ABlockTransferSrcVectorDim,
51  ck::index_t ABlockTransferSrcScalarPerVector,
52  ck::index_t ABlockTransferDstScalarPerVector_K1,
53  bool ABlockLdsAddExtraM,
54  typename BBlockTransferThreadClusterLengths_K0_N_K1,
55  typename BBlockTransferThreadClusterArrangeOrder,
56  typename BBlockTransferSrcAccessOrder,
57  ck::index_t BBlockTransferSrcVectorDim,
58  ck::index_t BBlockTransferSrcScalarPerVector,
59  ck::index_t BBlockTransferDstScalarPerVector_K1,
60  bool BBlockLdsAddExtraN,
61  index_t CShuffleMRepeatPerShuffle,
62  index_t CShuffleNRepeatPerShuffle,
63  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
64  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
67 struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
68  BLayout,
69  CLayout,
70  ADataType,
71  BDataType,
72  CDataType,
73  AElementwiseOperation,
74  BElementwiseOperation,
75  CElementwiseOperation>
76 {
77  static constexpr auto I0 = Number<0>{};
78  static constexpr auto I1 = Number<1>{};
79  static constexpr auto I2 = Number<2>{};
80  static constexpr auto I3 = Number<3>{};
81  static constexpr auto I4 = Number<4>{};
82  static constexpr auto I5 = Number<5>{};
83  static constexpr auto I6 = Number<6>{};
84  // K1 = Max Vector Access Pixels
85  static constexpr auto K1Number = Number<K1>{};
86 
87  static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
88  static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
89  static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
90  static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false;
91  static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false;
92 
93  static constexpr auto AEnableLds_auto = (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1) &&
95  ? false
96  : true;
97  static constexpr auto BEnableLds_auto =
98  (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1) &&
100  ? false
101  : true;
102 
103  // If true, LDS is used unconditionally
104  static constexpr auto AEnableLds_manu = false;
105  static constexpr auto BEnableLds_manu = false;
106 
107  static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
108  static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
109 
110  static constexpr auto matrix_padder =
111  MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
112  // Describe how data read from Global memory
113  static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
114  {
115  const auto a_grid_desc_m_k = [&]() {
117  {
118  const auto a_grid_desc_mraw_kraw =
119  make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1));
120 
121  return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
122  }
124  {
125  const auto a_grid_desc_mraw_kraw =
126  make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA));
127 
128  return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
129  }
130  }();
131 
132  const auto M = a_grid_desc_m_k.GetLength(I0);
133  const auto K = a_grid_desc_m_k.GetLength(I1);
134  assert(K % K1 == 0);
135 
136  if constexpr(AEnableLds)
137  {
138  const index_t K0 = K / K1;
139 
141  a_grid_desc_m_k,
146  }
147  else
148  {
149  constexpr auto A_KRow = 2;
150  constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
151  const auto A_KWmma = K / WmmaK;
152 
153  const auto M0 = M / MPerBlock;
154  // 0 1 0 1 2 3 4 5 6
155  // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
157  a_grid_desc_m_k,
161  make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
164  }
165  }
166 
167  static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
168  {
169  const auto b_grid_desc_n_k = [&]() {
171  {
172  const auto b_grid_desc_nraw_kraw =
173  make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB));
174 
175  return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
176  }
177  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
178  {
179  const auto b_grid_desc_nraw_kraw =
180  make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1));
181 
182  return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
183  }
184  }();
185 
186  const auto N = b_grid_desc_n_k.GetLength(I0);
187  const auto K = b_grid_desc_n_k.GetLength(I1);
188  assert(K % K1 == 0);
189 
190  if constexpr(BEnableLds)
191  {
192  const index_t K0 = K / K1;
193 
195  b_grid_desc_n_k,
200  }
201  else
202  {
203  constexpr auto B_KRow = 2;
204  constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
205  const auto B_KWmma = K / WmmaK;
206 
207  const auto N0 = N / NPerBlock;
208  // 0 1 0 1 2 3 4 5 6
209  // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
211  b_grid_desc_n_k,
215  make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
218  }
219  }
220 
221  static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
222  {
223  const auto c_grid_desc_mraw_nraw = [&]() {
225  {
226  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
227  make_tuple(StrideC, I1));
228  }
230  {
231  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
232  make_tuple(I1, StrideC));
233  }
234  }();
235 
236  return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
237  }
238 
239  // Gridwise descriptor, mapping to whole given provblem.
240  using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
241  using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1));
242  using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
243 
244  // GridwiseGemm
245  using GridwiseGemm =
246  GridwiseGemm_Wmma<BlockSize,
247  ADataType,
248  BDataType,
249  AccDataType,
250  CShuffleDataType,
251  CDataType,
253  AGridDesc,
254  BGridDesc,
256  AElementwiseOperation,
257  BElementwiseOperation,
258  CElementwiseOperation,
259  MPerBlock,
260  NPerBlock,
261  KPerBlock,
262  MPerWmma,
263  NPerWmma,
264  K1,
265  MRepeat,
266  NRepeat,
267  ABlockTransferThreadClusterLengths_K0_M_K1,
268  ABlockTransferThreadClusterArrangeOrder,
269  ABlockTransferSrcAccessOrder,
270  ABlockTransferSrcVectorDim,
271  ABlockTransferSrcScalarPerVector,
272  ABlockTransferDstScalarPerVector_K1,
273  false, // AThreadTransferSrcResetCoordinateAfterRun,
274  AEnableLds,
275  ABlockLdsAddExtraM,
276  BBlockTransferThreadClusterLengths_K0_N_K1,
277  BBlockTransferThreadClusterArrangeOrder,
278  BBlockTransferSrcAccessOrder,
279  BBlockTransferSrcVectorDim,
280  BBlockTransferSrcScalarPerVector,
281  BBlockTransferDstScalarPerVector_K1,
282  false, // BThreadTransferSrcResetCoordinateAfterRun,
283  BEnableLds,
284  BBlockLdsAddExtraN,
285  CShuffleMRepeatPerShuffle,
286  CShuffleNRepeatPerShuffle,
287  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
288  CShuffleBlockTransferScalarPerVector_NPerBlock,
289  NumPrefetch,
290  LoopSched,
291  PipelineVer>;
292 
293  // Argument
294  struct Argument : public BaseArgument
295  {
296  Argument(const ADataType* p_a_grid,
297  const BDataType* p_b_grid,
298  CDataType* p_c_grid,
299  index_t M,
300  index_t N,
301  index_t K,
302  index_t StrideA,
303  index_t StrideB,
304  index_t StrideC,
305  index_t M01,
306  index_t N01,
307  AElementwiseOperation a_element_op,
308  BElementwiseOperation b_element_op,
309  CElementwiseOperation c_element_op)
310  : p_a_grid_{p_a_grid},
311  p_b_grid_{p_b_grid},
312  p_c_grid_{p_c_grid},
313  a_grid_desc_{},
318  M01_{M01},
319  N01_{N01},
320  a_element_op_{a_element_op},
321  b_element_op_{b_element_op},
322  c_element_op_{c_element_op},
323  MRaw_{M},
324  NRaw_{N},
325  KRaw_{K}
326  {
330 
333 
336  {
340  }
341  }
342 
343  // private:
344  const ADataType* p_a_grid_;
345  const BDataType* p_b_grid_;
346  CDataType* p_c_grid_;
355  AElementwiseOperation a_element_op_;
356  BElementwiseOperation b_element_op_;
357  CElementwiseOperation c_element_op_;
358  // for checking vector load/store
362  };
363 
364  // Invoker
365  struct Invoker : public BaseInvoker
366  {
368 
369  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
370  {
373  arg.c_grid_desc_m_n_,
374  arg.block_2_ctile_map_))
375  {
376  throw std::runtime_error(
377  "wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting");
378  }
379 
380  const index_t grid_size =
381  arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
382 
383  const auto K = [&]() {
384  if constexpr(AEnableLds)
385  {
386  return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
387  }
388  else
389  {
390  return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
391  arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
392  }
393  }();
394  auto launch_kernel = [&](auto has_main_k_block_loop) {
395  const auto kernel = kernel_gemm_wmma<
396  GridwiseGemm,
397  ADataType,
398  BDataType,
399  CDataType,
400  remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc>,
401  remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc>,
404  AElementwiseOperation,
405  BElementwiseOperation,
406  CElementwiseOperation,
407  remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
408  has_main_k_block_loop>;
409 
410  return launch_and_time_kernel(stream_config,
411  kernel,
412  dim3(grid_size),
413  dim3(BlockSize),
414  0,
415  arg.p_a_grid_,
416  arg.p_b_grid_,
417  arg.p_c_grid_,
418  arg.a_grid_desc_,
421  arg.a_element_op_,
422  arg.b_element_op_,
423  arg.c_element_op_,
424  arg.block_2_ctile_map_);
425  };
426 
428  {
429  return launch_kernel(integral_constant<bool, true>{});
430  }
431  else
432  {
433  return launch_kernel(integral_constant<bool, false>{});
434  }
435  }
436 
437  // polymorphic
438  float Run(const BaseArgument* p_arg,
439  const StreamConfig& stream_config = StreamConfig{}) override
440  {
441  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
442  }
443  };
444 
445  static constexpr bool IsValidCompilationParameter()
446  {
447  // TODO: properly implement this check
448  return true;
449  }
450 
451  static bool IsSupportedArgument(const Argument& arg)
452  {
454  {
455  if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
456  is_same_v<AccDataType, int32_t>))
457  {
458  printf("DeviceOp err: AccDataType");
459  return false;
460  }
461  }
462  else
463  {
464  printf("DeviceOp err: Arch");
465  return false;
466  }
467 
468  // check vector load/store
469  {
472 
473  // check vector load of A
474  if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
475  {
476  if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0)
477  {
478  return false;
479  }
480  }
481  else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
482  {
483  // FIXME: not rigorous
484  if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
485  {
486  return false;
487  }
488  }
489  else
490  {
491  return false;
492  }
493 
494  // check vector laod of B
495  if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
496  {
497  if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0)
498  {
499  return false;
500  }
501  }
502  else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
503  {
504  // FIXME: not rigorous
505  if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
506  {
507  return false;
508  }
509  }
510  else
511  {
512  return false;
513  }
514 
515  // check vector store of C
516  // only support RowMajor for now
517  if constexpr(is_same_v<CLayout, Row>)
518  {
519  if(arg.NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
520  {
521  return false;
522  }
523  }
524  else
525  {
526  return false;
527  }
528  }
529 
532  arg.c_grid_desc_m_n_,
533  arg.block_2_ctile_map_);
534  }
535 
536  // polymorphic
537  bool IsSupportedArgument(const BaseArgument* p_arg) override
538  {
539  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
540  }
541 
542  static auto MakeArgument(const ADataType* p_a,
543  const BDataType* p_b,
544  CDataType* p_c,
545  index_t M,
546  index_t N,
547  index_t K,
548  index_t StrideA,
549  index_t StrideB,
550  index_t StrideC,
551  AElementwiseOperation a_element_op,
552  BElementwiseOperation b_element_op,
553  CElementwiseOperation c_element_op)
554  {
555  return Argument{p_a,
556  p_b,
557  p_c,
558  M,
559  N,
560  K,
561  StrideA,
562  StrideB,
563  StrideC,
564  1,
565  1,
566  a_element_op,
567  b_element_op,
568  c_element_op};
569  }
570 
571  static auto MakeInvoker() { return Invoker{}; }
572 
573  // polymorphic
574  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
575  const void* p_b,
576  void* p_c,
577  index_t M,
578  index_t N,
579  index_t K,
580  index_t StrideA,
581  index_t StrideB,
582  index_t StrideC,
583  AElementwiseOperation a_element_op,
584  BElementwiseOperation b_element_op,
585  CElementwiseOperation c_element_op) override
586  {
587  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
588  static_cast<const BDataType*>(p_b),
589  static_cast<CDataType*>(p_c),
590  M,
591  N,
592  K,
593  StrideA,
594  StrideB,
595  StrideC,
596  1,
597  1,
598  a_element_op,
599  b_element_op,
600  c_element_op);
601  }
602 
603  // polymorphic
604  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
605  {
606  return std::make_unique<Invoker>(Invoker{});
607  }
608 
609  // polymorphic
610  std::string GetTypeString() const override
611  {
612  auto str = std::stringstream();
613 
614  std::map<LoopScheduler, std::string> LoopSchedToString{
615  {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
616 
617  std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
618  {PipelineVersion::v2, "v2"}};
619 
620  // clang-format off
621  str << "DeviceGemmWmma_CShuffle"
622  << "<"
623  << BlockSize << ", "
624  << MPerBlock << ", "
625  << NPerBlock << ", "
626  << KPerBlock << ", "
627  << K1 << ", "
628  << MPerWmma << ", "
629  << NPerWmma << ", "
630  << MRepeat << ", "
631  << NRepeat
632  << ">"
633  << " AEnableLds: "
634  << AEnableLds << ", "
635  << "BEnableLds: "
636  << BEnableLds << ", "
637  << "NumPrefetch: "
638  << NumPrefetch << ", "
639  << "LoopScheduler: "
640  << LoopSchedToString[LoopSched] << ", "
641  << "PipelineVersion: "
642  << PipelineVersionToString[PipelineVer];
643  // clang-format on
644 
645  return str.str();
646  }
647 };
648 
649 } // namespace device
650 } // namespace tensor_operation
651 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
GemmSpecialization
Definition: gemm_specialization.hpp:11
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables... callables)
Definition: kernel_launch.hpp:72
Definition: ck.hpp:264
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_gemm_wmma(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, CDataType *__restrict__ p_c_grid, const AGridDesc a_grid_desc, const BGridDesc b_grid_desc, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_wmma.hpp:36
bool is_gfx12_supported()
Definition: device_prop.hpp:94
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:17
bool is_gfx11_supported()
Definition: device_prop.hpp:88
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_wmma.hpp:123
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition: gridwise_gemm_wmma.hpp:530
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition: gridwise_gemm_wmma.hpp:570
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc &a_grid_desc, const BGridDesc &b_grid_desc, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_wmma.hpp:412
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_wmma.hpp:511
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_wmma.hpp:503
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition: gridwise_gemm_wmma.hpp:572
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: tensor_layout.hpp:21
Definition: tensor_layout.hpp:16
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_gemm.hpp:22
CElementwiseOperation c_element_op_
Definition: device_gemm_wmma.hpp:357
AElementwiseOperation a_element_op_
Definition: device_gemm_wmma.hpp:355
AGridDesc a_grid_desc_
Definition: device_gemm_wmma.hpp:347
index_t M01_
Definition: device_gemm_wmma.hpp:353
index_t NRaw_
Definition: device_gemm_wmma.hpp:360
index_t N01_
Definition: device_gemm_wmma.hpp:354
const ADataType * p_a_grid_
Definition: device_gemm_wmma.hpp:344
BGridDesc b_grid_desc_k0_n_k1_
Definition: device_gemm_wmma.hpp:348
GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_
Definition: device_gemm_wmma.hpp:352
index_t MRaw_
Definition: device_gemm_wmma.hpp:359
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t M01, index_t N01, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_gemm_wmma.hpp:296
GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock
Definition: device_gemm_wmma.hpp:351
index_t KRaw_
Definition: device_gemm_wmma.hpp:361
BElementwiseOperation b_element_op_
Definition: device_gemm_wmma.hpp:356
CGridDesc_M_N c_grid_desc_m_n_
Definition: device_gemm_wmma.hpp:349
CDataType * p_c_grid_
Definition: device_gemm_wmma.hpp:346
const BDataType * p_b_grid_
Definition: device_gemm_wmma.hpp:345
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_wmma.hpp:438
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_wmma.hpp:369
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
Definition: device_gemm_wmma.hpp:221
static constexpr auto K1Number
Definition: device_gemm_wmma.hpp:85
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
Definition: device_gemm_wmma.hpp:113
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_wmma.hpp:604
static auto MakeInvoker()
Definition: device_gemm_wmma.hpp:571
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_wmma.hpp:537
static constexpr auto AEnableLds_manu
Definition: device_gemm_wmma.hpp:104
static constexpr auto AEnableLds
Definition: device_gemm_wmma.hpp:107
static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
Definition: device_gemm_wmma.hpp:167
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_wmma.hpp:451
static constexpr auto BEnableLds
Definition: device_gemm_wmma.hpp:108
static constexpr auto I3
Definition: device_gemm_wmma.hpp:80
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition: device_gemm_wmma.hpp:242
static constexpr auto MaxVectorLoadA
Definition: device_gemm_wmma.hpp:90
static constexpr auto I1
Definition: device_gemm_wmma.hpp:78
static constexpr auto AEnableLds_auto
Definition: device_gemm_wmma.hpp:93
std::string GetTypeString() const override
Definition: device_gemm_wmma.hpp:610
static constexpr auto I6
Definition: device_gemm_wmma.hpp:83
static constexpr auto I0
Definition: device_gemm_wmma.hpp:77
static constexpr auto I2
Definition: device_gemm_wmma.hpp:79
static constexpr auto I5
Definition: device_gemm_wmma.hpp:82
decltype(MakeBGridDescriptor(1, 1, 1)) BGridDesc
Definition: device_gemm_wmma.hpp:241
static constexpr auto MWaves
Definition: device_gemm_wmma.hpp:87
GridwiseGemm_Wmma< BlockSize, ADataType, BDataType, AccDataType, CShuffleDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc, BGridDesc, CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer > GridwiseGemm
Definition: device_gemm_wmma.hpp:291
static constexpr auto NWaves
Definition: device_gemm_wmma.hpp:88
static constexpr auto matrix_padder
Definition: device_gemm_wmma.hpp:110
static constexpr auto WmmaK
Definition: device_gemm_wmma.hpp:89
decltype(MakeAGridDescriptor(1, 1, 1)) AGridDesc
Definition: device_gemm_wmma.hpp:240
static constexpr auto BEnableLds_auto
Definition: device_gemm_wmma.hpp:97
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_wmma.hpp:445
static constexpr auto MaxVectorLoadB
Definition: device_gemm_wmma.hpp:91
static constexpr auto BEnableLds_manu
Definition: device_gemm_wmma.hpp:105
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_gemm_wmma.hpp:542
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_gemm_wmma.hpp:574
static constexpr auto I4
Definition: device_gemm_wmma.hpp:81
Definition: matrix_padder.hpp:180