/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp Source File
device_grouped_gemm_xdl.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
10 #include "ck/utility/env.hpp"
20 
21 namespace ck {
22 namespace tensor_operation {
23 namespace device {
24 
25 template <typename GridwiseGemm,
26  typename GemmDesc,
27  typename AElementwiseOperation,
28  typename BElementwiseOperation,
29  typename CDEElementwiseOperation,
30  bool HasMainKBlockLoop>
31 __global__ void
32 #if CK_USE_LAUNCH_BOUNDS
34 #endif
35  kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
36  const index_t group_count,
37  const AElementwiseOperation a_element_op,
38  const BElementwiseOperation b_element_op,
39  const CDEElementwiseOperation c_element_op)
40 {
41 #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
42  if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
43  {
44  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
45 
46  const index_t block_id = get_block_1d_id();
47 
48  const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(
49  cast_pointer_to_generic_address_space(gemm_descs_const));
50 
51  index_t left = 0;
52  index_t right = group_count;
53  index_t group_id = index_t((left + right) / 2);
54  while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ &&
55  block_id < gemm_desc_ptr[group_id].BlockEnd_)) &&
56  left <= right)
57  {
58  if(block_id < gemm_desc_ptr[group_id].BlockStart_)
59  {
60  right = group_id;
61  }
62  else
63  {
64  left = group_id;
65  }
66  group_id = index_t((left + right) / 2);
67  }
68 
69  GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
70  gemm_desc_ptr[group_id].a_ptr_,
71  gemm_desc_ptr[group_id].b_ptr_,
72  gemm_desc_ptr[group_id].ds_ptr_,
73  gemm_desc_ptr[group_id].e_ptr_,
74  p_shared,
75  a_element_op,
76  b_element_op,
77  c_element_op,
78  gemm_desc_ptr[group_id].a_grid_desc_ak0_m_ak1_,
79  gemm_desc_ptr[group_id].b_grid_desc_bk0_n_bk1_,
80  gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
81  gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
82  gemm_desc_ptr[group_id].block_2_etile_map_);
83  }
84 #else
85  ignore = gemm_descs_const;
86  ignore = group_count;
87  ignore = a_element_op;
88  ignore = b_element_op;
89  ignore = c_element_op;
90 #endif
91 }
92 
93 template <typename ALayout,
94  typename BLayout,
95  typename DsLayout,
96  typename ELayout,
97  typename ADataType,
98  typename BDataType,
99  typename AccDataType,
100  typename CShuffleDataType,
101  typename DsDataType,
102  typename EDataType,
103  typename AElementwiseOperation,
104  typename BElementwiseOperation,
105  typename CDEElementwiseOperation,
106  GemmSpecialization GemmSpec,
107  ck::index_t NumPrefetch,
108  ck::index_t BlockSize,
109  ck::index_t MPerBlock,
110  ck::index_t NPerBlock,
111  ck::index_t KPerBlock,
112  ck::index_t AK1,
113  ck::index_t BK1,
114  ck::index_t MPerXDL,
115  ck::index_t NPerXDL,
116  ck::index_t MXdlPerWave,
117  ck::index_t NXdlPerWave,
118  typename ABlockTransferThreadClusterLengths_K0_M_K1,
119  typename ABlockTransferThreadClusterArrangeOrder,
120  typename ABlockTransferSrcAccessOrder,
121  ck::index_t ABlockTransferSrcVectorDim,
122  ck::index_t ABlockTransferSrcScalarPerVector,
123  ck::index_t ABlockTransferDstScalarPerVector_K1,
124  bool ABlockLdsExtraM,
125  typename BBlockTransferThreadClusterLengths_K0_N_K1,
126  typename BBlockTransferThreadClusterArrangeOrder,
127  typename BBlockTransferSrcAccessOrder,
128  ck::index_t BBlockTransferSrcVectorDim,
129  ck::index_t BBlockTransferSrcScalarPerVector,
130  ck::index_t BBlockTransferDstScalarPerVector_K1,
131  bool BBlockLdsExtraN,
132  index_t CShuffleMXdlPerWavePerShuffle,
133  index_t CShuffleNXdlPerWavePerShuffle,
134  typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
135  index_t CDEBlockTransferScalarPerVector_NPerBlock,
137  typename ComputeDataType = ADataType>
138 struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
139  BLayout,
140  DsLayout,
141  ELayout,
142  ADataType,
143  BDataType,
144  DsDataType,
145  EDataType,
146  AElementwiseOperation,
147  BElementwiseOperation,
148  CDEElementwiseOperation,
149  ComputeDataType>
150 {
153  static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
154  static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
155  static constexpr index_t NumDTensor = DsDataType::Size();
156 
157  static constexpr auto I0 = Number<0>{};
158  static constexpr auto I1 = Number<1>{};
159  static constexpr auto I2 = Number<2>{};
160 
161  static constexpr auto matrix_padder =
162  MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
163 
164  static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
165  {
166  const auto a_grid_desc_mraw_kraw = [&]() {
167  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
168  {
169  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
170  make_tuple(StrideA, I1));
171  }
172  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
173  {
174  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
175  make_tuple(I1, StrideA));
176  }
177  }();
178 
179  return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
180  }
181 
182  static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
183  {
184  const auto b_grid_desc_nraw_kraw = [&]() {
186  {
187  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
188  make_tuple(I1, StrideB));
189  }
191  {
192  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
193  make_tuple(StrideB, I1));
194  }
195  }();
196 
197  return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
198  }
199 
200  template <typename ELay>
201  static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
202  {
203  const auto e_grid_desc_mraw_nraw = [&]() {
205  {
206  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
207  make_tuple(StrideE, I1));
208  }
210  {
211  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
212  make_tuple(I1, StrideE));
213  }
214  }();
215 
216  return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
217  }
218 
219  static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
220  const std::array<index_t, NumDTensor>& NRaws,
221  const std::array<index_t, NumDTensor>& DsStride)
222  {
223  return generate_tuple(
224  [&](auto i) {
225  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
226 
227  return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
228  },
230  }
231 
232  using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
233  using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
235  using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
236 
237  // GridwiseGemm
238  template <index_t NXdlPerWave_>
240  ADataType, // TODO: distinguish A/B datatype
241  BDataType,
242  ComputeDataType,
243  AccDataType,
244  CShuffleDataType,
245  DsDataType,
246  EDataType,
247  AElementwiseOperation,
248  BElementwiseOperation,
249  CDEElementwiseOperation,
250  NumPrefetch, // NumGemmKPrefetchStage
251  BlockSize,
252  MPerBlock,
253  NPerBlock,
254  KPerBlock,
255  AK1,
256  BK1,
257  MPerXDL,
258  NPerXDL,
259  MXdlPerWave,
260  NXdlPerWave_,
261  ABlockTransferThreadClusterLengths_K0_M_K1,
262  ABlockTransferThreadClusterArrangeOrder,
263  ABlockTransferSrcAccessOrder,
264  ABlockTransferSrcVectorDim,
265  ABlockTransferSrcScalarPerVector,
266  ABlockTransferDstScalarPerVector_K1,
267  false, // AThreadTransferSrcResetCoordinateAfterRun,
268  ABlockLdsExtraM,
269  BBlockTransferThreadClusterLengths_K0_N_K1,
270  BBlockTransferThreadClusterArrangeOrder,
271  BBlockTransferSrcAccessOrder,
272  BBlockTransferSrcVectorDim,
273  BBlockTransferSrcScalarPerVector,
274  BBlockTransferDstScalarPerVector_K1,
275  false, // BThreadTransferSrcResetCoordinateAfterRun,
276  BBlockLdsExtraN,
277  CShuffleMXdlPerWavePerShuffle,
278  CShuffleNXdlPerWavePerShuffle,
279  CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
280  CDEBlockTransferScalarPerVector_NPerBlock,
281  LoopSched>;
284 
287  AGridDesc_M_K{}))>;
290  BGridDesc_N_K{}))>;
293  DsGridDesc_M_N{}))>;
296  EGridDesc_M_N{}))>;
297 
299  {
302 
304  {
306  BlockStart_ = -1;
307  }
308 
309  GroupedGemmBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, ck::index_t BlockStart)
310  {
312  BlockStart_ = BlockStart;
313  }
314 
315  template <typename TopIdx>
316  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
317  {
318  return block_2_etile_map_.CalculateBottomIndex(
319  make_multi_index(idx_top[I0] - BlockStart_));
320  }
321 
322  // it's actually E-Tile
323  template <typename CTileIdx, typename CTileDim>
324  __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
325  const CTileDim& c_tile_dim) const
326  {
327  return block_2_etile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
328  }
329 
330  __host__ bool CheckValidity(const EGridDesc_M_N& e_grid_desc_m_n) const
331  {
332  return block_2_etile_map_.CheckValidity(e_grid_desc_m_n);
333  }
334 
337  };
338 
340  {
341  // pointers
342  const ADataType* a_ptr_;
343  const BDataType* b_ptr_;
345  EDataType* e_ptr_;
346 
347  // tensor descriptors for problem definiton
352 
353  // tensor descriptors for block/thread-wise copy
359 
360  // block-to-e-tile map
363  };
364 
365  // Argument
366  struct Argument : public BaseArgument
367  {
368  template <typename GridwiseGemm, typename DsPointer, typename Block2ETileMap>
369  void init_gridwise_gemm_desc(const ADataType* a_ptr,
370  const BDataType* b_ptr,
371  DsPointer ds_ptr,
372  EDataType* e_ptr,
373  const AGridDesc_M_K& a_grid_desc_m_k,
374  const BGridDesc_N_K& b_grid_desc_n_k,
375  const DsGridDesc_M_N& ds_grid_desc_m_n,
376  const EGridDesc_M_N& e_grid_desc_m_n,
377  const Block2ETileMap& block_2_etile_map,
378  index_t BlockStart,
379  index_t BlockEnd)
380  {
381  // tensor descriptors for block/thread-wise copy
382  const auto a_grid_desc_ak0_m_ak1 =
384 
385  const auto b_grid_desc_bk0_n_bk1 =
387 
388  if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
389  b_grid_desc_n_k,
390  ds_grid_desc_m_n,
391  e_grid_desc_m_n,
392  block_2_etile_map))
393  {
394  // tensor descriptors for block/thread-wise copy
396  ds_grid_desc_mblock_mperblock_nblock_nperblock;
397 
398  static_for<0, NumDTensor, 1>{}([&](auto j) {
399  ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
400  GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
401  ds_grid_desc_m_n[j]);
402  });
403 
404  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
405  GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
406  e_grid_desc_m_n);
407 
408  gemm_desc_kernel_arg_.push_back(
410  b_ptr,
411  ds_ptr,
412  e_ptr,
413  a_grid_desc_m_k,
414  b_grid_desc_n_k,
415  ds_grid_desc_m_n,
416  e_grid_desc_m_n,
417  a_grid_desc_ak0_m_ak1,
418  b_grid_desc_bk0_n_bk1,
419  ds_grid_desc_mblock_mperblock_nblock_nperblock,
420  e_grid_desc_mblock_mperblock_nblock_nperblock,
421  block_2_etile_map,
422  BlockStart,
423  BlockEnd});
424  }
425  };
426  Argument(std::vector<const void*>& p_As,
427  std::vector<const void*>& p_Bs,
428  std::vector<std::array<const void*, NumDTensor>>& p_Ds,
429  std::vector<void*>& p_Es,
430  std::vector<GemmDesc>& gemm_descs,
431  AElementwiseOperation a_element_op,
432  BElementwiseOperation b_element_op,
433  CDEElementwiseOperation c_element_op)
434  : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}
435  {
436  grid_size_ = 0;
437 
438  group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
439 
440  if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
441  group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
442  group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
443  {
444  throw std::runtime_error("wrong! group_count_ != p_As/b/c.size");
445  }
446 
448 
450 
451  for(std::size_t i = 0; i < gemm_descs.size(); i++)
452  {
453  const index_t M = gemm_descs[i].M_;
454  const index_t N = gemm_descs[i].N_;
455  const index_t K = gemm_descs[i].K_;
456 
457  a_mtx_mraw_kraw_.emplace_back(M, K);
458  b_mtx_nraw_kraw_.emplace_back(N, K);
459 
460  if(M == 0)
461  {
463  continue;
464  }
465 
466  const index_t StrideA = gemm_descs[i].stride_A_;
467  const index_t StrideB = gemm_descs[i].stride_B_;
468  const index_t StrideC = gemm_descs[i].stride_C_;
469 
470  // pointer
471  typename GridwiseGemm64::DsGridPointer p_ds_grid{};
472 
473  static_for<0, NumDTensor, 1>{}([&](auto j) {
474  using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
475 
476  p_ds_grid(j) = static_cast<const DDataType*>(p_Ds[i][j]);
477  });
478 
479  // tensor descriptors for problem definiton
480  const auto a_grid_desc_m_k = DeviceOp::MakeAGridDescriptor_M_K(M, K, StrideA);
481  const auto b_grid_desc_n_k = DeviceOp::MakeBGridDescriptor_N_K(K, N, StrideB);
482 
483  DsGridDesc_M_N ds_grid_desc_m_n;
484 
485  static_for<0, NumDTensor, 1>{}([&](auto j) {
486  using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
487 
488  ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
489  M, N, gemm_descs[i].stride_Ds_[j]);
490  });
491 
492  const auto e_grid_desc_m_n =
493  DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideC);
494 
495  const index_t grid_size_grp =
496  GroupedGemmBlock2ETileMap(e_grid_desc_m_n, 0)
497  .block_2_etile_map_.CalculateGridSize(e_grid_desc_m_n);
498 
499  const index_t BlockStart = grid_size_;
500  const index_t BlockEnd = grid_size_ + grid_size_grp;
501 
502  grid_size_ += grid_size_grp;
503 
504  // block-to-e-tile map
505  const auto block_2_etile_map =
506  GroupedGemmBlock2ETileMap(e_grid_desc_m_n, BlockStart);
507 
508  if(get_warp_size() == 64)
509  {
510  if constexpr(NXdlPerWave64 > 0)
511  {
512  init_gridwise_gemm_desc<GridwiseGemm64>(
513  static_cast<const ADataType*>(p_As[i]),
514  static_cast<const BDataType*>(p_Bs[i]),
515  p_ds_grid,
516  static_cast<EDataType*>(p_Es[i]),
517  a_grid_desc_m_k,
518  b_grid_desc_n_k,
519  ds_grid_desc_m_n,
520  e_grid_desc_m_n,
521  block_2_etile_map,
522  BlockStart,
523  BlockEnd);
524  }
525  }
526  else
527  {
528  if constexpr(NXdlPerWave32 > 0)
529  {
530  init_gridwise_gemm_desc<GridwiseGemm32>(
531  static_cast<const ADataType*>(p_As[i]),
532  static_cast<const BDataType*>(p_Bs[i]),
533  p_ds_grid,
534  static_cast<EDataType*>(p_Es[i]),
535  a_grid_desc_m_k,
536  b_grid_desc_n_k,
537  ds_grid_desc_m_n,
538  e_grid_desc_m_n,
539  block_2_etile_map,
540  BlockStart,
541  BlockEnd);
542  }
543  }
544  }
545  }
546 
547  // private:
550 
551  AElementwiseOperation a_element_op_;
552  BElementwiseOperation b_element_op_;
553  CDEElementwiseOperation c_element_op_;
554 
555  std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
556  std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
557  std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
558 
561  };
562 
563  // Invoker
564  struct Invoker : public BaseInvoker
565  {
567 
568  template <typename GridwiseGemm>
569  float RunImp(const Argument& arg,
570  const StreamConfig& stream_config = StreamConfig{},
571  hipStream_t cpy_stream = nullptr,
572  hipEvent_t cpy_event = nullptr)
573  {
574  bool has_main_k_block_loop = true;
575 
576  for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
577  {
578  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
579  {
580  std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{"
581  << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0)
582  << ", "
583  << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I1)
584  << ", "
585  << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2)
586  << "}";
587 
588  std::cout << ", arg.b_grid_desc_bk0_n_bk1_{"
589  << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I0)
590  << ", "
591  << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I1)
592  << ", "
593  << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I2)
594  << "}";
595 
596  std::cout << ", arg.e_grid_desc_m_n_{ "
597  << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
598  << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
599  << std::endl;
600  }
601 
602  if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_m_k_,
603  arg.gemm_desc_kernel_arg_[i].b_grid_desc_n_k_,
604  arg.gemm_desc_kernel_arg_[i].ds_grid_desc_m_n_,
605  arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_,
606  arg.gemm_desc_kernel_arg_[i].block_2_etile_map_))
607  {
608  throw std::runtime_error(
609  "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
610  }
611 
612  const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
613  arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
614 
615  if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
616  {
617  throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
618  }
619  }
620 
621  // If the user provides copy stream and copy event, we assume that they're also
622  // responsible for providing allocated host memory (eg. pinned) which
623  // would be used to copy kernel arguments to the device.
624  if(cpy_stream && cpy_event)
625  {
626  if(arg.gemm_kernel_host_args_ == nullptr)
627  {
628  std::ostringstream err;
629  err << "No memory has been allocated for gemm kernel host args "
630  << "when providing the copy stream and copy event! In " << __FILE__ << ":"
631  << __LINE__ << ", in function: " << __func__;
632  throw std::runtime_error(err.str());
633  }
634  hipGetErrorString(hipMemcpyAsync(arg.p_workspace_,
636  arg.group_count_ * sizeof(GemmBiasTransKernelArg),
637  hipMemcpyHostToDevice,
638  cpy_stream));
639  hipGetErrorString(hipEventRecord(cpy_event, cpy_stream));
640  hipGetErrorString(hipEventSynchronize(cpy_event));
641  }
642  else // In this case CK owns memory allocated on host.
643  {
644  hipGetErrorString(hipMemcpyAsync(arg.p_workspace_,
645  arg.gemm_desc_kernel_arg_.data(),
646  arg.gemm_desc_kernel_arg_.size() *
647  sizeof(GemmBiasTransKernelArg),
648  hipMemcpyHostToDevice,
649  stream_config.stream_id_));
650  }
651 
652  float ave_time = 0;
653 
654  auto launch_kernel = [&](auto has_main_k_block_loop_) {
655  const auto kernel = kernel_grouped_gemm_xdl<GridwiseGemm,
656  GemmBiasTransKernelArg,
657  AElementwiseOperation,
658  BElementwiseOperation,
659  CDEElementwiseOperation,
660  has_main_k_block_loop_>;
661 
662  return launch_and_time_kernel(
663  stream_config,
664  kernel,
665  dim3(arg.grid_size_),
666  dim3(BlockSize),
667  0,
669  arg.gemm_desc_kernel_arg_.size(),
670  arg.a_element_op_,
671  arg.b_element_op_,
672  arg.c_element_op_);
673  };
674 
675  if(has_main_k_block_loop)
676  {
677  ave_time = launch_kernel(integral_constant<bool, true>{});
678  }
679  else
680  {
681  ave_time = launch_kernel(integral_constant<bool, false>{});
682  }
683 
684  return ave_time;
685  }
686 
687  float Run(const Argument& arg,
688  const StreamConfig& stream_config = StreamConfig{},
689  hipStream_t cpy_stream = nullptr,
690  hipEvent_t cpy_event = nullptr)
691  {
692  if(get_warp_size() == 64)
693  {
694  if constexpr(NXdlPerWave64 > 0)
695  {
696  return RunImp<GridwiseGemm64>(arg, stream_config, cpy_stream, cpy_event);
697  }
698  }
699  else
700  {
701  if constexpr(NXdlPerWave32 > 0)
702  {
703  return RunImp<GridwiseGemm32>(arg, stream_config, cpy_stream, cpy_event);
704  }
705  }
706  return 0;
707  }
708 
709  // polymorphic
710  float Run(const BaseArgument* p_arg,
711  const StreamConfig& stream_config = StreamConfig{}) override
712  {
713  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
714  }
715  };
716 
717  static bool IsSupportedArgument(const Argument& arg)
718  {
719  if(!ck::is_xdl_wmma_supported<ADataType, BDataType, MPerXDL, NPerXDL>())
720  {
721  return false;
722  }
723  if((ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) +
725  {
726  return false;
727  }
728 
729  bool supported = true;
730 
731  // If we use padding we do not support vector loads for dimensions not divisible by
732  // vector load size.
733  if constexpr(GemmSpec != GemmSpecialization::Default)
734  {
735  // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1}
736  // layout, thus we have to adapt it to the {M,K} or {N,K} layout.
737  const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
738  const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;
739 
740  for(index_t i = 0; i < arg.group_count_; ++i)
741  {
742  const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number<a_raw_vector_dim>{});
743  const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number<b_raw_vector_dim>{});
744 
745  supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0);
746  supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0);
747  }
748  }
749 
750  return supported;
751  }
752 
753  // polymorphic
754  bool IsSupportedArgument(const BaseArgument* p_arg) override
755  {
756  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
757  }
758 
759  static auto MakeArgument(std::vector<const void*>& p_As,
760  std::vector<const void*>& p_Bs,
761  std::vector<std::array<const void*, NumDTensor>>& p_Ds,
762  std::vector<void*>& p_Es,
763  std::vector<GemmDesc> gemm_descs,
764  AElementwiseOperation a_element_op,
765  BElementwiseOperation b_element_op,
766  CDEElementwiseOperation c_element_op)
767  {
768  return Argument{
769  p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
770  }
771 
772  static auto MakeInvoker() { return Invoker{}; }
773 
774  // polymorphic
775  std::unique_ptr<BaseArgument>
776  MakeArgumentPointer(std::vector<const void*>& p_As,
777  std::vector<const void*>& p_Bs,
778  std::vector<std::array<const void*, NumDTensor>>& p_Ds,
779  std::vector<void*>& p_Es,
780  std::vector<GemmDesc>& gemm_descs,
781  AElementwiseOperation a_element_op,
782  BElementwiseOperation b_element_op,
783  CDEElementwiseOperation c_element_op) override
784  {
785  return std::make_unique<Argument>(
786  p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op);
787  }
788 
789  // polymorphic
790  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
791  {
792  return std::make_unique<Invoker>(Invoker{});
793  }
794 
795  // polymorphic
796  std::string GetTypeString() const override
797  {
798  auto str = std::stringstream();
799 
800  // clang-format off
801  str << "DeviceGroupedGemm_Xdl"
802  << "<"
803  << BlockSize << ", "
804  << MPerBlock << ", "
805  << NPerBlock << ", "
806  << KPerBlock << ", "
807  << AK1 << ", "
808  << BK1 << ", "
809  << MPerXDL << ", "
810  << NPerXDL << ", "
811  << MXdlPerWave << ", "
812  << NXdlPerWave << ", "
813  << ABlockTransferSrcScalarPerVector << ", "
814  << BBlockTransferSrcScalarPerVector << ", "
815  << CShuffleMXdlPerWavePerShuffle << ", "
816  << CShuffleNXdlPerWavePerShuffle << ", "
817  << getGemmSpecializationString(GemmSpec)
818  << ">";
819  // clang-format on
820 
821  return str.str();
822  }
823 
824  size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
825  {
826  auto p_arg_ = dynamic_cast<const Argument*>(p_arg);
827  if(p_arg_)
828  {
829  return p_arg_->group_count_ * sizeof(GemmBiasTransKernelArg);
830  }
831  else
832  throw std::runtime_error("The argument pointer is not an object of "
833  "DeviceGroupedGemmMultipleDXdlCShuffle::Argument structure!");
834  }
835 
836  size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
837  {
838  return GetWorkSpaceSize(p_arg);
839  }
840 
841  void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
842  {
843  return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
844  }
845 
846  size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); }
847 
848  //----------------------------------------------------------------------------------------------
858  void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const
859  {
860  Argument* pArg_ = dynamic_cast<Argument*>(p_arg);
861  if(!pArg_)
862  {
863  throw std::runtime_error("Failed to cast argument pointer!");
864  }
865 
866  pArg_->gemm_kernel_host_args_ = p_host_kernel_args;
867  std::copy(pArg_->gemm_desc_kernel_arg_.begin(),
868  pArg_->gemm_desc_kernel_arg_.end(),
869  static_cast<GemmBiasTransKernelArg*>(pArg_->gemm_kernel_host_args_));
870  }
871 };
872 
873 } // namespace device
874 } // namespace tensor_operation
875 } // namespace ck
#define CK_CONSTANT_ADDRESS_SPACE
Definition: ck.hpp:23
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:81
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
auto copy(InputRange &&range, OutputIterator iter) -> decltype(std::copy(std::begin(std::forward< InputRange >(range)), std::end(std::forward< InputRange >(range)), iter))
Definition: algorithm.hpp:14
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
__global__ void kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation c_element_op)
Definition: device_grouped_gemm_xdl.hpp:35
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition: kernel_launch.hpp:173
Definition: ck.hpp:270
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition: amd_address_space.hpp:35
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition: amd_address_space.hpp:24
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:301
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:16
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
__host__ static constexpr __device__ auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:190
__host__ static constexpr __device__ auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:207
__host__ static constexpr __device__ auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:224
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:411
__host__ static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N &ds_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:245
__host__ static constexpr __device__ auto MakeDefaultBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:257
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33
Definition: device_base.hpp:197
void * p_workspace_
Definition: device_base.hpp:204
Definition: device_base.hpp:208
virtual void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const
Definition: device_base.hpp:249
Definition: device_grouped_gemm_xdl.hpp:367
AElementwiseOperation a_element_op_
Definition: device_grouped_gemm_xdl.hpp:551
CDEElementwiseOperation c_element_op_
Definition: device_grouped_gemm_xdl.hpp:553
BElementwiseOperation b_element_op_
Definition: device_grouped_gemm_xdl.hpp:552
void init_gridwise_gemm_desc(const ADataType *a_ptr, const BDataType *b_ptr, DsPointer ds_ptr, EDataType *e_ptr, const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const DsGridDesc_M_N &ds_grid_desc_m_n, const EGridDesc_M_N &e_grid_desc_m_n, const Block2ETileMap &block_2_etile_map, index_t BlockStart, index_t BlockEnd)
Definition: device_grouped_gemm_xdl.hpp:369
index_t grid_size_
Definition: device_grouped_gemm_xdl.hpp:559
index_t group_count_
Definition: device_grouped_gemm_xdl.hpp:548
std::vector< Tuple< index_t, index_t > > b_mtx_nraw_kraw_
Definition: device_grouped_gemm_xdl.hpp:557
std::vector< GemmBiasTransKernelArg > gemm_desc_kernel_arg_
Definition: device_grouped_gemm_xdl.hpp:555
std::vector< Tuple< index_t, index_t > > a_mtx_mraw_kraw_
Definition: device_grouped_gemm_xdl.hpp:556
void * gemm_kernel_host_args_
Definition: device_grouped_gemm_xdl.hpp:560
index_t skipped_group_count_
Definition: device_grouped_gemm_xdl.hpp:549
Argument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor >> &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op)
Definition: device_grouped_gemm_xdl.hpp:426
ck::index_t BlockEnd_
Definition: device_grouped_gemm_xdl.hpp:362
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_grouped_gemm_xdl.hpp:358
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition: device_grouped_gemm_xdl.hpp:355
GridwiseGemm64::DsGridPointer ds_ptr_
Definition: device_grouped_gemm_xdl.hpp:344
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_grouped_gemm_xdl.hpp:357
const BDataType * b_ptr_
Definition: device_grouped_gemm_xdl.hpp:343
GroupedGemmBlock2ETileMap block_2_etile_map_
Definition: device_grouped_gemm_xdl.hpp:361
ck::index_t BlockStart_
Definition: device_grouped_gemm_xdl.hpp:362
AGridDesc_M_K a_grid_desc_m_k_
Definition: device_grouped_gemm_xdl.hpp:348
EGridDesc_M_N e_grid_desc_m_n_
Definition: device_grouped_gemm_xdl.hpp:351
BGridDesc_N_K b_grid_desc_n_k_
Definition: device_grouped_gemm_xdl.hpp:349
DsGridDesc_M_N ds_grid_desc_m_n_
Definition: device_grouped_gemm_xdl.hpp:350
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition: device_grouped_gemm_xdl.hpp:354
EDataType * e_ptr_
Definition: device_grouped_gemm_xdl.hpp:345
const ADataType * a_ptr_
Definition: device_grouped_gemm_xdl.hpp:342
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> Block2ETileMap
Definition: device_grouped_gemm_xdl.hpp:301
__host__ bool CheckValidity(const EGridDesc_M_N &e_grid_desc_m_n) const
Definition: device_grouped_gemm_xdl.hpp:330
Block2ETileMap block_2_etile_map_
Definition: device_grouped_gemm_xdl.hpp:335
ck::index_t BlockStart_
Definition: device_grouped_gemm_xdl.hpp:336
GroupedGemmBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n, ck::index_t BlockStart)
Definition: device_grouped_gemm_xdl.hpp:309
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition: device_grouped_gemm_xdl.hpp:324
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: device_grouped_gemm_xdl.hpp:316
GroupedGemmBlock2ETileMap()
Definition: device_grouped_gemm_xdl.hpp:303
Definition: device_grouped_gemm_xdl.hpp:565
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_grouped_gemm_xdl.hpp:710
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{}, hipStream_t cpy_stream=nullptr, hipEvent_t cpy_event=nullptr)
Definition: device_grouped_gemm_xdl.hpp:687
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{}, hipStream_t cpy_stream=nullptr, hipEvent_t cpy_event=nullptr)
Definition: device_grouped_gemm_xdl.hpp:569
Definition: device_grouped_gemm_xdl.hpp:150
static auto MakeInvoker()
Definition: device_grouped_gemm_xdl.hpp:772
static constexpr index_t NumDTensor
Definition: device_grouped_gemm_xdl.hpp:155
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({}, {}, {}))> DsGridDesc_M_N
Definition: device_grouped_gemm_xdl.hpp:234
remove_cvref_t< decltype(GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: device_grouped_gemm_xdl.hpp:296
static constexpr auto matrix_padder
Definition: device_grouped_gemm_xdl.hpp:161
std::string GetTypeString() const override
Definition: device_grouped_gemm_xdl.hpp:796
size_t GetHostKernelArgSize(const BaseArgument *p_arg) const
Definition: device_grouped_gemm_xdl.hpp:846
static auto MakeDsGridDescriptor_M_N(const std::array< index_t, NumDTensor > &MRaws, const std::array< index_t, NumDTensor > &NRaws, const std::array< index_t, NumDTensor > &DsStride)
Definition: device_grouped_gemm_xdl.hpp:219
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor >> &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op) override
Definition: device_grouped_gemm_xdl.hpp:776
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition: device_grouped_gemm_xdl.hpp:182
static auto MakeArgument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor >> &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op)
Definition: device_grouped_gemm_xdl.hpp:759
static constexpr auto I2
Definition: device_grouped_gemm_xdl.hpp:159
void SetHostKernelArgsPointer(BaseArgument *p_arg, void *p_host_kernel_args) const
Sets the host kernel arguments pointer and copies that data on the host side. This function can be ut...
Definition: device_grouped_gemm_xdl.hpp:858
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition: device_grouped_gemm_xdl.hpp:233
size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const override
Gets the device kernel argument size.
Definition: device_grouped_gemm_xdl.hpp:836
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_grouped_gemm_xdl.hpp:754
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_grouped_gemm_xdl.hpp:790
static constexpr auto I0
Definition: device_grouped_gemm_xdl.hpp:157
remove_cvref_t< decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: device_grouped_gemm_xdl.hpp:293
static constexpr auto NXdlPerWave32
Definition: device_grouped_gemm_xdl.hpp:154
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition: device_grouped_gemm_xdl.hpp:290
static bool IsSupportedArgument(const Argument &arg)
Definition: device_grouped_gemm_xdl.hpp:717
decltype(MakeEGridDescriptor_M_N< ELayout >(1, 1, 1)) EGridDesc_M_N
Definition: device_grouped_gemm_xdl.hpp:235
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition: device_grouped_gemm_xdl.hpp:287
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition: device_grouped_gemm_xdl.hpp:232
void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args) const override
Sets the device kernel arguments pointer and may copy data to device.
Definition: device_grouped_gemm_xdl.hpp:841
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_grouped_gemm_xdl.hpp:153
static constexpr auto I1
Definition: device_grouped_gemm_xdl.hpp:158
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition: device_grouped_gemm_xdl.hpp:824
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
Definition: device_grouped_gemm_xdl.hpp:201
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition: device_grouped_gemm_xdl.hpp:164
Definition: device_grouped_gemm.hpp:100
Definition: device_grouped_gemm.hpp:80
Definition: matrix_padder.hpp:180
#define CK_ENV(name)
Definition: env.hpp:128