/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp Source File
device_gemm_reduce_xdl_cshuffle.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
18 
19 namespace ck {
20 namespace tensor_operation {
21 namespace device {
22 
23 // Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
24 // version currently has compiler issues with register spill which further causes validation
25 // failures.
26 template <typename ALayout,
27  typename BLayout,
28  typename CLayout,
29  typename ADataType,
30  typename BDataType,
31  typename CDataType,
32  typename GemmAccDataType,
33  typename CShuffleDataType,
34  typename ReduceAccDataType,
35  typename ReducePtrsGlobal,
36  typename AElementwiseOperation,
37  typename BElementwiseOperation,
38  typename CElementwiseOperation,
39  typename ReduceOperations,
40  typename ReduceInElementwiseOperations,
41  typename ReduceAccElementwiseOperations,
42  typename ReduceGlobalMemoryDataOperation,
43  GemmSpecialization GemmSpec,
44  index_t NumGemmKPrefetchStage,
45  index_t BlockSize,
46  index_t MPerBlock,
47  index_t NPerBlock,
48  index_t KPerBlock,
49  index_t AK1,
50  index_t BK1,
51  index_t MPerXDL,
52  index_t NPerXDL,
53  index_t MXdlPerWave,
54  index_t NXdlPerWave,
55  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
56  typename ABlockTransferThreadClusterArrangeOrder,
57  typename ABlockTransferSrcAccessOrder,
58  index_t ABlockTransferSrcVectorDim,
59  index_t ABlockTransferSrcScalarPerVector,
60  index_t ABlockTransferDstScalarPerVector_AK1,
61  bool ABlockLdsExtraM,
62  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
63  typename BBlockTransferThreadClusterArrangeOrder,
64  typename BBlockTransferSrcAccessOrder,
65  index_t BBlockTransferSrcVectorDim,
66  index_t BBlockTransferSrcScalarPerVector,
67  index_t BBlockTransferDstScalarPerVector_BK1,
68  bool BBlockLdsExtraN,
69  index_t CShuffleMXdlPerWavePerShuffle,
70  index_t CShuffleNXdlPerWavePerShuffle,
71  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
72  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
73  typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
74  index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
75  index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
77 struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()>
78 {
80 
81  static constexpr auto I0 = Number<0>{};
82  static constexpr auto I1 = Number<1>{};
83  static constexpr auto I2 = Number<2>{};
84 
85  static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
86  {
87  const auto a_grid_desc_mraw_kraw = [&]() {
88  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
89  {
90  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
91  make_tuple(StrideA, I1));
92  }
93  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
94  {
95  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
96  make_tuple(I1, StrideA));
97  }
98  }();
99 
100  const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
101  const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
102 
103  const auto MPad = M - MRaw;
104  const auto KPad = K - KRaw;
105 
106  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
107  GemmSpec == GemmSpecialization::MNKPadding)
108  {
109  // pad both M and K
110  assert(K % AK1 == 0);
111 
112  const auto AK0 = K / AK1;
113 
114  const auto a_grid_desc_m_k =
115  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
117  make_right_pad_transform(KRaw, KPad)),
120 
121  const auto a_grid_desc_ak0_m_ak1 =
122  transform_tensor_descriptor(a_grid_desc_m_k,
127 
128  return a_grid_desc_ak0_m_ak1;
129  }
130  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
131  GemmSpec == GemmSpecialization::MNPadding)
132  {
133  // pad M, but not K
134  assert(KRaw % AK1 == 0);
135 
136  const auto AK0 = KRaw / AK1;
137 
138  const auto a_grid_desc_ak0_m_ak1 =
139  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
141  make_right_pad_transform(MRaw, MPad)),
144 
145  return a_grid_desc_ak0_m_ak1;
146  }
147  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
148  GemmSpec == GemmSpecialization::NKPadding)
149  {
150  // pad K, but not M
151  assert(K % AK1 == 0);
152 
153  const auto AK0 = K / AK1;
154 
155  const auto a_grid_desc_m_k = transform_tensor_descriptor(
156  a_grid_desc_mraw_kraw,
160 
161  const auto a_grid_desc_ak0_m_ak1 =
162  transform_tensor_descriptor(a_grid_desc_m_k,
167 
168  return a_grid_desc_ak0_m_ak1;
169  }
170  else
171  {
172  // not pad M or K
173  assert(KRaw % AK1 == 0);
174 
175  const auto AK0 = KRaw / AK1;
176 
177  const auto a_grid_desc_ak0_m_ak1 =
178  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
183 
184  return a_grid_desc_ak0_m_ak1;
185  }
186  }
187 
188  static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
189  {
190  const auto b_grid_desc_nraw_kraw = [&]() {
192  {
193  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
194  make_tuple(I1, StrideB));
195  }
197  {
198  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
199  make_tuple(StrideB, I1));
200  }
201  }();
202 
203  const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
204  const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
205 
206  const auto NPad = N - NRaw;
207  const auto KPad = K - KRaw;
208 
209  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
210  GemmSpec == GemmSpecialization::MNKPadding)
211  {
212  // pad both N and K
213  assert(K % BK1 == 0);
214 
215  const auto BK0 = K / BK1;
216 
217  const auto b_grid_desc_n_k =
218  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
220  make_right_pad_transform(KRaw, KPad)),
223 
224  const auto b_grid_desc_bk0_n_bk1 =
225  transform_tensor_descriptor(b_grid_desc_n_k,
230 
231  return b_grid_desc_bk0_n_bk1;
232  }
233  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
234  GemmSpec == GemmSpecialization::MNPadding)
235  {
236  // pad N, but not K
237  assert(KRaw % BK1 == 0);
238 
239  const auto BK0 = KRaw / BK1;
240 
241  const auto b_grid_desc_bk0_n_bk1 =
242  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
244  make_right_pad_transform(NRaw, NPad)),
247 
248  return b_grid_desc_bk0_n_bk1;
249  }
250  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
251  GemmSpec == GemmSpecialization::MKPadding)
252  {
253  // pad K, but not N
254  assert(K % BK1 == 0);
255 
256  const auto BK0 = K / BK1;
257 
258  const auto b_grid_desc_n_k = transform_tensor_descriptor(
259  b_grid_desc_nraw_kraw,
263 
264  const auto b_grid_desc_bk0_n_bk1 =
265  transform_tensor_descriptor(b_grid_desc_n_k,
270 
271  return b_grid_desc_bk0_n_bk1;
272  }
273  else
274  {
275  // not pad N or K
276  assert(KRaw % BK1 == 0);
277 
278  const auto BK0 = KRaw / BK1;
279 
280  const auto b_grid_desc_bk0_n_bk1 =
281  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
286 
287  return b_grid_desc_bk0_n_bk1;
288  }
289  }
290 
291  static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
292  {
293  const auto c_grid_desc_mraw_nraw = [&]() {
295  {
296  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
297  make_tuple(StrideC, I1));
298  }
300  {
301  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
302  make_tuple(I1, StrideC));
303  }
304  }();
305 
306  const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
307  const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
308 
309  const auto MPad = M - MRaw;
310  const auto NPad = N - NRaw;
311 
312  if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
313  GemmSpec == GemmSpecialization::MNKPadding)
314  {
315  // pad M and N
316  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
318  make_right_pad_transform(NRaw, NPad)),
321  }
322  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
323  GemmSpec == GemmSpecialization::MKPadding)
324  {
325  // pad M, but not N
327  c_grid_desc_mraw_nraw,
331  }
332  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
333  GemmSpec == GemmSpecialization::NKPadding)
334  {
335  // pad N, but not M
337  c_grid_desc_mraw_nraw,
341  }
342  else
343  {
344  // not pad M or N
345  return c_grid_desc_mraw_nraw;
346  }
347  }
348 
349  // assume Reduce is packed tensor
351  {
352  const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
353 
354  const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
355  const auto MPad = M - MRaw;
356 
357  if constexpr(GemmSpec == GemmSpecialization::MPadding ||
358  GemmSpec == GemmSpecialization::MNPadding ||
359  GemmSpec == GemmSpecialization::MKPadding ||
360  GemmSpec == GemmSpecialization::MNKPadding)
361  {
362  // pad M
363  return transform_tensor_descriptor(d_grid_desc_mraw,
367  }
368  else
369  {
370  // not pad M
371  return d_grid_desc_mraw;
372  }
373  }
374 
377  using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
379 
380  // GridwiseGemm
382  ADataType, // TODO: distinguish A/B datatype
383  GemmAccDataType,
384  CShuffleDataType,
385  CDataType,
386  ReduceAccDataType,
387  ReducePtrsGlobal,
388  AElementwiseOperation,
389  BElementwiseOperation,
390  CElementwiseOperation,
391  ReduceOperations,
392  ReduceInElementwiseOperations,
393  ReduceAccElementwiseOperations,
395  ReduceGlobalMemoryDataOperation,
400  NumGemmKPrefetchStage,
401  BlockSize,
402  MPerBlock,
403  NPerBlock,
404  KPerBlock,
405  AK1,
406  BK1,
407  MPerXDL,
408  NPerXDL,
409  MXdlPerWave,
410  NXdlPerWave,
411  ABlockTransferThreadClusterLengths_AK0_M_AK1,
412  ABlockTransferThreadClusterArrangeOrder,
413  ABlockTransferSrcAccessOrder,
414  ABlockTransferSrcVectorDim,
415  ABlockTransferSrcScalarPerVector,
416  ABlockTransferDstScalarPerVector_AK1,
417  false,
418  ABlockLdsExtraM,
419  BBlockTransferThreadClusterLengths_BK0_N_BK1,
420  BBlockTransferThreadClusterArrangeOrder,
421  BBlockTransferSrcAccessOrder,
422  BBlockTransferSrcVectorDim,
423  BBlockTransferSrcScalarPerVector,
424  BBlockTransferDstScalarPerVector_BK1,
425  false,
426  BBlockLdsExtraN,
427  CShuffleMXdlPerWavePerShuffle,
428  CShuffleNXdlPerWavePerShuffle,
429  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
430  CShuffleBlockTransferScalarPerVector_NPerBlock,
431  CReduceThreadClusterLengths_MPerBlock_NPerBlock,
432  CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
433  CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
434  LoopSched>;
435 
436  // Argument
437  struct Argument : public BaseArgument
438  {
439  Argument(const ADataType* p_a_grid,
440  const BDataType* p_b_grid,
441  CDataType* p_c_grid,
442  ReducePtrsGlobal p_reduces_grid,
443  index_t MRaw,
444  index_t NRaw,
445  index_t KRaw,
446  index_t StrideA,
447  index_t StrideB,
448  index_t StrideC,
449  AElementwiseOperation a_element_op,
450  BElementwiseOperation b_element_op,
451  CElementwiseOperation c_element_op,
452  ReduceInElementwiseOperations reduce_in_element_ops,
453  ReduceAccElementwiseOperations reduce_out_element_ops)
454  : p_a_grid_{p_a_grid},
455  p_b_grid_{p_b_grid},
456  p_c_grid_{p_c_grid},
457  p_reduces_grid_{p_reduces_grid},
460  c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
464  block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
465  a_element_op_{a_element_op},
466  b_element_op_{b_element_op},
467  c_element_op_{c_element_op},
468  reduce_in_element_ops_{reduce_in_element_ops},
469  reduce_out_element_ops_{reduce_out_element_ops}
470  {
475  {
479 
482  }
483  }
484 
485  // private:
486  const ADataType* p_a_grid_;
487  const BDataType* p_b_grid_;
488  CDataType* p_c_grid_;
489  ReducePtrsGlobal p_reduces_grid_;
499  AElementwiseOperation a_element_op_;
500  BElementwiseOperation b_element_op_;
501  CElementwiseOperation c_element_op_;
502  ReduceInElementwiseOperations reduce_in_element_ops_;
503  ReduceAccElementwiseOperations reduce_out_element_ops_;
504  };
505 
506  // Invoker
507  struct Invoker : public BaseInvoker
508  {
510 
511  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
512  {
513  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
514  {
515  std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
516  << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
517  << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
518  << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
519 
520  std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
521  << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
522  << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
523  << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
524 
525  std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
526  << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
527 
528  std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0)
529  << "}" << std::endl;
530  }
531 
534  arg.c_grid_desc_m_n_,
535  arg.block_2_ctile_map_))
536  {
537  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
538  }
539 
540  const index_t grid_size =
541  arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
542 
543  const auto K =
544  arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
545 
546  float elapsed_time = 0.0f;
548  {
549  const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1<
550  GridwiseGemm,
551  ADataType, // TODO: distiguish A/B datatype
552  CDataType,
553  ReducePtrsGlobal,
554  AElementwiseOperation,
555  BElementwiseOperation,
556  CElementwiseOperation,
557  ReduceInElementwiseOperations,
558  ReduceAccElementwiseOperations,
564  true>;
565 
566  elapsed_time =
567  launch_and_time_kernel(stream_config,
568  kernel,
569  dim3(grid_size),
570  dim3(BlockSize),
571  0,
572  arg.p_a_grid_,
573  arg.p_b_grid_,
574  arg.p_c_grid_,
575  arg.p_reduces_grid_,
576  arg.a_element_op_,
577  arg.b_element_op_,
578  arg.c_element_op_,
585  arg.block_2_ctile_map_);
586  }
587  else
588  {
589  const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1<
590  GridwiseGemm,
591  ADataType, // TODO: distiguish A/B datatype
592  CDataType,
593  ReducePtrsGlobal,
594  AElementwiseOperation,
595  BElementwiseOperation,
596  CElementwiseOperation,
597  ReduceInElementwiseOperations,
598  ReduceAccElementwiseOperations,
604  false>;
605 
606  elapsed_time =
607  launch_and_time_kernel(stream_config,
608  kernel,
609  dim3(grid_size),
610  dim3(BlockSize),
611  0,
612  arg.p_a_grid_,
613  arg.p_b_grid_,
614  arg.p_c_grid_,
615  arg.p_reduces_grid_,
616  arg.a_element_op_,
617  arg.b_element_op_,
618  arg.c_element_op_,
625  arg.block_2_ctile_map_);
626  }
627 
628  return elapsed_time;
629  }
630 
631  // polymorphic
632  float Run(const BaseArgument* p_arg,
633  const StreamConfig& stream_config = StreamConfig{}) override
634  {
635  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
636  }
637  };
638 
639  static constexpr bool IsValidCompilationParameter()
640  {
641  // TODO: properly implement this check
642  return true;
643  }
644 
645  static bool IsSupportedArgument(const Argument& arg)
646  {
647  if(!ck::is_xdl_supported())
648  {
649  return false;
650  }
651 
654  arg.c_grid_desc_m_n_,
655  arg.block_2_ctile_map_);
656  }
657 
658  // polymorphic
659  bool IsSupportedArgument(const BaseArgument* p_arg) override
660  {
661  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
662  }
663 
664  static constexpr int NumReduce = ReduceOperations::Size();
665  static auto MakeArgument(const void* p_a,
666  const void* p_b,
667  const void* p_bias,
668  std::array<const void*, 0> p_ds,
669  void* p_c,
670  std::array<void*, NumReduce> p_reduces,
671  ck::index_t M,
672  ck::index_t N,
673  ck::index_t K,
674  ck::index_t StrideA,
675  ck::index_t StrideB,
676  ck::index_t StrideC,
677  std::array<ck::index_t, 0> StrideDs,
678  std::array<void*, 3> gemm_element_ops,
679  std::array<void*, 0> d_element_ops,
680  std::array<void*, NumReduce> reduce_in_element_op,
681  std::array<void*, NumReduce> reduce_out_element_op)
682  {
683  (void)p_bias;
684  (void)p_ds;
685  (void)StrideDs;
686  (void)d_element_ops;
687 
688  ReducePtrsGlobal reduce_tuple = generate_tuple(
689  [&](auto I) {
690  auto tmp = ReducePtrsGlobal{}[I];
691  using T = remove_pointer_t<decltype(tmp)>;
692  return static_cast<T*>(p_reduces[I]);
693  },
695 
696  ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
697  [&](auto I) {
698  auto tmp = ReduceInElementwiseOperations{}[I];
699  using T = remove_pointer_t<decltype(tmp)>;
700  return *(static_cast<T*>(reduce_in_element_op[I]));
701  },
703  ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
704  [&](auto I) {
705  auto tmp = ReduceAccElementwiseOperations{}[I];
706  using T = remove_pointer_t<decltype(tmp)>;
707  return *(static_cast<T*>(reduce_out_element_op[I]));
708  },
710 
711  AElementwiseOperation a_element_op =
712  *(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
713  BElementwiseOperation b_element_op =
714  *(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
715  CElementwiseOperation c_element_op =
716  *(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
717 
718  return Argument{static_cast<const ADataType*>(p_a),
719  static_cast<const BDataType*>(p_b),
720  static_cast<CDataType*>(p_c),
721  reduce_tuple,
722  M,
723  N,
724  K,
725  StrideA,
726  StrideB,
727  StrideC,
728  a_element_op,
729  b_element_op,
730  c_element_op,
731  reduce_in_element_ops,
732  reduce_out_element_ops};
733  }
734 
735  static auto MakeInvoker() { return Invoker{}; }
736 
737  // polymorphic
738  std::unique_ptr<BaseArgument>
739  MakeArgumentPointer(const void* p_a,
740  const void* p_b,
741  const void* p_bias,
742  std::array<const void*, 0> p_ds,
743  void* p_c,
744  std::array<void*, NumReduce> p_reduces,
745  ck::index_t M,
746  ck::index_t N,
747  ck::index_t K,
748  ck::index_t StrideA,
749  ck::index_t StrideB,
750  ck::index_t StrideC,
751  std::array<ck::index_t, 0> StrideDs,
752  std::array<void*, 3> gemm_element_ops,
753  std::array<void*, 0> d_element_ops,
754  std::array<void*, NumReduce> reduce_in_element_op,
755  std::array<void*, NumReduce> reduce_out_element_op,
756  ck::index_t = 1) override
757  {
758  (void)p_bias;
759  (void)p_ds;
760  (void)StrideDs;
761  (void)d_element_ops;
762 
763  ReducePtrsGlobal reduce_tuple = generate_tuple(
764  [&](auto I) {
765  auto tmp = ReducePtrsGlobal{}[I];
766  using T = remove_pointer_t<decltype(tmp)>;
767  return static_cast<T*>(p_reduces[I]);
768  },
770 
771  ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
772  [&](auto I) {
773  auto tmp = ReduceInElementwiseOperations{}[I];
774  using T = remove_pointer_t<decltype(tmp)>;
775  return *(static_cast<T*>(reduce_in_element_op[I]));
776  },
778  ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
779  [&](auto I) {
780  auto tmp = ReduceAccElementwiseOperations{}[I];
781  using T = remove_pointer_t<decltype(tmp)>;
782  return *(static_cast<T*>(reduce_out_element_op[I]));
783  },
785 
786  AElementwiseOperation a_element_op =
787  *(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
788  BElementwiseOperation b_element_op =
789  *(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
790  CElementwiseOperation c_element_op =
791  *(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
792 
793  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
794  static_cast<const BDataType*>(p_b),
795  static_cast<CDataType*>(p_c),
796  reduce_tuple,
797  M,
798  N,
799  K,
800  StrideA,
801  StrideB,
802  StrideC,
803  a_element_op,
804  b_element_op,
805  c_element_op,
806  reduce_in_element_ops,
807  reduce_out_element_ops);
808  }
809 
810  // polymorphic
811  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
812  {
813  return std::make_unique<Invoker>(Invoker{});
814  }
815 
816  // polymorphic
817  std::string GetTypeString() const override
818  {
819  auto str = std::stringstream();
820 
821  // clang-format off
822  str << "DeviceGemmReduce_Xdl_CShuffle"
823  << "<"
824  << BlockSize << ", "
825  << MPerBlock << ", "
826  << NPerBlock << ", "
827  << KPerBlock << ", "
828  << AK1 << ", "
829  << BK1 << ", "
830  << MPerXDL << ", "
831  << NPerXDL << ", "
832  << MXdlPerWave << ", "
833  << NXdlPerWave << ", "
834  << ABlockTransferSrcScalarPerVector << ", "
835  << BBlockTransferSrcScalarPerVector << ", "
836  << CShuffleMXdlPerWavePerShuffle << ", "
837  << CShuffleNXdlPerWavePerShuffle
838  << ">";
839  // clang-format on
840 
841  return str.str();
842  }
843 };
844 
845 } // namespace device
846 } // namespace tensor_operation
847 } // namespace ck
#define CK_ENV(name)
Definition: env.hpp:128
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:264
bool is_xdl_supported()
Definition: device_prop.hpp:54
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
typename remove_pointer< T >::type remove_pointer_t
Definition: type.hpp:303
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__global__ void kernel_gemm_reduce_xdl_cshuffle_v1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, ReducePtrsGlobal p_reduces_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const ReduceInElementwiseOperations reduce_in_element_ops, const ReduceAccElementwiseOperations reduce_out_element_ops, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:40
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:149
remove_cvref_t< decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))> ReduceGridDescriptor_MBlock_MPerBlock
Definition: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:326
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))> DefaultBlock2CTileMap
Definition: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:329
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:232
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:280
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:272
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:323
__host__ static constexpr __device__ auto MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M &d_grid_desc_m)
Definition: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:299
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_gemm_reduce_xdl_cshuffle.hpp:438
CGridDesc_M_N c_grid_desc_m_n_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:492
const BDataType * p_b_grid_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:487
ReducePtrsGlobal p_reduces_grid_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:489
BElementwiseOperation b_element_op_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:500
CDataType * p_c_grid_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:488
GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:495
CElementwiseOperation c_element_op_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:501
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, ReducePtrsGlobal p_reduces_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, ReduceInElementwiseOperations reduce_in_element_ops, ReduceAccElementwiseOperations reduce_out_element_ops)
Definition: device_gemm_reduce_xdl_cshuffle.hpp:439
GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:498
ReduceInElementwiseOperations reduce_in_element_ops_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:502
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:490
AElementwiseOperation a_element_op_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:499
ReduceAccElementwiseOperations reduce_out_element_ops_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:503
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:491
GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:497
const ADataType * p_a_grid_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:486
ReduceGridDesc_M reduce_grid_desc_m_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:493
Definition: device_gemm_reduce_xdl_cshuffle.hpp:508
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_reduce_xdl_cshuffle.hpp:632
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_reduce_xdl_cshuffle.hpp:511
Definition: device_gemm_reduce_xdl_cshuffle.hpp:78
static constexpr auto I0
Definition: device_gemm_reduce_xdl_cshuffle.hpp:81
static constexpr auto I2
Definition: device_gemm_reduce_xdl_cshuffle.hpp:83
static constexpr auto I1
Definition: device_gemm_reduce_xdl_cshuffle.hpp:82
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
Definition: device_gemm_reduce_xdl_cshuffle.hpp:291
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_reduce_xdl_cshuffle.hpp:639
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
Definition: device_gemm_reduce_xdl_cshuffle.hpp:188
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition: device_gemm_reduce_xdl_cshuffle.hpp:377
GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched > GridwiseGemm
Definition: device_gemm_reduce_xdl_cshuffle.hpp:434
decltype(MakeReduceGridDescriptor_M(1)) ReduceGridDesc_M
Definition: device_gemm_reduce_xdl_cshuffle.hpp:378
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_reduce_xdl_cshuffle.hpp:811
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_reduce_xdl_cshuffle.hpp:659
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_reduce_xdl_cshuffle.hpp:645
static auto MakeReduceGridDescriptor_M(index_t MRaw)
Definition: device_gemm_reduce_xdl_cshuffle.hpp:350
static auto MakeInvoker()
Definition: device_gemm_reduce_xdl_cshuffle.hpp:735
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
Definition: device_gemm_reduce_xdl_cshuffle.hpp:85
decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)) AGridDesc_AK0_M_AK1
Definition: device_gemm_reduce_xdl_cshuffle.hpp:375
decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)) BGridDesc_BK0_N_BK1
Definition: device_gemm_reduce_xdl_cshuffle.hpp:376
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const void *p_bias, std::array< const void *, 0 > p_ds, void *p_c, std::array< void *, NumReduce > p_reduces, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, std::array< ck::index_t, 0 > StrideDs, std::array< void *, 3 > gemm_element_ops, std::array< void *, 0 > d_element_ops, std::array< void *, NumReduce > reduce_in_element_op, std::array< void *, NumReduce > reduce_out_element_op, ck::index_t=1) override
Definition: device_gemm_reduce_xdl_cshuffle.hpp:739
static auto MakeArgument(const void *p_a, const void *p_b, const void *p_bias, std::array< const void *, 0 > p_ds, void *p_c, std::array< void *, NumReduce > p_reduces, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, std::array< ck::index_t, 0 > StrideDs, std::array< void *, 3 > gemm_element_ops, std::array< void *, 0 > d_element_ops, std::array< void *, NumReduce > reduce_in_element_op, std::array< void *, NumReduce > reduce_out_element_op)
Definition: device_gemm_reduce_xdl_cshuffle.hpp:665
static constexpr int NumReduce
Definition: device_gemm_reduce_xdl_cshuffle.hpp:664
std::string GetTypeString() const override
Definition: device_gemm_reduce_xdl_cshuffle.hpp:817
Definition: device_gemm_reduce.hpp:17