13 throw std::runtime_error(
"Host tensor is not rank 2 tensor.");
18 if(aqk_ % block_aq_k != 0)
20 throw std::runtime_error(
"shuffle_aq needs a aqk of multiple times of block_aq_k.");
31 const size_t rank = lengths.size();
34 int bqk_dim = (
rank == 5) ? lengths[4] : (
rank == 2) ? lengths[0] : -1;
38 throw std::runtime_error(
"shuffle_bq expects either rank-2 or rank-5 tensor, got rank " +
39 std::to_string(
rank));
42 if(bqk_dim % block_bq_k != 0)
44 throw std::runtime_error(
"shuffle_bq needs bqk dimension to be a multiple of block_bq_k.");
52 static_cast<int>(lengths[1]),
53 static_cast<int>(lengths[2]),
54 static_cast<int>(lengths[3]),
70 template <
typename GemmConfig,
typename T>
79 constexpr
int divisor = 2;
80 constexpr
int kABK1PerLane = 8;
81 int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
83 gemmConfig.N_Warp_Tile,
84 k_ / gemmConfig.K_Warp_Tile,
100 assert(is_wave32() ==
false);
101 divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
104 gemmConfig.N_Warp_Tile,
105 k_ / gemmConfig.K_Warp_Tile,
107 gemmConfig.K_Warp_Tile / divisor});
113 template <
typename GemmConfig,
typename T>
119 template <
typename GemmConfig,
typename T>
126 constexpr
int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
130 GemmConfig::N_Warp_Tile / group_n,
137 template <
typename GemmConfig,
typename T>
143 int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp;
146 constexpr
int divisor = 2;
147 constexpr
int kABK1PerLane = 8;
148 int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
151 gemmConfig.N_Warp_Tile,
153 k_ / gemmConfig.K_Warp_Tile,
169 assert(is_wave32() ==
false);
170 divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
174 gemmConfig.N_Warp_Tile,
176 k_ / gemmConfig.K_Warp_Tile,
178 gemmConfig.K_Warp_Tile / divisor});
184 template <
typename GemmConfig,
typename T>
__host__ constexpr __device__ auto rank([[maybe_unused]] const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition: layout_utils.hpp:310
auto copy(InputRange &&range, OutputIterator iter) -> decltype(std::copy(std::begin(std::forward< InputRange >(range)), std::end(std::forward< InputRange >(range)), iter))
Definition: algorithm.hpp:14
Definition: cluster_descriptor.hpp:13
auto shuffle_bq(const ck_tile::HostTensor< T > *t, int block_bq_k)
Definition: tensor_shuffle_utils.hpp:28
bool is_gfx12_supported()
Definition: device_prop.hpp:63
int32_t index_t
Definition: integer.hpp:9
auto shuffle_aq(const ck_tile::HostTensor< T > *t, int block_aq_k)
Definition: tensor_shuffle_utils.hpp:9
auto shuffle_b_permuteN(const ck_tile::HostTensor< T > &t, const GemmConfig &gemmConfig)
Definition: tensor_shuffle_utils.hpp:138
bool is_gfx11_supported()
Definition: device_prop.hpp:55
auto bq_permuteN(const ck_tile::HostTensor< T > &t, index_t group_n)
Definition: tensor_shuffle_utils.hpp:120
auto shuffle_b(const ck_tile::HostTensor< T > &t, const GemmConfig &gemmConfig)
Definition: tensor_shuffle_utils.hpp:71
CK_TILE_HOST void reference_permute(const HostTensor< DataType > &x, HostTensor< DataType > &y, std::vector< index_t > perm)
Definition: reference_permute.hpp:19
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
Data::iterator end()
Definition: host_tensor.hpp:588
Data::iterator begin()
Definition: host_tensor.hpp:586