6 #include <hip/hip_runtime.h>
19 template <
typename Argument,
typename AsDataType,
typename BsDataType,
typename DsDataType>
32 std::size_t rotating_count_hint,
33 std::array<std::size_t, NumAs> size_as_,
34 std::array<std::size_t, NumBs> size_bs_,
35 std::array<std::size_t, NumDs> size_ds_)
37 rotating_count(rotating_count_hint),
42 p_as_grids.push_back(arg.p_as_grid);
43 p_bs_grids.push_back(arg.p_bs_grid);
44 p_ds_grids.push_back(arg.p_ds_grid);
47 const uint64_t footprint = std::accumulate(size_as.begin(), size_as.end(), 0UL) +
48 std::accumulate(size_bs.begin(), size_bs.end(), 0UL) +
49 std::accumulate(size_ds.begin(), size_ds.end(), 0UL);
50 const uint64_t max_rotating_count = (1ULL << 31) / footprint;
51 rotating_count =
std::min(rotating_count, max_rotating_count);
53 for(
size_t i = 1; i < rotating_count; i++)
59 hip_check_error(hipMalloc(
static_cast<void**
>(&pADeviceBuf), size_as_[j]));
61 static_cast<const void*
>(p_as_grids[0][j]),
63 hipMemcpyDeviceToDevice));
66 as_buffer(j) =
static_cast<const ADataType*
>(pADeviceBuf);
68 p_as_grids.push_back(as_buffer);
75 hip_check_error(hipMalloc(
static_cast<void**
>(&pBDeviceBuf), size_bs_[j]));
77 static_cast<const void*
>(p_bs_grids[0][j]),
79 hipMemcpyDeviceToDevice));
82 bs_buffer(j) =
static_cast<const BDataType*
>(pBDeviceBuf);
84 p_bs_grids.push_back(bs_buffer);
91 hip_check_error(hipMalloc(
static_cast<void**
>(&pDDeviceBuf), size_ds_[j]));
93 static_cast<const void*
>(p_ds_grids[0][j]),
95 hipMemcpyDeviceToDevice));
99 ds_buffer(j) =
static_cast<const DDataType*
>(pDDeviceBuf);
102 p_ds_grids.push_back(ds_buffer);
109 if(rotating_count > 1)
111 std::size_t idx = iter++ % rotating_count;
112 arg.p_as_grid = p_as_grids[idx];
113 arg.p_bs_grid = p_bs_grids[idx];
114 arg.p_ds_grid = p_ds_grids[idx];
119 std::cout <<
"RotatingMemWrapperMultiD: { size_a: {";
121 [&](
auto j) { std::cout << size_as[j] << (j.value <
NumAs - 1 ?
", " :
""); });
122 std::cout <<
"}, size_b: {";
124 [&](
auto j) { std::cout << size_bs[j] << (j.value <
NumBs - 1 ?
", " :
""); });
125 std::cout <<
"}, rotating_count: " << rotating_count <<
"}" << std::endl;
129 if(rotating_count > 1)
132 arg.p_as_grid = p_as_grids[0];
133 arg.p_bs_grid = p_bs_grids[0];
134 arg.p_ds_grid = p_ds_grids[0];
137 for(
size_t i = 1; i < rotating_count; i++)
142 hipFree(
static_cast<void*
>(
const_cast<ADataType*
>(p_as_grids[i][j]))));
148 hipFree(
static_cast<void*
>(
const_cast<BDataType*
>(p_bs_grids[i][j]))));
154 hipFree(
static_cast<void*
>(
const_cast<DDataType*
>(p_ds_grids[i][j]))));
162 std::size_t iter = 0;
163 std::size_t rotating_count = 1;
164 std::array<std::size_t, NumAs> size_as = {0};
165 std::array<std::size_t, NumBs> size_bs = {0};
166 std::array<std::size_t, NumDs> size_ds = {0};
167 std::vector<AsGridPointer> p_as_grids;
168 std::vector<BsGridPointer> p_bs_grids;
169 std::vector<DsGridPointer> p_ds_grids;
172 template <
typename Argument,
typename DsDataType>
183 std::size_t rotating_count_hint,
186 std::array<std::size_t, NumDs> size_ds_)
188 rotating_count(rotating_count_hint),
193 p_a_grids.push_back(arg.p_a_grid);
194 p_b_grids.push_back(arg.p_b_grid);
195 p_ds_grids.push_back(arg.p_ds_grid);
199 std::accumulate(size_ds.begin(), size_ds.end(), 0UL) + (size_a + size_b);
200 const uint64_t max_rotating_count = (1ULL << 31) / footprint;
201 rotating_count =
std::min(rotating_count, max_rotating_count);
203 for(
size_t i = 1; i < rotating_count; i++)
207 hip_check_error(hipMalloc(
static_cast<void**
>(&pADeviceBuf), size_a_));
209 const_cast<void*
>(p_a_grids[0]),
211 hipMemcpyDeviceToDevice));
212 p_a_grids.push_back(pADeviceBuf);
217 hip_check_error(hipMalloc(
static_cast<void**
>(&pBDeviceBuf), size_b_));
219 const_cast<void*
>(p_b_grids[0]),
221 hipMemcpyDeviceToDevice));
222 p_b_grids.push_back(pBDeviceBuf);
230 hip_check_error(hipMalloc(
static_cast<void**
>(&pDDeviceBuf), size_ds_[j]));
232 static_cast<const void*
>(p_ds_grids[0][j]),
234 hipMemcpyDeviceToDevice));
238 ds_buffer(j) =
static_cast<const DDataType*
>(pDDeviceBuf);
241 p_ds_grids.push_back(ds_buffer);
248 if(rotating_count > 1)
250 std::size_t idx = iter++ % rotating_count;
251 arg.p_a_grid =
reinterpret_cast<ADataType>(p_a_grids[idx]);
252 arg.p_b_grid =
reinterpret_cast<BDataType>(p_b_grids[idx]);
253 arg.p_ds_grid = p_ds_grids[idx];
258 std::cout <<
"RotatingMemWrapperMultiD: { size_a: " << size_a <<
", size_b: " << size_b
259 <<
", rotating_count: " << rotating_count <<
"}" << std::endl;
263 if(rotating_count > 1)
266 arg.p_a_grid =
reinterpret_cast<ADataType>(p_a_grids[0]);
267 arg.p_b_grid =
reinterpret_cast<BDataType>(p_b_grids[0]);
268 arg.p_ds_grid = p_ds_grids[0];
271 for(
size_t i = 1; i < rotating_count; i++)
279 hipFree(
static_cast<void*
>(
const_cast<DDataType*
>(p_ds_grids[i][j]))));
287 std::size_t iter = 0;
288 std::size_t rotating_count = 1;
289 std::size_t size_a = 0;
290 std::size_t size_b = 0;
291 std::array<std::size_t, NumDs> size_ds = {0};
292 std::vector<const void*> p_a_grids;
293 std::vector<const void*> p_b_grids;
294 std::vector<DsGridPointer> p_ds_grids;
297 template <
typename Argument>
305 std::size_t rotating_count_hint,
308 : arg(arg_), rotating_count(rotating_count_hint), size_a(size_a_), size_b(size_b_)
310 p_a_grids.push_back(arg.p_a_grid);
311 p_b_grids.push_back(arg.p_b_grid);
314 const uint64_t footprint = (size_a + size_b);
315 const uint64_t max_rotating_count = (1ULL << 31) / footprint;
316 rotating_count =
std::min(rotating_count, max_rotating_count);
318 for(
size_t i = 1; i < rotating_count; i++)
322 hip_check_error(hipMalloc(
static_cast<void**
>(&pADeviceBuf), size_a_));
324 const_cast<void*
>(p_a_grids[0]),
326 hipMemcpyDeviceToDevice));
327 p_a_grids.push_back(pADeviceBuf);
332 hip_check_error(hipMalloc(
static_cast<void**
>(&pBDeviceBuf), size_b_));
334 const_cast<void*
>(p_b_grids[0]),
336 hipMemcpyDeviceToDevice));
337 p_b_grids.push_back(pBDeviceBuf);
344 if(rotating_count > 1)
346 std::size_t idx = iter++ % rotating_count;
347 arg.p_a_grid =
reinterpret_cast<ADataType>(p_a_grids[idx]);
348 arg.p_b_grid =
reinterpret_cast<BDataType>(p_b_grids[idx]);
353 std::cout <<
"RotatingMemWrapper: { size_a: " << size_a <<
", size_b: " << size_b
354 <<
", rotating_count: " << rotating_count <<
"}" << std::endl;
358 if(rotating_count > 1)
361 arg.p_a_grid =
reinterpret_cast<ADataType>(p_a_grids[0]);
362 arg.p_b_grid =
reinterpret_cast<BDataType>(p_b_grids[0]);
365 for(
size_t i = 1; i < rotating_count; i++)
375 std::size_t iter = 0;
376 std::size_t rotating_count = 1;
377 std::size_t size_a = 0;
378 std::size_t size_b = 0;
379 std::vector<const void*> p_a_grids;
380 std::vector<const void*> p_b_grids;
385 hipDeviceProp_t deviceProps;
387 int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
389 ck::flush_icache<<<dim3(gpu_block3), dim3(64), 0,
nullptr>>>();
393 template <
bool TimePreprocess,
397 typename PreProcessFunc>
399 PreProcessFunc preprocess,
403 std::size_t lds_byte,
413 printf(
"%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
422 printf(
"Warm up %d times\n", stream_config.
cold_niters_);
427 kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
431 const int nrepeat = stream_config.
nrepeat_;
438 printf(
"Start running %d times...\n", nrepeat);
442 std::set<float> times;
444 float total_time = 0;
446 hipEvent_t start, stop;
454 for(
int i = 0; i < nrepeat; ++i)
456 if constexpr(!TimePreprocess)
469 if constexpr(TimePreprocess)
474 kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
488 #if !defined(CK_USE_WMMA)
493 printf(
"gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n",
494 static_cast<const void*
>(gemm_args.p_a_grid),
495 static_cast<const void*
>(gemm_args.p_b_grid));
504 times.insert(cur_time);
506 total_time += cur_time;
510 auto mid = times.begin();
511 std::advance(mid, (nrepeat - 1) / 2);
519 std::advance(mid_next, 1);
520 return (*mid + *mid_next) / 2;
524 hipDeviceProp_t deviceProps;
526 float preprocess_offset = deviceProps.multiProcessorCount == 80 ? 0.005 : 0.01;
527 return (total_time - preprocess_offset * nrepeat) / nrepeat;
533 kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
539 kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:10
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
void flush_icache()
Definition: flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition: flush_cache.hpp:398
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:299
signed int int32_t
Definition: stdint.h:123
unsigned __int64 uint64_t
Definition: stdint.h:136
Definition: stream_config.hpp:10
int cold_niters_
Definition: stream_config.hpp:14
bool time_kernel_
Definition: stream_config.hpp:12
int nrepeat_
Definition: stream_config.hpp:15
hipStream_t stream_id_
Definition: stream_config.hpp:11
Definition: functional2.hpp:33
Definition: flush_cache.hpp:299
decltype(Argument::p_a_grid) ADataType
Definition: flush_cache.hpp:300
RotatingMemWrapper(Argument &arg_, std::size_t rotating_count_hint, std::size_t size_a_, std::size_t size_b_)
Definition: flush_cache.hpp:304
decltype(Argument::p_b_grid) BDataType
Definition: flush_cache.hpp:301
~RotatingMemWrapper()
Definition: flush_cache.hpp:356
RotatingMemWrapper()=delete
void Print()
Definition: flush_cache.hpp:351
void Next()
Definition: flush_cache.hpp:342
Definition: flush_cache.hpp:21
static constexpr index_t NumBs
Definition: flush_cache.hpp:23
RotatingMemWrapperMultiABD(Argument &arg_, std::size_t rotating_count_hint, std::array< std::size_t, NumAs > size_as_, std::array< std::size_t, NumBs > size_bs_, std::array< std::size_t, NumDs > size_ds_)
Definition: flush_cache.hpp:31
static constexpr index_t NumDs
Definition: flush_cache.hpp:24
decltype(Argument::p_bs_grid) BsGridPointer
Definition: flush_cache.hpp:27
decltype(Argument::p_ds_grid) DsGridPointer
Definition: flush_cache.hpp:28
void Print()
Definition: flush_cache.hpp:117
void Next()
Definition: flush_cache.hpp:107
static constexpr index_t NumAs
Definition: flush_cache.hpp:22
decltype(Argument::p_as_grid) AsGridPointer
Definition: flush_cache.hpp:26
RotatingMemWrapperMultiABD()=delete
~RotatingMemWrapperMultiABD()
Definition: flush_cache.hpp:127
Definition: flush_cache.hpp:174
decltype(Argument::p_b_grid) BDataType
Definition: flush_cache.hpp:178
void Print()
Definition: flush_cache.hpp:256
static constexpr index_t NumDs
Definition: flush_cache.hpp:175
decltype(Argument::p_a_grid) ADataType
Definition: flush_cache.hpp:177
RotatingMemWrapperMultiD()=delete
~RotatingMemWrapperMultiD()
Definition: flush_cache.hpp:261
decltype(Argument::p_ds_grid) DsGridPointer
Definition: flush_cache.hpp:179
RotatingMemWrapperMultiD(Argument &arg_, std::size_t rotating_count_hint, std::size_t size_a_, std::size_t size_b_, std::array< std::size_t, NumDs > size_ds_)
Definition: flush_cache.hpp:182
void Next()
Definition: flush_cache.hpp:246
#define CK_ENV(name)
Definition: env.hpp:129