15 static constexpr 
bool is_scale_mfma_data_type()
 
   17     using U = element_type_t<T>;
 
   18     return is_same_v<U, f8_ocp_t> || is_same_v<U, bf8_ocp_t> || is_same_v<U, f6_t> ||
 
   19            is_same_v<U, bf6_t> || is_same_v<U, f4_t>;
 
   26 static constexpr 
bool is_scale_mfma_scale_type()
 
   28     return is_same_v<T, e8m0_bexp_t>;
 
   34 template <
typename ADataType, 
typename BDataType, 
typename AScaleDataType, 
typename BScaleDataType>
 
   35 static constexpr 
bool scale_mfma_hw_support()
 
   37     return is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>() &&
 
   38            is_scale_mfma_scale_type<AScaleDataType>() && is_scale_mfma_scale_type<BScaleDataType>();
 
   82 template <MfmaInstr instr>
 
   89     static constexpr 
index_t num_groups_per_blk  = 4;
 
   90     static constexpr 
index_t num_regs_per_blk    = 16;
 
   91     static constexpr 
index_t num_threads_per_blk = 32;
 
   93     static constexpr 
index_t num_input_blks      = 2;
 
   94     static constexpr 
index_t num_output_blks     = 2;
 
   98     static constexpr 
bool is_k_reduction         = 
false;
 
  100     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  101     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  111     static constexpr 
index_t num_groups_per_blk  = 4;
 
  112     static constexpr 
index_t num_regs_per_blk    = 16;
 
  113     static constexpr 
index_t num_threads_per_blk = 32;
 
  120     static constexpr 
bool is_k_reduction         = 
true;
 
  122     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  123     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  133     static constexpr 
index_t num_groups_per_blk  = 1;
 
  134     static constexpr 
index_t num_regs_per_blk    = 4;
 
  135     static constexpr 
index_t num_threads_per_blk = 16;
 
  142     static constexpr 
bool is_k_reduction         = 
true;
 
  144     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  145     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  155     static constexpr 
index_t num_groups_per_blk  = 1;
 
  156     static constexpr 
index_t num_regs_per_blk    = 4;
 
  157     static constexpr 
index_t num_threads_per_blk = 16;
 
  164     static constexpr 
bool is_k_reduction         = 
false;
 
  166     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  167     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  178     static constexpr 
index_t num_groups_per_blk  = 1;
 
  179     static constexpr 
index_t num_regs_per_blk    = 4;
 
  180     static constexpr 
index_t num_threads_per_blk = 64;
 
  187     static constexpr 
bool is_k_reduction         = 
false;
 
  189     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  190     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  200     static constexpr 
index_t num_groups_per_blk  = 4;
 
  201     static constexpr 
index_t num_regs_per_blk    = 16;
 
  202     static constexpr 
index_t num_threads_per_blk = 32;
 
  209     static constexpr 
bool is_k_reduction         = 
false;
 
  211     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  212     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  222     static constexpr 
index_t num_groups_per_blk  = 4;
 
  223     static constexpr 
index_t num_regs_per_blk    = 16;
 
  224     static constexpr 
index_t num_threads_per_blk = 32;
 
  231     static constexpr 
bool is_k_reduction         = 
true;
 
  233     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  234     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  244     static constexpr 
index_t num_groups_per_blk  = 4;
 
  245     static constexpr 
index_t num_regs_per_blk    = 16;
 
  246     static constexpr 
index_t num_threads_per_blk = 32;
 
  253     static constexpr 
bool is_k_reduction         = 
true;
 
  255     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  256     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  266     static constexpr 
index_t num_groups_per_blk  = 1;
 
  267     static constexpr 
index_t num_regs_per_blk    = 4;
 
  268     static constexpr 
index_t num_threads_per_blk = 16;
 
  275     static constexpr 
bool is_k_reduction         = 
true;
 
  277     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  278     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  288     static constexpr 
index_t num_groups_per_blk  = 1;
 
  289     static constexpr 
