/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp Source File
gemm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
17 
18 namespace ck_tile {
19 
29 {
31  CK_TILE_HOST GemmHostArgs(const void* a_ptr_,
32  const void* b_ptr_,
33  void* e_ptr_,
34  index_t k_batch_,
35  index_t M_,
36  index_t N_,
37  index_t K_,
38  index_t stride_A_,
39  index_t stride_B_,
40  index_t stride_E_)
41  : a_ptr(a_ptr_),
42  b_ptr(b_ptr_),
43  e_ptr(e_ptr_),
44  M(M_),
45  N(N_),
46  K(K_),
47  stride_A(stride_A_),
48  stride_B(stride_B_),
49  stride_E(stride_E_),
50  k_batch(k_batch_)
51  {
52  }
53 
54  const void* a_ptr;
55  const void* b_ptr;
56  union
57  {
58  void* e_ptr;
59  void* c_ptr;
60  };
61 
67 
68  union
69  {
72  };
73 
75 };
76 
77 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
78 struct GemmKernel
79 {
84 
88 
93 
98 
100  static_assert(
102  "ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
103 
105  static_assert(
107  "BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
108 
110  static_assert(!is_detected<is_tuple, ELayout>::value &&
112  "C/ELayout and C/EDataType must be scalars.");
113 
114  static constexpr index_t NumATensor = 1;
115  static constexpr index_t NumBTensor = 1;
116 
117  CK_TILE_HOST static auto GetName() -> const std::string
118  {
120  }
121 
122  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3
123  {
124  return UniversalGemmKernel::GridSize(M, N, KBatch);
125  }
126 
127  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
128  {
130  }
131 
132  CK_TILE_HOST static constexpr auto BlockSize() -> dim3
133  {
135  }
136 
137  CK_TILE_HOST static constexpr auto MakeKernelArgs(const GemmHostArgs& hostArgs) ->
139  {
144  {hostArgs.a_ptr},
145  {hostArgs.b_ptr},
146  {/*hostArgs.ds_ptr*/},
147  hostArgs.e_ptr,
148  hostArgs.k_batch,
149  hostArgs.M,
150  hostArgs.N,
151  hostArgs.K,
152  {hostArgs.stride_A},
153  {hostArgs.stride_B},
154  {/*hostArgs.stride_Ds*/},
155  hostArgs.stride_E));
156  }
157 
158  CK_TILE_HOST static auto
160  {
162  }
163 
164  CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void
165  {
166  UniversalGemmKernel{}.template operator()(kargs);
167  }
168 };
169 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
The GEMM kernel host arguments.
Definition: gemm_kernel.hpp:29
CK_TILE_HOST GemmHostArgs()=default
void * c_ptr
Definition: gemm_kernel.hpp:59
index_t stride_E
Definition: gemm_kernel.hpp:70
index_t stride_B
Definition: gemm_kernel.hpp:66
index_t stride_C
Definition: gemm_kernel.hpp:71
void * e_ptr
Definition: gemm_kernel.hpp:58
index_t K
Definition: gemm_kernel.hpp:64
index_t M
Definition: gemm_kernel.hpp:62
CK_TILE_HOST GemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_E_)
Definition: gemm_kernel.hpp:31
index_t stride_A
Definition: gemm_kernel.hpp:65
const void * a_ptr
Definition: gemm_kernel.hpp:54
const void * b_ptr
Definition: gemm_kernel.hpp:55
index_t N
Definition: gemm_kernel.hpp:63
index_t k_batch
Definition: gemm_kernel.hpp:74
Definition: gemm_kernel.hpp:79
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: gemm_kernel.hpp:97
static constexpr CK_TILE_HOST auto BlockSize() -> dim3
Definition: gemm_kernel.hpp:132
UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > UniversalGemmKernel
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Definition: gemm_kernel.hpp:83
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Specify the data type configurations for A, B, E and D.
Definition: gemm_kernel.hpp:95
remove_cvref_t< typename GemmPipeline::CLayout > ELayout
Definition: gemm_kernel.hpp:92
static constexpr CK_TILE_HOST auto MakeKernelArgs(const GemmHostArgs &hostArgs) -> typename UniversalGemmKernel::KernelArgs
Definition: gemm_kernel.hpp:137
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3
Definition: gemm_kernel.hpp:122
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Specify the layout configurations for A, B, E and D.
Definition: gemm_kernel.hpp:90
static constexpr index_t NumBTensor
Definition: gemm_kernel.hpp:115
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: gemm_kernel.hpp:96
CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void
Definition: gemm_kernel.hpp:164
static constexpr index_t NumATensor
ALayout and ADataType are expected to be scalars, not a tuple.
Definition: gemm_kernel.hpp:114
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Definition: gemm_kernel.hpp:127
static CK_TILE_HOST auto IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs &kargs) -> bool
Definition: gemm_kernel.hpp:159
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_kernel.hpp:86
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: gemm_kernel.hpp:91
static CK_TILE_HOST auto GetName() -> const std::string
Definition: gemm_kernel.hpp:117
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_kernel.hpp:85
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_kernel.hpp:87
The Universal GEMM kernel host arguments.
Definition: universal_gemm_kernel.hpp:32
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:86
static CK_TILE_HOST const std::string GetName()
Definition: universal_gemm_kernel.hpp:240
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: universal_gemm_kernel.hpp:247
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: universal_gemm_kernel.hpp:258
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:342
static constexpr CK_TILE_HOST auto BlockSize()
Definition: universal_gemm_kernel.hpp:269
static constexpr CK_TILE_HOST KernelArgs MakeKernelArgs(const UniversalGemmHostArgs< NumATensor, NumBTensor, NumDTensor > &hostArgs)
Definition: universal_gemm_kernel.hpp:272
Definition: stream_config.hpp:30