/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp Source File
topk_softmax_warp_per_row_problem.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include <string>
8 #include <type_traits>
9 
10 namespace ck_tile {
11 
12 template <typename InputType_,
13  typename WeightType_,
14  typename IndexType_,
15  index_t Experts_,
16  index_t IssuesPerCol_ = 2, // issue along col, to make sure block_reduce() OK
17  index_t BytesPerIssue_ = sizeof(InputType_),
18  index_t LaunchType_ = 0, // 0-streaming, >0, persistent #occupancy
19  index_t BlockSize_ = 256>
21 {
22  // TODO: this kernel only support warp per row
26 
27  static constexpr index_t LaunchType = LaunchType_;
28  static constexpr index_t Experts = Experts_;
29  static constexpr index_t BytesPerIssue = BytesPerIssue_;
30  static constexpr index_t IssuesPerCol = IssuesPerCol_;
31  static constexpr index_t BlockSize = BlockSize_;
32  static constexpr index_t WarpSize = get_warp_size();
33 
34  static_assert(BytesPerIssue % sizeof(InputType) == 0);
35  static constexpr index_t VectorSize = BytesPerIssue / sizeof(InputType);
36  static_assert(Experts % VectorSize == 0);
37  static constexpr index_t LanesPerRow = min(Experts / VectorSize, WarpSize);
38  static_assert(WarpSize % LanesPerRow == 0);
39  static constexpr index_t RowsPerWarpPerColIssue = WarpSize / LanesPerRow;
40  static constexpr index_t RowsPerWarp = IssuesPerCol * RowsPerWarpPerColIssue;
41  static constexpr index_t IssuesPerRow = Experts / (LanesPerRow * VectorSize);
42 
43  static constexpr index_t WarpsPerBlock = BlockSize / WarpSize;
44  static constexpr index_t RowsPerBlock = RowsPerWarp * WarpsPerBlock;
45 };
46 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE index_t get_warp_size()
Definition: arch.hpp:51
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
Definition: topk_softmax_warp_per_row_problem.hpp:21
remove_cvref_t< IndexType_ > IndexType
Definition: topk_softmax_warp_per_row_problem.hpp:25
remove_cvref_t< InputType_ > InputType
Definition: topk_softmax_warp_per_row_problem.hpp:23
remove_cvref_t< WeightType_ > WeightType
Definition: topk_softmax_warp_per_row_problem.hpp:24