index_t num_regs_per_blk    = 4;
 
  290     static constexpr 
index_t num_threads_per_blk = 16;
 
  297     static constexpr 
bool is_k_reduction         = 
true;
 
  299     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  300     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  310     static constexpr 
index_t num_groups_per_blk  = 1;
 
  311     static constexpr 
index_t num_regs_per_blk    = 4;
 
  312     static constexpr 
index_t num_threads_per_blk = 16;
 
  319     static constexpr 
bool is_k_reduction         = 
false;
 
  321     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  322     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  332     static constexpr 
index_t num_groups_per_blk  = 1;
 
  333     static constexpr 
index_t num_regs_per_blk    = 4;
 
  334     static constexpr 
index_t num_threads_per_blk = 64;
 
  341     static constexpr 
bool is_k_reduction         = 
false;
 
  343     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  344     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  354     static constexpr 
index_t num_groups_per_blk  = 4;
 
  355     static constexpr 
index_t num_regs_per_blk    = 16;
 
  356     static constexpr 
index_t num_threads_per_blk = 32;
 
  363     static constexpr 
bool is_k_reduction         = 
true;
 
  365     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  366     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  376     static constexpr 
index_t num_groups_per_blk  = 4;
 
  377     static constexpr 
index_t num_regs_per_blk    = 16;
 
  378     static constexpr 
index_t num_threads_per_blk = 32;
 
  385     static constexpr 
bool is_k_reduction         = 
true;
 
  387     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  388     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  398     static constexpr 
index_t num_groups_per_blk  = 1;
 
  399     static constexpr 
index_t num_regs_per_blk    = 4;
 
  400     static constexpr 
index_t num_threads_per_blk = 16;
 
  407     static constexpr 
bool is_k_reduction         = 
true;
 
  409     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  410     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  420     static constexpr 
index_t num_groups_per_blk  = 1;
 
  421     static constexpr 
index_t num_regs_per_blk    = 4;
 
  422     static constexpr 
index_t num_threads_per_blk = 16;
 
  429     static constexpr 
bool is_k_reduction         = 
true;
 
  431     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  432     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  442     static constexpr 
index_t num_groups_per_blk  = 4;
 
  443     static constexpr 
index_t num_regs_per_blk    = 16;
 
  444     static constexpr 
index_t num_threads_per_blk = 32;
 
  451     static constexpr 
bool is_k_reduction         = 
true;
 
  453     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  454     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  464     static constexpr 
index_t num_groups_per_blk  = 1;
 
  465     static constexpr 
index_t num_regs_per_blk    = 4;
 
  466     static constexpr 
index_t num_threads_per_blk = 16;
 
  473     static constexpr 
bool is_k_reduction         = 
true;
 
  475     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  476     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  486     static constexpr 
index_t num_groups_per_blk  = 4;
 
  487     static constexpr 
index_t num_regs_per_blk    = 16;
 
  488     static constexpr 
index_t num_threads_per_blk = 32;
 
  495     static constexpr 
bool is_k_reduction         = 
true;
 
  497     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  498     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  508     static constexpr 
index_t num_groups_per_blk  = 1;
 
  509     static constexpr 
index_t num_regs_per_blk    = 4;
 
  510     static constexpr 
index_t num_threads_per_blk = 16;
 
  517     static constexpr 
bool is_k_reduction         = 
true;
 
  519     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  520     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  530     static constexpr 
index_t num_groups_per_blk  = 4;
 
  531     static constexpr 
index_t num_regs_per_blk    = 16;
 
  532     static constexpr 
index_t num_threads_per_blk = 32;
 
  539     static constexpr 
bool is_k_reduction         = 
true;
 
  541     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  542     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  552     static constexpr 
index_t num_groups_per_blk  = 1;
 
  553     static constexpr 
index_t num_regs_per_blk    = 4;
 
  554     static constexpr 
index_t num_threads_per_blk = 16;
 
  561     static constexpr 
