/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp Source File
reduce2d_default_policy.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
13 {
14  template <typename Problem>
16  {
17  using S = typename Problem::BlockShape;
20  sequence<>,
21  tuple<
28  }
29 
30  template <typename Problem>
31  CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
32  {
33  using P_ = BlockReduce2dProblem<typename Problem::XDataType,
34  typename Problem::ComputeDataType,
35  typename Problem::BlockShape>;
36  return BlockReduce2d<P_>{};
37  }
38 
39  template <typename Problem>
40  CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
41  {
42  using P_ = BlockReduce2dProblem<typename Problem::XDataType,
43  typename Problem::ComputeDataType,
44  typename Problem::BlockShape>;
45  return BlockReduce2dSync<P_>{};
46  }
47 
48  template <typename Problem>
50  {
51  using P_ = BlockReduce2dProblem<typename Problem::XDataType,
52  typename Problem::ComputeDataType,
53  typename Problem::BlockShape>;
55  }
56 
57  template <typename Problem>
59  {
60  if constexpr(Problem::kNeedCrossWarpSync)
61  {
62  using P_ = BlockReduce2dProblem<typename Problem::XDataType,
63  typename Problem::ComputeDataType,
64  typename Problem::BlockShape>;
65 
66  using block_reduce2d = BlockReduce2d<P_>;
67  using x_block_tile =
68  decltype(make_static_distributed_tensor<typename Problem::XDataType>(
69  MakeXBlockTileDistribution<Problem>()));
70  using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile<x_block_tile>());
71 
72  return GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
73  }
74  else
75  {
76  return 1; // zero size arrays are an extension
77  }
78  }
79 };
80 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
Definition: block_reduce2d.hpp:200
Definition: block_reduce2d.hpp:45
Definition: block_reduce2d_problem.hpp:12
Definition: block_reduce2d.hpp:135
Definition: reduce2d_default_policy.hpp:13
static constexpr CK_TILE_HOST_DEVICE auto GetBlockReduce2dSync()
Definition: reduce2d_default_policy.hpp:40
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: reduce2d_default_policy.hpp:58
static constexpr CK_TILE_DEVICE auto MakeXBlockTileDistribution()
Definition: reduce2d_default_policy.hpp:15
static constexpr CK_TILE_HOST_DEVICE auto GetBlockReduce2d()
Definition: reduce2d_default_policy.hpp:31
static constexpr CK_TILE_HOST_DEVICE auto GetBlockReduce2dCrossWarpSync()
Definition: reduce2d_default_policy.hpp:49
Definition: sequence.hpp:49
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192