/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp Source File
device_put_element_impl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
16 
17 namespace ck {
18 namespace tensor_operation {
19 namespace device {
20 
21 // output[indices] = input
22 template <typename InDataType,
23  typename IndexDataType,
24  typename OutDataType,
25  typename ElementwiseOperation,
27  ck::index_t InVectorSize>
29  : public DevicePutElement<InDataType, IndexDataType, OutDataType, ElementwiseOperation, MemOp>
30 {
31  template <typename Desc_M>
32  static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
33  {
34  constexpr auto I0 = Number<0>{};
35 
36  const auto m = desc_m.GetLength(I0);
37  const index_t loop_step = gridSize * blockSize * InVectorSize;
38  const auto pad = math::integer_least_multiple(m, loop_step) - m;
39  const auto desc_m_pad =
44  return desc_m_pad;
45  }
46 
47  static auto MakeDescriptor_M(index_t length, index_t gridSize, index_t blockSize)
48  {
49  const auto desc_m = make_naive_tensor_descriptor_packed(make_tuple(length));
50  return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
51  }
52 
53  using InGrid1dDesc = decltype(MakeDescriptor_M(1, 1, 1));
54 
56  InDataType,
57  IndexDataType,
58  OutDataType,
59  ElementwiseOperation,
60  MemOp,
61  InVectorSize>;
62 
63  struct Argument : public BaseArgument
64  {
65  Argument(const InDataType* p_input,
66  const IndexDataType* p_indices,
67  OutDataType* p_output,
68  index_t input_length,
69  ElementwiseOperation elementwise_op)
70  : p_input_{p_input},
71  p_indices_{p_indices},
72  p_output_{p_output},
73  input_length_raw_{input_length},
74  elementwise_op_{elementwise_op},
75  blockSize_{256}
76  {
77  }
78 
79  const InDataType* p_input_;
80  const IndexDataType* p_indices_;
81  OutDataType* p_output_;
83  ElementwiseOperation elementwise_op_;
85  };
86 
87  struct Invoker : public BaseInvoker
88  {
89  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
90  {
91  index_t gridSize = getAvailableComputeUnitCount(stream_config);
92  InGrid1dDesc in_grid_desc =
93  MakeDescriptor_M(arg.input_length_raw_, gridSize, arg.blockSize_);
94 
95  const auto kernel = kernel_put_element_1d<GridwisePutElement,
97  InDataType,
98  IndexDataType,
99  OutDataType,
100  ElementwiseOperation>;
101 
102  float elapsed_time = launch_and_time_kernel(stream_config,
103  kernel,
104  dim3(gridSize),
105  dim3(arg.blockSize_),
106  0,
107  in_grid_desc,
108  arg.p_input_,
109  arg.p_indices_,
110  arg.p_output_,
111  arg.elementwise_op_);
112  return elapsed_time;
113  }
114 
115  float Run(const BaseArgument* p_arg,
116  const StreamConfig& stream_config = StreamConfig{}) override
117  {
118  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
119  }
120  };
121 
122  bool IsSupportedArgument(const BaseArgument* p_arg) override
123  {
124  const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
125 
126  if(pArg->input_length_raw_ % InVectorSize != 0)
127  {
128  return false;
129  }
130  return true;
131  }
132 
133  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_input,
134  const void* p_indices,
135  void* p_output,
136  index_t input_length,
137  index_t,
138  ElementwiseOperation elementwise_op) override
139  {
140  return std::make_unique<Argument>(static_cast<const InDataType*>(p_input),
141  static_cast<const IndexDataType*>(p_indices),
142  static_cast<OutDataType*>(p_output),
143  input_length,
144  elementwise_op);
145  }
146 
147  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
148  {
149  return std::make_unique<Invoker>(Invoker{});
150  }
151 };
152 
153 } // namespace device
154 } // namespace tensor_operation
155 } // namespace ck
auto pad(ck::index_t mpb, ck::index_t npb, ck::index_t kpb, ck::tensor_operation::device::GemmSpecialization gemm, CDesc_MRaw_NRaw conv)
Definition: helper.hpp:70
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
Definition: ck.hpp:264
InMemoryDataOperationEnum
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__global__ void kernel_put_element_1d(const InGrid1dDesc in_grid_1d_desc, const InDataType *__restrict__ p_in_global, const IndexDataType *__restrict__ p_indices_global, OutDataType *__restrict__ p_out_global, const ElementwiseOperation elementwise_op)
Definition: gridwise_put_element_1d.hpp:17
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: stream_config.hpp:10
Definition: gridwise_put_element_1d.hpp:36
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_put_element.hpp:22
Definition: device_put_element_impl.hpp:64
const IndexDataType * p_indices_
Definition: device_put_element_impl.hpp:80
Argument(const InDataType *p_input, const IndexDataType *p_indices, OutDataType *p_output, index_t input_length, ElementwiseOperation elementwise_op)
Definition: device_put_element_impl.hpp:65
ElementwiseOperation elementwise_op_
Definition: device_put_element_impl.hpp:83
index_t blockSize_
Definition: device_put_element_impl.hpp:84
index_t input_length_raw_
Definition: device_put_element_impl.hpp:82
OutDataType * p_output_
Definition: device_put_element_impl.hpp:81
const InDataType * p_input_
Definition: device_put_element_impl.hpp:79
Definition: device_put_element_impl.hpp:88
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_put_element_impl.hpp:115
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_put_element_impl.hpp:89
Definition: device_put_element_impl.hpp:30
decltype(MakeDescriptor_M(1, 1, 1)) InGrid1dDesc
Definition: device_put_element_impl.hpp:53
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_input, const void *p_indices, void *p_output, index_t input_length, index_t, ElementwiseOperation elementwise_op) override
Definition: device_put_element_impl.hpp:133
static auto MakeDescriptor_M(index_t length, index_t gridSize, index_t blockSize)
Definition: device_put_element_impl.hpp:47
GridwisePutElement_1D< InGrid1dDesc, InDataType, IndexDataType, OutDataType, ElementwiseOperation, MemOp, InVectorSize > GridwisePutElement
Definition: device_put_element_impl.hpp:61
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_put_element_impl.hpp:147
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_put_element_impl.hpp:122
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
Definition: device_put_element_impl.hpp:32