30 #ifndef HIPCUB_ROCPRIM_DEVICE_DEVICE_SPMV_HPP_
31 #define HIPCUB_ROCPRIM_DEVICE_DEVICE_SPMV_HPP_
33 #include "../../../config.hpp"
35 #include "../iterator/tex_ref_input_iterator.hpp"
37 BEGIN_HIPCUB_NAMESPACE
50 OffsetT* d_row_end_offsets;
51 OffsetT* d_column_indices;
63 static constexpr uint32_t CsrMVKernel_MaxThreads = 256;
65 template <
typename ValueT>
66 static __global__
void
67 CsrMVKernel(SpmvParams<ValueT, int> spmv_params)
69 __shared__ ValueT partial;
71 const int32_t row_id = hipBlockIdx_x;
73 if(hipThreadIdx_x == 0)
75 partial = spmv_params.beta * spmv_params.d_vector_y[row_id];
79 int32_t row_offset = (row_id == 0) ? (0) : (spmv_params.d_row_end_offsets[row_id - 1]);
80 for(uint32_t thread_offset = 0; thread_offset < spmv_params.num_cols / hipBlockDim_x; thread_offset++)
82 int32_t offset = row_offset + thread_offset * hipBlockDim_x + hipThreadIdx_x;
84 if(offset < spmv_params.d_row_end_offsets[row_id])
88 spmv_params.d_values[offset] *
89 spmv_params.d_vector_x[spmv_params.d_column_indices[offset]];
91 atomicAdd(&partial, t_value);
95 if(hipThreadIdx_x == 0)
97 spmv_params.d_vector_y[row_id] = partial;
103 template <
typename ValueT>
104 HIPCUB_RUNTIME_FUNCTION
106 void* d_temp_storage,
107 size_t& temp_storage_bytes,
110 int* d_column_indices,
116 hipStream_t stream = 0,
117 bool debug_synchronous =
false)
120 spmv_params.d_values = d_values;
121 spmv_params.d_row_end_offsets = d_row_offsets + 1;
122 spmv_params.d_column_indices = d_column_indices;
123 spmv_params.d_vector_x = d_vector_x;
124 spmv_params.d_vector_y = d_vector_y;
125 spmv_params.num_rows = num_rows;
126 spmv_params.num_cols = num_cols;
127 spmv_params.num_nonzeros = num_nonzeros;
128 spmv_params.alpha = 1.0;
129 spmv_params.beta = 0.0;
132 if(d_temp_storage ==
nullptr)
136 temp_storage_bytes = 4;
137 return hipError_t(0);
141 size_t block_size = min(num_cols, DeviceSpmv::CsrMVKernel_MaxThreads);
142 size_t grid_size = num_rows;
143 CsrMVKernel<<<grid_size, block_size, 0, stream>>>(spmv_params);
144 status = hipGetLastError();
static __host__ hipError_t CsrMV(void *d_temp_storage, size_t &temp_storage_bytes, ValueT *d_values, int *d_row_offsets, int *d_column_indices, ValueT *d_vector_x, ValueT *d_vector_y, int num_rows, int num_cols, int num_nonzeros, hipStream_t stream=0, bool debug_synchronous=false)
Definition: device_spmv.hpp:105
< Signed integer type for sequence offsets
Definition: device_spmv.hpp:49