bool is_k_reduction         = 
true;
 
  563     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  564     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  574     static constexpr 
index_t num_groups_per_blk  = 4;
 
  575     static constexpr 
index_t num_regs_per_blk    = 16;
 
  576     static constexpr 
index_t num_threads_per_blk = 32;
 
  583     static constexpr 
bool is_k_reduction         = 
true;
 
  585     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  586     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  596     static constexpr 
index_t num_groups_per_blk  = 1;
 
  597     static constexpr 
index_t num_regs_per_blk    = 4;
 
  598     static constexpr 
index_t num_threads_per_blk = 16;
 
  605     static constexpr 
bool is_k_reduction         = 
true;
 
  607     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  608     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  618     static constexpr 
index_t num_groups_per_blk  = 4;
 
  619     static constexpr 
index_t num_regs_per_blk    = 4; 
 
  620     static constexpr 
index_t num_threads_per_blk = 16;
 
  627     static constexpr 
bool is_k_reduction         = 
true;
 
  629     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  630     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  640     static constexpr 
index_t num_groups_per_blk  = 4;
 
  641     static constexpr 
index_t num_regs_per_blk    = 16;
 
  642     static constexpr 
index_t num_threads_per_blk = 32;
 
  649     static constexpr 
bool is_k_reduction         = 
true;
 
  651     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  652     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  662     static constexpr 
index_t num_groups_per_blk  = 1;
 
  663     static constexpr 
index_t num_regs_per_blk    = 4;
 
  664     static constexpr 
index_t num_threads_per_blk = 16;
 
  671     static constexpr 
bool is_k_reduction         = 
true;
 
  673     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  674     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  684     static constexpr 
index_t num_groups_per_blk  = 4;
 
  685     static constexpr 
index_t num_regs_per_blk    = 16;
 
  686     static constexpr 
index_t num_threads_per_blk = 32;
 
  693     static constexpr 
bool is_k_reduction         = 
true;
 
  695     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  696     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  706     static constexpr 
index_t num_groups_per_blk  = 1;
 
  707     static constexpr 
index_t num_regs_per_blk    = 4;
 
  708     static constexpr 
index_t num_threads_per_blk = 16;
 
  715     static constexpr 
bool is_k_reduction         = 
true;
 
  717     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  718     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  728     static constexpr 
index_t num_groups_per_blk  = 4;
 
  729     static constexpr 
index_t num_regs_per_blk    = 16;
 
  730     static constexpr 
index_t num_threads_per_blk = 32;
 
  737     static constexpr 
bool is_k_reduction         = 
true;
 
  739     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  740     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  750     static constexpr 
index_t num_groups_per_blk  = 1;
 
  751     static constexpr 
index_t num_regs_per_blk    = 4;
 
  752     static constexpr 
index_t num_threads_per_blk = 16;
 
  759     static constexpr 
bool is_k_reduction         = 
true;
 
  761     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  762     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  772     static constexpr 
index_t num_groups_per_blk  = 4;
 
  773     static constexpr 
index_t num_regs_per_blk    = 16;
 
  774     static constexpr 
index_t num_threads_per_blk = 32;
 
  781     static constexpr 
bool is_k_reduction         = 
true;
 
  783     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  784     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  794     static constexpr 
index_t num_groups_per_blk  = 1;
 
  795     static constexpr 
index_t num_regs_per_blk    = 4;
 
  796     static constexpr 
index_t num_threads_per_blk = 16;
 
  803     static constexpr 
bool is_k_reduction         = 
true;
 
  805     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  806     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  817     static constexpr 
index_t num_groups_per_blk  = 4;    
 
  818     static constexpr 
index_t num_regs_per_blk    = 16;   
 
  819     static constexpr 
index_t num_threads_per_blk = 32;   
 
  826     static constexpr 
bool is_k_reduction         = 
true; 
 
  829     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  830     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  841     static constexpr 
