19 template <
typename InOutElementFunc,
 
   20           typename... InOutDstrTensors,
 
   22               std::negation<std::is_same<std::remove_const_t<InOutDstrTensors>, null_tensor>>...>>>
 
   24                                            InOutDstrTensors&... inout_dstr_tensors)
 
   29     constexpr 
index_t thread_buffer_size =
 
   30         __type_pack_element<0, InOutDstrTensors...>::get_thread_buffer_size();
 
   33         [&](
auto i) { inout_element_func(inout_dstr_tensors.get_thread_buffer().at(i)...); });
 
   36 template <
typename InElementFunc,
 
   39               std::conjunction_v<std::negation<std::is_same<InTensor, null_tensor>>...>>>
 
   41                                         const InTensor&... in_dstr_tensors)
 
   43     using OutDataType = decltype(in_element_func(
typename InTensor::DataType{}...));
 
   47     constexpr 
auto in_tile_dstr = __type_pack_element<0, InTensor...>::get_tile_distribution();
 
   49     constexpr 
index_t thread_buffer_size =
 
   50         __type_pack_element<0, InTensor...>::get_thread_buffer_size();
 
   52     auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
 
   55         out_dstr_tensor.get_thread_buffer()(i) =
 
   56             in_element_func(in_dstr_tensors.get_thread_buffer()[i]...);
 
   59     return out_dstr_tensor;
 
   70 template <
typename InElementFunc, 
typename Tuple, 
size_t... I>
 
   73                                                   std::index_sequence<I...>)
 
   86 template <
typename InElementFunc, 
typename Tuple>
 
   90     static constexpr 
auto size = Tuple::size();
 
   94 template <
typename DstrTensors, 
typename T>
 
   99             x = type_convert<typename DstrTensors::DataType, remove_cvref_t<T>>(value);
 
  104 template <
typename T>
 
  111 template <
typename DstrTensors, index_t v, 
bool skip_subdword_opt = false>
 
  115     using elem_type             = 
typename DstrTensors::DataType;
 
  116     constexpr 
index_t elem_size = 
sizeof(elem_type);
 
  118     constexpr 
