4 #ifndef CK_AMD_INLINE_ASM_HPP 
    5 #define CK_AMD_INLINE_ASM_HPP 
   17     asm volatile(
"v_and_b32 %0, %1, %2" : 
"=v"(c) : 
"v"(a), 
"v"(b));
 
   24     asm volatile(
"v_and_or_b32 %0, %1, %2, %3" : 
"=v"(c) : 
"v"(a), 
"v"(b), 
"v"(d));
 
   31     asm volatile(
"v_pk_fma_f16 %0, %1, %2, %3" : 
"=v"(d) : 
"v"(a), 
"v"(b), 
"v"(c));
 
   38     asm volatile(
"v_pk_add_f16 %0, %1, %2" : 
"=v"(c) : 
"v"(a), 
"v"(b));
 
   45     asm volatile(
"v_cvt_off_f32_i4 %0, %1" : 
"=v"(a) : 
"v"(b));
 
   52     asm volatile(
"v_cvt_pk_fp8_f32 %0, %1, %2\n" 
   53                  "v_cvt_pk_fp8_f32 %0, %3, %4, op_sel:[0, 0, 1]\n" 
   55                  : 
"v"(b0), 
"v"(b1), 
"v"(b2), 
"v"(b3));
 
   61     uint32_t i4x8 = 
static_cast<uint32_t
>(a);
 
   64     float tmp_0, tmp_1, tmp_2;
 
   66     asm volatile(
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n" 
   67                  "v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_2\n" 
   68                  "v_cvt_pk_fp8_f32 %[v_dst_0], %[v_tmp_0], %[v_tmp_1]\n" 
   69                  "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n" 
   70                  "v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_3\n" 
   71                  "v_cvt_pk_fp8_f32 %[v_dst_1], %[v_tmp_0], %[v_tmp_1]\n" 
   72                  "v_lshrrev_b32 %[v_tmp_2], 4, %[v_src]\n" 
   73                  "v_cvt_off_f32_i4 %[v_tmp_0], %[v_tmp_2]\n" 
   74                  "v_cvt_off_f32_i4 %[v_tmp_1], %[v_tmp_2], src0_sel:BYTE_2\n" 
   75                  "v_cvt_pk_fp8_f32 %[v_dst_0], %[v_tmp_0], %[v_tmp_1], op_sel:[0, 0, 1]\n" 
   76                  "v_cvt_off_f32_i4 %[v_tmp_0], %[v_tmp_2], src0_sel:BYTE_1\n" 
   77                  "v_cvt_off_f32_i4 %[v_tmp_1], %[v_tmp_2], src0_sel:BYTE_3\n" 
   78                  "v_cvt_pk_fp8_f32 %[v_dst_1], %[v_tmp_0], %[v_tmp_1], op_sel:[0, 0, 1]\n" 
   79                  : [v_tmp_0] 
"+v"(tmp_0),
 
   80                    [v_tmp_1] 
"+v"(tmp_1),
 
   81                    [v_tmp_2] 
"+v"(tmp_2),
 
   82                    [v_dst_0] 
"+v"(fp8x4_0),
 
   83                    [v_dst_1] 
"+v"(fp8x4_1),
 
   87     return bit_cast<f8x8_t>(((
static_cast<uint64_t
>(fp8x4_1) << 32) | fp8x4_0));
 
   95             v_fmac_f32 %0, %2, %3 \n \ 
   96             v_fmac_f32 %1, %2, %4 \n \ 
   99                  : 
"v"(a), 
"v"(b0), 
"v"(b1), 
"0"(c0), 
"1"(c1));
 
  107     float a, 
float b0, 
float b1, 
float b2, 
float b3, 
float& c0, 
float& c1, 
float& c2, 
float& c3)
 
  110             v_fmac_f32 %0, %4, %5 \n \ 
  111             v_fmac_f32 %1, %4, %6 \n \ 
  112             v_fmac_f32 %2, %4, %7 \n \ 
  113             v_fmac_f32 %3, %4, %8 \n \ 
  115                  : 
"=v"(c0), 
"=v"(c1), 
"=v"(c2), 
"=v"(c3)
 
  116                  : 
"v"(a), 
"v"(b0), 
"v"(b1), 
"v"(b2), 
"v"(b3), 
"0"(c0), 
"1"(c1), 
"2"(c2), 
"3"(c3));
 
  125             v_dot2_f32_f16 %0, %2, %3, %0\n \ 
  126             v_dot2_f32_f16 %1, %2, %4, %1\n \ 
  129                  : 
