/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/tensor_shuffle_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/tensor_shuffle_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/tensor_shuffle_utils.hpp Source File
tensor_shuffle_utils.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 #include <stdexcept>
6 
7 namespace ck_tile {
8 template <typename T>
9 auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
10 {
11  if(t->get_lengths().size() != 2)
12  {
13  throw std::runtime_error("Host tensor is not rank 2 tensor.");
14  }
15  int m_ = t->get_lengths()[0];
16  int aqk_ = t->get_lengths()[1];
17 
18  if(aqk_ % block_aq_k != 0)
19  {
20  throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
21  }
22  ck_tile::HostTensor<T> t_view({m_, aqk_ / block_aq_k, block_aq_k});
23  std::copy(t->begin(), t->end(), t_view.begin());
24  return ck_tile::reference_permute(t_view, {1, 0, 2});
25 }
26 
27 template <typename T>
28 auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
29 {
30  const auto& lengths = t->get_lengths();
31  const size_t rank = lengths.size();
32 
33  // Validate block_bq_k divisibility based on rank
34  int bqk_dim = (rank == 5) ? lengths[4] : (rank == 2) ? lengths[0] : -1;
35 
36  if(bqk_dim < 0)
37  {
38  throw std::runtime_error("shuffle_bq expects either rank-2 or rank-5 tensor, got rank " +
39  std::to_string(rank));
40  }
41 
42  if(bqk_dim % block_bq_k != 0)
43  {
44  throw std::runtime_error("shuffle_bq needs bqk dimension to be a multiple of block_bq_k.");
45  }
46 
47  // For TilePermuteN
48  if(rank == 5)
49  {
50  // Handle 5D tensor: [n, nrepeat, nwarp, n_warp_tile, bqk]
51  ck_tile::HostTensor<T> t_view({static_cast<int>(lengths[0]),
52  static_cast<int>(lengths[1]),
53  static_cast<int>(lengths[2]),
54  static_cast<int>(lengths[3]),
55  bqk_dim / block_bq_k,
56  block_bq_k});
57  std::copy(t->begin(), t->end(), t_view.begin());
58  return ck_tile::reference_permute(t_view, {4, 0, 1, 2, 3, 5});
59  }
60  else // rank == 2
61  {
62  // Handle 2D tensor: [bqk, n]
63  int n_ = lengths[1];
64  ck_tile::HostTensor<T> t_view({n_, bqk_dim / block_bq_k, block_bq_k});
65  std::copy(t->begin(), t->end(), t_view.begin());
66  return ck_tile::reference_permute(t_view, {1, 0, 2});
67  }
68 }
69 
70 template <typename GemmConfig, typename T>
71 auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
72 {
73  assert(t.get_lengths().size() == 2);
74  int n_ = t.get_lengths()[1];
75  int k_ = t.get_lengths()[0];
76 
78  {
79  constexpr int divisor = 2;
80  constexpr int kABK1PerLane = 8;
81  int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
82  ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
83  gemmConfig.N_Warp_Tile,
84  k_ / gemmConfig.K_Warp_Tile,
85  kABK0PerLane,
86  divisor,
87  kABK1PerLane});
88  std::copy(t.begin(), t.end(), t_view.begin());
89  return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
90  }
91  else
92  {
93  int divisor = 1;
95  {
96  divisor = 1;
97  }
98  else
99  {
100  assert(is_wave32() == false);
101  divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
102  }
103  ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
104  gemmConfig.N_Warp_Tile,
105  k_ / gemmConfig.K_Warp_Tile,
106  divisor,
107  gemmConfig.K_Warp_Tile / divisor});
108  std::copy(t.begin(), t.end(), t_view.begin());
109  return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
110  }
111 }
112 
113 template <typename GemmConfig, typename T>
115 {
116  return shuffle_b(t, GemmConfig{});
117 }
118 
119 template <typename GemmConfig, typename T>
121 {
122  assert(t.get_lengths().size() == 2);
123 
124  int n_ = t.get_lengths()[1];
125  int bqk_ = t.get_lengths()[0];
126  constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
127 
128  ck_tile::HostTensor<T> t_view({n_ / (GemmConfig::N_Tile / group_n),
129  GemmConfig::N_Warp,
130  GemmConfig::N_Warp_Tile / group_n,
131  NRepeat,
132  bqk_});
133  std::copy(t.begin(), t.end(), t_view.begin());
134  return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4});
135 }
136 
137 template <typename GemmConfig, typename T>
138 auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
139 {
140  assert(t.get_lengths().size() == 2);
141  int n_ = t.get_lengths()[1];
142  int k_ = t.get_lengths()[0];
143  int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp;
145  {
146  constexpr int divisor = 2;
147  constexpr int kABK1PerLane = 8;
148  int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
149  ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
150  gemmConfig.N_Warp,
151  gemmConfig.N_Warp_Tile,
152  NRepeat,
153  k_ / gemmConfig.K_Warp_Tile,
154  kABK0PerLane,
155  divisor,
156  kABK1PerLane});
157  std::copy(t.begin(), t.end(), t_view.begin());
158  return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 6, 5, 2, 7});
159  }
160  else
161  {
162  int divisor = 1;
164  {
165  divisor = 1;
166  }
167  else
168  {
169  assert(is_wave32() == false);
170  divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
171  }
172  ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
173  gemmConfig.N_Warp,
174  gemmConfig.N_Warp_Tile,
175  NRepeat,
176  k_ / gemmConfig.K_Warp_Tile,
177  divisor,
178  gemmConfig.K_Warp_Tile / divisor});
179  std::copy(t.begin(), t.end(), t_view.begin());
180  return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
181  }
182 }
183 
184 template <typename GemmConfig, typename T>
186 {
187  return shuffle_b_permuteN(t, GemmConfig{});
188 }
189 } // namespace ck_tile
__host__ constexpr __device__ auto rank([[maybe_unused]] const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition: layout_utils.hpp:310
auto copy(InputRange &&range, OutputIterator iter) -> decltype(std::copy(std::begin(std::forward< InputRange >(range)), std::end(std::forward< InputRange >(range)), iter))
Definition: algorithm.hpp:14
Definition: cluster_descriptor.hpp:13
auto shuffle_bq(const ck_tile::HostTensor< T > *t, int block_bq_k)
Definition: tensor_shuffle_utils.hpp:28
bool is_gfx12_supported()
Definition: device_prop.hpp:63
int32_t index_t
Definition: integer.hpp:9
auto shuffle_aq(const ck_tile::HostTensor< T > *t, int block_aq_k)
Definition: tensor_shuffle_utils.hpp:9
auto shuffle_b_permuteN(const ck_tile::HostTensor< T > &t, const GemmConfig &gemmConfig)
Definition: tensor_shuffle_utils.hpp:138
bool is_gfx11_supported()
Definition: device_prop.hpp:55
auto bq_permuteN(const ck_tile::HostTensor< T > &t, index_t group_n)
Definition: tensor_shuffle_utils.hpp:120
auto shuffle_b(const ck_tile::HostTensor< T > &t, const GemmConfig &gemmConfig)
Definition: tensor_shuffle_utils.hpp:71
CK_TILE_HOST void reference_permute(const HostTensor< DataType > &x, HostTensor< DataType > &y, std::vector< index_t > perm)
Definition: reference_permute.hpp:19
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
Data::iterator end()
Definition: host_tensor.hpp:588
Data::iterator begin()
Definition: host_tensor.hpp:586