6 #include <hip/hip_runtime.h> 
   18 template <
typename Argument, 
typename DsDataType>
 
   29                              std::size_t rotating_count_,
 
   32                              std::array<std::size_t, NumDs> size_ds_)
 
   34           rotating_count(rotating_count_),
 
   39         p_a_grids.push_back(arg.p_a_grid);
 
   40         p_b_grids.push_back(arg.p_b_grid);
 
   41         p_ds_grids.push_back(arg.p_ds_grid);
 
   42         for(
size_t i = 1; i < rotating_count; i++)
 
   48                                           const_cast<void*
>(p_a_grids[0]),
 
   50                                           hipMemcpyDeviceToDevice));
 
   51                 p_a_grids.push_back(pADeviceBuf);
 
   58                                           const_cast<void*
>(p_b_grids[0]),
 
   60                                           hipMemcpyDeviceToDevice));
 
   61                 p_b_grids.push_back(pBDeviceBuf);
 
   69                     hip_check_error(hipMalloc(
static_cast<void**
>(&pDDeviceBuf), size_ds_[j]));
 
   71                                               static_cast<const void*
>(p_ds_grids[0][j]),
 
   73                                               hipMemcpyDeviceToDevice));
 
   77                     ds_buffer(j) = 
static_cast<const DDataType*
>(pDDeviceBuf);
 
   80                 p_ds_grids.push_back(ds_buffer);
 
   87         if(rotating_count > 1)
 
   89             std::size_t idx = iter++ % rotating_count;
 
   90             arg.p_a_grid    = 
reinterpret_cast<ADataType>(p_a_grids[idx]);
 
   91             arg.p_b_grid    = 
reinterpret_cast<BDataType>(p_b_grids[idx]);
 
   92             arg.p_ds_grid   = p_ds_grids[idx];
 
   97         std::cout << 
"RotatingMemWrapperMultiD: { size_a: " << size_a << 
", size_b: " << size_b
 
   98                   << 
", rotating_count: " << rotating_count << 
"}" << std::endl;
 
  102         if(rotating_count > 1)
 
  105             arg.p_a_grid  = 
reinterpret_cast<ADataType>(p_a_grids[0]);
 
  106             arg.p_b_grid  = 
reinterpret_cast<BDataType>(p_b_grids[0]);
 
  107             arg.p_ds_grid = p_ds_grids[0];
 
  110             for(
size_t i = 1; i < rotating_count; i++)
 
  118                         hipFree(
static_cast<void*
>(
const_cast<DDataType*
>(p_ds_grids[i][j]))));
 
  126     std::size_t iter                       = 0;
 
  127     std::size_t rotating_count             = 1;
 
  128     std::size_t size_a                     = 0;
 
  129     std::size_t size_b                     = 0;
 
  130     std::array<std::size_t, NumDs> size_ds = {0};
 
  131     std::vector<const void*> p_a_grids;
 
  132     std::vector<const void*> p_b_grids;
 
  133     std::vector<DsGridPointer> p_ds_grids;
 
  136 template <
typename Argument>
 
  144                        std::size_t rotating_count_,
 
  147         : arg(arg_), rotating_count(rotating_count_), size_a(size_a_), size_b(size_b_)
 
  149         p_a_grids.push_back(arg.p_a_grid);
 
  150         p_b_grids.push_back(arg.p_b_grid);
 
  151         for(
size_t i = 1; i < rotating_count; i++)
 
  155                 hip_check_error(hipMalloc(
static_cast<void**
>(&pADeviceBuf), size_a_));
 
  157                                           const_cast<void*
>(p_a_grids[0]),
 
  159                                           hipMemcpyDeviceToDevice));
 
  160                 p_a_grids.push_back(pADeviceBuf);
 
  165                 hip_check_error(hipMalloc(
static_cast<void**
>(&pBDeviceBuf), size_b_));
 
  167                                           const_cast<void*
>(p_b_grids[0]),
 
  169                                           hipMemcpyDeviceToDevice));
 
  170                 p_b_grids.push_back(pBDeviceBuf);
 
  177         if(rotating_count > 1)
 
  179             std::size_t idx = iter++ % rotating_count;
 
  180             arg.p_a_grid    = 
reinterpret_cast<ADataType>(p_a_grids[idx]);
 
  181             arg.p_b_grid    = 
reinterpret_cast<BDataType>(p_b_grids[idx]);
 
  186         std::cout << 
"RotatingMemWrapper: { size_a: " << size_a << 
", size_b: " << size_b
 
  187                   << 
", rotating_count: " << rotating_count << 
"}" << std::endl;
 
  191         if(rotating_count > 1)
 
  194             arg.p_a_grid = 
