/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp Source File
wmma_gfx12.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "wmma_traits.hpp"
7 
12 
13 namespace ck_tile::core::arch::mma {
14 
15 // NOTE: At this point forward, we are specializing amdgcn_mma for each target id as needed.
16 // This is because some built-ins are only available on certain target ids.
17 // We can also do things, such add some padding specializations for when we need to use
18 // smaller values of K that aren't directly supported by the built-ins.
19 // For flexibility, it is recommended that for each backend wrapper it supports at least
20 // one packed register for each input to be able to process smaller K values by padding.
21 
29 // TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
30 // TODO: c++20 requires
31 template <typename CtrlFlags, typename CompilerTarget>
33  fp16_t,
34  fp32_t,
35  16u,
36  16u,
37  16u,
38  CtrlFlags,
39  CompilerTarget,
40  enable_if_target_family_gfx12_t<CompilerTarget>>
41 {
42  // Wmma operation type
43  using OpType = WmmaOp;
44 
45  // Register types
49 
50  // Layout constants
51  static constexpr index_t kAMBlock = 1;
52  static constexpr index_t kBNBlock = 1;
53  static constexpr index_t kAMLane = 16;
54  static constexpr index_t kBNLane = 16;
55  static constexpr index_t kABKLane = 8;
56  static constexpr index_t kABKPerLane = 8;
57  static constexpr index_t kCMLane = 2;
58  static constexpr index_t kCNLane = 2;
59  static constexpr index_t kCM0PerLane = 4;
60  static constexpr index_t kCM1PerLane = 1;
61 
62  CK_TILE_DEVICE static auto
63  exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
64  {
65  return {__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(aVec, bVec, cVec)};
66  }
67 };
68 
69 } // namespace ck_tile::core::arch::mma
#define CK_TILE_DEVICE
Definition: config.hpp:45
Definition: amdgcn_mma.hpp:10
_Float16 fp16_t
Definition: half.hpp:110
int32_t index_t
Definition: integer.hpp:9
float fp32_t
Definition: pk_fp4.hpp:21
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:84
Meta-tag for the WMMA operation. This will be used in the MmaOp struct to identify the operation as a...
static CK_TILE_DEVICE auto exec(AVecType const &aVec, BVecType const &bVec, CVecType const &cVec) -> CVecType
Definition: wmma_gfx12.hpp:63
This is the default MmaOp policy. Instances of this class are to be used as MmaOp policies....
Definition: amdgcn_mma.hpp:82
static constexpr index_t kCNLane
Definition: amdgcn_mma.hpp:101
static constexpr index_t kAMBlock
Definition: amdgcn_mma.hpp:92
static constexpr index_t kAMLane
Definition: amdgcn_mma.hpp:95
static constexpr index_t kCM1PerLane
Definition: amdgcn_mma.hpp:103
static constexpr index_t kBNBlock
Definition: amdgcn_mma.hpp:93
static constexpr index_t kCM0PerLane
Definition: amdgcn_mma.hpp:102
static constexpr index_t kABKLane
Definition: amdgcn_mma.hpp:97
static constexpr index_t kBNLane
Definition: amdgcn_mma.hpp:96
static constexpr index_t kABKPerLane
Definition: amdgcn_mma.hpp:98
static constexpr index_t kCMLane
Definition: amdgcn_mma.hpp:100