/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.2.0/hipcub/include/hipcub/backend/rocprim/device/device_spmv.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.2.0/hipcub/include/hipcub/backend/rocprim/device/device_spmv.hpp Source File#

hipCUB: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.2.0/hipcub/include/hipcub/backend/rocprim/device/device_spmv.hpp Source File
device_spmv.hpp
1 /******************************************************************************
2  * Copyright (c) 2010-2011, Duane Merrill. All rights reserved.
3  * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
4  * Modifications Copyright (c) 2017-2021, Advanced Micro Devices, Inc. All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  * * Redistributions of source code must retain the above copyright
9  * notice, this list of conditions and the following disclaimer.
10  * * Redistributions in binary form must reproduce the above copyright
11  * notice, this list of conditions and the following disclaimer in the
12  * documentation and/or other materials provided with the distribution.
13  * * Neither the name of the NVIDIA CORPORATION nor the
14  * names of its contributors may be used to endorse or promote products
15  * derived from this software without specific prior written permission.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
18  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
21  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
22  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
23  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
24  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27  *
28  ******************************************************************************/
29 
30 #ifndef HIPCUB_ROCPRIM_DEVICE_DEVICE_SPMV_HPP_
31 #define HIPCUB_ROCPRIM_DEVICE_DEVICE_SPMV_HPP_
32 
33 #include "../../../config.hpp"
34 
35 #include "../iterator/tex_ref_input_iterator.hpp"
36 
37 BEGIN_HIPCUB_NAMESPACE
38 
39 class DeviceSpmv
40 {
41 
42 public:
43 
44 template <
45  typename ValueT,
46  typename OffsetT>
47 struct SpmvParams
48 {
49  ValueT* d_values;
50  OffsetT* d_row_end_offsets;
51  OffsetT* d_column_indices;
52  ValueT* d_vector_x;
53  ValueT* d_vector_y;
54  int num_rows;
55  int num_cols;
56  int num_nonzeros;
57  ValueT alpha;
58  ValueT beta;
59 
61 };
62 
63 static constexpr uint32_t CsrMVKernel_MaxThreads = 256;
64 
65 template <typename ValueT>
66 static __global__ void
67 CsrMVKernel(SpmvParams<ValueT, int> spmv_params)
68 {
69  __shared__ ValueT partial;
70 
71  const int32_t row_id = hipBlockIdx_x;
72 
73  if(hipThreadIdx_x == 0)
74  {
75  partial = spmv_params.beta * spmv_params.d_vector_y[row_id];
76  }
77  __syncthreads();
78 
79  int32_t row_offset = (row_id == 0) ? (0) : (spmv_params.d_row_end_offsets[row_id - 1]);
80  for(uint32_t thread_offset = 0; thread_offset < spmv_params.num_cols / hipBlockDim_x; thread_offset++)
81  {
82  int32_t offset = row_offset + thread_offset * hipBlockDim_x + hipThreadIdx_x;
83 
84  if(offset < spmv_params.d_row_end_offsets[row_id])
85  {
86  ValueT t_value =
87  spmv_params.alpha *
88  spmv_params.d_values[offset] *
89  spmv_params.d_vector_x[spmv_params.d_column_indices[offset]];
90 
91  atomicAdd(&partial, t_value);
92 
93  __syncthreads();
94 
95  if(hipThreadIdx_x == 0)
96  {
97  spmv_params.d_vector_y[row_id] = partial;
98  }
99  }
100  }
101 }
102 
103 template <typename ValueT>
104  HIPCUB_RUNTIME_FUNCTION
105  static hipError_t CsrMV(
106  void* d_temp_storage,
107  size_t& temp_storage_bytes,
108  ValueT* d_values,
109  int* d_row_offsets,
110  int* d_column_indices,
111  ValueT* d_vector_x,
112  ValueT* d_vector_y,
113  int num_rows,
114  int num_cols,
115  int num_nonzeros,
116  hipStream_t stream = 0,
117  bool debug_synchronous = false)
118  {
119  SpmvParams<ValueT, int> spmv_params;
120  spmv_params.d_values = d_values;
121  spmv_params.d_row_end_offsets = d_row_offsets + 1;
122  spmv_params.d_column_indices = d_column_indices;
123  spmv_params.d_vector_x = d_vector_x;
124  spmv_params.d_vector_y = d_vector_y;
125  spmv_params.num_rows = num_rows;
126  spmv_params.num_cols = num_cols;
127  spmv_params.num_nonzeros = num_nonzeros;
128  spmv_params.alpha = 1.0;
129  spmv_params.beta = 0.0;
130 
131  hipError_t status;
132  if(d_temp_storage == nullptr)
133  {
134  temp_storage_bytes = 0;
135  return hipError_t(0);
136  }
137  else
138  {
139  size_t block_size = min(num_cols, DeviceSpmv::CsrMVKernel_MaxThreads);
140  size_t grid_size = num_rows;
141  CsrMVKernel<<<grid_size, block_size, 0, stream>>>(spmv_params);
142  status = hipGetLastError();
143  }
144  return status;
145  }
146 };
147 
148 END_HIPCUB_NAMESPACE
149 
150 #endif // HIPCUB_CUB_DEVICE_DEVICE_SELECT_HPP_
151 
static __host__ hipError_t CsrMV(void *d_temp_storage, size_t &temp_storage_bytes, ValueT *d_values, int *d_row_offsets, int *d_column_indices, ValueT *d_vector_x, ValueT *d_vector_y, int num_rows, int num_cols, int num_nonzeros, hipStream_t stream=0, bool debug_synchronous=false)
Definition: device_spmv.hpp:105
< Signed integer type for sequence offsets
Definition: device_spmv.hpp:49