index_t tensor_bytes = DstrTensors::get_thread_buffer_size() * elem_size;
 
  121     if constexpr(v == 0 && tensor_bytes % 4 == 0 && !skip_subdword_opt)
 
  123 #if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 
  124         auto& buffer = dstr_tensor.get_thread_buffer();
 
  126         static_for<0, tensor_bytes / 4, 1>{}([&](
auto i_write) {
 
  127             if constexpr(elem_size == 1)
 
  130                 constexpr 
auto values = ext_vector_t<elem_type, 4>{0, 0, 0, 0};
 
  132                 buffer[i_write * 4 + 0] = values.x;
 
  133                 buffer[i_write * 4 + 1] = values.y;
 
  134                 buffer[i_write * 4 + 2] = values.z;
 
  135                 buffer[i_write * 4 + 3] = values.w;
 
  137             else if constexpr(elem_size == 2)
 
  140                 constexpr 
auto values = ext_vector_t<elem_type, 2>{0, 0};
 
  142                 buffer[i_write * 2 + 0] = values.x;
 
  143                 buffer[i_write * 2 + 1] = values.y;
 
  145             else if constexpr(elem_size == 4)
 
  148                 constexpr elem_type value = 0;
 
  150                 buffer[i_write] = value;
 
  154                 static_assert(
false, 
"type not supported");
 
  158         using dvec_t = array<
index_t, tensor_bytes / 4>;
 
  159         auto& tensor = 
reinterpret_cast<dvec_t&
>(dstr_tensor.get_thread_buffer());
 
  160         for(
auto i = 0; i < tensor.size(); i++)
 
  176 template <
typename DstrTensors>
 
  184 template <
typename OutDataType, 
typename InTensor>
 
  187 #if defined(__gfx94__) 
  189     constexpr 
auto in_tile_dstr = InTensor::get_tile_distribution();
 
  191     constexpr 
index_t thread_buffer_size = InTensor::get_thread_buffer_size();
 
  192     static_assert(thread_buffer_size % 4 == 0);
 
  193     constexpr 
index_t thread_buffer_size_pk = thread_buffer_size / 4;
 
  195     auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
 
  196 #pragma clang diagnostic push 
  197 #pragma clang diagnostic ignored "-Wuninitialized" 
  203         uint32_t x = __builtin_amdgcn_cvt_pk_fp8_f32(
 
  209         uint32_t y = __builtin_amdgcn_cvt_pk_fp8_f32(
 
  215         constexpr 
int32_t m0 = 0x05040100;
 
  218         vec_t d = bit_cast<vec_t>(__builtin_amdgcn_perm(y, x, m0));
 
  219         out_dstr_tensor.get_thread_buffer().template set_as<vec_t>(
number<i>{}, d);
 
  221 #pragma clang diagnostic pop 
  223     return out_dstr_tensor;
 
  231 template <
typename OutDataType, 
typename InTensor>
 
  234 #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) 
  236     constexpr 
auto in_tile_dstr = InTensor::get_tile_distribution();
 
  238     constexpr 
index_t thread_buffer_size = InTensor::get_thread_buffer_size();
 
  239     static_assert(thread_buffer_size % 2 == 0);
 
  240     constexpr 
index_t thread_buffer_size_pk = thread_buffer_size / 2;
 
  242     auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
 
  245     for(
index_t i = 0; i < thread_buffer_size_pk; i++)
 
  247         auto o = __builtin_amdgcn_cvt_pkrtz(in_dstr_tensors.get_thread_buffer()[2 * i + 0],
 
  248                                             in_dstr_tensors.get_thread_buffer()[2 * i + 1]);
 
  250         out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = o.x;
 
  251         out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = o.y;
 
  254     return out_dstr_tensor;
 
  262 #if CK_TILE_USE_SUBDWORD_TILE_CAST 
  265 template <
typename OutDataType, 
typename InTensor>
 
  266 CK_TILE_DEVICE auto cast_tile_opt_subdword(
const InTensor& in_dstr_tensors)
 
  268     constexpr 
auto in_tile_dstr = InTensor::get_tile_distribution();
 
  270     auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
 
  274     constexpr 
index_t i_elem_bytes = 
sizeof(i_type);
 
  275     constexpr 
index_t o_elem_bytes = 
sizeof(o_type);
 
  276     static_assert(i_elem_bytes < 4 || o_elem_bytes < 4);
 
  279         (i_elem_bytes >= o_elem_bytes) ? (4 / o_elem_bytes) : (4 / i_elem_bytes);
 
  280     static_assert(bulk_size != 0);
 
  285     constexpr 
index_t thread_buffer_size = InTensor::get_thread_buffer_size();
 
  287     constexpr 
index_t iters = thread_buffer_size / bulk_size;
 
  288     constexpr 
index_t rems  = thread_buffer_size % bulk_size;
 
  295             o_type data[bulk_size];
 
  299         static_for<0, bulk_size, 1>{}([&o_bulk, &in_dstr_tensors, &i](
auto ib) {
 
  300             o_bulk.data[ib.value] = 
static_cast<o_type
>(
 
  301                 in_dstr_tensors.get_thread_buffer()
 
  302                     .template get_as<i_type>()[
number<bulk_size * i.value + ib.value>{}]);
 
  312         out_dstr_tensor.get_thread_buffer().template set_as<o_bulk_type>(i, o_bulk.bulk);
 
  315     static_for<0, rems, 1>{}([&](
auto r) {
 
  317         auto idx = number<iters * bulk_size + r>{};
 
  318         out_dstr_tensor.get_thread_buffer().at(idx) =
 
  319             static_cast<o_type
>(in_dstr_tensors.get_thread_buffer().at(idx));
 
  322     return out_dstr_tensor;
 
  327 template <
typename DstType, 
typename SrcTensor>
 
  330     if constexpr((std::is_same_v<DstType, fp8_t> || std::is_same_v<DstType, bf8_t>) &&
 
  331                  std::is_same_v<typename SrcTensor::DataType, float> &&
 
  332                  (SrcTensor::get_thread_buffer_size() % 4 == 0))
 
  334         return impl::cast_tile_pk_fp8_fp32<DstType, SrcTensor>(src_tensor);
 
  336 #if CK_TILE_USE_PK_FP16_TILE_CAST 
  337     else if constexpr(std::is_same_v<DstType, fp16_t> &&
 
  338                       std::is_same_v<typename SrcTensor::DataType, float> &&
 
  339                       (SrcTensor::get_thread_buffer_size() % 2 == 0))
 
  341         return impl::cast_tile_pk_fp16_fp32<DstType, SrcTensor>(src_tensor);
 
  344 #if CK_TILE_USE_SUBDWORD_TILE_CAST 
  345     else if constexpr(
sizeof(DstType) < 4 || 
sizeof(
typename SrcTensor::DataType) < 4)
 
  347         return impl::cast_tile_opt_subdword<DstType, SrcTensor>(src_tensor);
 
  351         return tile_elementwise_in(type_convert<DstType, typename SrcTensor::DataType>, src_tensor);
 
  355 template <
typename InOutElementFunc,
 
  356           typename... MaybeNullTensor,
 
  358               std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, null_tensor>...>>>
 
  364 template <
typename InElementFunc,
 
  365           typename... MaybeNullTensor,
 
  367               std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, null_tensor>...>>>
 
#define CK_TILE_DEVICE
Definition: config.hpp:40
 
CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor &in_dstr_tensors)
Definition: tile_elementwise.hpp:185
 
CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor &in_dstr_tensors)
Definition: tile_elementwise.hpp:232
 
Definition: cluster_descriptor.hpp:13
 
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
 
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
 
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition: tile_elementwise.hpp:23
 
int32_t index_t
Definition: integer.hpp:9
 
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
 
constant< v > number
Definition: integral_constant.hpp:33
 
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc &in_element_func, const Tuple &t, std::index_sequence< I... >)
Template function that "unpacks" a tuple and applies an element-wise operation.
Definition: tile_elementwise.hpp:71
 
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition: tile_elementwise.hpp:328
 
int32_t int32_t
Definition: integer.hpp:10
 
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:177
 
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
 
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
 
Definition: integral_constant.hpp:13
 
Definition: null_tensor.hpp:9
 
Definition: functional.hpp:43