/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp Source File
gridwise_gemm_xdlops_streamk.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
20 
21 namespace ck {
22 
23 template <typename GridwiseGemm>
24 __global__ void
25 #if CK_USE_LAUNCH_BOUNDS
27 #endif
28  kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid,
29  const typename GridwiseGemm::FloatAB* p_b_grid,
30  typename GridwiseGemm::FloatC* p_c_grid,
31  void* p_workspace,
32  index_t M,
33  index_t N,
34  index_t K,
35  index_t StrideA,
36  index_t StrideB,
37  index_t StrideC,
38  typename GridwiseGemm::Block2CTileMap block_mapping)
39 {
40 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
41  defined(__gfx94__))
42  constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
43 
44  __shared__ uint8_t p_shared[shared_size];
45 
46  GridwiseGemm::Run(p_a_grid,
47  p_b_grid,
48  p_c_grid,
49  p_workspace,
50  M,
51  N,
52  K,
53  StrideA,
54  StrideB,
55  StrideC,
56  block_mapping,
57  static_cast<void*>(p_shared));
58 #else
59  ignore = p_a_grid;
60  ignore = p_b_grid;
61  ignore = p_c_grid;
62  ignore = p_workspace;
63  ignore = M;
64  ignore = N;
65  ignore = K;
66  ignore = StrideA;
67  ignore = StrideB;
68  ignore = StrideC;
69  ignore = block_mapping;
70 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
71 }
72 
73 template <index_t BlockSize,
74  typename Block2CTileMap_,
75  typename FloatAB_,
76  typename FloatAcc_,
77  typename FloatC_,
78  typename ALayout,
79  typename BLayout,
80  typename CLayout,
81  typename AElementwiseOperation,
82  typename BElementwiseOperation,
83  typename CElementwiseOperation,
84  index_t MPerBlock,
85  index_t NPerBlock,
86  index_t K0PerBlock,
87  index_t MPerXDL,
88  index_t NPerXDL,
89  index_t K1Value,
90  index_t MRepeat,
91  index_t NRepeat,
92  typename ABlockTransferThreadClusterLengths_K0_M_K1,
93  typename ABlockTransferThreadClusterArrangeOrder,
94  typename ABlockTransferSrcAccessOrder,
95  index_t ABlockTransferSrcVectorDim,
96  index_t ABlockTransferSrcScalarPerVector,
97  index_t ABlockTransferDstScalarPerVector_K1,
98  bool AThreadTransferSrcResetCoordinateAfterRun,
99  index_t ABlockLdsExtraM,
100  typename BBlockTransferThreadClusterLengths_K0_N_K1,
101  typename BBlockTransferThreadClusterArrangeOrder,
102  typename BBlockTransferSrcAccessOrder,
103  index_t BBlockTransferSrcVectorDim,
104  index_t BBlockTransferSrcScalarPerVector,
105  index_t BBlockTransferDstScalarPerVector_K1,
106  bool BThreadTransferSrcResetCoordinateAfterRun,
107  index_t BBlockLdsExtraN,
108  index_t CShuffleMRepeatPerShuffle,
109  index_t CShuffleNRepeatPerShuffle,
110  index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
111  typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>
113 {
114  static constexpr auto I0 = Number<0>{};
115  static constexpr auto I1 = Number<1>{};
116  static constexpr auto I2 = Number<2>{};
117  static constexpr auto I3 = Number<3>{};
118  static constexpr auto I4 = Number<4>{};
119  static constexpr auto I5 = Number<5>{};
120  static constexpr auto I6 = Number<6>{};
121  static constexpr auto I7 = Number<7>{};
122 
123  // K1 should be Number<...>
124  static constexpr auto K1 = Number<K1Value>{};
125  static constexpr auto M01 = 1;
126  static constexpr auto N01 = 1;
127  static constexpr auto KPerBlock = K0PerBlock * K1;
128 
130  using FloatAcc = FloatAcc_;
132 
133  using Block2CTileMap = Block2CTileMap_;
134  using FloatAB = FloatAB_;
135  using FloatC = FloatC_;
136 
138  {
149 
150  Argument(const FloatAB* p_a_grid_,
151  const FloatAB* p_b_grid_,
152  FloatC* p_c_grid_,
153  index_t M_,
154  index_t N_,
155  index_t K_,
156  index_t StrideA_,
157  index_t StrideB_,
158  index_t StrideC_,
159  uint32_t num_cu,
160  uint32_t occupancy,
161  uint32_t num_sk_blocks_)
162  : p_a_grid(p_a_grid_),
163  p_b_grid(p_b_grid_),
164  p_c_grid(p_c_grid_),
165  M(M_),
166  N(N_),
167  K(K_),
168  StrideA(StrideA_),
169  StrideB(StrideB_),
170  StrideC(StrideC_),
171  block_mapping(M, N, K, num_cu, occupancy, num_sk_blocks_)
172  {
173  }
174 
175  void Print() const
176  {
177  std::cout << "arg {"
178  << "M:" << M << ", "
179  << "N:" << N << ", "
180  << "K:" << K << ", "
181  << "SA:" << StrideA << ", "
182  << "SB:" << StrideB << ", "
183  << "SC:" << StrideC << std::endl;
184  }
185  };
186 
187  __host__ __device__ static auto CalculateGridSize(const Argument& karg)
188  {
189  return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock),
190  math::integer_divide_ceil(karg.M, MPerBlock),
191  karg.k_batch);
192  }
193 
194  __host__ __device__ static auto CalculateK0(index_t KPad) { return KPad / K1; }
195 
196  __host__ __device__ static auto
198  {
199  const index_t K0 = CalculateK0(KPad);
200 
201  const auto a_grid_desc_m_k = [&]() {
203  {
204  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
205  }
207  {
208  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
209  }
210  }();
211 
212  const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
213  a_grid_desc_m_k,
217 
218  return transform_tensor_descriptor(a_grid_desc_m_kpad,
220  make_right_pad_transform(M, MPad - M)),
223  }
224 
225  __host__ __device__ static auto
227  {
228  const index_t K0 = CalculateK0(KPad);
229 
230  const auto b_grid_desc_k_n = [&]() {
232  {
233  return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
234  }
236  {
237  return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
238  }
239  }();
240 
241  const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
242  b_grid_desc_k_n,
246 
247  return transform_tensor_descriptor(b_grid_desc_kpad_n,
249  make_right_pad_transform(N, NPad - N)),
252  }
253 
254  __host__ __device__ static auto
256  {
257  const auto c_grid_desc_m_n = [&]() {
259  {
260  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
261  }
263  {
264  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
265  }
266  }();
267 
268  return transform_tensor_descriptor(c_grid_desc_m_n,
270  make_right_pad_transform(N, NPad - N)),
273  }
274 
275  __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
276  {
277  // A matrix in LDS memory, dst of blockwise copy
281  }
282 
283  __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
284  {
285  // B matrix in LDS memory, dst of blockwise copy
289  }
290 
291  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
292  {
293  constexpr auto max_lds_align = K1;
294 
295  // LDS allocation for A and B: be careful of alignment
296  constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
297  constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
298 
299  constexpr auto a_block_space_size_aligned =
300  math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
301 
302  constexpr auto b_block_space_size_aligned =
303  math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
304 
305  constexpr auto c_block_size =
307 
308  return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
309  sizeof(FloatAB),
310  c_block_size * sizeof(FloatCShuffle));
311  }
312 
313  __host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
314  {
316  {
317  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
318  return false;
319  }
320  else
321  {
322  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
323  return false;
324  }
325 
327  {
328  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
329  return false;
330  }
331  else
332  {
333  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
334  return false;
335  }
336 
338  {
339  if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
340  return false;
341  }
342  else
343  {
344  if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
345  return false;
346  }
347 
348  return true;
349  }
350 
351  __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
352  {
353  const bool has_main_k0_block_loop = K0 > K0PerBlock;
354 
355  return has_main_k0_block_loop;
356  }
357 
358  template <typename CGridDesc>
359  __host__ __device__ static constexpr auto
360  MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc)
361  {
362  const auto M = c_m_n_grid_desc.GetLength(I0);
363  const auto N = c_m_n_grid_desc.GetLength(I1);
364 
365  const auto MBlock = M / MPerBlock;
366  const auto NBlock = N / NPerBlock;
367 
369  c_m_n_grid_desc,
374  }
375 
376  // return block_id to C matrix tile idx (m0, n0) mapping
377  template <typename CGridDesc>
378  __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
379  const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
380  {
382  c_m_n_grid_desc, 8, KBatch);
383  }
384 
385  __host__ __device__ static constexpr auto
387  {
388  constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
389  constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
390 
392  make_tuple(I1,
394  I1,
396  }
397 
398  __host__ __device__ static constexpr auto
400  {
401  constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
402  constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
403 
407  Number<NRepeat / CShuffleNRepeatPerShuffle>{},
409  }
410 
411  __host__ __device__ static constexpr auto GetClusterLengthReduction()
412  {
413  // TODO: assume C is row major
414  // TODO: we always first loop over N, then M
415  constexpr auto NPerBlockPow2 = math::next_power_of_two<NPerBlock>();
416  constexpr auto NPerBlockReduction =
417  NPerBlockPow2 / CBlockTransferScalarPerVector_NWaveNPerXDL;
418  constexpr auto MPerBlockReduction =
419  (BlockSize + NPerBlockReduction - 1) / NPerBlockReduction;
421  }
422 
423  __host__ __device__ static constexpr auto GetPartialAccBlockDescriptor()
424  {
425  const auto c_partial_acc_block_m_n = [&]() {
427  {
428  return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
429  make_tuple(NPerBlock, I1));
430  }
432  {
433  return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
434  make_tuple(I1, MPerBlock));
435  }
436  }();
437  return c_partial_acc_block_m_n;
438  }
439 
440  using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))>;
441 
442  __device__ static void Run(const FloatAB* p_a_grid,
443  const FloatAB* p_b_grid,
444  FloatC* p_c_grid,
445  void* p_workspace,
446  index_t M,
447  index_t N,
448  index_t K,
449  index_t StrideA,
450  index_t StrideB,
451  index_t StrideC,
452  Block2CTileMap block_mapping,
453  void* __restrict__ p_shared_block)
454  {
455  uint32_t m = M;
456  uint32_t n = N;
457  uint32_t k = K;
458  uint32_t pad_m = (m + MPerBlock - 1) / MPerBlock * MPerBlock;
459  uint32_t pad_n = (n + NPerBlock - 1) / NPerBlock * NPerBlock;
460  uint32_t pad_k = (k + KPerBlock - 1) / KPerBlock * KPerBlock;
461  uint32_t stride_a = StrideA;
462  uint32_t stride_b = StrideB;
463  uint32_t stride_c = StrideC;
464 
465  const auto a_k0_m_k1_grid_desc = MakeAGridDescriptor_K0_M_K1(m, pad_m, k, pad_k, stride_a);
466  const auto b_k0_n_k1_grid_desc = MakeBGridDescriptor_K0_N_K1(k, pad_k, n, pad_n, stride_b);
467  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(m, pad_m, n, pad_n, stride_c);
468 
469  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
471  const AElementwiseOperation a_element_op = AElementwiseOperation{};
472  const BElementwiseOperation b_element_op = BElementwiseOperation{};
473  const CElementwiseOperation c_element_op = CElementwiseOperation{};
474 
475  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
476  p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
477  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
478  p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
479  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
480  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
481 
482  // lds max alignment
483  constexpr auto max_lds_align = K1;
484 
485  // A matrix in LDS memory, dst of blockwise copy
486  constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
487 
488  // B matrix in LDS memory, dst of blockwise copy
489  constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
490 
491  auto blockwise_gemm =
493  FloatAB,
494  FloatAB,
495  FloatAcc,
496  decltype(a_block_desc_k0_m_k1),
497  decltype(b_block_desc_k0_n_k1),
498  MPerXDL,
499  NPerXDL,
500  MRepeat,
501  NRepeat,
502  K1>{};
503 
504  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
505 
506  // LDS allocation for A and B: be careful of alignment
507  constexpr auto a_block_space_size =
508  math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
509 
510  FloatAB* p_a_block = static_cast<FloatAB*>(p_shared_block);
511  FloatAB* p_b_block = static_cast<FloatAB*>(p_shared_block) + a_block_space_size;
512 
513  constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
514  constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
515 
516  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
517  p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize());
518  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
519  p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize());
520 
521  // gridwise GEMM pipeline
522  const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v3();
523 
524  uint32_t block_idx = block_mapping.get_block_idx();
525  bool is_sk_block = block_idx < block_mapping.sk_num_blocks;
526  bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx &&
527  block_idx < block_mapping.reduction_start_block_idx;
528  bool is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx;
529  bool is_padding_block = block_idx >= block_mapping.sk_num_blocks &&
530  block_idx < block_mapping.dp_start_block_idx;
531  uint32_t iter_start, iter_end;
532  block_mapping.get_block_itr(block_idx, iter_start, iter_end);
533  uint32_t total_iter_length = iter_end - iter_start;
534 
535  if(is_padding_block)
536  return;
537 
538  uint32_t* p_semaphore =
539  reinterpret_cast<uint32_t*>(reinterpret_cast<char*>(p_workspace) +
540  block_mapping.get_workspace_size_for_acc(sizeof(FloatAcc)));
541 
542  if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction)
543  {
544  if(is_reduction_block)
545  {
546  // descriptors
547  constexpr auto cluster_length_reduce = GetClusterLengthReduction();
548  constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
549  const auto reduce_thread_cluster_idx =
550  reduce_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
551  const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
552  const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
553 
554  constexpr auto MReduceIters =
555  math::integer_divide_ceil(Number<MPerBlock>{}, cluster_length_reduce.At(I0));
556  constexpr auto NReduceIters = math::integer_divide_ceil(
558  cluster_length_reduce.At(I1) *
560 
561  constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
563  constexpr auto acc_thread_buf_store_desc = make_naive_tensor_descriptor_packed(
565 
566  constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
567 
568  constexpr auto partial_acc_load_step_n = make_multi_index(
569  0, cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
570  constexpr auto partial_acc_load_step_n_reverse =
572  -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
573  CBlockTransferScalarPerVector_NWaveNPerXDL);
574  constexpr auto partial_acc_load_step_m =
575  make_multi_index(cluster_length_reduce.At(I0), 0);
576 
577  constexpr auto partial_acc_store_step_n = make_multi_index(
578  0,
579  0,
580  0,
581  cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
582  constexpr auto partial_acc_store_step_n_reverse =
584  0,
585  0,
586  -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
587  CBlockTransferScalarPerVector_NWaveNPerXDL);
588  constexpr auto partial_acc_store_step_m =
589  make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
590 
592  FloatAcc,
593  CBlockTransferScalarPerVector_NWaveNPerXDL,
594  true>
595  parcial_acc_buf;
597  FloatAcc,
598  CBlockTransferScalarPerVector_NWaveNPerXDL,
599  true>
600  acc_buf;
601 
602  // start to compute
603  auto reduction_idx = blockIdx.x - block_mapping.reduction_start_block_idx;
604  auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n);
605 
606  workgroup_barrier wg_barrier(p_semaphore);
607 
608  uint32_t tile_acc_offset_start =
609  block_mapping.get_acc_buffer_offset_from_tile(reduction_idx);
610  uint32_t tile_acc_offset_end =
611  block_mapping.get_acc_buffer_offset_from_tile(reduction_idx + 1);
612 
613  auto acc_load = ThreadwiseTensorSliceTransfer_v2<
614  FloatAcc, // SrcData,
615  FloatAcc, // DstData,
616  decltype(c_partial_acc_block_m_n), // SrcDesc,
617  decltype(acc_thread_buf_load_desc), // DstDesc,
619  Sequence<0, 1>, // DimAccessOrder,
620  1, // SrcVectorDim,
621  CBlockTransferScalarPerVector_NWaveNPerXDL, // SrcScalarPerVector,
622  1, // SrcScalarStrideInVector,
623  false // SrcResetCoordinateAfterRun,
624  >{c_partial_acc_block_m_n,
625  make_multi_index(thread_m_cluster_id,
626  thread_n_cluster_id *
627  CBlockTransferScalarPerVector_NWaveNPerXDL)};
628 
629  auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
630  FloatAcc, // SrcData,
631  FloatC, // DstData,
632  decltype(acc_thread_buf_store_desc), // SrcDesc,
633  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
634  CElementwiseOperation, // ElementwiseOperation,
636  Sequence<0, 1, 2, 3>, // DimAccessOrder,
637  3, // DstVectorDim,
638  CBlockTransferScalarPerVector_NWaveNPerXDL, // DstScalarPerVector,
639  InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
640  1, // DstScalarStrideInVector,
641  false // DstResetCoordinateAfterRun,
642  >{c_grid_desc_mblock_mperblock_nblock_nperblock,
643  make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
644  thread_m_cluster_id,
645  __builtin_amdgcn_readfirstlane(spatial_idx[I1]),
646  thread_n_cluster_id *
647  CBlockTransferScalarPerVector_NWaveNPerXDL),
648  CElementwiseOperation{}};
649 
650  // block synchronization
651  wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
652 
653 #if 0
654  if(threadIdx.x == 0) {
655  printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(blockIdx.x),
656  reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
657  __builtin_amdgcn_readfirstlane(spatial_idx[I0]),
658  __builtin_amdgcn_readfirstlane(spatial_idx[I1]));
659  }
660 #endif
661 
662  using Accumulation = ck::detail::
663  AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, FloatAcc>;
664 
665  for(int i_m = 0; i_m < MReduceIters; i_m++)
666  {
667  static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
668  acc_buf.Clear();
669  for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
670  {
671  auto c_partial_acc_buf =
674  reinterpret_cast<FloatAcc*>(p_workspace) +
675  i * c_partial_acc_block_m_n.GetElementSpaceSize(),
676  c_partial_acc_block_m_n.GetElementSpaceSize());
677 
678  acc_load.Run(c_partial_acc_block_m_n,
679  c_partial_acc_buf,
680  acc_thread_buf_load_desc,
681  make_tuple(I0, I0),
682  parcial_acc_buf);
683 
685  [&](auto i_vec) {
686  constexpr auto offset =
687  acc_thread_buf_load_desc.CalculateOffset(
688  make_tuple(0, i_vec));
689  Accumulation::Calculate(acc_buf(Number<offset>{}),
690  parcial_acc_buf[Number<offset>{}]);
691  });
692  }
693 
694  if(thread_n_cluster_id * CBlockTransferScalarPerVector_NWaveNPerXDL <
695  NPerBlock)
696  {
697  acc_store.Run(acc_thread_buf_store_desc,
698  make_tuple(I0, I0, I0, I0),
699  acc_buf,
700  c_grid_desc_mblock_mperblock_nblock_nperblock,
701  c_grid_buf);
702  }
703  if constexpr(NReduceIters != 1)
704  {
705  if constexpr(i_n_reduce != (NReduceIters - 1))
706  {
707  acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
708  partial_acc_load_step_n);
709  acc_store.MoveDstSliceWindow(
710  c_grid_desc_mblock_mperblock_nblock_nperblock,
711  partial_acc_store_step_n);
712  }
713  else
714  {
715  acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
716  partial_acc_load_step_n_reverse);
717  acc_store.MoveDstSliceWindow(
718  c_grid_desc_mblock_mperblock_nblock_nperblock,
719  partial_acc_store_step_n_reverse);
720  }
721  }
722  });
723  {
724  acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
725  partial_acc_load_step_m);
726  acc_store.MoveDstSliceWindow(c_grid_desc_mblock_mperblock_nblock_nperblock,
727  partial_acc_store_step_m);
728  }
729  }
730  return;
731  }
732  }
733 
734  // offset for last acc buffer of this block
735  uint32_t block_acc_offset =
736  (block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock *
737  NPerBlock;
738 
739  while(true)
740  {
741  uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
742  block_mapping.get_current_iter_length(iter_start, iter_end, total_iter_length));
743  uint32_t tile_idx, iter_offset;
744  block_mapping.get_tile_idx_with_offset(iter_end - 1, tile_idx, iter_offset);
745  iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
746  auto spatial_idx = block_mapping.tile_to_spatial(tile_idx, m, n);
747 
748  const index_t m_block_data_idx_on_grid =
749  __builtin_amdgcn_readfirstlane(spatial_idx[I0] * MPerBlock);
750 
751  const index_t n_block_data_idx_on_grid =
752  __builtin_amdgcn_readfirstlane(spatial_idx[I1] * NPerBlock);
753 
754  const index_t k0_block_data_idx_on_grid =
755  __builtin_amdgcn_readfirstlane(iter_offset * K0PerBlock);
756 
757  // A matrix blockwise copy
758  auto a_blockwise_copy =
760  AElementwiseOperation,
764  ABlockTransferThreadClusterLengths_K0_M_K1,
765  ABlockTransferThreadClusterArrangeOrder,
766  FloatAB,
767  FloatAB,
768  decltype(a_k0_m_k1_grid_desc),
769  decltype(a_block_desc_k0_m_k1),
770  ABlockTransferSrcAccessOrder,
772  ABlockTransferSrcVectorDim,
773  2,
774  ABlockTransferSrcScalarPerVector,
775  ABlockTransferDstScalarPerVector_K1,
776  1,
777  1,
778  AThreadTransferSrcResetCoordinateAfterRun,
779  true>(
780  a_k0_m_k1_grid_desc,
781  make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0),
782  a_element_op,
783  a_block_desc_k0_m_k1,
784  make_multi_index(0, 0, 0),
786 
787  // B matrix blockwise copy
788  auto b_blockwise_copy =
790  BElementwiseOperation,
794  BBlockTransferThreadClusterLengths_K0_N_K1,
795  BBlockTransferThreadClusterArrangeOrder,
796  FloatAB,
797  FloatAB,
798  decltype(b_k0_n_k1_grid_desc),
799  decltype(b_block_desc_k0_n_k1),
800  BBlockTransferSrcAccessOrder,
802  BBlockTransferSrcVectorDim,
803  2,
804  BBlockTransferSrcScalarPerVector,
805  BBlockTransferDstScalarPerVector_K1,
806  1,
807  1,
808  BThreadTransferSrcResetCoordinateAfterRun,
809  true>(
810  b_k0_n_k1_grid_desc,
811  make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0),
812  b_element_op,
813  b_block_desc_k0_n_k1,
814  make_multi_index(0, 0, 0),
816 
817  const index_t num_k_block_main_loop = current_iter_length;
818 
819  gridwise_gemm_pipeline.Run(a_k0_m_k1_grid_desc,
820  a_block_desc_k0_m_k1,
821  a_blockwise_copy,
822  a_grid_buf,
823  a_block_buf,
824  a_block_slice_copy_step,
825  b_k0_n_k1_grid_desc,
826  b_block_desc_k0_n_k1,
827  b_blockwise_copy,
828  b_grid_buf,
829  b_block_buf,
830  b_block_slice_copy_step,
831  blockwise_gemm,
832  c_thread_buf,
833  num_k_block_main_loop);
834 
835  // output: register to global memory
836  {
837  constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
838  constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
839 
840  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
841  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
842 
843  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
844  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
845 
846  constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
847  constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
848  constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
849  constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
850  constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
851  constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
852  constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
853  constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
854 
855  constexpr auto c_block_desc_mblock_mpershuffle_nblock_npershuffle =
857 
858  constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
860 
861  auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
862  reinterpret_cast<FloatCShuffle*>(p_shared_block),
863  c_block_desc_mblock_mpershuffle_nblock_npershuffle.GetElementSpaceSize());
864 
865  auto c_partial_acc_buf =
866  make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
867  reinterpret_cast<FloatAcc*>(p_workspace) + block_acc_offset,
868  c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
869  .GetElementSpaceSize());
870 
871  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
872  c_block_desc_mblock_mpershuffle_nblock_npershuffle,
873  make_tuple(make_freeze_transform(I0), // freeze mblock
875  make_tuple(CShuffleMRepeatPerShuffle,
876  M1,
877  M2,
878  M3,
879  M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL
880  make_freeze_transform(I0), // freeze nblock
882  make_tuple(CShuffleNRepeatPerShuffle,
883  N1,
884  N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL
888  Sequence<>{},
889  Sequence<1, 3, 7>{}));
890 
891  // calculate origin of thread output tensor on global memory
892  // blockwise GEMM c matrix starting index
893  const auto c_thread_mtx_on_block =
894  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
895 
896  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
897  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
898 
899  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
901  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
904 
905  const auto m_thread_data_on_block_idx =
906  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
907  make_multi_index(m_thread_data_on_block));
908 
909  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
914 
915  const auto n_thread_data_on_block_idx =
916  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
917  make_multi_index(n_thread_data_on_block));
918 
919  // VGPR to LDS
920  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
921  FloatAcc,
923  decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
924  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
926  Sequence<CShuffleMRepeatPerShuffle,
927  CShuffleNRepeatPerShuffle,
928  I1,
929  I1,
930  M2,
931  I1,
932  M4,
933  I1>,
935  7,
936  1,
938  1,
939  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
941  0,
942  m_thread_data_on_block_idx[I1],
943  n_thread_data_on_block_idx[I1],
944  m_thread_data_on_block_idx[I2],
945  m_thread_data_on_block_idx[I3],
946  m_thread_data_on_block_idx[I4],
947  n_thread_data_on_block_idx[I2]),
949 
950  // LDS to global
951  auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2<
952  ThisThreadBlock, // index_t BlockSize,
953  CElementwiseOperation, // ElementwiseOperation,
954  // InMemoryDataOperationEnum::Set, // DstInMemOp,
955  Sequence<1,
956  CShuffleMRepeatPerShuffle * MWave * MPerXDL,
957  1,
958  CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths,
959  CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
960  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
961  FloatCShuffle, // typename SrcData,
962  FloatC, // typename DstData,
963  decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
964  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
965  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
966  3, // index_t VectorDim,
967  CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
968  false, // bool ThreadTransferSrcResetCoordinateAfterRun,
969  false> // bool ThreadTransferDstResetCoordinateAfterRun
970  {c_block_desc_mblock_mpershuffle_nblock_npershuffle,
971  make_multi_index(0, 0, 0, 0),
972  c_grid_desc_mblock_mperblock_nblock_nperblock,
973  make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
974  0,
975  __builtin_amdgcn_readfirstlane(spatial_idx[I1]),
976  0),
977  c_element_op};
978 
979  // LDS to global partial acc
980  auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
981  ThisThreadBlock, // index_t BlockSize,
982  CElementwiseOperation, // ElementwiseOperation,
983  // InMemoryDataOperationEnum::Set, // DstInMemOp,
984  Sequence<1,
985  CShuffleMRepeatPerShuffle * MWave * MPerXDL,
986  1,
987  CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths,
988  CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
989  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
990  FloatCShuffle, // typename SrcData,
991  FloatCShuffle, // typename DstData,
992  decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
993  decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
994  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
995  3, // index_t VectorDim,
996  CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
997  false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be false,
998  // othre wise has scratch
999  false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be false,
1000  // othre wise has scratch
1001  {c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1002  make_multi_index(0, 0, 0, 0),
1003  c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1004  make_multi_index(0, 0, 0, 0),
1005  c_element_op};
1006 
1007  constexpr auto mxdlperwave_forward_step =
1008  make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0);
1009  constexpr auto nxdlperwave_forward_step =
1010  make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL);
1011  constexpr auto nxdlperwave_backward_step =
1012  make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL);
1013 
1014  static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
1015  constexpr auto mxdlperwave = mxdlperwave_iter;
1016 
1017  static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
1018  constexpr bool nxdlperwave_forward_sweep =
1019  (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
1020 
1021  constexpr index_t nxdlperwave_value =
1022  nxdlperwave_forward_sweep
1023  ? nxdlperwave_iter
1024  : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
1025 
1026  constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
1027 
1028  // make sure it's safe to do ds_write
1029  block_sync_lds();
1030 
1031  // VGPR to LDS
1032  c_thread_copy_vgpr_to_lds.Run(
1033  c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
1034  make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
1035  c_thread_buf,
1036  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1037  c_block_buf);
1038 
1039  // make sure it's safe to do ds_read
1040  block_sync_lds();
1041 
1042  c_block_copy_lds_to_global.SetSrcSliceOrigin(
1043  c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1044  make_tuple(0, 0, 0, 0));
1045 
1046  // LDS to global
1047  if(is_dp_block)
1048  c_block_copy_lds_to_global.template Run<decltype(c_block_buf),
1049  decltype(c_grid_buf),
1051  c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1052  c_block_buf,
1053  c_grid_desc_mblock_mperblock_nblock_nperblock,
1054  c_grid_buf);
1055  else if(is_sk_block)
1056  {
1057  if constexpr(Block2CTileMap::ReductionStrategy ==
1059  {
1060  // constexpr offset
1061  c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
1062  c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1063  make_tuple(0, 0, 0, 0));
1064 
1065  c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
1066  c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1067  make_tuple(mxdlperwave.value, 0, nxdlperwave.value, 0));
1068 
1069  c_block_copy_lds_to_partial_acc
1070  .template Run<decltype(c_block_buf),
1071  decltype(c_partial_acc_buf),
1073  c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1074  c_block_buf,
1075  c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1076  c_partial_acc_buf);
1077  }
1078  else if constexpr(Block2CTileMap::ReductionStrategy ==
1080  {
1081  c_block_copy_lds_to_global
1082  .template Run<decltype(c_block_buf),
1083  decltype(c_grid_buf),
1085  c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1086  c_block_buf,
1087  c_grid_desc_mblock_mperblock_nblock_nperblock,
1088  c_grid_buf);
1089  }
1090  }
1091 
1092  // move on nxdlperwave dimension
1093  if constexpr(nxdlperwave_forward_sweep &&
1094  (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
1095  {
1096  c_block_copy_lds_to_global.MoveDstSliceWindow(
1097  c_grid_desc_mblock_mperblock_nblock_nperblock,
1098  nxdlperwave_forward_step);
1099  }
1100  else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
1101  {
1102  c_block_copy_lds_to_global.MoveDstSliceWindow(
1103  c_grid_desc_mblock_mperblock_nblock_nperblock,
1104  nxdlperwave_backward_step);
1105  }
1106  });
1107 
1108  // move on mxdlperwave dimension
1109  if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
1110  {
1111  c_block_copy_lds_to_global.MoveDstSliceWindow(
1112  c_grid_desc_mblock_mperblock_nblock_nperblock,
1113  mxdlperwave_forward_step);
1114  }
1115  });
1116 
1117  if constexpr(Block2CTileMap::ReductionStrategy ==
1119  {
1120  if(is_sk_block)
1121  {
1122  // increase the counter for this tile
1123  workgroup_barrier wg_barrier(p_semaphore);
1124  wg_barrier.inc(tile_idx);
1125  }
1126  }
1127  }
1128 
1129  // exit condition
1130  iter_end -= current_iter_length;
1131  if(iter_end <= iter_start)
1132  break;
1133 
1134  if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction)
1135  {
1136  block_acc_offset -= MPerBlock * NPerBlock;
1137  }
1138  // make sure next loop LDS is ready for use
1139  block_sync_lds();
1140  }
1141  }
1142 
1143  template <typename Layout>
1144  struct LStr
1145  {
1146  static std::string Get() { return ""; }
1147  };
1148 
1149  template <>
1151  {
1152  static std::string Get() { return "R"; }
1153  };
1154 
1155  template <>
1157  {
1158  static std::string Get() { return "C"; }
1159  };
1160 
1161  static std::string GetTypeString()
1162  {
1163  auto str = std::stringstream();
1164 
1165  // clang-format off
1166  str << "GemmXdlStreamK_"
1167  << std::string(ALayout::name)[0]
1168  << std::string(BLayout::name)[0]
1169  << std::string(CLayout::name)[0]
1170  << "_"
1171  << "B" << BlockSize << "_"
1172  << "Vec" << ABlockTransferSrcScalarPerVector << "x"
1173  << BBlockTransferSrcScalarPerVector << "x"
1174  << CBlockTransferScalarPerVector_NWaveNPerXDL << "_"
1175  << MPerBlock << "x"
1176  << NPerBlock << "x"
1177  << K0PerBlock << "x"
1178  << K1 ;
1179  // clang-format on
1180 
1181  return str.str();
1182  }
1183 };
1184 
1185 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
@ Atomic
Definition: block_to_ctile_map.hpp:1009
@ Reduction
Definition: block_to_ctile_map.hpp:1010
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__global__ void kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB *p_a_grid, const typename GridwiseGemm::FloatAB *p_b_grid, typename GridwiseGemm::FloatC *p_c_grid, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, typename GridwiseGemm::Block2CTileMap block_mapping)
Definition: gridwise_gemm_xdlops_streamk.hpp:28
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:16
__host__ constexpr __device__ auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition: dynamic_buffer.hpp:448
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: block_to_ctile_map.hpp:539
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
Definition: gridwise_gemm_xdlops_streamk.hpp:138
index_t K
Definition: gridwise_gemm_xdlops_streamk.hpp:144
const FloatAB * p_b_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:140
void Print() const
Definition: gridwise_gemm_xdlops_streamk.hpp:175
index_t M
Definition: gridwise_gemm_xdlops_streamk.hpp:142
Argument(const FloatAB *p_a_grid_, const FloatAB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, uint32_t num_cu, uint32_t occupancy, uint32_t num_sk_blocks_)
Definition: gridwise_gemm_xdlops_streamk.hpp:150
FloatC * p_c_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:141
const FloatAB * p_a_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:139
index_t StrideC
Definition: gridwise_gemm_xdlops_streamk.hpp:147
index_t StrideB
Definition: gridwise_gemm_xdlops_streamk.hpp:146
index_t StrideA
Definition: gridwise_gemm_xdlops_streamk.hpp:145
index_t N
Definition: gridwise_gemm_xdlops_streamk.hpp:143
Block2CTileMap block_mapping
Definition: gridwise_gemm_xdlops_streamk.hpp:148
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1158
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1152
Definition: gridwise_gemm_xdlops_streamk.hpp:1145
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1146
Definition: gridwise_gemm_xdlops_streamk.hpp:113
static constexpr auto I5
Definition: gridwise_gemm_xdlops_streamk.hpp:119
static __device__ void Run(const FloatAB *p_a_grid, const FloatAB *p_b_grid, FloatC *p_c_grid, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, Block2CTileMap block_mapping, void *__restrict__ p_shared_block)
Definition: gridwise_gemm_xdlops_streamk.hpp:442
__host__ static constexpr __device__ auto MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_m_n_grid_desc)
Definition: gridwise_gemm_xdlops_streamk.hpp:360
__host__ static __device__ auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA)
Definition: gridwise_gemm_xdlops_streamk.hpp:197
__host__ static __device__ auto CalculateK0(index_t KPad)
Definition: gridwise_gemm_xdlops_streamk.hpp:194
static constexpr auto I0
Definition: gridwise_gemm_xdlops_streamk.hpp:114
Block2CTileMap_ Block2CTileMap
Definition: gridwise_gemm_xdlops_streamk.hpp:133
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdlops_streamk.hpp:283
__host__ static __device__ auto CalculateGridSize(const Argument &karg)
Definition: gridwise_gemm_xdlops_streamk.hpp:187
FloatAcc FloatCShuffle
Definition: gridwise_gemm_xdlops_streamk.hpp:131
__host__ static constexpr __device__ auto GetClusterLengthReduction()
Definition: gridwise_gemm_xdlops_streamk.hpp:411
__host__ static constexpr __device__ bool CalculateHasMainK0BlockLoop(index_t K0)
Definition: gridwise_gemm_xdlops_streamk.hpp:351
static constexpr auto N01
Definition: gridwise_gemm_xdlops_streamk.hpp:126
static constexpr auto I6
Definition: gridwise_gemm_xdlops_streamk.hpp:120
__host__ static constexpr __device__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdlops_streamk.hpp:313
static constexpr auto M01
Definition: gridwise_gemm_xdlops_streamk.hpp:125
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle()
Definition: gridwise_gemm_xdlops_streamk.hpp:386
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdlops_streamk.hpp:255
static std::string GetTypeString()
Definition: gridwise_gemm_xdlops_streamk.hpp:1161
__host__ static constexpr __device__ auto GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle()
Definition: gridwise_gemm_xdlops_streamk.hpp:399
__host__ static constexpr __device__ auto GetPartialAccBlockDescriptor()
Definition: gridwise_gemm_xdlops_streamk.hpp:423
static constexpr auto I2
Definition: gridwise_gemm_xdlops_streamk.hpp:116
static constexpr auto I1
Definition: gridwise_gemm_xdlops_streamk.hpp:115
FloatAB_ FloatAB
Definition: gridwise_gemm_xdlops_streamk.hpp:134
static constexpr auto K1
Definition: gridwise_gemm_xdlops_streamk.hpp:124
static constexpr auto KPerBlock
Definition: gridwise_gemm_xdlops_streamk.hpp:127
__host__ static constexpr __device__ auto MakeCBlockClusterAdaptor(const CGridDesc &c_m_n_grid_desc, index_t, index_t, index_t KBatch)
Definition: gridwise_gemm_xdlops_streamk.hpp:378
FloatAcc_ FloatAcc
Definition: gridwise_gemm_xdlops_streamk.hpp:130
remove_cvref_t< decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))> CGridDesc_M_N
Definition: gridwise_gemm_xdlops_streamk.hpp:440
static constexpr auto I3
Definition: gridwise_gemm_xdlops_streamk.hpp:117
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_streamk.hpp:129
static constexpr auto I7
Definition: gridwise_gemm_xdlops_streamk.hpp:121
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_streamk.hpp:291
static constexpr auto I4
Definition: gridwise_gemm_xdlops_streamk.hpp:118
FloatC_ FloatC
Definition: gridwise_gemm_xdlops_streamk.hpp:135
__host__ static __device__ auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB)
Definition: gridwise_gemm_xdlops_streamk.hpp:226
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdlops_streamk.hpp:275
Definition: gridwise_gemm_pipeline_v3.hpp:11
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
__host__ __device__ void Clear()
Definition: static_buffer.hpp:63
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1r2.hpp:33
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: threadwise_tensor_slice_transfer.hpp:214
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: reduction_operator.hpp:37
Definition: functional2.hpp:31
Definition: tensor_layout.hpp:21
Definition: tensor_layout.hpp:16
Definition: device_base.hpp:50
Definition: unary_element_wise_operation.hpp:241
Definition: workgroup_barrier.hpp:7
__device__ void inc(uint32_t offset)
Definition: workgroup_barrier.hpp:62
__device__ void wait_eq(uint32_t offset, uint32_t value)
Definition: workgroup_barrier.hpp:29