13 throw std::runtime_error(
"Host tensor is not rank 2 tensor.");
17 if(aqk_ % block_aq_k != 0)
19 throw std::runtime_error(
"shuffle_aq needs a aqk of multiple times of block_aq_k.");
26 template <
typename GemmConfig,
typename T>
35 constexpr
int divisor = 2;
36 constexpr
int kABK1PerLane = 8;
37 constexpr
int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
39 GemmConfig::N_Warp_Tile,
40 k_ / GemmConfig::K_Warp_Tile,
56 assert(is_wave32() ==
false);
57 divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
60 GemmConfig::N_Warp_Tile,
61 k_ / GemmConfig::K_Warp_Tile,
63 GemmConfig::K_Warp_Tile / divisor});
69 template <
typename GemmConfig,
typename T>
76 constexpr
int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
79 {n_ / GemmConfig::N_Tile, GemmConfig::N_Warp, GemmConfig::N_Warp_Tile, NRepeat, bqk_});
84 template <
typename GemmConfig,
typename T>
90 constexpr
int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
93 constexpr
int divisor = 2;
94 constexpr
int kABK1PerLane = 8;
95 constexpr
int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
98 GemmConfig::N_Warp_Tile,
100 k_ / GemmConfig::K_Warp_Tile,
116 assert(is_wave32() ==
false);
117 divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
121 GemmConfig::N_Warp_Tile,
123 k_ / GemmConfig::K_Warp_Tile,
125 GemmConfig::K_Warp_Tile / divisor});
auto copy(InputRange &&range, OutputIterator iter) -> decltype(std::copy(std::begin(std::forward< InputRange >(range)), std::end(std::forward< InputRange >(range)), iter))
Definition: algorithm.hpp:14
Definition: cluster_descriptor.hpp:13
bool is_gfx12_supported()
Definition: device_prop.hpp:63
auto shuffle_b(const ck_tile::HostTensor< T > &t)
Definition: tensor_shuffle_utils.hpp:27
auto shuffle_bq_permuteN(const ck_tile::HostTensor< T > &t)
Definition: tensor_shuffle_utils.hpp:70
auto shuffle_b_permuteN(const ck_tile::HostTensor< T > &t)
Definition: tensor_shuffle_utils.hpp:85
auto shuffle_aq(const ck_tile::HostTensor< T > *t, int block_aq_k)
Definition: tensor_shuffle_utils.hpp:9
bool is_gfx11_supported()
Definition: device_prop.hpp:55
CK_TILE_HOST void reference_permute(const HostTensor< DataType > &x, HostTensor< DataType > &y, std::vector< index_t > perm)
Definition: reference_permute.hpp:19
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
Data::iterator end()
Definition: host_tensor.hpp:589
Data::iterator begin()
Definition: host_tensor.hpp:587