/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp Source File
wmma_gfx11.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "wmma_traits.hpp"
7 
12 
13 namespace ck_tile::core::arch::mma {
14 // TODO: Specifically for gfx11 wmma, we need to deal with quirks such as:
15 // - Duplicating A and B inputs
16 // - Handling C / D is always in b32, even for f16 accumulation.
17 // NOTE: Two suggestions:
18 // 1) We could do it here in the wrappers by accepting packed inputs, then swizzling them to
19 // duplicate the inputs as needed before calling the actual built-in. This may introduce
20 // some instruction overhead and violate single responsibility clauses, but keeps the logic
21 // contained within the backend wrapper.
22 // 2) We could do it at a higher level, e.g. in the Mma interface (workflow) by introducing
23 // pre-mma, mma and post-mma steps. The pre-mma step could handle input duplication transform
24 // post-mma could implement D-shuffle transform. This may be cleaner and more flexible than
25 // trying to handle everything in the backend wrappers.
26 //
27 // This current example assumes duplication has already been done, and that C data shuffles have
28 // already been completed. (e.g. option 2 above). These expect duplicated inputs and pre-shuffled
29 // data in C.
30 
31 // NOTE: At this point forward, we are specializing amdgcn_mma for each target id as needed.
32 // This is because some built-ins are only available on certain target ids.
33 // We can also do things, such add some padding specializations for when we need to use
34 // smaller values of K that aren't directly supported by the built-ins.
35 // For flexibility, it is recommended that for each backend wrapper it supports at least
36 // one packed register for each input to be able to process smaller K values by padding.
37 
45 template <typename ADataType, typename BDataType, typename CDataType>
47 {
48  // Generate default flags for signage
49  // Only used currently for integer inputs / accum in gfx11 / gfx12
50  constexpr static WmmaCtrlFlags InputSignA =
51  std::is_signed_v<ADataType> ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED;
52  constexpr static WmmaCtrlFlags InputSignB =
53  std::is_signed_v<BDataType> ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED;
54  constexpr static WmmaCtrlFlags AccumSign =
55  std::is_signed_v<CDataType> ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED;
56 
57  // Generate default flags for accumulator destination bits.
58  // Only used if accumulation size is 16-bit in gfx11
60 };
61 
69 // TODO: c++20 template <CtrlFlagsGfx11I CtrlFlags, amdgcn_target CompilerTarget>
70 // TODO: c++20 requires
71 template <typename CtrlFlags, typename CompilerTarget>
73  fp16_t,
74  fp32_t,
75  16u,
76  16u,
77  16u,
78  CtrlFlags,
79  CompilerTarget,
80  enable_if_target_family_gfx11_t<CompilerTarget>>
81 {
82  // Wmma operation type
83  using OpType = WmmaOp;
84 
85  // Register types (duplicated input / b32 accum)
89 
90  // Layout constants
91  static constexpr index_t kAMBlock = 1;
92  static constexpr index_t kBNBlock = 1;
93  static constexpr index_t kAMLane = 16;
94  static constexpr index_t kBNLane = 16;
95  static constexpr index_t kABKLane = 8;
96  static constexpr index_t kABKPerLane = 8;
97  static constexpr index_t kCMLane = 2;
98  static constexpr index_t kCNLane = 2;
99  static constexpr index_t kCM0PerLane = 4;
100  static constexpr index_t kCM1PerLane = 1;
101 
102  CK_TILE_DEVICE static auto
103  exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
104  {
105  return {__builtin_amdgcn_wmma_f32_16x16x16_f16_w32(aVec, bVec, cVec)};
106  }
107 };
108 
109 } // namespace ck_tile::core::arch::mma
#define CK_TILE_DEVICE
Definition: config.hpp:45
Definition: amdgcn_mma.hpp:10
WmmaCtrlFlags
Common wmma control flags for gfx11 and gfx12.
Definition: wmma.hpp:13
_Float16 fp16_t
Definition: half.hpp:110
int32_t index_t
Definition: integer.hpp:9
float fp32_t
Definition: pk_fp4.hpp:21
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:84
Meta-tag for the WMMA operation. This will be used in the MmaOp struct to identify the operation as a...
constexpr static WmmaCtrlFlags InputSignA
Definition: wmma_gfx11.hpp:50
constexpr static WmmaCtrlFlags InputSignB
Definition: wmma_gfx11.hpp:52
constexpr static WmmaCtrlFlags AccumBits
Definition: wmma_gfx11.hpp:59
constexpr static WmmaCtrlFlags AccumSign
Definition: wmma_gfx11.hpp:54
static CK_TILE_DEVICE auto exec(AVecType const &aVec, BVecType const &bVec, CVecType const &cVec) -> CVecType
Definition: wmma_gfx11.hpp:103
This is the default MmaOp policy. Instances of this class are to be used as MmaOp policies....
Definition: amdgcn_mma.hpp:82
static constexpr index_t kCNLane
Definition: amdgcn_mma.hpp:101
static constexpr index_t kAMBlock
Definition: amdgcn_mma.hpp:92
static constexpr index_t kAMLane
Definition: amdgcn_mma.hpp:95
static constexpr index_t kCM1PerLane
Definition: amdgcn_mma.hpp:103
static constexpr index_t kBNBlock
Definition: amdgcn_mma.hpp:93
static constexpr index_t kCM0PerLane
Definition: amdgcn_mma.hpp:102
static constexpr index_t kABKLane
Definition: amdgcn_mma.hpp:97
static constexpr index_t kBNLane
Definition: amdgcn_mma.hpp:96
static constexpr index_t kABKPerLane
Definition: amdgcn_mma.hpp:98
static constexpr index_t kCMLane
Definition: amdgcn_mma.hpp:100