index_t num_groups_per_blk  = 1;    
 
  842     static constexpr 
index_t num_regs_per_blk    = 4;    
 
  843     static constexpr 
index_t num_threads_per_blk = 16;   
 
  850     static constexpr 
bool is_k_reduction         = 
true; 
 
  853     template <index_t MPerXdlops, index_t NPerXdlops, 
class FloatA, 
class FloatB, 
class FloatC>
 
  854     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  865     static constexpr 
index_t num_groups_per_blk  = 4;    
 
  866     static constexpr 
index_t num_regs_per_blk    = 16;   
 
  867     static constexpr 
index_t num_threads_per_blk = 32;   
 
  874     static constexpr 
bool is_k_reduction         = 
true; 
 
  886     __device__ 
void run(
const FloatA& a,
 
  887                         const ScaleA& scale_a,
 
  889                         const ScaleB& scale_b,
 
  893             a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
 
  902     static constexpr 
index_t num_groups_per_blk  = 1;    
 
  903     static constexpr 
index_t num_regs_per_blk    = 4;    
 
  904     static constexpr 
index_t num_threads_per_blk = 16;   
 
  911     static constexpr 
bool is_k_reduction         = 
true; 
 
  923     __device__ 
void run(
const FloatA& a,
 
  924                         const ScaleA& scale_a,
 
  926                         const ScaleB& scale_b,
 
  931             a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
 
  935 template <
typename base_type,
 
  938           typename additional_type = base_type,
 
  939           bool is_single_rate_mfma = 
false,
 
  940           bool is_scale_mfma       = 
false>
 
  943     template <
typename base_type_,
 
  946               typename additional_type_ = base_type_,
 
  947               bool is_single_rate_mfma_ = 
false,
 
  948               bool is_scale_mfma_       = 
false>
 
  952     constexpr 
auto GetMfma<double, 16, 16>()
 
  958     constexpr 
auto GetMfma<float, 64, 64>()
 
  964     constexpr 
auto GetMfma<float, 32, 64>()
 
  970     constexpr 
auto GetMfma<float, 16, 64>()
 
  976     constexpr 
auto GetMfma<float, 8, 64>()
 
  982     constexpr 
auto GetMfma<float, 4, 64>()
 
  988     constexpr 
auto GetMfma<float, 32, 32>()
 
  994     constexpr 
auto GetMfma<float, 16, 16>()
 
 1000     constexpr 
auto GetMfma<half_t, 64, 64>()
 
 1006     constexpr 
auto GetMfma<half_t, 32, 64>()
 
 1012     constexpr 
auto GetMfma<half_t, 32, 32, half_t, false>()
 
 1014 #if defined(__gfx950__) 
 1021     constexpr 
auto GetMfma<half_t, 32, 32, half_t, true>()
 
 1027     constexpr 
auto GetMfma<half_t, 16, 16, half_t, false>()
 
 1029 #if defined(__gfx950__) 
 1037     constexpr 
auto GetMfma<half_t, 16, 16, half_t, true>()
 
 1043     constexpr 
auto GetMfma<half_t, 16, 64>()
 
 1049     constexpr 
auto GetMfma<half_t, 8, 64>()
 
 1055     constexpr 
auto GetMfma<half_t, 4, 64>()
 
 1061     constexpr 
auto GetMfma<bhalf_t, 32, 32, bhalf_t, false>()
 
 1063 #if defined(__gfx950__) 
 1065 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP) 
 1073     constexpr 
auto GetMfma<bhalf_t, 32, 32, bhalf_t, true>()
 
 1075 #if defined(CK_USE_AMD_MFMA_BF16_1K_OP) 
 1083     constexpr 
auto GetMfma<bhalf_t, 16, 16, bhalf_t, false>()
 
 1085 #if defined(__gfx950__) 
 1087 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP) 
 1095     constexpr 
auto GetMfma<bhalf_t, 16, 16, bhalf_t, true>()
 
 1097 #if defined(CK_USE_AMD_MFMA_BF16_1K_OP) 
 1105     constexpr 
