15           typename enable_if<is_scalar_type<S>::value, 
bool>::type = 
false>
 
   28     vy0.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I0];
 
   29     vy0.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I0];
 
   31     vy1.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I1];
 
   32     vy1.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I1];
 
   34     y0 = vy0.template AsType<half2_t>()[I0];
 
   35     y1 = vy1.template AsType<half2_t>()[I0];
 
   37     constexpr 
int32_t m0 = 0x05040100;
 
   38     constexpr 
int32_t m1 = 0x07060302;
 
   44     y0 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0));
 
   45     y1 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m1));
 
   49 template <index_t NX, index_t NY>
 
   66         static_assert((NX % 2 == 0 && NY % 2 == 0), 
"wrong!");
 
   72                 const auto& x_s2_0 = vx_tuple[ix].template AsType<half2_t>()[iy / I2];
 
   73                 const auto& x_s2_1 = vx_tuple[ix + I1].template AsType<half2_t>()[iy / I2];
 
   76                 auto& y_s2_0 = vy_tuple(iy).template AsType<half2_t>()(ix / I2);
 
   77                 auto& y_s2_1 = vy_tuple(iy + I1).template AsType<half2_t>()(ix / I2);
 
   98     constexpr 
int32_t m0 = 0x05010400;
 
   99     constexpr 
int32_t m1 = 0x05040100;
 
  100     constexpr 
int32_t m2 = 0x07060302;
 
  101     constexpr 
int32_t m3 = 0x07030602;
 
  107     t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0);
 
  108     t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m0);
 
  109     z0 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
 
  110     z1 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
 
  111     t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m3);
 
  112     t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m3);
 
  113     z2 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
 
  114     z3 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
 
  116     y0 = bit_cast<int8x4_t>(z0);
 
  117     y1 = bit_cast<int8x4_t>(z1);
 
  118     y2 = bit_cast<int8x4_t>(z2);
 
  119     y3 = bit_cast<int8x4_t>(z3);
 
  122 template <index_t NX, index_t NY>
 
  141         static_assert((NX % 4 == 0 && NY % 4 == 0), 
"wrong!");
 
  147                 const auto& x_s4_0 = vx_tuple[ix].template AsType<int8x4_t>()[iy / I4];
 
  148                 const auto& x_s4_1 = vx_tuple[ix + I1].template AsType<int8x4_t>()[iy / I4];
 
  149                 const auto& x_s4_2 = vx_tuple[ix + I2].template AsType<int8x4_t>()[iy / I4];
 
  150                 const auto& x_s4_3 = vx_tuple[ix + I3].template AsType<int8x4_t>()[iy / I4];
 
  153                 auto& y_s4_0 = vy_tuple(iy).template AsType<int8x4_t>()(ix / I4);
 
  154                 auto& y_s4_1 = vy_tuple(iy + I1).template AsType<int8x4_t>()(ix / I4);
 
  155                 auto& y_s4_2 = vy_tuple(iy + I2).template AsType<int8x4_t>()(ix / I4);
 
  156                 auto& y_s4_3 = vy_tuple(iy + I3).template AsType<int8x4_t>()(ix / I4);
 
  177     constexpr 
int32_t m0 = 0x05010400;
 
  178     constexpr 
int32_t m1 = 0x05040100;
 
  179     constexpr 
int32_t m2 = 0x07060302;
 
  180     constexpr 
int32_t m3 = 0x07030602;
 
  186     t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0);
 
  187     t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m0);
 
  188     z0 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
 
  189     z1 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
 
  190     t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m3);
 
  191     t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m3);
 
  192     z2 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
 
  193     z3 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
 
  195     y0 = bit_cast<f8x4_t>(z0);
 
  196     y1 = bit_cast<f8x4_t>(z1);
 
  197     y2 = bit_cast<f8x4_t>(z2);
 
  198     y3 = bit_cast<f8x4_t>(z3);
 
  201 template <index_t NX, index_t NY>
 
  220         static_assert((NX % 4 == 0 && NY % 4 == 0), 
"wrong!");
 
  226                 const auto& x_s4_0 = vx_tuple[ix].template AsType<f8x4_t>()[iy / I4];
 
  227                 const auto& x_s4_1 = vx_tuple[ix + I1].template AsType<f8x4_t>()[iy / I4];
 
  228                 const auto& x_s4_2 = vx_tuple[ix + I2].template AsType<f8x4_t>()[iy / I4];
 
  229                 const auto& x_s4_3 = vx_tuple[ix + I3].template AsType<f8x4_t>()[iy / I4];
 
  232                 auto& y_s4_0 = vy_tuple(iy).template AsType<f8x4_t>()(ix / I4);
 
  233                 auto& y_s4_1 = vy_tuple(iy + I1).template AsType<f8x4_t>()(ix / I4);
 
  234                 auto& y_s4_2 = vy_tuple(iy + I2).template AsType<f8x4_t>()(ix / I4);
 
  235                 auto& y_s4_3 = vy_tuple(iy + I3).template AsType<f8x4_t>()(ix / I4);
 
  238                 transpose_f8_4x4(x_s4_0, x_s4_1, x_s4_2, x_s4_3, y_s4_0, y_s4_1, y_s4_2, y_s4_3);
 
int8_t int8_t
Definition: int8.hpp:20
 
int32_t int32_t
Definition: integer.hpp:10
 
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
 
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
 
_Float16 half_t
Definition: data_type.hpp:30
 
__device__ void transpose_f8_4x4(const f8x4_t &x0, const f8x4_t &x1, const f8x4_t &x2, const f8x4_t &x3, f8x4_t &y0, f8x4_t &y1, f8x4_t &y2, f8x4_t &y3)
Definition: transpose_vectors.hpp:166
 
typename vector_type< half_t, 2 >::type half2_t
Definition: dtype_vector.hpp:2139
 
int32_t index_t
Definition: ck.hpp:297
 
typename vector_type< int8_t, 4 >::type int8x4_t
Definition: dtype_vector.hpp:2163
 
__device__ void transpose_int8_4x4(const int8x4_t &x0, const int8x4_t &x1, const int8x4_t &x2, const int8x4_t &x3, int8x4_t &y0, int8x4_t &y1, int8x4_t &y2, int8x4_t &y3)
Definition: transpose_vectors.hpp:87
 
__device__ void transpose_fp16_2x2(const half2_t &x0, const half2_t &x1, half2_t &y0, half2_t &y1)
Definition: transpose_vectors.hpp:19
 
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
f8_t S
Definition: transpose_vectors.hpp:208
 
__device__ void operator()(const StaticallyIndexedArray< const VX &, NX > &vx_tuple, StaticallyIndexedArray< VY &, NY > &vy_tuple)
Definition: transpose_vectors.hpp:212
 
half_t S
Definition: transpose_vectors.hpp:56
 
__device__ void operator()(const StaticallyIndexedArray< const VX &, NX > &vx_tuple, StaticallyIndexedArray< VY &, NY > &vy_tuple)
Definition: transpose_vectors.hpp:60
 
__device__ void operator()(const StaticallyIndexedArray< const VX &, NX > &vx_tuple, StaticallyIndexedArray< VY &, NY > &vy_tuple)
Definition: transpose_vectors.hpp:133
 
int8_t S
Definition: transpose_vectors.hpp:129
 
Definition: transpose_vectors.hpp:16
 
Definition: dtype_vector.hpp:10