/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp Source File
gridwise_2d_reduction_multiblock.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
13 
14 namespace ck {
15 
16 template <typename GridwiseReduction,
17  bool OutputIndex,
18  bool HaveIndexInput,
19  typename InDataType,
20  typename OutDataType,
21  typename AccDataType,
22  typename IndexDataType,
23  typename InGridDesc_M_K,
24  typename OutGridDesc_M,
25  typename InElementwiseOperation,
26  typename AccElementwiseOperation>
27 __global__ void kernel_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k,
28  const OutGridDesc_M out_grid_desc_m,
29  const InElementwiseOperation in_elementwise_op,
30  const AccElementwiseOperation acc_elementwise_op,
31  index_t block_group_size,
32  index_t num_k_block_tile_iteration,
33  AccDataType alpha,
34  const InDataType* const __restrict__ p_in_value_global,
35  const IndexDataType* const __restrict__ p_in_index_global,
36  AccDataType beta,
37  OutDataType* const __restrict__ p_out_value_global,
38  IndexDataType* const __restrict__ p_out_index_global)
39 {
40  if constexpr(!OutputIndex)
41  {
42  (void)p_in_index_global;
43  (void)p_out_index_global;
44 
45  GridwiseReduction::Run(in_grid_desc_m_k,
46  out_grid_desc_m,
47  in_elementwise_op,
48  acc_elementwise_op,
49  block_group_size,
50  num_k_block_tile_iteration,
51  alpha,
52  p_in_value_global,
53  beta,
54  p_out_value_global);
55  }
56  else
57  {
58  GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k,
59  out_grid_desc_m,
60  in_elementwise_op,
61  acc_elementwise_op,
62  num_k_block_tile_iteration,
63  alpha,
64  p_in_value_global,
65  p_in_index_global,
66  beta,
67  p_out_value_global,
68  p_out_index_global);
69  };
70 };
71 
72 template <typename InDataType,
73  typename OutDataType,
74  typename AccDataType,
75  typename IndexDataType,
76  typename InGridDesc_M_K,
77  typename OutGridDesc_M,
78  typename ReduceOperation,
79  typename InElementwiseOperation,
80  typename AccElementwiseOperation,
81  InMemoryDataOperationEnum OutMemoryDataOperation,
82  bool PropagateNan,
83  index_t BlockSize,
84  index_t MThreadClusterSize,
85  index_t KThreadClusterSize,
86  index_t MThreadSliceSize,
87  index_t KThreadSliceSize,
88  index_t InSrcVectorDim,
89  index_t InSrcVectorSize,
90  index_t OutDstVectorSize>
92 {
93  static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
94  (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
95  (MThreadSliceSize % OutDstVectorSize == 0),
96  "Invalid thread slice sizes and/or vector sizes configuration, please check!");
97 
98  static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
99 
101 
104 
107 
108  static constexpr auto thread_cluster_desc =
110 
115 
117  BlockSize,
120  ReduceOperation,
121  PropagateNan>;
122 
126  ReduceOperation,
127  PropagateNan>;
128 
130 
131  static constexpr auto I0 = Number<0>{};
132  static constexpr auto I1 = Number<1>{};
133 
134  static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
135  static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
136 
138 
139  __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
140  const OutGridDesc_M& out_grid_desc_m,
141  const InElementwiseOperation& in_elementwise_op,
142  const AccElementwiseOperation& acc_elementwise_op,
143  index_t block_group_size,
144  index_t num_k_block_tile_iteration,
145  AccDataType alpha,
146  const InDataType* const __restrict__ p_in_value_global,
147  AccDataType beta,
148  OutDataType* const __restrict__ p_out_value_global)
149  {
150  const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
151 
152  // LDS
153  __shared__ AccDataType p_reduce_work_buffer[BlockSize];
154 
155  const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
156  p_in_value_global,
157  in_grid_desc_m_k.GetElementSpaceSize(),
158  ReduceOperation::template GetIdentityValue<InDataType>());
159  auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
160  p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
161 
162  auto reduce_work_buf =
163  make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
164 
166  in_thread_buf;
167 
169 
170  static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
171 
172  const index_t thread_local_id = get_thread_local_1d_id();
173  const index_t block_global_id = get_block_1d_id();
174  const index_t blkgroup_id = block_global_id / block_group_size;
175  const index_t block_local_id = block_global_id % block_group_size;
176 
177  const auto thread_cluster_idx =
178  thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
179 
180  const auto thread_m_cluster_id = thread_cluster_idx[I0];
181  const auto thread_k_cluster_id = thread_cluster_idx[I1];
182 
183  const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
184 
185  using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
186  constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
188 
189  auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
190  AccDataType,
191  InGridDesc_M_K,
192  decltype(thread_buffer_desc),
193  ThreadBufferLengths,
195  InSrcVectorDim,
196  InSrcVectorSize,
197  1,
198  false>(
199  in_grid_desc_m_k,
200  make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
201  block_local_id * reduceSizePerBlock +
202  thread_k_cluster_id * KThreadSliceSize));
203 
204  constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
205 
206  index_t reducedTiles = 0;
207  do
208  {
209  threadwise_src_load.Run(in_grid_desc_m_k,
210  in_global_val_buf,
211  thread_buffer_desc,
212  make_tuple(I0, I0),
213  in_thread_buf);
214 
215  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
216  // do element-wise pre-reduction operation
217  static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
218  constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
219  in_elementwise_op(in_thread_buf(Number<offset>{}),
220  in_thread_buf(Number<offset>{}));
221  });
222  });
223 
224  ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
225 
226  threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
227 
228  reducedTiles++;
229  } while(reducedTiles < num_k_block_tile_iteration);
230 
231  constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
232 
234  [&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
235 
236  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
237  if(thread_k_cluster_id == 0)
238  {
239  acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
240 
241  accu_value_buf(I) *= alpha;
242  }
243  });
244 
245  if(thread_k_cluster_id == 0)
246  {
247  if(!float_equal_zero{}(beta))
248  {
250  priorDstValueBuf;
251 
252  auto threadwise_dst_load =
254  OutDataType,
255  OutGridDesc_M,
256  decltype(reduced_data_desc),
258  Sequence<0>,
259  0,
260  OutDstVectorSize,
261  1,
262  false>(
263  out_grid_desc_m,
264  make_multi_index(blkgroup_id * M_BlockTileSize +
265  thread_m_cluster_id * MThreadSliceSize));
266 
267  threadwise_dst_load.Run(out_grid_desc_m,
268  out_global_val_buf,
269  reduced_data_desc,
270  make_tuple(I0),
271  priorDstValueBuf);
272 
273  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
274  accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
275  });
276  };
277 
278  auto threadwise_dst_store =
280  OutDataType,
281  decltype(reduced_data_desc),
282  OutGridDesc_M,
285  Sequence<0>,
286  0,
287  OutDstVectorSize,
288  OutMemoryDataOperation,
289  1,
290  true>(
291  out_grid_desc_m,
292  make_multi_index(blkgroup_id * M_BlockTileSize +
293  thread_m_cluster_id * MThreadSliceSize),
294  PassThroughOp{});
295 
296  threadwise_dst_store.Run(reduced_data_desc,
297  make_tuple(I0),
298  accu_value_buf,
299  out_grid_desc_m,
300  out_global_val_buf);
301  }
302  };
303 
304  template <bool HaveIndexInput>
305  __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
306  const OutGridDesc_M& out_grid_desc_m,
307  const InElementwiseOperation in_elementwise_op,
308  const AccElementwiseOperation acc_elementwise_op,
309  index_t num_k_block_tile_iteration,
310  AccDataType alpha,
311  const InDataType* const __restrict__ p_in_value_global,
312  const IndexDataType* const __restrict__ p_in_index_global,
313  AccDataType beta,
314  OutDataType* const __restrict__ p_out_value_global,
315  IndexDataType* const __restrict__ p_out_index_global)
316  {
317  using BlockwiseReduceWithIndex =
319  IndexDataType,
320  BlockSize,
323  ReduceOperation,
324  PropagateNan>;
325 
326  using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
327  ReduceOperation,
328  AccDataType,
329  IndexDataType>;
330 
331  (void)in_elementwise_op;
332 
333  // LDS
334  __shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
335  __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
336 
337  const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
338 
339  const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
340  p_in_value_global,
341  in_grid_desc_m_k.GetElementSpaceSize(),
342  ReduceOperation::template GetIdentityValue<InDataType>());
343  const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
344  p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
345  auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
346  p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
347  auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
348  p_out_index_global, out_grid_desc_m.GetElementSpaceSize());
349 
350  auto reduce_work_val_buf =
351  make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
352  auto reduce_work_idx_buf =
353  make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
354 
356  in_thread_val_buf;
357 
359  IndexDataType,
360  MThreadSliceSize * KThreadSliceSize,
361  true>
362  in_thread_idx_buf;
363 
366 
367  const index_t thread_local_id = get_thread_local_1d_id();
368  const index_t block_global_1d_id = get_block_1d_id();
369 
370  const auto thread_cluster_idx =
371  thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
372 
373  const auto thread_m_cluster_id = thread_cluster_idx[I0];
374  const auto thread_k_cluster_id = thread_cluster_idx[I1];
375 
376  using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
377  constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
379 
380  auto threadwise_src_val_load =
382  AccDataType,
383  InGridDesc_M_K,
384  decltype(thread_buffer_desc),
385  ThreadBufferLengths,
387  InSrcVectorDim,
388  InSrcVectorSize,
389  1,
390  false>(
391  in_grid_desc_m_k,
392  make_multi_index(block_global_1d_id * M_BlockTileSize +
393  thread_m_cluster_id * MThreadSliceSize,
394  thread_k_cluster_id * KThreadSliceSize));
395 
396  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
397  accu_value_buf(I) = identityVal;
398  accu_index_buf(I) = 0;
399  });
400 
401  constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
402 
403  index_t reducedTiles = 0;
404 
405  if constexpr(HaveIndexInput)
406  {
407  auto threadwise_src_idx_load =
408  ThreadwiseTensorSliceTransfer_v2<IndexDataType,
409  IndexDataType,
410  InGridDesc_M_K,
411  decltype(thread_buffer_desc),
412  ThreadBufferLengths,
414  InSrcVectorDim,
415  InSrcVectorSize,
416  1,
417  false>(
418  in_grid_desc_m_k,
419  make_multi_index(block_global_1d_id * M_BlockTileSize +
420  thread_m_cluster_id * MThreadSliceSize,
421  thread_k_cluster_id * KThreadSliceSize));
422 
423  do
424  {
425  // load the thread slice
426  threadwise_src_val_load.Run(in_grid_desc_m_k,
427  in_global_val_buf,
428  thread_buffer_desc,
429  make_tuple(I0, I0),
430  in_thread_val_buf);
431  threadwise_src_idx_load.Run(in_grid_desc_m_k,
432  in_global_idx_buf,
433  thread_buffer_desc,
434  make_tuple(I0, I0),
435  in_thread_idx_buf);
436 
437  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
438  AccDataType tmpValue = identityVal;
439  IndexDataType tmpIndex = 0;
440 
441  static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
442  constexpr auto offset =
443  thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
444 
445  AccumulationWithIndex::Calculate(tmpValue,
446  in_thread_val_buf[Number<offset>{}],
447  tmpIndex,
448  in_thread_idx_buf[Number<offset>{}]);
449  });
450 
451  BlockwiseReduceWithIndex::Reduce(
452  reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
453 
454  AccumulationWithIndex::Calculate(
455  accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
456  });
457 
458  threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
459  threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
460 
461  reducedTiles++;
462  } while(reducedTiles < num_k_block_tile_iteration);
463  }
464  else
465  {
466  index_t indexOffset = 0;
467 
468  do
469  {
470  // load the thread slice
471  threadwise_src_val_load.Run(in_grid_desc_m_k,
472  in_global_val_buf,
473  thread_buffer_desc,
474  make_tuple(I0, I0),
475  in_thread_val_buf);
476 
477  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
478  static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
479  constexpr auto offset =
480  thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
481 
482  // initialize the indices for the per-thread to-reduce values
483  in_thread_idx_buf(Number<offset>{}) =
484  indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
485 
486  // do element-wise pre-reduction operation
487  in_elementwise_op(in_thread_val_buf(Number<offset>{}),
488  in_thread_val_buf(Number<offset>{}));
489  });
490 
491  AccDataType tmpValue = identityVal;
492  IndexDataType tmpIndex = 0;
493 
494  static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
495  constexpr auto offset =
496  thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
497 
498  AccumulationWithIndex::Calculate(tmpValue,
499  in_thread_val_buf[Number<offset>{}],
500  tmpIndex,
501  in_thread_idx_buf[Number<offset>{}]);
502  });
503 
504  BlockwiseReduceWithIndex::Reduce(
505  reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
506 
507  AccumulationWithIndex::Calculate(
508  accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
509  });
510 
511  threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
512 
513  indexOffset += K_BlockTileSize;
514  reducedTiles++;
515  } while(reducedTiles < num_k_block_tile_iteration);
516  };
517 
518  constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
519 
520  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
521  if(thread_k_cluster_id == 0)
522  {
523  // for indiced operation, acc_elementwise_op shoud do nothing
524  acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
525 
526  accu_value_buf(I) *= alpha;
527  }
528  });
529 
530  if(thread_k_cluster_id == 0)
531  {
532  if(!float_equal_zero{}(beta))
533  {
535  priorDstValueBuf;
536 
537  auto threadwise_dst_load =
539  OutDataType,
540  OutGridDesc_M,
541  decltype(reduced_data_desc),
543  Sequence<0>,
544  0,
545  OutDstVectorSize,
546  1,
547  true>(
548  out_grid_desc_m,
549  make_multi_index(block_global_1d_id * M_BlockTileSize +
550  thread_m_cluster_id * MThreadSliceSize));
551 
552  threadwise_dst_load.Run(out_grid_desc_m,
553  out_global_val_buf,
554  reduced_data_desc,
555  make_tuple(I0),
556  priorDstValueBuf);
557 
558  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
559  accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
560  });
561  };
562 
563  auto threadwise_dst_val_store =
565  OutDataType,
566  decltype(reduced_data_desc),
567  OutGridDesc_M,
570  Sequence<0>,
571  0,
572  OutDstVectorSize,
574  1,
575  true>(
576  out_grid_desc_m,
577  make_multi_index(block_global_1d_id * M_BlockTileSize +
578  thread_m_cluster_id * MThreadSliceSize),
579  PassThroughOp{});
580 
581  auto threadwise_dst_idx_store =
583  IndexDataType,
584  decltype(reduced_data_desc),
585  OutGridDesc_M,
588  Sequence<0>,
589  0,
590  OutDstVectorSize,
592  1,
593  true>(
594  out_grid_desc_m,
595  make_multi_index(block_global_1d_id * M_BlockTileSize +
596  thread_m_cluster_id * MThreadSliceSize),
597  PassThroughOp{});
598 
599  threadwise_dst_val_store.Run(reduced_data_desc,
600  make_tuple(I0),
601  accu_value_buf,
602  out_grid_desc_m,
603  out_global_val_buf);
604  threadwise_dst_idx_store.Run(reduced_data_desc,
605  make_tuple(I0),
606  accu_index_buf,
607  out_grid_desc_m,
608  out_global_idx_buf);
609  }
610  };
611 };
612 
613 } // namespace ck
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
InMemoryDataOperationEnum
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__global__ void kernel_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, index_t block_group_size, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition: gridwise_2d_reduction_multiblock.hpp:27
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:16
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: gridwise_2d_reduction_multiblock.hpp:92
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition: gridwise_2d_reduction_multiblock.hpp:114
static constexpr bool reorder_thread_cluster
Definition: gridwise_2d_reduction_multiblock.hpp:98
static __device__ void Run(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M &out_grid_desc_m, const InElementwiseOperation &in_elementwise_op, const AccElementwiseOperation &acc_elementwise_op, index_t block_group_size, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global)
Definition: gridwise_2d_reduction_multiblock.hpp:139
static constexpr index_t M_BlockTileSize
Definition: gridwise_2d_reduction_multiblock.hpp:134
static constexpr auto I0
Definition: gridwise_2d_reduction_multiblock.hpp:131
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_2d_reduction_multiblock.hpp:103
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition: gridwise_2d_reduction_multiblock.hpp:100
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition: gridwise_2d_reduction_multiblock.hpp:106
static constexpr auto thread_cluster_desc
Definition: gridwise_2d_reduction_multiblock.hpp:108
static constexpr auto I1
Definition: gridwise_2d_reduction_multiblock.hpp:132
static __device__ void RunWithIndex(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M &out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition: gridwise_2d_reduction_multiblock.hpp:305
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_2d_reduction_multiblock.hpp:129
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition: gridwise_2d_reduction_multiblock.hpp:112
static constexpr index_t K_BlockTileSize
Definition: gridwise_2d_reduction_multiblock.hpp:135
Definition: reduction_functions_blockwise.hpp:28
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition: reduction_functions_blockwise.hpp:44
Definition: reduction_functions_blockwise.hpp:175
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: reduction_functions_threadwise.hpp:23
static __device__ void Reduce(const SrcBufferType &src_buf, DstBufferType &dst_buf)
Definition: reduction_functions_threadwise.hpp:36
Definition: threadwise_tensor_slice_transfer.hpp:39
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:66
Definition: threadwise_tensor_slice_transfer.hpp:214
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:243
Definition: functional.hpp:100
Definition: reduction_functions_accumulate.hpp:65
Definition: reduction_functions_accumulate.hpp:28
Definition: reduction_common.hpp:20
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: unary_element_wise_operation.hpp:241