auto GetMfma<int8_t, 32, 32, int8_t, false>()
 
 1107 #if defined(__gfx950__) 
 1109 #elif defined(__gfx942__) 
 1117     constexpr 
auto GetMfma<int8_t, 32, 32, int8_t, true>()
 
 1119 #if defined(__gfx942__) || defined(__gfx950__) 
 1127     constexpr 
auto GetMfma<int8_t, 16, 16, int8_t, false>()
 
 1129 #if defined(__gfx950__) 
 1131 #elif defined(__gfx942__) 
 1139     constexpr 
auto GetMfma<int8_t, 16, 16, int8_t, true>()
 
 1141 #if defined(__gfx942__) || defined(__gfx950__) 
 1149     constexpr 
auto GetMfma<f8_t, 32, 32, f8_t, true, false>()
 
 1155     constexpr 
auto GetMfma<f8_t, 32, 32, f8_t, false, false>()
 
 1157 #if defined(__gfx950__) 
 1165     constexpr 
auto GetMfma<f8_t, 32, 32, f8_t, false, true>()
 
 1171     constexpr 
auto GetMfma<bf8_t, 32, 32, f8_t, false, true>()
 
 1176     constexpr 
auto GetMfma<f4_t, 32, 32, f4_t, false, true>()
 
 1181     constexpr 
auto GetMfma<f4_t, 16, 16, f4_t, false, true>()
 
 1187     constexpr 
auto GetMfma<f8_t, 16, 16, f8_t, true, false>()
 
 1193     constexpr 
auto GetMfma<f8_t, 16, 16, f8_t, false, false>()
 
 1195 #if defined(__gfx950__) 
 1203     constexpr 
auto GetMfma<f8_t, 16, 16, f8_t, false, true>()
 
 1209     constexpr 
auto GetMfma<bf8_t, 16, 16, bf8_t, false, true>()
 
 1215     constexpr 
auto GetMfma<f8_t, 16, 16, bf8_t, false, true>()
 
 1221     constexpr 
auto GetMfma<bf8_t, 16, 16, f8_t, false, true>()
 
 1227     constexpr 
auto GetMfma<f6_t, 32, 32, f6_t, false, true>()
 
 1232     constexpr 
auto GetMfma<f6_t, 16, 16, f6_t, false, true>()
 
 1237     constexpr 
auto GetMfma<bf6_t, 32, 32, bf6_t, false, true>()
 
 1242     constexpr 
auto GetMfma<bf6_t, 16, 16, bf6_t, false, true>()
 
 1248     constexpr 
auto GetMfma<bf8_t, 32, 32, bf8_t, true, false>()
 
 1254     constexpr 
auto GetMfma<bf8_t, 32, 32, bf8_t, false, false>()
 
 1256 #if defined(__gfx950__) 
 1264     constexpr 
auto GetMfma<bf8_t, 16, 16, bf8_t, true, false>()
 
 1270     constexpr 
auto GetMfma<bf8_t, 16, 16, bf8_t, false, false>()
 
 1272 #if defined(__gfx950__) 
 1280     constexpr 
auto GetMfma<f8_t, 32, 32, bf8_t, true, false>()
 
 1286     constexpr 
auto GetMfma<f8_t, 32, 32, bf8_t, false, false>()
 
 1288 #if defined(__gfx950__) 
 1296     constexpr 
auto GetMfma<f8_t, 16, 16, bf8_t, true, false>()
 
 1302     constexpr 
auto GetMfma<f8_t, 16, 16, bf8_t, false, false>()
 
 1304 #if defined(__gfx950__) 
 1312     constexpr 
auto GetMfma<bf8_t, 32, 32, f8_t, true, false>()
 
 1318     constexpr 
auto GetMfma<bf8_t, 32, 32, f8_t, false, false>()
 
 1320 #if defined(__gfx950__) 
 1328     constexpr 
