/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/tensor_shuffle_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/tensor_shuffle_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/tensor_shuffle_utils.hpp Source File
tensor_shuffle_utils.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 #include <stdexcept>
6 
7 namespace ck_tile {
8 template <typename T>
9 auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
10 {
11  if(t->get_lengths().size() != 2)
12  {
13  throw std::runtime_error("Host tensor is not rank 2 tensor.");
14  }
15  int m_ = t->get_lengths()[0];
16  int aqk_ = t->get_lengths()[1];
17  if(aqk_ % block_aq_k != 0)
18  {
19  throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
20  }
21  ck_tile::HostTensor<T> t_view({m_, aqk_ / block_aq_k, block_aq_k});
22  std::copy(t->begin(), t->end(), t_view.begin());
23  return ck_tile::reference_permute(t_view, {1, 0, 2});
24 }
25 
26 template <typename T>
27 auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
28 {
29  const auto& lengths = t->get_lengths();
30  const size_t rank = lengths.size();
31 
32  // Validate block_bq_k divisibility based on rank
33  int bqk_dim = (rank == 5) ? lengths[4] : (rank == 2) ? lengths[0] : -1;
34 
35  if(bqk_dim < 0)
36  {
37  throw std::runtime_error("shuffle_bq expects either rank-2 or rank-5 tensor, got rank " +
38  std::to_string(rank));
39  }
40 
41  if(bqk_dim % block_bq_k != 0)
42  {
43  throw std::runtime_error("shuffle_bq needs bqk dimension to be a multiple of block_bq_k.");
44  }
45 
46  // For TilePermuteN
47  if(rank == 5)
48  {
49  // Handle 5D tensor: [n, nrepeat, nwarp, n_warp_tile, bqk]
50  ck_tile::HostTensor<T> t_view({static_cast<int>(lengths[0]),
51  static_cast<int>(lengths[1]),
52  static_cast<int>(lengths[2]),
53  static_cast<int>(lengths[3]),
54  bqk_dim / block_bq_k,
55  block_bq_k});
56  std::copy(t->begin(), t->end(), t_view.begin());
57  return ck_tile::reference_permute(t_view, {4, 0, 1, 2, 3, 5});
58  }
59  else // rank == 2
60  {
61  // Handle 2D tensor: [bqk, n]
62  int n_ = lengths[1];
63  ck_tile::HostTensor<T> t_view({n_, bqk_dim / block_bq_k, block_bq_k});
64  std::copy(t->begin(), t->end(), t_view.begin());
65  return ck_tile::reference_permute(t_view, {1, 0, 2});
66  }
67 }
68 
69 template <typename GemmConfig, typename T>
71 {
72  assert(t.get_lengths().size() == 2);
73  int n_ = t.get_lengths()[1];
74  int k_ = t.get_lengths()[0];
75 
77  {
78  constexpr int divisor = 2;
79  constexpr int kABK1PerLane = 8;
80  constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
81  ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
82  GemmConfig::N_Warp_Tile,
83  k_ / GemmConfig::K_Warp_Tile,
84  kABK0PerLane,
85  divisor,
86  kABK1PerLane});
87  std::copy(t.begin(), t.end(), t_view.begin());
88  return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
89  }
90  else
91  {
92  int divisor = 1;
94  {
95  divisor = 1;
96  }
97  else
98  {
99  assert(is_wave32() == false);
100  divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
101  }
102  ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
103  GemmConfig::N_Warp_Tile,
104  k_ / GemmConfig::K_Warp_Tile,
105  divisor,
106  GemmConfig::K_Warp_Tile / divisor});
107  std::copy(t.begin(), t.end(), t_view.begin());
108  return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
109  }
110 }
111 
112 template <typename GemmConfig, typename T>
114 {
115  assert(t.get_lengths().size() == 2);
116 
117  int n_ = t.get_lengths()[1];
118  int bqk_ = t.get_lengths()[0];
119  constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
120 
121  ck_tile::HostTensor<T> t_view(
122  {n_ / GemmConfig::N_Tile, GemmConfig::N_Warp, GemmConfig::N_Warp_Tile, NRepeat, bqk_});
123  std::copy(t.begin(), t.end(), t_view.begin());
124  return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4});
125 }
126 
127 template <typename GemmConfig, typename T>
129 {
130  assert(t.get_lengths().size() == 2);
131  int n_ = t.get_lengths()[1];
132  int k_ = t.get_lengths()[0];
133  constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
135  {
136  constexpr int divisor = 2;
137  constexpr int kABK1PerLane = 8;
138  constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
139  ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
140  GemmConfig::N_Warp,
141  GemmConfig::N_Warp_Tile,
142  NRepeat,
143  k_ / GemmConfig::K_Warp_Tile,
144  kABK0PerLane,
145  divisor,
146  kABK1PerLane});
147  std::copy(t.begin(), t.end(), t_view.begin());
148  return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 6, 5, 2, 7});
149  }
150  else
151  {
152  int divisor = 1;
154  {
155  divisor = 1;
156  }
157  else
158  {
159  assert(is_wave32() == false);
160  divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
161  }
162  ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
163  GemmConfig::N_Warp,
164  GemmConfig::N_Warp_Tile,
165  NRepeat,
166  k_ / GemmConfig::K_Warp_Tile,
167  divisor,
168  GemmConfig::K_Warp_Tile / divisor});
169  std::copy(t.begin(), t.end(), t_view.begin());
170  return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
171  }
172 }
173 } // namespace ck_tile
__host__ constexpr __device__ auto rank([[maybe_unused]] const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition: layout_utils.hpp:310
auto copy(InputRange &&range, OutputIterator iter) -> decltype(std::copy(std::begin(std::forward< InputRange >(range)), std::end(std::forward< InputRange >(range)), iter))
Definition: algorithm.hpp:14
Definition: cluster_descriptor.hpp:13
auto shuffle_bq(const ck_tile::HostTensor< T > *t, int block_bq_k)
Definition: tensor_shuffle_utils.hpp:27
bool is_gfx12_supported()
Definition: device_prop.hpp:63
auto shuffle_b(const ck_tile::HostTensor< T > &t)
Definition: tensor_shuffle_utils.hpp:70
auto shuffle_b_permuteN(const ck_tile::HostTensor< T > &t)
Definition: tensor_shuffle_utils.hpp:128
auto shuffle_aq(const ck_tile::HostTensor< T > *t, int block_aq_k)
Definition: tensor_shuffle_utils.hpp:9
auto bq_permuteN(const ck_tile::HostTensor< T > &t)
Definition: tensor_shuffle_utils.hpp:113
bool is_gfx11_supported()
Definition: device_prop.hpp:55
CK_TILE_HOST void reference_permute(const HostTensor< DataType > &x, HostTensor< DataType > &y, std::vector< index_t > perm)
Definition: reference_permute.hpp:19
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
Data::iterator end()
Definition: host_tensor.hpp:588
Data::iterator begin()
Definition: host_tensor.hpp:586