/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp Source File
mfma_selector.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
11 
12 #include "mfma_traits.hpp"
13 #include "mfma_gfx9.hpp"
14 
15 namespace ck_tile::core::arch::mma {
16 
35 template <typename ADataType,
36  typename BDataType,
37  typename CDataType,
38  uint32_t BlockM,
39  uint32_t BlockN,
40  uint32_t BlockKTest,
41  typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
42 // TODO: c++20 requires(is_gfx9_arch_id(CompilerTarget) && is_power_of_two_integer(BlockKTest))
44 {
45  private:
46  // Define our candidate MFMA implementation for the current parameters
47  using CandidateOp =
48  amdgcn_mma<ADataType,
49  BDataType,
50  CDataType,
51  BlockM,
52  BlockN,
53  BlockKTest,
54  DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
55  CompilerTarget>;
57 
58  public:
59  // If the candidate is supported (e.g., a backend implementation exists), then select it.
60  // Otherwise, test another smaller BlockK. If no existing implementations, we will get BlockK=0u
61  // and fall back to the unsupported pass-through implementation.
64  typename MfmaDefaultSelector<ADataType,
65  BDataType,
66  CDataType,
67  BlockM,
68  BlockN,
69  BlockKTest / 2u,
70  CompilerTarget>::SelectedOp>;
71 };
72 
84 template <typename ADataType,
85  typename BDataType,
86  typename CDataType,
87  uint32_t BlockM,
88  uint32_t BlockN,
89  typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
90 struct MfmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u, CompilerTarget>
91 {
92  // Default unsupported pass-through if no instruction is found
93  using SelectedOp =
94  amdgcn_mma<ADataType,
95  BDataType,
96  CDataType,
97  BlockM,
98  BlockN,
99  1u,
100  DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
101  CompilerTarget>;
102 };
103 
118 template <typename ADataType,
119  typename BDataType,
120  typename CDataType,
121  uint32_t FragM,
122  uint32_t FragN,
123  uint32_t FragK,
124  typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
125 struct MmaDefaultSelector<ADataType,
126  BDataType,
127  CDataType,
128  FragM,
129  FragN,
130  FragK,
131  CompilerTarget,
132  enable_if_target_family_gfx9_t<CompilerTarget>>
133 {
134  private:
135  // Provide the default depth-K search strategy for each class of common MFMA shapes.
136  // Start searching from the largest K dimension MFMA shape down to the smallest.
137  using CandidateOp4x4 =
140  using CandidateOp16x16 = typename MfmaDefaultSelector<ADataType,
141  BDataType,
142  CDataType,
143  16u,
144  16u,
145  128u,
146  CompilerTarget>::SelectedOp;
147  using CandidateOp32x32 = typename MfmaDefaultSelector<ADataType,
148  BDataType,
149  CDataType,
150  32u,
151  32u,
152  64u,
153  CompilerTarget>::SelectedOp;
154 
155  // Default operation triggers pass-through
156  using DefaultOp =
159 
160  // Traits for each candidate
164 
165  // Check if each candidate is supported for the given fragment sizes
166  // For this case, we require the fragment sizes to be multiples of the MFMA shape
167  static constexpr bool IsSupported4x4 =
168  CandidateTraits4x4::IsSupported && (FragM % CandidateTraits4x4::BlockM == 0u) &&
169  (FragN % CandidateTraits4x4::BlockN == 0u) && (FragK % CandidateTraits4x4::BlockK == 0u);
170  static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
171  (FragM % CandidateTraits16x16::BlockM == 0u) &&
172  (FragN % CandidateTraits16x16::BlockN == 0u) &&
173  (FragK % CandidateTraits16x16::BlockK == 0u);
174  static constexpr bool IsSupported32x32 = CandidateTraits32x32::IsSupported &&
175  (FragM % CandidateTraits32x32::BlockM == 0u) &&
176  (FragN % CandidateTraits32x32::BlockN == 0u) &&
177  (FragK % CandidateTraits32x32::BlockK == 0u);
178 
179  public:
180  // Select the largest supported MFMA operation for the given fragment shape
182  IsSupported32x32,
183  CandidateOp32x32,
184  std::conditional_t<IsSupported16x16,
185  CandidateOp16x16,
186  std::conditional_t<IsSupported4x4, CandidateOp4x4, DefaultOp>>>;
187 };
188 
189 } // namespace ck_tile::core::arch::mma
Definition: amdgcn_mma.hpp:10
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
unsigned int uint32_t
Definition: stdint.h:126
Implements a default MFMA selector strategy for gfx9 target architectures. This implements the K dime...
Definition: mfma_selector.hpp:44
std::conditional_t< CandidateTraits::IsSupported, CandidateOp, typename MfmaDefaultSelector< ADataType, BDataType, CDataType, BlockM, BlockN, BlockKTest/2u, CompilerTarget >::SelectedOp > SelectedOp
Definition: mfma_selector.hpp:70
std::conditional_t< IsSupported32x32, CandidateOp32x32, std::conditional_t< IsSupported16x16, CandidateOp16x16, std::conditional_t< IsSupported4x4, CandidateOp4x4, DefaultOp > >> SelectedOp
Definition: mfma_selector.hpp:186
Implements the gfx9 default MMA selector strategy for wave-wise MMA decomposition....
Definition: mma_selector.hpp:38
amdgcn_mma< ADataType, BDataType, CDataType, FragM, FragN, FragK, void, amdgcn_target<> > SelectedOp
Definition: mma_selector.hpp:42
Reflects the template parameters and static members of a given MmaOp.
Definition: mma_traits.hpp:125
constexpr static bool IsSupported
Definition: mma_traits.hpp:147
This is the default MmaOp policy. Instances of this class are to be used as MmaOp policies....
Definition: amdgcn_mma.hpp:82