14 template <
typename DataType, 
typename ComputeDataType = 
float>
 
   20                                                               bool use_1_row_sin_cos = 
false)
 
   27     assert(
static_cast<std::size_t
>(rotary_dim) <= input_bsd.
get_length(2));
 
   29     output_bsd.
ForEach([&](
auto& 
self, 
auto i) {
 
   33             self(i) = input_bsd(i);
 
   36         assert(i_d < rotary_dim);
 
   39         const index_t i_s_cos_sin = (use_1_row_sin_cos ? 0 : i_s);
 
   41         const ComputeDataType 
cos = type_convert<ComputeDataType>(
 
   42             interleaved ? cos_sd(i_s_cos_sin, i_d / 2)
 
   43                         : cos_sd(i_s_cos_sin, i_d % cos_sd.
get_length(1)));
 
   44         const ComputeDataType 
sin = type_convert<ComputeDataType>(
 
   45             interleaved ? sin_sd(i_s_cos_sin, i_d / 2)
 
   46                         : sin_sd(i_s_cos_sin, i_d % sin_sd.
get_length(1)));
 
   48         const ComputeDataType half_rotated_input = [&] {
 
   53                 const bool is_even         = (i_d % 2 == 0);
 
   54                 const index_t pos          = i_d + (is_even ? 1 : -1);
 
   55                 const ComputeDataType sign = (is_even ? -1 : 1);
 
   56                 return sign * type_convert<ComputeDataType>(input_bsd(i_b, i_s, pos));
 
   60                 const index_t half_rdim    = (rotary_dim / 2);
 
   61                 const index_t pos          = (i_d + half_rdim) % rotary_dim;
 
   62                 const ComputeDataType sign = (pos < half_rdim ? 1 : -1);
 
   63                 return sign * type_convert<ComputeDataType>(input_bsd(i_b, i_s, pos));
 
   66         ComputeDataType result =
 
   67             type_convert<ComputeDataType>(input_bsd(i)) * 
cos + half_rotated_input * 
sin;
 
   69         self(i) = type_convert<DataType>(result);
 
#define CK_TILE_HOST
Definition: config.hpp:39
 
Definition: cluster_descriptor.hpp:13
 
CK_TILE_HOST T cos(T x)
Definition: math.hpp:752
 
int32_t index_t
Definition: integer.hpp:9
 
CK_TILE_HOST T sin(T x)
Definition: math.hpp:698
 
CK_TILE_HOST void reference_batched_rotary_position_embedding(const HostTensor< DataType > &input_bsd, const HostTensor< DataType > &cos_sd, const HostTensor< DataType > &sin_sd, bool interleaved, HostTensor< DataType > &output_bsd, bool use_1_row_sin_cos=false)
Definition: reference_batched_rotary_position_embedding.hpp:15
 
Definition: host_tensor.hpp:336
 
void ForEach(F &&f)
Definition: host_tensor.hpp:431
 
std::size_t get_num_of_dimension() const
Definition: host_tensor.hpp:396
 
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:388