/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/host/kernel_launch.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/host/kernel_launch.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/host/kernel_launch.hpp Source File
kernel_launch.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <numeric>
7 #include <functional>
12 #include "ck_tile/host/timer.hpp"
13 #include <cstddef>
14 #include <hip/hip_runtime.h>
15 
16 namespace ck_tile {
17 
18 template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename... Args>
19 #if CK_TILE_USE_LAUNCH_BOUNDS
20 __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
21 #endif
22  __global__ void kentry(Args... args)
23 {
24 #if defined(__HIP_DEVICE_COMPILE__)
25  Kernel{}(args...);
26 #else
27  (..., (ignore = args, 0));
28 #endif
29 }
30 
31 //
32 // return a anonymous functor(lambda) to be called later
33 // the KernelImpl should be a class without non-static data member, or let's say
34 // can be instantiate with "KernelImpl{}"
35 //
36 // the "static __device__ operator()(some_arg)" is the entry point of KernelImpl
37 //
38 template <int MaxThreadPerBlock = CK_TILE_MAX_THREAD_PER_BLOCK,
39  int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
40  typename KernelImpl,
41  typename... Args>
42 CK_TILE_HOST auto
43 make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
44 {
45  const auto kernel = kentry<MaxThreadPerBlock, MinBlockPerCu, KernelImpl, Args...>;
46 
47  return [=](const stream_config& s) {
48  kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
49  };
50 }
51 
52 template <typename... Callables>
53 CK_TILE_HOST void launch_and_check(const stream_config& sc, Callables&&... callables)
54 {
55  // abort the sequence in case of intermediate error
56  if(!((static_cast<void>(callables(sc)), hipPeekAtLastError() == hipSuccess) && ...))
57  {
58  HIP_CHECK_ERROR(hipGetLastError());
59  }
60 }
61 
62 // Measure the preprocess time during the cold iterations
63 template <typename TimerType, typename PreprocessFunc>
64 CK_TILE_HOST double
65 preprocess_profiling_impl(TimerType timer, const stream_config& s, PreprocessFunc preprocess)
66 {
67  timer.start(s.stream_id_);
68  for(int i = 0; i < s.nrepeat_; i++)
69  {
70  if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
71  {
72  preprocess();
73  }
74  }
75  timer.stop(s.stream_id_);
76 
77  return timer.duration() / s.nrepeat_;
78 }
79 
80 template <typename TimerType, typename CallablesFunc, typename PreprocessFunc = std::nullptr_t>
81 CK_TILE_HOST double timing_loop_impl(TimerType timer,
82  const stream_config& s,
83  CallablesFunc&& callables_func,
84  PreprocessFunc preprocess = nullptr)
85 {
86  for(int i = 0; i < s.cold_niters_; i++)
87  {
88  callables_func();
89  }
90  // Only profile preprocess if it's provided
91  auto preprocess_time = 0.0;
92  if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
93  {
94  preprocess_time = preprocess_profiling_impl(gpu_timer{}, s, preprocess);
95  }
96 
97  int i = 0;
98  timer.start(s.stream_id_);
99  while(i < s.nrepeat_)
100  {
101  if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
102  {
103  preprocess();
104  }
105 
106  callables_func();
107  i++;
108  }
109  timer.stop(s.stream_id_);
110 
111  if(!i)
112  return 0.;
113  return (timer.duration() / s.nrepeat_) - preprocess_time;
114 }
115 
116 // clang-format off
117 /*
118  * launch_kernel()
119  *
120  * this is the function to launch arbitrary number of kernels with optional timer(selected by stream_config)
121  * the callables should have signature as "operator()(const stream_config& s){ ... }" to call
122  *
123  * the simplest way is pass in a lambda function, with "[=](const stream_config& s){ call_your_kernel_here() }"
124  * as signature, for the callable (pay attention to the capture list)
125  *
126  * e.g.
127  * ck_tile::launch_kernel(s,
128  * [=](const stream_config& s){ hipMemset(ptr, 0, size) },
129  * [=](const stream_config& s){ some_kernel<<<grids, blocks>>>(arg); }
130  * );
131  *
132  * if you use ck_tile kernel, or similiar to this style (structure with "static __device__ operator()(...){}")
133  * you can pass your kernel to ck_tile::make_kernel(), which will create a anonymous functor for you,
134  * then pass it to ck_tile::launch_kernel()
135  *
136  * e.g.
137  * ck_tile::launch_kernel(s,
138  * ck_tile::make_kernel<T0, B0>(kernel_0{}, grids0, blocks0, 0, kargs0),
139  * ck_tile::make_kernel<T0, B1>(kernel_1{}, grids1, blocks1, 0, kargs1),
140  * ...);
141  **/
142 // clang-format on
143 template <typename... Callables>
144 CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callables)
145 {
146  static_assert(sizeof...(callables) > 0, "At least one callable is required!");
147 
148  if(!s.time_kernel_)
149  {
150  launch_and_check(s, std::forward<Callables>(callables)...);
151  return 0;
152  }
153 
154  auto callables_func = [&]() { launch_and_check(s, std::forward<Callables>(callables)...); };
155 
156  if(s.is_gpu_timer_)
157  {
158  return timing_loop_impl(gpu_timer{}, s, callables_func);
159  }
160  else
161  {
162  return timing_loop_impl(cpu_timer{}, s, callables_func);
163  }
164 }
165 
166 template <typename PreprocessFunc, typename... Callables>
167 CK_TILE_HOST float
168 launch_kernel_time_mask(const stream_config& s, PreprocessFunc preprocess, Callables&&... callables)
169 {
170  static_assert(sizeof...(callables) > 0, "At least one callable is required!");
171 
172  if(!s.time_kernel_)
173  {
174  preprocess();
175  launch_and_check(s, std::forward<Callables>(callables)...);
176  return 0;
177  }
178 
179  auto callables_func = [&]() { launch_and_check(s, std::forward<Callables>(callables)...); };
180 
181  if(s.is_gpu_timer_)
182  {
183  return timing_loop_impl(gpu_timer{}, s, callables_func, preprocess);
184  }
185  else
186  {
187  return timing_loop_impl(cpu_timer{}, s, callables_func, preprocess);
188  }
189 }
190 } // namespace ck_tile
#define CK_TILE_MIN_BLOCK_PER_CU
Definition: config.hpp:114
#define CK_TILE_MAX_THREAD_PER_BLOCK
Definition: config.hpp:113
#define CK_TILE_HOST
Definition: config.hpp:39
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition: hip_check_error.hpp:21
Definition: cluster_descriptor.hpp:13
__global__ void kentry(Args... args)
Definition: kernel_launch.hpp:22
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
CK_TILE_HOST double timing_loop_impl(TimerType timer, const stream_config &s, CallablesFunc &&callables_func, PreprocessFunc preprocess=nullptr)
Definition: kernel_launch.hpp:81
CK_TILE_HOST auto make_kernel(KernelImpl, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:43
CK_TILE_HOST double preprocess_profiling_impl(TimerType timer, const stream_config &s, PreprocessFunc preprocess)
Definition: kernel_launch.hpp:65
CK_TILE_HOST void launch_and_check(const stream_config &sc, Callables &&... callables)
Definition: kernel_launch.hpp:53
CK_TILE_HOST float launch_kernel_time_mask(const stream_config &s, PreprocessFunc preprocess, Callables &&... callables)
Definition: kernel_launch.hpp:168
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition: kernel_launch.hpp:144
Definition: timer.hpp:52
Definition: timer.hpp:15
Definition: stream_config.hpp:30
hipStream_t stream_id_
Definition: stream_config.hpp:31
int cold_niters_
Definition: stream_config.hpp:34
bool time_kernel_
Definition: stream_config.hpp:32
int nrepeat_
Definition: stream_config.hpp:35
bool is_gpu_timer_
Definition: stream_config.hpp:36