/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp Source File
warp_gemm_impl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 namespace ck_tile {
8 
9 template <typename WarpGemmAttribute_>
11 {
13 
14  static constexpr index_t kM = WarpGemmAttribute::kM;
15  static constexpr index_t kN = WarpGemmAttribute::kN;
16  static constexpr index_t kK = WarpGemmAttribute::kK;
21  static constexpr index_t kKPerThread = WarpGemmAttribute::kKPerThread;
22 
23  using ADataType = typename WarpGemmAttribute::ADataType;
24  using BDataType = typename WarpGemmAttribute::BDataType;
25  using CDataType = typename WarpGemmAttribute::CDataType;
26 
27  using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding;
28  using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding;
29  using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding;
30 
34 
38 
39  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access()
40  {
41  return WarpGemmAttribute_::get_num_of_access();
42  }
43 
44  template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
45  CK_TILE_DEVICE void
46  operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant<post_nop_> = {}) const
47  {
48  static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
49  detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
50  detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
51  using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
52  using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
53  using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
54 
55  constexpr auto I0 = number<0>{};
56 
57  const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
58  const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
59  auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
60 
61  // c_vec += a_vec * b_vec
62  WarpGemmAttribute{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
63 
64  c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
65  }
66 
67  template <typename CTensor,
68  typename ATensor,
69  typename BTensor,
70  index_t i_subk,
71  bool post_nop_ = false>
72  CK_TILE_DEVICE void operator()(CTensor& c,
73  const ATensor& a,
74  const BTensor& b,
76  bool_constant<post_nop_> = {}) const
77  {
78  using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
79  using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
80  using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
81 
82  constexpr auto I0 = number<0>{};
83 
84  const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
85  const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
86  auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
87 
88  // c_vec += a_vec * b_vec
89  WarpGemmAttribute{}(c_vec, a_vec, b_vec, number<i_subk>{}, bool_constant<post_nop_>{});
90 
91  c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
92  }
93 
94  template <typename ATensor, typename BTensor>
95  CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const
96  {
97  using CTensor = CWarpTensor;
98  static_assert(detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
99  detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
100  CTensor c;
101 
102  using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
103  using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
104  using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
105 
106  constexpr auto I0 = number<0>{};
107 
108  const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
109  const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
110 
111  // c_vec = a_vec * b_vec
112  auto c_vec = WarpGemmAttribute{}(a_vec, b_vec);
113 
114  c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
115 
116  return c;
117  }
118 };
119 
120 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:54
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
Definition: warp_gemm_impl.hpp:11
CK_TILE_DEVICE void operator()(CTensor &c, const ATensor &a, const BTensor &b, number< i_subk >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_impl.hpp:72
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_impl.hpp:39
typename WarpGemmAttribute::CWarpDstrEncoding CWarpDstrEncoding
Definition: warp_gemm_impl.hpp:29
CK_TILE_DEVICE auto operator()(const ATensor &a, const BTensor &b) const
Definition: warp_gemm_impl.hpp:95
typename WarpGemmAttribute::BWarpDstrEncoding BWarpDstrEncoding
Definition: warp_gemm_impl.hpp:28
remove_cvref_t< decltype(make_static_tile_distribution(BWarpDstrEncoding{}))> BWarpDstr
Definition: warp_gemm_impl.hpp:32
static constexpr index_t kKPerThread
The number of elements in K dimension processed by single thread in wavefront.
Definition: warp_gemm_impl.hpp:21
typename WarpGemmAttribute::AWarpDstrEncoding AWarpDstrEncoding
Definition: warp_gemm_impl.hpp:27
CK_TILE_DEVICE void operator()(CTensor &c, const ATensor &a, const BTensor &b, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_impl.hpp:46
static constexpr index_t kM
Definition: warp_gemm_impl.hpp:14
static constexpr index_t kK
Definition: warp_gemm_impl.hpp:16
typename WarpGemmAttribute::CDataType CDataType
Definition: warp_gemm_impl.hpp:25
typename WarpGemmAttribute::ADataType ADataType
Definition: warp_gemm_impl.hpp:23
remove_cvref_t< decltype(make_static_tile_distribution(CWarpDstrEncoding{}))> CWarpDstr
Definition: warp_gemm_impl.hpp:33
static constexpr index_t kN
Definition: warp_gemm_impl.hpp:15
remove_cvref_t< WarpGemmAttribute_ > WarpGemmAttribute
Definition: warp_gemm_impl.hpp:12
remove_cvref_t< decltype(make_static_tile_distribution(AWarpDstrEncoding{}))> AWarpDstr
Definition: warp_gemm_impl.hpp:31
static_distributed_tensor< CDataType, CWarpDstr > CWarpTensor
Definition: warp_gemm_impl.hpp:37
typename WarpGemmAttribute::BDataType BDataType
Definition: warp_gemm_impl.hpp:24
Definition: integral_constant.hpp:13
Definition: static_distributed_tensor.hpp:21