/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.2.0/include/ck_tile/host/tensor_shuffle_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.2.0/include/ck_tile/host/tensor_shuffle_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.2.0/include/ck_tile/host/tensor_shuffle_utils.hpp Source File
tensor_shuffle_utils.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 #include <stdexcept>
6 
7 namespace ck_tile {
8 template <typename T>
9 auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
10 {
11  if(t->get_lengths().size() != 2)
12  {
13  throw std::runtime_error("Host tensor is not rank 2 tensor.");
14  }
15  int m_ = t->get_lengths()[0];
16  int aqk_ = t->get_lengths()[1];
17  if(aqk_ % block_aq_k != 0)
18  {
19  throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
20  }
21  ck_tile::HostTensor<T> t_view({m_, aqk_ / block_aq_k, block_aq_k});
22  std::copy(t->begin(), t->end(), t_view.begin());
23  return ck_tile::reference_permute(t_view, {1, 0, 2});
24 }
25 
26 template <typename GemmConfig, typename T>
28 {
29  assert(t.get_lengths().size() == 2);
30  int n_ = t.get_lengths()[1];
31  int k_ = t.get_lengths()[0];
32 
34  {
35  constexpr int divisor = 2;
36  constexpr int kABK1PerLane = 8;
37  constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
38  ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
39  GemmConfig::N_Warp_Tile,
40  k_ / GemmConfig::K_Warp_Tile,
41  kABK0PerLane,
42  divisor,
43  kABK1PerLane});
44  std::copy(t.begin(), t.end(), t_view.begin());
45  return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
46  }
47  else
48  {
49  int divisor = 1;
51  {
52  divisor = 1;
53  }
54  else
55  {
56  assert(is_wave32() == false);
57  divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
58  }
59  ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
60  GemmConfig::N_Warp_Tile,
61  k_ / GemmConfig::K_Warp_Tile,
62  divisor,
63  GemmConfig::K_Warp_Tile / divisor});
64  std::copy(t.begin(), t.end(), t_view.begin());
65  return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
66  }
67 }
68 
69 template <typename GemmConfig, typename T>
71 {
72  assert(t.get_lengths().size() == 2);
73 
74  int n_ = t.get_lengths()[1];
75  int bqk_ = t.get_lengths()[0];
76  constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
77 
79  {n_ / GemmConfig::N_Tile, GemmConfig::N_Warp, GemmConfig::N_Warp_Tile, NRepeat, bqk_});
80  std::copy(t.begin(), t.end(), t_view.begin());
81  return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4});
82 }
83 
84 template <typename GemmConfig, typename T>
86 {
87  assert(t.get_lengths().size() == 2);
88  int n_ = t.get_lengths()[1];
89  int k_ = t.get_lengths()[0];
90  constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
92  {
93  constexpr int divisor = 2;
94  constexpr int kABK1PerLane = 8;
95  constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
96  ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
97  GemmConfig::N_Warp,
98  GemmConfig::N_Warp_Tile,
99  NRepeat,
100  k_ / GemmConfig::K_Warp_Tile,
101  kABK0PerLane,
102  divisor,
103  kABK1PerLane});
104  std::copy(t.begin(), t.end(), t_view.begin());
105  return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 6, 5, 2, 7});
106  }
107  else
108  {
109  int divisor = 1;
111  {
112  divisor = 1;
113  }
114  else
115  {
116  assert(is_wave32() == false);
117  divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
118  }
119  ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
120  GemmConfig::N_Warp,
121  GemmConfig::N_Warp_Tile,
122  NRepeat,
123  k_ / GemmConfig::K_Warp_Tile,
124  divisor,
125  GemmConfig::K_Warp_Tile / divisor});
126  std::copy(t.begin(), t.end(), t_view.begin());
127  return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
128  }
129 }
130 } // namespace ck_tile
auto copy(InputRange &&range, OutputIterator iter) -> decltype(std::copy(std::begin(std::forward< InputRange >(range)), std::end(std::forward< InputRange >(range)), iter))
Definition: algorithm.hpp:14
Definition: cluster_descriptor.hpp:13
bool is_gfx12_supported()
Definition: device_prop.hpp:63
auto shuffle_b(const ck_tile::HostTensor< T > &t)
Definition: tensor_shuffle_utils.hpp:27
auto shuffle_bq_permuteN(const ck_tile::HostTensor< T > &t)
Definition: tensor_shuffle_utils.hpp:70
auto shuffle_b_permuteN(const ck_tile::HostTensor< T > &t)
Definition: tensor_shuffle_utils.hpp:85
auto shuffle_aq(const ck_tile::HostTensor< T > *t, int block_aq_k)
Definition: tensor_shuffle_utils.hpp:9
bool is_gfx11_supported()
Definition: device_prop.hpp:55
CK_TILE_HOST void reference_permute(const HostTensor< DataType > &x, HostTensor< DataType > &y, std::vector< index_t > perm)
Definition: reference_permute.hpp:19
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
Data::iterator end()
Definition: host_tensor.hpp:589
Data::iterator begin()
Definition: host_tensor.hpp:587