auto GetMfma<bf8_t, 16, 16, f8_t, true, false>()
 
 1334     constexpr 
auto GetMfma<bf8_t, 16, 16, f8_t, false, false>()
 
 1336 #if defined(__gfx950__) 
 1347                                                             is_single_rate_mfma,
 
 1348                                                             is_scale_mfma>()>{};
 
 1354                       "wrong! num_regs_per_blk");
 
 1357                       "n_per_blk != num_threads_per_blk");
 
 1361                       "m_per_blk != num_input_blks * num_regs_per_blk");
 
 1365                       "incorrect num_output_blks");
 
 1369                       "num_regs_per_blk incorrect");
 
 1373                       "is_k_reduction wrong!");
 
 1378         static_assert(NPerXdlops >= MPerXdlops, 
"only support ABroadcast");
 
 1391 template <
typename base_type,
 
 1395           typename additional_type = base_type,
 
 1396           bool TransposeC          = 
false,
 
 1397           bool is_scale_mfma       = 
false>
 
 1414         return MPerXdlops * NPerXdlops /
 
 1420         static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 ||
 
 1422                       "Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
 
 1424         static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
 
 1426                       "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
 
 1428         static_assert(KPack % 
mfma_instr.k_per_blk == 0, 
"KPack should be a multiple of k_per_blk");
 
 1433     template <
typename CDesc_M0_N0_M1_N1_M2_N2>
 
 1434     __host__ __device__ 
static constexpr 
auto 
 1437         const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I0);
 
 1438         const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I1);
 
 1439         const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I2);
 
 1440         const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I3);
 
 1443             c_desc_m0_n0_m1_n1_m2_n2,
 
 1468     template <
typename CDesc_M0_N0_M1_N1_M2_N2>
 
 1470         const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
 
 1472         const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I0);
 
 1473         const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I1);
 
 1474         const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I2);
 
 1475         const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I3);
 
 1476         const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I4);
 
 1477         const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I5);
 
 1480             c_desc_m0_n0_m1_n1_m2_n2,
 
 1511     template <
typename CDesc_M0_N0_M1_N1_M2_N2>
 
 1512     __host__ __device__ 
static constexpr 
auto 
 1515         const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I0);
 
 1516         const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I1);
 
 1517         const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I2);
 
 1518         const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I3);
 
 1521             c_desc_m0_n0_m1_n1_m2_n2,
 
 1544     template <
typename CDesc_G_M0_N0_M1_N1_M2_N2>
 
 1546         const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
 
 1548         const auto G  = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I0);
 
 1549         const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I1);
 
 1550         const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I2);
 
 1551         const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I3);
 
 1552         const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I4);
 
 1555             c_desc_g_m0_n0_m1_n1_m2_n2,
 
 1583         return MPerXdlops * NPerXdlops / 
mfma_instr.wave_size;
 
 1588     template <
class FloatA, 
class FloatB, 
class FloatC>
 
 1589     __device__ 
void Run(
const FloatA& p_a_wave, 
const FloatB& p_b_wave, FloatC& p_c_thread)
 const 
 1598             "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!");
 
 1601             if constexpr(!TransposeC)
 
 1603                 mfma_instr.template run<MPerXdlops, NPerXdlops>(
 
 1604                     p_a_wave[k], p_b_wave[k], p_c_thread);
 
 1608                 mfma_instr.template run<MPerXdlops, NPerXdlops>(
 
 1609                     p_b_wave[k], p_a_wave[k], p_c_thread);
 
 1621     __device__ 
