/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp Source File
wmma_selector.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
10 
11 namespace ck_tile::core::arch::mma {
12 
26 template <typename ADataType,
27  typename BDataType,
28  typename CDataType,
29  uint32_t BlockM,
30  uint32_t BlockN,
31  uint32_t BlockKTest,
32  typename CompilerTarget>
33 // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
34 // TODO: c++20 requires(is_rdna_arch_id(CompilerTarget) && is_power_of_two_integer(BlockKTest))
36 {
37  private:
38  // By default, let's assume no special flags for WMMA
40 
41  // Define our candidate WMMA implementation for the current parameters
42  using CandidateOp = amdgcn_mma<ADataType,
43  BDataType,
44  CDataType,
45  BlockM,
46  BlockN,
47  BlockKTest,
48  CtrlFlags,
49  CompilerTarget>;
50 
52 
53  public:
54  // If the candidate is supported (e.g., a backend implementation exists), then select it.
55  // Otherwise, test another smaller BlockK. If no existing implementations, we will get BlockK=0u
56  // and fall back to the unsupported pass-through implementation.
59  typename WmmaDefaultSelector<ADataType,
60  BDataType,
61  CDataType,
62  BlockM,
63  BlockN,
64  BlockKTest / 2u,
65  CompilerTarget>::SelectedOp>;
66 };
67 
81 template <typename ADataType,
82  typename BDataType,
83  typename CDataType,
84  uint32_t BlockM,
85  uint32_t BlockN,
86  typename CompilerTarget>
87 // TODO: c++20 amdgcn_target_arch_id GfxTargetId>
88 struct WmmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u, CompilerTarget>
89 {
90  // By default, let's assume no special flags for WMMA
92 
93  // Default unsupported pass-through if no instruction is found
94  using SelectedOp =
96 };
97 
112 template <typename ADataType,
113  typename BDataType,
114  typename CDataType,
115  uint32_t FragM,
116  uint32_t FragN,
117  uint32_t FragK,
118  typename CompilerTarget>
119 // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
120 // TODO: c++20 requires
121 struct MmaDefaultSelector<ADataType,
122  BDataType,
123  CDataType,
124  FragM,
125  FragN,
126  FragK,
127  CompilerTarget,
128  enable_if_target_arch_rdna_t<CompilerTarget>>
129 {
130  private:
131  // Provide the default depth-K search strategy for each class of common WMMA shapes.
132  // Start searching from the largest K dimension MFMA shape down to the smallest.
133  using CandidateOp16x16 = typename WmmaDefaultSelector<ADataType,
134  BDataType,
135  CDataType,
136  16u,
137  16u,
138  128u,
139  CompilerTarget>::SelectedOp;
140 
141  // Default operation triggers pass-through
142  using DefaultOp =
145 
146  // Traits for each candidate
148 
149  // Check if each candidate is supported for the given fragment sizes
150  // For this case, we require the fragment sizes to be multiples of the WMMA shape
151  static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
152  (FragM % CandidateTraits16x16::BlockM == 0u) &&
153  (FragN % CandidateTraits16x16::BlockN == 0u) &&
154  (FragK % CandidateTraits16x16::BlockK == 0u);
155 
156  public:
157  // Select the largest supported WMMA operation for the given fragment shape
158  using SelectedOp = std::conditional_t<IsSupported16x16, CandidateOp16x16, DefaultOp>;
159 };
160 
161 } // namespace ck_tile::core::arch::mma
Definition: amdgcn_mma.hpp:10
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
unsigned int uint32_t
Definition: stdint.h:126
Implements the gfx9 default MMA selector strategy for wave-wise MMA decomposition....
Definition: mma_selector.hpp:38
amdgcn_mma< ADataType, BDataType, CDataType, FragM, FragN, FragK, void, amdgcn_target<> > SelectedOp
Definition: mma_selector.hpp:42
Reflects the template parameters and static members of a given MmaOp.
Definition: mma_traits.hpp:125
constexpr static bool IsSupported
Definition: mma_traits.hpp:147
Implements a default WMMA selector strategy for gfx11/12 target architectures. This implements the K ...
Definition: wmma_selector.hpp:36
std::conditional_t< CandidateTraits::IsSupported, CandidateOp, typename WmmaDefaultSelector< ADataType, BDataType, CDataType, BlockM, BlockN, BlockKTest/2u, CompilerTarget >::SelectedOp > SelectedOp
Definition: wmma_selector.hpp:65
This is the default MmaOp policy. Instances of this class are to be used as MmaOp policies....
Definition: amdgcn_mma.hpp:82