"v"(a), 
"v"(b0), 
"v"(b1), 
"0"(c0), 
"1"(c1));
 
  138     const half2_t* p_a_half2  = c_style_pointer_cast<const half2_t*>(&a);
 
  139     const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
 
  140     const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
 
  144             v_dot2_f32_f16 %0, %2, %4, %0\n \ 
  145             v_dot2_f32_f16 %1, %2, %6, %1\n \ 
  146             v_dot2_f32_f16 %0, %3, %5, %0\n \ 
  147             v_dot2_f32_f16 %1, %3, %7, %1\n \ 
  175             v_dot2_f32_f16 %0, %4, %5, %0\n \ 
  176             v_dot2_f32_f16 %1, %4, %6, %1\n \ 
  177             v_dot2_f32_f16 %2, %4, %7, %2\n \ 
  178             v_dot2_f32_f16 %3, %4, %8, %3\n \ 
  180                  : 
"=v"(c0), 
"=v"(c1), 
"=v"(c2), 
"=v"(c3)
 
  181                  : 
"v"(a), 
"v"(b0), 
"v"(b1), 
"v"(b2), 
"v"(b3), 
"0"(c0), 
"1"(c1), 
"2"(c2), 
"3"(c3));
 
  199     const half2_t* p_a_half2  = c_style_pointer_cast<const half2_t*>(&a);
 
  200     const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
 
  201     const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
 
  202     const half2_t* p_b2_half2 = c_style_pointer_cast<const half2_t*>(&b2);
 
  203     const half2_t* p_b3_half2 = c_style_pointer_cast<const half2_t*>(&b3);
 
  207             v_dot2_f32_f16 %0, %4, %6,  %0\n \ 
  208             v_dot2_f32_f16 %1, %4, %8,  %1\n \ 
  209             v_dot2_f32_f16 %2, %4, %10, %2\n \ 
  210             v_dot2_f32_f16 %3, %4, %12, %3\n \ 
  211             v_dot2_f32_f16 %0, %5, %7,  %0\n \ 
  212             v_dot2_f32_f16 %1, %5, %9,  %1\n \ 
  213             v_dot2_f32_f16 %2, %5, %11, %2\n \ 
  214             v_dot2_f32_f16 %3, %5, %13, %3\n \ 
  216                  : 
"=v"(c0), 
"=v"(c1), 
"=v"(c2), 
"=v"(c3)
 
  245     const half4_t* p_a_half4  = c_style_pointer_cast<const half4_t*>(&a);
 
  246     const half4_t* p_b0_half4 = c_style_pointer_cast<const half4_t*>(&b0);
 
  247     const half4_t* p_b1_half4 = c_style_pointer_cast<const half4_t*>(&b1);
 
  248     const half4_t* p_b2_half4 = c_style_pointer_cast<const half4_t*>(&b2);
 
  249     const half4_t* p_b3_half4 = c_style_pointer_cast<const half4_t*>(&b3);
 
  252         p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3);
 
  255         p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3);
 
  269     const half8_t* p_a_half8  = c_style_pointer_cast<const half8_t*>(&a);
 
  270     const half8_t* p_b0_half8 = c_style_pointer_cast<const half8_t*>(&b0);
 
  271     const half8_t* p_b1_half8 = c_style_pointer_cast<const half8_t*>(&b1);
 
  272     const half8_t* p_b2_half8 = c_style_pointer_cast<const half8_t*>(&b2);
 
  273     const half8_t* p_b3_half8 = c_style_pointer_cast<const half8_t*>(&b3);
 
  276         p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3);
 
  279         p_a_half8[1], p_b0_half8[1], p_b1_half8[1], p_b2_half8[1], p_b3_half8[1], c0, c1, c2, c3);
 
  289             v_dot4_i32_i8 %0, %2, %3, %0\n \ 
  290             v_dot4_i32_i8 %1, %2, %4, %1\n \ 
  293                  : 
"v"(bit_cast<int32_t>(a)),
 
  294                    "v"(bit_cast<int32_t>(b0)),
 
  295                    "v"(bit_cast<int32_t>(b1)),
 
  299     c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, 
false);
 
  300     c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, 
false);
 
  320             v_dot4_i32_i8 %0, %4, %5, %0\n \ 
  321             v_dot4_i32_i8 %1, %4, %6, %1\n \ 
  322             v_dot4_i32_i8 %2, %4, %7, %2\n \ 
  323             v_dot4_i32_i8 %3, %4, %8, %3\n \ 
  325                  : 
"=v"(c0), 
"=v"(c1), 
"=v"(c2), 
"=v"(c3)
 
  326                  : 
"v"(bit_cast<int32_t>(a)),
 
  327                    "v"(bit_cast<int32_t>(b0)),
 
  328                    "v"(bit_cast<int32_t>(b1)),
 
  329                    "v"(bit_cast<int32_t>(b2)),
 
  330                    "v"(bit_cast<int32_t>(b3)),
 
  336     c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, 
false);
 
  337     c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, 
false);
 
  338     c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, 
false);
 
  339     c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, 
false);
 
  353     constexpr 
