/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/mma_traits.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/mma_traits.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/mma_traits.hpp Source File
mma_traits.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 #pragma once
4 #include "amdgcn_mma.hpp"
5 #include "mfma/mfma_traits.hpp"
6 #include "wmma/wmma_traits.hpp"
7 
8 namespace ck_tile::core::arch::mma {
9 
15 // TODO: c++20 template <MmaOpI MmaOp, typename = void>
16 template <typename MmaOp, typename = void>
18 {
19 };
20 
26 // TODO: c++20 template <MmaOpI MmaOp>
27 template <typename MmaOp>
28 // TODO: c++20 requires
29 struct is_mma_op_supported<MmaOp,
30  std::enable_if_t<std::is_same_v<typename MmaOp::OpType, Unsupported>>>
32 {
33 };
34 
39 // TODO: c++20 template <MmaOpI MmaOp>
40 template <typename MmaOp>
41 static constexpr bool is_mma_op_supported_v = is_mma_op_supported<MmaOp>::value;
42 
48 // TODO: c++20 template <MmaOpI MmaOp>
49 template <typename MmaOp>
50 struct MmaOpParams;
51 
52 #if defined(__cpp_concepts) && __cpp_concepts >= 201907L
53 
58 template <typename MmaOpParams>
59 concept MmaOpParamsI = requires(MmaOpParams op) {
60  // Capture template parameters
61  typename MmaOpParams::ADataType;
62  typename MmaOpParams::BDataType;
63  typename MmaOpParams::CDataType;
64  typename MmaOpParams::CtrlFlags;
65 
66  { MmaOpParams::BlockM } -> std::convertible_to<unsigned int>;
67  { MmaOpParams::BlockN } -> std::convertible_to<unsigned int>;
68  { MmaOpParams::BlockK } -> std::convertible_to<unsigned int>;
69  { MmaOpParams::GfxTargetId } -> std::convertible_to<amdgcn_target_arch_id>;
70 };
71 
72 #endif // defined(__cpp_concepts) && __cpp_concepts >= 201907L
73 
86 template <typename ADataType_,
87  typename BDataType_,
88  typename CDataType_,
89  uint32_t BlockM_,
90  uint32_t BlockN_,
91  uint32_t BlockK_,
92  typename CtrlFlags_,
93  typename CompilerTarget_>
94 // TODO: c++20 amdgcn_target_arch_id CompilerTarget_>
95 struct MmaOpParams<amdgcn_mma<ADataType_,
96  BDataType_,
97  CDataType_,
98  BlockM_,
99  BlockN_,
100  BlockK_,
101  CtrlFlags_,
102  CompilerTarget_>>
103 {
104  // Capture incoming template parameters
105  using ADataType = ADataType_;
106  using BDataType = BDataType_;
107  using CDataType = CDataType_;
108  static constexpr uint32_t BlockM = BlockM_;
109  static constexpr uint32_t BlockN = BlockN_;
110  static constexpr uint32_t BlockK = BlockK_;
111  using CtrlFlags = CtrlFlags_;
112  using CompilerTarget = CompilerTarget_;
113  // TODO c++20static constexpr amdgcn_target_arch_id GfxTargetId = CompilerTarget_;
114 };
115 
121 template <typename MmaOp>
122 // TODO: c++20 template <MmaOpI MmaOp>
123 // TODO: c++20 requires MmaOpParamsI<MmaOpParams<MmaOp>>
124 struct MmaOpTraits : public MmaOpParams<MmaOp>
125 {
126  // Capture internal MmaOp static members
127  using OpType = typename MmaOp::OpType;
128  using AVecType = typename MmaOp::AVecType;
129  using BVecType = typename MmaOp::BVecType;
130  using CVecType = typename MmaOp::CVecType;
131 
132  // Capture layout parameters
133  static constexpr index_t kAMBlock = MmaOp::kAMBlock;
134  static constexpr index_t kBNBlock = MmaOp::kBNBlock;
135  static constexpr index_t kAMLane = MmaOp::kAMLane;
136  static constexpr index_t kBNLane = MmaOp::kBNLane;
137  static constexpr index_t kABKLane = MmaOp::kABKLane;
138  static constexpr index_t kABKPerLane = MmaOp::kABKPerLane;
139  static constexpr index_t kCMLane = MmaOp::kCMLane;
140  static constexpr index_t kCNLane = MmaOp::kCNLane;
141  static constexpr index_t kCM0PerLane = MmaOp::kCM0PerLane;
142  static constexpr index_t kCM1PerLane = MmaOp::kCM1PerLane;
143 
144  // Additional traits to identify the type of MmaOp at compile time
145  constexpr static bool IsMfma = is_mma_op_mfma_v<MmaOp>;
146  constexpr static bool IsWmma = is_mma_op_wmma_v<MmaOp>;
147  constexpr static bool IsSupported = is_mma_op_supported_v<MmaOp>;
148 };
149 
150 } // namespace ck_tile::core::arch::mma
Definition: amdgcn_mma.hpp:10
int32_t index_t
Definition: integer.hpp:9
bool_constant< false > false_type
Definition: integral_constant.hpp:63
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
bool_constant< true > true_type
Definition: integral_constant.hpp:62
unsigned int uint32_t
Definition: stdint.h:126
Reflects the template parameters of a given MmaOp.
Definition: mma_traits.hpp:50
Reflects the template parameters and static members of a given MmaOp.
Definition: mma_traits.hpp:125
static constexpr index_t kCM0PerLane
Definition: mma_traits.hpp:141
typename MmaOp::OpType OpType
Definition: mma_traits.hpp:127
static constexpr index_t kAMLane
Definition: mma_traits.hpp:135
typename MmaOp::CVecType CVecType
Definition: mma_traits.hpp:130
static constexpr index_t kABKPerLane
Definition: mma_traits.hpp:138
static constexpr index_t kCMLane
Definition: mma_traits.hpp:139
static constexpr index_t kCNLane
Definition: mma_traits.hpp:140
static constexpr index_t kABKLane
Definition: mma_traits.hpp:137
static constexpr index_t kAMBlock
Definition: mma_traits.hpp:133
static constexpr index_t kCM1PerLane
Definition: mma_traits.hpp:142
typename MmaOp::BVecType BVecType
Definition: mma_traits.hpp:129
typename MmaOp::AVecType AVecType
Definition: mma_traits.hpp:128
constexpr static bool IsMfma
Definition: mma_traits.hpp:145
static constexpr index_t kBNLane
Definition: mma_traits.hpp:136
constexpr static bool IsSupported
Definition: mma_traits.hpp:147
static constexpr index_t kBNBlock
Definition: mma_traits.hpp:134
constexpr static bool IsWmma
Definition: mma_traits.hpp:146
This is the default MmaOp policy. Instances of this class are to be used as MmaOp policies....
Definition: amdgcn_mma.hpp:82
Trait to check if MmaOp is supported.
Definition: mma_traits.hpp:18