/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp Source File
wmma_transforms.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
9 
10 namespace ck_tile::core::arch::mma {
11 
17 {
18  template <typename VecType>
19  CK_TILE_DEVICE static decltype(auto) exec(VecType&& v)
20  {
21  // TODO: Implement duplication logic to broadcast low
22  // register elements to high elements [0 - (N/2 -1)] -> [N/2 - (N-1)]
23  return std::forward<VecType>(v);
24  }
25 };
26 
32 {
33  template <typename VecType>
34  CK_TILE_DEVICE static decltype(auto) exec(VecType&& v)
35  {
36  // TODO: Implement b32 padding logic.
37  // E.g., for fp16, pad each 16-bit element with 16 bits of 0 to make 32-bit elements
38  return std::forward<VecType>(v);
39  }
40 };
41 
47 {
48  template <typename VecType>
49  CK_TILE_DEVICE static decltype(auto) exec(VecType&& v)
50  {
51  // TODO: Implement b32 logic to unpad 32 to original data type.
52  return std::forward<VecType>(v);
53  }
54 };
55 
61 {
66 };
67 
73 {
78 };
79 
86 template <typename MmaOp, typename CompilerTarget>
87 // TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id GfxTargetId>
88 // TODO: c++20 requires
90  CompilerTarget,
91  enable_if_target_family_gfx11_t<CompilerTarget>>
92 {
94 };
95 
102 template <typename MmaOp, typename CompilerTarget>
103 // TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id GfxTargetId>
104 // TODO: c++20 requires
106  CompilerTarget,
107  enable_if_target_family_gfx12_t<CompilerTarget>>
108 {
110 };
111 
112 } // namespace ck_tile::core::arch::mma
#define CK_TILE_DEVICE
Definition: config.hpp:45
Definition: amdgcn_mma.hpp:10
Transform to duplicate low register elements to high register elements.
Definition: wmma_transforms.hpp:17
static decltype(auto) CK_TILE_DEVICE exec(VecType &&v)
Definition: wmma_transforms.hpp:19
Default MMA transforms for GFX11 architecture.
Definition: wmma_transforms.hpp:61
Default MMA transforms for GFX12 architecture.
Definition: wmma_transforms.hpp:73
Implements the default MMA transforms selection for gfx9 targets.
Definition: mma_transforms.hpp:28
Transform to pad data from original type to b32 type.
Definition: wmma_transforms.hpp:32
static decltype(auto) CK_TILE_DEVICE exec(VecType &&v)
Definition: wmma_transforms.hpp:34
A no-op transform that passes through the input as-is.
Definition: mma_transforms.hpp:11
Transform to unpad data from b32 type to original type.
Definition: wmma_transforms.hpp:47
static decltype(auto) CK_TILE_DEVICE exec(VecType &&v)
Definition: wmma_transforms.hpp:49