void Run(
const FloatA& p_a_wave,
 
 1622                         const ScaleA& a_scale_thread,
 
 1623                         const FloatB& p_b_wave,
 
 1624                         const ScaleB& b_scale_thread,
 
 1625                         FloatC& p_c_thread)
 const 
 1628             if constexpr(!TransposeC)
 
 1630                 mfma_instr.template run<MPerXdlops, NPerXdlops, OpselA, OpselB>(
 
 1631                     p_a_wave[k], a_scale_thread[k], p_b_wave[k], b_scale_thread[k], p_c_thread);
 
 1635                 mfma_instr.template run<MPerXdlops, NPerXdlops, OpselB, OpselA>(
 
 1636                     p_b_wave[k], b_scale_thread[k], p_a_wave[k], a_scale_thread[k], p_c_thread);
 
 1653         const auto blk_idx =
 
 1654             threadidx_to_blk_idx_adaptor.CalculateBottomIndex(
make_multi_index(laneId));
 
 1656         const auto blk_id = blk_idx[
I1];
 
 1657         const auto blk_td = blk_idx[
I2];
 
 1667         const auto blk_id = blk_idx[
I0];
 
 1668         const auto blk_td = blk_idx[
I1];
 
 1685         const auto blk_id = blk_idx[
I0];
 
 1686         const auto blk_td = blk_idx[
I1];
 
 1702         const auto blk_id = blk_idx[
I0];
 
 1703         const auto blk_td = blk_idx[
I1];
 
 1708         return TransposeC ? 
CIndex{n_offset, m_offset} : 
CIndex{m_offset, n_offset};
 
 1715         const auto blk_id = blk_idx[
I0];
 
 1716         const auto blk_td = blk_idx[
I1];
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
MfmaInstr
Definition: xdlops_gemm.hpp:42
 
@ mfma_f32_32x32x64f8f6f4
 
@ mfma_scale_f32_32x32x64f8f6f4
 
@ mfma_f32_16x16x16bf16_1k
 
@ mfma_scale_f32_16x16x128f8f6f4
 
@ mfma_f32_16x16x32bf8bf8
 
@ mfma_f32_16x16x128f8f6f4
 
@ mfma_f32_32x32x16bf8bf8
 
@ mfma_f32_32x32x8bf16_1k
 
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
 
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
 
typename packed_type_info< T >::element_type element_type_t
Definition: data_type.hpp:416
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
 
int32_t index_t
Definition: ck.hpp:297
 
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
Definition: xdlops_gemm.hpp:942
 
__host__ constexpr __device__ MfmaSelector()
Definition: xdlops_gemm.hpp:1350
 
static constexpr bool IsABroadcast()
Definition: xdlops_gemm.hpp:1376
 
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1388
 
static constexpr auto GetMfma()
 
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1343
 
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1382
 
Definition: sequence.hpp:43
 
Definition: xdlops_gemm.hpp:1399
 
static constexpr auto mfma_instr
Definition: xdlops_gemm.hpp:1739
 
__host__ constexpr __device__ XdlopsGemm()
Definition: xdlops_gemm.hpp:1418
 
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition: xdlops_gemm.hpp:1680
 
static __device__ auto GetBlkIdx()
Definition: xdlops_gemm.hpp:1643
 
static constexpr auto I2
Definition: xdlops_gemm.hpp:1402
 
static constexpr __device__ index_t GetNumBlks()
Definition: xdlops_gemm.hpp:1410
 
static __device__ auto GetLaneId()
Definition: xdlops_gemm.hpp:1641
 
static constexpr auto K0PerXdlops
Definition: xdlops_gemm.hpp:1743
 
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1469
 
static constexpr __device__ index_t GetNumXdlops()
Definition: xdlops_gemm.hpp:1412
 
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition: xdlops_gemm.hpp:1662
 
static constexpr bool is_single_rate_mfma
Definition: xdlops_gemm.hpp:1724
 
static __device__ CIndex4D GetBeginOfThreadBlk4D(index_t, index_t)
Definition: xdlops_gemm.hpp:1711
 
static constexpr __device__ index_t GetWaveSize()
Definition: xdlops_gemm.hpp:1586
 
static constexpr __device__ index_t GetRegSizePerXdlops()
Definition: xdlops_gemm.hpp:1581
 
static constexpr auto I5
Definition: xdlops_gemm.hpp:1405
 
