13 throw std::runtime_error(
"Host tensor is not rank 2 tensor.");
17 if(aqk_ % block_aq_k != 0)
19 throw std::runtime_error(
"shuffle_aq needs a aqk of multiple times of block_aq_k.");
30 const size_t rank = lengths.size();
33 int bqk_dim = (
rank == 5) ? lengths[4] : (
rank == 2) ? lengths[0] : -1;
37 throw std::runtime_error(
"shuffle_bq expects either rank-2 or rank-5 tensor, got rank " +
38 std::to_string(
rank));
41 if(bqk_dim % block_bq_k != 0)
43 throw std::runtime_error(
"shuffle_bq needs bqk dimension to be a multiple of block_bq_k.");
51 static_cast<int>(lengths[1]),
52 static_cast<int>(lengths[2]),
53 static_cast<int>(lengths[3]),
69 template <
typename GemmConfig,
typename T>
78 constexpr
int divisor = 2;
79 constexpr
int kABK1PerLane = 8;
80 constexpr
int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
82 GemmConfig::N_Warp_Tile,
83 k_ / GemmConfig::K_Warp_Tile,
99 assert(is_wave32() ==
false);
100 divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
103 GemmConfig::N_Warp_Tile,
104 k_ / GemmConfig::K_Warp_Tile,
106 GemmConfig::K_Warp_Tile / divisor});
112 template <
typename GemmConfig,
typename T>
119 constexpr
int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
122 {n_ / GemmConfig::N_Tile, GemmConfig::N_Warp, GemmConfig::N_Warp_Tile, NRepeat, bqk_});
127 template <
typename GemmConfig,
typename T>
133 constexpr
int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
136 constexpr
int divisor = 2;
137 constexpr
int kABK1PerLane = 8;
138 constexpr
int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
141 GemmConfig::N_Warp_Tile,
143 k_ / GemmConfig::K_Warp_Tile,
159 assert(is_wave32() ==
false);
160 divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
164 GemmConfig::N_Warp_Tile,
166 k_ / GemmConfig::K_Warp_Tile,
168 GemmConfig::K_Warp_Tile / divisor});
__host__ constexpr __device__ auto rank([[maybe_unused]] const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition: layout_utils.hpp:310
auto copy(InputRange &&range, OutputIterator iter) -> decltype(std::copy(std::begin(std::forward< InputRange >(range)), std::end(std::forward< InputRange >(range)), iter))
Definition: algorithm.hpp:14
Definition: cluster_descriptor.hpp:13
auto shuffle_bq(const ck_tile::HostTensor< T > *t, int block_bq_k)
Definition: tensor_shuffle_utils.hpp:27
bool is_gfx12_supported()
Definition: device_prop.hpp:63
auto shuffle_b(const ck_tile::HostTensor< T > &t)
Definition: tensor_shuffle_utils.hpp:70
auto shuffle_b_permuteN(const ck_tile::HostTensor< T > &t)
Definition: tensor_shuffle_utils.hpp:128
auto shuffle_aq(const ck_tile::HostTensor< T > *t, int block_aq_k)
Definition: tensor_shuffle_utils.hpp:9
auto bq_permuteN(const ck_tile::HostTensor< T > &t)
Definition: tensor_shuffle_utils.hpp:113
bool is_gfx11_supported()
Definition: device_prop.hpp:55
CK_TILE_HOST void reference_permute(const HostTensor< DataType > &x, HostTensor< DataType > &y, std::vector< index_t > perm)
Definition: reference_permute.hpp:19
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
Data::iterator end()
Definition: host_tensor.hpp:588
Data::iterator begin()
Definition: host_tensor.hpp:586