reinterpret_cast<ADataType>(p_a_grids[0]);
 
  195             arg.p_b_grid = 
reinterpret_cast<BDataType>(p_b_grids[0]);
 
  198             for(
size_t i = 1; i < rotating_count; i++)
 
  208     std::size_t iter           = 0;
 
  209     std::size_t rotating_count = 1;
 
  210     std::size_t size_a         = 0;
 
  211     std::size_t size_b         = 0;
 
  212     std::vector<const void*> p_a_grids;
 
  213     std::vector<const void*> p_b_grids;
 
  218     hipDeviceProp_t deviceProps;
 
  220     int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
 
  222     ck::flush_icache<<<dim3(gpu_block3), dim3(64), 0, 
nullptr>>>();
 
  226 template <
bool TimePreprocess,
 
  230           typename PreProcessFunc>
 
  232                                              PreProcessFunc preprocess,
 
  236                                              std::size_t lds_byte,
 
  246             printf(
"%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
 
  255             printf(
"Warm up %d times\n", stream_config.
cold_niters_);
 
  260             kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
 
  264         const int nrepeat = stream_config.
nrepeat_;
 
  271             printf(
"Start running %d times...\n", nrepeat);
 
  275         std::set<float> times;
 
  277         float total_time = 0;
 
  279         hipEvent_t start, stop;
 
  287         for(
int i = 0; i < nrepeat; ++i)
 
  289             if constexpr(!TimePreprocess)
 
  302             if constexpr(TimePreprocess)
 
  307             kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
 
  325                 printf(
"gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n",
 
  326                        static_cast<const void*
>(gemm_args.p_a_grid),
 
  327                        static_cast<const void*
>(gemm_args.p_b_grid));
 
  335         times.insert(cur_time);
 
  337         total_time += cur_time;
 
  341         auto mid = times.begin();
 
  342         std::advance(mid, (nrepeat - 1) / 2);
 
  350             std::advance(mid_next, 1);
 
  351             return (*mid + *mid_next) / 2;
 
  355         hipDeviceProp_t deviceProps;
 
  357         float preprocess_offset = deviceProps.multiProcessorCount == 80 ? 0.005 : 0.01;
 
  358         return (total_time - preprocess_offset * nrepeat) / nrepeat;
 
  364         kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
 
  370     kernel<<<grid_dim, block_dim, lds_byte, stream_config.
stream_id_>>>(gemm_args, args...);
 
void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:10
 
void flush_icache()
Definition: flush_cache.hpp:216
 
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition: flush_cache.hpp:231
 
int32_t int32_t
Definition: integer.hpp:10
 
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
 
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
int32_t index_t
Definition: ck.hpp:300
 
Definition: stream_config.hpp:10
 
int cold_niters_
Definition: stream_config.hpp:14
 
bool time_kernel_
Definition: stream_config.hpp:12
 
int nrepeat_
Definition: stream_config.hpp:15
 
hipStream_t stream_id_
Definition: stream_config.hpp:11
 
Definition: functional2.hpp:33
 
Definition: flush_cache.hpp:138
 
decltype(Argument::p_a_grid) ADataType
Definition: flush_cache.hpp:139
 
RotatingMemWrapper(Argument &arg_, std::size_t rotating_count_, std::size_t size_a_, std::size_t size_b_)
Definition: flush_cache.hpp:143
 
decltype(Argument::p_b_grid) BDataType
Definition: flush_cache.hpp:140
 
~RotatingMemWrapper()
Definition: flush_cache.hpp:189
 
RotatingMemWrapper()=delete
 
void Print()
Definition: flush_cache.hpp:184
 
void Next()
Definition: flush_cache.hpp:175
 
Definition: flush_cache.hpp:20
 
decltype(Argument::p_b_grid) BDataType
Definition: flush_cache.hpp:24
 
void Print()
Definition: flush_cache.hpp:95
 
static constexpr index_t NumDs
Definition: flush_cache.hpp:21
 
decltype(Argument::p_a_grid) ADataType
Definition: flush_cache.hpp:23
 
RotatingMemWrapperMultiD()=delete
 
~RotatingMemWrapperMultiD()
Definition: flush_cache.hpp:100
 
decltype(Argument::p_ds_grid) DsGridPointer
Definition: flush_cache.hpp:25
 
RotatingMemWrapperMultiD(Argument &arg_, std::size_t rotating_count_, std::size_t size_a_, std::size_t size_b_, std::array< std::size_t, NumDs > size_ds_)
Definition: flush_cache.hpp:28
 
void Next()
Definition: flush_cache.hpp:85
 
#define CK_ENV(name)
Definition: env.hpp:128