static constexpr auto I3
Definition: xdlops_gemm.hpp:1403
 
static constexpr auto I0
Definition: xdlops_gemm.hpp:1400
 
__device__ void Run(const FloatA &p_a_wave, const ScaleA &a_scale_thread, const FloatB &p_b_wave, const ScaleB &b_scale_thread, FloatC &p_c_thread) const
Definition: xdlops_gemm.hpp:1621
 
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1435
 
__host__ static constexpr __device__ auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_G_M0_N0_M1_N1_M2_N2 &c_desc_g_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1545
 
static constexpr auto I1
Definition: xdlops_gemm.hpp:1401
 
static constexpr auto K1PerXdlops
Definition: xdlops_gemm.hpp:1742
 
static constexpr auto KPerXdlops
Definition: xdlops_gemm.hpp:1741
 
static constexpr auto I4
Definition: xdlops_gemm.hpp:1404
 
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
Definition: xdlops_gemm.hpp:1589
 
static constexpr auto mfma
Definition: xdlops_gemm.hpp:1732
 
static __device__ CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
Definition: xdlops_gemm.hpp:1698
 
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1513
 
__host__ static constexpr __device__ auto GetCM0M1M2NThreadBlkLengths()
Definition: xdlops_gemm.hpp:1745
 
Definition: integral_constant.hpp:20
 
Definition: amd_xdlops.hpp:1202
 
Definition: amd_xdlops.hpp:303
 
Definition: amd_xdlops.hpp:193
 
Definition: amd_xdlops.hpp:70
 
Definition: amd_xdlops.hpp:269
 
Definition: amd_xdlops.hpp:1483
 
Definition: amd_xdlops.hpp:1609
 
Definition: amd_xdlops.hpp:159
 
Definition: amd_xdlops.hpp:1546
 
Definition: amd_xdlops.hpp:1420
 
Definition: amd_xdlops.hpp:207
 
Definition: amd_xdlops.hpp:56
 
Definition: amd_xdlops.hpp:331
 
Definition: amd_xdlops.hpp:249
 
Definition: amd_xdlops.hpp:1451
 
Definition: amd_xdlops.hpp:1577
 
Definition: amd_xdlops.hpp:139
 
Definition: amd_xdlops.hpp:1514
 
Definition: amd_xdlops.hpp:1388
 
Definition: amd_xdlops.hpp:15
 
Definition: amd_xdlops.hpp:42
 
Definition: amd_xdlops.hpp:317
 
Definition: amd_xdlops.hpp:112
 
Definition: amd_xdlops.hpp:481
 
Definition: amd_xdlops.hpp:289
 
Definition: amd_xdlops.hpp:179
 
Definition: amd_xdlops.hpp:84
 
Definition: amd_xdlops.hpp:221
 
Definition: amd_xdlops.hpp:461
 
Definition: amd_xdlops.hpp:364
 
Definition: amd_xdlops.hpp:442
 
Definition: amd_xdlops.hpp:403
 
Definition: amd_xdlops.hpp:423
 
Definition: amd_xdlops.hpp:383
 
Definition: amd_xdlops.hpp:345
 
Definition: amd_xdlops.hpp:886
 
Definition: amd_xdlops.hpp:666
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:854
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:432
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:300
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:167
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:410
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:718
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:806
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:278
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:762
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:674
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:322
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:145
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:476
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:366
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:696
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:784
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:256
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:740
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:652
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:101
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:123
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:454
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:212
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:830
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:388
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:234
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:190
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:344
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:630
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:520
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:564
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:608
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:542
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:586
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:498
 
__device__ void run(const FloatA &a, const ScaleA &scale_a, const FloatB &b, const ScaleB &scale_b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:923
 
__device__ void run(const FloatA &a, const ScaleA &scale_a, const FloatB &b, const ScaleB &scale_b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:886
 
Definition: xdlops_gemm.hpp:83
 
Definition: functional2.hpp:33