/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp Source File
mfma_gfx9.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "mfma_traits.hpp"
7 
12 
13 namespace ck_tile::core::arch::mma {
14 
15 // NOTE: At this point forward, we are specializing amdgcn_mma for each target id as needed.
16 // This is because some built-ins are only available on certain target ids.
17 // We can also do things such add some padding specializations for when we need to use
18 // smaller values of K that aren't directly supported by the built-ins.
19 // For flexibility, it is recommended that for each backend wrapper it supports at least
20 // one packed register for each input to be able to process smaller K values by padding.
21 
27 {
28  static constexpr uint32_t Cbsz = 0; // CBSZ flag, default 0
29  static constexpr uint32_t Abid = 0; // ABID flag, default 0
30  static constexpr uint32_t Blgp = 0; // BLGP flag, default 0
31 };
32 
33 #if defined(__cpp_concepts) && __cpp_concepts >= 201907L
34 
39 template <typename CtrlFlags>
40 concept CtrlFlagsGfx9I = requires(CtrlFlags ctrlFlags) {
41  // Flag members for Gfx9 MFMA instructions
42  { CtrlFlags::Cbsz } -> std::convertible_to<int>;
43  { CtrlFlags::Abid } -> std::convertible_to<int>;
44  { CtrlFlags::Blgp } -> std::convertible_to<int>;
45 };
46 
47 #endif // defined(__cpp_concepts) && __cpp_concepts >= 201907L
48 
59 // TODO: c++20 template <CtrlFlagsGfx9I CtrlFlags, amdgcn_target CompilerTarget>
60 // TODO: c++20 requires
61 template <typename CtrlFlags, typename CompilerTarget>
63  fp16_t,
64  fp32_t,
65  16u,
66  16u,
67  16u,
68  CtrlFlags,
69  CompilerTarget,
70  enable_if_target_family_gfx9_t<CompilerTarget>>
71 {
72  // Mfma operation type
73  using OpType = MfmaOp;
74 
75  // Register types
79 
80  // Layout constants
81  static constexpr index_t kAMBlock = 1;
82  static constexpr index_t kBNBlock = 1;
83 
84  static constexpr index_t kAMLane = 16;
85  static constexpr index_t kBNLane = 16;
86  static constexpr index_t kABKLane = 4;
87  static constexpr index_t kABKPerLane = 4;
88 
89  static constexpr index_t kCMLane = 4;
90  static constexpr index_t kCNLane = 16;
91  static constexpr index_t kCM0PerLane = 1;
92  static constexpr index_t kCM1PerLane = 4;
93 
94  CK_TILE_DEVICE static auto
95  exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
96  {
97  return {__builtin_amdgcn_mfma_f32_16x16x16f16(aVec,
98  bVec,
99  cVec,
100  static_cast<int>(CtrlFlags::Cbsz),
101  static_cast<int>(CtrlFlags::Abid),
102  static_cast<int>(CtrlFlags::Blgp))};
103  }
104 };
105 
116 // TODO: c++20 template <CtrlFlagsGfx9I CtrlFlags, amdgcn_target CompilerTarget>
117 // TODO: c++20 requires
118 template <typename CtrlFlags, typename CompilerTarget>
120  fp16_t,
121  fp32_t,
122  16u,
123  16u,
124  32u,
125  CtrlFlags,
126  CompilerTarget,
127  enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
128 {
129  using OpType = MfmaOp;
130 
131  // Packed register types
135 
136  // Layout constants
137  static constexpr index_t kAMBlock = 1;
138  static constexpr index_t kBNBlock = 1;
139 
140  static constexpr index_t kAMLane = 16;
141  static constexpr index_t kBNLane = 16;
142  static constexpr index_t kABKLane = 8;
143  static constexpr index_t kABKPerLane = 8;
144 
145  static constexpr index_t kCMLane = 4;
146  static constexpr index_t kCNLane = 16;
147  static constexpr index_t kCM0PerLane = 1;
148  static constexpr index_t kCM1PerLane = 4;
149 
150  CK_TILE_DEVICE static auto
151  exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
152  {
153  return {__builtin_amdgcn_mfma_f32_16x16x32_f16(aVec,
154  bVec,
155  cVec,
156  static_cast<int>(CtrlFlags::Cbsz),
157  static_cast<int>(CtrlFlags::Abid),
158  static_cast<int>(CtrlFlags::Blgp))};
159  }
160 };
161 
162 } // namespace ck_tile::core::arch::mma
#define CK_TILE_DEVICE
Definition: config.hpp:45
Definition: amdgcn_mma.hpp:10
_Float16 fp16_t
Definition: half.hpp:110
int32_t index_t
Definition: integer.hpp:9
float fp32_t
Definition: pk_fp4.hpp:21
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:84
unsigned int uint32_t
Definition: stdint.h:126
Meta-tag for the MFMA operation. This will be used in the MmaOp policies to identify the operation as...
static constexpr uint32_t Abid
Definition: mfma_gfx9.hpp:29
static constexpr uint32_t Blgp
Definition: mfma_gfx9.hpp:30
static constexpr uint32_t Cbsz
Definition: mfma_gfx9.hpp:28
static CK_TILE_DEVICE auto exec(AVecType const &aVec, BVecType const &bVec, CVecType const &cVec) -> CVecType
Definition: mfma_gfx9.hpp:151
static CK_TILE_DEVICE auto exec(AVecType const &aVec, BVecType const &bVec, CVecType const &cVec) -> CVecType
Definition: mfma_gfx9.hpp:95
This is the default MmaOp policy. Instances of this class are to be used as MmaOp policies....
Definition: amdgcn_mma.hpp:82
static constexpr index_t kCNLane
Definition: amdgcn_mma.hpp:101
static constexpr index_t kAMBlock
Definition: amdgcn_mma.hpp:92
static constexpr index_t kAMLane
Definition: amdgcn_mma.hpp:95
static constexpr index_t kCM1PerLane
Definition: amdgcn_mma.hpp:103
static constexpr index_t kBNBlock
Definition: amdgcn_mma.hpp:93
static constexpr index_t kCM0PerLane
Definition: amdgcn_mma.hpp:102
static constexpr index_t kABKLane
Definition: amdgcn_mma.hpp:97
static constexpr index_t kBNLane
Definition: amdgcn_mma.hpp:96
static constexpr index_t kABKPerLane
Definition: amdgcn_mma.hpp:98
static constexpr index_t kCMLane
Definition: amdgcn_mma.hpp:100