/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp Source File
device_sparse_embeddings_forward_layernorm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
16 
17 namespace ck {
18 namespace tensor_operation {
19 namespace device {
20 
21 template <typename EmbType,
22  typename IndexType,
23  typename GammaDataType,
24  typename BetaDataType,
25  typename AccDataType,
26  typename OutType,
27  typename EmbElementwiseOperation,
28  ck::index_t BlockSize,
29  ck::index_t DimClusterSize,
30  ck::index_t RowClusterSize,
31  ck::index_t DimPerBlock,
32  ck::index_t RowPerBlock,
33  ck::index_t DimThreadSize,
34  ck::index_t RowVectorSize,
35  ck::index_t NumEmbeddings>
37 {
38  static auto MakeOutputDescriptor(const index_t index_length, const index_t rows)
39  {
40  return make_naive_tensor_descriptor_packed(make_tuple(index_length, rows));
41  }
42 
43  struct Argument : public BaseArgument
44  {
45  Argument(OutType* p_out,
48  const GammaDataType* p_gamma,
49  const BetaDataType* p_beta,
50  const ck::index_t EmbeddingDim,
51  const ck::index_t IndexLength,
52  const AccDataType epsilon,
53  const EmbElementwiseOperation emb_elementwise_op)
54  : p_out_(p_out),
55  p_embs_(p_embs),
56  p_indexs_(p_indexs),
57  p_gamma_(p_gamma),
58  p_beta_(p_beta),
59  EmbeddingDim_(EmbeddingDim),
60  IndexLength_(IndexLength),
61  epsilon_(epsilon),
62  emb_elementwise_op_(emb_elementwise_op)
63  {
64  grid_size_ = (IndexLength + DimClusterSize - 1) / DimClusterSize;
65  }
66 
67  OutType* p_out_;
70  const GammaDataType* p_gamma_;
71  const BetaDataType* p_beta_;
74  AccDataType epsilon_;
75  EmbElementwiseOperation emb_elementwise_op_;
76 
77  size_t grid_size_;
78  };
79 
80  std::unique_ptr<BaseArgument>
81  MakeArgumentPointer(void* p_out,
84  const void* p_gamma,
85  const void* p_beta,
86  ck::index_t EmbeddingDim,
87  ck::index_t IndexLength,
88  const AccDataType epsilon,
89  const EmbElementwiseOperation emb_elementwise_op)
90  {
91  return std::make_unique<Argument>(reinterpret_cast<OutType*>(p_out),
92  p_embs,
93  p_indexs,
94  reinterpret_cast<const GammaDataType*>(p_gamma),
95  reinterpret_cast<const BetaDataType*>(p_beta),
96  EmbeddingDim,
97  IndexLength,
98  epsilon,
99  emb_elementwise_op);
100  }
101 
104  IndexType,
105  GammaDataType,
106  BetaDataType,
107  AccDataType,
108  OutType,
109  decltype(MakeOutputDescriptor(1, 1)),
110  EmbElementwiseOperation,
111  BlockSize,
112  DimClusterSize,
113  RowClusterSize,
114  DimPerBlock,
115  RowPerBlock,
116  DimThreadSize,
117  RowVectorSize,
118  NumEmbeddings>;
119 
120  struct Invoker : public BaseInvoker
121  {
122  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
123  {
124  auto out_desc = MakeOutputDescriptor(arg.IndexLength_, arg.EmbeddingDim_);
125  const auto kernel_main =
127  EmbType,
128  IndexType,
129  GammaDataType,
130  BetaDataType,
131  AccDataType,
132  OutType,
133  decltype(out_desc),
134  EmbElementwiseOperation,
135  NumEmbeddings>;
136  float avg_time = 0;
137  avg_time += launch_and_time_kernel(stream_config,
138  kernel_main,
139  dim3(arg.grid_size_),
140  dim3(BlockSize),
141  0,
142  arg.p_out_,
143  arg.p_embs_,
144  arg.p_indexs_,
145  arg.p_gamma_,
146  arg.p_beta_,
147  out_desc,
148  arg.epsilon_,
149  arg.emb_elementwise_op_);
150 
151  return (avg_time);
152  }
153 
154  float Run(const BaseArgument* p_arg,
155  const StreamConfig& stream_config = StreamConfig{}) override
156  {
157  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
158  };
159  };
160 
161  static bool IsSupportedArgument(const Argument* p_arg)
162  {
163  return (RowPerBlock == p_arg->EmbeddingDim_);
164  }
165 
166  bool IsSupportedArgument(const BaseArgument* p_arg) override
167  {
168  return IsSupportedArgument(dynamic_cast<const Argument*>(p_arg));
169  }
170 
171  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer()
172  {
173  return std::make_unique<Invoker>();
174  }
175 
176  std::string GetTypeString() const override
177  {
178  auto str = std::stringstream();
179 
180  // clang-format off
181  str << "DeviceSparseEmbeddingsForwardLayernorm_"<< BlockSize << "_" <<
182  DimClusterSize << "x" << RowClusterSize << "_" <<
183  DimPerBlock << "x" << RowPerBlock << "_" <<
184  DimThreadSize << "x" << RowVectorSize;
185  // clang-format on
186 
187  return str.str();
188  }
189 };
190 
191 } // namespace device
192 } // namespace tensor_operation
193 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__global__ void kernel_sparse_embeddings_forward_layernorm(OutType *p_out, const ck::Array< EmbType *, NumEmbeddings > p_embs, const ck::Array< IndexType *, NumEmbeddings > p_indexes, const GammaDataType *p_gamma, const BetaDataType *p_beta, const OutGridDesc out_grid_desc, const AccDataType epsilon, const EmbElementwiseOperation emb_elementwise_op)
Definition: gridwise_sparse_embeddings_forward_layernorm.hpp:26
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
Definition: stream_config.hpp:10
Definition: gridwise_sparse_embeddings_forward_layernorm.hpp:57
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_base.hpp:76
Definition: device_sparse_embeddings_forward_layernorm.hpp:44
const GammaDataType * p_gamma_
Definition: device_sparse_embeddings_forward_layernorm.hpp:70
ck::index_t IndexLength_
Definition: device_sparse_embeddings_forward_layernorm.hpp:73
ck::Array< EmbType *, NumEmbeddings > p_embs_
Definition: device_sparse_embeddings_forward_layernorm.hpp:68
size_t grid_size_
Definition: device_sparse_embeddings_forward_layernorm.hpp:77
OutType * p_out_
Definition: device_sparse_embeddings_forward_layernorm.hpp:67
const BetaDataType * p_beta_
Definition: device_sparse_embeddings_forward_layernorm.hpp:71
ck::index_t EmbeddingDim_
Definition: device_sparse_embeddings_forward_layernorm.hpp:72
Argument(OutType *p_out, const ck::Array< EmbType *, NumEmbeddings > &p_embs, const ck::Array< IndexType *, NumEmbeddings > &p_indexs, const GammaDataType *p_gamma, const BetaDataType *p_beta, const ck::index_t EmbeddingDim, const ck::index_t IndexLength, const AccDataType epsilon, const EmbElementwiseOperation emb_elementwise_op)
Definition: device_sparse_embeddings_forward_layernorm.hpp:45
AccDataType epsilon_
Definition: device_sparse_embeddings_forward_layernorm.hpp:74
ck::Array< IndexType *, NumEmbeddings > p_indexs_
Definition: device_sparse_embeddings_forward_layernorm.hpp:69
EmbElementwiseOperation emb_elementwise_op_
Definition: device_sparse_embeddings_forward_layernorm.hpp:75
Definition: device_sparse_embeddings_forward_layernorm.hpp:121
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_sparse_embeddings_forward_layernorm.hpp:122
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_sparse_embeddings_forward_layernorm.hpp:154
Definition: device_sparse_embeddings_forward_layernorm.hpp:37
std::unique_ptr< BaseArgument > MakeArgumentPointer(void *p_out, const ck::Array< EmbType *, NumEmbeddings > &p_embs, const ck::Array< IndexType *, NumEmbeddings > &p_indexs, const void *p_gamma, const void *p_beta, ck::index_t EmbeddingDim, ck::index_t IndexLength, const AccDataType epsilon, const EmbElementwiseOperation emb_elementwise_op)
Definition: device_sparse_embeddings_forward_layernorm.hpp:81
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()
Definition: device_sparse_embeddings_forward_layernorm.hpp:171
static bool IsSupportedArgument(const Argument *p_arg)
Definition: device_sparse_embeddings_forward_layernorm.hpp:161
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_sparse_embeddings_forward_layernorm.hpp:166
GridwiseSparseEmbeddingsForwardLayernorm< EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, decltype(MakeOutputDescriptor(1, 1)), EmbElementwiseOperation, BlockSize, DimClusterSize, RowClusterSize, DimPerBlock, RowPerBlock, DimThreadSize, RowVectorSize, NumEmbeddings > GridwiseSparseEmbedding
Definition: device_sparse_embeddings_forward_layernorm.hpp:118
std::string GetTypeString() const override
Definition: device_sparse_embeddings_forward_layernorm.hpp:176
static auto MakeOutputDescriptor(const index_t index_length, const index_t rows)
Definition: device_sparse_embeddings_forward_layernorm.hpp:38