auto I0 = Number<0>{};
 
  354     constexpr 
auto I1 = Number<1>{};
 
  357                                    vector_type<int8_t, 8>{b0}.AsType<
int8x4_t>()[I0],
 
  358                                    vector_type<int8_t, 8>{b1}.AsType<
int8x4_t>()[I0],
 
  359                                    vector_type<int8_t, 8>{b2}.AsType<
int8x4_t>()[I0],
 
  360                                    vector_type<int8_t, 8>{b3}.AsType<
int8x4_t>()[I0],
 
  367                                    vector_type<int8_t, 8>{b0}.AsType<
int8x4_t>()[I1],
 
  368                                    vector_type<int8_t, 8>{b1}.AsType<
int8x4_t>()[I1],
 
  369                                    vector_type<int8_t, 8>{b2}.AsType<
int8x4_t>()[I1],
 
  370                                    vector_type<int8_t, 8>{b3}.AsType<
int8x4_t>()[I1],
 
  388     constexpr 
auto I0 = Number<0>{};
 
  389     constexpr 
auto I1 = Number<1>{};
 
  390     constexpr 
auto I2 = Number<2>{};
 
  391     constexpr 
auto I3 = Number<3>{};
 
  394                                    vector_type<int8_t, 16>{b0}.AsType<
int8x4_t>()[I0],
 
  395                                    vector_type<int8_t, 16>{b1}.AsType<
int8x4_t>()[I0],
 
  396                                    vector_type<int8_t, 16>{b2}.AsType<
int8x4_t>()[I0],
 
  397                                    vector_type<int8_t, 16>{b3}.AsType<
int8x4_t>()[I0],
 
  404                                    vector_type<int8_t, 16>{b0}.AsType<
int8x4_t>()[I1],
 
  405                                    vector_type<int8_t, 16>{b1}.AsType<
int8x4_t>()[I1],
 
  406                                    vector_type<int8_t, 16>{b2}.AsType<
int8x4_t>()[I1],
 
  407                                    vector_type<int8_t, 16>{b3}.AsType<
int8x4_t>()[I1],
 
  414                                    vector_type<int8_t, 16>{b0}.AsType<
int8x4_t>()[I2],
 
  415                                    vector_type<int8_t, 16>{b1}.AsType<
int8x4_t>()[I2],
 
  416                                    vector_type<int8_t, 16>{b2}.AsType<
int8x4_t>()[I2],
 
  417                                    vector_type<int8_t, 16>{b3}.AsType<
int8x4_t>()[I2],
 
  424                                    vector_type<int8_t, 16>{b0}.AsType<
int8x4_t>()[I3],
 
  425                                    vector_type<int8_t, 16>{b1}.AsType<
int8x4_t>()[I3],
 
  426                                    vector_type<int8_t, 16>{b2}.AsType<
int8x4_t>()[I3],
 
  427                                    vector_type<int8_t, 16>{b3}.AsType<
int8x4_t>()[I3],
 
int32_t int32_t
Definition: integer.hpp:10
 
__device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b)
Definition: amd_inline_asm.hpp:35
 
__device__ f8x8_t amd_assembly_i4_to_fp8x8(int a)
Definition: amd_inline_asm.hpp:59
 
__device__ void amd_assembly_outer_product_1x4(float a, float b0, float b1, float b2, float b3, float &c0, float &c1, float &c2, float &c3)
Definition: amd_inline_asm.hpp:106
 
__device__ f8x4_t amd_assembly_cvt_f8_to_f32(float b0, float b1, float b2, float b3)
Definition: amd_inline_asm.hpp:49
 
__device__ int amd_assembly_and_b32(int a, int b)
Definition: amd_inline_asm.hpp:14
 
__device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c)
Definition: amd_inline_asm.hpp:28
 
typename vector_type< int8_t, 8 >::type int8x8_t
Definition: dtype_vector.hpp:2164
 
typename vector_type< half_t, 4 >::type half4_t
Definition: dtype_vector.hpp:2140
 
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float &c0, float &c1)
Definition: amd_inline_asm.hpp:92
 
typename vector_type< int8_t, 16 >::type int8x16_t
Definition: dtype_vector.hpp:2165
 
typename vector_type< half_t, 2 >::type half2_t
Definition: dtype_vector.hpp:2139
 
__device__ int amd_assembly_and_or_b32(int a, int b, int d)
Definition: amd_inline_asm.hpp:21
 
typename vector_type< int8_t, 4 >::type int8x4_t
Definition: dtype_vector.hpp:2163
 
typename vector_type< half_t, 16 >::type half16_t
Definition: dtype_vector.hpp:2142
 
__device__ float amd_assemble_cvt_f32_i4(int b)
Definition: amd_inline_asm.hpp:42
 
typename vector_type< half_t, 8 >::type half8_t
Definition: dtype_vector.hpp:2141