/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/host/kernel_launch.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/host/kernel_launch.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/host/kernel_launch.hpp Source File
kernel_launch.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
10 #include "ck_tile/host/timer.hpp"
11 #include <cstddef>
12 #include <hip/hip_runtime.h>
13 
14 namespace ck_tile {
15 
16 #define LOW_CU_PROCESSORS 80
17 #define HIGH_CU_PROCESSORS 228
18 #define OPTIMAL_LATENCY_LOW_CU_PROCESSORS 0.005
19 #define OPTIMAL_LATENCY_HIGH_CU_PROCESSORS 0.0015
20 #define OPTIMAL_LATENCY_SAFE_MARGIN 0.01
21 
22 template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename... Args>
23 #if CK_TILE_USE_LAUNCH_BOUNDS
24 __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
25 #endif
26  __global__ void kentry(Args... args)
27 {
28 #if defined(__HIP_DEVICE_COMPILE__)
29  Kernel{}(args...);
30 #else
31  (..., (ignore = args, 0));
32 #endif
33 }
34 
35 //
36 // return a anonymous functor(lambda) to be called later
37 // the KernelImpl should be a class without non-static data member, or let's say
38 // can be instantiate with "KernelImpl{}"
39 //
40 // the "static __device__ operator()(some_arg)" is the entry point of KernelImpl
41 //
42 template <int MaxThreadPerBlock = CK_TILE_MAX_THREAD_PER_BLOCK,
43  int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
44  typename KernelImpl,
45  typename... Args>
46 CK_TILE_HOST auto
47 make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
48 {
49  const auto kernel = kentry<MaxThreadPerBlock, MinBlockPerCu, KernelImpl, Args...>;
50 
51  return [=](const stream_config& s) {
52  kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
53  };
54 }
55 
56 template <typename... Callables>
57 CK_TILE_HOST void launch_and_check(const stream_config& sc, Callables&&... callables)
58 {
59  // abort the sequence in case of intermediate error
60  if(!((static_cast<void>(callables(sc)), hipPeekAtLastError() == hipSuccess) && ...))
61  {
62  HIP_CHECK_ERROR(hipGetLastError());
63  }
64 }
65 
66 // clang-format off
67 /*
68  * launch_kernel()
69  *
70  * this is the function to launch arbitrary number of kernels with optional timer(selected by stream_config)
71  * the callables should have signature as "operator()(const stream_config& s){ ... }" to call
72  *
73  * the simplest way is pass in a lambda function, with "[=](const stream_config& s){ call_your_kernel_here() }"
74  * as signature, for the callable (pay attention to the capture list)
75  *
76  * e.g.
77  * ck_tile::launch_kernel(s,
78  * [=](const stream_config& s){ hipMemset(ptr, 0, size) },
79  * [=](const stream_config& s){ some_kernel<<<grids, blocks>>>(arg); }
80  * );
81  *
82  * if you use ck_tile kernel, or similiar to this style (structure with "static __device__ operator()(...){}")
83  * you can pass your kernel to ck_tile::make_kernel(), which will create a anonymous functor for you,
84  * then pass it to ck_tile::launch_kernel()
85  *
86  * e.g.
87  * ck_tile::launch_kernel(s,
88  * ck_tile::make_kernel<T0, B0>(kernel_0{}, grids0, blocks0, 0, kargs0),
89  * ck_tile::make_kernel<T0, B1>(kernel_1{}, grids1, blocks1, 0, kargs1),
90  * ...);
91  **/
92 // clang-format on
93 template <typename... Callables>
94 CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callables)
95 {
96  static_assert(sizeof...(callables) > 0, "At least one callable is required!");
97 
98  if(!s.time_kernel_)
99  {
100  launch_and_check(s, std::forward<Callables>(callables)...);
101  return 0;
102  }
103 
104  auto time_launches = [&](auto timer) {
105  // Warmup
106  for(int i = 0; i < s.cold_niters_; i++)
107  {
108  launch_and_check(s, std::forward<Callables>(callables)...);
109  }
110 
111  timer.start(s.stream_id_);
112  for(int i = 0; i < s.nrepeat_; i++)
113  {
114  launch_and_check(s, std::forward<Callables>(callables)...);
115  }
116  timer.stop(s.stream_id_);
117 
118  return timer.duration() / s.nrepeat_;
119  };
120 
121  if(s.is_gpu_timer_)
122  {
123  return time_launches(gpu_timer{});
124  }
125  else
126  {
127  return time_launches(cpu_timer{});
128  }
129 }
130 
131 template <typename PreprocessFunc, typename... Callables>
133  PreprocessFunc preprocess,
134  Callables&&... callables)
135 {
136  static_assert(sizeof...(callables) > 0, "At least one callable is required!");
137 
138  if(!s.time_kernel_)
139  {
140  preprocess();
141  launch_and_check(s, std::forward<Callables>(callables)...);
142  return 0;
143  }
144 
145  auto time_launches = [&](auto timer) {
146  // Warmup
147  for(int i = 0; i < s.cold_niters_; i++)
148  {
149  launch_and_check(s, std::forward<Callables>(callables)...);
150  }
151 
152  timer.start(s.stream_id_);
153  for(int i = 0; i < s.nrepeat_; i++)
154  {
155  preprocess();
156  launch_and_check(s, std::forward<Callables>(callables)...);
157  }
158  timer.stop(s.stream_id_);
159 
160  hipDeviceProp_t deviceProps;
161  HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0));
162 
163  float preprocess_offset = (deviceProps.multiProcessorCount >= HIGH_CU_PROCESSORS)
165  : (deviceProps.multiProcessorCount == LOW_CU_PROCESSORS)
168  return (timer.duration() - preprocess_offset * s.nrepeat_) / s.nrepeat_;
169  };
170 
171  if(s.is_gpu_timer_)
172  {
173  return time_launches(gpu_timer{});
174  }
175  else
176  {
177  return time_launches(cpu_timer{});
178  }
179 }
180 } // namespace ck_tile
#define CK_TILE_MIN_BLOCK_PER_CU
Definition: config.hpp:114
#define CK_TILE_MAX_THREAD_PER_BLOCK
Definition: config.hpp:113
#define CK_TILE_HOST
Definition: config.hpp:39
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition: hip_check_error.hpp:22
Definition: cluster_descriptor.hpp:13
__global__ void kentry(Args... args)
Definition: kernel_launch.hpp:26
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
CK_TILE_HOST auto make_kernel(KernelImpl, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:47
CK_TILE_HOST void launch_and_check(const stream_config &sc, Callables &&... callables)
Definition: kernel_launch.hpp:57
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition: kernel_launch.hpp:94
CK_TILE_HOST float launch_kernel_preprocess(const stream_config &s, PreprocessFunc preprocess, Callables &&... callables)
Definition: kernel_launch.hpp:132
Definition: timer.hpp:52
Definition: timer.hpp:15
Definition: stream_config.hpp:26
hipStream_t stream_id_
Definition: stream_config.hpp:27
int cold_niters_
Definition: stream_config.hpp:30
bool time_kernel_
Definition: stream_config.hpp:28
int nrepeat_
Definition: stream_config.hpp:31
bool is_gpu_timer_
Definition: stream_config.hpp:32
#define OPTIMAL_LATENCY_HIGH_CU_PROCESSORS
Definition: kernel_launch.hpp:19
#define LOW_CU_PROCESSORS
Definition: kernel_launch.hpp:16
#define OPTIMAL_LATENCY_LOW_CU_PROCESSORS
Definition: kernel_launch.hpp:18
#define HIGH_CU_PROCESSORS
Definition: kernel_launch.hpp:17
#define OPTIMAL_LATENCY_SAFE_MARGIN
Definition: kernel_launch.hpp:20