/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp Source File
gridwise_set_multiple_buffer_value.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 
9 namespace ck {
10 
11 template <typename Grid1dBufferDescTuple,
12  index_t NumBuffer,
13  index_t BlockSize,
14  typename DataTypePointerTuple,
15  typename DataTypeTuple>
16 __global__ void
17 kernel_multiple_buffer_set_value(const Grid1dBufferDescTuple grid_1d_buffer_desc_tuple,
18  DataTypePointerTuple p_global_tuple,
19  DataTypeTuple value_tuple)
20 
21 {
22  static_assert(NumBuffer == DataTypePointerTuple::Size() && NumBuffer == DataTypeTuple::Size(),
23  "The tuple size should be same as NumBuffer!");
24 
25  static_for<0, NumBuffer, 1>{}([&](auto iB) {
26  using DataTypePointer = remove_cvref_t<decltype(DataTypePointerTuple{}[iB])>;
27  using DataTypeFromPointer = remove_pointer_t<DataTypePointer>;
28  using DataType = remove_cvref_t<decltype(DataTypeTuple{}[iB])>;
29 
31  "Types in tuples does not match!");
32  });
33 
34  constexpr auto I0 = Number<0>{};
35 
36  const index_t thread_global_id = get_thread_global_1d_id();
37 
38  auto value_buf_tuple = generate_tuple(
39  [&](auto iB) {
40  using DataType = remove_cvref_t<decltype(DataTypeTuple{}[iB])>;
41 
43  },
45 
46  static_for<0, NumBuffer, 1>{}([&](auto iB) {
47  static_for<0, 1, 1>{}([&](auto J) { value_buf_tuple(iB)(J) = value_tuple[iB]; });
48  });
49 
50  auto global_buf_tuple = generate_tuple(
51  [&](auto iB) {
52  return make_dynamic_buffer<AddressSpaceEnum::Global>(
53  p_global_tuple(iB), grid_1d_buffer_desc_tuple[iB].GetElementSpaceSize());
54  },
56 
57  constexpr auto val_buff_desc = make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
58 
59  static_for<0, NumBuffer, 1>{}([&](auto iB) {
60  using DataType = remove_cvref_t<decltype(DataTypeTuple{}[iB])>;
61  using PassThroughOp = tensor_operation::element_wise::PassThrough;
62 
63  auto threadwise_store =
65  DataType,
66  decltype(val_buff_desc),
67  decltype(Grid1dBufferDescTuple{}[iB]),
68  PassThroughOp,
71  0,
72  1,
74  1,
75  true>(
76  grid_1d_buffer_desc_tuple[iB], make_multi_index(thread_global_id), PassThroughOp{});
77 
78  threadwise_store.Run(val_buff_desc,
79  make_tuple(I0),
80  value_buf_tuple(iB),
81  grid_1d_buffer_desc_tuple[iB],
82  global_buf_tuple(iB));
83  });
84 };
85 
86 } // namespace ck
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
typename remove_pointer< T >::type remove_pointer_t
Definition: type.hpp:303
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__global__ void kernel_multiple_buffer_set_value(const Grid1dBufferDescTuple grid_1d_buffer_desc_tuple, DataTypePointerTuple p_global_tuple, DataTypeTuple value_tuple)
Definition: gridwise_set_multiple_buffer_value.hpp:17
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:18
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
int32_t index_t
Definition: ck.hpp:289
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: functional2.hpp:31
Definition: unary_element_wise_operation.hpp:241