/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/kernel_launch.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/kernel_launch.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/kernel_launch.hpp Source File
kernel_launch.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <numeric>
7 #include <functional>
12 #include "ck_tile/host/timer.hpp"
13 #include <cstddef>
14 #include <hip/hip_runtime.h>
15 
16 namespace ck_tile {
17 
18 template <int MinBlockPerCu, typename Kernel, typename... Args>
19 #if CK_TILE_USE_LAUNCH_BOUNDS
20 __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
21 #endif
22  __global__ void kentry(Args... args)
23 {
24 #if defined(__HIP_DEVICE_COMPILE__)
25  Kernel{}(args...);
26 #else
27  (..., (ignore = args, 0));
28 #endif
29 }
30 
31 //
32 // return a anonymous functor(lambda) to be called later
33 // the KernelImpl should be a class without non-static data member, or let's say
34 // can be instantiate with "KernelImpl{}"
35 //
36 // the "static __device__ operator()(some_arg)" is the entry point of KernelImpl
37 //
38 template <int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU, typename KernelImpl, typename... Args>
39 CK_TILE_HOST auto
40 make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
41 {
42  const auto kernel = kentry<MinBlockPerCu, KernelImpl, Args...>;
43  return [=](const stream_config& s) {
44  kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
45  };
46 }
47 
48 template <typename... Callables>
49 CK_TILE_HOST void launch_and_check(const stream_config& sc, Callables&&... callables)
50 {
51  // abort the sequence in case of intermediate error
52  if(!((static_cast<void>(callables(sc)), hipPeekAtLastError() == hipSuccess) && ...))
53  {
54  HIP_CHECK_ERROR(hipGetLastError());
55  }
56 }
57 
58 // Measure the preprocess time during the cold iterations
59 template <typename TimerType, typename PreprocessFunc>
60 CK_TILE_HOST double
61 preprocess_profiling_impl(TimerType timer, const stream_config& s, PreprocessFunc preprocess)
62 {
63  timer.start(s.stream_id_);
64  for(int i = 0; i < s.nrepeat_; i++)
65  {
66  if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
67  {
68  preprocess();
69  }
70  }
71  timer.stop(s.stream_id_);
72 
73  return timer.duration() / s.nrepeat_;
74 }
75 
76 template <typename TimerType, typename CallablesFunc, typename PreprocessFunc = std::nullptr_t>
77 CK_TILE_HOST double timing_loop_impl(TimerType timer,
78  const stream_config& s,
79  CallablesFunc&& callables_func,
80  PreprocessFunc preprocess = nullptr)
81 {
82  for(int i = 0; i < s.cold_niters_; i++)
83  {
84  callables_func();
85  }
86  // Only profile preprocess if it's provided
87  auto preprocess_time = 0.0;
88  if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
89  {
90  preprocess_time = preprocess_profiling_impl(gpu_timer{}, s, preprocess);
91  }
92 
93  int i = 0;
94  timer.start(s.stream_id_);
95  while(i < s.nrepeat_)
96  {
97  if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
98  {
99  preprocess();
100  }
101 
102  callables_func();
103  i++;
104  }
105  timer.stop(s.stream_id_);
106 
107  if(!i)
108  return 0.;
109  return (timer.duration() / s.nrepeat_) - preprocess_time;
110 }
111 
112 // clang-format off
113 /*
114  * launch_kernel()
115  *
116  * this is the function to launch arbitrary number of kernels with optional timer(selected by stream_config)
117  * the callables should have signature as "operator()(const stream_config& s){ ... }" to call
118  *
119  * the simplest way is pass in a lambda function, with "[=](const stream_config& s){ call_your_kernel_here() }"
120  * as signature, for the callable (pay attention to the capture list)
121  *
122  * e.g.
123  * ck_tile::launch_kernel(s,
124  * [=](const stream_config& s){ hipMemset(ptr, 0, size) },
125  * [=](const stream_config& s){ some_kernel<<<grids, blocks>>>(arg); }
126  * );
127  *
128  * if you use ck_tile kernel, or similiar to this style (structure with "static __device__ operator()(...){}")
129  * you can pass your kernel to ck_tile::make_kernel(), which will create a anonymous functor for you,
130  * then pass it to ck_tile::launch_kernel()
131  *
132  * e.g.
133  * ck_tile::launch_kernel(s,
134  * ck_tile::make_kernel<T0, B0>(kernel_0{}, grids0, blocks0, 0, kargs0),
135  * ck_tile::make_kernel<T0, B1>(kernel_1{}, grids1, blocks1, 0, kargs1),
136  * ...);
137  **/
138 // clang-format on
139 template <typename... Callables>
140 CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callables)
141 {
142  static_assert(sizeof...(callables) > 0, "At least one callable is required!");
143 
144  if(!s.time_kernel_)
145  {
146  launch_and_check(s, std::forward<Callables>(callables)...);
147  return 0;
148  }
149 
150  auto callables_func = [&]() { launch_and_check(s, std::forward<Callables>(callables)...); };
151 
152  if(s.is_gpu_timer_)
153  {
154  return timing_loop_impl(gpu_timer{}, s, callables_func);
155  }
156  else
157  {
158  return timing_loop_impl(cpu_timer{}, s, callables_func);
159  }
160 }
161 
162 template <typename PreprocessFunc, typename... Callables>
163 CK_TILE_HOST float
164 launch_kernel_time_mask(const stream_config& s, PreprocessFunc preprocess, Callables&&... callables)
165 {
166  static_assert(sizeof...(callables) > 0, "At least one callable is required!");
167 
168  if(!s.time_kernel_)
169  {
170  preprocess();
171  launch_and_check(s, std::forward<Callables>(callables)...);
172  return 0;
173  }
174 
175  auto callables_func = [&]() { launch_and_check(s, std::forward<Callables>(callables)...); };
176 
177  if(s.is_gpu_timer_)
178  {
179  return timing_loop_impl(gpu_timer{}, s, callables_func, preprocess);
180  }
181  else
182  {
183  return timing_loop_impl(cpu_timer{}, s, callables_func, preprocess);
184  }
185 }
186 } // namespace ck_tile
#define CK_TILE_MIN_BLOCK_PER_CU
Definition: config.hpp:115
#define CK_TILE_HOST
Definition: config.hpp:40
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition: hip_check_error.hpp:21
Definition: cluster_descriptor.hpp:13
__global__ void kentry(Args... args)
Definition: kernel_launch.hpp:22
CK_TILE_HOST auto make_kernel(KernelImpl, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:40
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
CK_TILE_HOST double timing_loop_impl(TimerType timer, const stream_config &s, CallablesFunc &&callables_func, PreprocessFunc preprocess=nullptr)
Definition: kernel_launch.hpp:77
CK_TILE_HOST double preprocess_profiling_impl(TimerType timer, const stream_config &s, PreprocessFunc preprocess)
Definition: kernel_launch.hpp:61
CK_TILE_HOST void launch_and_check(const stream_config &sc, Callables &&... callables)
Definition: kernel_launch.hpp:49
CK_TILE_HOST float launch_kernel_time_mask(const stream_config &s, PreprocessFunc preprocess, Callables &&... callables)
Definition: kernel_launch.hpp:164
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition: kernel_launch.hpp:140
Definition: timer.hpp:52
Definition: timer.hpp:15
Definition: stream_config.hpp:30
hipStream_t stream_id_
Definition: stream_config.hpp:31
int cold_niters_
Definition: stream_config.hpp:34
bool time_kernel_
Definition: stream_config.hpp:32
int nrepeat_
Definition: stream_config.hpp:35
bool is_gpu_timer_
Definition: stream_config.hpp:36