/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp Source File
block_to_ctile_map.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/utility/math.hpp"
7 #include "ck/utility/number.hpp"
8 #include "ck/utility/tuple.hpp"
11 #ifndef CK_CODE_GEN_RTC
12 #include <limits>
13 #include <stdlib.h>
14 #endif
15 
16 namespace ck {
17 
18 // Rows of column-vectors
19 template <index_t MPerBlock,
20  index_t NPerBlock,
21  typename CGridDesc_M_N,
22  bool DeviceCTileIndexCheck = false>
24 {
25  static constexpr auto I0 = Number<0>{};
26  static constexpr auto I1 = Number<1>{};
27  static constexpr auto I2 = Number<2>{};
28  static constexpr auto I3 = Number<3>{};
29 
30  __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01() = default;
31 
32  __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n,
33  index_t M01 = 1)
34  : M01_(M01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01))
35  {
36  }
37 
38  __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
39  {
40  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
41  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
42 
43  const auto M00 = math::integer_divide_ceil(M0, M01_);
44 
45  const index_t grid_size = M00 * M01_ * N0;
46 
47  return grid_size;
48  }
49 
50  template <typename TopIdx>
51  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
52  {
53  return underlying_map_.CalculateBottomIndex(idx_top);
54  }
55 
56  template <typename CTileIdx, typename CTileDim>
57  __host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& c_tile_idx,
58  const CTileDim& c_tile_dim) const
59  {
60  if constexpr(DeviceCTileIndexCheck)
61  return DefaultValidCTileIndex(c_tile_idx, c_tile_dim);
62  else
63  return true;
64  }
65 
66  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
67  {
68  if constexpr(DeviceCTileIndexCheck)
69  return true; // validity check moved to kernel
70 
71  const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
72  if(M0 % M01_ == 0)
73  {
74  return true;
75  }
76  else
77  {
78  return false;
79  }
80  }
81 
82  private:
83  __host__ __device__ static constexpr auto
84  GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01)
85  {
86  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
87  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
88 
89  const auto M00 = math::integer_divide_ceil(M0, M01);
90 
91  const auto m00_n0_m01_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor(
95  make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
96  make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
97 
98  const auto cblockid_to_m00_n0_m01_block_cluster_adaptor = make_single_stage_tensor_adaptor(
99  make_tuple(make_merge_transform(make_tuple(1, M00, N0, M01))),
100  make_tuple(Sequence<0, 1, 2, 3>{}),
101  make_tuple(Sequence<0>{}));
102 
103  const auto cblockid_to_m0_n0_block_cluster_adaptor =
104  chain_tensor_adaptors(m00_n0_m01_to_m0_n0_block_cluster_adaptor,
105  cblockid_to_m00_n0_m01_block_cluster_adaptor);
106 
107  return cblockid_to_m0_n0_block_cluster_adaptor;
108  }
109 
110  index_t M01_;
111  using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1));
112  UnderlyingMap underlying_map_;
113 };
114 
115 // Rows of column-vectors
116 // This C-tile map dynamically adjusts M01 when C-tile index is out of range
117 template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
118 struct BlockToCTileMap_M00_N0_M01Adapt;
119 
120 template <index_t MPerBlock, index_t NPerBlock>
121 struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
122 {
123  static constexpr auto I0 = Number<0>{};
124  static constexpr auto I1 = Number<1>{};
125 
126  __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default;
127 
128  __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(
129  const BlockToCTileMap_M00_N0_M01Adapt&) = default;
130  __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(
131  BlockToCTileMap_M00_N0_M01Adapt&&) = default;
132  __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt&
134  __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt&
136 
137  __host__
138  __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8)
139  : M_(M), N_(N), M01_(M01)
140  {
141 #if 0
142  if(get_thread_global_1d_id()==0){
143  printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_);
144  }
145 #endif
146  }
147 
148  template <typename CGridDesc_M_N>
149  __host__
150  __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
151  index_t M01 = 8)
153  c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
154  {
155  }
156 
157  __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
158  {
159  const auto M0 = math::integer_divide_ceil(M, MPerBlock);
160  const auto N0 = math::integer_divide_ceil(N, NPerBlock);
161 
162  return M0 * N0;
163  }
164 
165  template <typename CGridDesc_M_N>
166  __host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
167  {
168  return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
169  }
170 
171  template <typename CGridDesc_M_N>
172  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
173  {
174  return true;
175  }
176 
177  template <typename TopIdx>
178  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
179  {
180  auto block_1d_id = idx_top[I0];
181 
182  const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
183  const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
184 
185  block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
186 
187  index_t idx_N0 = block_1d_id % N0;
188  index_t idx_M0 = block_1d_id / N0;
189 
190  const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
191 
192  index_t idx_M00 = idx_M0 / M01_;
193  index_t idx_M01 = idx_M0 % M01_;
194  index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
195 
240  return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
241  idx_N0_M01_local / M01_adapt);
242  }
243 
244  template <typename CTileIdx, typename CTileDim>
245  __host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
246  const CTileDim& /* c_tile_dim */) const
247  {
248  return true; // always valid provided that user gets grid size from CalculateGridSize()
249  }
250 
251  private:
252  index_t M_;
253  index_t N_;
254  index_t M01_;
255 };
256 
257 // keep the redundant type argument for backward compatibility
258 template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
260 {
262  BlockToCTileMap_M00_N0_M01Adapt;
263 };
264 
265 // Grouped Rows of column-vectors WGP mapping
266 // Optimized for gfx94x-like multipe-die chip
267 
268 template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock>
270 {
271  static constexpr auto I0 = Number<0>{};
272  static constexpr auto I1 = Number<1>{};
273 
275  index_t N,
276  index_t M01 = 8)
277  : M_(M), N_(N), M01_(M01)
278  {
279  }
280 
281  __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
282  {
283  const auto M0 = math::integer_divide_ceil(M, MPerBlock);
284  const auto N0 = math::integer_divide_ceil(N, NPerBlock);
285 
286  return M0 * N0;
287  }
288 
289  template <typename CGridDesc_M_N>
290  __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
291  {
292  return true;
293  }
294 
295  template <typename TopIdx>
296  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
297  {
298  auto block_1d_id = idx_top[I0];
299 
300  const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
301  const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
302 
303  if(M0 == 1)
304  {
305  return make_tuple(0, block_1d_id);
306  }
307  else if(N0 == 1)
308  {
309  return make_tuple(block_1d_id, 0);
310  }
311  // block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
312  else
313  {
314  const auto group_size = math::integer_divide_ceil(M0 * N0, GroupNum);
315  const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
316  auto group_id_x = block_1d_id % GroupNum;
317  auto group_id_y = block_1d_id / GroupNum;
318  auto remap_block_1d_id =
319  group_id_x <= big_group_num
320  ? group_id_x * group_size + group_id_y
321  : group_id_x * group_size + big_group_num - group_id_x + group_id_y;
322 
323  index_t idx_N0 = remap_block_1d_id % N0;
324  index_t idx_M0 = remap_block_1d_id / N0;
325 
326  const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
327 
328  index_t idx_M00 = idx_M0 / M01_;
329  index_t idx_M01 = idx_M0 % M01_;
330  index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
331 
376  return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
377  idx_N0_M01_local / M01_adapt);
378  }
379  }
380 
381  template <typename CTileIdx, typename CTileDim>
382  __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
383  const CTileDim& /* c_tile_dim */) const
384  {
385  return true; // always valid provided that user gets grid size from CalculateGridSize()
386  }
387 
388  private:
389  index_t M_;
390  index_t N_;
391  index_t M01_;
392 };
393 
394 // columns of row-vectors
395 // This C-tile map dynamically adjusts N01 when C-tile index is out of range
396 template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
398 
399 template <index_t MPerBlock, index_t NPerBlock>
400 struct BlockToCTileMap_N00_M0_N01Adapt<MPerBlock, NPerBlock, void>
401 {
402  static constexpr auto I0 = Number<0>{};
403  static constexpr auto I1 = Number<1>{};
404 
405  __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt() = default;
406 
408  default;
410  default;
411  __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt&
413  __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt&
415 
416  __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(index_t M, index_t N, index_t N01 = 8)
417  : M_(M), N_(N), N01_(N01)
418  {
419 #if 0
420  if(get_thread_global_1d_id()==0){
421  printf("Ctor called, M= %d, N= %d, N01 = %d\n", M_, N_, N01_);
422  }
423 #endif
424  }
425 
426  template <typename CGridDesc_M_N>
427  __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
428  index_t N01 = 8)
430  c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), N01)
431  {
432  }
433 
434  __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
435  {
436  const auto M0 = math::integer_divide_ceil(M, MPerBlock);
437  const auto N0 = math::integer_divide_ceil(N, NPerBlock);
438 
439  return M0 * N0;
440  }
441 
442  template <typename CGridDesc_M_N>
443  __host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
444  {
445  return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
446  }
447 
448  template <typename CGridDesc_M_N>
449  __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
450  {
451  return true;
452  }
453 
454  template <typename TopIdx>
455  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
456  {
457  auto block_1d_id = idx_top[I0];
458 
459  const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
460  const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
461 
462  block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
463 
464  index_t idx_M0 = block_1d_id % M0;
465  index_t idx_N0 = block_1d_id / M0;
466 
467  const auto N01_adapt = (idx_N0 < N0 - N0 % N01_) ? N01_ : N0 % N01_;
468 
469  index_t idx_N00 = idx_N0 / N01_;
470  index_t idx_N01 = idx_N0 % N01_;
471  index_t idx_M0_N01_local = idx_M0 + idx_N01 * M0;
472 
518  return make_tuple(idx_M0_N01_local / N01_adapt,
519  idx_M0_N01_local % N01_adapt + idx_N00 * N01_);
520  }
521 
522  template <typename CTileIdx, typename CTileDim>
523  __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
524  const CTileDim& /* c_tile_dim */) const
525  {
526  return true; // always valid provided that user gets grid size from CalculateGridSize()
527  }
528 
529  private:
530  index_t M_;
531  index_t N_;
532  index_t N01_;
533 };
534 
535 // 2D slices of column-vectors in 3D space
536 // This C-tile map dynamically adjusts M01 when C-tile index is out of range
537 template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
539 {
540  static constexpr auto I0 = Number<0>{};
541  static constexpr auto I1 = Number<1>{};
542  static constexpr auto I2 = Number<2>{};
543  static constexpr auto I3 = Number<3>{};
544 
545  __host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt() = default;
546 
547  __host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
548  index_t M01 = 8,
549  index_t KSplit = 1)
550  : M01_(M01), KSplit_(KSplit), c_grid_desc_m_n_(c_grid_desc_m_n)
551  {
552  }
553 
554  __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
555  {
556  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
557  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
558 
559  const index_t grid_size = M0 * N0 * KSplit_;
560 
561  return grid_size;
562  }
563 
564  template <typename TopIdx>
565  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
566  {
567  auto block_1d_id = idx_top[I0];
568 
569  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
570  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
571 
572  block_1d_id = block_1d_id % (M0 * N0 * KSplit_); // hide groups
573 
574  const index_t idx_ksplit = block_1d_id / (M0 * N0);
575  block_1d_id = block_1d_id % (M0 * N0);
576 
577  index_t idx_N0 = block_1d_id % N0;
578  index_t idx_M0 = block_1d_id / N0;
579 
580  const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
581 
582  index_t idx_M00 = idx_M0 / M01_;
583  index_t idx_M01 = idx_M0 % M01_;
584  index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
585 
586  return make_tuple(idx_ksplit,
587  idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
588  idx_N0_M01_local / M01_adapt);
589  }
590 
591  template <typename CTileIdx, typename CTileDim>
592  __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
593  const CTileDim& /* c_tile_dim */) const
594  {
595  return true; // always valid provided that user gets grid size from CalculateGridSize()
596  }
597 
598  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
599  {
600  return true;
601  }
602 
603  private:
604  index_t M01_;
605  index_t KSplit_;
606  CGridDesc_M_N c_grid_desc_m_n_;
607 };
608 
609 // Blocks of row-vectors
610 template <index_t MPerBlock,
611  index_t NPerBlock,
612  typename CGridDesc_M_N,
613  bool DeviceCTileIndexCheck = false>
615 {
616  static constexpr auto I0 = Number<0>{};
617  static constexpr auto I1 = Number<1>{};
618  static constexpr auto I2 = Number<2>{};
619  static constexpr auto I3 = Number<3>{};
620 
621  __host__ __device__ BlockToCTileMap_M00_N00_M01_N01() = default;
622 
623  __host__ __device__ BlockToCTileMap_M00_N00_M01_N01(const CGridDesc_M_N& c_grid_desc_m_n,
624  index_t M01 = 1,
625  index_t N01 = 1)
626  : M01_(M01), N01_(N01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01))
627  {
628  }
629 
630  __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
631  {
632  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
633  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
634 
635  const auto M00 = math::integer_divide_ceil(M0, M01_);
636  const auto N00 = math::integer_divide_ceil(N0, N01_);
637 
638  const index_t grid_size = M00 * M01_ * N00 * N01_;
639 
640  return grid_size;
641  }
642 
643  template <typename TopIdx>
644  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
645  {
646  return underlying_map_.CalculateBottomIndex(idx_top);
647  }
648 
649  template <typename CTileIdx, typename CTileDim>
650  __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
651  const CTileDim& c_tile_dim) const
652  {
653  if constexpr(DeviceCTileIndexCheck)
654  return DefaultValidCTileIndex(c_tile_idx, c_tile_dim);
655  else
656  return true;
657  }
658 
659  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
660  {
661  if constexpr(DeviceCTileIndexCheck)
662  return true; // validity check moved to kernel
663 
664  const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
665  const index_t N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
666  if(M0 % M01_ == 0 && N0 % N01_ == 0)
667  {
668  return true;
669  }
670  else
671  {
672  return false;
673  }
674  }
675 
676  private:
677  __host__ __device__ static constexpr auto
678  GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
679  {
680  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
681  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
682 
683  const auto M00 = math::integer_divide_ceil(M0, M01);
684  const auto N00 = math::integer_divide_ceil(N0, N01);
685 
686  const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
688  make_tuple(make_insert_transform(1), // swallow the carry from lower dimensions
691  make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
692  make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
693 
694  const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
696  make_tuple(make_merge_transform(make_tuple(1, M00, N00, M01, N01))),
697  make_tuple(Sequence<0, 1, 2, 3, 4>{}),
698  make_tuple(Sequence<0>{}));
699 
700  const auto cblockid_to_m0_n0_block_cluster_adaptor =
701  chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
702  cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
703 
704  return cblockid_to_m0_n0_block_cluster_adaptor;
705  }
706 
707  index_t M01_, N01_;
708  using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1));
709  UnderlyingMap underlying_map_;
710 };
711 
712 // 2D slices of row-vectors in 3D space
713 template <index_t MPerBlock,
714  index_t NPerBlock,
715  typename CGridDesc_M_N,
716  bool DeviceCTileIndexCheck = false>
718 {
719  static constexpr auto I0 = Number<0>{};
720  static constexpr auto I1 = Number<1>{};
721  static constexpr auto I2 = Number<2>{};
722  static constexpr auto I3 = Number<3>{};
723 
725 
726  __host__ BlockToCTileMap_KSplit_M00_N00_M01_N01(const CGridDesc_M_N& c_grid_desc_m_n,
727  index_t M01 = 1,
728  index_t N01 = 1,
729  index_t KSplit = 1)
730  : c_grid_desc_m_n_(c_grid_desc_m_n),
731  M01_(M01),
732  N01_(N01),
733  KSplit_(KSplit),
734  underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01, KSplit))
735  {
736  }
737 
738  __host__ __device__ constexpr index_t
739  CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
740  {
741  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
742  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
743 
744  const auto M00 = math::integer_divide_ceil(M0, M01_);
745  const auto N00 = math::integer_divide_ceil(N0, N01_);
746 
747  const index_t grid_size = M00 * M01_ * N00 * N01_ * KSplit_;
748 
749  return grid_size;
750  }
751 
752  template <typename TopIdx>
753  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
754  {
755  static_assert(TopIdx::Size() == 1);
756 
757  return underlying_map_.CalculateBottomIndex(
758  make_multi_index(idx_top[I0] % CalculateGridSize()));
759  }
760 
761  template <typename CTileIdx, typename CTileDim>
762  __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
763  const CTileDim& c_tile_dim) const
764  {
765  if constexpr(DeviceCTileIndexCheck)
766  return DefaultValidCTileIndex(c_tile_idx, c_tile_dim);
767  else
768  return true;
769  }
770 
771  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
772  {
773  if constexpr(DeviceCTileIndexCheck)
774  return true; // validity check moved to kernel
775 
776  const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
777  const index_t N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
778  if(M0 % M01_ == 0 && N0 % N01_ == 0)
779  {
780  return true;
781  }
782  else
783  {
784  return false;
785  }
786  }
787 
788  private:
789  __device__ constexpr index_t CalculateGridSize() const
790  {
791  return CalculateGridSize(c_grid_desc_m_n_);
792  }
793 
794  __host__ static constexpr auto GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n,
795  index_t M01,
796  index_t N01,
797  index_t KSplit)
798  {
799  const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
800  const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
801 
802  const auto M00 = math::integer_divide_ceil(M0, M01);
803  const auto N00 = math::integer_divide_ceil(N0, N01);
804 
805  const auto ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
810  make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
811  make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
812 
813  const auto c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor =
815  make_tuple(make_merge_transform(make_tuple(KSplit, M00, N00, M01, N01))),
816  make_tuple(Sequence<0, 1, 2, 3, 4>{}),
817  make_tuple(Sequence<0>{}));
818 
819  const auto c_blockid_to_ksplit_m0_n0_block_cluster_adaptor =
820  chain_tensor_adaptors(ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
821  c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor);
822 
823  return c_blockid_to_ksplit_m0_n0_block_cluster_adaptor;
824  }
825 
826  CGridDesc_M_N c_grid_desc_m_n_;
827  index_t M01_, N01_, KSplit_;
828  using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1, 1));
829  UnderlyingMap underlying_map_;
830 };
831 
832 template <typename CTileIdx, typename CTileDim>
833 __host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx,
834  const CTileDim& c_tile_dim)
835 {
836  bool is_valid = false;
837 
838  const index_t m_block = c_tile_dim[Number<0>{}];
839  const index_t n_block = c_tile_dim[Number<1>{}];
840 
841  if constexpr(CTileIdx::Size() == 2)
842  {
843  const index_t m_block_idx = c_tile_idx[Number<0>{}];
844  const index_t n_block_idx = c_tile_idx[Number<1>{}];
845  if(0 <= m_block_idx && m_block_idx < m_block && 0 <= n_block_idx && n_block_idx < n_block)
846  {
847  is_valid = true;
848  }
849  }
850  else if constexpr(CTileIdx::Size() == 3)
851  {
852  const index_t ksplit_idx = c_tile_idx[Number<0>{}];
853  const index_t m_block_idx = c_tile_idx[Number<1>{}];
854  const index_t n_block_idx = c_tile_idx[Number<2>{}];
855  if(0 <= m_block_idx && m_block_idx < m_block && 0 <= n_block_idx && n_block_idx < n_block)
856  {
857  is_valid = true;
858  }
859  ignore = ksplit_idx;
860  }
861 
862  return is_valid;
863 }
864 
865 // This wrapper class is for grouped gemm where it subtracts blockIdx by a value so that the
866 // workgroups assigned to a given gemm problem have top index offsetted to range [0,
867 // grid_size_per_gemm]
868 template <typename UnderlyingBlockToCTileMap>
870 {
871  using underlying_type = UnderlyingBlockToCTileMap;
872 
873  __host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
874  index_t block_start)
875  {
876  block_to_ctile_map_ = block_to_ctile_map;
877  block_start_ = block_start;
878  }
879 
880  template <typename TopIdx>
881  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
882  {
883  return block_to_ctile_map_.CalculateBottomIndex(
884  make_multi_index(idx_top[Number<0>{}] - block_start_));
885  }
886 
887  template <typename CTileIdx, typename CTileDim>
888  __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
889  const CTileDim& c_tile_dim) const
890  {
891  return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
892  }
893 
894  template <typename CGridDesc_M_N>
895  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
896  {
897  return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
898  }
899 
900  template <typename CGridDesc_M_N>
901  __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
902  {
903  return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
904  }
905 
906  __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
907  {
908  return block_to_ctile_map_.CalculateGridSize(M, N);
909  }
910 
911  UnderlyingBlockToCTileMap block_to_ctile_map_;
913 };
914 // second version with 2 offsets
915 template <typename UnderlyingBlockToCTileMap>
917 {
918  using underlying_type = UnderlyingBlockToCTileMap;
919 
920  __host__ __device__ OffsettedBlockToCTileMap2(UnderlyingBlockToCTileMap block_to_ctile_map,
921  index_t group_offset,
922  index_t tile_offset)
923  : block_to_ctile_map_{block_to_ctile_map},
924  group_offset_{group_offset},
925  tile_offset_{tile_offset}
926  {
927  }
928 
929  template <typename TopIdx>
930  __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
931  {
932  return block_to_ctile_map_.CalculateBottomIndex(
934  }
935 
936  template <typename CTileIdx, typename CTileDim>
937  __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
938  const CTileDim& c_tile_dim) const
939  {
940  return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
941  }
942 
943  template <typename CGridDesc_M_N>
944  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
945  {
946  return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
947  }
948 
949  __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
950  {
951  return block_to_ctile_map_.CalculateGridSize(M, N);
952  }
953 
954  __device__ void UpdateTileOffset(index_t offset) { tile_offset_ = offset; }
955  UnderlyingBlockToCTileMap block_to_ctile_map_;
958 };
959 
972 template <index_t MPerBlock, index_t NPerBlock>
974 {
975 
976  __host__ __device__ BlockToCTileMap_3DGrid_KSplit() = default;
977 
978  __host__ __device__ constexpr auto
980  {
981  // Create 3D grid
982  const auto M0 = math::integer_divide_ceil(M, MPerBlock);
983  const auto N0 = math::integer_divide_ceil(N, NPerBlock);
984  return make_tuple(N0, M0, k_split);
985  }
986 
987  template <typename TopIdx>
988  __device__ constexpr auto CalculateBottomIndex(const TopIdx&) const
989  {
990  return make_tuple(blockIdx.z, blockIdx.y, blockIdx.x);
991  }
992 
993  template <typename CTileIdx, typename CTileDim>
994  __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
995  const CTileDim& /* c_tile_dim */) const
996  {
997  return true; // always valid provided that user gets grid size from CalculateGridSize()
998  }
999 
1000  template <typename CGridDesc_M_N>
1001  __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
1002  {
1003  return true;
1004  }
1005 };
1006 
1008 {
1009  Atomic = 0, // sk block use atomic to do reduction
1010  Reduction, // let some workgroup responsible for doing the reduction operation
1011 };
1012 
1013 template <uint32_t MPerBlock_,
1014  uint32_t NPerBlock_,
1015  uint32_t KPerBlock_,
1017  uint32_t TileSwizzleSubM_ = 8>
1019 {
1020  static constexpr uint32_t min_k_iters_per_sk_block = 2;
1021  static constexpr uint32_t MPerBlock = MPerBlock_;
1022  static constexpr uint32_t NPerBlock = NPerBlock_;
1023  static constexpr uint32_t KPerBlock = KPerBlock_;
1024  static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_;
1025  static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
1026 
1027  //--------------------------------------
1028  // pass to device
1029  uint32_t sk_num_blocks;
1036  MDiv eqav_tiles_big; // for reduction
1037  MDiv eqav_tiles_little; // for reduction
1038 
1039  // MDiv tile_swizzle_sub_m_rem;
1040  //--------------------------------------
1041 
1042  // prefer construct on host
1044  uint32_t n,
1045  uint32_t k,
1046  uint32_t num_cu,
1047  uint32_t occupancy,
1048  uint32_t sk_blocks = 0xffffffff)
1049  {
1050  uint32_t num_tiles =
1053 
1054  // one cu can hold one wg at one time, from the whole chip's point of view
1055  // if number of wg is same as num_cu, we call it 1 dispatch
1056  // if number of wg is 2x num_cu, we call it 2 dispatches.
1057  // one dispatch can deliver wg same as num_cu (full dispatch), or less than num_cu (partial
1058  // dispatch)
1059  //
1060  uint32_t full_dispatches = num_tiles / num_cu;
1061  uint32_t full_dispatch_tiles = full_dispatches * num_cu;
1062  uint32_t partial_dispatche_tiles = num_tiles - full_dispatch_tiles;
1063 
1064  uint32_t sk_occupancy = occupancy;
1065  uint32_t dp_tiles = full_dispatch_tiles;
1066  uint32_t sk_tiles = partial_dispatche_tiles;
1067 
1068  if(full_dispatches < occupancy)
1069  {
1070  // in this case, we allocate all blocks as sk blocks
1071  // sk_occupancy = occupancy - full_dispatches;
1072  sk_occupancy = 1; // TODO: single occ seems better
1073  dp_tiles = full_dispatch_tiles;
1074  sk_tiles = partial_dispatche_tiles;
1075  }
1076  else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1))
1077  {
1078  // e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
1079  // occupancy = 3, full_dispatches = 5, 8, 11 ...
1080  // occupancy = 4, full_dispatches = 7, 11 ...
1081  sk_occupancy = 1; // left 1 slot for sk occupancy
1082  dp_tiles = full_dispatch_tiles;
1083  sk_tiles = partial_dispatche_tiles;
1084  }
1085  else
1086  {
1087  // others, we reduce 1 dispatch from dp, together with partial dispatch,
1088  // to construct sk dispatch
1089  sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy);
1090  dp_tiles = full_dispatch_tiles - num_cu;
1091  sk_tiles = partial_dispatche_tiles + num_cu;
1092  }
1093 
1094  // uint32_t dp_iters_per_block = k_iters_per_tile.get();
1095  uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles;
1096  uint32_t dp_num_blocks = 0;
1097 
1098  {
1099  uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
1100  uint32_t max_sk_tiles =
1101  (sk_tiles >= num_cu) ? num_cu * sk_occupancy
1102  : math::min(num_cu, sk_total_iters / min_k_iters_per_sk_block);
1103 
1104  // if use dp for sk-block, how many iters do we need
1105  uint32_t dp_for_sk_iters = k_iters_per_tile.get();
1106 
1107  uint32_t best_sk_score =
1108  NumericLimits<int32_t>::Max(); // we need to find the smallest sk iters
1109  for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
1110  tentative_sk_blocks++)
1111  {
1112  uint32_t tentative_sk_iters_per_block =
1113  (sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks;
1114  uint32_t tentative_sk_iters = tentative_sk_iters_per_block;
1115  uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles;
1116 
1117  // TODO: carefully adjust this parameter
1118  // the more sk_blocks_per_tile, the worse the overhead
1119  uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile;
1120  if(tentative_sk_blocks % sk_tiles != 0)
1121  {
1122  // penalty for uneven divide
1123  cross_sk_blocks_overhead +=
1124  sk_blocks_per_tile * tentative_sk_iters_per_block / 50;
1125  }
1126 
1127  uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead;
1128 
1129  if(tentative_sk_score < best_sk_score)
1130  {
1131  best_sk_score = tentative_sk_score;
1132  sk_num_blocks = tentative_sk_blocks;
1133  }
1134  }
1135 
1136  if(best_sk_score >= dp_for_sk_iters)
1137  {
1138  sk_num_blocks = 0;
1139  }
1140 
1141  // give a chance to control num of sk blocks
1142  sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks;
1143 
1144  if(sk_num_blocks == 0)
1145  {
1146  sk_num_big_blocks = 0;
1148 
1149  dp_num_blocks = num_tiles; // all tile to be dp block
1150  dp_start_block_idx = 0;
1151  sk_total_iters = 0; // clear this tiles
1152  }
1153  else
1154  {
1155  // k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
1156  // we need to decide how many iters for each sk block
1157  // let m = k_iters_per_sk_block
1158  // some of the sk block (little) will cover m iters, some (big) will cover m+1
1159  // we have
1160  // 1) l + b = sk_blocks
1161  // 2) l * m + b * (m + 1) = sk_total_iters
1162  // => (l + b) * m + b = sk_total_iters
1163  // => sk_blocks * m + b = sk_total_iters
1164  // => b = sk_total_iters - m * sk_blocks
1165  // NOTE: big could be zero
1166  uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
1167  sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
1168  k_iters_per_big_block = k_iters_per_sk_block + 1;
1169 
1170  dp_num_blocks = dp_tiles;
1171  dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
1172  }
1173  }
1176 
1178  {
1179  uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get());
1180  uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get());
1181  eqav_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
1182  eqav_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
1183  }
1184 
1185 #if 0
1186  printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
1187  "sk_num_blocks:%d, "
1188  "sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, "
1189  "k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
1190  "sk_tiles:%u, workspace(acc float):%u\n",
1191  num_cu,
1192  occupancy,
1193  get_grid_dims().x,
1194  num_tiles,
1195  dp_tiles,
1197  sk_num_blocks,
1198  sk_total_iters,
1200  dp_iters_per_block,
1201  dp_num_blocks,
1205  get_sk_tiles(),
1206  get_workspace_size(sizeof(float)));
1207 #endif
1208  }
1209 
1210  __host__ __device__ uint32_t get_sk_total_iters() const
1211  {
1212  uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block +
1214  return sk_total_iters;
1215  }
1216 
1217  __host__ __device__ uint32_t get_sk_tiles() const
1218  {
1219  // tiles for sk
1220  uint32_t sk_total_iters = get_sk_total_iters();
1221  return k_iters_per_tile.div(sk_total_iters);
1222  }
1223 
1224  __host__ __device__ dim3 get_grid_dims() const
1225  {
1227  {
1228  return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
1229  }
1230  else
1231  return dim3(reduction_start_block_idx, 1, 1);
1232  }
1233 
1234  __device__ uint32_t get_block_idx() const
1235  {
1236  // TODO: swizzle block index for better locality
1237  return __builtin_amdgcn_readfirstlane(blockIdx.x);
1238  }
1239 
1240  __device__ void
1241  get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const
1242  {
1243  if(block_idx < sk_num_big_blocks)
1244  {
1245  iter_start = block_idx * k_iters_per_big_block;
1246  iter_end = iter_start + k_iters_per_big_block;
1247  }
1248  else if(block_idx < sk_num_blocks)
1249  {
1250  iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
1251  (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
1252  iter_end = iter_start + (k_iters_per_big_block - 1);
1253  }
1254  else if(block_idx >= dp_start_block_idx)
1255  {
1256  uint32_t sk_total_iters = get_sk_total_iters();
1257  uint32_t dp_iters_per_block = k_iters_per_tile.get();
1258  iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
1259  iter_end = iter_start + dp_iters_per_block;
1260  }
1261  }
1262 
1263  __device__ uint32_t get_current_iter_length(uint32_t iter_start,
1264  uint32_t iter_end,
1265  uint32_t total_iter_length) const
1266  {
1267  uint32_t iter_length_mod, iter_length_quo /*unused*/;
1268  k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod);
1269  uint32_t current_iter_length = math::min(
1270  iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length);
1271  return current_iter_length;
1272  }
1273 
1274  __device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); }
1275 
1276  __device__ void
1277  get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const
1278  {
1279  k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
1280  }
1281 
1282  __device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
1283  {
1284  uint32_t m_tile_idx, n_tile_idx;
1285  uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock);
1286  n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx);
1287 
1288  // swizzle tile
1289  uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock);
1290 
1291  uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m;
1292 
1293  const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem))
1295  : tile_swizzle_sub_m_rem;
1296 
1297  uint32_t m_tile_idx_sub0, m_tile_idx_sub1;
1298  m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m;
1299  m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m;
1300 
1301  uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value;
1302 
1303  uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
1304 
1305  n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt;
1306  m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
1307  return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
1308  n_tile_idx_with_adapt);
1309  }
1310 
1311  __host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
1312  {
1313  static constexpr uint32_t alignment = 128;
1314  uint32_t acc_buffer_bytes =
1315  MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes;
1316  return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
1317  }
1318 
1319  __host__ __device__ uint32_t get_workspace_size_for_semaphore() const
1320  {
1321  return get_sk_tiles() * sizeof(uint32_t);
1322  }
1323 
1324  __host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
1325  {
1326  return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore();
1327  }
1328 
1329  __host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_,
1330  const MDiv& eqav_tiles_) const
1331  {
1332  uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
1333  uint32_t max_eqav_tiles_ = eqav_tiles_.get() - 1;
1334  uint32_t quo_, rem_;
1335  eqav_tiles_.divmod(tile_idx_, quo_, rem_);
1336  return quo_ * max_eqav_tiles_ + rem_;
1337  }
1338 
1339  __host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_,
1340  uint32_t iters_per_sk_block_) const
1341  {
1342  return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
1343  1);
1344  }
1345 
1346  __host__ __device__ uint32_t get_total_acc_buffers() const
1347  {
1348  uint32_t tiles_cover_big_blocks =
1350  uint32_t tiles_cover_little_blocks =
1352 
1353  uint32_t total_intersec_big =
1354  get_tile_intersections(tiles_cover_big_blocks, eqav_tiles_big);
1355  uint32_t total_intersec_little =
1356  get_tile_intersections(tiles_cover_little_blocks, eqav_tiles_little);
1357 
1358  return sk_num_blocks + total_intersec_big + total_intersec_little;
1359  }
1360 
1361  __device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
1362  {
1363  // TODO: from big to little
1364  uint32_t tiles_cover_big_blocks =
1366  if(tile_idx_ < tiles_cover_big_blocks)
1367  {
1368  uint32_t touched_sk_blocks =
1369  (tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
1371  uint32_t current_intersec = get_tile_intersections(tile_idx_, eqav_tiles_big);
1372  return touched_sk_blocks + current_intersec;
1373  }
1374  else
1375  {
1376  uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
1377  uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_;
1378  uint32_t touched_sk_blocks =
1379  (tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
1380  iters_per_little_sk_block;
1381  uint32_t current_intersec =
1382  get_tile_intersections(tile_idx_little_reverse, eqav_tiles_little);
1383  return get_total_acc_buffers() - (touched_sk_blocks + current_intersec);
1384  }
1385  }
1386 
1387  __device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
1388  {
1389  uint32_t iters_per_big_sk_block = k_iters_per_big_block;
1390  uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
1391  if(block_idx_ < sk_num_big_blocks)
1392  {
1393  uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
1394  k_iters_per_tile.get() - 1);
1395  uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_big);
1396  return block_idx_ + current_intersec;
1397  }
1398  else
1399  {
1400  uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
1401  uint32_t touched_tiles = k_iters_per_tile.div(
1402  block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
1403  uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_little);
1404  return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec);
1405  }
1406  }
1407 };
1408 
1409 template <uint32_t MPerBlock_,
1410  uint32_t NPerBlock_,
1411  uint32_t KPerBlock_,
1413  uint32_t TileSwizzleSubM_ = 8,
1414  index_t GroupNum = 8,
1415  index_t M01_ = 4>
1417 {
1418  static constexpr uint32_t min_k_iters_per_sk_block = 2;
1419  static constexpr uint32_t MPerBlock = MPerBlock_;
1420  static constexpr uint32_t NPerBlock = NPerBlock_;
1421  static constexpr uint32_t KPerBlock = KPerBlock_;
1422  static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_;
1423  static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
1424 
1425  //--------------------------------------
1426  // pass to device
1427  mutable uint32_t sk_num_blocks;
1434  MDiv equiv_tiles_big; // for reduction
1435  MDiv equiv_tiles_little; // for reduction
1436 
1437  // prefer construct on host
1438  __host__ __device__ BlockToCTileMap_GemmStreamK_v2(
1439  uint32_t m, uint32_t n, uint32_t k, uint32_t grid_size = 1, uint32_t streamk_sel = 1)
1440  {
1441  // total output tiles
1442  uint32_t num_tiles =
1445 
1446  uint32_t dp_tiles, dp_num_blocks, sk_total_iters;
1447 
1448  // default to regular DP GEMM if sk blocks == 0
1449  if(streamk_sel == 0)
1450  {
1451  sk_num_blocks = 0;
1452  dp_tiles = num_tiles;
1453  sk_num_big_blocks = 0;
1455 
1456  dp_num_blocks = num_tiles; // all tile to be dp block
1457  dp_start_block_idx = 0;
1458  sk_total_iters = 0; // clear this tiles
1459  }
1460  // 2-tile sk + DP GEMM
1461  else
1462  {
1463 
1464  // check if there's enough work for DP+ stream-k
1465  bool bigEnough = num_tiles > grid_size;
1466  // select between stream-k strategies
1467  uint32_t sk_tiles = 0;
1468  if(streamk_sel == 1) // 1 tile stream-k
1469  {
1470  sk_tiles = bigEnough ? (num_tiles % grid_size) : num_tiles;
1471  }
1472  else if(streamk_sel == 2) // 2-tile stream-k
1473  {
1474  sk_tiles = bigEnough ? (grid_size + num_tiles % grid_size) : num_tiles;
1475  }
1476  else if(streamk_sel == 3) // 3-tile stream-k
1477  {
1478  sk_tiles = (num_tiles > (2 * grid_size)) ? (2 * grid_size + num_tiles % grid_size)
1479  : num_tiles;
1480  }
1481  else if(streamk_sel == 4) // 4-tile stream-k
1482  {
1483  sk_tiles = (num_tiles > (3 * grid_size)) ? (3 * grid_size + num_tiles % grid_size)
1484  : num_tiles;
1485  }
1486  sk_num_blocks = sk_tiles;
1487  // remaining tiles are DP tiles
1488  dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0;
1489 
1490  sk_total_iters = k_iters_per_tile.get() * sk_tiles;
1491 
1492  // k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
1493  // we need to decide how many iters for each sk block
1494  // let m = k_iters_per_sk_block
1495  // some of the sk block (little) will cover m iters, some (big) will cover m+1
1496  // we have
1497  // 1) l + b = sk_blocks
1498  // 2) l * m + b * (m + 1) = sk_total_iters
1499  // => (l + b) * m + b = sk_total_iters
1500  // => sk_blocks * m + b = sk_total_iters
1501  // => b = sk_total_iters - m * sk_blocks
1502  // NOTE: big could be zero
1503  uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
1504  sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
1505  k_iters_per_big_block = k_iters_per_sk_block + 1;
1506 
1507  dp_num_blocks = dp_tiles;
1509  }
1510 
1512  // using multiple blocks for parallel reduction
1514 
1516  {
1517  uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get());
1518  uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get());
1519  equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
1520  equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
1521  }
1522  }
1523 
1524  __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
1525  {
1526  const auto M0 = math::integer_divide_ceil(M, MPerBlock);
1527  const auto N0 = math::integer_divide_ceil(N, NPerBlock);
1528 
1529  return M0 * N0;
1530  }
1531  __host__ __device__ uint32_t get_sk_total_iters() const
1532  {
1533  uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block +
1535  return sk_total_iters;
1536  }
1537 
1538  __host__ __device__ uint32_t get_sk_tiles() const
1539  {
1540  // tiles for sk
1541  uint32_t sk_total_iters = get_sk_total_iters();
1542  return k_iters_per_tile.div(sk_total_iters);
1543  }
1544 
1545  __host__ __device__ index_t get_grid_dims() const
1546  {
1548  {
1549  // return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
1551  }
1552  else
1554  }
1555 
1556  __device__ uint32_t get_block_idx() const
1557  {
1558  // TODO: swizzle block index for better locality
1559  return __builtin_amdgcn_readfirstlane(blockIdx.x);
1560  }
1561 
1562  __device__ void
1563  get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const
1564  {
1565  if(block_idx < sk_num_big_blocks)
1566  {
1567  iter_start = block_idx * k_iters_per_big_block;
1568  iter_end = iter_start + k_iters_per_big_block;
1569  }
1570  else if(block_idx < sk_num_blocks)
1571  {
1572  iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
1573  (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
1574  iter_end = iter_start + (k_iters_per_big_block - 1);
1575  }
1576  else if(block_idx >= dp_start_block_idx)
1577  {
1578  uint32_t sk_total_iters = get_sk_total_iters();
1579  uint32_t dp_iters_per_block = k_iters_per_tile.get();
1580  iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
1581  iter_end = iter_start + dp_iters_per_block;
1582  }
1583  }
1584 
1585  __device__ uint32_t get_current_iter_length(uint32_t iter_start,
1586  uint32_t iter_end,
1587  uint32_t total_iter_length) const
1588  {
1589  uint32_t iter_length_mod, iter_length_quo /*unused*/;
1590  k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod);
1591  uint32_t current_iter_length = math::min(
1592  iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length);
1593  return current_iter_length;
1594  }
1595 
1596  __device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); }
1597 
1598  __device__ void
1599  get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const
1600  {
1601  k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
1602  }
1603 
1604  __device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
1605  {
1606  uint32_t m_tile_idx, n_tile_idx;
1607  uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock);
1608  n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx);
1609 
1610  // // swizzle tile
1611  uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock);
1612 
1613  uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m;
1614 
1615  const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem))
1617  : tile_swizzle_sub_m_rem;
1618 
1619  uint32_t m_tile_idx_sub0, m_tile_idx_sub1;
1620  m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m;
1621  m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m;
1622 
1623  uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value;
1624 
1625  uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
1626 
1627  n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt;
1628  m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
1629  return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
1630  n_tile_idx_with_adapt);
1631  }
1632 
1633  __host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
1634  {
1635  static constexpr uint32_t alignment = 128;
1636  uint32_t acc_buffer_bytes =
1637  MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes;
1638  return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
1639  }
1640 
1641  __host__ __device__ uint32_t get_workspace_size_for_semaphore() const
1642  {
1643  return get_sk_tiles() * sizeof(uint32_t);
1644  }
1645 
1646  __host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
1647  {
1648  return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore();
1649  }
1650 
1651  __host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_,
1652  const MDiv& equiv_tiles_) const
1653  {
1654  uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
1655  uint32_t max_equiv_tiles_ = equiv_tiles_.get() - 1;
1656  uint32_t quo_, rem_;
1657  equiv_tiles_.divmod(tile_idx_, quo_, rem_);
1658  return quo_ * max_equiv_tiles_ + rem_;
1659  }
1660 
1661  __host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_,
1662  uint32_t iters_per_sk_block_) const
1663  {
1664  return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
1665  1);
1666  }
1667 
1668  __host__ __device__ uint32_t get_total_acc_buffers() const
1669  {
1670  uint32_t tiles_cover_big_blocks =
1672  uint32_t tiles_cover_little_blocks =
1674 
1675  uint32_t total_intersec_big =
1676  get_tile_intersections(tiles_cover_big_blocks, equiv_tiles_big);
1677  uint32_t total_intersec_little =
1678  get_tile_intersections(tiles_cover_little_blocks, equiv_tiles_little);
1679 
1680  return sk_num_blocks + total_intersec_big + total_intersec_little;
1681  }
1682 
1683  __device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
1684  {
1685  // TODO: from big to little
1686  uint32_t tiles_cover_big_blocks =
1688  if(tile_idx_ < tiles_cover_big_blocks)
1689  {
1690  uint32_t touched_sk_blocks =
1691  (tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
1693  uint32_t current_intersec = get_tile_intersections(tile_idx_, equiv_tiles_big);
1694  return touched_sk_blocks + current_intersec;
1695  }
1696  else
1697  {
1698  uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
1699  uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_;
1700  uint32_t touched_sk_blocks =
1701  (tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
1702  iters_per_little_sk_block;
1703  uint32_t current_intersec =
1704  get_tile_intersections(tile_idx_little_reverse, equiv_tiles_little);
1705  return get_total_acc_buffers() - (touched_sk_blocks + current_intersec);
1706  }
1707  }
1708 
1709  __device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
1710  {
1711  uint32_t iters_per_big_sk_block = k_iters_per_big_block;
1712  uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
1713  if(block_idx_ < sk_num_big_blocks)
1714  {
1715  uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
1716  k_iters_per_tile.get() - 1);
1717  uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_big);
1718  return block_idx_ + current_intersec;
1719  }
1720  else
1721  {
1722  uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
1723  uint32_t touched_tiles = k_iters_per_tile.div(
1724  block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
1725  uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_little);
1726  return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec);
1727  }
1728  }
1729 };
1730 
1731 } // namespace ck
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
StreamKReductionStrategy
Definition: block_to_ctile_map.hpp:1008
@ Atomic
Definition: block_to_ctile_map.hpp:1009
@ Reduction
Definition: block_to_ctile_map.hpp:1010
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:18
__host__ __device__ bool DefaultValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim)
Definition: block_to_ctile_map.hpp:833
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_insert_transform(const UpperIndex &up_idx)
Definition: multi_index_transform_helper.hpp:104
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto chain_tensor_adaptors(const TensorAdaptor0 &adaptor0, const TensorAdaptor1 &adaptor1)
Definition: tensor_adaptor.hpp:245
Simple tile mapping which creates 3D grid of block of threads.
Definition: block_to_ctile_map.hpp:974
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &) const
Definition: block_to_ctile_map.hpp:1001
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition: block_to_ctile_map.hpp:994
constexpr __device__ auto CalculateBottomIndex(const TopIdx &) const
Definition: block_to_ctile_map.hpp:988
__host__ constexpr __device__ auto CalculateGridSize(index_t M, index_t N, index_t k_split) const
Definition: block_to_ctile_map.hpp:979
__host__ __device__ BlockToCTileMap_3DGrid_KSplit()=default
Definition: block_to_ctile_map.hpp:1417
__host__ __device__ uint32_t get_sk_tiles() const
Definition: block_to_ctile_map.hpp:1538
MDiv k_iters_per_tile
Definition: block_to_ctile_map.hpp:1433
__host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
Definition: block_to_ctile_map.hpp:1646
__host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_, const MDiv &equiv_tiles_) const
Definition: block_to_ctile_map.hpp:1651
MDiv equiv_tiles_little
Definition: block_to_ctile_map.hpp:1435
__device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
Definition: block_to_ctile_map.hpp:1709
uint32_t dp_start_block_idx
Definition: block_to_ctile_map.hpp:1429
static constexpr uint32_t KPerBlock
Definition: block_to_ctile_map.hpp:1421
static constexpr uint32_t NPerBlock
Definition: block_to_ctile_map.hpp:1420
MDiv2 n_tiles
Definition: block_to_ctile_map.hpp:1432
static constexpr uint32_t min_k_iters_per_sk_block
Definition: block_to_ctile_map.hpp:1418
__host__ __device__ uint32_t get_sk_total_iters() const
Definition: block_to_ctile_map.hpp:1531
__host__ __device__ uint32_t get_total_acc_buffers() const
Definition: block_to_ctile_map.hpp:1668
__host__ __device__ index_t get_grid_dims() const
Definition: block_to_ctile_map.hpp:1545
__device__ uint32_t get_tile_idx(uint32_t iter) const
Definition: block_to_ctile_map.hpp:1596
__host__ __device__ uint32_t get_workspace_size_for_semaphore() const
Definition: block_to_ctile_map.hpp:1641
__device__ void get_block_itr(uint32_t block_idx, uint32_t &iter_start, uint32_t &iter_end) const
Definition: block_to_ctile_map.hpp:1563
uint32_t k_iters_per_big_block
Definition: block_to_ctile_map.hpp:1431
uint32_t sk_num_big_blocks
Definition: block_to_ctile_map.hpp:1428
MDiv equiv_tiles_big
Definition: block_to_ctile_map.hpp:1434
__device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
Definition: block_to_ctile_map.hpp:1683
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
Definition: block_to_ctile_map.hpp:1604
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:1524
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
Definition: block_to_ctile_map.hpp:1633
__device__ uint32_t get_current_iter_length(uint32_t iter_start, uint32_t iter_end, uint32_t total_iter_length) const
Definition: block_to_ctile_map.hpp:1585
static constexpr uint32_t tile_swizzle_sub_m
Definition: block_to_ctile_map.hpp:1423
static constexpr StreamKReductionStrategy ReductionStrategy
Definition: block_to_ctile_map.hpp:1422
uint32_t reduction_start_block_idx
Definition: block_to_ctile_map.hpp:1430
__host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_, uint32_t iters_per_sk_block_) const
Definition: block_to_ctile_map.hpp:1661
__host__ __device__ BlockToCTileMap_GemmStreamK_v2(uint32_t m, uint32_t n, uint32_t k, uint32_t grid_size=1, uint32_t streamk_sel=1)
Definition: block_to_ctile_map.hpp:1438
uint32_t sk_num_blocks
Definition: block_to_ctile_map.hpp:1427
__device__ uint32_t get_block_idx() const
Definition: block_to_ctile_map.hpp:1556
__device__ void get_tile_idx_with_offset(uint32_t iter, uint32_t &tile_idx, uint32_t &iter_offset) const
Definition: block_to_ctile_map.hpp:1599
static constexpr uint32_t MPerBlock
Definition: block_to_ctile_map.hpp:1419
Definition: block_to_ctile_map.hpp:1019
uint32_t k_iters_per_big_block
Definition: block_to_ctile_map.hpp:1033
__host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
Definition: block_to_ctile_map.hpp:1324
__device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
Definition: block_to_ctile_map.hpp:1387
__host__ __device__ uint32_t get_sk_total_iters() const
Definition: block_to_ctile_map.hpp:1210
__host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_, uint32_t iters_per_sk_block_) const
Definition: block_to_ctile_map.hpp:1339
static constexpr uint32_t MPerBlock
Definition: block_to_ctile_map.hpp:1021
uint32_t dp_start_block_idx
Definition: block_to_ctile_map.hpp:1031
__host__ __device__ uint32_t get_sk_tiles() const
Definition: block_to_ctile_map.hpp:1217
static constexpr uint32_t KPerBlock
Definition: block_to_ctile_map.hpp:1023
__host__ __device__ uint32_t get_total_acc_buffers() const
Definition: block_to_ctile_map.hpp:1346
__device__ uint32_t get_current_iter_length(uint32_t iter_start, uint32_t iter_end, uint32_t total_iter_length) const
Definition: block_to_ctile_map.hpp:1263
static constexpr uint32_t NPerBlock
Definition: block_to_ctile_map.hpp:1022
__device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
Definition: block_to_ctile_map.hpp:1361
uint32_t reduction_start_block_idx
Definition: block_to_ctile_map.hpp:1032
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
Definition: block_to_ctile_map.hpp:1311
MDiv k_iters_per_tile
Definition: block_to_ctile_map.hpp:1035
__device__ void get_tile_idx_with_offset(uint32_t iter, uint32_t &tile_idx, uint32_t &iter_offset) const
Definition: block_to_ctile_map.hpp:1277
static constexpr uint32_t tile_swizzle_sub_m
Definition: block_to_ctile_map.hpp:1025
BlockToCTileMap_GemmStreamK(uint32_t m, uint32_t n, uint32_t k, uint32_t num_cu, uint32_t occupancy, uint32_t sk_blocks=0xffffffff)
Definition: block_to_ctile_map.hpp:1043
static constexpr StreamKReductionStrategy ReductionStrategy
Definition: block_to_ctile_map.hpp:1024
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
Definition: block_to_ctile_map.hpp:1282
__device__ uint32_t get_tile_idx(uint32_t iter) const
Definition: block_to_ctile_map.hpp:1274
__host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_, const MDiv &eqav_tiles_) const
Definition: block_to_ctile_map.hpp:1329
__device__ uint32_t get_block_idx() const
Definition: block_to_ctile_map.hpp:1234
__device__ void get_block_itr(uint32_t block_idx, uint32_t &iter_start, uint32_t &iter_end) const
Definition: block_to_ctile_map.hpp:1241
MDiv eqav_tiles_little
Definition: block_to_ctile_map.hpp:1037
uint32_t sk_num_blocks
Definition: block_to_ctile_map.hpp:1029
MDiv2 n_tiles
Definition: block_to_ctile_map.hpp:1034
MDiv eqav_tiles_big
Definition: block_to_ctile_map.hpp:1036
static constexpr uint32_t min_k_iters_per_sk_block
Definition: block_to_ctile_map.hpp:1020
uint32_t sk_num_big_blocks
Definition: block_to_ctile_map.hpp:1030
__host__ __device__ dim3 get_grid_dims() const
Definition: block_to_ctile_map.hpp:1224
__host__ __device__ uint32_t get_workspace_size_for_semaphore() const
Definition: block_to_ctile_map.hpp:1319
Definition: block_to_ctile_map.hpp:270
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:296
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:281
static constexpr auto I1
Definition: block_to_ctile_map.hpp:272
__host__ bool CheckValidity(const CGridDesc_M_N &) const
Definition: block_to_ctile_map.hpp:290
static constexpr auto I0
Definition: block_to_ctile_map.hpp:271
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition: block_to_ctile_map.hpp:382
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(index_t M, index_t N, index_t M01=8)
Definition: block_to_ctile_map.hpp:274
Definition: block_to_ctile_map.hpp:718
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:753
static constexpr auto I2
Definition: block_to_ctile_map.hpp:721
static constexpr auto I0
Definition: block_to_ctile_map.hpp:719
__host__ BlockToCTileMap_KSplit_M00_N00_M01_N01(const CGridDesc_M_N &c_grid_desc_m_n, index_t M01=1, index_t N01=1, index_t KSplit=1)
Definition: block_to_ctile_map.hpp:726
__host__ BlockToCTileMap_KSplit_M00_N00_M01_N01()=default
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:771
static constexpr auto I3
Definition: block_to_ctile_map.hpp:722
static constexpr auto I1
Definition: block_to_ctile_map.hpp:720
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition: block_to_ctile_map.hpp:762
__host__ constexpr __device__ index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:739
Definition: block_to_ctile_map.hpp:539
__host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt(const CGridDesc_M_N &c_grid_desc_m_n, index_t M01=8, index_t KSplit=1)
Definition: block_to_ctile_map.hpp:547
__host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt()=default
static constexpr auto I0
Definition: block_to_ctile_map.hpp:540
static constexpr auto I1
Definition: block_to_ctile_map.hpp:541
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition: block_to_ctile_map.hpp:592
constexpr __host__ index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:554
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:565
static constexpr auto I2
Definition: block_to_ctile_map.hpp:542
static constexpr auto I3
Definition: block_to_ctile_map.hpp:543
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &) const
Definition: block_to_ctile_map.hpp:598
Definition: block_to_ctile_map.hpp:615
__host__ __device__ BlockToCTileMap_M00_N00_M01_N01()=default
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:659
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:644
__host__ __device__ BlockToCTileMap_M00_N00_M01_N01(const CGridDesc_M_N &c_grid_desc_m_n, index_t M01=1, index_t N01=1)
Definition: block_to_ctile_map.hpp:623
constexpr __host__ index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:630
static constexpr auto I0
Definition: block_to_ctile_map.hpp:616
static constexpr auto I3
Definition: block_to_ctile_map.hpp:619
static constexpr auto I1
Definition: block_to_ctile_map.hpp:617
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition: block_to_ctile_map.hpp:650
static constexpr auto I2
Definition: block_to_ctile_map.hpp:618
__host__ constexpr __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition: block_to_ctile_map.hpp:245
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:157
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01Adapt()=default
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01Adapt(BlockToCTileMap_M00_N0_M01Adapt &&)=default
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01Adapt & operator=(BlockToCTileMap_M00_N0_M01Adapt &&)=default
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01Adapt(const BlockToCTileMap_M00_N0_M01Adapt &)=default
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01Adapt & operator=(const BlockToCTileMap_M00_N0_M01Adapt &)=default
static constexpr __host__ index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: block_to_ctile_map.hpp:166
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N &c_grid_desc_m_n, index_t M01=8)
Definition: block_to_ctile_map.hpp:150
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01=8)
Definition: block_to_ctile_map.hpp:138
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:178
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &) const
Definition: block_to_ctile_map.hpp:172
Definition: block_to_ctile_map.hpp:260
Definition: block_to_ctile_map.hpp:24
constexpr __host__ index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:38
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:66
static constexpr auto I3
Definition: block_to_ctile_map.hpp:28
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N &c_grid_desc_m_n, index_t M01=1)
Definition: block_to_ctile_map.hpp:32
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:51
static constexpr auto I2
Definition: block_to_ctile_map.hpp:27
static constexpr auto I0
Definition: block_to_ctile_map.hpp:25
__host__ constexpr __device__ BlockToCTileMap_M00_N0_M01()=default
static constexpr auto I1
Definition: block_to_ctile_map.hpp:26
__host__ constexpr __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition: block_to_ctile_map.hpp:57
__host__ bool CheckValidity(const CGridDesc_M_N &) const
Definition: block_to_ctile_map.hpp:449
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const CGridDesc_M_N &c_grid_desc_m_n, index_t N01=8)
Definition: block_to_ctile_map.hpp:427
static constexpr __host__ index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: block_to_ctile_map.hpp:443
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt & operator=(const BlockToCTileMap_N00_M0_N01Adapt &)=default
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition: block_to_ctile_map.hpp:523
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt & operator=(BlockToCTileMap_N00_M0_N01Adapt &&)=default
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:455
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const BlockToCTileMap_N00_M0_N01Adapt &)=default
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(BlockToCTileMap_N00_M0_N01Adapt &&)=default
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(index_t M, index_t N, index_t N01=8)
Definition: block_to_ctile_map.hpp:416
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:434
Definition: block_to_ctile_map.hpp:397
Definition: magic_division.hpp:207
__host__ __device__ void divmod(uint32_t dividend_, uint32_t divisor_, uint32_t &quotient_, uint32_t &remainder_) const
Definition: magic_division.hpp:229
Definition: magic_division.hpp:165
__host__ __device__ uint32_t get() const
Definition: magic_division.hpp:203
__host__ __device__ void divmod(uint32_t dividend_, uint32_t &quotient_, uint32_t &remainder_) const
Definition: magic_division.hpp:197
__host__ __device__ uint32_t div(uint32_t dividend_) const
Definition: magic_division.hpp:191
__host__ static constexpr __device__ T Max()
Definition: data_type.hpp:2833
Definition: block_to_ctile_map.hpp:917
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition: block_to_ctile_map.hpp:937
index_t tile_offset_
Definition: block_to_ctile_map.hpp:957
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:944
UnderlyingBlockToCTileMap block_to_ctile_map_
Definition: block_to_ctile_map.hpp:955
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:930
__host__ __device__ OffsettedBlockToCTileMap2(UnderlyingBlockToCTileMap block_to_ctile_map, index_t group_offset, index_t tile_offset)
Definition: block_to_ctile_map.hpp:920
UnderlyingBlockToCTileMap underlying_type
Definition: block_to_ctile_map.hpp:918
index_t group_offset_
Definition: block_to_ctile_map.hpp:956
__device__ void UpdateTileOffset(index_t offset)
Definition: block_to_ctile_map.hpp:954
__host__ constexpr __device__ index_t CalculateGridSize(index_t M, index_t N) const
Definition: block_to_ctile_map.hpp:949
Definition: block_to_ctile_map.hpp:870
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition: block_to_ctile_map.hpp:888
__host__ constexpr __device__ index_t CalculateGridSize(index_t M, index_t N) const
Definition: block_to_ctile_map.hpp:906
constexpr __host__ bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:895
constexpr __host__ index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition: block_to_ctile_map.hpp:901
index_t block_start_
Definition: block_to_ctile_map.hpp:912
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:881
__host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start)
Definition: block_to_ctile_map.hpp:873
UnderlyingBlockToCTileMap underlying_type
Definition: block_to_ctile_map.hpp:871
UnderlyingBlockToCTileMap block_to_ctile_map_
Definition: block_to_ctile_map.hpp:911
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10